]> git.nihav.org Git - nihav-encoder.git/commitdiff
allow palettisation for a defined number of colours
authorKostya Shishkov <kostya.shishkov@gmail.com>
Sat, 25 Apr 2026 08:44:36 +0000 (10:44 +0200)
committerKostya Shishkov <kostya.shishkov@gmail.com>
Sat, 25 Apr 2026 08:44:36 +0000 (10:44 +0200)
src/palettise.rs
src/transcoder.rs

index e36eef6763655fef9e724ac65db573473362e0b8..2025232cf5a479afdc74d5619fe8369a1b377f70 100644 (file)
@@ -28,7 +28,7 @@ fn find_nearest(pix: &[u8], pal: &[[u8; 3]]) -> usize {
     bestidx
 }
 
-fn gen_full_lut(dst: &mut [u8], pal: &[[u8; 3]; 256]) {
+fn gen_full_lut(dst: &mut [u8], pal: &[[u8; 3]]) {
     for (r, chunkr) in dst.chunks_exact_mut(1 << 16).enumerate() {
         for (g, chunkg) in chunkr.chunks_exact_mut(1 << 8).enumerate() {
             for (b, el) in chunkg.iter_mut().enumerate() {
@@ -40,6 +40,7 @@ fn gen_full_lut(dst: &mut [u8], pal: &[[u8; 3]; 256]) {
 
 struct LocalSearch {
     pal:        [[u8; 3]; 256],
+    nclrs:      usize,
     db:         Vec<Vec<[u8; 4]>>,
 }
 
@@ -49,13 +50,15 @@ impl LocalSearch {
         (((key[1] >> 3) as usize) << 5) |
          ((key[2] >> 3) as usize)
     }
-    fn new(in_pal: &[[u8; 3]; 256]) -> Self {
+    fn new(in_pal: &[[u8; 3]]) -> Self {
         let mut db = Vec::with_capacity(1 << 15);
-        let pal = *in_pal;
+        let mut pal = [[0; 3]; 256];
+        let nclrs = in_pal.len();
+        pal[..nclrs].copy_from_slice(in_pal);
         for _ in 0..(1 << 15) {
             db.push(Vec::new());
         }
-        for (i, palentry) in pal.iter().enumerate() {
+        for (i, palentry) in pal.iter().enumerate().take(in_pal.len()) {
             let r0 = (palentry[0] >> 3) as usize;
             let g0 = (palentry[1] >> 3) as usize;
             let b0 = (palentry[2] >> 3) as usize;
@@ -68,7 +71,7 @@ impl LocalSearch {
                 }
             }
         }
-        Self { pal, db }
+        Self { pal, nclrs, db }
     }
     fn dist(a: &[u8; 4], b: [u8; 3]) -> u32 {
         let d0 = i32::from(a[0]) - i32::from(b[0]);
@@ -93,7 +96,7 @@ impl LocalSearch {
         if count > 0 {
             best_idx
         } else {
-            find_nearest(&pix, &self.pal)
+            find_nearest(&pix, &self.pal[..self.nclrs])
         }
     }
 }
@@ -115,16 +118,17 @@ fn avg_u8(a: u8, b: u8) -> u8 {
 }
 
 impl KDTree {
-    fn new(pal: &[[u8; 3]; 256]) -> Self {
+    fn new(pal: &[[u8; 3]]) -> Self {
         let mut npal = [[0; 4]; 256];
-        for i in 0..256 {
-            npal[i][0] = pal[i][0];
-            npal[i][1] = pal[i][1];
-            npal[i][2] = pal[i][2];
-            npal[i][3] = i as u8;
-        }
-        let mut tree = Self { nodes: Vec::with_capacity(512) };
-        tree.build(&mut npal, 0, 256, 1024, false);
+        let nclrs = pal.len();
+        for (i, (dst, src)) in npal.iter_mut().zip(pal.iter()).enumerate() {
+            dst[0] = src[0];
+            dst[1] = src[1];
+            dst[2] = src[2];
+            dst[3] = i as u8;
+        }
+        let mut tree = Self { nodes: Vec::with_capacity(nclrs * 2) };
+        tree.build(&mut npal, 0, nclrs, 1024, false);
         tree
     }
     fn build(&mut self, pal: &mut [[u8; 4]; 256], start: usize, end: usize, root: usize, child0: bool) {
@@ -213,6 +217,7 @@ enum PMode {
 
 pub struct Palettiser {
     pal:    [[u8; 3]; 256],
+    nclrs:  usize,
     pmode:  PMode,
 }
 
@@ -244,17 +249,17 @@ impl LookupCache {
 #[allow(dead_code)]
 impl Palettiser {
     pub fn get_default_mode() -> PaletteSearchMode { PaletteSearchMode::Full }
-    pub fn new(mode: PaletteSearchMode, pal: &[[u8; 3]; 256]) -> Self {
+    pub fn new(mode: PaletteSearchMode, pal: &[[u8; 3]; 256], nclrs: usize) -> Self {
         let pmode = match mode {
                 PaletteSearchMode::Full => {
                     let mut tab = vec![0; 1 << 24];
-                    gen_full_lut(&mut tab, pal);
+                    gen_full_lut(&mut tab, &pal[..nclrs]);
                     PMode::Full(tab)
                 },
                 PaletteSearchMode::Local => PMode::Local(LocalSearch::new(pal)),
                 PaletteSearchMode::KDTree => PMode::Tree(KDTree::new(pal)),
             };
-        Self { pal: *pal, pmode }
+        Self { pal: *pal, nclrs, pmode }
     }
     pub fn search(&self, pix: [u8; 3]) -> usize {
         match &self.pmode {
@@ -266,12 +271,13 @@ impl Palettiser {
             PMode::Tree(kdt) => kdt.search(pix),
         }
     }
-    pub fn set_pal(&mut self, pal: &[[u8; 3]; 256]) {
+    pub fn set_pal(&mut self, pal: &[[u8; 3]; 256], nclrs: usize) {
         self.pal.copy_from_slice(pal);
+        self.nclrs = nclrs;
         match &mut self.pmode {
-            PMode::Full(ref mut tab) => { gen_full_lut(tab, pal); },
-            PMode::Local(ref mut ls) => { *ls = LocalSearch::new(pal); },
-            PMode::Tree(ref mut kdt) => { *kdt = KDTree::new(pal); },
+            PMode::Full(ref mut tab) => { gen_full_lut(tab, &pal[..nclrs]); },
+            PMode::Local(ref mut ls) => { *ls = LocalSearch::new(&pal[..nclrs]); },
+            PMode::Tree(ref mut kdt) => { *kdt = KDTree::new(&pal[..nclrs]); },
         }
     }
     pub fn palettise_frame(&self, pic_in: &NABufferType, pic_out: &mut NABufferType) -> Result<(), &'static str> {
@@ -376,6 +382,7 @@ pub fn create_palettiser(enc_opts: &[OptionArgs]) -> Option<Palettiser> {
     let mut pmode = None;
     let mut pal = std::array::from_fn(|i| [i as u8; 3]);
     let mut pal_is_some = false;
+    let mut nclrs = 256;
     for opt in enc_opts.iter() {
         match opt.name.as_str() {
             "pal.mode" => {
@@ -400,14 +407,17 @@ pub fn create_palettiser(enc_opts: &[OptionArgs]) -> Option<Palettiser> {
                             for (i, clr) in pal.iter_mut().enumerate() {
                                 *clr = [i as u8; 3];
                             }
+                            nclrs = 256;
                         },
                         "bw" => {
                             pal[0] = [0x00; 3];
                             pal[1] = [0xFF; 3];
+                            nclrs = 2;
                         },
                         "wb" => {
                             pal[0] = [0xFF; 3];
                             pal[1] = [0x00; 3];
+                            nclrs = 2;
                         },
                         "systematic" => {
                             for (i, clr) in pal.iter_mut().enumerate() {
@@ -419,21 +429,25 @@ pub fn create_palettiser(enc_opts: &[OptionArgs]) -> Option<Palettiser> {
                                         (g << 5) | (g << 2) | (g >> 1),
                                         b * 0x55];
                             }
+                            nclrs = 256;
                         },
                         "qt4" => {
                             for (dclr, sclr) in pal.iter_mut().zip(MOV_DEFAULT_PAL_2BIT.chunks_exact(4)) {
                                 dclr.copy_from_slice(&sclr[..3]);
                             }
+                            nclrs = 4;
                         },
                         "qt16" => {
                             for (dclr, sclr) in pal.iter_mut().zip(MOV_DEFAULT_PAL_4BIT.chunks_exact(4)) {
                                 dclr.copy_from_slice(&sclr[..3]);
                             }
+                            nclrs = 16;
                         },
                         "qt256" => {
                             for (dclr, sclr) in pal.iter_mut().zip(MOV_DEFAULT_PAL_8BIT.chunks_exact(4)) {
                                 dclr.copy_from_slice(&sclr[..3]);
                             }
+                            nclrs = 256;
                         },
                         _ => {
                             println!("invalid or unknown palette mode");
@@ -450,7 +464,7 @@ pub fn create_palettiser(enc_opts: &[OptionArgs]) -> Option<Palettiser> {
         }
     }
     if pmode.is_some() || pal_is_some {
-        Some(Palettiser::new(pmode.unwrap_or_default(), &pal))
+        Some(Palettiser::new(pmode.unwrap_or_default(), &pal, nclrs))
     } else {
         None
     }
@@ -626,13 +640,15 @@ impl Bucket {
 struct BucketCounter {
     buckets:    Vec<Bucket>,
     nentries:   usize,
+    nclrs:      usize,
 }
 
 impl BucketCounter {
-    fn new() -> Self {
+    fn new(nclrs: usize) -> Self {
         Self {
             buckets:    vec![Bucket::default(); NBUCKETS],
             nentries:   0,
+            nclrs,
         }
     }
     fn reset(&mut self) {
@@ -656,18 +672,19 @@ impl BucketCounter {
             }
         }
     }
-    fn get_pal(&self, debug: bool) -> [[u8; 3]; 256] {
+    fn get_pal(&self, debug: bool) -> ([[u8; 3]; 256], usize) {
         let mut pal = [[0; 3]; 256];
         if debug {
             println!("  {} entries in total", self.nentries);
         }
-        if self.nentries <= 256 {
+        if self.nentries <= self.nclrs {
             let mut dst = pal.iter_mut();
             for bucket in self.buckets.iter() {
                 for clr in bucket.clrs.iter() {
                     *dst.next().unwrap() = clr.clr;
                 }
             }
+            (pal, self.nentries)
         } else {
             let mut in_clrs = Vec::with_capacity(self.nentries);
             for bucket in self.buckets.iter() {
@@ -675,19 +692,19 @@ impl BucketCounter {
                     in_clrs.push(*clr);
                 }
             }
-            let mut ppal = [Colour::default(); 256];
+            let mut ppal = vec![Colour::default(); self.nclrs];
             let _prim_clrs = quantise_median_cut::<Colour, ColourSum>(&in_clrs, &mut ppal);
             let mut elbg: ELBG<Colour, ColourSum> = ELBG::new(&ppal);
-            elbg.quantise(&in_clrs, &mut ppal);
+            let nclrs = elbg.quantise(&in_clrs, &mut ppal);
             for (dst, src) in pal.iter_mut().zip(ppal.iter()) {
                 *dst = src.clr;
             }
+            (pal, nclrs)
         }
-        pal
     }
 }
 
-pub type PalSegment = (usize, [[u8; 3]; 256]);
+pub type PalSegment = (usize, [[u8; 3]; 256], usize);
 
 struct MultiCount {
     glbl_hist:  [u64; 1 << 15],
@@ -759,6 +776,7 @@ pub struct ColourCounter {
     ctype:      CounterType,
     debug:      bool,
     multi:      Option<Box<MultiCount>>,
+    nclrs:      usize,
 }
 
 impl ColourCounter {
@@ -766,6 +784,7 @@ impl ColourCounter {
         let mut counter_type = "full";
         let mut debug = false;
         let mut multi = None;
+        let mut nclrs = 256;
         for opt in options.iter() {
             match opt.name.as_str() {
                 "counter" => {
@@ -824,19 +843,34 @@ impl ColourCounter {
                         println!("group size requires a numeric argument");
                     }
                 },
+                "nclrs" | "num_colours" | "palette_size" => {
+                    if let Some(val) = opt.value.as_deref() {
+                        if let Ok(psize) = val.parse::<usize>() {
+                            if (2..=256).contains(&psize) {
+                                nclrs = psize;
+                            } else {
+                                println!("palette should contain 2..256 colours");
+                            }
+                        } else {
+                            println!("palette size requires a numeric argument");
+                        }
+                    } else {
+                        println!("palette size requires a numeric argument");
+                    }
+                },
                 _ => {},
             }
         }
         let ctype = match counter_type {
                 "full" => CounterType::Brawn(vec![0; 1 << 24]),
-                "bucket" => CounterType::Buckets(BucketCounter::new()),
+                "bucket" => CounterType::Buckets(BucketCounter::new(nclrs)),
                 _ => unreachable!(),
             };
         Self {
             scaler:     None,
             sc_buf:     NABufferType::None,
             oinfo:      NAVideoInfo{ width: 0, height: 0, flipped: false, format: RGB24_FORMAT, bits: 24 },
-            ctype, debug, multi,
+            ctype, debug, multi, nclrs,
         }
     }
     pub fn add(&mut self, clr: [u8; 3]) {
@@ -880,10 +914,10 @@ impl ColourCounter {
             if self.debug {
                 println!("  new pal segment {pal_start}..{pal_end}");
             }
-            let pal = self.get_pal();
+            let (pal, nclrs) = self.get_pal();
             self.reset();
             if let Some(ref mut multi) = self.multi {
-                multi.pals.push((multi.fstart, pal));
+                multi.pals.push((multi.fstart, pal, nclrs));
                 multi.fstart = multi.frameno - 1;
             }
         }
@@ -928,7 +962,7 @@ impl ColourCounter {
             Err("not a video frame")
         }
     }
-    pub fn get_pal(&self) -> [[u8; 3]; 256] {
+    pub fn get_pal(&self) -> ([[u8; 3]; 256], usize) {
         match self.ctype {
             CounterType::Buckets(ref bkt) => bkt.get_pal(self.debug),
             CounterType::Brawn(ref hist) => {
@@ -937,32 +971,33 @@ impl ColourCounter {
                 if self.debug {
                     println!("  {} entries in total", clrs.len());
                 }
-                if clrs.len() <= 256 {
+                if clrs.len() <= self.nclrs {
                     for (dclr, sclr) in pal.iter_mut().zip(clrs.iter()) {
                         *dclr = sclr.clr;
                     }
+                    (pal, clrs.len())
                 } else {
-                    let mut ppal = [Colour::default(); 256];
+                    let mut ppal = vec![Colour::default(); self.nclrs];
                     let _prim_clrs = quantise_median_cut::<Colour, ColourSum>(&clrs, &mut ppal);
                     let mut elbg: ELBG<Colour, ColourSum> = ELBG::new(&ppal);
-                    elbg.quantise(&clrs, &mut ppal);
+                    let nclrs = elbg.quantise(&clrs, &mut ppal);
                     for (dst, src) in pal.iter_mut().zip(ppal.iter()) {
                         *dst = src.clr;
                     }
+                    (pal, nclrs)
                 }
-                pal
             },
         }
     }
     pub fn get_multi_pals(&self) -> Vec<PalSegment> {
-        let last_pal = self.get_pal();
+        let (last_pal, nclrs) = self.get_pal();
         if let Some(ref multi) = self.multi {
             let mut pals = multi.pals.clone();
             if multi.fstart < multi.frameno {
-                pals.push((multi.fstart, last_pal));
+                pals.push((multi.fstart, last_pal, nclrs));
             }
             return pals;
         }
-        vec![(0, last_pal)]
+        vec![(0, last_pal, nclrs)]
     }
 }
index 7208334ff639080de52d081dd038fab2d4af1e29..7fcbf0db6c8b44fa3c3b0ec055fe15864eb4853d 100644 (file)
@@ -247,8 +247,8 @@ impl EncoderInterface for VideoEncodeContext {
         if let Some(ref mut plt) = self.plt {
             if !self.pals.is_empty() && !matches!(buf, NABufferType::None) {
                 if self.pal_frm >= self.pals[0].0 {
-                    let (_, pal) = self.pals.remove(0);
-                    plt.set_pal(&pal);
+                    let (_, pal, nclrs) = self.pals.remove(0);
+                    plt.set_pal(&pal, nclrs);
                 }
                 self.pal_frm += 1;
             }
@@ -986,9 +986,9 @@ impl Transcoder {
         for (src_id, pals) in qsupp.pals.iter() {
             if *src_id == iidx && !pals.is_empty() {
                 if let Some(ref mut p) = plt {
-                    p.set_pal(&pals[0].1);
+                    p.set_pal(&pals[0].1, pals[0].2);
                 } else {
-                    plt = Some(Palettiser::new(Palettiser::get_default_mode(), &pals[0].1));
+                    plt = Some(Palettiser::new(Palettiser::get_default_mode(), &pals[0].1, pals[0].2));
                 }
             }
         }