]> git.nihav.org Git - nihav.git/blobdiff - nihav-core/src/io/byteio.rs
core/io: make bitstream reader clone()able
[nihav.git] / nihav-core / src / io / byteio.rs
index 880e89b49191f3d06e828d4e691e3744136843a1..a10ccb9d0eef6f3f5169bf8895941bf6bbc46f84 100644 (file)
@@ -1,6 +1,7 @@
-use std::io::SeekFrom;
+pub use std::io::SeekFrom;
 use std::fs::File;
 use std::io::prelude::*;
+use std::ptr;
 
 #[derive(Debug)]
 pub enum ByteIOError {
@@ -24,7 +25,7 @@ pub trait ByteIO {
     fn write_buf(&mut self, buf: &[u8]) -> ByteIOResult<()>;
     fn tell(&mut self) -> u64;
     fn seek(&mut self, pos: SeekFrom) -> ByteIOResult<u64>;
-    fn is_eof(&mut self) -> bool;
+    fn is_eof(&self) -> bool;
     fn is_seekable(&mut self) -> bool;
     fn size(&mut self) -> i64;
 }
@@ -36,7 +37,6 @@ pub struct ByteReader<'a> {
 
 pub struct MemoryReader<'a> {
     buf:      &'a [u8],
-    size:     usize,
     pos:      usize,
 }
 
@@ -47,20 +47,20 @@ pub struct FileReader<'a> {
 
 macro_rules! read_int {
     ($s: ident, $inttype: ty, $size: expr, $which: ident) => ({
-        let mut buf = [0; $size];
-        $s.read_buf(&mut buf)?;
         unsafe {
-            Ok((*(buf.as_ptr() as *const $inttype)).$which())
+            let mut buf: $inttype = 0;
+            $s.read_buf(&mut *(&mut buf as *mut $inttype as *mut [u8; $size]))?;
+            Ok(buf.$which())
         }
     })
 }
 
 macro_rules! peek_int {
     ($s: ident, $inttype: ty, $size: expr, $which: ident) => ({
-        let mut buf = [0; $size];
-        $s.peek_buf(&mut buf)?;
         unsafe {
-            Ok((*(buf.as_ptr() as *const $inttype)).$which())
+            let mut buf: $inttype = 0;
+            $s.peek_buf(&mut *(&mut buf as *mut $inttype as *mut [u8; $size]))?;
+            Ok(buf.$which())
         }
     })
 }
@@ -70,7 +70,9 @@ macro_rules! read_int_func {
         pub fn $s(src: &[u8]) -> ByteIOResult<$inttype> {
             if src.len() < $size { return Err(ByteIOError::ReadError); }
             unsafe {
-                Ok((*(src.as_ptr() as *const $inttype)).$which())
+                let mut buf: $inttype = 0;
+                ptr::copy_nonoverlapping(src.as_ptr(), &mut buf as *mut $inttype as *mut u8, std::mem::size_of::<$inttype>());
+                Ok(buf.$which())
             }
         }
     }
@@ -85,15 +87,58 @@ read_int_func!(read_u64le, u64, 8, to_le);
 
 pub fn read_u24be(src: &[u8]) -> ByteIOResult<u32> {
     if src.len() < 3 { return Err(ByteIOError::ReadError); }
-    Ok(((src[0] as u32) << 16) | ((src[1] as u32) << 8) | (src[2] as u32))
+    Ok((u32::from(src[0]) << 16) | (u32::from(src[1]) << 8) | u32::from(src[2]))
 }
 pub fn read_u24le(src: &[u8]) -> ByteIOResult<u32> {
     if src.len() < 3 { return Err(ByteIOError::ReadError); }
-    Ok(((src[2] as u32) << 16) | ((src[1] as u32) << 8) | (src[0] as u32))
+    Ok((u32::from(src[2]) << 16) | (u32::from(src[1]) << 8) | u32::from(src[0]))
 }
