h264: introduce inline assembly for CABAC get_bit() and bump compiler version
authorKostya Shishkov <kostya.shishkov@gmail.com>
Wed, 26 Jul 2023 16:22:42 +0000 (18:22 +0200)
committerKostya Shishkov <kostya.shishkov@gmail.com>
Wed, 26 Jul 2023 16:22:42 +0000 (18:22 +0200)
nihav-itu/Cargo.toml
nihav-itu/src/codecs/h264/cabac_coder.rs

index 19eaa50965232f9cbd34cf489e313890dcc9fb96..8badb20d49d1a9c288e8b4f48521261093a9a979 100644 (file)
@@ -3,6 +3,7 @@ name = "nihav_itu"
 version = "0.1.0"
 authors = ["Kostya Shishkov <kostya.shishkov@gmail.com>"]
 edition = "2018"
+rust-version = "1.69"
 
 [dependencies.nihav_core]
 path = "../nihav-core"
@@ -14,7 +15,7 @@ path = "../nihav-codec-support"
 nihav_commonfmt = { path = "../nihav-commonfmt", default-features=false, features = ["all_demuxers"] }
 
 [features]
-default = ["all_decoders"]
+default = ["all_decoders", "simd"]
 simd = [] #enable when the default rustc is >=1.62
 
 all_decoders = ["all_video_decoders"]
index baa8caff030757b11561eb2869ba021e1fd91c48..3e9278e1f72c36dc647273bdc00ff9501aec7a9c 100644 (file)
@@ -1,8 +1,11 @@
 use nihav_core::codecs::{DecoderResult, DecoderError};
 use super::slice::SliceType;
+#[cfg(target_arch="x86_64")]
+use std::arch::asm;
 
 const NUM_CABAC_CONTEXTS: usize = 1024;
 
+#[repr(C)]
 pub struct CABAC<'a> {
     pub src:    &'a [u8],
     pub pos:    usize,
@@ -129,6 +132,7 @@ impl<'a> CABAC<'a> {
             false
         }
     }
+    #[cfg(not(target_arch="x86_64"))]
     pub fn decode_bit(&mut self, idx: usize) -> bool {
         let mut val_mps = (self.states[idx] & 0x80) != 0;
         let state_idx = (self.states[idx] & 0x3F) as usize;
@@ -153,6 +157,68 @@ impl<'a> CABAC<'a> {
         self.renorm();
         bit
     }
+    #[cfg(target_arch="x86_64")]
+    pub fn decode_bit(&mut self, idx: usize) -> bool {
+        unsafe {
+            // states offset - 0x18
+            // cod_range offset - 0x418
+            // cod_offset offset - 0x41A
+            let mut bit: u16;
+            asm!(
+                // unpack state
+                "movzx  {state_idx:e},  byte ptr [{ctx} + 0x18 + {idx}]",
+                "mov    {val_mps:x},    {state_idx:x}",
+                "and    {state_idx},    0x3F",
+                "and    {val_mps:r},    0x80",
+                "movzx  {tmp},          word ptr [{ctx} + 0x418]",
+                "mov    {bit:r},        {val_mps:r}",
+                "shr    {tmp},          6",
+                "and    {tmp},          3",
+                "lea    {range_lps:r},  {range_tab}[rip]",
+                "lea    {range_lps:r},  [{range_lps:r} + {state_idx} * 4]",
+                "movzx  {range_lps:x},  byte ptr [{range_lps:r} + {tmp}]",
+                // self.cod_range -= range_lps;
+                "sub    word ptr [{ctx} + 0x418], {range_lps:x}",
+                // determine bit value
+                "mov    {tmp:x},        word ptr [{ctx} + 0x41A]",
+                "cmp    {tmp:x},        word ptr [{ctx} + 0x418]",
+                "jl     1f",
+                "sub    {tmp:x},        word ptr [{ctx} + 0x418]",
+                "mov    word ptr [{ctx} + 0x418], {range_lps:x}",
+                "mov    word ptr [{ctx} + 0x41A], {tmp:x}",
+                "xor    {bit:l},        0x80",
+                "1:",
+                // update state[idx]
+                "cmp    {bit:x},        {val_mps:x}",
+                "jne    2f",
+                "lea    {tmp},          {trans_idx_mps}[rip]",
+                "jmp    3f",
+                "2:",
+                "lea    {tmp},          {trans_idx_lps}[rip]",
+                "cmp    {state_idx},    0",
+                "jnz    3f",
+                "xor    {val_mps:x},    0x80",
+                "3:",
+                "movzx  {tmp},          byte ptr [{tmp} + {state_idx}]",
+                "or     {tmp:x},        {val_mps:x}",
+                "mov    byte ptr [{ctx} + 0x18 + {idx}], {tmp:l}",
+
+                ctx = inout(reg) self => _,
+                idx = inout(reg) idx => _,
+                bit = out(reg) bit,
+                range_tab = sym RANGE_TBL_LPS,
+                trans_idx_mps = sym TRANS_IDX_MPS,
+                trans_idx_lps = sym TRANS_IDX_LPS,
+                val_mps = out(reg) _,
+                state_idx = out(reg) _,
+                tmp = out(reg) _,
+                range_lps = out(reg) _,
+            );
+
+            self.renorm();
+            bit != 0
+        }
+    }
     pub fn decode_bits(&mut self, mut start: usize, maxidx: usize, len: usize) -> u8 {
         let mut val = 0;
         for _ in 0..len {
@@ -195,7 +261,7 @@ impl<'a> CABAC<'a> {
     }
 }
 
-const RANGE_TBL_LPS: [u8; 64 * 4] = [
+static RANGE_TBL_LPS: [u8; 64 * 4] = [
     128, 176, 208, 240,
     128, 167, 197, 227,
     128, 158, 187, 216,
@@ -261,13 +327,13 @@ const RANGE_TBL_LPS: [u8; 64 * 4] = [
       6,   7,   8,   9,
       2,   2,   2,   2
 ];
-const TRANS_IDX_MPS: [u8; 64] = [
+static TRANS_IDX_MPS: [u8; 64] = [
      1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
     17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
     33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
     49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 62, 63
 ];
-const TRANS_IDX_LPS: [u8; 64] = [
+static TRANS_IDX_LPS: [u8; 64] = [
      0,  0,  1,  2,  2,  4,  4,  5,  6,  7,  8,  9,  9, 11, 11, 12,
     13, 13, 15, 15, 16, 16, 18, 18, 19, 19, 21, 21, 22, 22, 23, 24,
     24, 25, 26, 26, 27, 27, 28, 29, 29, 30, 30, 30, 31, 32, 32, 33,