mov: some fixes for MP4 parsing
[nihav.git] / nihav-commonfmt / src / demuxers / mov.rs
index 24431f9a7416e0a9fae179fd7256f005c942167b..1c52dfe24a06ef57d27aa1b05c138a394efef895 100644 (file)
@@ -239,7 +239,7 @@ fn read_cmov(dmx: &mut MOVDemuxer, strmgr: &mut StreamManager, size: u64) -> Dem
     dmx.duration = ddmx.duration;
     dmx.tb_den = ddmx.tb_den;
     std::mem::swap(&mut dmx.pal, &mut ddmx.pal);
-    
+
     Ok(size)
 }
 
@@ -292,7 +292,7 @@ fn read_tkhd(track: &mut Track, br: &mut ByteReader, size: u64) -> DemuxerResult
     let _mtime              = br.read_u32be()?;
     let track_id            = br.read_u32be()?;
                               br.read_skip(4)?;
-    let _duration           = br.read_u32be()?;
+    let duration            = br.read_u32be()?;
                               br.read_skip(8)?;
     let _layer              = br.read_u16be()?;
     let _alt_group          = br.read_u16be()?;
@@ -304,6 +304,7 @@ fn read_tkhd(track: &mut Track, br: &mut ByteReader, size: u64) -> DemuxerResult
     track.width  = width  >> 16;
     track.height = height >> 16;
     track.track_id = track_id;
+    track.duration = duration;
 
     track.tkhd_found = true;
     Ok(KNOWN_TKHD_SIZE)
@@ -334,7 +335,7 @@ fn read_hdlr(track: &mut Track, br: &mut ByteReader, size: u64) -> DemuxerResult
     let _comp_flags         = br.read_u32be()?;
     let _comp_flags_mask    = br.read_u32be()?;
 
