h264: fix direct 8x8 inference mode
[nihav.git] / nihav-itu / src / codecs / h264 / types.rs
index 6b5d0100c5810583ce9faa24c0bb5f94bab998d9..3456b990b03e3bc5ecd17cce3af7a0af97688b42 100644 (file)
@@ -1,3 +1,4 @@
+use nihav_core::frame::NASimpleVideoFrame;
 use nihav_codec_support::codecs::{MV, ZERO_MV};
 use nihav_codec_support::data::GenericCache;
 use super::FrameRefs;
@@ -224,13 +225,15 @@ impl CompactMBType {
     pub fn is_inter(self) -> bool {
         !self.is_intra() && !self.is_skip() && self != CompactMBType::PCM
     }
-    pub fn is_16x16(self) -> bool {
+    pub fn is_16x16_ref(self) -> bool {
         match self {
-            CompactMBType::P16x8 | CompactMBType::P8x16 |
-            CompactMBType::P8x8 | CompactMBType::P8x8Ref0 |
-            CompactMBType::B16x8 | CompactMBType::B8x16 |
-            CompactMBType::B8x8 => false,
-            _ => true,
+            CompactMBType::Intra4x4 |
+            CompactMBType::Intra8x8 |
+            CompactMBType::Intra16x16 |
+            CompactMBType::PCM |
+            CompactMBType::P16x16 |
+            CompactMBType::B16x16 => true,
+            _ => false,
         }
     }
 }
@@ -431,12 +434,19 @@ pub struct SliceState {
     pub blk8:           GenericCache<Blk8Data>,
     pub blk4:           GenericCache<Blk4Data>,
 
-    pub deblock:        GenericCache<u8>,
+    pub deblock:        [u8; 16],
 
     pub has_top:        bool,
     pub has_left:       bool,
+
+    pub top_line_y:     Vec<u8>,
+    pub left_y:         [u8; 17], // first element is top-left
+    pub top_line_c:     [Vec<u8>; 2],
+    pub left_c:         [[u8; 9]; 2],
 }
 
+const BLK4_TO_D8: [usize; 16] = [ 0, 0, 3, 3, 0, 0, 3, 3, 12, 12, 15, 15, 12, 12, 15, 15 ];
+
 impl SliceState {
     pub fn new() -> Self {
         Self {
@@ -449,10 +459,15 @@ impl SliceState {
             blk8:       GenericCache::new(0, 0, Blk8Data::default()),
             blk4:       GenericCache::new(0, 0, Blk4Data::default()),
 
-            deblock:    GenericCache::new(0, 0, 0),
+            deblock:    [0; 16],
 
             has_top:    false,
             has_left:   false,
+
+            top_line_y: Vec::new(),
+            left_y:     [0; 17],
+            top_line_c: [Vec::new(), Vec::new()],
+            left_c:     [[0; 9]; 2],
         }
     }
     pub fn reset(&mut self, mb_w: usize, mb_h: usize, mb_pos: usize) {
@@ -470,42 +485,65 @@ impl SliceState {
         self.blk8  = GenericCache::new(2, mb_w * 2 + 2, Blk8Data::default());
         self.blk4  = GenericCache::new(4, mb_w * 4 + 2, Blk4Data::default());
 
-        self.deblock  = GenericCache::new(4, mb_w * 4 + 1, 0);
-
         self.has_top  = false;
         self.has_left = false;
+
+        self.top_line_y.resize(mb_w * 16 + 1, 0x80);
+        self.top_line_c[0].resize(mb_w *  8 + 1, 0x80);
+        self.top_line_c[1].resize(mb_w *  8 + 1, 0x80);
+        self.left_y = [0x80; 17];
+        self.left_c = [[0x80; 9]; 2];
+    }
+    pub fn save_ipred_context(&mut self, frm: &NASimpleVideoFrame<u8>) {
+        let dstoff = self.mb_x * 16;
+        let srcoff = frm.offset[0] + self.mb_x * 16 + self.mb_y * 16 * frm.stride[0];
+        self.left_y[0] = self.top_line_y[dstoff + 15];
+        self.top_line_y[dstoff..][..16].copy_from_slice(&frm.data[srcoff + frm.stride[0] * 15..][..16]);
+        for (dst, src) in self.left_y[1..].iter_mut().zip(frm.data[srcoff..].chunks(frm.stride[0])) {
+            *dst = src[15];
+        }
+        for chroma in 0..2 {
+            let cstride = frm.stride[chroma + 1];
+            let dstoff = self.mb_x * 8;
+            let srcoff = frm.offset[chroma + 1] + self.mb_x * 8 + self.mb_y * 8 * cstride;
+            self.left_c[chroma][0] = self.top_line_c[chroma][dstoff + 7];
+            self.top_line_c[chroma][dstoff..][..8].copy_from_slice(&frm.data[srcoff + cstride * 7..][..8]);
+            for (dst, src) in self.left_c[chroma][1..].iter_mut().zip(frm.data[srcoff..].chunks(cstride)) {
+                *dst = src[7];
+            }
+        }
     }
-    pub fn fill_deblock(&mut self, deblock_mode: u8, is_s: bool) {
+    pub fn fill_deblock(&mut self, frefs: &FrameRefs, deblock_mode: u8, is_s: bool) {
         if deblock_mode == 1 {
             return;
         }
 
+        self.deblock = [0; 16];
+
         let tx8x8 = self.get_cur_mb().transform_8x8;
 
-        let mut idx = self.deblock.xpos + self.mb_x * 4;
         let cur_mbt     = self.get_cur_mb().mb_type;
         let left_mbt    = self.get_left_mb().mb_type;
         let mut top_mbt = self.get_top_mb().mb_type;
         for y in 0..4 {
-            if tx8x8 && (y & 1) != 0 {
-                continue;
-            }
             let can_do_top = y != 0 || (self.mb_y != 0 && (self.has_top || deblock_mode != 2));
-            if can_do_top {
+            if can_do_top && (!tx8x8 || (y & 1) == 0) {
                 if is_s || cur_mbt.is_intra() || top_mbt.is_intra() {
                     let val = if y == 0 { 0x40 } else { 0x30 };
-                    for el in self.deblock.data[idx..][..4].iter_mut() { *el |= val; }
+                    for el in self.deblock[y * 4..][..4].iter_mut() { *el |= val; }
                 } else {
                     for x in 0..4 {
-                        if self.get_cur_blk4(x).ncoded != 0 || self.get_top_blk4(x).ncoded != 0 {
-                            self.deblock.data[idx + x] |= 0x20;
+                        let blk4 = x + y * 4;
+                        let blk8 = x / 2 + (y / 2) * 2;
+                        if self.get_cur_blk4(blk4).ncoded != 0 || self.get_top_blk4(blk4).ncoded != 0 {
+                            self.deblock[y * 4 + x] |= 0x20;
                         } else {
-                            let cur_mv = self.get_cur_blk4(x).mv;
-                            let top_mv = self.get_top_blk4(x).mv;
-                            let cur_ref = self.get_cur_blk8(x / 2).ref_idx;
-                            let top_ref = self.get_top_blk8(x / 2).ref_idx;
-                            if mvdiff4(cur_mv[0], top_mv[0]) || mvdiff4(cur_mv[1], top_mv[1]) || cur_ref != top_ref {
-                                self.deblock.data[idx + x] |= 0x10;
+                            let cur_mv = self.get_cur_blk4(blk4).mv;
+                            let top_mv = self.get_top_blk4(blk4).mv;
+                            let cur_ref = self.get_cur_blk8(blk8).ref_idx;
+                            let top_ref = if (y & 1) == 0 { self.get_top_blk8(blk8).ref_idx } else { cur_ref };
+                            if mvdiff4(cur_mv[0], top_mv[0]) || mvdiff4(cur_mv[1], top_mv[1]) || !frefs.cmp_refs(cur_ref, top_ref) {
+                                self.deblock[y * 4 + x] |= 0x10;
                             }
                         }
                     }
@@ -513,32 +551,30 @@ impl SliceState {
             }
             let mut lleft_mbt = left_mbt;
             for x in 0..4 {
-                if tx8x8 && (x & 1) != 0 {
-                    continue;
-                }
+                let skip_8 = tx8x8 && (x & 1) != 0;
                 let can_do_left = x > 0 || self.has_left || (self.mb_x != 0 && deblock_mode != 2);
                 if !can_do_left {
                     continue;
                 }
                 let blk4 = x + y * 4;
                 let blk8 = x / 2 + (y / 2) * 2;
-                if is_s || cur_mbt.is_intra() || lleft_mbt.is_intra() {
-                    self.deblock.data[idx + x] |= if x == 0 { 4 } else { 3 };
-                } else if self.get_cur_blk4(blk4).ncoded != 0 || self.get_top_blk4(blk4).ncoded != 0 {
-                    self.deblock.data[idx + x] |= 2;
+                if skip_8 {
+                } else if is_s || cur_mbt.is_intra() || lleft_mbt.is_intra() {
+                    self.deblock[y * 4 + x] |= if x == 0 { 4 } else { 3 };
+                } else if self.get_cur_blk4(blk4).ncoded != 0 || self.get_left_blk4(blk4).ncoded != 0 {
+                    self.deblock[y * 4 + x] |= 2;
                 } else {
                     let cur_mv  = self.get_cur_blk4(blk4).mv;
                     let left_mv = self.get_left_blk4(blk4).mv;
                     let cur_ref  = self.get_cur_blk8(blk8).ref_idx;
-                    let left_ref = self.get_left_blk8(blk8).ref_idx;
-                    if mvdiff4(cur_mv[0], left_mv[0]) || mvdiff4(cur_mv[1], left_mv[1]) || cur_ref != left_ref {
-                        self.deblock.data[idx + x] |= 1;
+                    let left_ref = if (x & 1) == 0 { self.get_left_blk8(blk8).ref_idx } else { cur_ref };
+                    if mvdiff4(cur_mv[0], left_mv[0]) || mvdiff4(cur_mv[1], left_mv[1]) || !frefs.cmp_refs(cur_ref, left_ref) {
+                        self.deblock[y * 4 + x] |= 1;
                     }
                 }
                 lleft_mbt = cur_mbt;
             }
             top_mbt = cur_mbt;
-            idx += self.deblock.stride;
         }
     }
     pub fn next_mb(&mut self) {
@@ -551,8 +587,6 @@ impl SliceState {
             self.blk8.update_row();
             self.blk4.update_row();
 
-            self.deblock.update_row();
-
             self.has_left = false;
         }
         self.has_top = self.mb_x + self.mb_y * self.mb_w >= self.mb_start + self.mb_w;
@@ -733,9 +767,15 @@ impl SliceState {
         self.fill_mv (0, 0, 16, 16, 0, mv);
         self.fill_ref(0, 0, 16, 16, 0, ref_idx);
     }
-    pub fn predict_direct_mb(&mut self, frame_refs: &FrameRefs, temporal_mv: bool, cur_id: u16) {
+    pub fn predict_direct_mb(&mut self, frame_refs: &FrameRefs, temporal_mv: bool, direct_8x8: bool, cur_id: u16) {
         let (col_mb, _, _) = frame_refs.get_colocated_info(self.mb_x, self.mb_y);
-        if col_mb.mb_type.is_16x16() || !temporal_mv {
+        if direct_8x8 {
+            for blk4 in 0..16 {
+                let (mv0, ref0, mv1, ref1) = self.get_direct_mv(frame_refs, temporal_mv, cur_id, BLK4_TO_D8[blk4]);
+                self.get_cur_blk4(blk4).mv = [mv0, mv1];
+                self.get_cur_blk8(blk4_to_blk8(blk4)).ref_idx = [ref0, ref1];
+            }
+        } else if col_mb.mb_type.is_16x16_ref() || !temporal_mv {
             let (mv0, ref0, mv1, ref1) = self.get_direct_mv(frame_refs, temporal_mv, cur_id, 0);
             self.apply_to_blk4(|blk4| blk4.mv = [mv0, mv1]);
             self.apply_to_blk8(|blk8| blk8.ref_idx = [ref0, ref1]);
@@ -747,8 +787,9 @@ impl SliceState {
             }
         }
     }
-    pub fn predict_direct_sub(&mut self, frame_refs: &FrameRefs, temporal_mv: bool, cur_id: u16, blk4: usize) {
-        let (mv0, ref0, mv1, ref1) = self.get_direct_mv(frame_refs, temporal_mv, cur_id, blk4);
+    pub fn predict_direct_sub(&mut self, frame_refs: &FrameRefs, temporal_mv: bool, direct8x8: bool, cur_id: u16, blk4: usize) {
+        let src_blk = if !direct8x8 { blk4 } else { BLK4_TO_D8[blk4] };
+        let (mv0, ref0, mv1, ref1) = self.get_direct_mv(frame_refs, temporal_mv, cur_id, src_blk);
         self.get_cur_blk4(blk4).mv = [mv0, mv1];
         self.get_cur_blk8(blk4_to_blk8(blk4)).ref_idx = [ref0, ref1];
     }