]> git.nihav.org Git - nihav.git/blame - nihav-realmedia/src/codecs/rv40enc/motion_est.rs
avimux: do not record palette change chunks in OpenDML index
[nihav.git] / nihav-realmedia / src / codecs / rv40enc / motion_est.rs
CommitLineData
4965a5e5
KS
1use nihav_core::frame::*;
2use nihav_codec_support::codecs::{MV, ZERO_MV};
3use std::str::FromStr;
4use super::dsp::{RefMBData, luma_mc, chroma_mc};
5
e6aaad5c 6#[derive(Clone,Copy,PartialEq,Default)]
4965a5e5
KS
7pub enum MVSearchMode {
8 Dummy,
9 Diamond,
e6aaad5c 10 #[default]
4965a5e5
KS
11 Hexagon,
12 UMH,
13}
14
15impl MVSearchMode {
16 pub const fn get_possible_modes() -> &'static [&'static str] {
17 &["diamond", "hexagon", "umh"]
18 }
19 fn create(self) -> Box<dyn MVSearch+Send> {
20 match self {
21 MVSearchMode::Dummy => Box::new(DummySearcher{}),
22 MVSearchMode::Diamond => Box::new(DiaSearch::new()),
23 MVSearchMode::Hexagon => Box::new(HexSearch::new()),
24 MVSearchMode::UMH => Box::new(UnevenHexSearch::new()),
25 }
26 }
27}
28
4965a5e5
KS
29impl std::fmt::Display for MVSearchMode {
30 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
31 match *self {
32 MVSearchMode::Diamond => write!(f, "diamond"),
33 MVSearchMode::Hexagon => write!(f, "hexagon"),
34 MVSearchMode::UMH => write!(f, "umh"),
35 MVSearchMode::Dummy => write!(f, "dummy"),
36 }
37 }
38}
39
40impl FromStr for MVSearchMode {
41 type Err = ();
42 fn from_str(s: &str) -> Result<Self, Self::Err> {
43 match s {
44 "diamond" => Ok(MVSearchMode::Diamond),
45 "hexagon" => Ok(MVSearchMode::Hexagon),
46 "umh" => Ok(MVSearchMode::UMH),
47 "dummy" => Ok(MVSearchMode::Dummy),
48 _ => Err(()),
49 }
50 }
51}
52
08e0a8bf 53const MAX_DIST: u32 = u32::MAX;
4965a5e5
KS
54const DIST_THRESH: u32 = 256;
55
d92111a8 56#[allow(clippy::wrong_self_convention)]
4965a5e5
KS
57trait FromPixels {
58 fn from_pixels(self) -> Self;
59}
60
61impl FromPixels for MV {
62 fn from_pixels(self) -> MV {
63 MV { x: self.x * 4, y: self.y * 4 }
64 }
65}
66
67const DIA_PATTERN: [MV; 9] = [
68 ZERO_MV,
69 MV {x: -2, y: 0},
70 MV {x: -1, y: 1},
71 MV {x: 0, y: 2},
72 MV {x: 1, y: 1},
73 MV {x: 2, y: 0},
74 MV {x: 1, y: -1},
75 MV {x: 0, y: -2},
76 MV {x: -1, y: -1}
77];
78
79const HEX_PATTERN: [MV; 7] = [
80 ZERO_MV,
81 MV {x: -2, y: 0},
82 MV {x: -1, y: 2},
83 MV {x: 1, y: 2},
84 MV {x: 2, y: 0},
85 MV {x: 1, y: -2},
86 MV {x: -1, y: -2}
87];
88
89const REFINEMENT: [MV; 4] = [
90 MV {x: -1, y: 0},
91 MV {x: 0, y: 1},
92 MV {x: 1, y: 0},
93 MV {x: 0, y: -1}
94];
95
96macro_rules! search_template {
97 ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident, $threshold: expr) => ({
98 search_template!($self, $mv_est, $cur_blk, $mb_x, $mb_y, $sad_func, $threshold, ZERO_MV, MAX_DIST, true)
99 });
100 ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident, $threshold: expr, $start_mv: expr, $best_dist: expr, $fullpel_stage: expr) => ({
101 let mut best_dist = $best_dist;
102 let mut best_mv = $start_mv;
103
104 let mut min_dist;
105 let mut min_idx;
106
107 if $fullpel_stage {
108 $self.reset();
109 loop {
110 let mut cur_best_dist = best_dist;
111 for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) {
112 if *dist == MAX_DIST {
113 *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point.from_pixels(), cur_best_dist);
114 cur_best_dist = cur_best_dist.min(*dist);
115 if *dist <= $threshold {
116 break;
117 }
118 }
119 }
120 min_dist = $self.dist[0];
121 min_idx = 0;
122 for (i, &dist) in $self.dist.iter().enumerate().skip(1) {
123 if dist < min_dist {
124 min_dist = dist;
125 min_idx = i;
126 if dist <= $threshold {
127 break;
128 }
129 }
130 }
131 if min_dist <= $threshold || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range || $self.point[min_idx].y.abs() >= $mv_est.mv_range {
132 break;
133 }
134 best_dist = min_dist;
135 $self.update($self.steps[min_idx]);
136 }
137 best_dist = min_dist;
138 best_mv = $self.point[min_idx];
139 if best_dist <= $threshold {
140 return (best_mv.from_pixels(), best_dist);
141 }
142 for &step in REFINEMENT.iter() {
143 let mv = best_mv + step;
144 let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv.from_pixels(), MAX_DIST);
145 if best_dist > dist {
146 best_dist = dist;
147 best_mv = mv;
148 }
149 }
150 best_mv = best_mv.from_pixels();
151 if best_dist <= $threshold {
152 return (best_mv, best_dist);
153 }
154 }
155
156 // subpel refinement
157 $self.set_new_point(best_mv, best_dist);
158 loop {
159 let mut cur_best_dist = best_dist;
160 for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) {
161 if *dist == MAX_DIST {
162 *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point, cur_best_dist);
163 cur_best_dist = cur_best_dist.min(*dist);
164 if *dist <= $threshold {
165 break;
166 }
167 }
168 }
169 min_dist = $self.dist[0];
170 min_idx = 0;
171 for (i, &dist) in $self.dist.iter().enumerate().skip(1) {
172 if dist < min_dist {
173 min_dist = dist;
174 min_idx = i;
175 if dist <= $threshold {
176 break;
177 }
178 }
179 }
180 if min_dist <= $threshold || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range * 8 || $self.point[min_idx].y.abs() >= $mv_est.mv_range * 8 {
181 break;
182 }
183 best_dist = min_dist;
184 $self.update($self.steps[min_idx]);
185 }
186 best_dist = min_dist;
187 best_mv = $self.point[min_idx];
188 if best_dist <= $threshold {
189 return (best_mv, best_dist);
190 }
191 for &step in REFINEMENT.iter() {
192 let mv = best_mv + step;
193 let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv, MAX_DIST);
194 if best_dist > dist {
195 best_dist = dist;
196 best_mv = mv;
197 }
198 }
199 (best_mv, best_dist)
200 });
201}
202
203macro_rules! pattern_search {
204 ($struct_name: ident, $patterns: expr) => {
205 pub struct $struct_name {
206 point: [MV; $patterns.len()],
207 dist: [u32; $patterns.len()],
208 steps: &'static [MV; $patterns.len()],
209 }
210
211 impl $struct_name {
212 pub fn new() -> Self {
213 Self {
214 point: $patterns,
215 dist: [MAX_DIST; $patterns.len()],
216 steps: &$patterns,
217 }
218 }
219 fn reset(&mut self) {
220 self.point = $patterns;
221 self.dist = [MAX_DIST; $patterns.len()];
222 }
223 fn set_new_point(&mut self, start: MV, dist: u32) {
224 for (dst, &src) in self.point.iter_mut().zip(self.steps.iter()) {
225 *dst = src + start;
226 }
227 self.dist = [MAX_DIST; $patterns.len()];
228 self.dist[0] = dist;
229 }
230 fn update(&mut self, step: MV) {
231 let mut new_point = self.point;
232 let mut new_dist = [MAX_DIST; $patterns.len()];
233
234 for point in new_point.iter_mut() {
235 *point += step;
236 }
237
238 for (new_point, new_dist) in new_point.iter_mut().zip(new_dist.iter_mut()) {
239 for (&old_point, &old_dist) in self.point.iter().zip(self.dist.iter()) {
240 if *new_point == old_point {
241 *new_dist = old_dist;
242 break;
243 }
244 }
245 }
246 self.point = new_point;
247 self.dist = new_dist;
248 }
249 }
250
251 impl MVSearch for $struct_name {
252 fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &RefMBData, mb_x: usize, mb_y: usize, _cand_mvs: &[MV]) -> (MV, u32) {
253 search_template!(self, mv_est, cur_mb, mb_x, mb_y, sad_mb, DIST_THRESH)
254 }
255 fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) {
256 search_template!(self, mv_est, ref_blk, xpos, ypos, sad_blk8, DIST_THRESH / 4)
257 }
258 }
259 }
260}
261
262pattern_search!(DiaSearch, DIA_PATTERN);
263pattern_search!(HexSearch, HEX_PATTERN);
264
265const LARGE_HEX_PATTERN: [MV; 16] = [
266 MV { x: -4, y: 0 },
267 MV { x: -4, y: 1 },
268 MV { x: -4, y: 2 },
269 MV { x: -2, y: 3 },
270 MV { x: 0, y: 4 },
271 MV { x: 2, y: 3 },
272 MV { x: 4, y: 2 },
273 MV { x: 4, y: 1 },
274 MV { x: 4, y: 0 },
275 MV { x: 4, y: -1 },
276 MV { x: 4, y: -2 },
277 MV { x: -2, y: -3 },
278 MV { x: 0, y: -4 },
279 MV { x: -2, y: -3 },
280 MV { x: -4, y: -2 },
281 MV { x: -4, y: -1 }
282];
283
284const UNSYMM_CROSS: [MV; 4] = [
285 MV { x: -2, y: 0 },
286 MV { x: 0, y: 1 },
287 MV { x: 2, y: 0 },
288 MV { x: 0, y: -1 }
289];
290
291#[derive(Default)]
292struct UniqueSet<T:Copy+Default> {
293 list: [T; 16],
294 count: usize,
295}
296
297impl<T:Copy+Default+PartialEq> UniqueSet<T> {
298 fn new() -> Self { Self::default() }
299 fn clear(&mut self) { self.count = 0; }
300 fn get_list(&self) -> &[T] { &self.list[..self.count] }
301 fn add(&mut self, val: T) {
302 if self.count < self.list.len() && !self.get_list().contains(&val) {
303 self.list[self.count] = val;
304 self.count += 1;
305 }
306 }
307}
308
d92111a8 309#[allow(clippy::wrong_self_convention)]
4965a5e5
KS
310trait MVOps {
311 fn scale(self, scale: i16) -> Self;
312 fn is_in_range(self, range: i16) -> bool;
313}
314
315impl MVOps for MV {
316 fn scale(self, scale: i16) -> MV {
317 MV { x: self.x * scale, y: self.y * scale }
318 }
319 fn is_in_range(self, range: i16) -> bool {
320 self.x.abs() <= range && self.y.abs() <= range
321 }
322}
323
324macro_rules! single_search_step {
325 ($start:expr, $best_dist:expr, $mv_est:expr, $sad_func:ident, $ref_blk:expr, $xpos:expr, $ypos:expr, $pattern:expr, $scale:expr, $dist_thr:expr) => {{
326 let mut best_mv = $start;
327 let mut best_dist = $best_dist;
328 for point in $pattern.iter() {
329 let mv = point.scale($scale) + $start;
330 if !mv.is_in_range($mv_est.mv_range * 4) {
331 continue;
332 }
333 let dist = $mv_est.$sad_func($ref_blk, $xpos, $ypos, mv, best_dist);
334 if dist < best_dist {
335 best_mv = mv;
336 best_dist = dist;
337 if best_dist < $dist_thr {
338 break;
339 }
340 }
341 }
342 (best_mv, best_dist, best_mv != $start)
343 }}
344}
345
346struct UnevenHexSearch {
347 mv_list: UniqueSet<MV>,
348}
349
350impl UnevenHexSearch {
351 fn new() -> Self {
352 Self {
353 mv_list: UniqueSet::new(),
354 }
355 }
356 fn get_cand_mv(&mut self, cand_mvs: &[MV]) -> MV {
357 self.mv_list.clear();
358 for &mv in cand_mvs.iter() {
359 self.mv_list.add(mv);
360 }
361 match self.mv_list.count {
362 1 => self.mv_list.list[0],
363 3 => MV::pred(self.mv_list.list[0], self.mv_list.list[1], self.mv_list.list[2]),
364 _ => {
365 let sum = self.mv_list.get_list().iter().fold((0i32, 0i32),
366 |acc, mv| (acc.0 + i32::from(mv.x), acc.1 + i32::from(mv.y)));
367 MV {x: (sum.0 / (self.mv_list.count as i32)) as i16,
368 y: (sum.1 / (self.mv_list.count as i32)) as i16}
369 },
370 }
371 }
372}
373
374macro_rules! umh_search_template {
375 ($cand_mv:expr, $cutoff:expr, $mv_est:expr, $sad_func:ident, $ref_blk:expr, $xpos:expr, $ypos:expr) => {{
376 let cand_mv = $cand_mv;
377 let best_dist = $mv_est.$sad_func($ref_blk, $xpos, $ypos, cand_mv, MAX_DIST);
378 if best_dist < $cutoff {
379 return (cand_mv, best_dist);
380 }
381
382 // step 1 - small refinement search
383 let (mut cand_mv, mut best_dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, DIA_PATTERN, 1, $cutoff);
384 if best_dist < $cutoff {
385 return (cand_mv, best_dist);
386 }
387
388 // step 2 - unsymmetrical cross search
389 loop {
390 let (mv, dist, changed) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, UNSYMM_CROSS, 4, $cutoff);
391 if !changed {
392 break;
393 }
394 cand_mv = mv;
395 best_dist = dist;
396 if best_dist < $cutoff {
397 return (mv, dist);
398 }
399 }
400
401 // step 3 - multi-hexagon grid search
402 let mut scale = 4;
403 while scale > 0 {
404 let (mv, dist, changed) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, LARGE_HEX_PATTERN, scale, $cutoff);
405 if !changed {
406 break;
407 }
408 cand_mv = mv;
409 best_dist = dist;
410 if best_dist < $cutoff {
411 return (mv, dist);
412 }
413 scale >>= 1;
414 }
415 // step 4 - final hexagon search
416 let (cand_mv, best_dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, HEX_PATTERN, 1, $cutoff);
417 if best_dist > $cutoff {
418 let (mv, dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, DIA_PATTERN, 1, $cutoff);
419 (mv, dist)
420 } else {
421 (cand_mv, best_dist)
422 }
423 }}
424}
425
426impl MVSearch for UnevenHexSearch {
427 fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32) {
428 let cand_mv = self.get_cand_mv(cand_mvs);
429 let cutoff = mv_est.cutoff_thr;
430 umh_search_template!(cand_mv, cutoff, mv_est, sad_mb, cur_mb, mb_x, mb_y)
431 }
432 fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32) {
433 let cand_mv = self.get_cand_mv(cand_mvs);
434 let cutoff = mv_est.cutoff_thr / 4;
435 umh_search_template!(cand_mv, cutoff, mv_est, sad_blk8, ref_blk, xpos, ypos)
436 }
437}
438
439struct MVEstimator<'a> {
440 pic: &'a NAVideoBuffer<u8>,
441 mv_range: i16,
442 cutoff_thr: u32,
443}
444
445macro_rules! sad {
446 ($src1:expr, $src2:expr) => {
447 $src1.iter().zip($src2.iter()).fold(0u32, |acc, (&a, &b)|
448 acc + (((i32::from(a) - i32::from(b)) * (i32::from(a) - i32::from(b))) as u32))
449 }
450}
451
452impl<'a> MVEstimator<'a> {
453 fn sad_mb(&self, ref_mb: &RefMBData, mb_x: usize, mb_y: usize, mv: MV, cur_best_dist: u32) -> u32 {
454 let mut dst = RefMBData::new();
455 luma_mc(&mut dst.y, 16, self.pic, mb_x * 16, mb_y * 16, mv, true);
456
457 let mut dist = 0;
458 for (dline, sline) in dst.y.chunks(16).zip(ref_mb.y.chunks(16)) {
459 dist += sad!(dline, sline);
460 if dist > cur_best_dist {
461 return dist;
462 }
463 }
464 chroma_mc(&mut dst.u, 8, self.pic, mb_x * 8, mb_y * 8, 1, mv, true);
465 dist += sad!(dst.u, ref_mb.u);
466 if dist > cur_best_dist {
467 return dist;
468 }
469 chroma_mc(&mut dst.v, 8, self.pic, mb_x * 8, mb_y * 8, 2, mv, true);
470 dist += sad!(dst.v, ref_mb.v);
471
472 dist
473 }
474 fn sad_blk8(&self, ref_mb: &RefMBData, xpos: usize, ypos: usize, mv: MV, cur_best_dist: u32) -> u32 {
475 let mut cur_y = [0; 64];
476 let mut cur_u = [0; 16];
477 let mut cur_v = [0; 16];
478
479 let mut dist = 0;
480
481 let y_off = (xpos & 8) + (ypos & 8) * 16;
482 luma_mc(&mut cur_y, 8, self.pic, xpos, ypos, mv, false);
483 for (dline, sline) in cur_y.chunks(8).zip(ref_mb.y[y_off..].chunks(16)) {
484 dist += sad!(dline, sline);
485 if dist > cur_best_dist {
486 return dist;
487 }
488 }
489
490 let c_off = (xpos & 8) / 2 + (ypos & 8) * 4;
491 chroma_mc(&mut cur_u, 4, self.pic, xpos / 2, ypos / 2, 1, mv, false);
492 for (dline, sline) in cur_u.chunks(4).zip(ref_mb.u[c_off..].chunks(8)) {
493 dist += sad!(dline, sline);
494 if dist > cur_best_dist {
495 return dist;
496 }
497 }
498 chroma_mc(&mut cur_v, 4, self.pic, xpos / 2, ypos / 2, 2, mv, false);
499 for (dline, sline) in cur_v.chunks(4).zip(ref_mb.v[c_off..].chunks(8)) {
500 dist += sad!(dline, sline);
501 if dist > cur_best_dist {
502 return dist;
503 }
504 }
505
506 dist
507 }
508}
509
510trait MVSearch {
511 fn search_mb(&mut self, mv_est: &mut MVEstimator, ref_mb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32);
512 fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32);
513}
514
515struct DummySearcher {}
516
517impl MVSearch for DummySearcher {
518 fn search_mb(&mut self, _mv_est: &mut MVEstimator, _ref_mb: &RefMBData, _mb_x: usize, _mb_y: usize, _cand_mvs: &[MV]) -> (MV, u32) {
08e0a8bf 519 (ZERO_MV, u32::MAX / 2)
4965a5e5
KS
520 }
521 fn search_blk8(&mut self, _mv_est: &mut MVEstimator, _ref_mb: &RefMBData, _xpos: usize, _ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) {
08e0a8bf 522 (ZERO_MV, u32::MAX / 2)
4965a5e5
KS
523 }
524}
525
526pub struct MotionEstimator {
527 pub range: i16,
528 pub thresh: u32,
529 mode: MVSearchMode,
530 srch: Box<dyn MVSearch+Send>,
531}
532
533impl MotionEstimator {
534 pub fn new() -> Self {
535 let mode = MVSearchMode::default();
536 Self {
537 range: 64,
538 thresh: 32,
539 mode,
540 srch: mode.create(),
541 }
542 }
543 pub fn get_mode(&self) -> MVSearchMode { self.mode }
544 pub fn set_mode(&mut self, new_mode: MVSearchMode) {
545 if self.mode != new_mode {
546 self.mode = new_mode;
547 self.srch = self.mode.create();
548 }
549 }
550 pub fn search_mb_p(&mut self, pic: &NAVideoBuffer<u8>, refmb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32) {
551 let mut mv_est = MVEstimator {
552 mv_range: self.range,
553 cutoff_thr: self.thresh,
554 pic,
555 };
556 self.srch.search_mb(&mut mv_est, refmb, mb_x, mb_y, cand_mvs)
557 }
558 pub fn search_blk8(&mut self, pic: &NAVideoBuffer<u8>, refmb: &RefMBData, xoff: usize, yoff: usize, cand_mvs: &[MV]) -> (MV, u32) {
559 let mut mv_est = MVEstimator {
560 mv_range: self.range,
561 cutoff_thr: self.thresh,
562 pic,
563 };
564 self.srch.search_blk8(&mut mv_est, refmb, xoff, yoff, cand_mvs)
565 }
566}
567
568pub struct SearchB<'a> {
569 ref_p: &'a NAVideoBuffer<u8>,
570 ref_n: &'a NAVideoBuffer<u8>,
571 xpos: usize,
572 ypos: usize,
573 ratios: [u32; 2],
574 tmp1: RefMBData,
575 tmp2: RefMBData,
576 pred_blk: RefMBData,
577}
578
579impl<'a> SearchB<'a> {
580 pub fn new(ref_p: &'a NAVideoBuffer<u8>, ref_n: &'a NAVideoBuffer<u8>, mb_x: usize, mb_y: usize, ratios: [u32; 2]) -> Self {
581 Self {
582 ref_p, ref_n,
583 xpos: mb_x * 16,
584 ypos: mb_y * 16,
585 ratios,
586 tmp1: RefMBData::new(),
587 tmp2: RefMBData::new(),
588 pred_blk: RefMBData::new(),
589 }
590 }
591 pub fn search_mb(&mut self, ref_mb: &RefMBData, cand_mvs: [MV; 2]) -> (MV, MV) {
592 let mut best_cand = cand_mvs;
593 let mut best_dist = self.interp_b_dist(ref_mb, best_cand, MAX_DIST);
594
595 loop {
596 let mut improved = false;
597 for &fmv_add in DIA_PATTERN.iter() {
598 for &bmv_add in DIA_PATTERN.iter() {
599 let cand = [best_cand[0] + fmv_add.from_pixels(),
600 best_cand[1] + bmv_add.from_pixels()];
601 let dist = self.interp_b_dist(ref_mb, cand, best_dist);
602 if dist < best_dist {
603 best_dist = dist;
604 best_cand = cand;
605 improved = true;
606 }
607 }
608 }
609 if !improved {
610 break;
611 }
612 }
613
614 for &fmv_add in REFINEMENT.iter() {
615 for &bmv_add in REFINEMENT.iter() {
616 let cand = [best_cand[0] + fmv_add, best_cand[1] + bmv_add];
617 let dist = self.interp_b_dist(ref_mb, cand, best_dist);
618 if dist < best_dist {
619 best_dist = dist;
620 best_cand = cand;
621 }
622 }
623 }
624
625 (best_cand[0], best_cand[1])
626 }
627 fn interp_b_dist(&mut self, ref_mb: &RefMBData, cand_mv: [MV; 2], cur_best_dist: u32) -> u32 {
628 let [fmv, bmv] = cand_mv;
629 luma_mc(&mut self.tmp1.y, 16, self.ref_p, self.xpos, self.ypos, fmv, true);
630 chroma_mc(&mut self.tmp1.u, 8, self.ref_p, self.xpos / 2, self.ypos / 2, 1, fmv, true);
631 chroma_mc(&mut self.tmp1.v, 8, self.ref_p, self.xpos / 2, self.ypos / 2, 2, fmv, true);
632 luma_mc(&mut self.tmp2.y, 16, self.ref_n, self.xpos, self.ypos, bmv, true);
633 chroma_mc(&mut self.tmp2.u, 8, self.ref_n, self.xpos / 2, self.ypos / 2, 1, bmv, true);
634 chroma_mc(&mut self.tmp2.v, 8, self.ref_n, self.xpos / 2, self.ypos / 2, 2, bmv, true);
635 self.pred_blk.avg(&self.tmp1, self.ratios[0], &self.tmp2, self.ratios[1]);
636
637 let mut dist = 0;
638 for (dline, sline) in self.pred_blk.y.chunks(16).zip(ref_mb.y.chunks(16)) {
639 dist += sad!(dline, sline);
640 if dist > cur_best_dist {
641 return dist;
642 }
643 }
644 dist += sad!(self.pred_blk.u, ref_mb.u);
645 if dist > cur_best_dist {
646 return dist;
647 }
648 dist += sad!(self.pred_blk.v, ref_mb.v);
649
650 dist
651 }
652}
653
654macro_rules! hadamard {
655 ($s0:expr, $s1:expr, $s2:expr, $s3:expr, $d0:expr, $d1:expr, $d2:expr, $d3:expr) => {
656 let t0 = $s0 + $s1;
657 let t1 = $s0 - $s1;
658 let t2 = $s2 + $s3;
659 let t3 = $s2 - $s3;
660 $d0 = t0 + t2;
661 $d2 = t0 - t2;
662 $d1 = t1 + t3;
663 $d3 = t1 - t3;
664 }
665}
666
667pub struct FrameComplexityEstimate {
668 ref_frm: NAVideoBufferRef<u8>,
669 cur_frm: NAVideoBufferRef<u8>,
670 nxt_frm: NAVideoBufferRef<u8>,
671 width: usize,
672 height: usize,
673}
674
675impl FrameComplexityEstimate {
676 pub fn new() -> Self {
677 let vinfo = NAVideoInfo::new(24, 24, false, YUV420_FORMAT);
678 let vt = alloc_video_buffer(vinfo, 4).unwrap();
679 let buf = vt.get_vbuf().unwrap();
680 Self {
681 ref_frm: buf.clone(),
682 cur_frm: buf.clone(),
683 nxt_frm: buf,
684 width: 0,
685 height: 0,
686 }
687 }
688 pub fn resize(&mut self, width: usize, height: usize) {
689 if width != self.width || height != self.height {
690 self.width = width;
691 self.height = height;
692
693 let vinfo = NAVideoInfo::new(self.width / 2, self.height / 2, false, YUV420_FORMAT);
694 let vt = alloc_video_buffer(vinfo, 4).unwrap();
695 self.ref_frm = vt.get_vbuf().unwrap();
696 let frm = self.ref_frm.get_data_mut().unwrap();
697 for el in frm.iter_mut() {
698 *el = 0x80;
699 }
700 let vt = alloc_video_buffer(vinfo, 4).unwrap();
701 self.cur_frm = vt.get_vbuf().unwrap();
702 let vt = alloc_video_buffer(vinfo, 4).unwrap();
703 self.nxt_frm = vt.get_vbuf().unwrap();
704 }
705 }
706 pub fn set_current(&mut self, frm: &NAVideoBuffer<u8>) {
707 Self::downscale(&mut self.cur_frm, frm);
708 }
709 pub fn get_complexity(&self, ftype: FrameType) -> u32 {
710 match ftype {
711 FrameType::I => Self::calculate_i_cplx(&self.cur_frm),
712 FrameType::P => Self::calculate_mv_diff(&self.ref_frm, &self.cur_frm),
713 _ => 0,
714 }
715 }
716 pub fn decide_b_frame(&mut self, frm1: &NAVideoBuffer<u8>, frm2: &NAVideoBuffer<u8>) -> bool {
717 Self::downscale(&mut self.cur_frm, frm1);
718 Self::downscale(&mut self.nxt_frm, frm2);
719 let diff_ref_cur = Self::calculate_mv_diff(&self.ref_frm, &self.cur_frm);
720 let diff_cur_nxt = Self::calculate_mv_diff(&self.cur_frm, &self.nxt_frm);
721
722 // simple rule - if complexity ref->cur and cur->next is about the same this should be a B-frame
723 let ddiff = diff_ref_cur.max(diff_cur_nxt) - diff_ref_cur.min(diff_cur_nxt);
724 if ddiff < 256 {
725 true
726 } else {
727 let mut order = 0;
728 while (ddiff << order) < diff_ref_cur.min(diff_cur_nxt) {
729 order += 1;
730 }
731 order > 2
732 }
733 }
734 pub fn update_ref(&mut self) {
735 std::mem::swap(&mut self.ref_frm, &mut self.cur_frm);
736 }
737
738 fn add_mv(mb_x: usize, mb_y: usize, mv: MV) -> (usize, usize) {
739 (((mb_x * 16) as isize + (mv.x as isize)) as usize,
740 ((mb_y * 16) as isize + (mv.y as isize)) as usize)
741 }
742 fn calculate_i_cplx(frm: &NAVideoBuffer<u8>) -> u32 {
743 let (w, h) = frm.get_dimensions(0);
744 let src = frm.get_data();
745 let stride = frm.get_stride(0);
746 let mut sum = 0;
747 let mut offset = 0;
748 for y in (0..h).step_by(4) {
749 for x in (0..w).step_by(4) {
750 sum += Self::satd_i(src, offset + x, stride, x > 0, y > 0);
751 }
752 offset += stride * 4;
753 }
754 sum
755 }
756 fn calculate_mv_diff(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>) -> u32 {
757 let (w, h) = ref_frm.get_dimensions(0);
758 let mut sum = 0;
759 for mb_y in 0..(h / 16) {
760 for mb_x in 0..(w / 16) {
761 sum += Self::satd_mb_diff(ref_frm, cur_frm, mb_x, mb_y);
762 }
763 }
764 sum
765 }
766 fn satd_mb_diff(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>, mb_x: usize, mb_y: usize) -> u32 {
767 let mv = Self::search_mv(ref_frm, cur_frm, mb_x, mb_y);
768 let mut sum = 0;
769 let src0 = ref_frm.get_data();
770 let src1 = cur_frm.get_data();
771 let stride = ref_frm.get_stride(0);
772 let (src_x, src_y) = Self::add_mv(mb_x, mb_y, mv);
773 for y in (0..16).step_by(4) {
774 for x in (0..16).step_by(4) {
775 sum += Self::satd(&src0[src_x + x + (src_y + y) * stride..],
776 &src1[mb_x * 16 + x + (mb_y * 16 + y) * stride..],
777 stride);
778 }
779 }
780 sum
781 }
782 fn search_mv(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>, mb_x: usize, mb_y: usize) -> MV {
783 let stride = ref_frm.get_stride(0);
784 let (w, h) = ref_frm.get_dimensions(0);
785 let (v_edge, h_edge) = (w - 16, h - 16);
786 let ref_src = ref_frm.get_data();
787 let cur_src = cur_frm.get_data();
788 let cur_src = &cur_src[mb_x * 16 + mb_y * 16 * stride..];
789
790 let mut best_mv = ZERO_MV;
791 let mut best_dist = Self::sad(cur_src, ref_src, mb_x, mb_y, stride, best_mv);
792 if best_dist == 0 {
793 return best_mv;
794 }
795
796 for step in (0..=2).rev() {
797 let mut changed = true;
798 while changed {
799 changed = false;
800 for &mv in DIA_PATTERN[1..].iter() {
801 let cand_mv = best_mv + mv.scale(1 << step);
802 let (cx, cy) = Self::add_mv(mb_x, mb_y, cand_mv);
803 if cx > v_edge || cy > h_edge {
804 continue;
805 }
806 let cand_dist = Self::sad(cur_src, ref_src, mb_x, mb_y, stride, cand_mv);
807 if cand_dist < best_dist {
808 best_dist = cand_dist;
809 best_mv = cand_mv;
810 if best_dist == 0 {
811 return best_mv;
812 }
813 changed = true;
814 }
815 }
816 }
817 }
818 best_mv
819 }
820 fn sad(cur_src: &[u8], src: &[u8], mb_x: usize, mb_y: usize, stride: usize, mv: MV) -> u32 {
821 let (src_x, src_y) = Self::add_mv(mb_x, mb_y, mv);
822 let mut sum = 0;
823 for (line1, line2) in cur_src.chunks(stride).zip(src[src_x + src_y * stride..].chunks(stride)).take(16) {
824 sum += line1[..16].iter().zip(line2[..16].iter()).fold(0u32,
825 |acc, (&a, &b)| acc + u32::from(a.max(b) - a.min(b)) * u32::from(a.max(b) - a.min(b)));
826 }
827 sum
828 }
829 fn satd_i(src: &[u8], mut offset: usize, stride: usize, has_left: bool, has_top: bool) -> u32 {
830 let mut diffs = [0; 16];
831 match (has_left, has_top) {
832 (true, true) => {
833 for row in diffs.chunks_exact_mut(4) {
834 let mut left = i16::from(src[offset - 1]);
835 let mut tl = i16::from(src[offset - stride - 1]);
836 for (x, dst) in row.iter_mut().enumerate() {
837 let cur = i16::from(src[offset + x]);
838 let top = i16::from(src[offset + x - stride]);
839
840 *dst = cur - (top + left + tl - top.min(left).min(tl) - top.max(left).max(tl));
841
842 left = cur;
843 tl = top;
844 }
845
846 offset += stride;
847 }
848 },
849 (true, false) => {
850 for (dst, (left, cur)) in diffs.chunks_exact_mut(4).zip(
851 src[offset - 1..].chunks(stride).zip(src[offset..].chunks(stride))) {
852 for (dst, (&left, &cur)) in dst.iter_mut().zip(left.iter().zip(cur.iter())) {
853 *dst = i16::from(cur) - i16::from(left);
854 }
855 }
856 },
857 (false, true) => {
858 for (dst, (top, cur)) in diffs.chunks_exact_mut(4).zip(
859 src[offset - stride..].chunks(stride).zip(src[offset..].chunks(stride))) {
860 for (dst, (&top, &cur)) in dst.iter_mut().zip(top.iter().zip(cur.iter())) {
861 *dst = i16::from(cur) - i16::from(top);
862 }
863 }
864 },
865 (false, false) => {
866 for (dst, src) in diffs.chunks_exact_mut(4).zip(src[offset..].chunks(stride)) {
867 for (dst, &src) in dst.iter_mut().zip(src.iter()) {
868 *dst = i16::from(src) - 128;
869 }
870 }
871 },
872 };
873 for row in diffs.chunks_exact_mut(4) {
874 hadamard!(row[0], row[1], row[2], row[3], row[0], row[1], row[2], row[3]);
875 }
876 for i in 0..4 {
877 hadamard!(diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12],
878 diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12]);
879 }
e6aaad5c 880 diffs.iter().fold(0u32, |acc, x| acc + (x.unsigned_abs() as u32))
4965a5e5
KS
881 }
882 fn satd(src0: &[u8], src1: &[u8], stride: usize) -> u32 {
883 let mut diffs = [0; 16];
884 for (dst, (src0, src1)) in diffs.chunks_exact_mut(4).zip(
885 src0.chunks(stride).zip(src1.chunks(stride))) {
886 hadamard!(i16::from(src0[0]) - i16::from(src1[0]),
887 i16::from(src0[1]) - i16::from(src1[1]),
888 i16::from(src0[2]) - i16::from(src1[2]),
889 i16::from(src0[3]) - i16::from(src1[3]),
890 dst[0], dst[1], dst[2], dst[3]);
891 }
892 for i in 0..4 {
893 hadamard!(diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12],
894 diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12]);
895 }
e6aaad5c 896 diffs.iter().fold(0u32, |acc, x| acc + (x.unsigned_abs() as u32))
4965a5e5
KS
897 }
898 fn downscale(dst: &mut NAVideoBuffer<u8>, src: &NAVideoBuffer<u8>) {
899 let dst = NASimpleVideoFrame::from_video_buf(dst).unwrap();
900 let sdata = src.get_data();
901 for plane in 0..3 {
902 let cur_w = dst.width[plane];
903 let cur_h = dst.height[plane];
904 let doff = dst.offset[plane];
905 let soff = src.get_offset(plane);
906 let dstride = dst.stride[plane];
907 let sstride = src.get_stride(plane);
908 for (dline, sstrip) in dst.data[doff..].chunks_exact_mut(dstride).zip(
909 sdata[soff..].chunks_exact(sstride * 2)).take(cur_h) {
910 let (line0, line1) = sstrip.split_at(sstride);
911 for (dst, (src0, src1)) in dline.iter_mut().zip(
912 line0.chunks_exact(2).zip(line1.chunks_exact(2))).take(cur_w) {
913 *dst = ((u16::from(src0[0]) + u16::from(src0[1]) +
914 u16::from(src1[0]) + u16::from(src1[1]) + 2) >> 2) as u8;
915 }
916 }
917 }
918 }
919}