RealVideo 4 encoder
[nihav.git] / nihav-realmedia / src / codecs / rv40enc / motion_est.rs
CommitLineData
4965a5e5
KS
1use nihav_core::frame::*;
2use nihav_codec_support::codecs::{MV, ZERO_MV};
3use std::str::FromStr;
4use super::dsp::{RefMBData, luma_mc, chroma_mc};
5
6#[derive(Clone,Copy,PartialEq)]
7pub enum MVSearchMode {
8 Dummy,
9 Diamond,
10 Hexagon,
11 UMH,
12}
13
14impl MVSearchMode {
15 pub const fn get_possible_modes() -> &'static [&'static str] {
16 &["diamond", "hexagon", "umh"]
17 }
18 fn create(self) -> Box<dyn MVSearch+Send> {
19 match self {
20 MVSearchMode::Dummy => Box::new(DummySearcher{}),
21 MVSearchMode::Diamond => Box::new(DiaSearch::new()),
22 MVSearchMode::Hexagon => Box::new(HexSearch::new()),
23 MVSearchMode::UMH => Box::new(UnevenHexSearch::new()),
24 }
25 }
26}
27
28impl Default for MVSearchMode {
29 fn default() -> Self { MVSearchMode::Hexagon }
30}
31
32impl std::fmt::Display for MVSearchMode {
33 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
34 match *self {
35 MVSearchMode::Diamond => write!(f, "diamond"),
36 MVSearchMode::Hexagon => write!(f, "hexagon"),
37 MVSearchMode::UMH => write!(f, "umh"),
38 MVSearchMode::Dummy => write!(f, "dummy"),
39 }
40 }
41}
42
43impl FromStr for MVSearchMode {
44 type Err = ();
45 fn from_str(s: &str) -> Result<Self, Self::Err> {
46 match s {
47 "diamond" => Ok(MVSearchMode::Diamond),
48 "hexagon" => Ok(MVSearchMode::Hexagon),
49 "umh" => Ok(MVSearchMode::UMH),
50 "dummy" => Ok(MVSearchMode::Dummy),
51 _ => Err(()),
52 }
53 }
54}
55
56const MAX_DIST: u32 = std::u32::MAX;
57const DIST_THRESH: u32 = 256;
58
59trait FromPixels {
60 fn from_pixels(self) -> Self;
61}
62
63impl FromPixels for MV {
64 fn from_pixels(self) -> MV {
65 MV { x: self.x * 4, y: self.y * 4 }
66 }
67}
68
69const DIA_PATTERN: [MV; 9] = [
70 ZERO_MV,
71 MV {x: -2, y: 0},
72 MV {x: -1, y: 1},
73 MV {x: 0, y: 2},
74 MV {x: 1, y: 1},
75 MV {x: 2, y: 0},
76 MV {x: 1, y: -1},
77 MV {x: 0, y: -2},
78 MV {x: -1, y: -1}
79];
80
81const HEX_PATTERN: [MV; 7] = [
82 ZERO_MV,
83 MV {x: -2, y: 0},
84 MV {x: -1, y: 2},
85 MV {x: 1, y: 2},
86 MV {x: 2, y: 0},
87 MV {x: 1, y: -2},
88 MV {x: -1, y: -2}
89];
90
91const REFINEMENT: [MV; 4] = [
92 MV {x: -1, y: 0},
93 MV {x: 0, y: 1},
94 MV {x: 1, y: 0},
95 MV {x: 0, y: -1}
96];
97
98macro_rules! search_template {
99 ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident, $threshold: expr) => ({
100 search_template!($self, $mv_est, $cur_blk, $mb_x, $mb_y, $sad_func, $threshold, ZERO_MV, MAX_DIST, true)
101 });
102 ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident, $threshold: expr, $start_mv: expr, $best_dist: expr, $fullpel_stage: expr) => ({
103 let mut best_dist = $best_dist;
104 let mut best_mv = $start_mv;
105
106 let mut min_dist;
107 let mut min_idx;
108
109 if $fullpel_stage {
110 $self.reset();
111 loop {
112 let mut cur_best_dist = best_dist;
113 for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) {
114 if *dist == MAX_DIST {
115 *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point.from_pixels(), cur_best_dist);
116 cur_best_dist = cur_best_dist.min(*dist);
117 if *dist <= $threshold {
118 break;
119 }
120 }
121 }
122 min_dist = $self.dist[0];
123 min_idx = 0;
124 for (i, &dist) in $self.dist.iter().enumerate().skip(1) {
125 if dist < min_dist {
126 min_dist = dist;
127 min_idx = i;
128 if dist <= $threshold {
129 break;
130 }
131 }
132 }
133 if min_dist <= $threshold || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range || $self.point[min_idx].y.abs() >= $mv_est.mv_range {
134 break;
135 }
136 best_dist = min_dist;
137 $self.update($self.steps[min_idx]);
138 }
139 best_dist = min_dist;
140 best_mv = $self.point[min_idx];
141 if best_dist <= $threshold {
142 return (best_mv.from_pixels(), best_dist);
143 }
144 for &step in REFINEMENT.iter() {
145 let mv = best_mv + step;
146 let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv.from_pixels(), MAX_DIST);
147 if best_dist > dist {
148 best_dist = dist;
149 best_mv = mv;
150 }
151 }
152 best_mv = best_mv.from_pixels();
153 if best_dist <= $threshold {
154 return (best_mv, best_dist);
155 }
156 }
157
158 // subpel refinement
159 $self.set_new_point(best_mv, best_dist);
160 loop {
161 let mut cur_best_dist = best_dist;
162 for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) {
163 if *dist == MAX_DIST {
164 *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point, cur_best_dist);
165 cur_best_dist = cur_best_dist.min(*dist);
166 if *dist <= $threshold {
167 break;
168 }
169 }
170 }
171 min_dist = $self.dist[0];
172 min_idx = 0;
173 for (i, &dist) in $self.dist.iter().enumerate().skip(1) {
174 if dist < min_dist {
175 min_dist = dist;
176 min_idx = i;
177 if dist <= $threshold {
178 break;
179 }
180 }
181 }
182 if min_dist <= $threshold || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range * 8 || $self.point[min_idx].y.abs() >= $mv_est.mv_range * 8 {
183 break;
184 }
185 best_dist = min_dist;
186 $self.update($self.steps[min_idx]);
187 }
188 best_dist = min_dist;
189 best_mv = $self.point[min_idx];
190 if best_dist <= $threshold {
191 return (best_mv, best_dist);
192 }
193 for &step in REFINEMENT.iter() {
194 let mv = best_mv + step;
195 let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv, MAX_DIST);
196 if best_dist > dist {
197 best_dist = dist;
198 best_mv = mv;
199 }
200 }
201 (best_mv, best_dist)
202 });
203}
204
205macro_rules! pattern_search {
206 ($struct_name: ident, $patterns: expr) => {
207 pub struct $struct_name {
208 point: [MV; $patterns.len()],
209 dist: [u32; $patterns.len()],
210 steps: &'static [MV; $patterns.len()],
211 }
212
213 impl $struct_name {
214 pub fn new() -> Self {
215 Self {
216 point: $patterns,
217 dist: [MAX_DIST; $patterns.len()],
218 steps: &$patterns,
219 }
220 }
221 fn reset(&mut self) {
222 self.point = $patterns;
223 self.dist = [MAX_DIST; $patterns.len()];
224 }
225 fn set_new_point(&mut self, start: MV, dist: u32) {
226 for (dst, &src) in self.point.iter_mut().zip(self.steps.iter()) {
227 *dst = src + start;
228 }
229 self.dist = [MAX_DIST; $patterns.len()];
230 self.dist[0] = dist;
231 }
232 fn update(&mut self, step: MV) {
233 let mut new_point = self.point;
234 let mut new_dist = [MAX_DIST; $patterns.len()];
235
236 for point in new_point.iter_mut() {
237 *point += step;
238 }
239
240 for (new_point, new_dist) in new_point.iter_mut().zip(new_dist.iter_mut()) {
241 for (&old_point, &old_dist) in self.point.iter().zip(self.dist.iter()) {
242 if *new_point == old_point {
243 *new_dist = old_dist;
244 break;
245 }
246 }
247 }
248 self.point = new_point;
249 self.dist = new_dist;
250 }
251 }
252
253 impl MVSearch for $struct_name {
254 fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &RefMBData, mb_x: usize, mb_y: usize, _cand_mvs: &[MV]) -> (MV, u32) {
255 search_template!(self, mv_est, cur_mb, mb_x, mb_y, sad_mb, DIST_THRESH)
256 }
257 fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) {
258 search_template!(self, mv_est, ref_blk, xpos, ypos, sad_blk8, DIST_THRESH / 4)
259 }
260 }
261 }
262}
263
264pattern_search!(DiaSearch, DIA_PATTERN);
265pattern_search!(HexSearch, HEX_PATTERN);
266
267const LARGE_HEX_PATTERN: [MV; 16] = [
268 MV { x: -4, y: 0 },
269 MV { x: -4, y: 1 },
270 MV { x: -4, y: 2 },
271 MV { x: -2, y: 3 },
272 MV { x: 0, y: 4 },
273 MV { x: 2, y: 3 },
274 MV { x: 4, y: 2 },
275 MV { x: 4, y: 1 },
276 MV { x: 4, y: 0 },
277 MV { x: 4, y: -1 },
278 MV { x: 4, y: -2 },
279 MV { x: -2, y: -3 },
280 MV { x: 0, y: -4 },
281 MV { x: -2, y: -3 },
282 MV { x: -4, y: -2 },
283 MV { x: -4, y: -1 }
284];
285
286const UNSYMM_CROSS: [MV; 4] = [
287 MV { x: -2, y: 0 },
288 MV { x: 0, y: 1 },
289 MV { x: 2, y: 0 },
290 MV { x: 0, y: -1 }
291];
292
293#[derive(Default)]
294struct UniqueSet<T:Copy+Default> {
295 list: [T; 16],
296 count: usize,
297}
298
299impl<T:Copy+Default+PartialEq> UniqueSet<T> {
300 fn new() -> Self { Self::default() }
301 fn clear(&mut self) { self.count = 0; }
302 fn get_list(&self) -> &[T] { &self.list[..self.count] }
303 fn add(&mut self, val: T) {
304 if self.count < self.list.len() && !self.get_list().contains(&val) {
305 self.list[self.count] = val;
306 self.count += 1;
307 }
308 }
309}
310
311trait MVOps {
312 fn scale(self, scale: i16) -> Self;
313 fn is_in_range(self, range: i16) -> bool;
314}
315
316impl MVOps for MV {
317 fn scale(self, scale: i16) -> MV {
318 MV { x: self.x * scale, y: self.y * scale }
319 }
320 fn is_in_range(self, range: i16) -> bool {
321 self.x.abs() <= range && self.y.abs() <= range
322 }
323}
324
325macro_rules! single_search_step {
326 ($start:expr, $best_dist:expr, $mv_est:expr, $sad_func:ident, $ref_blk:expr, $xpos:expr, $ypos:expr, $pattern:expr, $scale:expr, $dist_thr:expr) => {{
327 let mut best_mv = $start;
328 let mut best_dist = $best_dist;
329 for point in $pattern.iter() {
330 let mv = point.scale($scale) + $start;
331 if !mv.is_in_range($mv_est.mv_range * 4) {
332 continue;
333 }
334 let dist = $mv_est.$sad_func($ref_blk, $xpos, $ypos, mv, best_dist);
335 if dist < best_dist {
336 best_mv = mv;
337 best_dist = dist;
338 if best_dist < $dist_thr {
339 break;
340 }
341 }
342 }
343 (best_mv, best_dist, best_mv != $start)
344 }}
345}
346
347struct UnevenHexSearch {
348 mv_list: UniqueSet<MV>,
349}
350
351impl UnevenHexSearch {
352 fn new() -> Self {
353 Self {
354 mv_list: UniqueSet::new(),
355 }
356 }
357 fn get_cand_mv(&mut self, cand_mvs: &[MV]) -> MV {
358 self.mv_list.clear();
359 for &mv in cand_mvs.iter() {
360 self.mv_list.add(mv);
361 }
362 match self.mv_list.count {
363 1 => self.mv_list.list[0],
364 3 => MV::pred(self.mv_list.list[0], self.mv_list.list[1], self.mv_list.list[2]),
365 _ => {
366 let sum = self.mv_list.get_list().iter().fold((0i32, 0i32),
367 |acc, mv| (acc.0 + i32::from(mv.x), acc.1 + i32::from(mv.y)));
368 MV {x: (sum.0 / (self.mv_list.count as i32)) as i16,
369 y: (sum.1 / (self.mv_list.count as i32)) as i16}
370 },
371 }
372 }
373}
374
375macro_rules! umh_search_template {
376 ($cand_mv:expr, $cutoff:expr, $mv_est:expr, $sad_func:ident, $ref_blk:expr, $xpos:expr, $ypos:expr) => {{
377 let cand_mv = $cand_mv;
378 let best_dist = $mv_est.$sad_func($ref_blk, $xpos, $ypos, cand_mv, MAX_DIST);
379 if best_dist < $cutoff {
380 return (cand_mv, best_dist);
381 }
382
383 // step 1 - small refinement search
384 let (mut cand_mv, mut best_dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, DIA_PATTERN, 1, $cutoff);
385 if best_dist < $cutoff {
386 return (cand_mv, best_dist);
387 }
388
389 // step 2 - unsymmetrical cross search
390 loop {
391 let (mv, dist, changed) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, UNSYMM_CROSS, 4, $cutoff);
392 if !changed {
393 break;
394 }
395 cand_mv = mv;
396 best_dist = dist;
397 if best_dist < $cutoff {
398 return (mv, dist);
399 }
400 }
401
402 // step 3 - multi-hexagon grid search
403 let mut scale = 4;
404 while scale > 0 {
405 let (mv, dist, changed) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, LARGE_HEX_PATTERN, scale, $cutoff);
406 if !changed {
407 break;
408 }
409 cand_mv = mv;
410 best_dist = dist;
411 if best_dist < $cutoff {
412 return (mv, dist);
413 }
414 scale >>= 1;
415 }
416 // step 4 - final hexagon search
417 let (cand_mv, best_dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, HEX_PATTERN, 1, $cutoff);
418 if best_dist > $cutoff {
419 let (mv, dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, DIA_PATTERN, 1, $cutoff);
420 (mv, dist)
421 } else {
422 (cand_mv, best_dist)
423 }
424 }}
425}
426
427impl MVSearch for UnevenHexSearch {
428 fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32) {
429 let cand_mv = self.get_cand_mv(cand_mvs);
430 let cutoff = mv_est.cutoff_thr;
431 umh_search_template!(cand_mv, cutoff, mv_est, sad_mb, cur_mb, mb_x, mb_y)
432 }
433 fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32) {
434 let cand_mv = self.get_cand_mv(cand_mvs);
435 let cutoff = mv_est.cutoff_thr / 4;
436 umh_search_template!(cand_mv, cutoff, mv_est, sad_blk8, ref_blk, xpos, ypos)
437 }
438}
439
440struct MVEstimator<'a> {
441 pic: &'a NAVideoBuffer<u8>,
442 mv_range: i16,
443 cutoff_thr: u32,
444}
445
446macro_rules! sad {
447 ($src1:expr, $src2:expr) => {
448 $src1.iter().zip($src2.iter()).fold(0u32, |acc, (&a, &b)|
449 acc + (((i32::from(a) - i32::from(b)) * (i32::from(a) - i32::from(b))) as u32))
450 }
451}
452
453impl<'a> MVEstimator<'a> {
454 fn sad_mb(&self, ref_mb: &RefMBData, mb_x: usize, mb_y: usize, mv: MV, cur_best_dist: u32) -> u32 {
455 let mut dst = RefMBData::new();
456 luma_mc(&mut dst.y, 16, self.pic, mb_x * 16, mb_y * 16, mv, true);
457
458 let mut dist = 0;
459 for (dline, sline) in dst.y.chunks(16).zip(ref_mb.y.chunks(16)) {
460 dist += sad!(dline, sline);
461 if dist > cur_best_dist {
462 return dist;
463 }
464 }
465 chroma_mc(&mut dst.u, 8, self.pic, mb_x * 8, mb_y * 8, 1, mv, true);
466 dist += sad!(dst.u, ref_mb.u);
467 if dist > cur_best_dist {
468 return dist;
469 }
470 chroma_mc(&mut dst.v, 8, self.pic, mb_x * 8, mb_y * 8, 2, mv, true);
471 dist += sad!(dst.v, ref_mb.v);
472
473 dist
474 }
475 fn sad_blk8(&self, ref_mb: &RefMBData, xpos: usize, ypos: usize, mv: MV, cur_best_dist: u32) -> u32 {
476 let mut cur_y = [0; 64];
477 let mut cur_u = [0; 16];
478 let mut cur_v = [0; 16];
479
480 let mut dist = 0;
481
482 let y_off = (xpos & 8) + (ypos & 8) * 16;
483 luma_mc(&mut cur_y, 8, self.pic, xpos, ypos, mv, false);
484 for (dline, sline) in cur_y.chunks(8).zip(ref_mb.y[y_off..].chunks(16)) {
485 dist += sad!(dline, sline);
486 if dist > cur_best_dist {
487 return dist;
488 }
489 }
490
491 let c_off = (xpos & 8) / 2 + (ypos & 8) * 4;
492 chroma_mc(&mut cur_u, 4, self.pic, xpos / 2, ypos / 2, 1, mv, false);
493 for (dline, sline) in cur_u.chunks(4).zip(ref_mb.u[c_off..].chunks(8)) {
494 dist += sad!(dline, sline);
495 if dist > cur_best_dist {
496 return dist;
497 }
498 }
499 chroma_mc(&mut cur_v, 4, self.pic, xpos / 2, ypos / 2, 2, mv, false);
500 for (dline, sline) in cur_v.chunks(4).zip(ref_mb.v[c_off..].chunks(8)) {
501 dist += sad!(dline, sline);
502 if dist > cur_best_dist {
503 return dist;
504 }
505 }
506
507 dist
508 }
509}
510
511trait MVSearch {
512 fn search_mb(&mut self, mv_est: &mut MVEstimator, ref_mb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32);
513 fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32);
514}
515
516struct DummySearcher {}
517
518impl MVSearch for DummySearcher {
519 fn search_mb(&mut self, _mv_est: &mut MVEstimator, _ref_mb: &RefMBData, _mb_x: usize, _mb_y: usize, _cand_mvs: &[MV]) -> (MV, u32) {
520 (ZERO_MV, std::u32::MAX / 2)
521 }
522 fn search_blk8(&mut self, _mv_est: &mut MVEstimator, _ref_mb: &RefMBData, _xpos: usize, _ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) {
523 (ZERO_MV, std::u32::MAX / 2)
524 }
525}
526
527pub struct MotionEstimator {
528 pub range: i16,
529 pub thresh: u32,
530 mode: MVSearchMode,
531 srch: Box<dyn MVSearch+Send>,
532}
533
534impl MotionEstimator {
535 pub fn new() -> Self {
536 let mode = MVSearchMode::default();
537 Self {
538 range: 64,
539 thresh: 32,
540 mode,
541 srch: mode.create(),
542 }
543 }
544 pub fn get_mode(&self) -> MVSearchMode { self.mode }
545 pub fn set_mode(&mut self, new_mode: MVSearchMode) {
546 if self.mode != new_mode {
547 self.mode = new_mode;
548 self.srch = self.mode.create();
549 }
550 }
551 pub fn search_mb_p(&mut self, pic: &NAVideoBuffer<u8>, refmb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32) {
552 let mut mv_est = MVEstimator {
553 mv_range: self.range,
554 cutoff_thr: self.thresh,
555 pic,
556 };
557 self.srch.search_mb(&mut mv_est, refmb, mb_x, mb_y, cand_mvs)
558 }
559 pub fn search_blk8(&mut self, pic: &NAVideoBuffer<u8>, refmb: &RefMBData, xoff: usize, yoff: usize, cand_mvs: &[MV]) -> (MV, u32) {
560 let mut mv_est = MVEstimator {
561 mv_range: self.range,
562 cutoff_thr: self.thresh,
563 pic,
564 };
565 self.srch.search_blk8(&mut mv_est, refmb, xoff, yoff, cand_mvs)
566 }
567}
568
569pub struct SearchB<'a> {
570 ref_p: &'a NAVideoBuffer<u8>,
571 ref_n: &'a NAVideoBuffer<u8>,
572 xpos: usize,
573 ypos: usize,
574 ratios: [u32; 2],
575 tmp1: RefMBData,
576 tmp2: RefMBData,
577 pred_blk: RefMBData,
578}
579
580impl<'a> SearchB<'a> {
581 pub fn new(ref_p: &'a NAVideoBuffer<u8>, ref_n: &'a NAVideoBuffer<u8>, mb_x: usize, mb_y: usize, ratios: [u32; 2]) -> Self {
582 Self {
583 ref_p, ref_n,
584 xpos: mb_x * 16,
585 ypos: mb_y * 16,
586 ratios,
587 tmp1: RefMBData::new(),
588 tmp2: RefMBData::new(),
589 pred_blk: RefMBData::new(),
590 }
591 }
592 pub fn search_mb(&mut self, ref_mb: &RefMBData, cand_mvs: [MV; 2]) -> (MV, MV) {
593 let mut best_cand = cand_mvs;
594 let mut best_dist = self.interp_b_dist(ref_mb, best_cand, MAX_DIST);
595
596 loop {
597 let mut improved = false;
598 for &fmv_add in DIA_PATTERN.iter() {
599 for &bmv_add in DIA_PATTERN.iter() {
600 let cand = [best_cand[0] + fmv_add.from_pixels(),
601 best_cand[1] + bmv_add.from_pixels()];
602 let dist = self.interp_b_dist(ref_mb, cand, best_dist);
603 if dist < best_dist {
604 best_dist = dist;
605 best_cand = cand;
606 improved = true;
607 }
608 }
609 }
610 if !improved {
611 break;
612 }
613 }
614
615 for &fmv_add in REFINEMENT.iter() {
616 for &bmv_add in REFINEMENT.iter() {
617 let cand = [best_cand[0] + fmv_add, best_cand[1] + bmv_add];
618 let dist = self.interp_b_dist(ref_mb, cand, best_dist);
619 if dist < best_dist {
620 best_dist = dist;
621 best_cand = cand;
622 }
623 }
624 }
625
626 (best_cand[0], best_cand[1])
627 }
628 fn interp_b_dist(&mut self, ref_mb: &RefMBData, cand_mv: [MV; 2], cur_best_dist: u32) -> u32 {
629 let [fmv, bmv] = cand_mv;
630 luma_mc(&mut self.tmp1.y, 16, self.ref_p, self.xpos, self.ypos, fmv, true);
631 chroma_mc(&mut self.tmp1.u, 8, self.ref_p, self.xpos / 2, self.ypos / 2, 1, fmv, true);
632 chroma_mc(&mut self.tmp1.v, 8, self.ref_p, self.xpos / 2, self.ypos / 2, 2, fmv, true);
633 luma_mc(&mut self.tmp2.y, 16, self.ref_n, self.xpos, self.ypos, bmv, true);
634 chroma_mc(&mut self.tmp2.u, 8, self.ref_n, self.xpos / 2, self.ypos / 2, 1, bmv, true);
635 chroma_mc(&mut self.tmp2.v, 8, self.ref_n, self.xpos / 2, self.ypos / 2, 2, bmv, true);
636 self.pred_blk.avg(&self.tmp1, self.ratios[0], &self.tmp2, self.ratios[1]);
637
638 let mut dist = 0;
639 for (dline, sline) in self.pred_blk.y.chunks(16).zip(ref_mb.y.chunks(16)) {
640 dist += sad!(dline, sline);
641 if dist > cur_best_dist {
642 return dist;
643 }
644 }
645 dist += sad!(self.pred_blk.u, ref_mb.u);
646 if dist > cur_best_dist {
647 return dist;
648 }
649 dist += sad!(self.pred_blk.v, ref_mb.v);
650
651 dist
652 }
653}
654
655macro_rules! hadamard {
656 ($s0:expr, $s1:expr, $s2:expr, $s3:expr, $d0:expr, $d1:expr, $d2:expr, $d3:expr) => {
657 let t0 = $s0 + $s1;
658 let t1 = $s0 - $s1;
659 let t2 = $s2 + $s3;
660 let t3 = $s2 - $s3;
661 $d0 = t0 + t2;
662 $d2 = t0 - t2;
663 $d1 = t1 + t3;
664 $d3 = t1 - t3;
665 }
666}
667
668pub struct FrameComplexityEstimate {
669 ref_frm: NAVideoBufferRef<u8>,
670 cur_frm: NAVideoBufferRef<u8>,
671 nxt_frm: NAVideoBufferRef<u8>,
672 width: usize,
673 height: usize,
674}
675
676impl FrameComplexityEstimate {
677 pub fn new() -> Self {
678 let vinfo = NAVideoInfo::new(24, 24, false, YUV420_FORMAT);
679 let vt = alloc_video_buffer(vinfo, 4).unwrap();
680 let buf = vt.get_vbuf().unwrap();
681 Self {
682 ref_frm: buf.clone(),
683 cur_frm: buf.clone(),
684 nxt_frm: buf,
685 width: 0,
686 height: 0,
687 }
688 }
689 pub fn resize(&mut self, width: usize, height: usize) {
690 if width != self.width || height != self.height {
691 self.width = width;
692 self.height = height;
693
694 let vinfo = NAVideoInfo::new(self.width / 2, self.height / 2, false, YUV420_FORMAT);
695 let vt = alloc_video_buffer(vinfo, 4).unwrap();
696 self.ref_frm = vt.get_vbuf().unwrap();
697 let frm = self.ref_frm.get_data_mut().unwrap();
698 for el in frm.iter_mut() {
699 *el = 0x80;
700 }
701 let vt = alloc_video_buffer(vinfo, 4).unwrap();
702 self.cur_frm = vt.get_vbuf().unwrap();
703 let vt = alloc_video_buffer(vinfo, 4).unwrap();
704 self.nxt_frm = vt.get_vbuf().unwrap();
705 }
706 }
707 pub fn set_current(&mut self, frm: &NAVideoBuffer<u8>) {
708 Self::downscale(&mut self.cur_frm, frm);
709 }
710 pub fn get_complexity(&self, ftype: FrameType) -> u32 {
711 match ftype {
712 FrameType::I => Self::calculate_i_cplx(&self.cur_frm),
713 FrameType::P => Self::calculate_mv_diff(&self.ref_frm, &self.cur_frm),
714 _ => 0,
715 }
716 }
717 pub fn decide_b_frame(&mut self, frm1: &NAVideoBuffer<u8>, frm2: &NAVideoBuffer<u8>) -> bool {
718 Self::downscale(&mut self.cur_frm, frm1);
719 Self::downscale(&mut self.nxt_frm, frm2);
720 let diff_ref_cur = Self::calculate_mv_diff(&self.ref_frm, &self.cur_frm);
721 let diff_cur_nxt = Self::calculate_mv_diff(&self.cur_frm, &self.nxt_frm);
722
723 // simple rule - if complexity ref->cur and cur->next is about the same this should be a B-frame
724 let ddiff = diff_ref_cur.max(diff_cur_nxt) - diff_ref_cur.min(diff_cur_nxt);
725 if ddiff < 256 {
726 true
727 } else {
728 let mut order = 0;
729 while (ddiff << order) < diff_ref_cur.min(diff_cur_nxt) {
730 order += 1;
731 }
732 order > 2
733 }
734 }
735 pub fn update_ref(&mut self) {
736 std::mem::swap(&mut self.ref_frm, &mut self.cur_frm);
737 }
738
739 fn add_mv(mb_x: usize, mb_y: usize, mv: MV) -> (usize, usize) {
740 (((mb_x * 16) as isize + (mv.x as isize)) as usize,
741 ((mb_y * 16) as isize + (mv.y as isize)) as usize)
742 }
743 fn calculate_i_cplx(frm: &NAVideoBuffer<u8>) -> u32 {
744 let (w, h) = frm.get_dimensions(0);
745 let src = frm.get_data();
746 let stride = frm.get_stride(0);
747 let mut sum = 0;
748 let mut offset = 0;
749 for y in (0..h).step_by(4) {
750 for x in (0..w).step_by(4) {
751 sum += Self::satd_i(src, offset + x, stride, x > 0, y > 0);
752 }
753 offset += stride * 4;
754 }
755 sum
756 }
757 fn calculate_mv_diff(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>) -> u32 {
758 let (w, h) = ref_frm.get_dimensions(0);
759 let mut sum = 0;
760 for mb_y in 0..(h / 16) {
761 for mb_x in 0..(w / 16) {
762 sum += Self::satd_mb_diff(ref_frm, cur_frm, mb_x, mb_y);
763 }
764 }
765 sum
766 }
767 fn satd_mb_diff(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>, mb_x: usize, mb_y: usize) -> u32 {
768 let mv = Self::search_mv(ref_frm, cur_frm, mb_x, mb_y);
769 let mut sum = 0;
770 let src0 = ref_frm.get_data();
771 let src1 = cur_frm.get_data();
772 let stride = ref_frm.get_stride(0);
773 let (src_x, src_y) = Self::add_mv(mb_x, mb_y, mv);
774 for y in (0..16).step_by(4) {
775 for x in (0..16).step_by(4) {
776 sum += Self::satd(&src0[src_x + x + (src_y + y) * stride..],
777 &src1[mb_x * 16 + x + (mb_y * 16 + y) * stride..],
778 stride);
779 }
780 }
781 sum
782 }
783 fn search_mv(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>, mb_x: usize, mb_y: usize) -> MV {
784 let stride = ref_frm.get_stride(0);
785 let (w, h) = ref_frm.get_dimensions(0);
786 let (v_edge, h_edge) = (w - 16, h - 16);
787 let ref_src = ref_frm.get_data();
788 let cur_src = cur_frm.get_data();
789 let cur_src = &cur_src[mb_x * 16 + mb_y * 16 * stride..];
790
791 let mut best_mv = ZERO_MV;
792 let mut best_dist = Self::sad(cur_src, ref_src, mb_x, mb_y, stride, best_mv);
793 if best_dist == 0 {
794 return best_mv;
795 }
796
797 for step in (0..=2).rev() {
798 let mut changed = true;
799 while changed {
800 changed = false;
801 for &mv in DIA_PATTERN[1..].iter() {
802 let cand_mv = best_mv + mv.scale(1 << step);
803 let (cx, cy) = Self::add_mv(mb_x, mb_y, cand_mv);
804 if cx > v_edge || cy > h_edge {
805 continue;
806 }
807 let cand_dist = Self::sad(cur_src, ref_src, mb_x, mb_y, stride, cand_mv);
808 if cand_dist < best_dist {
809 best_dist = cand_dist;
810 best_mv = cand_mv;
811 if best_dist == 0 {
812 return best_mv;
813 }
814 changed = true;
815 }
816 }
817 }
818 }
819 best_mv
820 }
821 fn sad(cur_src: &[u8], src: &[u8], mb_x: usize, mb_y: usize, stride: usize, mv: MV) -> u32 {
822 let (src_x, src_y) = Self::add_mv(mb_x, mb_y, mv);
823 let mut sum = 0;
824 for (line1, line2) in cur_src.chunks(stride).zip(src[src_x + src_y * stride..].chunks(stride)).take(16) {
825 sum += line1[..16].iter().zip(line2[..16].iter()).fold(0u32,
826 |acc, (&a, &b)| acc + u32::from(a.max(b) - a.min(b)) * u32::from(a.max(b) - a.min(b)));
827 }
828 sum
829 }
830 fn satd_i(src: &[u8], mut offset: usize, stride: usize, has_left: bool, has_top: bool) -> u32 {
831 let mut diffs = [0; 16];
832 match (has_left, has_top) {
833 (true, true) => {
834 for row in diffs.chunks_exact_mut(4) {
835 let mut left = i16::from(src[offset - 1]);
836 let mut tl = i16::from(src[offset - stride - 1]);
837 for (x, dst) in row.iter_mut().enumerate() {
838 let cur = i16::from(src[offset + x]);
839 let top = i16::from(src[offset + x - stride]);
840
841 *dst = cur - (top + left + tl - top.min(left).min(tl) - top.max(left).max(tl));
842
843 left = cur;
844 tl = top;
845 }
846
847 offset += stride;
848 }
849 },
850 (true, false) => {
851 for (dst, (left, cur)) in diffs.chunks_exact_mut(4).zip(
852 src[offset - 1..].chunks(stride).zip(src[offset..].chunks(stride))) {
853 for (dst, (&left, &cur)) in dst.iter_mut().zip(left.iter().zip(cur.iter())) {
854 *dst = i16::from(cur) - i16::from(left);
855 }
856 }
857 },
858 (false, true) => {
859 for (dst, (top, cur)) in diffs.chunks_exact_mut(4).zip(
860 src[offset - stride..].chunks(stride).zip(src[offset..].chunks(stride))) {
861 for (dst, (&top, &cur)) in dst.iter_mut().zip(top.iter().zip(cur.iter())) {
862 *dst = i16::from(cur) - i16::from(top);
863 }
864 }
865 },
866 (false, false) => {
867 for (dst, src) in diffs.chunks_exact_mut(4).zip(src[offset..].chunks(stride)) {
868 for (dst, &src) in dst.iter_mut().zip(src.iter()) {
869 *dst = i16::from(src) - 128;
870 }
871 }
872 },
873 };
874 for row in diffs.chunks_exact_mut(4) {
875 hadamard!(row[0], row[1], row[2], row[3], row[0], row[1], row[2], row[3]);
876 }
877 for i in 0..4 {
878 hadamard!(diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12],
879 diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12]);
880 }
881 diffs.iter().fold(0u32, |acc, x| acc + (x.abs() as u32))
882 }
883 fn satd(src0: &[u8], src1: &[u8], stride: usize) -> u32 {
884 let mut diffs = [0; 16];
885 for (dst, (src0, src1)) in diffs.chunks_exact_mut(4).zip(
886 src0.chunks(stride).zip(src1.chunks(stride))) {
887 hadamard!(i16::from(src0[0]) - i16::from(src1[0]),
888 i16::from(src0[1]) - i16::from(src1[1]),
889 i16::from(src0[2]) - i16::from(src1[2]),
890 i16::from(src0[3]) - i16::from(src1[3]),
891 dst[0], dst[1], dst[2], dst[3]);
892 }
893 for i in 0..4 {
894 hadamard!(diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12],
895 diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12]);
896 }
897 diffs.iter().fold(0u32, |acc, x| acc + (x.abs() as u32))
898 }
899 fn downscale(dst: &mut NAVideoBuffer<u8>, src: &NAVideoBuffer<u8>) {
900 let dst = NASimpleVideoFrame::from_video_buf(dst).unwrap();
901 let sdata = src.get_data();
902 for plane in 0..3 {
903 let cur_w = dst.width[plane];
904 let cur_h = dst.height[plane];
905 let doff = dst.offset[plane];
906 let soff = src.get_offset(plane);
907 let dstride = dst.stride[plane];
908 let sstride = src.get_stride(plane);
909 for (dline, sstrip) in dst.data[doff..].chunks_exact_mut(dstride).zip(
910 sdata[soff..].chunks_exact(sstride * 2)).take(cur_h) {
911 let (line0, line1) = sstrip.split_at(sstride);
912 for (dst, (src0, src1)) in dline.iter_mut().zip(
913 line0.chunks_exact(2).zip(line1.chunks_exact(2))).take(cur_w) {
914 *dst = ((u16::from(src0[0]) + u16::from(src0[1]) +
915 u16::from(src1[0]) + u16::from(src1[1]) + 2) >> 2) as u8;
916 }
917 }
918 }
919 }
920}