use NAPacketiser::attach_stream() where appropriate
[nihav-tool.git] / src / wavwriter.rs
index fbf825d88e79046a4f0dd01349472fb2175b92d0..e32a4ec69a23fb5e7a84aa72e5dd7c2d9d53d549 100644 (file)
@@ -8,6 +8,8 @@ use std::io::SeekFrom;
 pub struct WavWriter<'a> {
     io: Box<ByteWriter<'a>>,
     data_pos: u64,
+    be: bool,
+    bits: u8,
 }
 
 fn write_byte(wr: &mut ByteWriter, sample: u8) -> ByteIOResult<()> {
@@ -59,9 +61,10 @@ macro_rules! write_data {
 impl<'a> WavWriter<'a> {
     pub fn new(name: &str) -> Self {
         let file = File::create(name).unwrap();
+        let file = std::io::BufWriter::new(file);
         let fw   = Box::new(FileWriter::new_write(file));
         let io   = ByteWriter::new(Box::leak(fw));
-        WavWriter { io: Box::new(io), data_pos: 0 }
+        WavWriter { io: Box::new(io), data_pos: 0, be: false, bits: 0 }
     }
     pub fn write_header(&mut self, ainfo: NAAudioInfo) -> ByteIOResult<()> {
         let bits = ainfo.get_format().get_bits() as usize;
@@ -74,14 +77,14 @@ impl<'a> WavWriter<'a> {
         self.io.write_u32le(16)?;
         self.io.write_u16le(0x0001)?; // PCM
         self.io.write_u16le(ainfo.get_channels() as u16)?;
-        self.io.write_u32le(ainfo.get_sample_rate() as u32)?;
+        self.io.write_u32le(ainfo.get_sample_rate())?;
 
-        if bits < 16 {
-            self.io.write_u32le((ainfo.get_channels() as u32) * (ainfo.get_sample_rate() as u32))?;
+        if bits <= 8 {
+            self.io.write_u32le((ainfo.get_channels() as u32) * ainfo.get_sample_rate())?;
             self.io.write_u16le(ainfo.get_channels() as u16)?; // block align
             self.io.write_u16le(8)?;
         } else {
-            self.io.write_u32le(2 * (ainfo.get_channels() as u32) * (ainfo.get_sample_rate() as u32))?;
+            self.io.write_u32le(2 * (ainfo.get_channels() as u32) * ainfo.get_sample_rate())?;
             self.io.write_u16le((2 * ainfo.get_channels()) as u16)?; // block align
             self.io.write_u16le(16)?;
         }
@@ -89,6 +92,8 @@ impl<'a> WavWriter<'a> {
         self.io.write_buf(b"data")?;
         self.io.write_u32le(0)?;
 
+        self.bits = bits as u8;
+        self.be = ainfo.get_format().is_be();
         self.data_pos = self.io.tell();
         Ok(())
     }
@@ -107,7 +112,51 @@ impl<'a> WavWriter<'a> {
                 write_data!(&mut self.io, buf, write_f32);
             }
             NABufferType::AudioPacked(ref buf) => {
-                self.io.write_buf(buf.get_data().as_slice())?;
+                let data = buf.get_data();
+                match self.bits {
+                    _ if !self.be && (self.bits & 7) == 0 => {
+                        self.io.write_buf(data.as_slice())?;
+                    },
+                    8 => {
+                        self.io.write_buf(data.as_slice())?;
+                    },
+                    12 if !self.be => {
+                        let mut src = data.chunks_exact(3);
+                        for chunk in src.by_ref() {
+                            self.io.write_byte(chunk[0] << 4)?;
+                            self.io.write_byte((chunk[1] << 4) | (chunk[0] >> 4))?;
+                            self.io.write_byte(chunk[1] & 0xF0)?;
+                            self.io.write_byte(chunk[2])?;
+                        }
+                        let tail = src.remainder();
+                        if tail.len() == 2 {
+                            self.io.write_byte(tail[0] << 4)?;
+                            self.io.write_byte(tail[1] << 4)?;
+                        }
+                    }
+                    16 => {
+                        for samp in data.chunks(2) {
+                            self.io.write_byte(samp[1])?;
+                            self.io.write_byte(samp[0])?;
+                        }
+                    },
+                    24 => {
+                        for samp in data.chunks(3) {
+                            self.io.write_byte(samp[2])?;
+                            self.io.write_byte(samp[1])?;
+                            self.io.write_byte(samp[0])?;
+                        }
+                    },
+                    32 => {
+                        for samp in data.chunks(4) {
+                            self.io.write_byte(samp[3])?;
+                            self.io.write_byte(samp[2])?;
+                            self.io.write_byte(samp[1])?;
+                            self.io.write_byte(samp[0])?;
+                        }
+                    },
+                    _ => unimplemented!(),
+                };
             }
             _ => {},
         };
@@ -123,7 +172,8 @@ impl<'a> Drop for WavWriter<'a> {
             let res = self.io.seek(SeekFrom::Start(4));
             let res = self.io.write_u32le((size - 8) as u32);
             let res = self.io.seek(SeekFrom::Start(self.data_pos - 4));
-            let res = self.io.write_u32le(((size as u64) - self.data_pos) as u32);
+            let res = self.io.write_u32le((size - self.data_pos) as u32);
+            let res = self.io.flush();
         }
     }
 }