]> git.nihav.org Git - nihav-encoder.git/commitdiff
add global palette calculation feature/hack
authorKostya Shishkov <kostya.shishkov@gmail.com>
Wed, 15 Apr 2026 20:01:37 +0000 (22:01 +0200)
committerKostya Shishkov <kostya.shishkov@gmail.com>
Wed, 15 Apr 2026 20:01:37 +0000 (22:01 +0200)
Cargo.toml
src/main.rs
src/palettise.rs
src/transcoder.rs

index 1bf45dc7b3ed72e21a767124f8f79dedd8047031..3885990d1f645681372446c04d174cb3f046448b 100644 (file)
@@ -6,7 +6,7 @@ edition = "2018"
 
 [dependencies]
 nihav_core = { path="../nihav-core" }
-nihav_codec_support = { path="../nihav-codec-support" }
+nihav_codec_support = { path="../nihav-codec-support", features=["vq"] }
 nihav_registry = { path="../nihav-registry" }
 nihav_allstuff = { path="../nihav-allstuff" }
 
index 9dee3199e6351caed097d6b8f874eeab1861b893..9bdde8c1e5990e60298926be89ab7a808f4a948c 100644 (file)
@@ -300,6 +300,9 @@ fn main() {
                 }
                 printed_info = true;
             },
