]> git.nihav.org Git - nihav.git/blame - nihav-core/src/scale/palette/neuquant.rs
avimux: do not record palette change chunks in OpenDML index
[nihav.git] / nihav-core / src / scale / palette / neuquant.rs
CommitLineData
4b459d0b
KS
1use super::Pixel;
2
3pub struct NeuQuantQuantiser {
4 weights: [[f64; 3]; 256],
5 freq: [f64; 256],
6 bias: [f64; 256],
7 factor: usize,
8}
9
10const SPECIAL_NODES: usize = 2;
11impl NeuQuantQuantiser {
1ff7036b 12 #[allow(clippy::needless_range_loop)]
4b459d0b
KS
13 pub fn new(factor: usize) -> Self {
14 let mut weights = [[0.0; 3]; 256];
15 if SPECIAL_NODES > 1 {
16 weights[1] = [255.0; 3]; // for white
17 }
18 for i in SPECIAL_NODES..256 {
19 let w = 255.0 * ((i - SPECIAL_NODES) as f64) / ((256 - SPECIAL_NODES) as f64);
20 weights[i] = [w, w, w];
21 }
22 Self {
23 weights,
24 freq: [1.0 / 256.0; 256],
25 bias: [0.0; 256],
26 factor,
27 }
28 }
29 fn update_node(&mut self, idx: usize, clr: &[f64; 3], alpha: f64) {
30 self.weights[idx][0] -= alpha * (self.weights[idx][0] - clr[0]);
31 self.weights[idx][1] -= alpha * (self.weights[idx][1] - clr[1]);
32 self.weights[idx][2] -= alpha * (self.weights[idx][2] - clr[2]);
33 }
34 fn update_neighbours(&mut self, idx: usize, clr: &[f64; 3], alpha: f64, radius: usize) {
35 let low = idx.saturating_sub(radius).max(SPECIAL_NODES - 1);
36 let high = (idx + radius).min(self.weights.len() - 1);
37
38 let mut idx0 = idx + 1;
39 let mut idx1 = idx - 1;
40 let mut range = 0;
41 let sqradius = (radius * radius) as f64;
42 while (idx0 < high) || (idx1 > low) {
b36f412c 43 let sqrng = f64::from(range * range);
4b459d0b
KS
44 let a = alpha * (sqradius - sqrng) / sqradius;
45 range += 1;
46 if idx0 < high {
47 self.update_node(idx0, clr, a);
48 idx0 += 1;
49 }
50 if idx1 > low {
51 self.update_node(idx1, clr, a);
52 idx1 -= 1;
53 }
54 }
55 }
b7c882c1 56 #[allow(clippy::float_cmp)]
4b459d0b
KS
57 fn find_node(&mut self, clr: &[f64; 3]) -> usize {
58 for i in 0..SPECIAL_NODES {
59 if &self.weights[i] == clr {
60 return i;
61 }
62 }
08e0a8bf 63 let mut bestdist = f64::MAX;
4b459d0b 64 let mut distidx = 0;
08e0a8bf 65 let mut bestbias = f64::MAX;
4b459d0b
KS
66 let mut biasidx = 0;
67 for i in SPECIAL_NODES..256 {
68 let dist = (self.weights[i][0] - clr[0]) * (self.weights[i][0] - clr[0])
69 + (self.weights[i][1] - clr[1]) * (self.weights[i][1] - clr[1])
70 + (self.weights[i][2] - clr[2]) * (self.weights[i][2] - clr[2]);
71 if bestdist > dist {
72 bestdist = dist;
73 distidx = i;
74 }
75 let biasdiff = dist - self.bias[i];
76 if bestbias > biasdiff {
77 bestbias = biasdiff;
78 biasidx = i;
79 }
80 self.freq[i] -= self.freq[i] / 1024.0;
81 self.bias[i] += self.freq[i];
82 }
83 self.freq[distidx] += 1.0 / 1024.0;
84 self.bias[distidx] -= 1.0;
85 biasidx
86 }
87 pub fn learn(&mut self, src: &[Pixel]) {
88 let mut bias_radius = (256 / 8) << 6;
89 let alphadec = (30 + (self.factor - 1) / 3) as f64;
b36f412c 90 let initial_alpha = f64::from(1 << 10);
4b459d0b
KS
91
92 let npixels = src.len();
93
94 let mut radius = bias_radius >> 6;
95 if radius == 1 { radius = 0 };
96 let samples = npixels / self.factor;
97 let delta = samples / 100;
98 let mut alpha = initial_alpha;
99
100 let mut pos = 0;
101 const PRIMES: [usize; 4] = [ 499, 491, 487, 503 ];
102 let mut step = PRIMES[3];
103 for prime in PRIMES.iter().rev() {
104 if npixels % *prime != 0 {
105 step = *prime;
106 }
107 }
108
109 for i in 0..samples {
b36f412c 110 let clr = [f64::from(src[pos].r), f64::from(src[pos].g), f64::from(src[pos].b)];
4b459d0b
KS
111 let idx = self.find_node(&clr);
112 if idx >= SPECIAL_NODES {
113 let new_alpha = alphadec / initial_alpha;
114 self.update_node(idx, &clr, new_alpha);
115 if radius > 0 {
116 self.update_neighbours(idx, &clr, new_alpha, radius);
117 }
118 }
119 pos = (pos + step) % npixels;
120 if (i + 1) % delta == 0 {
121 alpha -= alpha / alphadec;
122 bias_radius -= bias_radius / 30;
123 radius = bias_radius >> 6;
124 if radius == 1 { radius = 0 };
125 }
126 }
127 }
128 pub fn make_pal(&self, pal: &mut [[u8; 3]; 256]) {
129 for (pal, node) in pal.iter_mut().zip(self.weights.iter()) {
130 pal[0] = (node[0] + 0.5).max(0.0).min(255.0) as u8;
131 pal[1] = (node[1] + 0.5).max(0.0).min(255.0) as u8;
132 pal[2] = (node[2] + 0.5).max(0.0).min(255.0) as u8;
133 }
134 }
135}