-    if comp_type == mktag!(b"mhlr") {
+    if comp_type == mktag!(b"mhlr") || comp_type == 0 {
         if comp_subtype == mktag!(b"vide") {
             track.stream_type = StreamType::Video;
         } else if comp_subtype == mktag!(b"soun") {
@@ -417,7 +418,7 @@ fn read_stbl(track: &mut Track, br: &mut ByteReader, size: u64) -> DemuxerResult
 
 const STBL_CHUNK_HANDLERS: &[TrackChunkHandler] = &[
     TrackChunkHandler { ctype: mktag!(b"stsd"), parse: read_stsd },
-    TrackChunkHandler { ctype: mktag!(b"stts"), parse: skip_chunk },
+    TrackChunkHandler { ctype: mktag!(b"stts"), parse: read_stts },
     TrackChunkHandler { ctype: mktag!(b"stss"), parse: read_stss },
     TrackChunkHandler { ctype: mktag!(b"stsc"), parse: read_stsc },
     TrackChunkHandler { ctype: mktag!(b"stsz"), parse: read_stsz },
@@ -428,6 +429,14 @@ const STBL_CHUNK_HANDLERS: &[TrackChunkHandler] = &[
 fn parse_audio_edata(br: &mut ByteReader, start_pos: u64, size: u64) -> DemuxerResult<Option<Vec<u8>>> {
     let read_part = br.tell() - start_pos;
     if read_part + 8 < size {
+        let mut buf = [0; 8];
+                              br.peek_buf(&mut buf)?;
+        if &buf[4..8] != b"wave" {
+            let mut buf = vec![0; (size - read_part) as usize];
+                              br.read_buf(&mut buf)?;
+            return Ok(Some(buf));
+        }
+
         let csize           = br.read_u32be()? as u64;
         let ctag            = br.read_u32be()?;
         validate!(read_part + csize <= size);
@@ -554,7 +563,8 @@ fn read_stsd(track: &mut Track, br: &mut ByteReader, size: u64) -> DemuxerResult
             let edata = if br.tell() - start_pos + 4 < size {
 //todo skip various common atoms
                     let edata_size  = br.read_u32be()? as usize;
-                    let mut buf = vec![0; edata_size];
+                    validate!(edata_size >= 4);
+                    let mut buf = vec![0; edata_size - 4];
                                   br.read_buf(buf.as_mut_slice())?;
                     Some(buf)
                 } else {
@@ -582,8 +592,10 @@ fn read_stsd(track: &mut Track, br: &mut ByteReader, size: u64) -> DemuxerResult
                 } else {
                     "unknown"
                 };
-//todo adjust format for various PCM kinds
-            let soniton = NASoniton::new(sample_size as u8, SONITON_FLAG_SIGNED | SONITON_FLAG_BE);
+            let mut soniton = NASoniton::new(sample_size as u8, SONITON_FLAG_SIGNED | SONITON_FLAG_BE);
+            if &fcc == b"raw " && sample_size == 8 {
+                soniton.signed = false;
+            }
             let block_align = 1;
             if sver == 1 {
                 let samples_per_packet      = br.read_u32be()?;
@@ -612,11 +624,46 @@ fn read_stsd(track: &mut Track, br: &mut ByteReader, size: u64) -> DemuxerResult
     };
     let read_size = br.tell() - start_pos;
     validate!(read_size <= size);
-    track.stream = Some(NAStream::new(track.stream_type, track.track_no, codec_info, 1, track.tb_den));
+    track.stream = Some(NAStream::new(track.stream_type, track.track_no, codec_info, 1, track.tb_den, u64::from(track.duration)));
     track.stsd_found = true;
     Ok(read_size)
 }
 
+fn read_stts(track: &mut Track, br: &mut ByteReader, size: u64) -> DemuxerResult<u64> {
+    validate!(size >= 8);
+    let start_pos = br.tell();
+    let version             = br.read_byte()?;
+    validate!(version == 0);
+    let _flags              = br.read_u24be()?;
+    let entries             = br.read_u32be()? as usize;
+    validate!(entries as u64 <= (size - 8) / 8);
+    if entries == 0 {
+    } else if entries == 1 {
+        let _count          = br.read_u32be()?;
+        let tb_num          = br.read_u32be()?;
+        if let Some(ref mut stream) = track.stream {
+            let tb_den = stream.tb_den;
+            let (tb_num, tb_den) = reduce_timebase(tb_num, tb_den);
+            stream.duration /= u64::from(stream.tb_den / tb_den);
+            stream.tb_num = tb_num;
+            stream.tb_den = tb_den;
+            track.tb_num = tb_num;
+            track.tb_den = tb_den;
+        }
+    } else {
+        track.time_to_sample.truncate(0);
+        track.time_to_sample.reserve(entries);
+        for _ in 0..entries {
+            let count       = br.read_u32be()?;
+            let mult        = br.read_u32be()?;
+            track.time_to_sample.push((count, mult));
+        }
+    }
+    let read_size = br.tell() - start_pos;
+    validate!(read_size <= size);
+    Ok(read_size)
+}
+
 fn read_stss(track: &mut Track, br: &mut ByteReader, size: u64) -> DemuxerResult<u64> {
     let version             = br.read_byte()?;
     validate!(version == 0);
@@ -708,7 +755,9 @@ struct Track {
     track_id:       u32,
     track_str_id:   usize,
     track_no:       u32,
+    tb_num:         u32,
     tb_den:         u32,
+    duration:       u32,
     depth:          u8,
     tkhd_found:     bool,
     stsd_found:     bool,
@@ -722,6 +771,7 @@ struct Track {
     keyframes:      Vec<u32>,
     chunk_sizes:    Vec<u32>,
     chunk_offsets:  Vec<u64>,
+    time_to_sample: Vec<(u32, u32)>,
     sample_map:     Vec<(u32, u32)>,
     sample_size:    u32,
     frame_samples:  usize,
@@ -731,6 +781,48 @@ struct Track {
     samples_left:   usize,
     last_offset:    u64,
     pal:            Option<Arc<[u8; 1024]>>,
+    timesearch:     TimeSearcher,
+}
+
+#[derive(Default)]
+struct TimeSearcher {
+    idx:        usize,
+    base:       u64,
+    sbase:      u32,
+    cur_len:    u32,
+    cur_mul:    u32,
+}
+
+impl TimeSearcher {
+    fn new() -> Self { Self::default() }
+    fn reset(&mut self) {
+        *self = Self::default();
+    }
+    fn map_time(&mut self, sample: u32, tts: &Vec<(u32, u32)>) -> u64 {
+        if tts.is_empty() {
+            u64::from(sample)
+        } else if sample >= self.sbase {
+            let mut sample = sample - self.sbase;
+            if self.idx == 0 {
+                let (cur_len, cur_mul) = tts[0];
+                self.cur_len = cur_len;
+                self.cur_mul = cur_mul;
+                self.idx += 1;
+            }
+            while self.idx < tts.len() && sample > self.cur_len {
+                sample -= self.cur_len;
+                self.sbase += self.cur_len;
+                self.base += u64::from(self.cur_len) * u64::from(self.cur_mul);
+                self.cur_len = tts[self.idx].0;
+                self.cur_mul = tts[self.idx].1;
+                self.idx += 1;
+            }
+            self.base + u64::from(sample) * u64::from(self.cur_mul)
+        } else {
+            self.reset();
+            self.map_time(sample, tts)
+        }
+    }
 }
 
 impl Track {
@@ -741,7 +833,9 @@ impl Track {
             track_id:       0,
             track_str_id:   0,
             track_no,
+            tb_num: 1,
             tb_den,
+            duration:       0,
             stream_type:    StreamType::None,
             width:          0,
             height:         0,
@@ -752,6 +846,7 @@ impl Track {
             keyframes:      Vec::new(),
             chunk_sizes:    Vec::new(),
             chunk_offsets:  Vec::new(),
+            time_to_sample: Vec::new(),
             sample_map:     Vec::new(),
             sample_size:    0,
             frame_samples:  0,
@@ -762,6 +857,7 @@ impl Track {
             samples_left:   0,
             last_offset:    0,
             pal:            None,
+            timesearch:     TimeSearcher::new(),
         }
     }
     read_chunk_list!(track; "trak", read_trak, TRAK_CHUNK_HANDLERS);
@@ -772,14 +868,11 @@ impl Track {
         if !self.keyframes.is_empty() {
             seek_index.mode = SeekIndexMode::Present;
         }
+        let mut tsearch = TimeSearcher::new();
         for kf_time in self.keyframes.iter() {
-            let pts = u64::from(*kf_time - 1);
-            let time = NATimeInfo::ts_to_time(pts, 1000, 1, self.tb_den);
-            let idx = (*kf_time - 1) as usize;
-            if idx < self.chunk_offsets.len() {
-                let pos = self.chunk_offsets[idx];
-                seek_index.add_entry(self.track_no as u32, SeekEntry { time, pts, pos });
-            }
+            let pts = tsearch.map_time(*kf_time - 1, &self.time_to_sample);
+            let time = NATimeInfo::ts_to_time(pts, 1000, self.tb_num, self.tb_den);
+            seek_index.add_entry(self.track_no as u32, SeekEntry { time, pts: u64::from(*kf_time - 1), pos: 0 });
         }
     }
     fn calculate_chunk_size(&self, nsamp: usize) -> usize {
@@ -815,7 +908,8 @@ impl Track {
         }
     }
     fn get_next_chunk(&mut self) -> Option<(NATimeInfo, u64, usize)> {
-        let pts = NATimeInfo::new(Some(self.cur_sample as u64), None, None, 1, self.tb_den);
+        let pts_val = self.timesearch.map_time(self.cur_sample as u32, &self.time_to_sample);
+        let pts = NATimeInfo::new(Some(pts_val), None, None, 1, self.tb_den);
 //todo dts decoding
         if self.chunk_offsets.len() == self.chunk_sizes.len() { // simple one-to-one mapping
             if self.cur_sample >= self.chunk_sizes.len() {
@@ -845,8 +939,14 @@ impl Track {
             self.last_offset += size as u64;
             if self.stream_type == StreamType::Video {
                 self.samples_left -= 1;
-            } else if self.frame_samples != 0 {
-                self.samples_left -= self.frame_samples.min(self.samples_left);
+            } else if self.frame_samples != 0 && self.bsize != 0 {
+                let nblocks = size / self.bsize;
+                if nblocks > 0 {
+                    let consumed = (nblocks * self.frame_samples).min(self.samples_left);
+                    self.samples_left -= consumed;
+                } else {
+                    self.samples_left = 0;
+                }
             } else {
                 self.samples_left = 0;
             }
@@ -883,7 +983,7 @@ impl Track {
             let mut cur_samps = 0;
             let (mut next_idx, mut next_samples) = cmap.next().unwrap();
             loop {
-                if self.cur_chunk == next_idx as usize {
+                if self.cur_chunk + 1 == next_idx as usize {
                     self.samples_left = cur_samps;
                     cur_samps = next_samples as usize;
                     if let Some((new_idx, new_samples)) = cmap.next() {
@@ -899,10 +999,11 @@ impl Track {
                 self.cur_chunk += 1;
             }
             csamp -= cur_samps;
-            for sample_no in csamp..self.cur_chunk {
+            for sample_no in csamp..self.cur_sample {
                 self.last_offset += self.get_size(sample_no) as u64;
             }
-            self.samples_left = self.cur_sample - csamp - cur_samps;
+            self.samples_left = csamp + cur_samps - self.cur_sample;
+            self.cur_chunk += 1;
         }
     }
 }
@@ -947,7 +1048,7 @@ impl<'a> DemuxCore<'a> for MOVDemuxer<'a> {
         Err(DemuxerError::EOF)
     }
 
-    fn seek(&mut self, time: u64, seek_index: &SeekIndex) -> DemuxerResult<()> {
+    fn seek(&mut self, time: NATimePoint, seek_index: &SeekIndex) -> DemuxerResult<()> {
         let ret = seek_index.find_pos(time);
         if ret.is_none() {
             return Err(DemuxerError::SeekError);
@@ -958,6 +1059,13 @@ impl<'a> DemuxCore<'a> for MOVDemuxer<'a> {
         }
         Ok(())
     }
+    fn get_duration(&self) -> u64 {
+        if self.tb_den != 0 {
+            u64::from(self.duration) * 1000 / u64::from(self.tb_den)
+        } else {
+            0
+        }
+    }
 }
 
 impl<'a> NAOptionHandler for MOVDemuxer<'a> {