h264: introduce inline assembly for CABAC get_bit() and bump compiler version
[nihav.git] / nihav-itu / src / codecs / h264 / cabac_coder.rs
index baa8caff030757b11561eb2869ba021e1fd91c48..3e9278e1f72c36dc647273bdc00ff9501aec7a9c 100644 (file)
@@ -1,8 +1,11 @@
 use nihav_core::codecs::{DecoderResult, DecoderError};
 use super::slice::SliceType;
+#[cfg(target_arch="x86_64")]
+use std::arch::asm;
 
 const NUM_CABAC_CONTEXTS: usize = 1024;
 
+#[repr(C)]
 pub struct CABAC<'a> {
     pub src:    &'a [u8],
     pub pos:    usize,
@@ -129,6 +132,7 @@ impl<'a> CABAC<'a> {
             false
         }
     }
+    #[cfg(not(target_arch="x86_64"))]
     pub fn decode_bit(&mut self, idx: usize) -> bool {
         let mut val_mps = (self.states[idx] & 0x80) != 0;
         let state_idx = (self.states[idx] & 0x3F) as usize;
@@ -153,6 +157,68 @@ impl<'a> CABAC<'a> {
         self.renorm();
         bit
     }
+    #[cfg(target_arch="x86_64")]
+    pub fn decode_bit(&mut self, idx: usize) -> bool {
+        unsafe {
+            // states offset - 0x18
+            // cod_range offset - 0x418
+            // cod_offset offset - 0x41A
+            let mut bit: u16;
+            asm!(
+                // unpack state
+                "movzx  {state_idx:e},  byte ptr [{ctx} + 0x18 + {idx}]",
+                "mov    {val_mps:x},    {state_idx:x}",
+                "and    {state_idx},    0x3F",
+                "and    {val_mps:r},    0x80",
+                "movzx  {tmp},          word ptr [{ctx} + 0x418]",
+                "mov    {bit:r},        {val_mps:r}",
+                "shr    {tmp},          6",
+                "and    {tmp},          3",
+                "lea    {range_lps:r},  {range_tab}[rip]",
+                "lea    {range_lps:r},  [{range_lps:r} + {state_idx} * 4]",
+                "movzx  {range_lps:x},  byte ptr [{range_lps:r} + {tmp}]",
+                // self.cod_range -= range_lps;
+                "sub    word ptr [{ctx} + 0x418], {range_lps:x}",
+                // determine bit value
+                "mov    {tmp:x},        word ptr [{ctx} + 0x41A]",
+                "cmp    {tmp:x},        word ptr [{ctx} + 0x418]",
+                "jl     1f",
+                "sub    {tmp:x},        word ptr [{ctx} + 0x418]",
+                "mov    word ptr [{ctx} + 0x418], {range_lps:x}",
+                "mov    word ptr [{ctx} + 0x41A], {tmp:x}",
+                "xor    {bit:l},        0x80",
+                "1:",
+                // update state[idx]
+                "cmp    {bit:x},        {val_mps:x}",
+                "jne    2f",
+                "lea    {tmp},          {trans_idx_mps}[rip]",
+                "jmp    3f",
+                "2:",
+                "lea    {tmp},          {trans_idx_lps}[rip]",
+                "cmp    {state_idx},    0",
+                "jnz    3f",
+                "xor    {val_mps:x},    0x80",
+                "3:",
+                "movzx  {tmp},          byte ptr [{tmp} + {state_idx}]",
+                "or     {tmp:x},        {val_mps:x}",
+                "mov    byte ptr [{ctx} + 0x18 + {idx}], {tmp:l}",
+
+                ctx = inout(reg) self => _,
+                idx = inout(reg) idx => _,
+                bit = out(reg) bit,
+                range_tab = sym RANGE_TBL_LPS,
+                trans_idx_mps = sym TRANS_IDX_MPS,
+                trans_idx_lps = sym TRANS_IDX_LPS,
+                val_mps = out(reg) _,
+                state_idx = out(reg) _,
+                tmp = out(reg) _,
+                range_lps = out(reg) _,
+            );
+
+            self.renorm();
+            bit != 0
+        }
+    }
     pub fn decode_bits(&mut self, mut start: usize, maxidx: usize, len: usize) -> u8 {
         let mut val = 0;
         for _ in 0..len {
@@ -195,7 +261,7 @@ impl<'a> CABAC<'a> {
     }
 }
 
-const RANGE_TBL_LPS: [u8; 64 * 4] = [
+static RANGE_TBL_LPS: [u8; 64 * 4] = [
     128, 176, 208, 240,
     128, 167, 197, 227,
     128, 158, 187, 216,
@@ -261,13 +327,13 @@ const RANGE_TBL_LPS: [u8; 64 * 4] = [
       6,   7,   8,   9,
       2,   2,   2,   2
 ];
-const TRANS_IDX_MPS: [u8; 64] = [
+static TRANS_IDX_MPS: [u8; 64] = [
      1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
     17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
     33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
     49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 62, 63
 ];
-const TRANS_IDX_LPS: [u8; 64] = [
+static TRANS_IDX_LPS: [u8; 64] = [
      0,  0,  1,  2,  2,  4,  4,  5,  6,  7,  8,  9,  9, 11, 11, 12,
     13, 13, 15, 15, 16, 16, 18, 18, 19, 19, 21, 21, 22, 22, 23, 24,
     24, 25, 26, 26, 27, 27, 28, 29, 29, 30, 30, 30, 31, 32, 32, 33,