1 use nihav_core::frame::*;
2 use nihav_codec_support::codecs::{MV, ZERO_MV};
3 use super::super::vp78dsp::*;
4 use super::blocks::{SrcBlock, get_block_difference};
5 use crate::codecs::vpenc::motion_est::*;
6 pub use crate::codecs::vpenc::motion_est::MVSearchMode;
8 pub trait MVSearchModeCreate {
9 fn create_search(&self) -> Box<dyn MVSearch + Send>;
12 impl MVSearchModeCreate for MVSearchMode {
13 fn create_search(&self) -> Box<dyn MVSearch + Send> {
15 MVSearchMode::SEA => Box::new(EliminationSearch::new()),
16 MVSearchMode::Diamond => Box::new(DiaSearch::new()),
17 MVSearchMode::Hexagon => Box::new(HexSearch::new()),
18 MVSearchMode::EPZS => Box::new(EPZSearch::new()),
24 const MAX_DIST: u32 = std::u32::MAX;
25 const DIST_THRESH: u32 = 256;
26 pub const LARGE_BLK8_DIST: u32 = 256;
29 fn from_pixels(self) -> Self;
32 impl FromPixels for MV {
33 fn from_pixels(self) -> MV {
34 MV { x: self.x * 8, y: self.y * 8 }
39 fn preinit(&mut self, mv_est: &MVEstimator);
40 fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &SrcBlock, mb_x: usize, mb_y: usize) -> (MV, u32);
41 fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 64], xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32);
42 fn search_blk4(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 16], xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32);
46 pub struct EliminationSearch {
51 impl EliminationSearch {
52 const BLOCK_SIZE: usize = 4;
53 pub fn new() -> Self { Self::default() }
54 fn get_rdist(&self, xpos: usize, ypos: usize, bavg: &[u16; 3]) -> i32 {
55 let luma_off = xpos + ypos * self.stride[0];
56 let chroma_off = (xpos / 2) + (ypos / 2) * self.stride[1];
59 for row in self.msa[0][luma_off..].chunks(self.stride[0]).take(16).step_by(Self::BLOCK_SIZE) {
60 for &el in row.iter().take(16).step_by(Self::BLOCK_SIZE) {
64 let mut chroma_avg = [0; 2];
66 for row in self.msa[chroma + 1][chroma_off..].chunks(self.stride[1]).take(8).step_by(Self::BLOCK_SIZE) {
67 for &el in row.iter().take(8).step_by(Self::BLOCK_SIZE) {
68 chroma_avg[chroma] += el;
73 (i32::from(bavg[0]) - i32::from(luma_avg)).abs() +
74 (i32::from(bavg[1]) - i32::from(chroma_avg[0])).abs() +
75 (i32::from(bavg[2]) - i32::from(chroma_avg[1])).abs()
79 impl MVSearch for EliminationSearch {
80 fn preinit(&mut self, mv_est: &MVEstimator) {
81 let data = mv_est.ref_frame.get_data();
82 for (plane, msa) in self.msa.iter_mut().enumerate() {
83 let (width, height) = mv_est.ref_frame.get_dimensions(plane);
84 self.stride[plane] = width + 1 - Self::BLOCK_SIZE;
86 msa.reserve(self.stride[plane] * (height + 1 - Self::BLOCK_SIZE));
88 let mut off = mv_est.ref_frame.get_offset(plane);
89 let stride = mv_est.ref_frame.get_stride(plane);
90 for _ in 0..(height + 1 - Self::BLOCK_SIZE) {
91 for x in 0..(width + 1 - Self::BLOCK_SIZE) {
93 for j in 0..Self::BLOCK_SIZE {
94 for i in 0..Self::BLOCK_SIZE {
95 sum += u16::from(data[off + x + i + j * stride]);
104 fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &SrcBlock, mb_x: usize, mb_y: usize) -> (MV, u32) {
105 let mut best_dist = MAX_DIST;
106 let mut best_mv = ZERO_MV;
108 let mut cur_mv = ZERO_MV;
110 let (width, height) = mv_est.ref_frame.get_dimensions(0);
111 let mut bavg = [0; 3];
112 for blk in cur_mb.luma_blocks() {
113 bavg[0] += blk.iter().fold(0u16, |acc, &x| acc + u16::from(x));
116 for blk in cur_mb.chroma_blocks(chroma) {
117 bavg[chroma + 1] += blk.iter().fold(0u16, |acc, &x| acc + u16::from(x));
120 let mut rough_dist = std::i32::MAX;
121 for ytry in 0..mv_est.mv_range * 2 + 1 {
122 let dy = if (ytry & 1) == 0 { ytry >> 1 } else { -((ytry + 1) >> 1) };
123 let ypos = (mb_y as isize) * 16 + (dy as isize);
124 if ypos < 0 || (ypos + 16) > (height as isize) {
127 let ypos = ypos as usize;
129 for xtry in 0..mv_est.mv_range * 2 + 1 {
130 let dx = if (xtry & 1) == 0 { xtry >> 1 } else { -((xtry + 1) >> 1) };
131 let xpos = (mb_x as isize) * 16 + (dx as isize);
132 if xpos < 0 || (xpos + 16) > (width as isize) {
135 let xpos = xpos as usize;
137 let rdist = self.get_rdist(xpos, ypos, &bavg);
138 if rdist > rough_dist {
145 let dist = mv_est.sad_mb(cur_mb, mb_x, mb_y, cur_mv, best_dist);
147 if dist < best_dist {
150 if dist <= DIST_THRESH {
151 return (best_mv, best_dist);
158 fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 64], xpos_orig: usize, ypos_orig: usize, _cand_mvs: &[MV]) -> (MV, u32) {
159 let mut best_dist = MAX_DIST;
160 let mut best_mv = ZERO_MV;
162 let mut cur_mv = ZERO_MV;
164 let (width, height) = mv_est.ref_frame.get_dimensions(0);
165 let bavg = ref_blk.iter().fold(0u16, |acc, &x| acc + u16::from(x));
167 let mut rough_dist = std::i32::MAX;
168 for ytry in 0..mv_est.mv_range * 2 + 1 {
169 let dy = if (ytry & 1) == 0 { ytry >> 1 } else { -((ytry + 1) >> 1) };
170 let ypos = (ypos_orig as isize) + (dy as isize);
171 if ypos < 0 || (ypos + 8) > (height as isize) {
174 let ypos = ypos as usize;
176 for xtry in 0..mv_est.mv_range * 2 + 1 {
177 let dx = if (xtry & 1) == 0 { xtry >> 1 } else { -((xtry + 1) >> 1) };
178 let xpos = (xpos_orig as isize) + (dx as isize);
179 if xpos < 0 || (xpos + 8) > (width as isize) {
182 let xpos = xpos as usize;
184 let luma_off = xpos + ypos * self.stride[0];
186 for row in self.msa[0][luma_off..].chunks(self.stride[0]).take(8).step_by(Self::BLOCK_SIZE) {
187 for &el in row.iter().take(8).step_by(Self::BLOCK_SIZE) {
192 let rdist = (i32::from(cur_avg) - i32::from(bavg)).abs();
193 if rdist > rough_dist {
200 let dist = mv_est.sad_blk8(ref_blk, xpos_orig, ypos_orig, cur_mv, best_dist);
202 if dist < best_dist {
205 if dist <= DIST_THRESH / 4 {
206 return (best_mv, best_dist);
213 fn search_blk4(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 16], xpos_orig: usize, ypos_orig: usize, _cand_mvs: &[MV]) -> (MV, u32) {
214 let mut best_dist = MAX_DIST;
215 let mut best_mv = ZERO_MV;
217 let mut cur_mv = ZERO_MV;
219 let (width, height) = mv_est.ref_frame.get_dimensions(0);
220 let bavg = ref_blk.iter().fold(0u16, |acc, &x| acc + u16::from(x));
222 let mut rough_dist = std::i32::MAX;
223 for ytry in 0..mv_est.mv_range * 2 + 1 {
224 let dy = if (ytry & 1) == 0 { ytry >> 1 } else { -((ytry + 1) >> 1) };
225 let ypos = (ypos_orig as isize) + (dy as isize);
226 if ypos < 0 || (ypos + 4) > (height as isize) {
229 let ypos = ypos as usize;
231 for xtry in 0..mv_est.mv_range * 2 + 1 {
232 let dx = if (xtry & 1) == 0 { xtry >> 1 } else { -((xtry + 1) >> 1) };
233 let xpos = (xpos_orig as isize) + (dx as isize);
234 if xpos < 0 || (xpos + 4) > (width as isize) {
237 let xpos = xpos as usize;
239 let luma_off = xpos + ypos * self.stride[0];
240 let cur_avg = self.msa[0][luma_off];
242 let rdist = (i32::from(cur_avg) - i32::from(bavg)).abs();
243 if rdist > rough_dist {
250 let dist = mv_est.sad_blk4(ref_blk, xpos_orig, ypos_orig, cur_mv, best_dist);
252 if dist < best_dist {
255 if dist <= DIST_THRESH / 16 {
256 return (best_mv, best_dist);
265 macro_rules! pattern_search {
266 ($struct_name: ident, $patterns: expr) => {
267 pub struct $struct_name {
268 point: [MV; $patterns.len()],
269 dist: [u32; $patterns.len()],
270 steps: &'static [MV; $patterns.len()],
274 pub fn new() -> Self {
277 dist: [MAX_DIST; $patterns.len()],
281 fn reset(&mut self) {
282 self.point = $patterns;
283 self.dist = [MAX_DIST; $patterns.len()];
285 fn set_new_point(&mut self, start: MV, dist: u32) {
286 for (dst, &src) in self.point.iter_mut().zip(self.steps.iter()) {
289 self.dist = [MAX_DIST; $patterns.len()];
292 fn update(&mut self, step: MV) {
293 let mut new_point = self.point;
294 let mut new_dist = [MAX_DIST; $patterns.len()];
296 for point in new_point.iter_mut() {
300 for (new_point, new_dist) in new_point.iter_mut().zip(new_dist.iter_mut()) {
301 for (&old_point, &old_dist) in self.point.iter().zip(self.dist.iter()) {
302 if *new_point == old_point {
303 *new_dist = old_dist;
308 self.point = new_point;
309 self.dist = new_dist;
313 impl MVSearch for $struct_name {
314 fn preinit(&mut self, _mv_est: &MVEstimator) {}
315 fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &SrcBlock, mb_x: usize, mb_y: usize) -> (MV, u32) {
316 search_template!(self, mv_est, cur_mb, mb_x, mb_y, sad_mb, DIST_THRESH)
318 fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 64], xpos: usize, ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) {
319 search_template!(self, mv_est, ref_blk, xpos, ypos, sad_blk8, DIST_THRESH / 4)
321 fn search_blk4(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 16], xpos: usize, ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) {
322 search_template!(self, mv_est, ref_blk, xpos, ypos, sad_blk4, DIST_THRESH / 16)
328 pattern_search!(DiaSearch, DIA_PATTERN);
329 pattern_search!(HexSearch, HEX_PATTERN);
331 pub struct EPZSearch {
332 point: [MV; DIA_PATTERN.len()],
333 dist: [u32; DIA_PATTERN.len()],
334 steps: &'static [MV; DIA_PATTERN.len()],
338 pub fn new() -> Self {
341 dist: [MAX_DIST; DIA_PATTERN.len()],
345 fn reset(&mut self) {
346 self.point = DIA_PATTERN;
347 self.dist = [MAX_DIST; DIA_PATTERN.len()];
349 fn set_new_point(&mut self, start: MV, dist: u32) {
350 for (dst, &src) in self.point.iter_mut().zip(self.steps.iter()) {
353 self.dist = [MAX_DIST; DIA_PATTERN.len()];
356 fn update(&mut self, step: MV) {
357 let mut new_point = self.point;
358 let mut new_dist = [MAX_DIST; DIA_PATTERN.len()];
360 for point in new_point.iter_mut() {
364 for (new_point, new_dist) in new_point.iter_mut().zip(new_dist.iter_mut()) {
365 for (&old_point, &old_dist) in self.point.iter().zip(self.dist.iter()) {
366 if *new_point == old_point {
367 *new_dist = old_dist;
372 self.point = new_point;
373 self.dist = new_dist;
377 impl MVSearch for EPZSearch {
378 fn preinit(&mut self, _mv_est: &MVEstimator) {}
379 fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &SrcBlock, mb_x: usize, mb_y: usize) -> (MV, u32) {
380 search_template!(self, mv_est, cur_mb, mb_x, mb_y, sad_mb, DIST_THRESH)
382 fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 64], xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32) {
383 let mut best_mv = ZERO_MV;
384 let zero_dist = mv_est.sad_blk8(ref_blk, xpos, ypos, best_mv, MAX_DIST);
385 let mut best_dist = zero_dist;
386 if best_dist > DIST_THRESH {
387 for &cmv in cand_mvs[1..].iter() {
388 let dist = mv_est.sad_blk8(ref_blk, xpos, ypos, cmv, best_dist);
389 if dist < best_dist {
392 if best_dist <= DIST_THRESH {
397 if best_dist > DIST_THRESH {
398 return search_template!(self, mv_est, ref_blk, xpos, ypos, sad_blk8, DIST_THRESH / 4, best_mv, best_dist, false);
403 fn search_blk4(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 16], xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32) {
404 let mut best_mv = ZERO_MV;
405 let zero_dist = mv_est.sad_blk4(ref_blk, xpos, ypos, best_mv, MAX_DIST);
406 let mut best_dist = zero_dist;
407 if best_dist > DIST_THRESH {
408 for &cmv in cand_mvs[1..].iter() {
409 let dist = mv_est.sad_blk4(ref_blk, xpos, ypos, cmv, best_dist);
410 if dist < best_dist {
413 if best_dist <= DIST_THRESH {
418 if best_dist > DIST_THRESH {
419 return search_template!(self, mv_est, ref_blk, xpos, ypos, sad_blk4, DIST_THRESH / 16, best_mv, best_dist, false);
426 pub struct MVEstimator {
427 pub ref_mb: SrcBlock,
430 mc_buf: NAVideoBufferRef<u8>,
431 ref_frame: NAVideoBufferRef<u8>,
438 pub fn new(ref_frame: NAVideoBufferRef<u8>, mc_buf: NAVideoBufferRef<u8>, mv_range: i16) -> Self {
440 ref_mb: SrcBlock::new(),
448 pub fn get_mb(&mut self, dst: &mut SrcBlock, mb_x: usize, mb_y: usize, cur_mv: MV) {
449 let tmp_blk = self.mc_buf.get_data_mut().unwrap();
450 mc_block16x16(&mut dst.luma, 0, 16, mb_x * 16, mb_y * 16, cur_mv.x * 2, cur_mv.y * 2, self.ref_frame.clone(), 0, tmp_blk);
451 mc_block8x8(&mut dst.chroma[0], 0, 8, mb_x * 8, mb_y * 8, cur_mv.x, cur_mv.y, self.ref_frame.clone(), 1, tmp_blk);
452 mc_block8x8(&mut dst.chroma[1], 0, 8, mb_x * 8, mb_y * 8, cur_mv.x, cur_mv.y, self.ref_frame.clone(), 2, tmp_blk);
454 pub fn get_blk8(&mut self, dst: &mut [u8; 64], plane: usize, x: usize, y: usize, mut cur_mv: MV) {
459 mc_block8x8(dst, 0, 8, x, y, cur_mv.x, cur_mv.y, self.ref_frame.clone(), plane, self.mc_buf.get_data_mut().unwrap());
461 fn sad_blk8(&mut self, refblk: &[u8; 64], x: usize, y: usize, cur_mv: MV, _best_dist: u32) -> u32 {
462 mc_block8x8(&mut self.blk8, 0, 8, x, y, cur_mv.x * 2, cur_mv.y * 2, self.ref_frame.clone(), 0, self.mc_buf.get_data_mut().unwrap());
464 sad8x8(&self.blk8, refblk)
466 pub fn get_blk4(&mut self, dst: &mut [u8; 16], plane: usize, x: usize, y: usize, mut cur_mv: MV) {
471 mc_block4x4(dst, 0, 4, x, y, cur_mv.x, cur_mv.y, self.ref_frame.clone(), plane, self.mc_buf.get_data_mut().unwrap());
473 fn sad_blk4(&mut self, refblk: &[u8; 16], x: usize, y: usize, cur_mv: MV, _best_dist: u32) -> u32 {
474 mc_block4x4(&mut self.blk4, 0, 4, x, y, cur_mv.x * 2, cur_mv.y * 2, self.ref_frame.clone(), 0, self.mc_buf.get_data_mut().unwrap());
475 sad4x4(&self.blk4, refblk)
477 fn sad_mb(&mut self, cur_mb: &SrcBlock, mb_x: usize, mb_y: usize, cur_mv: MV, best_dist: u32) -> u32 {
478 let tmp_blk = self.mc_buf.get_data_mut().unwrap();
480 mc_block16x16(&mut self.ref_mb.luma, 0, 16, mb_x * 16, mb_y * 16, cur_mv.x * 2, cur_mv.y * 2, self.ref_frame.clone(), 0, tmp_blk);
481 mc_block8x8(&mut self.ref_mb.chroma[0], 0, 8, mb_x * 8, mb_y * 8, cur_mv.x, cur_mv.y, self.ref_frame.clone(), 1, tmp_blk);
482 mc_block8x8(&mut self.ref_mb.chroma[1], 0, 8, mb_x * 8, mb_y * 8, cur_mv.x, cur_mv.y, self.ref_frame.clone(), 2, tmp_blk);
484 let mut diff = [0; 16];
485 for (sblk, dblk) in self.ref_mb.luma_blocks().zip(cur_mb.luma_blocks()) {
486 get_block_difference(&mut diff, &sblk, &dblk);
488 if dist > best_dist {
492 'chroma_loop: for chroma in 0..2 {
493 for (sblk, dblk) in self.ref_mb.chroma_blocks(chroma).zip(cur_mb.chroma_blocks(chroma)) {
494 get_block_difference(&mut diff, &sblk, &dblk);
496 if dist > best_dist {
505 fn sad(diff: &[i16; 16]) -> u32 {
506 diff.iter().fold(0u32, |acc, &x| acc + ((i32::from(x) * i32::from(x)) as u32))
508 fn sad8x8(blk1: &[u8; 64], blk2: &[u8; 64]) -> u32 {
510 for (&a, &b) in blk1.iter().zip(blk2.iter()) {
511 let diff = i32::from(a) - i32::from(b);
512 sum += (diff * diff) as u32;
516 fn sad4x4(blk1: &[u8; 16], blk2: &[u8; 16]) -> u32 {
518 for (&a, &b) in blk1.iter().zip(blk2.iter()) {
519 let diff = i32::from(a) - i32::from(b);
520 sum += (diff * diff) as u32;