+            "--generate-palette" => {
+                transcoder.gen_pal = true;
+            },
             "--output" | "-o" => {
                 next_arg!(args, arg_idx);
                 transcoder.output_name = args[arg_idx].clone();
@@ -451,7 +454,7 @@ fn main() {
     }
 
     let mut demuxers = Vec::with_capacity(1);
-    if !transcoder.create_demuxers(&mut demuxers, &full_reg, true) {
+    if !transcoder.create_demuxers(&mut demuxers, &full_reg, true, true) {
         return;
     }
 
@@ -575,13 +578,157 @@ fn main() {
         }
         // this is necessary since not all demuxers allow to seek even back to the start
         demuxers.clear();
-        if !transcoder.create_demuxers(&mut demuxers, &full_reg, false) {
+        if !transcoder.create_demuxers(&mut demuxers, &full_reg, false, false) {
             println!("failed to re-create demuxer(s)");
             transcoder.debug_log(DebugLog::GENERAL, "failed to re-create demuxer(s)");
             return;
         }
     }
 
+    if transcoder.gen_pal {
+        transcoder.debug_log(DebugLog::GENERAL, "Calculating global palette(s)");
+
+        let mut nenc = 0;
+        let mut plt_map = vec![0; ism.get_num_streams()];
+        let mut rev_map = vec![0; ism.get_num_streams()];
+        for (pmap, strm) in plt_map.iter_mut().zip(ism.iter()) {
+            if strm.get_media_type() == StreamType::Video {
+                *pmap = nenc;
+                rev_map.push(strm.get_num());
+                nenc += 1;
+            }
+        }
+        if nenc > 0 {
+            let mut ccounters = Vec::with_capacity(nenc);
+            for _ in 0..nenc {
+                ccounters.push(crate::palettise::ColourCounter::new());
+            }
+            let mut cur_dmx = 0;
+            let mut last_known_time = None;
+            let mut time = Instant::now();
+            let show_interval = Duration::from_millis(100);
+            'dec_loop: loop {
+                let mut pktres = Err(DemuxerError::EOF);
+                let mut src_dmx = 0;
+                loop {
+                    if !demuxers.iter().any(|(_, eof)| !eof) {
+                        break;
+                    }
+                    let mut got_res = false;
+                    if !demuxers[cur_dmx].1 {
+                        pktres = demuxers[cur_dmx].0.get_frame();
+                        got_res = true;
+                        src_dmx = cur_dmx;
+                    }
+                    cur_dmx += 1;
+                    if cur_dmx >= demuxers.len() {
+                        cur_dmx = 0;
+                    }
+                    if got_res {
+                        break;
+                    }
+                }
+
+                if let Err(DemuxerError::EOF) = pktres { break; }
+                if pktres.is_err() {
+                    println!("demuxing error");
+                    transcoder.debug_log(DebugLog::DEMUX, &format!("Demuxing error: {:?}", pktres.err().unwrap()));
+                    break;
+                }
+                let mut pkt = pktres.unwrap();
+                if pkt.get_stream().get_media_type() != StreamType::Video { continue; }
+                if transcoder.start != NATimePoint::None && pkt.ts.less_than(transcoder.start) { continue; }
+                let src_id = pkt.get_stream().get_num() + is_offset[src_dmx];
+                let ts = pkt.ts;
+                if let Some(pts) = ts.pts {
+                    last_known_time = Some(NATimeInfo::rescale_ts(pts, pkt.ts.tb_num, pkt.ts.tb_den, 1, 1000));
+                }
+                let newstream = ism.get_stream(src_id).unwrap();
+                pkt.reassign(newstream, ts);
+                transcoder.debug_log(DebugLog::DEMUX, &format!(" Got packet from input stream {src_id} ts {ts:?}"));
+
+                if let Some(ref mut dec_ctx) = transcoder.decoders[src_id] {
+                    let ret = dec_ctx.decoder.decode(&mut dec_ctx.dsupp, &pkt);
+                    if let Some(ref mut dbg) = transcoder.debug {
+                        match &ret {
+                            Ok(_) => dbg.log(DebugLog::DECODE, &format!(" Packet from stream {src_id} decoded")),
+                            Err(err) => dbg.log(DebugLog::DECODE, &format!(" Packet from stream {src_id} decode error {err:?}")),
+                        }
+                    }
+                    if let (true, Err(DecoderError::MissingReference)) = (transcoder.start != NATimePoint::None, &ret) {
+                        continue;
+                    }
+                    if ret.is_err() {
+                        println!("error decoding stream {}", src_id);
+                        if !ignerr {
+                            break;
+                        } else {
+                            continue;
+                        }
+                    }
+                    let frm = ret.unwrap();
+                    dec_ctx.reorderer.add_frame(frm);
+                    let ccount = &mut ccounters[plt_map[src_id]];
+                    while let Some(frm) = dec_ctx.reorderer.get_frame() {
+                        if transcoder.end != NATimePoint::None && !frm.ts.less_than(transcoder.end) { break 'dec_loop; }
+                        if let Err(err) = ccount.add_frame(frm.get_buffer()) {
+                            println!("palettiser error {err}");
+                        }
+                    }
+                } else {
+                    println!("no decoder for stream {}", src_id);
+                    transcoder.debug_log(DebugLog::DECODE, &format!("No decoder for input stream {src_id}"));
+                    break;
+                }
+                if transcoder.verbose > 0 && time.elapsed() >= show_interval {
+                    if let Some(time) = last_known_time {
+                        print!(" palettisation {}", format_time(time));
+                    } else {
+                        print!(" ???");
+                    }
+                    if !duration_string.is_empty() {
+                        print!(" / {}", duration_string);
+                    }
+                    print!("\r");
+                    std::io::stdout().flush().unwrap();
+                    time = Instant::now();
+                }
+            }
+            transcoder.debug_log(DebugLog::GENERAL, "Demuxing done, flushing leftover frames");
+            'reord_flush_loop: for stream in ism.iter() {
+                let src_id = stream.get_num();
+                if let Some(ref mut dec_ctx) = transcoder.decoders[src_id] {
+                    let ccount = &mut ccounters[plt_map[src_id]];
+                    while let Some(frm) = dec_ctx.reorderer.get_last_frames() {
+                        if transcoder.end != NATimePoint::None && !frm.ts.less_than(transcoder.end) { break 'reord_flush_loop; }
+                        if let Err(err) = ccount.add_frame(frm.get_buffer()) {
+                            println!("palettiser error {err}");
+                        }
+                    }
+                }
+            }
+            for (&id, plt) in rev_map.iter().zip(ccounters.iter()) {
+                let pal = plt.get_pal();
+                transcoder.qsupport.glbl_pal.push((id, pal));
+            }
+            if transcoder.verbose > 0 {
+                println!();
+            }
+            transcoder.debug_log(DebugLog::GENERAL, "Resetting state after palettisation");
+            // this is necessary since not all demuxers allow to seek even back to the start
+            demuxers.clear();
+            if !transcoder.create_demuxers(&mut demuxers, &full_reg, false, false) {
+                println!("failed to re-create demuxer(s)");
+                transcoder.debug_log(DebugLog::GENERAL, "failed to re-create demuxer(s)");
+                return;
+            }
+            transcoder.decoders.clear();
+            if !transcoder.create_decoders(&full_reg, &is_offset, &mut demuxers, ignerr) {
+                return;
+            }
+        }
+    }
+
     let mux_caps = mux_creator.get_capabilities();
     let mut out_sm = StreamManager::new();
     if !transcoder.negotiate_stream_map(&ism, mux_caps, &mut out_sm, &full_reg.enc_reg) {
index c1b75c31513eb1a25066e479f57218f68bcfa5cb..9889ee16b09680af62321186c4ad5a0b3d105fa7 100644 (file)
@@ -1,5 +1,10 @@
+use std::cmp::*;
+use std::convert::TryInto;
+
 use nihav_core::frame::*;
+use nihav_core::scale::{NAScale, ScaleInfo, get_scale_fmt_from_pic};
 use nihav_codec_support::codecs::qt_pal::*;
+use nihav_codec_support::vq::*;
 
 use crate::transcoder::OptionArgs;
 
@@ -400,3 +405,275 @@ pub fn create_palettiser(enc_opts: &[OptionArgs]) -> Option<Palettiser> {
         None
     }
 }
+
+#[derive(Clone,Copy,Default)]
+struct Colour {
+    clr:    [u8; 3],
+    count:  u64,
+}
+
+impl Colour {
+    fn new(clr: [u8; 3]) -> Self {
+        Self { clr, count: 1 }
+    }
+}
+
+impl std::cmp::PartialEq for Colour {
+    fn eq(&self, other: &Colour) -> bool {
+        self.clr == other.clr
+    }
+}
+
+impl VQElement for Colour {
+    fn dist(&self, rval: Self) -> u32 {
+        let rd = u32::from(self.clr[0].abs_diff(rval.clr[0]));
+        let gd = u32::from(self.clr[1].abs_diff(rval.clr[1]));
+        let bd = u32::from(self.clr[2].abs_diff(rval.clr[2]));
+        rd * rd + gd * gd + bd * bd
+    }
+    fn min_cw() -> Self { Colour::new([0x00; 3]) }
+    fn max_cw() -> Self { Colour::new([0xFF; 3]) }
+    fn min(&self, rval: Self) -> Self {
+        Colour::new([self.clr[0].min(rval.clr[0]),
+                     self.clr[1].min(rval.clr[1]),
+                     self.clr[2].min(rval.clr[2])])
+    }
+    fn max(&self, rval: Self) -> Self {
+        Colour::new([self.clr[0].max(rval.clr[0]),
+                     self.clr[1].max(rval.clr[1]),
+                     self.clr[2].max(rval.clr[2])])
+    }
+    fn num_components() -> usize { 3 }
+    fn sort_by_component(arr: &mut [Self], component: usize) {
+        arr.sort_unstable_by(|a, b| a.clr[component].cmp(&b.clr[component]));
+    }
+    fn max_dist_component(min: &Self, max: &Self) -> usize {
+        let rd = u32::from(min.clr[0].abs_diff(max.clr[0]));
+        let gd = u32::from(min.clr[1].abs_diff(max.clr[1]));
+        let bd = u32::from(min.clr[2].abs_diff(max.clr[2]));
+        if gd >= rd && gd >= bd {
+            1
+        } else if rd >= gd && rd >= bd {
+            0
+        } else if bd >= rd && bd >= gd {
+            2
+        } else {
+            1
+        }
+    }
+}
+
+#[derive(Default)]
+struct ColourSum {
+    clr:    [u64; 3],
+    tot:    u64,
+}
+
+impl VQElementSum<Colour> for ColourSum {
+    fn zero() -> Self { Self::default() }
+    fn add(&mut self, rval: Colour, _count: u64) {
+        for (dst, &src) in self.clr.iter_mut().zip(rval.clr.iter()) {
+            *dst += u64::from(src) * rval.count;
+        }
+        self.tot += rval.count;
+    }
+    fn get_centroid(&self) -> Colour {
+        if self.tot > 0 {
+            Colour {
+                clr: [(self.clr[0] / self.tot) as u8,
+                      (self.clr[1] / self.tot) as u8,
+                      (self.clr[2] / self.tot) as u8],
+                count: self.tot,
+            }
+        } else {
+            Colour::default()
+        }
+    }
+}
+
+const BUCKET_BITS: u8 = 4;
+const NBUCKETS: usize = 1 << (3 * (8 - BUCKET_BITS));
+const AVG_BITS: u8 = 2;
+const RESCALE_LIMIT: usize = 1 << (3 * (8 - AVG_BITS));
+
+#[derive(Clone,Default)]
+struct Bucket {
+    clrs:   Vec<Colour>,
+}
+
+impl Bucket {
+    fn key(clr: [u8; 3]) -> usize {
+        (usize::from(clr[0] >> (8 - BUCKET_BITS)) << (2 * BUCKET_BITS)) |
+        (usize::from(clr[1] >> (8 - BUCKET_BITS)) << BUCKET_BITS) |
+         usize::from(clr[2] >> (8 - BUCKET_BITS))
+    }
+    fn ileave(clr: [u8; 3]) -> u32 {
+        const UNP3: [u32; 16] = [
+            0x000, 0x001, 0x008, 0x009, 0x040, 0x041, 0x048, 0x049,
+            0x200, 0x201, 0x208, 0x209, 0x240, 0x241, 0x248, 0x249
+        ];
+        (UNP3[usize::from(clr[0] >> 4)] << 14) |
+        (UNP3[usize::from(clr[0] & 0xF)] << 2) |
+        (UNP3[usize::from(clr[1] >> 4)] << 13) |
+        (UNP3[usize::from(clr[1] & 0xF)] << 1) |
+        (UNP3[usize::from(clr[2] >> 4)] << 12) |
+         UNP3[usize::from(clr[2] & 0xF)]
+    }
+    fn clr_diff(a: [u8; 3], b: [u8; 3]) -> u8 {
+        (a[0] ^ b[0]) | (a[1] ^ b[1]) | (a[2] ^ b[2])
+    }
+    fn average(clrs: &[Colour]) -> Colour {
+        // xxx: maybe simply reuse ColourSum?
+        let maxsum = clrs.iter().fold(0u64, |acc, el| acc + el.count);
+        let mut sum = 0;
+        let mut acc = [0u64; 3];
+        const MASK: u8 = (1 << AVG_BITS) - 1;
+        for clr in clrs.iter() {
+            let weight = clr.count;
+            acc[0] += u64::from(clr.clr[0] & MASK) * weight;
+            acc[1] += u64::from(clr.clr[1] & MASK) * weight;
+            acc[2] += u64::from(clr.clr[2] & MASK) * weight;
+            sum += weight;
+        }
+        let new_clr = std::array::from_fn(|i| (clrs[0].clr[i] & !MASK) | ((acc[i] / sum) as u8));
+        Colour { clr: new_clr, count: maxsum }
+    }
+
+    fn get_num_entries(&self) -> usize { self.clrs.len() }
+    fn add(&mut self, clr: [u8; 3]) -> bool {
+        for el in self.clrs.iter_mut() {
+            if el.clr == clr {
+                el.count += 1;
+                return false;
+            }
+        }
+        self.clrs.push(Colour::new(clr));
+        true
+    }
+    fn rescale(&mut self) {
+        if self.clrs.is_empty() {
+            return;
+        }
+        self.clrs.sort_unstable_by(|a, b| Self::ileave(a.clr).cmp(&Self::ileave(b.clr)));
+        let mut new_clrs = Vec::with_capacity(self.clrs.len() / 4);
+        let mut start_idx = 0;
+        while start_idx < self.clrs.len() {
+            let mut end_idx = start_idx;
+            while end_idx < self.clrs.len() {
+                let diff = Self::clr_diff(self.clrs[start_idx].clr, self.clrs[end_idx].clr);
+                if diff >= (1 << AVG_BITS) {
+                    break;
+                }
+                end_idx += 1;
+            }
+            new_clrs.push(Self::average(&self.clrs[start_idx..end_idx]));
+            start_idx = end_idx;
+        }
+        self.clrs = new_clrs;
+    }
+}
+
+pub struct ColourCounter {
+    buckets:    Vec<Bucket>,
+    nentries:   usize,
+    scaler:     Option<NAScale>,
+    sc_buf:     NABufferType,
+    oinfo:      NAVideoInfo,
+}
+
+impl ColourCounter {
+    pub fn new() -> Self {
+        Self {
+            buckets:    vec![Bucket::default(); NBUCKETS],
+            nentries:   0,
+            scaler:     None,
+            sc_buf:     NABufferType::None,
+            oinfo:      NAVideoInfo{ width: 0, height: 0, flipped: false, format: RGB24_FORMAT, bits: 24 },
+        }
+    }
+    pub fn add(&mut self, clr: [u8; 3]) {
+        let idx = Bucket::key(clr);
+        if self.buckets[idx].add(clr) {
+            self.nentries += 1;
+            if self.nentries > RESCALE_LIMIT {
+                self.nentries = 0;
+                for bucket in self.buckets.iter_mut() {
+                    if bucket.clrs.len() > RESCALE_LIMIT / NBUCKETS / 2 {
+                        bucket.rescale();
+                    }
+                    self.nentries += bucket.get_num_entries();
+                }
+            }
+        }
+    }
+    fn add_rgb24(&mut self, buf: &NAVideoBuffer<u8>, width: usize, height: usize) {
+        let src = buf.get_data();
+        let stride = buf.get_stride(0);
+        for line in src.chunks_exact(stride).take(height) {
+            for pix in line.chunks_exact(3).take(width) {
+                let clr = pix.try_into().unwrap();
+                self.add(clr);
+            }
+        }
+    }
+    pub fn add_frame(&mut self, buf: NABufferType) -> Result<(), &'static str> {
+        if matches!(buf, NABufferType::None) {
+            return Ok(());
+        }
+        if let Some(vinfo) = buf.get_video_info() {
+            if vinfo.format == RGB24_FORMAT {
+                self.add_rgb24(&buf.get_vbuf().unwrap(), vinfo.width, vinfo.height);
+            } else {
+                if vinfo.width != self.oinfo.width || vinfo.height != self.oinfo.height || self.scaler.is_none() {
+                    let ifmt = get_scale_fmt_from_pic(&buf);
+                    if self.oinfo.width == 0 {
+                        self.oinfo.width  = vinfo.width;
+                        self.oinfo.height = vinfo.height;
+                    }
+                    let ofmt = ScaleInfo {
+                            width:  self.oinfo.width,
+                            height: self.oinfo.height,
+                            fmt:    RGB24_FORMAT,
+                        };
+                    self.scaler = Some(NAScale::new(ifmt, ofmt).map_err(|_| "error creating rgb24 scaler")?);
+                    self.sc_buf = alloc_video_buffer(self.oinfo, 0).map_err(|_| "error creating rgb24 scale buf")?;
+                }
+                if let Some(ref mut scaler) = self.scaler {
+                    scaler.convert(&buf, &mut self.sc_buf).map_err(|_| "rgb24 scaling failed")?;
+                    self.add_rgb24(&self.sc_buf.get_vbuf().unwrap(), self.oinfo.width, self.oinfo.height);
+                } else {
+                    return Err("rgb24 scaler not present!");
+                }
+            }
+            Ok(())
+        } else {
+            Err("not a video frame")
+        }
+    }
+    pub fn get_pal(&self) -> [[u8; 3]; 256] {
+        let mut pal = [[0; 3]; 256];
+        if self.nentries <= 256 {
+            let mut dst = pal.iter_mut();
+            for bucket in self.buckets.iter() {
+                for clr in bucket.clrs.iter() {
+                    *dst.next().unwrap() = clr.clr;
+                }
+            }
+        } else {
+            let mut in_clrs = Vec::with_capacity(self.nentries);
+            for bucket in self.buckets.iter() {
+                for clr in bucket.clrs.iter() {
+                    in_clrs.push(*clr);
+                }
+            }
+            let mut ppal = [Colour::default(); 256];
+            let _prim_clrs = quantise_median_cut::<Colour, ColourSum>(&in_clrs, &mut ppal);
+            let mut elbg: ELBG<Colour, ColourSum> = ELBG::new(&ppal);
+            elbg.quantise(&in_clrs, &mut ppal);
+            for (dst, src) in pal.iter_mut().zip(ppal.iter()) {
+                *dst = src.clr;
+            }
+        }
+        pal
+    }
+}
index 5b7b0f5f9fd2a56f1794cd1f7e9a9218c03034db..371e4191a9527ac9a8c0c82ea78d46e720747cb2 100644 (file)
@@ -533,6 +533,7 @@ pub struct QuirkSupport {
     pub nframes:        Vec<usize>,
     pub global_tb:      (u32, u32),
     pub fixed_rate:     bool,
+    pub glbl_pal:       Vec<(usize, [[u8; 3]; 256])>,
 }
 
 #[derive(Default)]
@@ -554,6 +555,7 @@ pub struct Transcoder {
     pub start:          NATimePoint,
     pub end:            NATimePoint,
     pub verbose:        u8,
+    pub gen_pal:        bool,
 
     pub qsupport:       QuirkSupport,
 
@@ -968,8 +970,17 @@ impl Transcoder {
         }
         let ret_eparams = ret_eparams.unwrap();
 
-        let plt = create_palettiser(&oopts.enc_opts);
+        let mut plt = create_palettiser(&oopts.enc_opts);
         oopts.enc_opts.retain(|opt| !opt.name.starts_with("pal."));
+        for (src_id, gpal) in qsupp.glbl_pal.iter() {
+            if *src_id == iidx {
+                if let Some(ref mut p) = plt {
+                    p.set_pal(gpal);
+                } else {
+                    plt = Some(Palettiser::new(crate::palettise::PaletteSearchMode::default(), gpal));
+                }
+            }
+        }
 
         let name = format!("output stream {}", out_id);
         parse_and_apply_options!(encoder, &oopts.enc_opts, name);
@@ -1273,13 +1284,15 @@ println!("encoder {} is not supported by output (expected {})", istr.id, istr.ge
             },
         }
     }
