fix clippy warnings
[nihav.git] / nihav-realmedia / src / codecs / rv40enc / motion_est.rs
CommitLineData
4965a5e5
KS
1use nihav_core::frame::*;
2use nihav_codec_support::codecs::{MV, ZERO_MV};
3use std::str::FromStr;
4use super::dsp::{RefMBData, luma_mc, chroma_mc};
5
e6aaad5c 6#[derive(Clone,Copy,PartialEq,Default)]
4965a5e5
KS
7pub enum MVSearchMode {
8 Dummy,
9 Diamond,
e6aaad5c 10 #[default]
4965a5e5
KS
11 Hexagon,
12 UMH,
13}
14
15impl MVSearchMode {
16 pub const fn get_possible_modes() -> &'static [&'static str] {
17 &["diamond", "hexagon", "umh"]
18 }
19 fn create(self) -> Box<dyn MVSearch+Send> {
20 match self {
21 MVSearchMode::Dummy => Box::new(DummySearcher{}),
22 MVSearchMode::Diamond => Box::new(DiaSearch::new()),
23 MVSearchMode::Hexagon => Box::new(HexSearch::new()),
24 MVSearchMode::UMH => Box::new(UnevenHexSearch::new()),
25 }
26 }
27}
28
4965a5e5
KS
29impl std::fmt::Display for MVSearchMode {
30 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
31 match *self {
32 MVSearchMode::Diamond => write!(f, "diamond"),
33 MVSearchMode::Hexagon => write!(f, "hexagon"),
34 MVSearchMode::UMH => write!(f, "umh"),
35 MVSearchMode::Dummy => write!(f, "dummy"),
36 }
37 }
38}
39
40impl FromStr for MVSearchMode {
41 type Err = ();
42 fn from_str(s: &str) -> Result<Self, Self::Err> {
43 match s {
44 "diamond" => Ok(MVSearchMode::Diamond),
45 "hexagon" => Ok(MVSearchMode::Hexagon),
46 "umh" => Ok(MVSearchMode::UMH),
47 "dummy" => Ok(MVSearchMode::Dummy),
48 _ => Err(()),
49 }
50 }
51}
52
53const MAX_DIST: u32 = std::u32::MAX;
54const DIST_THRESH: u32 = 256;
55
56trait FromPixels {
57 fn from_pixels(self) -> Self;
58}
59
60impl FromPixels for MV {
61 fn from_pixels(self) -> MV {
62 MV { x: self.x * 4, y: self.y * 4 }
63 }
64}
65
66const DIA_PATTERN: [MV; 9] = [
67 ZERO_MV,
68 MV {x: -2, y: 0},
69 MV {x: -1, y: 1},
70 MV {x: 0, y: 2},
71 MV {x: 1, y: 1},
72 MV {x: 2, y: 0},
73 MV {x: 1, y: -1},
74 MV {x: 0, y: -2},
75 MV {x: -1, y: -1}
76];
77
78const HEX_PATTERN: [MV; 7] = [
79 ZERO_MV,
80 MV {x: -2, y: 0},
81 MV {x: -1, y: 2},
82 MV {x: 1, y: 2},
83 MV {x: 2, y: 0},
84 MV {x: 1, y: -2},
85 MV {x: -1, y: -2}
86];
87
88const REFINEMENT: [MV; 4] = [
89 MV {x: -1, y: 0},
90 MV {x: 0, y: 1},
91 MV {x: 1, y: 0},
92 MV {x: 0, y: -1}
93];
94
95macro_rules! search_template {
96 ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident, $threshold: expr) => ({
97 search_template!($self, $mv_est, $cur_blk, $mb_x, $mb_y, $sad_func, $threshold, ZERO_MV, MAX_DIST, true)
98 });
99 ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident, $threshold: expr, $start_mv: expr, $best_dist: expr, $fullpel_stage: expr) => ({
100 let mut best_dist = $best_dist;
101 let mut best_mv = $start_mv;
102
103 let mut min_dist;
104 let mut min_idx;
105
106 if $fullpel_stage {
107 $self.reset();
108 loop {
109 let mut cur_best_dist = best_dist;
110 for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) {
111 if *dist == MAX_DIST {
112 *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point.from_pixels(), cur_best_dist);
113 cur_best_dist = cur_best_dist.min(*dist);
114 if *dist <= $threshold {
115 break;
116 }
117 }
118 }
119 min_dist = $self.dist[0];
120 min_idx = 0;
121 for (i, &dist) in $self.dist.iter().enumerate().skip(1) {
122 if dist < min_dist {
123 min_dist = dist;
124 min_idx = i;
125 if dist <= $threshold {
126 break;
127 }
128 }
129 }
130 if min_dist <= $threshold || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range || $self.point[min_idx].y.abs() >= $mv_est.mv_range {
131 break;
132 }
133 best_dist = min_dist;
134 $self.update($self.steps[min_idx]);
135 }
136 best_dist = min_dist;
137 best_mv = $self.point[min_idx];
138 if best_dist <= $threshold {
139 return (best_mv.from_pixels(), best_dist);
140 }
141 for &step in REFINEMENT.iter() {
142 let mv = best_mv + step;
143 let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv.from_pixels(), MAX_DIST);
144 if best_dist > dist {
145 best_dist = dist;
146 best_mv = mv;
147 }
148 }
149 best_mv = best_mv.from_pixels();
150 if best_dist <= $threshold {
151 return (best_mv, best_dist);
152 }
153 }
154
155 // subpel refinement
156 $self.set_new_point(best_mv, best_dist);
157 loop {
158 let mut cur_best_dist = best_dist;
159 for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) {
160 if *dist == MAX_DIST {
161 *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point, cur_best_dist);
162 cur_best_dist = cur_best_dist.min(*dist);
163 if *dist <= $threshold {
164 break;
165 }
166 }
167 }
168 min_dist = $self.dist[0];
169 min_idx = 0;
170 for (i, &dist) in $self.dist.iter().enumerate().skip(1) {
171 if dist < min_dist {
172 min_dist = dist;
173 min_idx = i;
174 if dist <= $threshold {
175 break;
176 }
177 }
178 }
179 if min_dist <= $threshold || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range * 8 || $self.point[min_idx].y.abs() >= $mv_est.mv_range * 8 {
180 break;
181 }
182 best_dist = min_dist;
183 $self.update($self.steps[min_idx]);
184 }
185 best_dist = min_dist;
186 best_mv = $self.point[min_idx];
187 if best_dist <= $threshold {
188 return (best_mv, best_dist);
189 }
190 for &step in REFINEMENT.iter() {
191 let mv = best_mv + step;
192 let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv, MAX_DIST);
193 if best_dist > dist {
194 best_dist = dist;
195 best_mv = mv;
196 }
197 }
198 (best_mv, best_dist)
199 });
200}
201
202macro_rules! pattern_search {
203 ($struct_name: ident, $patterns: expr) => {
204 pub struct $struct_name {
205 point: [MV; $patterns.len()],
206 dist: [u32; $patterns.len()],
207 steps: &'static [MV; $patterns.len()],
208 }
209
210 impl $struct_name {
211 pub fn new() -> Self {
212 Self {
213 point: $patterns,
214 dist: [MAX_DIST; $patterns.len()],
215 steps: &$patterns,
216 }
217 }
218 fn reset(&mut self) {
219 self.point = $patterns;
220 self.dist = [MAX_DIST; $patterns.len()];
221 }
222 fn set_new_point(&mut self, start: MV, dist: u32) {
223 for (dst, &src) in self.point.iter_mut().zip(self.steps.iter()) {
224 *dst = src + start;
225 }
226 self.dist = [MAX_DIST; $patterns.len()];
227 self.dist[0] = dist;
228 }
229 fn update(&mut self, step: MV) {
230 let mut new_point = self.point;
231 let mut new_dist = [MAX_DIST; $patterns.len()];
232
233 for point in new_point.iter_mut() {
234 *point += step;
235 }
236
237 for (new_point, new_dist) in new_point.iter_mut().zip(new_dist.iter_mut()) {
238 for (&old_point, &old_dist) in self.point.iter().zip(self.dist.iter()) {
239 if *new_point == old_point {
240 *new_dist = old_dist;
241 break;
242 }
243 }
244 }
245 self.point = new_point;
246 self.dist = new_dist;
247 }
248 }
249
250 impl MVSearch for $struct_name {
251 fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &RefMBData, mb_x: usize, mb_y: usize, _cand_mvs: &[MV]) -> (MV, u32) {
252 search_template!(self, mv_est, cur_mb, mb_x, mb_y, sad_mb, DIST_THRESH)
253 }
254 fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) {
255 search_template!(self, mv_est, ref_blk, xpos, ypos, sad_blk8, DIST_THRESH / 4)
256 }
257 }
258 }
259}
260
261pattern_search!(DiaSearch, DIA_PATTERN);
262pattern_search!(HexSearch, HEX_PATTERN);
263
264const LARGE_HEX_PATTERN: [MV; 16] = [
265 MV { x: -4, y: 0 },
266 MV { x: -4, y: 1 },
267 MV { x: -4, y: 2 },
268 MV { x: -2, y: 3 },
269 MV { x: 0, y: 4 },
270 MV { x: 2, y: 3 },
271 MV { x: 4, y: 2 },
272 MV { x: 4, y: 1 },
273 MV { x: 4, y: 0 },
274 MV { x: 4, y: -1 },
275 MV { x: 4, y: -2 },
276 MV { x: -2, y: -3 },
277 MV { x: 0, y: -4 },
278 MV { x: -2, y: -3 },
279 MV { x: -4, y: -2 },
280 MV { x: -4, y: -1 }
281];
282
283const UNSYMM_CROSS: [MV; 4] = [
284 MV { x: -2, y: 0 },
285 MV { x: 0, y: 1 },
286 MV { x: 2, y: 0 },
287 MV { x: 0, y: -1 }
288];
289
290#[derive(Default)]
291struct UniqueSet<T:Copy+Default> {
292 list: [T; 16],
293 count: usize,
294}
295
296impl<T:Copy+Default+PartialEq> UniqueSet<T> {
297 fn new() -> Self { Self::default() }
298 fn clear(&mut self) { self.count = 0; }
299 fn get_list(&self) -> &[T] { &self.list[..self.count] }
300 fn add(&mut self, val: T) {
301 if self.count < self.list.len() && !self.get_list().contains(&val) {
302 self.list[self.count] = val;
303 self.count += 1;
304 }
305 }
306}
307
308trait MVOps {
309 fn scale(self, scale: i16) -> Self;
310 fn is_in_range(self, range: i16) -> bool;
311}
312
313impl MVOps for MV {
314 fn scale(self, scale: i16) -> MV {
315 MV { x: self.x * scale, y: self.y * scale }
316 }
317 fn is_in_range(self, range: i16) -> bool {
318 self.x.abs() <= range && self.y.abs() <= range
319 }
320}
321
322macro_rules! single_search_step {
323 ($start:expr, $best_dist:expr, $mv_est:expr, $sad_func:ident, $ref_blk:expr, $xpos:expr, $ypos:expr, $pattern:expr, $scale:expr, $dist_thr:expr) => {{
324 let mut best_mv = $start;
325 let mut best_dist = $best_dist;
326 for point in $pattern.iter() {
327 let mv = point.scale($scale) + $start;
328 if !mv.is_in_range($mv_est.mv_range * 4) {
329 continue;
330 }
331 let dist = $mv_est.$sad_func($ref_blk, $xpos, $ypos, mv, best_dist);
332 if dist < best_dist {
333 best_mv = mv;
334 best_dist = dist;
335 if best_dist < $dist_thr {
336 break;
337 }
338 }
339 }
340 (best_mv, best_dist, best_mv != $start)
341 }}
342}
343
344struct UnevenHexSearch {
345 mv_list: UniqueSet<MV>,
346}
347
348impl UnevenHexSearch {
349 fn new() -> Self {
350 Self {
351 mv_list: UniqueSet::new(),
352 }
353 }
354 fn get_cand_mv(&mut self, cand_mvs: &[MV]) -> MV {
355 self.mv_list.clear();
356 for &mv in cand_mvs.iter() {
357 self.mv_list.add(mv);
358 }
359 match self.mv_list.count {
360 1 => self.mv_list.list[0],
361 3 => MV::pred(self.mv_list.list[0], self.mv_list.list[1], self.mv_list.list[2]),
362 _ => {
363 let sum = self.mv_list.get_list().iter().fold((0i32, 0i32),
364 |acc, mv| (acc.0 + i32::from(mv.x), acc.1 + i32::from(mv.y)));
365 MV {x: (sum.0 / (self.mv_list.count as i32)) as i16,
366 y: (sum.1 / (self.mv_list.count as i32)) as i16}
367 },
368 }
369 }
370}
371
372macro_rules! umh_search_template {
373 ($cand_mv:expr, $cutoff:expr, $mv_est:expr, $sad_func:ident, $ref_blk:expr, $xpos:expr, $ypos:expr) => {{
374 let cand_mv = $cand_mv;
375 let best_dist = $mv_est.$sad_func($ref_blk, $xpos, $ypos, cand_mv, MAX_DIST);
376 if best_dist < $cutoff {
377 return (cand_mv, best_dist);
378 }
379
380 // step 1 - small refinement search
381 let (mut cand_mv, mut best_dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, DIA_PATTERN, 1, $cutoff);
382 if best_dist < $cutoff {
383 return (cand_mv, best_dist);
384 }
385
386 // step 2 - unsymmetrical cross search
387 loop {
388 let (mv, dist, changed) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, UNSYMM_CROSS, 4, $cutoff);
389 if !changed {
390 break;
391 }
392 cand_mv = mv;
393 best_dist = dist;
394 if best_dist < $cutoff {
395 return (mv, dist);
396 }
397 }
398
399 // step 3 - multi-hexagon grid search
400 let mut scale = 4;
401 while scale > 0 {
402 let (mv, dist, changed) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, LARGE_HEX_PATTERN, scale, $cutoff);
403 if !changed {
404 break;
405 }
406 cand_mv = mv;
407 best_dist = dist;
408 if best_dist < $cutoff {
409 return (mv, dist);
410 }
411 scale >>= 1;
412 }
413 // step 4 - final hexagon search
414 let (cand_mv, best_dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, HEX_PATTERN, 1, $cutoff);
415 if best_dist > $cutoff {
416 let (mv, dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, DIA_PATTERN, 1, $cutoff);
417 (mv, dist)
418 } else {
419 (cand_mv, best_dist)
420 }
421 }}
422}
423
424impl MVSearch for UnevenHexSearch {
425 fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32) {
426 let cand_mv = self.get_cand_mv(cand_mvs);
427 let cutoff = mv_est.cutoff_thr;
428 umh_search_template!(cand_mv, cutoff, mv_est, sad_mb, cur_mb, mb_x, mb_y)
429 }
430 fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32) {
431 let cand_mv = self.get_cand_mv(cand_mvs);
432 let cutoff = mv_est.cutoff_thr / 4;
433 umh_search_template!(cand_mv, cutoff, mv_est, sad_blk8, ref_blk, xpos, ypos)
434 }
435}
436
437struct MVEstimator<'a> {
438 pic: &'a NAVideoBuffer<u8>,
439 mv_range: i16,
440 cutoff_thr: u32,
441}
442
443macro_rules! sad {
444 ($src1:expr, $src2:expr) => {
445 $src1.iter().zip($src2.iter()).fold(0u32, |acc, (&a, &b)|
446 acc + (((i32::from(a) - i32::from(b)) * (i32::from(a) - i32::from(b))) as u32))
447 }
448}
449
450impl<'a> MVEstimator<'a> {
451 fn sad_mb(&self, ref_mb: &RefMBData, mb_x: usize, mb_y: usize, mv: MV, cur_best_dist: u32) -> u32 {
452 let mut dst = RefMBData::new();
453 luma_mc(&mut dst.y, 16, self.pic, mb_x * 16, mb_y * 16, mv, true);
454
455 let mut dist = 0;
456 for (dline, sline) in dst.y.chunks(16).zip(ref_mb.y.chunks(16)) {
457 dist += sad!(dline, sline);
458 if dist > cur_best_dist {
459 return dist;
460 }
461 }
462 chroma_mc(&mut dst.u, 8, self.pic, mb_x * 8, mb_y * 8, 1, mv, true);
463 dist += sad!(dst.u, ref_mb.u);
464 if dist > cur_best_dist {
465 return dist;
466 }
467 chroma_mc(&mut dst.v, 8, self.pic, mb_x * 8, mb_y * 8, 2, mv, true);
468 dist += sad!(dst.v, ref_mb.v);
469
470 dist
471 }
472 fn sad_blk8(&self, ref_mb: &RefMBData, xpos: usize, ypos: usize, mv: MV, cur_best_dist: u32) -> u32 {
473 let mut cur_y = [0; 64];
474 let mut cur_u = [0; 16];
475 let mut cur_v = [0; 16];
476
477 let mut dist = 0;
478
479 let y_off = (xpos & 8) + (ypos & 8) * 16;
480 luma_mc(&mut cur_y, 8, self.pic, xpos, ypos, mv, false);
481 for (dline, sline) in cur_y.chunks(8).zip(ref_mb.y[y_off..].chunks(16)) {
482 dist += sad!(dline, sline);
483 if dist > cur_best_dist {
484 return dist;
485 }
486 }
487
488 let c_off = (xpos & 8) / 2 + (ypos & 8) * 4;
489 chroma_mc(&mut cur_u, 4, self.pic, xpos / 2, ypos / 2, 1, mv, false);
490 for (dline, sline) in cur_u.chunks(4).zip(ref_mb.u[c_off..].chunks(8)) {
491 dist += sad!(dline, sline);
492 if dist > cur_best_dist {
493 return dist;
494 }
495 }
496 chroma_mc(&mut cur_v, 4, self.pic, xpos / 2, ypos / 2, 2, mv, false);
497 for (dline, sline) in cur_v.chunks(4).zip(ref_mb.v[c_off..].chunks(8)) {
498 dist += sad!(dline, sline);
499 if dist > cur_best_dist {
500 return dist;
501 }
502 }
503
504 dist
505 }
506}
507
508trait MVSearch {
509 fn search_mb(&mut self, mv_est: &mut MVEstimator, ref_mb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32);
510 fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32);
511}
512
513struct DummySearcher {}
514
515impl MVSearch for DummySearcher {
516 fn search_mb(&mut self, _mv_est: &mut MVEstimator, _ref_mb: &RefMBData, _mb_x: usize, _mb_y: usize, _cand_mvs: &[MV]) -> (MV, u32) {
517 (ZERO_MV, std::u32::MAX / 2)
518 }
519 fn search_blk8(&mut self, _mv_est: &mut MVEstimator, _ref_mb: &RefMBData, _xpos: usize, _ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) {
520 (ZERO_MV, std::u32::MAX / 2)
521 }
522}
523
524pub struct MotionEstimator {
525 pub range: i16,
526 pub thresh: u32,
527 mode: MVSearchMode,
528 srch: Box<dyn MVSearch+Send>,
529}
530
531impl MotionEstimator {
532 pub fn new() -> Self {
533 let mode = MVSearchMode::default();
534 Self {
535 range: 64,
536 thresh: 32,
537 mode,
538 srch: mode.create(),
539 }
540 }
541 pub fn get_mode(&self) -> MVSearchMode { self.mode }
542 pub fn set_mode(&mut self, new_mode: MVSearchMode) {
543 if self.mode != new_mode {
544 self.mode = new_mode;
545 self.srch = self.mode.create();
546 }
547 }
548 pub fn search_mb_p(&mut self, pic: &NAVideoBuffer<u8>, refmb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32) {
549 let mut mv_est = MVEstimator {
550 mv_range: self.range,
551 cutoff_thr: self.thresh,
552 pic,
553 };
554 self.srch.search_mb(&mut mv_est, refmb, mb_x, mb_y, cand_mvs)
555 }
556 pub fn search_blk8(&mut self, pic: &NAVideoBuffer<u8>, refmb: &RefMBData, xoff: usize, yoff: usize, cand_mvs: &[MV]) -> (MV, u32) {
557 let mut mv_est = MVEstimator {
558 mv_range: self.range,
559 cutoff_thr: self.thresh,
560 pic,
561 };
562 self.srch.search_blk8(&mut mv_est, refmb, xoff, yoff, cand_mvs)
563 }
564}
565
566pub struct SearchB<'a> {
567 ref_p: &'a NAVideoBuffer<u8>,
568 ref_n: &'a NAVideoBuffer<u8>,
569 xpos: usize,
570 ypos: usize,
571 ratios: [u32; 2],
572 tmp1: RefMBData,
573 tmp2: RefMBData,
574 pred_blk: RefMBData,
575}
576
577impl<'a> SearchB<'a> {
578 pub fn new(ref_p: &'a NAVideoBuffer<u8>, ref_n: &'a NAVideoBuffer<u8>, mb_x: usize, mb_y: usize, ratios: [u32; 2]) -> Self {
579 Self {
580 ref_p, ref_n,
581 xpos: mb_x * 16,
582 ypos: mb_y * 16,
583 ratios,
584 tmp1: RefMBData::new(),
585 tmp2: RefMBData::new(),
586 pred_blk: RefMBData::new(),
587 }
588 }
589 pub fn search_mb(&mut self, ref_mb: &RefMBData, cand_mvs: [MV; 2]) -> (MV, MV) {
590 let mut best_cand = cand_mvs;
591 let mut best_dist = self.interp_b_dist(ref_mb, best_cand, MAX_DIST);
592
593 loop {
594 let mut improved = false;
595 for &fmv_add in DIA_PATTERN.iter() {
596 for &bmv_add in DIA_PATTERN.iter() {
597 let cand = [best_cand[0] + fmv_add.from_pixels(),
598 best_cand[1] + bmv_add.from_pixels()];
599 let dist = self.interp_b_dist(ref_mb, cand, best_dist);
600 if dist < best_dist {
601 best_dist = dist;
602 best_cand = cand;
603 improved = true;
604 }
605 }
606 }
607 if !improved {
608 break;
609 }
610 }
611
612 for &fmv_add in REFINEMENT.iter() {
613 for &bmv_add in REFINEMENT.iter() {
614 let cand = [best_cand[0] + fmv_add, best_cand[1] + bmv_add];
615 let dist = self.interp_b_dist(ref_mb, cand, best_dist);
616 if dist < best_dist {
617 best_dist = dist;
618 best_cand = cand;
619 }
620 }
621 }
622
623 (best_cand[0], best_cand[1])
624 }
625 fn interp_b_dist(&mut self, ref_mb: &RefMBData, cand_mv: [MV; 2], cur_best_dist: u32) -> u32 {
626 let [fmv, bmv] = cand_mv;
627 luma_mc(&mut self.tmp1.y, 16, self.ref_p, self.xpos, self.ypos, fmv, true);
628 chroma_mc(&mut self.tmp1.u, 8, self.ref_p, self.xpos / 2, self.ypos / 2, 1, fmv, true);
629 chroma_mc(&mut self.tmp1.v, 8, self.ref_p, self.xpos / 2, self.ypos / 2, 2, fmv, true);
630 luma_mc(&mut self.tmp2.y, 16, self.ref_n, self.xpos, self.ypos, bmv, true);
631 chroma_mc(&mut self.tmp2.u, 8, self.ref_n, self.xpos / 2, self.ypos / 2, 1, bmv, true);
632 chroma_mc(&mut self.tmp2.v, 8, self.ref_n, self.xpos / 2, self.ypos / 2, 2, bmv, true);
633 self.pred_blk.avg(&self.tmp1, self.ratios[0], &self.tmp2, self.ratios[1]);
634
635 let mut dist = 0;
636 for (dline, sline) in self.pred_blk.y.chunks(16).zip(ref_mb.y.chunks(16)) {
637 dist += sad!(dline, sline);
638 if dist > cur_best_dist {
639 return dist;
640 }
641 }
642 dist += sad!(self.pred_blk.u, ref_mb.u);
643 if dist > cur_best_dist {
644 return dist;
645 }
646 dist += sad!(self.pred_blk.v, ref_mb.v);
647
648 dist
649 }
650}
651
652macro_rules! hadamard {
653 ($s0:expr, $s1:expr, $s2:expr, $s3:expr, $d0:expr, $d1:expr, $d2:expr, $d3:expr) => {
654 let t0 = $s0 + $s1;
655 let t1 = $s0 - $s1;
656 let t2 = $s2 + $s3;
657 let t3 = $s2 - $s3;
658 $d0 = t0 + t2;
659 $d2 = t0 - t2;
660 $d1 = t1 + t3;
661 $d3 = t1 - t3;
662 }
663}
664
665pub struct FrameComplexityEstimate {
666 ref_frm: NAVideoBufferRef<u8>,
667 cur_frm: NAVideoBufferRef<u8>,
668 nxt_frm: NAVideoBufferRef<u8>,
669 width: usize,
670 height: usize,
671}
672
673impl FrameComplexityEstimate {
674 pub fn new() -> Self {
675 let vinfo = NAVideoInfo::new(24, 24, false, YUV420_FORMAT);
676 let vt = alloc_video_buffer(vinfo, 4).unwrap();
677 let buf = vt.get_vbuf().unwrap();
678 Self {
679 ref_frm: buf.clone(),
680 cur_frm: buf.clone(),
681 nxt_frm: buf,
682 width: 0,
683 height: 0,
684 }
685 }
686 pub fn resize(&mut self, width: usize, height: usize) {
687 if width != self.width || height != self.height {
688 self.width = width;
689 self.height = height;
690
691 let vinfo = NAVideoInfo::new(self.width / 2, self.height / 2, false, YUV420_FORMAT);
692 let vt = alloc_video_buffer(vinfo, 4).unwrap();
693 self.ref_frm = vt.get_vbuf().unwrap();
694 let frm = self.ref_frm.get_data_mut().unwrap();
695 for el in frm.iter_mut() {
696 *el = 0x80;
697 }
698 let vt = alloc_video_buffer(vinfo, 4).unwrap();
699 self.cur_frm = vt.get_vbuf().unwrap();
700 let vt = alloc_video_buffer(vinfo, 4).unwrap();
701 self.nxt_frm = vt.get_vbuf().unwrap();
702 }
703 }
704 pub fn set_current(&mut self, frm: &NAVideoBuffer<u8>) {
705 Self::downscale(&mut self.cur_frm, frm);
706 }
707 pub fn get_complexity(&self, ftype: FrameType) -> u32 {
708 match ftype {
709 FrameType::I => Self::calculate_i_cplx(&self.cur_frm),
710 FrameType::P => Self::calculate_mv_diff(&self.ref_frm, &self.cur_frm),
711 _ => 0,
712 }
713 }
714 pub fn decide_b_frame(&mut self, frm1: &NAVideoBuffer<u8>, frm2: &NAVideoBuffer<u8>) -> bool {
715 Self::downscale(&mut self.cur_frm, frm1);
716 Self::downscale(&mut self.nxt_frm, frm2);
717 let diff_ref_cur = Self::calculate_mv_diff(&self.ref_frm, &self.cur_frm);
718 let diff_cur_nxt = Self::calculate_mv_diff(&self.cur_frm, &self.nxt_frm);
719
720 // simple rule - if complexity ref->cur and cur->next is about the same this should be a B-frame
721 let ddiff = diff_ref_cur.max(diff_cur_nxt) - diff_ref_cur.min(diff_cur_nxt);
722 if ddiff < 256 {
723 true
724 } else {
725 let mut order = 0;
726 while (ddiff << order) < diff_ref_cur.min(diff_cur_nxt) {
727 order += 1;
728 }
729 order > 2
730 }
731 }
732 pub fn update_ref(&mut self) {
733 std::mem::swap(&mut self.ref_frm, &mut self.cur_frm);
734 }
735
736 fn add_mv(mb_x: usize, mb_y: usize, mv: MV) -> (usize, usize) {
737 (((mb_x * 16) as isize + (mv.x as isize)) as usize,
738 ((mb_y * 16) as isize + (mv.y as isize)) as usize)
739 }
740 fn calculate_i_cplx(frm: &NAVideoBuffer<u8>) -> u32 {
741 let (w, h) = frm.get_dimensions(0);
742 let src = frm.get_data();
743 let stride = frm.get_stride(0);
744 let mut sum = 0;
745 let mut offset = 0;
746 for y in (0..h).step_by(4) {
747 for x in (0..w).step_by(4) {
748 sum += Self::satd_i(src, offset + x, stride, x > 0, y > 0);
749 }
750 offset += stride * 4;
751 }
752 sum
753 }
754 fn calculate_mv_diff(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>) -> u32 {
755 let (w, h) = ref_frm.get_dimensions(0);
756 let mut sum = 0;
757 for mb_y in 0..(h / 16) {
758 for mb_x in 0..(w / 16) {
759 sum += Self::satd_mb_diff(ref_frm, cur_frm, mb_x, mb_y);
760 }
761 }
762 sum
763 }
764 fn satd_mb_diff(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>, mb_x: usize, mb_y: usize) -> u32 {
765 let mv = Self::search_mv(ref_frm, cur_frm, mb_x, mb_y);
766 let mut sum = 0;
767 let src0 = ref_frm.get_data();
768 let src1 = cur_frm.get_data();
769 let stride = ref_frm.get_stride(0);
770 let (src_x, src_y) = Self::add_mv(mb_x, mb_y, mv);
771 for y in (0..16).step_by(4) {
772 for x in (0..16).step_by(4) {
773 sum += Self::satd(&src0[src_x + x + (src_y + y) * stride..],
774 &src1[mb_x * 16 + x + (mb_y * 16 + y) * stride..],
775 stride);
776 }
777 }
778 sum
779 }
780 fn search_mv(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>, mb_x: usize, mb_y: usize) -> MV {
781 let stride = ref_frm.get_stride(0);
782 let (w, h) = ref_frm.get_dimensions(0);
783 let (v_edge, h_edge) = (w - 16, h - 16);
784 let ref_src = ref_frm.get_data();
785 let cur_src = cur_frm.get_data();
786 let cur_src = &cur_src[mb_x * 16 + mb_y * 16 * stride..];
787
788 let mut best_mv = ZERO_MV;
789 let mut best_dist = Self::sad(cur_src, ref_src, mb_x, mb_y, stride, best_mv);
790 if best_dist == 0 {
791 return best_mv;
792 }
793
794 for step in (0..=2).rev() {
795 let mut changed = true;
796 while changed {
797 changed = false;
798 for &mv in DIA_PATTERN[1..].iter() {
799 let cand_mv = best_mv + mv.scale(1 << step);
800 let (cx, cy) = Self::add_mv(mb_x, mb_y, cand_mv);
801 if cx > v_edge || cy > h_edge {
802 continue;
803 }
804 let cand_dist = Self::sad(cur_src, ref_src, mb_x, mb_y, stride, cand_mv);
805 if cand_dist < best_dist {
806 best_dist = cand_dist;
807 best_mv = cand_mv;
808 if best_dist == 0 {
809 return best_mv;
810 }
811 changed = true;
812 }
813 }
814 }
815 }
816 best_mv
817 }
818 fn sad(cur_src: &[u8], src: &[u8], mb_x: usize, mb_y: usize, stride: usize, mv: MV) -> u32 {
819 let (src_x, src_y) = Self::add_mv(mb_x, mb_y, mv);
820 let mut sum = 0;
821 for (line1, line2) in cur_src.chunks(stride).zip(src[src_x + src_y * stride..].chunks(stride)).take(16) {
822 sum += line1[..16].iter().zip(line2[..16].iter()).fold(0u32,
823 |acc, (&a, &b)| acc + u32::from(a.max(b) - a.min(b)) * u32::from(a.max(b) - a.min(b)));
824 }
825 sum
826 }
827 fn satd_i(src: &[u8], mut offset: usize, stride: usize, has_left: bool, has_top: bool) -> u32 {
828 let mut diffs = [0; 16];
829 match (has_left, has_top) {
830 (true, true) => {
831 for row in diffs.chunks_exact_mut(4) {
832 let mut left = i16::from(src[offset - 1]);
833 let mut tl = i16::from(src[offset - stride - 1]);
834 for (x, dst) in row.iter_mut().enumerate() {
835 let cur = i16::from(src[offset + x]);
836 let top = i16::from(src[offset + x - stride]);
837
838 *dst = cur - (top + left + tl - top.min(left).min(tl) - top.max(left).max(tl));
839
840 left = cur;
841 tl = top;
842 }
843
844 offset += stride;
845 }
846 },
847 (true, false) => {
848 for (dst, (left, cur)) in diffs.chunks_exact_mut(4).zip(
849 src[offset - 1..].chunks(stride).zip(src[offset..].chunks(stride))) {
850 for (dst, (&left, &cur)) in dst.iter_mut().zip(left.iter().zip(cur.iter())) {
851 *dst = i16::from(cur) - i16::from(left);
852 }
853 }
854 },
855 (false, true) => {
856 for (dst, (top, cur)) in diffs.chunks_exact_mut(4).zip(
857 src[offset - stride..].chunks(stride).zip(src[offset..].chunks(stride))) {
858 for (dst, (&top, &cur)) in dst.iter_mut().zip(top.iter().zip(cur.iter())) {
859 *dst = i16::from(cur) - i16::from(top);
860 }
861 }
862 },
863 (false, false) => {
864 for (dst, src) in diffs.chunks_exact_mut(4).zip(src[offset..].chunks(stride)) {
865 for (dst, &src) in dst.iter_mut().zip(src.iter()) {
866 *dst = i16::from(src) - 128;
867 }
868 }
869 },
870 };
871 for row in diffs.chunks_exact_mut(4) {
872 hadamard!(row[0], row[1], row[2], row[3], row[0], row[1], row[2], row[3]);
873 }
874 for i in 0..4 {
875 hadamard!(diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12],
876 diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12]);
877 }
e6aaad5c 878 diffs.iter().fold(0u32, |acc, x| acc + (x.unsigned_abs() as u32))
4965a5e5
KS
879 }
880 fn satd(src0: &[u8], src1: &[u8], stride: usize) -> u32 {
881 let mut diffs = [0; 16];
882 for (dst, (src0, src1)) in diffs.chunks_exact_mut(4).zip(
883 src0.chunks(stride).zip(src1.chunks(stride))) {
884 hadamard!(i16::from(src0[0]) - i16::from(src1[0]),
885 i16::from(src0[1]) - i16::from(src1[1]),
886 i16::from(src0[2]) - i16::from(src1[2]),
887 i16::from(src0[3]) - i16::from(src1[3]),
888 dst[0], dst[1], dst[2], dst[3]);
889 }
890 for i in 0..4 {
891 hadamard!(diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12],
892 diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12]);
893 }
e6aaad5c 894 diffs.iter().fold(0u32, |acc, x| acc + (x.unsigned_abs() as u32))
4965a5e5
KS
895 }
896 fn downscale(dst: &mut NAVideoBuffer<u8>, src: &NAVideoBuffer<u8>) {
897 let dst = NASimpleVideoFrame::from_video_buf(dst).unwrap();
898 let sdata = src.get_data();
899 for plane in 0..3 {
900 let cur_w = dst.width[plane];
901 let cur_h = dst.height[plane];
902 let doff = dst.offset[plane];
903 let soff = src.get_offset(plane);
904 let dstride = dst.stride[plane];
905 let sstride = src.get_stride(plane);
906 for (dline, sstrip) in dst.data[doff..].chunks_exact_mut(dstride).zip(
907 sdata[soff..].chunks_exact(sstride * 2)).take(cur_h) {
908 let (line0, line1) = sstrip.split_at(sstride);
909 for (dst, (src0, src1)) in dline.iter_mut().zip(
910 line0.chunks_exact(2).zip(line1.chunks_exact(2))).take(cur_w) {
911 *dst = ((u16::from(src0[0]) + u16::from(src0[1]) +
912 u16::from(src1[0]) + u16::from(src1[1]) + 2) >> 2) as u8;
913 }
914 }
915 }
916 }
917}