+pub fn read_f32be(src: &[u8]) -> ByteIOResult<f32> { Ok(f32::from_bits(read_u32be(src)?)) }
+pub fn read_f32le(src: &[u8]) -> ByteIOResult<f32> { Ok(f32::from_bits(read_u32le(src)?)) }
+pub fn read_f64be(src: &[u8]) -> ByteIOResult<f64> { Ok(f64::from_bits(read_u64be(src)?)) }
+pub fn read_f64le(src: &[u8]) -> ByteIOResult<f64> { Ok(f64::from_bits(read_u64le(src)?)) }
+
+macro_rules! write_int_func {
+    ($s: ident, $inttype: ty, $size: expr, $which: ident) => {
+        pub fn $s(dst: &mut [u8], val: $inttype) -> ByteIOResult<()> {
+            if dst.len() < $size { return Err(ByteIOError::WriteError); }
+            unsafe {
+                let val = val.$which();
+                ptr::copy_nonoverlapping(&val as *const $inttype as *const u8, dst.as_mut_ptr(), std::mem::size_of::<$inttype>());
+            }
+            Ok(())
+        }
+    }
+}
+
+write_int_func!(write_u16be, u16, 2, to_be);
+write_int_func!(write_u16le, u16, 2, to_le);
+write_int_func!(write_u32be, u32, 4, to_be);
+write_int_func!(write_u32le, u32, 4, to_le);
+write_int_func!(write_u64be, u64, 8, to_be);
+write_int_func!(write_u64le, u64, 8, to_le);
+
+pub fn write_u24be(dst: &mut [u8], val: u32) -> ByteIOResult<()> {
+    if dst.len() < 3 { return Err(ByteIOError::WriteError); }
+    dst[0] = (val >> 16) as u8;
+    dst[1] = (val >>  8) as u8;
+    dst[2] = (val >>  0) as u8;
+    Ok(())
+}
+pub fn write_u24le(dst: &mut [u8], val: u32) -> ByteIOResult<()> {
+    if dst.len() < 3 { return Err(ByteIOError::WriteError); }
+    dst[0] = (val >>  0) as u8;
+    dst[1] = (val >>  8) as u8;
+    dst[2] = (val >> 16) as u8;
+    Ok(())
+}
+pub fn write_f32be(dst: &mut [u8], val: f32) -> ByteIOResult<()> { write_u32be(dst, val.to_bits()) }
+pub fn write_f32le(dst: &mut [u8], val: f32) -> ByteIOResult<()> { write_u32le(dst, val.to_bits()) }
+pub fn write_f64be(dst: &mut [u8], val: f64) -> ByteIOResult<()> { write_u64be(dst, val.to_bits()) }
+pub fn write_f64le(dst: &mut [u8], val: f64) -> ByteIOResult<()> { write_u64le(dst, val.to_bits()) }
 
 impl<'a> ByteReader<'a> {
-    pub fn new(io: &'a mut ByteIO) -> Self { ByteReader { io: io } }
+    pub fn new(io: &'a mut ByteIO) -> Self { ByteReader { io } }
 
     pub fn read_buf(&mut self, buf: &mut [u8])  -> ByteIOResult<usize> {
         self.io.read_buf(buf)
@@ -126,13 +171,13 @@ impl<'a> ByteReader<'a> {
     pub fn read_u24be(&mut self) -> ByteIOResult<u32> {
         let p16 = self.read_u16be()?;
         let p8 = self.read_byte()?;
-        Ok(((p16 as u32) << 8) | (p8 as u32))
+        Ok((u32::from(p16) << 8) | u32::from(p8))
     }
 
     pub fn peek_u24be(&mut self) -> ByteIOResult<u32> {
         let mut src: [u8; 3] = [0; 3];
         self.peek_buf(&mut src)?;
-        Ok(((src[0] as u32) << 16) | ((src[1] as u32) << 8) | (src[2] as u32))
+        Ok((u32::from(src[0]) << 16) | (u32::from(src[1]) << 8) | u32::from(src[2]))
     }
 
     pub fn read_u32be(&mut self) -> ByteIOResult<u32> {
@@ -151,6 +196,22 @@ impl<'a> ByteReader<'a> {
         peek_int!(self, u64, 8, to_be)
     }
 
+    pub fn read_f32be(&mut self) -> ByteIOResult<f32> {
+        Ok(f32::from_bits(self.read_u32be()?))
+    }
+
+    pub fn peek_f32be(&mut self) -> ByteIOResult<f32> {
+        Ok(f32::from_bits(self.peek_u32be()?))
+    }
+
+    pub fn read_f64be(&mut self) -> ByteIOResult<f64> {
+        Ok(f64::from_bits(self.read_u64be()?))
+    }
+
+    pub fn peek_f64be(&mut self) -> ByteIOResult<f64> {
+        Ok(f64::from_bits(self.peek_u64be()?))
+    }
+
     pub fn read_u16le(&mut self) -> ByteIOResult<u16> {
         read_int!(self, u16, 2, to_le)
     }
@@ -162,13 +223,13 @@ impl<'a> ByteReader<'a> {
     pub fn read_u24le(&mut self) -> ByteIOResult<u32> {
         let p8 = self.read_byte()?;
         let p16 = self.read_u16le()?;
-        Ok(((p16 as u32) << 8) | (p8 as u32))
+        Ok((u32::from(p16) << 8) | u32::from(p8))
     }
 
     pub fn peek_u24le(&mut self) -> ByteIOResult<u32> {
         let mut src: [u8; 3] = [0; 3];
         self.peek_buf(&mut src)?;
-        Ok((src[0] as u32) | ((src[1] as u32) << 8) | ((src[2] as u32) << 16))
+        Ok(u32::from(src[0]) | (u32::from(src[1]) << 8) | (u32::from(src[2]) << 16))
     }
 
     pub fn read_u32le(&mut self) -> ByteIOResult<u32> {
@@ -187,6 +248,22 @@ impl<'a> ByteReader<'a> {
         peek_int!(self, u64, 8, to_le)
     }
 
+    pub fn read_f32le(&mut self) -> ByteIOResult<f32> {
+        Ok(f32::from_bits(self.read_u32le()?))
+    }
+
+    pub fn peek_f32le(&mut self) -> ByteIOResult<f32> {
+        Ok(f32::from_bits(self.peek_u32le()?))
+    }
+
+    pub fn read_f64le(&mut self) -> ByteIOResult<f64> {
+        Ok(f64::from_bits(self.read_u64le()?))
+    }
+
+    pub fn peek_f64le(&mut self) -> ByteIOResult<f64> {
+        Ok(f64::from_bits(self.peek_u64le()?))
+    }
+
     pub fn read_skip(&mut self, len: usize) -> ByteIOResult<()> {
         if self.io.is_seekable() {
             self.io.seek(SeekFrom::Current(len as i64))?;
@@ -200,7 +277,7 @@ impl<'a> ByteReader<'a> {
             }
             while ssize > 0 {
                 self.io.read_byte()?;
-                ssize = ssize - 1;
+                ssize -= 1;
             }
         }
         Ok(())
@@ -214,7 +291,7 @@ impl<'a> ByteReader<'a> {
         self.io.seek(pos)
     }
 
-    pub fn is_eof(&mut self) -> bool {
+    pub fn is_eof(&self) -> bool {
         self.io.is_eof()
     }
 
@@ -225,19 +302,19 @@ impl<'a> ByteReader<'a> {
     pub fn left(&mut self) -> i64 {
         let size = self.io.size();
         if size == -1 { return -1; }
-        return size - (self.io.tell() as i64)
+        size - (self.io.tell() as i64)
     }
 }
 
 impl<'a> MemoryReader<'a> {
 
     pub fn new_read(buf: &'a [u8]) -> Self {
-        MemoryReader { buf: buf, size: buf.len(), pos: 0 }
+        MemoryReader { buf, pos: 0 }
     }
 
     fn real_seek(&mut self, pos: i64) -> ByteIOResult<u64> {
-        if pos < 0 || (pos as usize) > self.size {
-            return Err(ByteIOError::WrongRange)
+        if pos < 0 || (pos as usize) > self.buf.len() {
+            return Err(ByteIOError::WrongRange);
         }
         self.pos = pos as usize;
         Ok(pos as u64)
@@ -248,7 +325,7 @@ impl<'a> ByteIO for MemoryReader<'a> {
     fn read_byte(&mut self) -> ByteIOResult<u8> {
         if self.is_eof() { return Err(ByteIOError::EOF); }
         let res = self.buf[self.pos];
-        self.pos = self.pos + 1;
+        self.pos += 1;
         Ok(res)
     }
 
@@ -258,11 +335,10 @@ impl<'a> ByteIO for MemoryReader<'a> {
     }
 
     fn peek_buf(&mut self, buf: &mut [u8]) -> ByteIOResult<usize> {
-        let copy_size = if self.size - self.pos < buf.len() { self.size } else { buf.len() };
+        let copy_size = if self.buf.len() - self.pos < buf.len() { self.buf.len() } else { buf.len() };
         if copy_size == 0 { return Err(ByteIOError::EOF); }
-        for i in 0..copy_size {
-            buf[i] = self.buf[self.pos + i];
-        }
+        let dst = &mut buf[0..copy_size];
+        dst.copy_from_slice(&self.buf[self.pos..][..copy_size]);
         Ok(copy_size)
     }
 
@@ -289,8 +365,8 @@ impl<'a> ByteIO for MemoryReader<'a> {
     }
 
     fn seek(&mut self, pos: SeekFrom) -> ByteIOResult<u64> {
-        let cur_pos  = self.pos  as i64;
-        let cur_size = self.size as i64;
+        let cur_pos  = self.pos       as i64;
+        let cur_size = self.buf.len() as i64;
         match pos {
             SeekFrom::Start(x)   => self.real_seek(x as i64),
             SeekFrom::Current(x) => self.real_seek(cur_pos + x),
@@ -298,8 +374,8 @@ impl<'a> ByteIO for MemoryReader<'a> {
         }
     }
 
-    fn is_eof(&mut self) -> bool {
-        self.pos >= self.size
+    fn is_eof(&self) -> bool {
+        self.pos >= self.buf.len()
     }
 
     fn is_seekable(&mut self) -> bool {
@@ -314,16 +390,16 @@ impl<'a> ByteIO for MemoryReader<'a> {
 impl<'a> FileReader<'a> {
 
     pub fn new_read(file: &'a mut File) -> Self {
-        FileReader { file: file, eof : false }
+        FileReader { file, eof : false }
     }
 }
 
 impl<'a> ByteIO for FileReader<'a> {
     fn read_byte(&mut self) -> ByteIOResult<u8> {
         let mut byte : [u8; 1] = [0];
-        let err = self.file.read(&mut byte);
-        if let Err(_) = err { return Err(ByteIOError::ReadError); }
-        let sz = err.unwrap();
+        let ret = self.file.read(&mut byte);
+        if ret.is_err() { return Err(ByteIOError::ReadError); }
+        let sz = ret.unwrap();
         if sz == 0 { self.eof = true; return Err(ByteIOError::EOF); }
         Ok (byte[0])
     }
@@ -335,17 +411,17 @@ impl<'a> ByteIO for FileReader<'a> {
     }
 
     fn read_buf(&mut self, buf: &mut [u8]) -> ByteIOResult<usize> {
-        let res = self.file.read(buf);
-        if let Err(_) = res { return Err(ByteIOError::ReadError); }
-        let sz = res.unwrap();
+        let ret = self.file.read(buf);
+        if ret.is_err() { return Err(ByteIOError::ReadError); }
+        let sz = ret.unwrap();
         if sz < buf.len() { self.eof = true; return Err(ByteIOError::EOF); }
         Ok(sz)
     }
 
     fn read_buf_some(&mut self, buf: &mut [u8]) -> ByteIOResult<usize> {
-        let res = self.file.read(buf);
-        if let Err(_) = res { return Err(ByteIOError::ReadError); }
-        let sz = res.unwrap();
+        let ret = self.file.read(buf);
+        if ret.is_err() { return Err(ByteIOError::ReadError); }
+        let sz = ret.unwrap();
         if sz < buf.len() { self.eof = true; }
         Ok(sz)
     }
@@ -373,7 +449,7 @@ impl<'a> ByteIO for FileReader<'a> {
         }
     }
 
-    fn is_eof(&mut self) -> bool {
+    fn is_eof(&self) -> bool {
         self.eof
     }
 
@@ -393,7 +469,6 @@ pub struct ByteWriter<'a> {
 
 pub struct MemoryWriter<'a> {
     buf:      &'a mut [u8],
-    size:     usize,
     pos:      usize,
 }
 
@@ -402,7 +477,7 @@ pub struct FileWriter {
 }
 
 impl<'a> ByteWriter<'a> {
-    pub fn new(io: &'a mut ByteIO) -> Self { ByteWriter { io: io } }
+    pub fn new(io: &'a mut ByteIO) -> Self { ByteWriter { io } }
 
     pub fn write_buf(&mut self, buf: &[u8])  -> ByteIOResult<()> {
         self.io.write_buf(buf)
@@ -444,13 +519,29 @@ impl<'a> ByteWriter<'a> {
     }
 
     pub fn write_u64be(&mut self, val: u64) -> ByteIOResult<()> {
-        self.write_u32be(((val >> 32) & 0xFFFFFFFF) as u32)?;
-        self.write_u32be((val & 0xFFFFFFFF) as u32)
+        self.write_u32be((val >> 32) as u32)?;
+        self.write_u32be(val as u32)
     }
 
     pub fn write_u64le(&mut self, val: u64) -> ByteIOResult<()> {
-        self.write_u32le((val & 0xFFFFFFFF) as u32)?;
-        self.write_u32le(((val >> 32) & 0xFFFFFFFF) as u32)
+        self.write_u32le(val as u32)?;
+        self.write_u32le((val >> 32) as u32)
+    }
+
+    pub fn write_f32be(&mut self, val: f32) -> ByteIOResult<()> {
+        self.write_u32be(val.to_bits())
+    }
+
+    pub fn write_f32le(&mut self, val: f32) -> ByteIOResult<()> {
+        self.write_u32le(val.to_bits())
+    }
+
+    pub fn write_f64be(&mut self, val: f64) -> ByteIOResult<()> {
+        self.write_u64be(val.to_bits())
+    }
+
+    pub fn write_f64le(&mut self, val: f64) -> ByteIOResult<()> {
+        self.write_u64le(val.to_bits())
     }
 
     pub fn tell(&mut self) -> u64 {
@@ -471,12 +562,11 @@ impl<'a> ByteWriter<'a> {
 impl<'a> MemoryWriter<'a> {
 
     pub fn new_write(buf: &'a mut [u8]) -> Self {
-        let len = buf.len();
-        MemoryWriter { buf: buf, size: len, pos: 0 }
+        MemoryWriter { buf, pos: 0 }
     }
 
     fn real_seek(&mut self, pos: i64) -> ByteIOResult<u64> {
-        if pos < 0 || (pos as usize) > self.size {
+        if pos < 0 || (pos as usize) > self.buf.len() {
             return Err(ByteIOError::WrongRange)
         }
         self.pos = pos as usize;
@@ -511,7 +601,7 @@ impl<'a> ByteIO for MemoryWriter<'a> {
     }
 
     fn write_buf(&mut self, buf: &[u8]) -> ByteIOResult<()> {
-        if self.pos + buf.len() > self.size { return Err(ByteIOError::WriteError); }
+        if self.pos + buf.len() > self.buf.len() { return Err(ByteIOError::WriteError); }
         for i in 0..buf.len() {
             self.buf[self.pos + i] = buf[i];
         }
@@ -524,8 +614,8 @@ impl<'a> ByteIO for MemoryWriter<'a> {
     }
 
     fn seek(&mut self, pos: SeekFrom) -> ByteIOResult<u64> {
-        let cur_pos  = self.pos  as i64;
-        let cur_size = self.size as i64;
+        let cur_pos  = self.pos       as i64;
+        let cur_size = self.buf.len() as i64;
         match pos {
             SeekFrom::Start(x)   => self.real_seek(x as i64),
             SeekFrom::Current(x) => self.real_seek(cur_pos + x),
@@ -533,8 +623,8 @@ impl<'a> ByteIO for MemoryWriter<'a> {
         }
     }
 
-    fn is_eof(&mut self) -> bool {
-        self.pos >= self.size
+    fn is_eof(&self) -> bool {
+        self.pos >= self.buf.len()
     }
 
     fn is_seekable(&mut self) -> bool {
@@ -548,7 +638,7 @@ impl<'a> ByteIO for MemoryWriter<'a> {
 
 impl FileWriter {
     pub fn new_write(file: File) -> Self {
-        FileWriter { file: file }
+        FileWriter { file }
     }
 }
 
@@ -597,7 +687,7 @@ impl ByteIO for FileWriter {
         }
     }
 
-    fn is_eof(&mut self) -> bool {
+    fn is_eof(&self) -> bool {
         false
     }
 
@@ -626,7 +716,7 @@ mod test {
         assert_eq!(reader.read_u24le().unwrap(), 0x010101u32);
         assert_eq!(reader.read_u32le().unwrap(), 0x01010101u32);
         assert_eq!(reader.read_u64le().unwrap(), 0x0101010101010101u64);
-        let mut file = File::open("assets/MaoMacha.asx").unwrap();
+        let mut file = File::open("assets/Misc/MaoMacha.asx").unwrap();
         let mut fr = FileReader::new_read(&mut file);
         let mut br2 = ByteReader::new(&mut fr);
         assert_eq!(br2.read_byte().unwrap(), 0x30);