VP7 encoder
[nihav.git] / nihav-duck / src / codecs / vp7enc / motion_est.rs
diff --git a/nihav-duck/src/codecs/vp7enc/motion_est.rs b/nihav-duck/src/codecs/vp7enc/motion_est.rs
new file mode 100644 (file)
index 0000000..882e3e7
--- /dev/null
@@ -0,0 +1,523 @@
+use nihav_core::frame::*;
+use nihav_codec_support::codecs::{MV, ZERO_MV};
+use super::super::vp78dsp::*;
+use super::blocks::{SrcBlock, get_block_difference};
+use crate::codecs::vpenc::motion_est::*;
+pub use crate::codecs::vpenc::motion_est::MVSearchMode;
+
+pub trait MVSearchModeCreate {
+    fn create_search(&self) -> Box<dyn MVSearch + Send>;
+}
+
+impl MVSearchModeCreate for MVSearchMode {
+    fn create_search(&self) -> Box<dyn MVSearch + Send> {
+        match *self {
+            MVSearchMode::SEA       => Box::new(EliminationSearch::new()),
+            MVSearchMode::Diamond   => Box::new(DiaSearch::new()),
+            MVSearchMode::Hexagon   => Box::new(HexSearch::new()),
+            MVSearchMode::EPZS      => Box::new(EPZSearch::new()),
+            _ => unreachable!(),
+        }
+    }
+}
+
+const MAX_DIST: u32 = std::u32::MAX;
+const DIST_THRESH: u32 = 256;
+pub const LARGE_BLK8_DIST: u32 = 256;
+
+trait FromPixels {
+    fn from_pixels(self) -> Self;
+}
+
+impl FromPixels for MV {
+    fn from_pixels(self) -> MV {
+        MV { x: self.x * 8, y: self.y * 8 }
+    }
+}
+
+pub trait MVSearch {
+    fn preinit(&mut self, mv_est: &MVEstimator);
+    fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &SrcBlock, mb_x: usize, mb_y: usize) -> (MV, u32);
+    fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 64], xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32);
+    fn search_blk4(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 16], xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32);
+}
+
+#[derive(Default)]
+pub struct EliminationSearch {
+    msa:    [Vec<u16>; 3],
+    stride: [usize; 3],
+}
+
+impl EliminationSearch {
+    const BLOCK_SIZE: usize = 4;
+    pub fn new() -> Self { Self::default() }
+    fn get_rdist(&self, xpos: usize, ypos: usize, bavg: &[u16; 3]) -> i32 {
+        let luma_off = xpos + ypos * self.stride[0];
+        let chroma_off = (xpos / 2) + (ypos / 2) * self.stride[1];
+
+        let mut luma_avg = 0;
+        for row in self.msa[0][luma_off..].chunks(self.stride[0]).take(16).step_by(Self::BLOCK_SIZE) {
+            for &el in row.iter().take(16).step_by(Self::BLOCK_SIZE) {
+                luma_avg += el;
+            }
+        }
+        let mut chroma_avg = [0; 2];
+        for chroma in 0..1 {
+            for row in self.msa[chroma + 1][chroma_off..].chunks(self.stride[1]).take(8).step_by(Self::BLOCK_SIZE) {
+                for &el in row.iter().take(8).step_by(Self::BLOCK_SIZE) {
+                    chroma_avg[chroma] += el;
+                }
+            }
+        }
+
+        (i32::from(bavg[0]) - i32::from(luma_avg)).abs() +
+        (i32::from(bavg[1]) - i32::from(chroma_avg[0])).abs() +
+        (i32::from(bavg[2]) - i32::from(chroma_avg[1])).abs()
+    }
+}
+
+impl MVSearch for EliminationSearch {
+    fn preinit(&mut self, mv_est: &MVEstimator) {
+        let data = mv_est.ref_frame.get_data();
+        for (plane, msa) in self.msa.iter_mut().enumerate() {
+            let (width, height) = mv_est.ref_frame.get_dimensions(plane);
+            self.stride[plane] = width + 1 - Self::BLOCK_SIZE;
+            msa.clear();
+            msa.reserve(self.stride[plane] * (height + 1 - Self::BLOCK_SIZE));
+
+            let mut off = mv_est.ref_frame.get_offset(plane);
+            let stride = mv_est.ref_frame.get_stride(plane);
+            for _ in 0..(height + 1 - Self::BLOCK_SIZE) {
+                for x in 0..(width + 1 - Self::BLOCK_SIZE) {
+                    let mut sum = 0;
+                    for j in 0..Self::BLOCK_SIZE {
+                        for i in 0..Self::BLOCK_SIZE {
+                            sum += u16::from(data[off + x + i + j * stride]);
+                        }
+                    }
+                    msa.push(sum);
+                }
+                off += stride;
+            }
+        }
+    }
+    fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &SrcBlock, mb_x: usize, mb_y: usize) -> (MV, u32) {
+        let mut best_dist = MAX_DIST;
+        let mut best_mv = ZERO_MV;
+
+        let mut cur_mv = ZERO_MV;
+
+        let (width, height) = mv_est.ref_frame.get_dimensions(0);
+        let mut bavg = [0; 3];
+        for blk in cur_mb.luma_blocks() {
+            bavg[0] += blk.iter().fold(0u16, |acc, &x| acc + u16::from(x));
+        }
+        for chroma in 0..2 {
+            for blk in cur_mb.chroma_blocks(chroma) {
+                bavg[chroma + 1] += blk.iter().fold(0u16, |acc, &x| acc + u16::from(x));
+            }
+        }
+        let mut rough_dist = std::i32::MAX;
+        for ytry in 0..mv_est.mv_range * 2 + 1 {
+            let dy = if (ytry & 1) == 0 { ytry >> 1 } else { -((ytry + 1) >> 1) };
+            let ypos = (mb_y as isize) * 16 + (dy as isize);
+            if ypos < 0 || (ypos + 16) > (height as isize) {
+                continue;
+            }
+            let ypos = ypos as usize;
+            cur_mv.y = dy * 8;
+            for xtry in 0..mv_est.mv_range * 2 + 1 {
+                let dx = if (xtry & 1) == 0 { xtry >> 1 } else { -((xtry + 1) >> 1) };
+                let xpos = (mb_x as isize) * 16 + (dx as isize);
+                if xpos < 0 || (xpos + 16) > (width as isize) {
+                    continue;
+                }
+                let xpos = xpos as usize;
+
+                let rdist = self.get_rdist(xpos, ypos, &bavg);
+                if rdist > rough_dist {
+                    continue;
+                }
+                rough_dist = rdist;
+
+                cur_mv.x = dx * 8;
+
+                let dist = mv_est.sad_mb(cur_mb, mb_x, mb_y, cur_mv, best_dist);
+
+                if dist < best_dist {
+                    best_dist = dist;
+                    best_mv = cur_mv;
+                    if dist <= DIST_THRESH {
+                        return (best_mv, best_dist);
+                    }
+                }
+            }
+        }
+        (best_mv, best_dist)
+    }
+    fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 64], xpos_orig: usize, ypos_orig: usize, _cand_mvs: &[MV]) -> (MV, u32) {
+        let mut best_dist = MAX_DIST;
+        let mut best_mv = ZERO_MV;
+
+        let mut cur_mv = ZERO_MV;
+
+        let (width, height) = mv_est.ref_frame.get_dimensions(0);
+        let bavg = ref_blk.iter().fold(0u16, |acc, &x| acc + u16::from(x));
+
+        let mut rough_dist = std::i32::MAX;
+        for ytry in 0..mv_est.mv_range * 2 + 1 {
+            let dy = if (ytry & 1) == 0 { ytry >> 1 } else { -((ytry + 1) >> 1) };
+            let ypos = (ypos_orig as isize) + (dy as isize);
+            if ypos < 0 || (ypos + 8) > (height as isize) {
+                continue;
+            }
+            let ypos = ypos as usize;
+            cur_mv.y = dy * 8;
+            for xtry in 0..mv_est.mv_range * 2 + 1 {
+                let dx = if (xtry & 1) == 0 { xtry >> 1 } else { -((xtry + 1) >> 1) };
+                let xpos = (xpos_orig as isize) + (dx as isize);
+                if xpos < 0 || (xpos + 8) > (width as isize) {
+                    continue;
+                }
+                let xpos = xpos as usize;
+
+                let luma_off = xpos + ypos * self.stride[0];
+                let mut cur_avg = 0;
+                for row in self.msa[0][luma_off..].chunks(self.stride[0]).take(8).step_by(Self::BLOCK_SIZE) {
+                    for &el in row.iter().take(8).step_by(Self::BLOCK_SIZE) {
+                        cur_avg += el;
+                    }
+                }
+
+                let rdist = (i32::from(cur_avg) - i32::from(bavg)).abs();
+                if rdist > rough_dist {
+                    continue;
+                }
+                rough_dist = rdist;
+
+                cur_mv.x = dx * 8;
+
+                let dist = mv_est.sad_blk8(ref_blk, xpos_orig, ypos_orig, cur_mv, best_dist);
+
+                if dist < best_dist {
+                    best_dist = dist;
+                    best_mv = cur_mv;
+                    if dist <= DIST_THRESH / 4 {
+                        return (best_mv, best_dist);
+                    }
+                }
+            }
+        }
+        (best_mv, best_dist)
+    }
+    fn search_blk4(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 16], xpos_orig: usize, ypos_orig: usize, _cand_mvs: &[MV]) -> (MV, u32) {
+        let mut best_dist = MAX_DIST;
+        let mut best_mv = ZERO_MV;
+
+        let mut cur_mv = ZERO_MV;
+
+        let (width, height) = mv_est.ref_frame.get_dimensions(0);
+        let bavg = ref_blk.iter().fold(0u16, |acc, &x| acc + u16::from(x));
+
+        let mut rough_dist = std::i32::MAX;
+        for ytry in 0..mv_est.mv_range * 2 + 1 {
+            let dy = if (ytry & 1) == 0 { ytry >> 1 } else { -((ytry + 1) >> 1) };
+            let ypos = (ypos_orig as isize) + (dy as isize);
+            if ypos < 0 || (ypos + 4) > (height as isize) {
+                continue;
+            }
+            let ypos = ypos as usize;
+            cur_mv.y = dy * 8;
+            for xtry in 0..mv_est.mv_range * 2 + 1 {
+                let dx = if (xtry & 1) == 0 { xtry >> 1 } else { -((xtry + 1) >> 1) };
+                let xpos = (xpos_orig as isize) + (dx as isize);
+                if xpos < 0 || (xpos + 4) > (width as isize) {
+                    continue;
+                }
+                let xpos = xpos as usize;
+
+                let luma_off = xpos + ypos * self.stride[0];
+                let cur_avg = self.msa[0][luma_off];
+
+                let rdist = (i32::from(cur_avg) - i32::from(bavg)).abs();
+                if rdist > rough_dist {
+                    continue;
+                }
+                rough_dist = rdist;
+
+                cur_mv.x = dx * 8;
+
+                let dist = mv_est.sad_blk4(ref_blk, xpos_orig, ypos_orig, cur_mv, best_dist);
+
+                if dist < best_dist {
+                    best_dist = dist;
+                    best_mv = cur_mv;
+                    if dist <= DIST_THRESH / 16 {
+                        return (best_mv, best_dist);
+                    }
+                }
+            }
+        }
+        (best_mv, best_dist)
+    }
+}
+
+macro_rules! pattern_search {
+    ($struct_name: ident, $patterns: expr) => {
+        pub struct $struct_name {
+            point:  [MV; $patterns.len()],
+            dist:   [u32; $patterns.len()],
+            steps:  &'static [MV; $patterns.len()],
+        }
+
+        impl $struct_name {
+            pub fn new() -> Self {
+                Self {
+                    point:  $patterns,
+                    dist:   [MAX_DIST; $patterns.len()],
+                    steps:  &$patterns,
+                }
+            }
+            fn reset(&mut self) {
+                self.point = $patterns;
+                self.dist  = [MAX_DIST; $patterns.len()];
+            }
+            fn set_new_point(&mut self, start: MV, dist: u32) {
+                for (dst, &src) in self.point.iter_mut().zip(self.steps.iter()) {
+                    *dst = src + start;
+                }
+                self.dist = [MAX_DIST; $patterns.len()];
+                self.dist[0] = dist;
+            }
+            fn update(&mut self, step: MV) {
+                let mut new_point = self.point;
+                let mut new_dist = [MAX_DIST; $patterns.len()];
+
+                for point in new_point.iter_mut() {
+                    *point += step;
+                }
+
+                for (new_point, new_dist) in new_point.iter_mut().zip(new_dist.iter_mut()) {
+                    for (&old_point, &old_dist) in self.point.iter().zip(self.dist.iter()) {
+                        if *new_point == old_point {
+                            *new_dist = old_dist;
+                            break;
+                        }
+                    }
+                }
+                self.point = new_point;
+                self.dist  = new_dist;
+            }
+        }
+
+        impl MVSearch for $struct_name {
+            fn preinit(&mut self, _mv_est: &MVEstimator) {}
+            fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &SrcBlock, mb_x: usize, mb_y: usize) -> (MV, u32) {
+                search_template!(self, mv_est, cur_mb, mb_x, mb_y, sad_mb, DIST_THRESH)
+            }
+            fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 64], xpos: usize, ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) {
+                search_template!(self, mv_est, ref_blk, xpos, ypos, sad_blk8, DIST_THRESH / 4)
+            }
+            fn search_blk4(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 16], xpos: usize, ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) {
+                search_template!(self, mv_est, ref_blk, xpos, ypos, sad_blk4, DIST_THRESH / 16)
+            }
+        }
+    }
+}
+
+pattern_search!(DiaSearch, DIA_PATTERN);
+pattern_search!(HexSearch, HEX_PATTERN);
+
+pub struct EPZSearch {
+    point:  [MV; DIA_PATTERN.len()],
+    dist:   [u32; DIA_PATTERN.len()],
+    steps:  &'static [MV; DIA_PATTERN.len()],
+}
+
+impl EPZSearch {
+    pub fn new() -> Self {
+        Self {
+            point:  DIA_PATTERN,
+            dist:   [MAX_DIST; DIA_PATTERN.len()],
+            steps:  &DIA_PATTERN,
+        }
+    }
+    fn reset(&mut self) {
+        self.point = DIA_PATTERN;
+        self.dist  = [MAX_DIST; DIA_PATTERN.len()];
+    }
+    fn set_new_point(&mut self, start: MV, dist: u32) {
+        for (dst, &src) in self.point.iter_mut().zip(self.steps.iter()) {
+            *dst = src + start;
+        }
+        self.dist = [MAX_DIST; DIA_PATTERN.len()];
+        self.dist[0] = dist;
+    }
+    fn update(&mut self, step: MV) {
+        let mut new_point = self.point;
+        let mut new_dist = [MAX_DIST; DIA_PATTERN.len()];
+
+        for point in new_point.iter_mut() {
+            *point += step;
+        }
+
+        for (new_point, new_dist) in new_point.iter_mut().zip(new_dist.iter_mut()) {
+            for (&old_point, &old_dist) in self.point.iter().zip(self.dist.iter()) {
+                if *new_point == old_point {
+                    *new_dist = old_dist;
+                    break;
+                }
+            }
+        }
+        self.point = new_point;
+        self.dist  = new_dist;
+    }
+}
+
+impl MVSearch for EPZSearch {
+    fn preinit(&mut self, _mv_est: &MVEstimator) {}
+    fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &SrcBlock, mb_x: usize, mb_y: usize) -> (MV, u32) {
+        search_template!(self, mv_est, cur_mb, mb_x, mb_y, sad_mb, DIST_THRESH)
+    }
+    fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 64], xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32) {
+        let mut best_mv = ZERO_MV;
+        let zero_dist = mv_est.sad_blk8(ref_blk, xpos, ypos, best_mv, MAX_DIST);
+        let mut best_dist = zero_dist;
+        if best_dist > DIST_THRESH {
+            for &cmv in cand_mvs[1..].iter() {
+                let dist = mv_est.sad_blk8(ref_blk, xpos, ypos, cmv, best_dist);
+                if dist < best_dist {
+                    best_dist = dist;
+                    best_mv   = cmv;
+                    if best_dist <= DIST_THRESH {
+                        break;
+                    }
+                }
+            }
+            if best_dist > DIST_THRESH {
+                return search_template!(self, mv_est, ref_blk, xpos, ypos, sad_blk8, DIST_THRESH / 4, best_mv, best_dist, false);
+            }
+        }
+        (best_mv, best_dist)
+    }
+    fn search_blk4(&mut self, mv_est: &mut MVEstimator, ref_blk: &[u8; 16], xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32) {
+        let mut best_mv = ZERO_MV;
+        let zero_dist = mv_est.sad_blk4(ref_blk, xpos, ypos, best_mv, MAX_DIST);
+        let mut best_dist = zero_dist;
+        if best_dist > DIST_THRESH {
+            for &cmv in cand_mvs[1..].iter() {
+                let dist = mv_est.sad_blk4(ref_blk, xpos, ypos, cmv, best_dist);
+                if dist < best_dist {
+                    best_dist = dist;
+                    best_mv   = cmv;
+                    if best_dist <= DIST_THRESH {
+                        break;
+                    }
+                }
+            }
+            if best_dist > DIST_THRESH {
+                return search_template!(self, mv_est, ref_blk, xpos, ypos, sad_blk4, DIST_THRESH / 16, best_mv, best_dist, false);
+            }
+        }
+        (best_mv, best_dist)
+    }
+}
+
+pub struct MVEstimator {
+    pub ref_mb:     SrcBlock,
+    pub blk8:       [u8; 64],
+    pub blk4:       [u8; 16],
+    mc_buf:         NAVideoBufferRef<u8>,
+    ref_frame:      NAVideoBufferRef<u8>,
+    mv_range:       i16,
+count: usize,
+}
+
+#[allow(dead_code)]
+impl MVEstimator {
+    pub fn new(ref_frame: NAVideoBufferRef<u8>, mc_buf: NAVideoBufferRef<u8>, mv_range: i16) -> Self {
+        Self {
+            ref_mb:         SrcBlock::new(),
+            blk8:           [0; 64],
+            blk4:           [0; 16],
+            ref_frame, mc_buf,
+            mv_range,
+count: 0,
+        }
+    }
+    pub fn get_mb(&mut self, dst: &mut SrcBlock, mb_x: usize, mb_y: usize, cur_mv: MV) {
+        let tmp_blk = self.mc_buf.get_data_mut().unwrap();
+        mc_block16x16(&mut dst.luma, 0, 16, mb_x * 16, mb_y * 16, cur_mv.x * 2, cur_mv.y * 2, self.ref_frame.clone(), 0, tmp_blk);
+        mc_block8x8(&mut dst.chroma[0], 0, 8, mb_x * 8, mb_y * 8, cur_mv.x, cur_mv.y, self.ref_frame.clone(), 1, tmp_blk);
+        mc_block8x8(&mut dst.chroma[1], 0, 8, mb_x * 8, mb_y * 8, cur_mv.x, cur_mv.y, self.ref_frame.clone(), 2, tmp_blk);
+    }
+    pub fn get_blk8(&mut self, dst: &mut [u8; 64], plane: usize, x: usize, y: usize, mut cur_mv: MV) {
+        if plane == 0 {
+            cur_mv.x *= 2;
+            cur_mv.y *= 2;
+        }
+        mc_block8x8(dst, 0, 8, x, y, cur_mv.x, cur_mv.y, self.ref_frame.clone(), plane, self.mc_buf.get_data_mut().unwrap());
+    }
+    fn sad_blk8(&mut self, refblk: &[u8; 64], x: usize, y: usize, cur_mv: MV, _best_dist: u32) -> u32 {
+        mc_block8x8(&mut self.blk8, 0, 8, x, y, cur_mv.x * 2, cur_mv.y * 2, self.ref_frame.clone(), 0, self.mc_buf.get_data_mut().unwrap());
+self.count += 1;
+        sad8x8(&self.blk8, refblk)
+    }
+    pub fn get_blk4(&mut self, dst: &mut [u8; 16], plane: usize, x: usize, y: usize, mut cur_mv: MV) {
+        if plane == 0 {
+            cur_mv.x *= 2;
+            cur_mv.y *= 2;
+        }
+        mc_block4x4(dst, 0, 4, x, y, cur_mv.x, cur_mv.y, self.ref_frame.clone(), plane, self.mc_buf.get_data_mut().unwrap());
+    }
+    fn sad_blk4(&mut self, refblk: &[u8; 16], x: usize, y: usize, cur_mv: MV, _best_dist: u32) -> u32 {
+        mc_block4x4(&mut self.blk4, 0, 4, x, y, cur_mv.x * 2, cur_mv.y * 2, self.ref_frame.clone(), 0, self.mc_buf.get_data_mut().unwrap());
+        sad4x4(&self.blk4, refblk)
+    }
+    fn sad_mb(&mut self, cur_mb: &SrcBlock, mb_x: usize, mb_y: usize, cur_mv: MV, best_dist: u32) -> u32 {
+        let tmp_blk = self.mc_buf.get_data_mut().unwrap();
+
+        mc_block16x16(&mut self.ref_mb.luma, 0, 16, mb_x * 16, mb_y * 16, cur_mv.x * 2, cur_mv.y * 2, self.ref_frame.clone(), 0, tmp_blk);
+        mc_block8x8(&mut self.ref_mb.chroma[0], 0, 8, mb_x * 8, mb_y * 8, cur_mv.x, cur_mv.y, self.ref_frame.clone(), 1, tmp_blk);
+        mc_block8x8(&mut self.ref_mb.chroma[1], 0, 8, mb_x * 8, mb_y * 8, cur_mv.x, cur_mv.y, self.ref_frame.clone(), 2, tmp_blk);
+        let mut dist = 0;
+        let mut diff = [0; 16];
+        for (sblk, dblk) in self.ref_mb.luma_blocks().zip(cur_mb.luma_blocks()) {
+            get_block_difference(&mut diff, &sblk, &dblk);
+            dist += sad(&diff);
+            if dist > best_dist {
+                break;
+            }
+        }
+        'chroma_loop: for chroma in 0..2 {
+            for (sblk, dblk) in self.ref_mb.chroma_blocks(chroma).zip(cur_mb.chroma_blocks(chroma)) {
+                get_block_difference(&mut diff, &sblk, &dblk);
+                dist += sad(&diff);
+                if dist > best_dist {
+                    break 'chroma_loop;
+                }
+            }
+        }
+        dist
+    }
+}
+
+fn sad(diff: &[i16; 16]) -> u32 {
+    diff.iter().fold(0u32, |acc, &x| acc + ((i32::from(x) * i32::from(x)) as u32))
+}
+fn sad8x8(blk1: &[u8; 64], blk2: &[u8; 64]) -> u32 {
+    let mut sum = 0u32;
+    for (&a, &b) in blk1.iter().zip(blk2.iter()) {
+        let diff = i32::from(a) - i32::from(b);
+        sum += (diff * diff) as u32;
+    }
+    sum
+}
+fn sad4x4(blk1: &[u8; 16], blk2: &[u8; 16]) -> u32 {
+    let mut sum = 0u32;
+    for (&a, &b) in blk1.iter().zip(blk2.iter()) {
+        let diff = i32::from(a) - i32::from(b);
+        sum += (diff * diff) as u32;
+    }
+    sum
+}