h264: fix direct 8x8 inference mode
[nihav.git] / nihav-itu / src / codecs / h264 / types.rs
index 0a0cc64cb934fee76d984082144995c81cf4f800..3456b990b03e3bc5ecd17cce3af7a0af97688b42 100644 (file)
@@ -445,6 +445,8 @@ pub struct SliceState {
     pub left_c:         [[u8; 9]; 2],
 }
 
+const BLK4_TO_D8: [usize; 16] = [ 0, 0, 3, 3, 0, 0, 3, 3, 12, 12, 15, 15, 12, 12, 15, 15 ];
+
 impl SliceState {
     pub fn new() -> Self {
         Self {
@@ -765,9 +767,15 @@ impl SliceState {
         self.fill_mv (0, 0, 16, 16, 0, mv);
         self.fill_ref(0, 0, 16, 16, 0, ref_idx);
     }
-    pub fn predict_direct_mb(&mut self, frame_refs: &FrameRefs, temporal_mv: bool, cur_id: u16) {
+    pub fn predict_direct_mb(&mut self, frame_refs: &FrameRefs, temporal_mv: bool, direct_8x8: bool, cur_id: u16) {
         let (col_mb, _, _) = frame_refs.get_colocated_info(self.mb_x, self.mb_y);
-        if col_mb.mb_type.is_16x16_ref() || !temporal_mv {
+        if direct_8x8 {
+            for blk4 in 0..16 {
+                let (mv0, ref0, mv1, ref1) = self.get_direct_mv(frame_refs, temporal_mv, cur_id, BLK4_TO_D8[blk4]);
+                self.get_cur_blk4(blk4).mv = [mv0, mv1];
+                self.get_cur_blk8(blk4_to_blk8(blk4)).ref_idx = [ref0, ref1];
+            }
+        } else if col_mb.mb_type.is_16x16_ref() || !temporal_mv {
             let (mv0, ref0, mv1, ref1) = self.get_direct_mv(frame_refs, temporal_mv, cur_id, 0);
             self.apply_to_blk4(|blk4| blk4.mv = [mv0, mv1]);
             self.apply_to_blk8(|blk8| blk8.ref_idx = [ref0, ref1]);
@@ -779,8 +787,9 @@ impl SliceState {
             }
         }
     }
-    pub fn predict_direct_sub(&mut self, frame_refs: &FrameRefs, temporal_mv: bool, cur_id: u16, blk4: usize) {
-        let (mv0, ref0, mv1, ref1) = self.get_direct_mv(frame_refs, temporal_mv, cur_id, blk4);
+    pub fn predict_direct_sub(&mut self, frame_refs: &FrameRefs, temporal_mv: bool, direct8x8: bool, cur_id: u16, blk4: usize) {
+        let src_blk = if !direct8x8 { blk4 } else { BLK4_TO_D8[blk4] };
+        let (mv0, ref0, mv1, ref1) = self.get_direct_mv(frame_refs, temporal_mv, cur_id, src_blk);
         self.get_cur_blk4(blk4).mv = [mv0, mv1];
         self.get_cur_blk8(blk4_to_blk8(blk4)).ref_idx = [ref0, ref1];
     }