wav: try to read as much PCM data as possible
[nihav.git] / nihav-commonfmt / src / demuxers / wav.rs
CommitLineData
283abfa6
KS
1use nihav_core::demuxers::*;
2use nihav_registry::register;
3use nihav_core::demuxers::DemuxerError::*;
4
5macro_rules! mktag {
6 ($a:expr, $b:expr, $c:expr, $d:expr) => {
61cab15b 7 (u32::from($a) << 24) | (u32::from($b) << 16) | (u32::from($c) << 8) | u32::from($d)
283abfa6
KS
8 };
9 ($arr:expr) => {
61cab15b 10 (u32::from($arr[0]) << 24) | (u32::from($arr[1]) << 16) | (u32::from($arr[2]) << 8) | u32::from($arr[3])
283abfa6
KS
11 };
12}
13
14struct WAVDemuxer<'a> {
15 src: &'a mut ByteReader<'a>,
16 data_pos: u64,
17 data_end: u64,
18 srate: u32,
19 block_size: usize,
20 is_pcm: bool,
21 avg_bytes: u32,
a480a0de 22 duration: u64,
34219db3
KS
23
24 force_tb_num: u32,
25 force_tb_den: u32,
283abfa6
KS
26}
27
28impl<'a> DemuxCore<'a> for WAVDemuxer<'a> {
29 fn open(&mut self, strmgr: &mut StreamManager, seek_index: &mut SeekIndex) -> DemuxerResult<()> {
30 let riff = self.src.read_u32be()?;
31 let riff_size = self.src.read_u32le()? as usize;
ac818eac 32 let riff_end = self.src.tell() + if riff_size > 0 { riff_size as u64 } else { u64::from(std::u32::MAX) };
283abfa6
KS
33 let wave = self.src.read_u32be()?;
34 validate!(riff == mktag!(b"RIFF"));
35 validate!(wave == mktag!(b"WAVE"));
36
37 seek_index.mode = SeekIndexMode::Automatic;
38
39 let mut fmt_parsed = false;
a480a0de 40 let mut duration = 0;
283abfa6
KS
41 while self.src.tell() < riff_end {
42 let ctype = self.src.read_tag()?;
43 let csize = self.src.read_u32le()? as usize;
44 match &ctype {
45 b"fmt " => {
46 validate!(!fmt_parsed);
47 self.parse_fmt(strmgr, csize)?;
48 fmt_parsed = true;
49 },
50 b"fact" => {
51 validate!(csize == 4);
a480a0de 52 duration = self.src.read_u32le()? as usize;
283abfa6
KS
53 },
54 b"data" => {
55 validate!(fmt_parsed);
56 self.data_pos = self.src.tell();
57 self.data_end = self.data_pos + (csize as u64);
a480a0de
KS
58
59 if duration != 0 {
60 self.duration = (duration as u64) * 1000 / u64::from(self.srate);
61 } else if self.avg_bytes > 0 {
62 self.duration = (self.data_end - self.data_pos) * 1000 / u64::from(self.avg_bytes);
63 } else {
64 self.duration = 0;
65 }
66
283abfa6
KS
67 return Ok(());
68 },
69 _ => {
70 self.src.read_skip(csize)?;
71 },
72 };
73 }
74 Err(DemuxerError::InvalidData)
75 }
76
77 fn get_frame(&mut self, strmgr: &mut StreamManager) -> DemuxerResult<NAPacket> {
78 if self.src.tell() >= self.data_end {
79 return Err(DemuxerError::EOF);
80 }
817e4872
KS
81 let strm = strmgr.get_stream(0);
82 if strm.is_none() { return Err(InvalidData); }
83 let stream = strm.unwrap();
b4bf2c3f
KS
84 let pts = if self.avg_bytes != 0 {
85 let pos = self.src.tell() - self.data_pos;
86 Some(pos * u64::from(self.srate) / u64::from(self.avg_bytes))
87 } else {
88 None
89 };
90 let ts = NATimeInfo::new(pts, None, None, 1, self.srate);
283abfa6 91 if self.is_pcm {
34219db3
KS
92 let bsize = if self.force_tb_num != 0 && self.force_tb_den != 0 {
93 let nbsize = u64::from(self.avg_bytes) * u64::from(self.force_tb_num) / u64::from(self.force_tb_den);
94 let mut nbsize = nbsize as usize + self.block_size - 1;
95 nbsize /= self.block_size;
96 nbsize *= self.block_size;
97 nbsize
98 } else {
99 let mut bsize = self.block_size;
100 while bsize < 256 {
101 bsize <<= 1;
102 }
103 bsize
104 };
283abfa6 105 let mut buf = vec![0; bsize];
f80d9ab9
KS
106 let mut tot_size = 0;
107 while let Ok(psize) = self.src.read_buf_some(&mut buf[tot_size..]) {
108 tot_size += psize;
109 if tot_size == buf.len() {
110 break;
111 }
112 }
113 buf.truncate(tot_size);
283abfa6
KS
114 Ok(NAPacket::new(stream, ts, true, buf))
115 } else {
116 self.src.read_packet(stream, ts, true, self.block_size)
117 }
118 }
119
24d99894 120 fn seek(&mut self, time: NATimePoint, _seek_index: &SeekIndex) -> DemuxerResult<()> {
283abfa6 121 if self.block_size != 0 && self.avg_bytes != 0 {
24d99894
KS
122 let seek_off = match time {
123 NATimePoint::Milliseconds(ms) => {
124 let seek_dst = u64::from(self.avg_bytes) * ms / 1000;
125 seek_dst / (self.block_size as u64) * (self.block_size as u64)
126 },
127 NATimePoint::PTS(pts) => (self.block_size as u64) * pts,
128 NATimePoint::None => return Ok(()),
129 };
283abfa6
KS
130 self.src.seek(SeekFrom::Start(self.data_pos + seek_off))?;
131 Ok(())
132 } else {
133 Err(DemuxerError::NotImplemented)
134 }
135 }
a480a0de
KS
136
137 fn get_duration(&self) -> u64 { self.duration }
283abfa6
KS
138}
139
34219db3
KS
140const WAV_OPTIONS: &[NAOptionDefinition] = &[
141 NAOptionDefinition {
142 name: "force_tb_num", description: "force timebase numerator for PCM blocks",
143 opt_type: NAOptionDefinitionType::Int(Some(1), None) },
144 NAOptionDefinition {
145 name: "force_tb_den", description: "force timebase denominator for PCM blocks",
146 opt_type: NAOptionDefinitionType::Int(Some(1), None) },
147];
148
787b8d03 149impl<'a> NAOptionHandler for WAVDemuxer<'a> {
34219db3
KS
150 fn get_supported_options(&self) -> &[NAOptionDefinition] { WAV_OPTIONS }
151 fn set_options(&mut self, options: &[NAOption]) {
152 for option in options.iter() {
153 match (option.name, &option.value) {
154 ("force_tb_num", NAValue::Int(ref ival)) => {
155 self.force_tb_num = *ival as u32;
156 },
157 ("force_tb_den", NAValue::Int(ref ival)) => {
158 self.force_tb_den = *ival as u32;
159 },
160 _ => {},
161 };
162 }
163 }
164 fn query_option_value(&self, name: &str) -> Option<NAValue> {
165 match name {
166 "force_tb_num" => Some(NAValue::Int(i64::from(self.force_tb_num))),
167 "force_tb_den" => Some(NAValue::Int(i64::from(self.force_tb_den))),
168 _ => None,
169 }
170 }
787b8d03
KS
171}
172
283abfa6
KS
173impl<'a> WAVDemuxer<'a> {
174 fn new(io: &'a mut ByteReader<'a>) -> Self {
175 WAVDemuxer {
176 src: io,
177 data_pos: 0,
178 data_end: 0,
179 srate: 0,
180 block_size: 0,
181 is_pcm: false,
182 avg_bytes: 0,
a480a0de 183 duration: 0,
34219db3
KS
184 force_tb_num: 0,
185 force_tb_den: 0,
283abfa6
KS
186 }
187 }
188 fn parse_fmt(&mut self, strmgr: &mut StreamManager, csize: usize) -> DemuxerResult<()> {
189 validate!(csize >= 14);
190 let format_tag = self.src.read_u16le()?;
191 let channels = self.src.read_u16le()?;
192 validate!(channels < 256);
193 let samples_per_sec = self.src.read_u32le()?;
194 let avg_bytes_per_sec = self.src.read_u32le()?;
195 let block_align = self.src.read_u16le()? as usize;
196 if block_align == 0 {
197 return Err(DemuxerError::NotImplemented);
198 }
199 let bits_per_sample = if csize >= 16 { self.src.read_u16le()? } else { 8 };
200 validate!(channels < 256);
201
61cab15b
KS
202 let edata = if csize > 16 {
203 validate!(csize >= 18);
204 let cb_size = self.src.read_u16le()? as usize;
205 let mut buf = vec![0; cb_size];
283abfa6 206 self.src.read_buf(buf.as_mut_slice())?;
61cab15b
KS
207 Some(buf)
208 } else {
209 None
210 };
283abfa6
KS
211
212 let cname = register::find_codec_from_wav_twocc(format_tag).unwrap_or("unknown");
213 let soniton = if cname == "pcm" {
214 if format_tag != 0x0003 {
215 if bits_per_sample == 8 {
216 NASoniton::new(8, 0)
217 } else {
218 NASoniton::new(bits_per_sample as u8, SONITON_FLAG_SIGNED)
219 }
220 } else {
221 NASoniton::new(bits_per_sample as u8, SONITON_FLAG_FLOAT)
222 }
223 } else {
224 NASoniton::new(bits_per_sample as u8, SONITON_FLAG_SIGNED)
225 };
226 let ahdr = NAAudioInfo::new(samples_per_sec, channels as u8, soniton, block_align);
227 let ainfo = NACodecInfo::new(cname, NACodecTypeInfo::Audio(ahdr), edata);
a480a0de 228 let res = strmgr.add_stream(NAStream::new(StreamType::Audio, 0, ainfo, 1, samples_per_sec, 0));
283abfa6
KS
229 if res.is_none() { return Err(MemoryError); }
230
231 self.srate = samples_per_sec;
232 self.block_size = block_align;
233 self.avg_bytes = avg_bytes_per_sec;
234 self.is_pcm = cname == "pcm";
a480a0de
KS
235 if self.is_pcm && self.avg_bytes == 0 {
236 self.avg_bytes = self.block_size as u32 * self.srate;
237 }
283abfa6
KS
238
239 Ok(())
240 }
241}
242
243pub struct WAVDemuxerCreator { }
244
245impl DemuxerCreator for WAVDemuxerCreator {
246 fn new_demuxer<'a>(&self, br: &'a mut ByteReader<'a>) -> Box<dyn DemuxCore<'a> + 'a> {
247 Box::new(WAVDemuxer::new(br))
248 }
249 fn get_name(&self) -> &'static str { "wav" }
250}
251
252#[cfg(test)]
253mod test {
254 use super::*;
255 use std::fs::File;
256
257 #[test]
258 fn test_wav_demux() {
886cde48 259 // sample: https://samples.mplayerhq.hu/A-codecs/msadpcm-stereo/scatter.wav
283abfa6
KS
260 let mut file = File::open("assets/MS/scatter.wav").unwrap();
261 let mut fr = FileReader::new_read(&mut file);
262 let mut br = ByteReader::new(&mut fr);
263 let mut dmx = WAVDemuxer::new(&mut br);
264 let mut sm = StreamManager::new();
265 let mut si = SeekIndex::new();
266 dmx.open(&mut sm, &mut si).unwrap();
267
268 loop {
269 let pktres = dmx.get_frame(&mut sm);
270 if let Err(e) = pktres {
271 if e == DemuxerError::EOF { break; }
272 panic!("error");
273 }
274 let pkt = pktres.unwrap();
275 println!("Got {}", pkt);
276 }
277 }
278}