core/compr: make a faster version of Inflate::uncompress()
authorKostya Shishkov <kostya.shishkov@gmail.com>
Fri, 23 Dec 2022 07:39:09 +0000 (08:39 +0100)
committerKostya Shishkov <kostya.shishkov@gmail.com>
Fri, 23 Dec 2022 07:39:09 +0000 (08:39 +0100)
nihav-core/src/compr/deflate.rs

index 71e0dfed331d8292ac9675d74e06a1c77c8a67ef..8fc6a13e873f534f93bd35ef575f95012b50ea77 100644 (file)
@@ -172,6 +172,23 @@ impl<'a> CurrentSource<'a> {
         }
         Ok(())
     }
+    fn skip_bytes(&mut self, nbytes: usize) -> BitReaderResult<()> {
+        self.align();
+        let cached = usize::from(self.br.bits / 8);
+        if nbytes <= cached {
+            self.skip((nbytes as u32) * 8)?;
+        } else {
+            self.skip((cached as u32) * 8)?;
+            self.br.bits = 0;
+            self.br.bitbuf = 0;
+            self.br.pos += nbytes - cached;
+            if self.br.pos > self.src.len() {
+                return Err(BitReaderError::BitstreamEnd);
+            }
+            self.refill();
+        }
+        Ok(())
+    }
     fn align(&mut self) {
         let b = self.br.bits & 7;
         if b != 0 {
@@ -181,6 +198,9 @@ impl<'a> CurrentSource<'a> {
     fn left(&self) -> isize {
         ((self.src.len() as isize) - (self.br.pos as isize)) * 8 + (self.br.bits as isize)
     }
+    fn tell(&self) -> usize {
+        self.br.pos - usize::from(self.br.bits / 8)
+    }
 }
 
 impl<'a, S: Copy> CodebookReader<S> for CurrentSource<'a> {
@@ -772,8 +792,210 @@ impl Inflate {
 
     ///! Decompresses input data into output returning the uncompressed data length.
     pub fn uncompress(src: &[u8], dst: &mut [u8]) -> DecompressResult<usize> {
-        let mut inflate = Self::new();
-        inflate.decompress_data(src, dst, false)
+        let mut csrc = CurrentSource::new(src, BitReaderState::default());
+        if src.len() > 2 {
+            let cm    = src[0] & 0xF;
+            let cinfo = src[0] >> 4;
+            let hdr   = (u16::from(src[0]) << 8) | u16::from(src[1]);
+            if cm == 8 && cinfo <= 7 && (hdr % 31) == 0 {
+                csrc.skip(16).unwrap();
+            }
+        }
+
+        let mut fix_len_cb = None;
+
+        let mut dst_idx = 0;
+        let mut final_block = false;
+        while !final_block {
+            final_block = csrc.read_bool()?;
+
+            let bmode = csrc.read(2)?;
+            match bmode {
+                0 => {
+                                  csrc.align();
+                    let len     = csrc.read(16)? as usize;
+                    let inv_len = csrc.read(16)? as usize;
+                    if (len ^ inv_len) != 0xFFFF {
+                        return Err(DecompressError::InvalidHeader);
+                    }
+                    let src_pos = csrc.tell();
+                    if src_pos + len > src.len() {
+                        return Err(DecompressError::ShortData);
+                    }
+                    if dst_idx + len > dst.len() {
+                        return Err(DecompressError::OutputFull);
+                    }
+                    dst[dst_idx..][..len].copy_from_slice(&src[src_pos..][..len]);
+                    dst_idx += len;
+                                  csrc.skip_bytes(len)?;
+                },
+                1 => {
+                    if fix_len_cb.is_none() {
+                        let mut cr = FixedLenCodeReader {};
+                        fix_len_cb = Some(Codebook::new(&mut cr, CodebookMode::LSB).unwrap());
+                    }
+                    if let Some(ref len_cb) = &fix_len_cb {
+                        loop {
+                            let val = csrc.read_cb(len_cb)?;
+                            if val < 256 {
+                                if dst_idx >= dst.len() {
+                                    return Err(DecompressError::OutputFull);
+                                }
+                                dst[dst_idx] = val as u8;
+                                dst_idx += 1;
+                            } else if val == 256 {
+                                break;
+                            } else {
+                                let len_idx = (val - 257) as usize;
+                                if len_idx >= LENGTH_BASE.len() {
+                                    return Err(DecompressError::InvalidData);
+                                }
+                                let len_bits = LENGTH_ADD_BITS[len_idx];
+                                let mut length = LENGTH_BASE[len_idx] as usize;
+                                if len_bits > 0 {
+                                    length += csrc.read(len_bits)? as usize;
+                                }
+                                let dist_idx = reverse_bits(csrc.read(5)?, 5) as usize;
+                                if dist_idx >= DIST_BASE.len() {
+                                    return Err(DecompressError::InvalidData);
+                                }
+                                let dist_bits = DIST_ADD_BITS[dist_idx];
+                                let mut dist = DIST_BASE[dist_idx] as usize;
+                                if dist_bits > 0 {
+                                    dist += csrc.read(dist_bits)? as usize;
+                                }
+
+                                if dst_idx + length > dst.len() {
+                                    return Err(DecompressError::OutputFull);
+                                }
+                                if dist > dst_idx {
+                                    return Err(DecompressError::InvalidData);
+                                }
+                                lz_copy(dst, dst_idx, dist, length);
+                                dst_idx += length;
+                            }
+                        }
+                    } else {
+                        unreachable!();
+                    }
+                },
+                2 => {
+                    let hlit = csrc.read(5)? as usize + 257;
+                    if hlit >= 287 {
+                        return Err(DecompressError::InvalidHeader);
+                    }
+                    let hdist = csrc.read(5)? as usize + 1;
+                    let hclen = csrc.read(4)? as usize + 4;
+                    let mut cur_len_idx = 0;
+                    let mut len_lengths = [0; 19];
+                    let mut all_lengths = [0; NUM_LITERALS + NUM_DISTS];
+
+                    for _ in 0..hclen {
+                        len_lengths[LEN_RECODE[cur_len_idx]] = csrc.read(3)? as u8;
+                        cur_len_idx += 1;
+                    }
+                    let mut len_codes = [ShortCodebookDesc { code: 0, bits: 0 }; 19];
+                    lengths_to_codes(&len_lengths, &mut len_codes)?;
+                    let mut cr = ShortCodebookDescReader::new(len_codes.to_vec());
+                    let ret = Codebook::new(&mut cr, CodebookMode::LSB);
+                    if ret.is_err() {
+                        return Err(DecompressError::InvalidHeader);
+                    }
+                    let dyn_len_cb = ret.unwrap();
+
+                    let mut cur_len_idx = 0;
+                    while cur_len_idx < hlit + hdist {
+                        let val = csrc.read_cb(&dyn_len_cb)?;
+                        if val < 16 {
+                            all_lengths[cur_len_idx] = val as u8;
+                            cur_len_idx += 1;
+                        } else {
+                            let mode = (val as usize) - 16;
+                            if mode > 2 {
+                                return Err(DecompressError::InvalidHeader);
+                            }
+                            let base = REPEAT_BASE[mode] as usize;
+                            let bits = REPEAT_BITS[mode];
+                            let len = base + (csrc.read(bits)? as usize);
+                            if cur_len_idx + len > hlit + hdist {
+                                return Err(DecompressError::InvalidHeader);
+                            }
+                            let rpt = if mode == 0 {
+                                    if cur_len_idx == 0 {
+                                        return Err(DecompressError::InvalidHeader);
+                                    }
+                                    all_lengths[cur_len_idx - 1]
+                                } else {
+                                    0
+                                };
+                            for _ in 0..len {
+                                all_lengths[cur_len_idx] = rpt;
+                                cur_len_idx += 1;
+                            }
+                        }
+                    }
+                    let (lit_lengths, dist_lengths) = all_lengths.split_at(hlit);
+
+                    let mut lit_codes = [ShortCodebookDesc { code: 0, bits: 0 }; NUM_LITERALS];
+                    lengths_to_codes(lit_lengths, &mut lit_codes)?;
+                    let mut cr = ShortCodebookDescReader::new(lit_codes.to_vec());
+                    let ret = Codebook::new(&mut cr, CodebookMode::LSB);
+                    if ret.is_err() { return Err(DecompressError::InvalidHeader); }
+                    let dyn_lit_cb = ret.unwrap();
+
+                    let mut dist_codes = [ShortCodebookDesc { code: 0, bits: 0 }; NUM_DISTS];
+                    lengths_to_codes(&dist_lengths[..hdist], &mut dist_codes)?;
+                    let mut cr = ShortCodebookDescReader::new(dist_codes.to_vec());
+                    let ret = Codebook::new(&mut cr, CodebookMode::LSB);
+                    if ret.is_err() { return Err(DecompressError::InvalidHeader); }
+                    let dyn_dist_cb = ret.unwrap();
+
+                    loop {
+                        let val = csrc.read_cb(&dyn_lit_cb)?;
+                        if val < 256 {
+                            if dst_idx >= dst.len() {
+                                return Err(DecompressError::OutputFull);
+                            }
+                            dst[dst_idx] = val as u8;
+                            dst_idx += 1;
+                        } else if val == 256 {
+                            break;
+                        } else {
+                            let len_idx = (val - 257) as usize;
+                            if len_idx >= LENGTH_BASE.len() {
+                                return Err(DecompressError::InvalidData);
+                            }
+                            let len_bits = LENGTH_ADD_BITS[len_idx];
+                            let mut length = LENGTH_BASE[len_idx] as usize;
+                            if len_bits > 0 {
+                                length += csrc.read(len_bits)? as usize;
+                            }
+
+                            let dist_idx = csrc.read_cb(&dyn_dist_cb)? as usize;
+                            if dist_idx >= DIST_BASE.len() {
+                                return Err(DecompressError::InvalidData);
+                            }
+                            let dist_bits = DIST_ADD_BITS[dist_idx];
+                            let mut dist = DIST_BASE[dist_idx] as usize;
+                            if dist_bits > 0 {
+                                dist += csrc.read(dist_bits)? as usize;
+                            }
+
+                            if dst_idx + length > dst.len() {
+                                return Err(DecompressError::OutputFull);
+                            }
+                            if dist > dst_idx {
+                                return Err(DecompressError::InvalidData);
+                            }
+                            lz_copy(dst, dst_idx, dist, length);
+                            dst_idx += length;
+                        }
+                    }
+                },
+                _ => return Err(DecompressError::InvalidHeader),
+            };
+        }
+        Ok(dst_idx)
     }
 }