i263: decode B-frames somehow
[nihav.git] / src / codecs / blockdec.rs
index 0326f84984f984ea06368876a5c71bec6b6449e3..23e22d41e628d45193c1e10a5c6e10aad62a45ca 100644 (file)
@@ -1,5 +1,5 @@
 use std::mem;
-use std::ops::Add;
+use std::ops::{Add, Sub};
 use super::*;
 use super::blockdsp;
 use super::h263code::*;
@@ -10,6 +10,21 @@ pub enum Type {
     I, P, Skip, Special
 }
 
+#[allow(dead_code)]
+#[derive(Debug,Clone,Copy)]
+pub struct PBInfo {
+    trb:        u8,
+    dbquant:    u8,
+}
+
+impl PBInfo {
+    pub fn new(trb: u8, dbquant: u8) -> Self {
+        PBInfo{ trb: trb, dbquant: dbquant }
+    }
+    pub fn get_trb(&self) -> u8 { self.trb }
+    pub fn get_dbquant(&self) -> u8 { self.dbquant }
+}
+
 #[allow(dead_code)]
 #[derive(Debug,Clone,Copy)]
 pub struct PicInfo {
@@ -19,14 +34,14 @@ pub struct PicInfo {
     quant:  u8,
     apm:    bool,
     umv:    bool,
-    pb:     bool,
+    pb:     Option<PBInfo>,
     ts:     u8,
 }
 
 #[allow(dead_code)]
 impl PicInfo {
-    pub fn new(w: usize, h: usize, mode: Type, quant: u8, apm: bool, umv: bool, pb: bool, ts: u8) -> Self {
-        PicInfo{ w: w, h: h, mode: mode, quant: quant, apm: apm, umv: umv, pb: pb, ts: ts }
+    pub fn new(w: usize, h: usize, mode: Type, quant: u8, apm: bool, umv: bool, ts: u8, pb: Option<PBInfo>) -> Self {
+        PicInfo{ w: w, h: h, mode: mode, quant: quant, apm: apm, umv: umv, ts: ts, pb: pb }
     }
     pub fn get_width(&self) -> usize { self.w }
     pub fn get_height(&self) -> usize { self.h }
@@ -34,8 +49,9 @@ impl PicInfo {
     pub fn get_quant(&self) -> u8 { self.quant }
     pub fn get_apm(&self) -> bool { self.apm }
     pub fn get_umv(&self) -> bool { self.umv }
-    pub fn is_pb(&self) -> bool { self.pb }
+    pub fn is_pb(&self) -> bool { self.pb.is_some() }
     pub fn get_ts(&self) -> u8 { self.ts }
+    pub fn get_pbinfo(&self) -> PBInfo { self.pb.unwrap() }
 }
 
 #[derive(Debug,Clone,Copy)]
@@ -109,6 +125,19 @@ impl MV {
         }
         new_mv
     }
+    fn scale(&self, trb: u8, trd: u8) -> Self {
+        if (trd == 0) || (trb == 0) {
+            ZERO_MV
+        } else {
+            MV { x: (self.x * (trb as i16)) / (trd as i16), y: (self.y * (trb as i16)) / (trd as i16) }
+        }
+    }
+    fn b_sub(pvec: MV, fwdvec: MV, bvec: MV, trb: u8, trd: u8) -> Self {
+        let bscale = (trb as i16) - (trd as i16);
+        let x = if bvec.x != 0 { fwdvec.x - pvec.x } else if trd != 0 { bscale * pvec.x / (trd as i16) } else { 0 };
+        let y = if bvec.y != 0 { fwdvec.y - pvec.y } else if trd != 0 { bscale * pvec.y / (trd as i16) } else { 0 };
+        MV { x: x, y: y }
+    }
 }
 
 pub const ZERO_MV: MV = MV { x: 0, y: 0 };
@@ -118,6 +147,11 @@ impl Add for MV {
     fn add(self, other: MV) -> MV { MV { x: self.x + other.x, y: self.y + other.y } }
 }
 
+impl Sub for MV {
+    type Output = MV;
+    fn sub(self, other: MV) -> MV { MV { x: self.x - other.x, y: self.y - other.y } }
+}
+
 #[derive(Debug,Clone,Copy)]
 pub struct BlockInfo {
     intra:   bool,
@@ -129,8 +163,18 @@ pub struct BlockInfo {
     num_mv:  usize,
     bpart:   bool,
     b_cbp:   u8,
-    mv2:     [MV; 4],
+    mv2:     [MV; 2],
     num_mv2: usize,
+    fwd:     bool,
+}
+
+#[allow(dead_code)]
+#[derive(Debug,Clone,Copy)]
+pub struct BBlockInfo {
+    present: bool,
+    cbp:     u8,
+    num_mv:  usize,
+    fwd:     bool,
 }
 
 #[allow(dead_code)]
@@ -146,8 +190,9 @@ impl BlockInfo {
             num_mv:  0,
             bpart:   false,
             b_cbp:   0,
-            mv2:     [MV::new(0, 0), MV::new(0, 0), MV::new(0, 0), MV::new(0, 0)],
+            mv2:     [ZERO_MV, ZERO_MV],
             num_mv2: 0,
+            fwd:     false,
         }
     }
     pub fn is_intra(&self) -> bool { self.intra }
@@ -163,28 +208,44 @@ impl BlockInfo {
     pub fn get_mv2(&self, idx: usize) -> MV { self.mv2[idx] }
     pub fn set_mv(&mut self, mvs: &[MV]) {
         if mvs.len() > 0 { self.skip = false; }
-        self.bpart = true;
         let mut mv_arr: [MV; 4] = [MV::new(0, 0), MV::new(0, 0), MV::new(0, 0), MV::new(0, 0)];
         for i in 0..mvs.len() { mv_arr[i] = mvs[i]; }
         self.mv     = mv_arr;
         self.num_mv = mvs.len();
     }
-    pub fn set_mv2(&mut self, cbp: u8, mvs: &[MV]) {
-        self.bpart = true;
-        self.b_cbp = cbp;
-        let mut mv_arr: [MV; 4] = [MV::new(0, 0), MV::new(0, 0), MV::new(0, 0), MV::new(0, 0)];
+    pub fn set_bpart(&mut self, bbinfo: BBlockInfo) {
+        self.bpart = bbinfo.present;
+        self.b_cbp = bbinfo.cbp;
+        self.fwd   = bbinfo.fwd;
+        self.num_mv2 = bbinfo.get_num_mv();
+    }
+    pub fn set_b_mv(&mut self, mvs: &[MV]) {
+        if mvs.len() > 0 { self.skip = false; }
+        let mut mv_arr: [MV; 2] = [ZERO_MV, ZERO_MV];
         for i in 0..mvs.len() { mv_arr[i] = mvs[i]; }
-        self.mv2     = mv_arr;
-        self.num_mv2 = mvs.len();
+        self.mv2    = mv_arr;
+    }
+    pub fn is_b_fwd(&self) -> bool { self.fwd }
+}
+
+impl BBlockInfo {
+    pub fn new(present: bool, cbp: u8, num_mv: usize, fwd: bool) -> Self {
+        BBlockInfo {
+            present: present,
+            cbp:     cbp,
+            num_mv:  num_mv,
+            fwd:     fwd,
+        }
     }
+    pub fn get_num_mv(&self) -> usize { self.num_mv }
 }
 
 pub trait BlockDecoder {
     fn decode_pichdr(&mut self) -> DecoderResult<PicInfo>;
     fn decode_slice_header(&mut self, pinfo: &PicInfo) -> DecoderResult<Slice>;
     fn decode_block_header(&mut self, pinfo: &PicInfo, sinfo: &Slice) -> DecoderResult<BlockInfo>;
-    fn decode_block_intra(&mut self, info: &BlockInfo, no: usize, coded: bool, blk: &mut [i16; 64]) -> DecoderResult<()>;
-    fn decode_block_inter(&mut self, info: &BlockInfo, no: usize, coded: bool, blk: &mut [i16; 64]) -> DecoderResult<()>;
+    fn decode_block_intra(&mut self, info: &BlockInfo, quant: u8, no: usize, coded: bool, blk: &mut [i16; 64]) -> DecoderResult<()>;
+    fn decode_block_inter(&mut self, info: &BlockInfo, quant: u8, no: usize, coded: bool, blk: &mut [i16; 64]) -> DecoderResult<()>;
     fn calc_mv(&mut self, vec: MV);
     fn is_slice_end(&mut self) -> bool;
 }
@@ -286,6 +347,29 @@ fn copy_blocks(dst: &mut NAVideoBuffer<u8>, src: &NAVideoBuffer<u8>, xpos: usize
     blockdsp::copy_blocks(dst, src, xpos, ypos, srcx, srcy, w, h, 0, 1, mode, H263_INTERP_FUNCS);
 }
 
+fn avg_blocks(dst: &mut NAVideoBuffer<u8>, src: &NAVideoBuffer<u8>, xpos: usize, ypos: usize, w: usize, h: usize, mv: MV) {
+    let srcx = ((mv.x >> 1) as isize) + (xpos as isize);
+    let srcy = ((mv.y >> 1) as isize) + (ypos as isize);
+    let mode = ((mv.x & 1) + (mv.y & 1) * 2) as usize;
+
+    blockdsp::copy_blocks(dst, src, xpos, ypos, srcx, srcy, w, h, 0, 1, mode, H263_INTERP_AVG_FUNCS);
+}
+
+#[allow(dead_code)]
+#[derive(Clone,Copy)]
+struct BMB {
+    num_mv: usize,
+    mv_f:   [MV; 4],
+    mv_b:   [MV; 4],
+    fwd:    bool,
+    blk:    [[i16; 64]; 6],
+    cbp:    u8,
+}
+
+impl BMB {
+    fn new() -> Self { BMB {blk: [[0; 64]; 6], cbp: 0, fwd: false, mv_f: [ZERO_MV; 4], mv_b: [ZERO_MV; 4], num_mv: 0} }
+}
+
 pub struct DCT8x8VideoDecoder {
     w:          usize,
     h:          usize,
@@ -295,6 +379,9 @@ pub struct DCT8x8VideoDecoder {
     ftype:      Type,
     prev_frm:   Option<NAVideoBuffer<u8>>,
     cur_frm:    Option<NAVideoBuffer<u8>>,
+    last_ts:    u8,
+    has_b:      bool,
+    b_data:     Vec<BMB>,
 }
 
 #[allow(dead_code)]
@@ -304,6 +391,8 @@ impl DCT8x8VideoDecoder {
             w: 0, h: 0, mb_w: 0, mb_h: 0, num_mb: 0,
             ftype: Type::Special,
             prev_frm: None, cur_frm: None,
+            last_ts: 0,
+            has_b: false, b_data: Vec::new(),
         }
     }
 
@@ -312,7 +401,7 @@ impl DCT8x8VideoDecoder {
 
     pub fn parse_frame(&mut self, bd: &mut BlockDecoder) -> DecoderResult<NABufferType> {
         let pinfo = bd.decode_pichdr()?;
-        let mut mvi  = MVInfo::new();
+        let mut mvi = MVInfo::new();
 
 //todo handle res change
         self.w = pinfo.w;
@@ -321,10 +410,18 @@ impl DCT8x8VideoDecoder {
         self.mb_h = (pinfo.h + 15) >> 4;
         self.num_mb = self.mb_w * self.mb_h;
         self.ftype = pinfo.mode;
+        self.has_b = pinfo.is_pb();
+
+        if self.has_b {
+            self.b_data.truncate(0);
+        }
 
         mem::swap(&mut self.cur_frm, &mut self.prev_frm);
 //        if self.ftype == Type::I && !pinfo.is_pb() { self.prev_frm = None; }
 
+        let tsdiff = pinfo.ts.wrapping_sub(self.last_ts);
+        let bsdiff = if pinfo.is_pb() { pinfo.get_pbinfo().get_trb() } else { 0 };
+
         let fmt = formats::YUV420_FORMAT;
         let vinfo = NAVideoInfo::new(self.w, self.h, false, fmt);
         let bufret = alloc_video_buffer(vinfo, 4);
@@ -332,17 +429,6 @@ impl DCT8x8VideoDecoder {
         let mut bufinfo = bufret.unwrap();
         let mut buf = bufinfo.get_vbuf().unwrap();
 
-        let mut bbuf;
-
-        if self.prev_frm.is_some() && pinfo.is_pb() {
-            let bufret = alloc_video_buffer(vinfo, 4);
-            if let Err(_) = bufret { return Err(DecoderError::InvalidData); }
-            let mut bbufinfo = bufret.unwrap();
-            bbuf = Some(bbufinfo.get_vbuf().unwrap());
-        } else {
-            bbuf = None;
-        }
-
         let mut slice = Slice::get_default_slice(&pinfo);
         mvi.reset(self.mb_w, 0, pinfo.get_umv());
 
@@ -362,7 +448,7 @@ impl DCT8x8VideoDecoder {
 //println!("mb {}.{} CBP {:X} type {:?}, {} mvs skip {}", mb_x,mb_y, cbp, binfo.get_mode(), binfo.get_num_mvs(),binfo.is_skipped());
                 if binfo.is_intra() {
                     for i in 0..6 {
-                        bd.decode_block_intra(&binfo, i, (cbp & (1 << (5 - i))) != 0, &mut blk[i])?;
+                        bd.decode_block_intra(&binfo, binfo.get_q(), i, (cbp & (1 << (5 - i))) != 0, &mut blk[i])?;
                         h263_idct(&mut blk[i]);
                     }
                     blockdsp::put_blocks(&mut buf, mb_x, mb_y, &blk);
@@ -387,7 +473,7 @@ impl DCT8x8VideoDecoder {
 //println!("");
                     }
                     for i in 0..6 {
-                        bd.decode_block_inter(&binfo, i, ((cbp >> (5 - i)) & 1) != 0, &mut blk[i])?;
+                        bd.decode_block_inter(&binfo, binfo.get_q(), i, ((cbp >> (5 - i)) & 1) != 0, &mut blk[i])?;
                         h263_idct(&mut blk[i]);
                     }
                     blockdsp::add_blocks(&mut buf, mb_x, mb_y, &blk);
@@ -397,44 +483,109 @@ impl DCT8x8VideoDecoder {
                         copy_blocks(&mut buf, srcbuf, mb_x * 16, mb_y * 16, 16, 16, ZERO_MV);
                     }
                 }
-                if pinfo.is_pb() && binfo.has_b_part() {
-                    let mut blk: [[i16; 64]; 6] = [[0; 64]; 6];
+                if pinfo.is_pb() {
+                    let mut b_mb = BMB::new();
                     let cbp = binfo.get_cbp_b();
+                    let bq = (((pinfo.get_pbinfo().get_dbquant() + 5) as u16) * (binfo.get_q() as u16)) >> 2;
+                    let bquant;
+                    if bq < 1 { bquant = 1; }
+                    else if bq > 31 { bquant = 31; }
+                    else { bquant = bq as u8; }
+
+                    b_mb.cbp = cbp;
                     for i in 0..6 {
-                        bd.decode_block_inter(&binfo, i, (cbp & (1 << (5 - i))) != 0, &mut blk[i])?;
-                        h263_idct(&mut blk[i]);
+                        bd.decode_block_inter(&binfo, bquant, i, (cbp & (1 << (5 - i))) != 0, &mut b_mb.blk[i])?;
+                        h263_idct(&mut b_mb.blk[i]);
                     }
-                    if let Some(ref mut b_buf) = bbuf {
-/*                        let is_fwd = false;
-                        if binfo.get_num_mvs() == 1 { //todo scale
-                            let mv_f = MV::add_umv(binfo.get_mv(0), binfo.get_mv2(0), pinfo.get_umv());
-                            let mv_b = ZERO_MV//if component = 0 then scaled else mv_f - component
-                        } else {
-                        }*/
-                        if let Some(ref srcbuf) = self.prev_frm {
-                            copy_blocks(b_buf, srcbuf, mb_x * 16, mb_y * 16, 16, 16, ZERO_MV);
-                            blockdsp::add_blocks(b_buf, mb_x, mb_y, &blk);
+
+                    let is_fwd = binfo.is_b_fwd();
+                    b_mb.fwd = is_fwd;
+                    b_mb.num_mv = binfo.get_num_mvs();
+                    if binfo.get_num_mvs() == 0 {
+                        b_mb.num_mv = 1;
+                        b_mb.mv_f[0] = binfo.get_mv2(1);
+                        b_mb.mv_b[0] = binfo.get_mv2(0);
+                    } if binfo.get_num_mvs() == 1 {
+                        let src_mv = if is_fwd { ZERO_MV } else { binfo.get_mv(0).scale(bsdiff, tsdiff) };
+                        let mv_f = MV::add_umv(src_mv, binfo.get_mv2(0), pinfo.get_umv());
+                        let mv_b = MV::b_sub(binfo.get_mv(0), mv_f, binfo.get_mv2(0), bsdiff, tsdiff);
+                        b_mb.mv_f[0] = mv_f;
+                        b_mb.mv_b[0] = mv_b;
+                    } else {
+                        for blk_no in 0..4 {
+                            let src_mv = if is_fwd { ZERO_MV } else { binfo.get_mv(blk_no).scale(bsdiff, tsdiff) };
+                            let mv_f = MV::add_umv(src_mv, binfo.get_mv2(0), pinfo.get_umv());
+                            let mv_b = MV::b_sub(binfo.get_mv(blk_no), mv_f, binfo.get_mv2(0), bsdiff, tsdiff);
+                            b_mb.mv_f[blk_no] = mv_f;
+                            b_mb.mv_b[blk_no] = mv_b;
                         }
                     }
+                    self.b_data.push(b_mb);
                 }
-
             }
             mvi.update_row();
         }
         self.cur_frm = Some(buf);
-        if pinfo.is_pb() {
-            return Ok(NABufferType::Video(bbuf.unwrap()));
-        } 
-println!("unpacked all");
+        self.last_ts = pinfo.ts;
+//println!("unpacked all");
         Ok(bufinfo)
     }
 
-    pub fn get_stored_pframe(&mut self) -> DecoderResult<NABufferType> {
-        if let Some(_) = self.cur_frm {
-            let buf = self.cur_frm.clone().unwrap();
-            Ok(NABufferType::Video(buf))
-        } else {
-            Err(DecoderError::MissingReference)
+    pub fn get_bframe(&mut self) -> DecoderResult<NABufferType> {
+        if !self.has_b || !self.cur_frm.is_some() || !self.prev_frm.is_some() {
+            return Err(DecoderError::MissingReference);
+        }
+        self.has_b = false;
+
+        let fmt = formats::YUV420_FORMAT;
+        let vinfo = NAVideoInfo::new(self.w, self.h, false, fmt);
+        let bufret = alloc_video_buffer(vinfo, 4);
+        if let Err(_) = bufret { return Err(DecoderError::InvalidData); }
+        let mut bufinfo = bufret.unwrap();
+        let mut b_buf = bufinfo.get_vbuf().unwrap();
+
+        if let Some(ref bck_buf) = self.prev_frm {
+            if let Some(ref fwd_buf) = self.cur_frm {
+                recon_b_frame(&mut b_buf, bck_buf, fwd_buf, self.mb_w, self.mb_h, &self.b_data);
+            }
+        }
+
+        self.b_data.truncate(0);
+        Ok(bufinfo)
+    }
+}
+
+fn recon_b_frame(b_buf: &mut NAVideoBuffer<u8>, bck_buf: &NAVideoBuffer<u8>, fwd_buf: &NAVideoBuffer<u8>,
+                 mb_w: usize, mb_h: usize, b_data: &Vec<BMB>) {
+    let mut cur_mb = 0;
+    for mb_y in 0..mb_h {
+        for mb_x in 0..mb_w {
+            let num_mv = b_data[cur_mb].num_mv;
+            let is_fwd = b_data[cur_mb].fwd;
+            if num_mv == 0 {
+                copy_blocks(b_buf, bck_buf, mb_x * 16, mb_y * 16, 16, 16, ZERO_MV);
+                if !is_fwd {
+                    avg_blocks(b_buf, fwd_buf, mb_x * 16, mb_y * 16, 16, 16, ZERO_MV);
+                }
+            } else if num_mv == 1 {
+                copy_blocks(b_buf, bck_buf, mb_x * 16, mb_y * 16, 16, 16, b_data[cur_mb].mv_f[0]);
+                if !is_fwd {
+                    avg_blocks(b_buf, fwd_buf, mb_x * 16, mb_y * 16, 16, 16, b_data[cur_mb].mv_b[0]);
+                }
+            } else {
+                for blk_no in 0..4 {
+                    let xpos = mb_x * 16 + (blk_no & 1) * 8;
+                    let ypos = mb_y * 16 + (blk_no & 2) * 4;
+                    copy_blocks(b_buf, bck_buf, xpos, ypos, 8, 8, b_data[cur_mb].mv_f[blk_no]);
+                    if !is_fwd {
+                        avg_blocks(b_buf, fwd_buf, xpos, ypos, 8, 8, b_data[cur_mb].mv_b[blk_no]);
+                    }
+                }
+            }
+            if num_mv != 0 && b_data[cur_mb].cbp != 0 {
+                blockdsp::add_blocks(b_buf, mb_x, mb_y, &b_data[cur_mb].blk);
+            }
+            cur_mb += 1;
         }
     }
 }