core/frame: add proper function for audio frame truncation
[nihav.git] / nihav-core / src / frame.rs
index 29db180082eb265f999f4993d702e5236e195265..31cf8e81991a519d26827e13913175a6d3cd1dd6 100644 (file)
@@ -2,9 +2,10 @@
 use std::cmp::max;
 //use std::collections::HashMap;
 use std::fmt;
-use std::sync::Arc;
+pub use std::sync::Arc;
 pub use crate::formats::*;
 pub use crate::refs::*;
+use std::str::FromStr;
 
 /// Audio stream information.
 #[allow(dead_code)]
@@ -53,12 +54,15 @@ pub struct NAVideoInfo {
     pub flipped:    bool,
     /// Picture pixel format.
     pub format:     NAPixelFormaton,
+    /// Declared bits per sample.
+    pub bits:       u8,
 }
 
 impl NAVideoInfo {
     /// Constructs a new `NAVideoInfo` instance.
     pub fn new(w: usize, h: usize, flip: bool, fmt: NAPixelFormaton) -> Self {
-        NAVideoInfo { width: w, height: h, flipped: flip, format: fmt }
+        let bits = fmt.get_total_depth();
+        NAVideoInfo { width: w, height: h, flipped: flip, format: fmt, bits }
     }
     /// Returns picture width.
     pub fn get_width(&self)  -> usize { self.width as usize }
@@ -146,6 +150,10 @@ pub struct NAVideoBuffer<T> {
 }
 
 impl<T: Clone> NAVideoBuffer<T> {
+    /// Constructs video buffer from the provided components.
+    pub fn from_raw_parts(info: NAVideoInfo, data: NABufferRef<Vec<T>>, offs: Vec<usize>, strides: Vec<usize>) -> Self {
+        Self { info, data, offs, strides }
+    }
     /// Returns the component offset (0 for all unavailable offsets).
     pub fn get_offset(&self, idx: usize) -> usize {
         if idx >= self.offs.len() { 0 }
@@ -233,6 +241,8 @@ impl<T: Clone> NAAudioBuffer<T> {
     pub fn get_chmap(&self) -> &NAChannelMap { &self.chmap }
     /// Returns an immutable reference to the data.
     pub fn get_data(&self) -> &Vec<T> { self.data.as_ref() }
+    /// Returns reference to the data.
+    pub fn get_data_ref(&self) -> NABufferRef<Vec<T>> { self.data.clone() }
     /// Returns a mutable reference to the data.
     pub fn get_data_mut(&mut self) -> Option<&mut Vec<T>> { self.data.as_mut() }
     /// Clones current `NAAudioBuffer` into a new one.
@@ -245,6 +255,12 @@ impl<T: Clone> NAAudioBuffer<T> {
     }
     /// Return the length of frame in samples.
     pub fn get_length(&self) -> usize { self.len }
+    /// Truncates buffer length if possible.
+    ///
+    /// In case when new length is larger than old length nothing is done.
+    pub fn truncate(&mut self, new_len: usize) {
+        self.len = self.len.min(new_len);
+    }
 
     fn print_contents(&self, datatype: &str) {
         println!("Audio buffer with {} data, stride {}, step {}", datatype, self.stride, self.step);
@@ -374,6 +390,17 @@ impl NABufferType {
             _ => 0,
         }
     }
+    /// Truncates audio frame duration if possible.
+    pub fn truncate_audio(&mut self, len: usize) {
+        match *self {
+            NABufferType::AudioU8(ref mut ab)     => ab.truncate(len),
+            NABufferType::AudioI16(ref mut ab)    => ab.truncate(len),
+            NABufferType::AudioI32(ref mut ab)    => ab.truncate(len),
+            NABufferType::AudioF32(ref mut ab)    => ab.truncate(len),
+            NABufferType::AudioPacked(ref mut ab) => ab.truncate(len),
+            _ => {},
+        };
+    }
     /// Returns the distance between starts of two channels.
     pub fn get_audio_stride(&self) -> usize {
         match *self {
@@ -655,6 +682,10 @@ pub fn alloc_audio_buffer(ainfo: NAAudioInfo, nsamples: usize, chmap: NAChannelM
                 let data: Vec<i16> = vec![0; length];
                 let buf: NAAudioBuffer<i16> = NAAudioBuffer { data: NABufferRef::new(data), info: ainfo, offs, chmap, len: nsamples, stride, step };
                 Ok(NABufferType::AudioI16(buf))
+            } else if ainfo.format.get_bits() == 32 && ainfo.format.is_signed() {
+                let data: Vec<i32> = vec![0; length];
+                let buf: NAAudioBuffer<i32> = NAAudioBuffer { data: NABufferRef::new(data), info: ainfo, offs, chmap, len: nsamples, stride, step };
+                Ok(NABufferType::AudioI32(buf))
             } else {
                 Err(AllocatorError::TooLargeDimensions)
             }
@@ -677,7 +708,7 @@ pub fn alloc_data_buffer(size: usize) -> Result<NABufferType, AllocatorError> {
 }
 
 /// Creates a clone of current buffer.
-pub fn copy_buffer(buf: NABufferType) -> NABufferType {
+pub fn copy_buffer(buf: &NABufferType) -> NABufferType {
     buf.clone()
 }
 
@@ -857,21 +888,6 @@ pub const DUMMY_CODEC_INFO: NACodecInfo = NACodecInfo {
                                 properties: NACodecTypeInfo::None,
                                 extradata: None };
 
-/// A list of accepted option values.
-#[derive(Debug,Clone)]
-pub enum NAValue {
-    /// Empty value.
-    None,
-    /// Integer value.
-    Int(i32),
-    /// Long integer value.
-    Long(i64),
-    /// String value.
-    String(String),
-    /// Binary data value.
-    Data(Arc<Vec<u8>>),
-}
-
 /// A list of recognized frame types.
 #[derive(Debug,Clone,Copy,PartialEq)]
 #[allow(dead_code)]
@@ -936,31 +952,35 @@ impl NATimeInfo {
     pub fn set_duration(&mut self, dur: Option<u64>) { self.duration = dur; }
 
     /// Converts time in given scale into timestamp in given base.
+    #[allow(clippy::collapsible_if)]
     pub fn time_to_ts(time: u64, base: u64, tb_num: u32, tb_den: u32) -> u64 {
-        let tb_num = tb_num as u64;
-        let tb_den = tb_den as u64;
-        let tmp = time.checked_mul(tb_num);
+        let tb_num = u64::from(tb_num);
+        let tb_den = u64::from(tb_den);
+        let tmp = time.checked_mul(tb_den);
         if let Some(tmp) = tmp {
-            tmp / base / tb_den
+            tmp / base / tb_num
         } else {
-            let tmp = time.checked_mul(tb_num);
-            if let Some(tmp) = tmp {
-                tmp / base / tb_den
+            if tb_num < base {
+                let coarse = time / tb_num;
+                if let Some(tmp) = coarse.checked_mul(tb_den) {
+                    tmp / base
+                } else {
+                    (coarse / base) * tb_den
+                }
             } else {
                 let coarse = time / base;
-                let tmp = coarse.checked_mul(tb_num);
-                if let Some(tmp) = tmp {
-                    tmp / tb_den
+                if let Some(tmp) = coarse.checked_mul(tb_den) {
+                    tmp / tb_num
                 } else {
-                    (coarse / tb_den) * tb_num
+                    (coarse / tb_num) * tb_den
                 }
             }
         }
     }
     /// Converts timestamp in given base into time in given scale.
     pub fn ts_to_time(ts: u64, base: u64, tb_num: u32, tb_den: u32) -> u64 {
-        let tb_num = tb_num as u64;
-        let tb_den = tb_den as u64;
+        let tb_num = u64::from(tb_num);
+        let tb_den = u64::from(tb_den);
         let tmp = ts.checked_mul(base);
         if let Some(tmp) = tmp {
             let tmp2 = tmp.checked_mul(tb_num);
@@ -978,6 +998,182 @@ impl NATimeInfo {
             }
         }
     }
+    fn get_cur_ts(&self) -> u64 { self.pts.unwrap_or_else(|| self.dts.unwrap_or(0)) }
+    fn get_cur_millis(&self) -> u64 {
+        let ts = self.get_cur_ts();
+        Self::ts_to_time(ts, 1000, self.tb_num, self.tb_den)
+    }
+    /// Checks whether the current time information is earler than provided reference time.
+    pub fn less_than(&self, time: NATimePoint) -> bool {
+        if self.pts.is_none() && self.dts.is_none() {
+            return true;
+        }
+        match time {
+            NATimePoint::PTS(rpts) => self.get_cur_ts() < rpts,
+            NATimePoint::Milliseconds(ms) => self.get_cur_millis() < ms,
+            NATimePoint::None => false,
+        }
+    }
+    /// Checks whether the current time information is the same as provided reference time.
+    pub fn equal(&self, time: NATimePoint) -> bool {
+        if self.pts.is_none() && self.dts.is_none() {
+            return time == NATimePoint::None;
+        }
+        match time {
+            NATimePoint::PTS(rpts) => self.get_cur_ts() == rpts,
+            NATimePoint::Milliseconds(ms) => self.get_cur_millis() == ms,
+            NATimePoint::None => false,
+        }
+    }
+}
+
+/// Time information for specifying durations or seek positions.
+#[derive(Clone,Copy,Debug,PartialEq)]
+pub enum NATimePoint {
+    /// Time in milliseconds.
+    Milliseconds(u64),
+    /// Stream timestamp.
+    PTS(u64),
+    /// No time information present.
+    None,
+}
+
+impl Default for NATimePoint {
+    fn default() -> Self {
+        NATimePoint::None
+    }
+}
+
+impl fmt::Display for NATimePoint {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match *self {
+            NATimePoint::Milliseconds(millis) => {
+                let tot_s = millis / 1000;
+                let ms = millis % 1000;
+                if tot_s < 60 {
+                    if ms != 0 {
+                        return write!(f, "{}.{:03}", tot_s, ms);
+                    } else {
+                        return write!(f, "{}", tot_s);
+                    }
+                }
+                let tot_m = tot_s / 60;
+                let s = tot_s % 60;
+                if tot_m < 60 {
+                    if ms != 0 {
+                        return write!(f, "{}:{:02}.{:03}", tot_m, s, ms);
+                    } else {
+                        return write!(f, "{}:{:02}", tot_m, s);
+                    }
+                }
+                let h = tot_m / 60;
+                let m = tot_m % 60;
+                if ms != 0 {
+                    write!(f, "{}:{:02}:{:02}.{:03}", h, m, s, ms)
+                } else {
+                    write!(f, "{}:{:02}:{:02}", h, m, s)
+                }
+            },
+            NATimePoint::PTS(pts) => {
+                write!(f, "{}pts", pts)
+            },
+            NATimePoint::None => {
+                write!(f, "none")
+            },
+        }
+    }
+}
+
+impl FromStr for NATimePoint {
+    type Err = FormatParseError;
+
+    /// Parses the string into time information.
+    ///
+    /// Accepted formats are `<u64>pts`, `<u64>ms` or `[hh:][mm:]ss[.ms]`.
+    fn from_str(s: &str) -> Result<Self, Self::Err> {
+        if s.is_empty() {
+            return Err(FormatParseError {});
+        }
+        if !s.ends_with("pts") {
+            if s.ends_with("ms") {
+                let str_b = s.as_bytes();
+                let num = std::str::from_utf8(&str_b[..str_b.len() - 2]).unwrap();
+                let ret = num.parse::<u64>();
+                if let Ok(val) = ret {
+                    return Ok(NATimePoint::Milliseconds(val));
+                } else {
+                    return Err(FormatParseError {});
+                }
+            }
+            let mut parts = s.split(':');
+            let mut hrs = None;
+            let mut mins = None;
+            let mut secs = parts.next();
+            if let Some(part) = parts.next() {
+                std::mem::swap(&mut mins, &mut secs);
+                secs = Some(part);
+            }
+            if let Some(part) = parts.next() {
+                std::mem::swap(&mut hrs, &mut mins);
+                std::mem::swap(&mut mins, &mut secs);
+                secs = Some(part);
+            }
+            if parts.next().is_some() {
+                return Err(FormatParseError {});
+            }
+            let hours = if let Some(val) = hrs {
+                    let ret = val.parse::<u64>();
+                    if ret.is_err() { return Err(FormatParseError {}); }
+                    let val = ret.unwrap();
+                    if val > 1000 { return Err(FormatParseError {}); }
+                    val
+                } else { 0 };
+            let minutes = if let Some(val) = mins {
+                    let ret = val.parse::<u64>();
+                    if ret.is_err() { return Err(FormatParseError {}); }
+                    let val = ret.unwrap();
+                    if val >= 60 { return Err(FormatParseError {}); }
+                    val
+                } else { 0 };
+            let (seconds, millis) = if let Some(val) = secs {
+                    let mut parts = val.split('.');
+                    let ret = parts.next().unwrap().parse::<u64>();
+                    if ret.is_err() { return Err(FormatParseError {}); }
+                    let seconds = ret.unwrap();
+                    if mins.is_some() && seconds >= 60 { return Err(FormatParseError {}); }
+                    let millis = if let Some(val) = parts.next() {
+                            let mut mval = 0;
+                            let mut base = 0;
+                            for ch in val.chars() {
+                                if ch >= '0' && ch <= '9' {
+                                    mval = mval * 10 + u64::from((ch as u8) - b'0');
+                                    base += 1;
+                                    if base > 3 { break; }
+                                } else {
+                                    return Err(FormatParseError {});
+                                }
+                            }
+                            while base < 3 {
+                                mval *= 10;
+                                base += 1;
+                            }
+                            mval
+                        } else { 0 };
+                    (seconds, millis)
+                } else { unreachable!(); };
+            let tot_secs = hours * 60 * 60 + minutes * 60 + seconds;
+            Ok(NATimePoint::Milliseconds(tot_secs * 1000 + millis))
+        } else {
+            let str_b = s.as_bytes();
+            let num = std::str::from_utf8(&str_b[..str_b.len() - 3]).unwrap();
+            let ret = num.parse::<u64>();
+            if let Ok(val) = ret {
+                Ok(NATimePoint::PTS(val))
+            } else {
+                Err(FormatParseError {})
+            }
+        }
+    }
 }
 
 /// Decoded frame information.
@@ -1112,12 +1308,15 @@ pub struct NAStream {
     pub tb_num:         u32,
     /// Timebase denominator.
     pub tb_den:         u32,
+    /// Duration in timebase units (zero if not available).
+    pub duration:       u64,
 }
 
 /// A specialised reference-counted `NAStream` type.
 pub type NAStreamRef = Arc<NAStream>;
 
 /// Downscales the timebase by its greatest common denominator.
+#[allow(clippy::comparison_chain)]
 pub fn reduce_timebase(tb_num: u32, tb_den: u32) -> (u32, u32) {
     if tb_num == 0 { return (tb_num, tb_den); }
     if (tb_den % tb_num) == 0 { return (1, tb_den / tb_num); }
@@ -1135,9 +1334,9 @@ pub fn reduce_timebase(tb_num: u32, tb_den: u32) -> (u32, u32) {
 
 impl NAStream {
     /// Constructs a new `NAStream` instance.
-    pub fn new(mt: StreamType, id: u32, info: NACodecInfo, tb_num: u32, tb_den: u32) -> Self {
+    pub fn new(mt: StreamType, id: u32, info: NACodecInfo, tb_num: u32, tb_den: u32, duration: u64) -> Self {
         let (n, d) = reduce_timebase(tb_num, tb_den);
-        NAStream { media_type: mt, id, num: 0, info: info.into_ref(), tb_num: n, tb_den: d }
+        NAStream { media_type: mt, id, num: 0, info: info.into_ref(), tb_num: n, tb_den: d, duration }
     }
     /// Returns stream id.
     pub fn get_id(&self) -> u32 { self.id }
@@ -1157,6 +1356,8 @@ impl NAStream {
         self.tb_num = n;
         self.tb_den = d;
     }
+    /// Returns stream duration.
+    pub fn get_duration(&self) -> usize { self.num }
     /// Converts current instance into a reference-counted one.
     pub fn into_ref(self) -> NAStreamRef { Arc::new(self) }
 }
@@ -1167,6 +1368,18 @@ impl fmt::Display for NAStream {
     }
 }
 
+/// Side data that may accompany demuxed data.
+#[derive(Clone)]
+pub enum NASideData {
+    /// Palette information.
+    ///
+    /// This side data contains a flag signalling that palette has changed since previous time and a reference to the current palette.
+    /// Palette is stored in 8-bit RGBA format.
+    Palette(bool, Arc<[u8; 1024]>),
+    /// Generic user data.
+    UserData(Arc<Vec<u8>>),
+}
+
 /// Packet with compressed data.
 #[allow(dead_code)]
 pub struct NAPacket {
@@ -1177,6 +1390,8 @@ pub struct NAPacket {
     /// Keyframe flag.
     pub keyframe:       bool,
 //    options:        HashMap<String, NAValue<'a>>,
+    /// Packet side data (e.g. palette for paletted formats).
+    pub side_data:      Vec<NASideData>,
 }
 
 impl NAPacket {
@@ -1184,7 +1399,11 @@ impl NAPacket {
     pub fn new(str: NAStreamRef, ts: NATimeInfo, kf: bool, vec: Vec<u8>) -> Self {
 //        let mut vec: Vec<u8> = Vec::new();
 //        vec.resize(size, 0);
-        NAPacket { stream: str, ts, keyframe: kf, buffer: NABufferRef::new(vec) }
+        NAPacket { stream: str, ts, keyframe: kf, buffer: NABufferRef::new(vec), side_data: Vec::new() }
+    }
+    /// Constructs a new `NAPacket` instance reusing a buffer reference.
+    pub fn new_from_refbuf(str: NAStreamRef, ts: NATimeInfo, kf: bool, buffer: NABufferRef<Vec<u8>>) -> Self {
+        NAPacket { stream: str, ts, keyframe: kf, buffer, side_data: Vec::new() }
     }
     /// Returns information about the stream packet belongs to.
     pub fn get_stream(&self) -> NAStreamRef { self.stream.clone() }
@@ -1200,6 +1419,13 @@ impl NAPacket {
     pub fn is_keyframe(&self) -> bool { self.keyframe }
     /// Returns a reference to packet data.
     pub fn get_buffer(&self) -> NABufferRef<Vec<u8>> { self.buffer.clone() }
+    /// Adds side data for a packet.
+    pub fn add_side_data(&mut self, side_data: NASideData) { self.side_data.push(side_data); }
+    /// Assigns packet to a new stream.
+    pub fn reassign(&mut self, str: NAStreamRef, ts: NATimeInfo) {
+        self.stream = str;
+        self.ts = ts;
+    }
 }
 
 impl Drop for NAPacket {
@@ -1217,3 +1443,21 @@ impl fmt::Display for NAPacket {
         write!(f, "{}", ostr)
     }
 }
+
+#[cfg(test)]
+mod test {
+    use super::*;
+
+    #[test]
+    fn test_time_parse() {
+        assert_eq!(NATimePoint::PTS(42).to_string(), "42pts");
+        assert_eq!(NATimePoint::Milliseconds(4242000).to_string(), "1:10:42");
+        assert_eq!(NATimePoint::Milliseconds(42424242).to_string(), "11:47:04.242");
+        let ret = NATimePoint::from_str("42pts");
+        assert_eq!(ret.unwrap(), NATimePoint::PTS(42));
+        let ret = NATimePoint::from_str("1:2:3");
+        assert_eq!(ret.unwrap(), NATimePoint::Milliseconds(3723000));
+        let ret = NATimePoint::from_str("1:2:3.42");
+        assert_eq!(ret.unwrap(), NATimePoint::Milliseconds(3723420));
+    }
+}