]> git.nihav.org Git - nihav-encoder.git/commitdiff
refactor code for more self-contained encoders and decoders
authorKostya Shishkov <kostya.shishkov@gmail.com>
Sat, 15 Mar 2025 17:59:30 +0000 (18:59 +0100)
committerKostya Shishkov <kostya.shishkov@gmail.com>
Sat, 15 Mar 2025 17:59:30 +0000 (18:59 +0100)
src/main.rs
src/transcoder.rs

index 24a395ee77a4443f2877e18f5bcb48a9e2cc71eb..c83b4bf1c428da5caafeb431d85887210188caf6 100644 (file)
@@ -11,7 +11,6 @@ use nihav_core::codecs::*;
 use nihav_core::demuxers::*;
 use nihav_core::muxers::*;
 use nihav_core::reorder::*;
-use nihav_core::scale::*;
 use nihav_registry::detect;
 use nihav_registry::register;
 use std::env;
@@ -84,56 +83,6 @@ fn print_options(name: &str, options: &[NAOptionDefinition]) {
     }
 }
 
-fn encode_frame(dst_id: u32, encoder: &mut Box<dyn NAEncoder>, cvt: &mut OutputConvert, frm: NAFrameRef, scale_opts: &[(String, String)]) -> bool {
-    let buf = frm.get_buffer();
-    let cbuf = if let NABufferType::None = buf {
-        if (encoder.get_capabilities() & ENC_CAPS_SKIPFRAME) == 0 {
-            match cvt {
-                OutputConvert::Video(_, ref mut dbuf) => dbuf.clone(),
-                _ => {
-                    println!("encoder does not support skip frames, skipping");
-                    return true;
-                },
-            }
-        } else {
-            buf
-        }
-    } else {
-        match cvt {
-            OutputConvert::None => buf,
-            OutputConvert::Video(ref mut scaler, ref mut dbuf) => {
-                let cur_ifmt = get_scale_fmt_from_pic(&buf);
-                let last_ifmt = scaler.get_in_fmt();
-                if cur_ifmt != last_ifmt {
-                    let ofmt = scaler.get_out_fmt();
-                    let ret = NAScale::new_with_options(cur_ifmt, ofmt, scale_opts);
-                    if ret.is_err() {
-                        println!("error re-initialising scaler for {} -> {}", cur_ifmt, ofmt);
-                        return false;
-                    }
-                    *scaler = ret.unwrap();
-                }
-                let ret = scaler.convert(&buf, dbuf);
-                if ret.is_err() {
-                    println!("error converting frame for encoding");
-                    return false;
-                }
-                dbuf.clone()
-            },
-            OutputConvert::Audio(ref mut acvt) => {
-                if !acvt.queue_frame(buf, frm.get_time_information()) {
-                    println!("error converting audio for stream {}", dst_id);
-                    return false;
-                }
-                return true;
-            },
-        }
-    };
-    let cfrm = NAFrame::new(frm.get_time_information(), frm.frame_type, frm.key, frm.get_info(), cbuf);
-    encoder.encode(&cfrm).unwrap();
-    true
-}
-
 macro_rules! next_arg {
     ($args: expr, $arg_idx: expr) => {
         if $arg_idx + 1 >= $args.len() {
@@ -459,14 +408,14 @@ fn main() {
             let decfunc = full_reg.dec_reg.find_decoder(info.get_name());
             let str_id = (s.get_num() + is_off) as u32;
             if let Some(create_dec) = decfunc {
-                let mut dec = (create_dec)();
+                let mut decoder = (create_dec)();
                 let mut dsupp = Box::new(NADecoderSupport::new());
-                let ret = dec.init(&mut dsupp, info.clone());
+                let ret = decoder.init(&mut dsupp, info.clone());
                 if ret.is_err() {
                     println!("Error initialising decoder '{}' for stream {}", info.get_name(), str_id);
                     return;
                 }
-                transcoder.apply_decoder_options(dec.as_mut(), str_id);
+                transcoder.apply_decoder_options(decoder.as_mut(), str_id);
                 let desc = register::get_codec_description(info.get_name());
                 let has_b = if let Some(desc) = desc {
                         desc.has_reorder()
@@ -474,8 +423,8 @@ fn main() {
                         println!("No codec description found, using B-frame reorderer.");
                         true
                     };
-                let reord: Box<dyn FrameReorderer> = if has_b { Box::new(IPBReorderer::new()) } else { Box::new(NoReorderer::new()) };
-                transcoder.decoders.push(Some((dsupp, dec, reord)));
+                let reorderer: Box<dyn FrameReorderer> = if has_b { Box::new(IPBReorderer::new()) } else { Box::new(NoReorderer::new()) };
+                transcoder.decoders.push(Some(DecodeContext{ dsupp, decoder, reorderer }));
             } else {
                 println!("No decoder for stream {} ({}) is found", str_id, info.get_name());
                 transcoder.decoders.push(None);
@@ -639,9 +588,9 @@ fn main() {
                 };
                 transcoder.queue.queue_packet(pkt);
             },
-            OutputMode::Encode(dst_id, ref mut encoder, ref mut cvt) => {
-                if let Some((ref mut dsupp, ref mut decoder, ref mut reorderer)) = transcoder.decoders[src_id] {
-                    let ret = decoder.decode(dsupp, &pkt);
+            OutputMode::Encode(dst_id, ref mut encoder) => {
+                if let Some(ref mut dec_ctx) = transcoder.decoders[src_id] {
+                    let ret = dec_ctx.decoder.decode(&mut dec_ctx.dsupp, &pkt);
                     if let (true, Err(DecoderError::MissingReference)) = (transcoder.start != NATimePoint::None, &ret) {
                         continue;
                     }
@@ -650,10 +599,9 @@ fn main() {
                         break;
                     }
                     let frm = ret.unwrap();
-                    let tinfo = frm.get_info();
-                    reorderer.add_frame(frm);
-                    while let Some(frm) = reorderer.get_frame() {
-                        if !encode_frame(dst_id, encoder, cvt, frm, &transcoder.scale_opts) {
+                    dec_ctx.reorderer.add_frame(frm);
+                    while let Some(frm) = dec_ctx.reorderer.get_frame() {
+                        if !encoder.encode_frame(dst_id, frm, &transcoder.scale_opts) {
                             break;
                         }
                         while let Ok(Some(pkt)) = encoder.get_packet() {
@@ -667,19 +615,6 @@ fn main() {
                             transcoder.queue.queue_packet(pkt);
                         }
                     }
-                    if let OutputConvert::Audio(ref mut acvt) = cvt {
-                        while let Some(ofrm) = acvt.get_frame(tinfo.clone()) {
-                            if encoder.encode(&ofrm).is_err() {
-                                break;
-                            }
-                            while let Ok(Some(pkt)) = encoder.get_packet() {
-                                if transcoder.end != NATimePoint::None && !pkt.ts.less_than(transcoder.end) { break 'main_loop; }
-                                let pkt_size = pkt.get_buffer().len();
-                                adata_size += pkt_size;
-                                transcoder.queue.queue_packet(pkt);
-                            }
-                        }
-                    }
                 } else {
                     println!("no decoder for stream {}", src_id);
                     break;
@@ -696,10 +631,10 @@ fn main() {
     }
     'reord_flush_loop: for stream in ism.iter() {
         let src_id = stream.get_num();
-        if let OutputMode::Encode(dst_id, ref mut encoder, ref mut cvt) = transcoder.encoders[src_id] {
-            if let Some((_, _, ref mut reorderer)) = transcoder.decoders[src_id] {
-                while let Some(frm) = reorderer.get_last_frames() {
-                    if !encode_frame(dst_id, encoder, cvt, frm, &transcoder.scale_opts) {
+        if let OutputMode::Encode(dst_id, ref mut encoder) = transcoder.encoders[src_id] {
+            if let Some(ref mut dec_ctx) = transcoder.decoders[src_id] {
+                while let Some(frm) = dec_ctx.reorderer.get_last_frames() {
+                    if !encoder.encode_frame(dst_id, frm, &transcoder.scale_opts) {
                         break;
                     }
                     while let Ok(Some(pkt)) = encoder.get_packet() {
@@ -712,15 +647,11 @@ fn main() {
     }
     /*'flush_loop:*/ for enc in transcoder.encoders.iter_mut() {
         match enc {
-            OutputMode::Encode(str_id, ref mut encoder, _) => {
-                let ret = encoder.flush();
+            OutputMode::Encode(str_id, ref mut encoder) => {
+                let ret = encoder.flush(&mut transcoder.queue);
                 if ret.is_err() {
                     println!("error flushing encoder for stream {}", str_id);
                     break;
-                } else {
-                    while let Ok(Some(pkt)) = encoder.get_packet() {
-                        transcoder.queue.queue_packet(pkt);
-                    }
                 }
             },
             _ => {},
index f57d45e8ca523ccc43305f1ab56361809eb62275..32eb2b5663596a0bf98c0fa488a2f2272bb1a7ca 100644 (file)
@@ -114,17 +114,121 @@ pub struct OutputStreamOptions {
     pub enc_opts:       Vec<OptionArgs>,
 }
 
-pub enum OutputConvert {
-    Video(NAScale, NABufferType),
-    Audio(AudioConverter),
-    None,
+pub struct DecodeContext {
+    pub dsupp:      Box<NADecoderSupport>,
+    pub decoder:    Box<dyn NADecoder>,
+    pub reorderer:  Box<dyn FrameReorderer>,
+}
+
+pub trait EncoderInterface {
+    fn encode_frame(&mut self, dst_id: u32, frm: NAFrameRef, scale_opts: &[(String, String)]) -> bool;
+    fn flush(&mut self, queue: &mut OutputQueue) -> EncoderResult<()>;
+    fn get_packet(&mut self) -> EncoderResult<Option<NAPacket>>;
+}
+
+pub struct AudioEncodeContext {
+    pub encoder:    Box<dyn NAEncoder>,
+    pub cvt:        Option<AudioConverter>,
+}
+
+impl EncoderInterface for AudioEncodeContext {
+    fn encode_frame(&mut self, dst_id: u32, frm: NAFrameRef, _scale_opts: &[(String, String)]) -> bool {
+        let buf = frm.get_buffer();
+        let cbuf = if let NABufferType::None = buf {
+                buf
+            } else if let Some(ref mut acvt) = self.cvt {
+                if !acvt.queue_frame(buf, frm.get_time_information()) {
+                    println!("error converting audio for stream {}", dst_id);
+                    return false;
+                }
+
+                while let Some(ofrm) = acvt.get_frame(frm.get_info().clone()) {
+                    if self.encoder.encode(&ofrm).is_err() {
+                        return false;
+                    }
+                }
+
+                return true;
+            } else {
+                buf
+            };
+        let cfrm = NAFrame::new(frm.get_time_information(), frm.frame_type, frm.key, frm.get_info(), cbuf);
+        self.encoder.encode(&cfrm).unwrap();
+        true        
+    }
+    fn flush(&mut self, queue: &mut OutputQueue) -> EncoderResult<()> {
+        self.encoder.flush()?;
+        while let Ok(Some(pkt)) = self.encoder.get_packet() {
+            queue.queue_packet(pkt);
+        }
+        Ok(())
+    }
+    fn get_packet(&mut self) -> EncoderResult<Option<NAPacket>> {
+        self.encoder.get_packet()
+    }
+}
+
+pub struct VideoEncodeContext {
+    pub encoder:    Box<dyn NAEncoder>,
+    pub scaler:     Option<NAScale>,
+    pub scaler_buf: NABufferType,
+}
+
+impl EncoderInterface for VideoEncodeContext {
+    fn encode_frame(&mut self, dst_id: u32, frm: NAFrameRef, scale_opts: &[(String, String)]) -> bool {
+        let buf = frm.get_buffer();
+        let cbuf = if let NABufferType::None = buf {
+            if (self.encoder.get_capabilities() & ENC_CAPS_SKIPFRAME) == 0 {
+                if let NABufferType::None = self.scaler_buf {
+                    println!("encoder does not support skip frames, skipping");
+                    return true;
+                } else {
+                    self.scaler_buf.clone()
+                }
+            } else {
+                buf
+            }
+        } else if let Some(ref mut scaler) = self.scaler {
+            let cur_ifmt = get_scale_fmt_from_pic(&buf);
+            let last_ifmt = scaler.get_in_fmt();
+            if cur_ifmt != last_ifmt {
+                let ofmt = scaler.get_out_fmt();
+                let ret = NAScale::new_with_options(cur_ifmt, ofmt, scale_opts);
+                if ret.is_err() {
+                    println!("error re-initialising scaler for {} -> {}", cur_ifmt, ofmt);
+                    return false;
+                }
+                *scaler = ret.unwrap();
+            }
+            let ret = scaler.convert(&buf, &mut self.scaler_buf);
+            if ret.is_err() {
+                println!("error converting frame for encoding stream {dst_id}");
+                return false;
+            }
+            self.scaler_buf.clone()
+        } else {
+            buf
+        };
+        let cfrm = NAFrame::new(frm.get_time_information(), frm.frame_type, frm.key, frm.get_info(), cbuf);
+        self.encoder.encode(&cfrm).unwrap();
+        true
+    }
+    fn flush(&mut self, queue: &mut OutputQueue) -> EncoderResult<()> {
+        self.encoder.flush()?;
+        while let Ok(Some(pkt)) = self.encoder.get_packet() {
+            queue.queue_packet(pkt);
+        }
+        Ok(())
+    }
+    fn get_packet(&mut self) -> EncoderResult<Option<NAPacket>> {
+        self.encoder.get_packet()
+    }
 }
 
-#[allow(clippy::large_enum_variant)]
 pub enum OutputMode {
     Drop,
     Copy(u32),
-    Encode(u32, Box<dyn NAEncoder>, OutputConvert),
+    Encode(u32, Box<dyn EncoderInterface>),
 }
 
 #[derive(Default)]
@@ -203,7 +307,7 @@ pub struct Transcoder {
     pub istr_opts:      Vec<InputStreamOptions>,
     pub ostr_opts:      Vec<OutputStreamOptions>,
     pub scale_opts:     Vec<(String, String)>,
-    pub decoders:       Vec<Option<(Box<NADecoderSupport>, Box<dyn NADecoder>, Box<dyn FrameReorderer>)>>,
+    pub decoders:       Vec<Option<DecodeContext>>,
     pub encoders:       Vec<OutputMode>,
     pub no_video:       bool,
     pub no_audio:       bool,
@@ -599,11 +703,26 @@ impl Transcoder {
                 }
                 let ret_eparams = ret_eparams.unwrap();
 
+                let name = format!("output stream {}", out_id);
+                parse_and_apply_options!(encoder, &oopts.enc_opts, name);
+
+                if self.calc_len && self.nframes.len() > iidx {
+                    encoder.set_options(&[NAOption{name: "nframes", value: NAValue::Int(self.nframes[iidx] as i64)}]);
+                }
+
+                let ret = encoder.init(out_id, ret_eparams);
+                if ret.is_err() {
+                    println!("error initialising encoder");
+                    return RegisterResult::Failed;
+                }
+
+                parse_and_apply_options!(encoder, &oopts.enc_opts, name);
+
 //todo check for params mismatch
-                let cvt = match (&iformat, &ret_eparams.format) {
+                let enc_ctx: Box<dyn EncoderInterface> = match (&iformat, &ret_eparams.format) {
                         (NACodecTypeInfo::Video(svinfo), NACodecTypeInfo::Video(dvinfo)) => {
                             if svinfo == dvinfo && !forced_out {
-                                OutputConvert::None
+                                Box::new(VideoEncodeContext { encoder, scaler: None, scaler_buf: NABufferType::None })
                             } else {
                                 let ofmt = ScaleInfo { fmt: dvinfo.format, width: dvinfo.width, height: dvinfo.height };
                                 let ret = NAScale::new_with_options(ofmt, ofmt, &self.scale_opts);
@@ -617,14 +736,14 @@ impl Transcoder {
                                     println!("cannot create scaler buffer");
                                     return RegisterResult::Failed;
                                 }
-                                let cvt_buf = ret.unwrap();
-                                OutputConvert::Video(scaler, cvt_buf)
+                                let scaler_buf = ret.unwrap();
+                                Box::new(VideoEncodeContext { encoder, scaler: Some(scaler), scaler_buf })
                             }
                         },
                         (NACodecTypeInfo::Audio(sainfo), NACodecTypeInfo::Audio(dainfo)) => {
                             let icodec = istr.get_info().get_name();
                             if (sainfo == dainfo) && (icodec != "pcm" || oopts.enc_name.as_str() == "pcm") {
-                                OutputConvert::None
+                                Box::new(AudioEncodeContext { encoder, cvt: None })
                             } else {
                                 let dchmap = match dainfo.channels {
                                         1 => NAChannelMap::from_ms_mapping(0x4),
@@ -636,28 +755,13 @@ println!("can't generate default channel map for {} channels", dainfo.channels);
                                     };
                                 let acvt = AudioConverter::new(sainfo, dainfo, dchmap);
 //todo channelmap
-                                OutputConvert::Audio(acvt)
+                                Box::new(AudioEncodeContext { encoder, cvt: Some(acvt) })
                             }
                         },
-                        _ => OutputConvert::None,
+                        _ => unreachable!(),
                     };
-                let name = format!("output stream {}", out_id);
-                parse_and_apply_options!(encoder, &oopts.enc_opts, name);
-
-                if self.calc_len && self.nframes.len() > iidx {
-                    encoder.set_options(&[NAOption{name: "nframes", value: NAValue::Int(self.nframes[iidx] as i64)}]);
-                }
-
-                let ret = encoder.init(out_id, ret_eparams);
-                if ret.is_err() {
-                    println!("error initialising encoder");
-                    return RegisterResult::Failed;
-                }
                 out_sm.add_stream_ref(ret.unwrap());
-
-                parse_and_apply_options!(encoder, &oopts.enc_opts, name);
-
-                self.encoders.push(OutputMode::Encode(out_id, encoder, cvt));
+                self.encoders.push(OutputMode::Encode(out_id, enc_ctx));
             } else {
 println!("encoder {} is not supported by output (expected {})", istr.id, istr.get_info().get_name());
                 return RegisterResult::Failed;
@@ -689,11 +793,16 @@ println!("encoder {} is not supported by output (expected {})", istr.id, istr.ge
             }
             let ret_eparams = ret_eparams.unwrap();
 
+            let ret = encoder.init(out_id, ret_eparams);
+            if ret.is_err() {
+                println!("error initialising encoder");
+                return RegisterResult::Failed;
+            }
 //todo check for params mismatch
-            let cvt = match (&oopts.enc_params.format, &ret_eparams.format) {
+            let enc_ctx: Box<dyn EncoderInterface> = match (&oopts.enc_params.format, &ret_eparams.format) {
                     (NACodecTypeInfo::Video(svinfo), NACodecTypeInfo::Video(dvinfo)) => {
                         if svinfo == dvinfo {
-                            OutputConvert::None
+                            Box::new(VideoEncodeContext { encoder, scaler: None, scaler_buf: NABufferType::None })
                         } else {
                             let ofmt = ScaleInfo { fmt: dvinfo.format, width: dvinfo.width, height: dvinfo.height };
                             let ret = NAScale::new_with_options(ofmt, ofmt, &self.scale_opts);
@@ -707,13 +816,13 @@ println!("encoder {} is not supported by output (expected {})", istr.id, istr.ge
                                 println!("cannot create scaler buffer");
                                 return RegisterResult::Failed;
                             }
-                            let cvt_buf = ret.unwrap();
-                            OutputConvert::Video(scaler, cvt_buf)
+                            let scaler_buf = ret.unwrap();
+                            Box::new(VideoEncodeContext { encoder, scaler: Some(scaler), scaler_buf })
                         }
                     },
                     (NACodecTypeInfo::Audio(sainfo), NACodecTypeInfo::Audio(dainfo)) => {
                         if sainfo == dainfo {
-                            OutputConvert::None
+                            Box::new(AudioEncodeContext { encoder, cvt: None })
                         } else {
                             let dchmap = match dainfo.channels {
                                     1 => NAChannelMap::from_ms_mapping(0x4),
@@ -725,18 +834,13 @@ println!("can't generate default channel map for {} channels", dainfo.channels);
                                 };
 //todo channelmap
                             let acvt = AudioConverter::new(sainfo, dainfo, dchmap);
-                            OutputConvert::Audio(acvt)
+                            Box::new(AudioEncodeContext { encoder, cvt: Some(acvt) })
                         }
                     },
-                    _ => OutputConvert::None,
+                    _ => unreachable!(),
                 };
-            let ret = encoder.init(out_id, ret_eparams);
-            if ret.is_err() {
-                println!("error initialising encoder");
-                return RegisterResult::Failed;
-            }
             out_sm.add_stream_ref(ret.unwrap());
-            self.encoders.push(OutputMode::Encode(out_id, encoder, cvt));
+            self.encoders.push(OutputMode::Encode(out_id, enc_ctx));
             self.ostr_opts.push(oopts);
         }
         RegisterResult::Ok