-    pub fn create_demuxers(&mut self, demuxers: &mut Vec<(DemuxerObject, bool)>, full_reg: &FullRegister, print_info: bool) -> bool {
+    pub fn create_demuxers(&mut self, demuxers: &mut Vec<(DemuxerObject, bool)>, full_reg: &FullRegister, print_info: bool, print_input: bool) -> bool {
         let mut isn_start = 0;
         for (i, (iname, ifmt)) in self.input_name.iter().zip(
                 self.input_fmt.iter()).enumerate() {
             match (iname, ifmt.as_ref().map(|s| s.as_str())) {
                 (Some(name), Some("imgseq")) => {
-                    println!("trying image sequence {}", name);
+                    if print_input {
+                        println!("trying image sequence {}", name);
+                    }
                     if let Some(ref mut dbg) = self.debug {
                         dbg.log(DebugLog::DEMUX, &format!("Input {i}: trying image sequence {name}"));
                     }
@@ -1313,7 +1326,9 @@ println!("encoder {} is not supported by output (expected {})", istr.id, istr.ge
                     demuxers.push((dmx, false))
                 },
                 (Some(name), _) => {
-                    print!("Input {i}: {name}");
+                    if print_input {
+                        print!("Input {i}: {name}");
+                    }
                     let res = File::open(name);
                     if let Some(ref mut dbg) = self.debug {
                         dbg.log(DebugLog::DEMUX, &format!("Input {i}: {name}"));
@@ -1352,7 +1367,9 @@ println!("encoder {} is not supported by output (expected {})", istr.id, istr.ge
                         }
                         return false;
                     }
-                    println!(" type {dmx}");
+                    if print_input {
+                        println!(" type {dmx}");
+                    }
                     if let Some(ref mut dbg) = self.debug {
                         dbg.log(DebugLog::DEMUX, &format!("input {i}: type {dmx}"));
                     }