core/scale: add conversion into paletted format
[nihav.git] / nihav-core / src / scale / palette / palettise.rs
1 pub fn find_nearest(pix: &[u8], pal: &[[u8; 3]; 256]) -> usize {
2 let mut bestidx = 0;
3 let mut bestdist = std::i32::MAX;
4
5 for (idx, entry) in pal.iter().enumerate() {
6 let dist0 = i32::from(pix[0]) - i32::from(entry[0]);
7 let dist1 = i32::from(pix[1]) - i32::from(entry[1]);
8 let dist2 = i32::from(pix[2]) - i32::from(entry[2]);
9 if (dist0 | dist1 | dist2) == 0 {
10 return idx;
11 }
12 let dist = dist0 * dist0 + dist1 * dist1 + dist2 * dist2;
13 if bestdist > dist {
14 bestdist = dist;
15 bestidx = idx;
16 }
17 }
18 bestidx
19 }
20
21 pub struct LocalSearch {
22 pal: [[u8; 3]; 256],
23 db: Vec<Vec<[u8; 4]>>,
24 }
25
26 impl LocalSearch {
27 fn quant(key: [u8; 3]) -> usize {
28 (((key[0] >> 3) as usize) << 10) |
29 (((key[1] >> 3) as usize) << 5) |
30 ((key[2] >> 3) as usize)
31 }
32 pub fn new(in_pal: &[[u8; 3]; 256]) -> Self {
33 let mut db = Vec::with_capacity(1 << 15);
34 let pal = *in_pal;
35 for _ in 0..(1 << 15) {
36 db.push(Vec::new());
37 }
38 for (i, palentry) in pal.iter().enumerate() {
39 let r0 = (palentry[0] >> 3) as usize;
40 let g0 = (palentry[1] >> 3) as usize;
41 let b0 = (palentry[2] >> 3) as usize;
42 for r in r0.saturating_sub(1)..=(r0 + 1).min(31) {
43 for g in g0.saturating_sub(1)..=(g0 + 1).min(31) {
44 for b in b0.saturating_sub(1)..=(b0 + 1).min(31) {
45 let idx = (r << 10) | (g << 5) | b;
46 db[idx].push([palentry[0], palentry[1], palentry[2], i as u8]);
47 }
48 }
49 }
50 }
51 Self { pal, db }
52 }
53 fn dist(a: &[u8; 4], b: [u8; 3]) -> u32 {
54 let d0 = i32::from(a[0]) - i32::from(b[0]);
55 let d1 = i32::from(a[1]) - i32::from(b[1]);
56 let d2 = i32::from(a[2]) - i32::from(b[2]);
57 (d0 * d0 + d1 * d1 + d2 * d2) as u32
58 }
59 pub fn search(&self, pix: [u8; 3]) -> usize {
60 let idx = Self::quant(pix);
61 let mut best_dist = std::u32::MAX;
62 let mut best_idx = 0;
63 let mut count = 0;
64 for clr in self.db[idx].iter() {
65 let dist = Self::dist(clr, pix);
66 count += 1;
67 if best_dist > dist {
68 best_dist = dist;
69 best_idx = clr[3] as usize;
70 if dist == 0 { break; }
71 }
72 }
73 if count > 0 {
74 best_idx
75 } else {
76 find_nearest(&pix, &self.pal)
77 }
78 }
79 }
80
81 struct KDNode {
82 key: [u8; 3],
83 comp: u8,
84 idx: u8,
85 child0: usize,
86 child1: usize,
87 }
88
89 pub struct KDTree {
90 nodes: Vec<KDNode>,
91 }
92
93 fn avg_u8(a: u8, b: u8) -> u8 {
94 (a & b) + ((a ^ b) >> 1)
95 }
96
97 impl KDTree {
98 pub fn new(pal: &[[u8; 3]; 256]) -> Self {
99 let mut npal = [[0; 4]; 256];
100 for i in 0..256 {
101 npal[i][0] = pal[i][0];
102 npal[i][1] = pal[i][1];
103 npal[i][2] = pal[i][2];
104 npal[i][3] = i as u8;
105 }
106 let mut tree = Self { nodes: Vec::with_capacity(512) };
107 tree.build(&mut npal, 0, 256, 1024, false);
108 tree
109 }
110 fn build(&mut self, pal: &mut [[u8; 4]; 256], start: usize, end: usize, root: usize, child0: bool) {
111 if start + 1 == end {
112 let key = [pal[start][0], pal[start][1], pal[start][2]];
113 let newnode = KDNode { key, comp: 0, idx: pal[start][3], child0: 0, child1: 0 };
114 let cur_node = self.nodes.len();
115 self.nodes.push(newnode);
116 if child0 {
117 self.nodes[root].child0 = cur_node;
118 } else {
119 self.nodes[root].child1 = cur_node;
120 }
121 return;
122 }
123 let mut min = [255u8; 3];
124 let mut max = [0u8; 3];
125 for i in start..end {
126 for comp in 0..3 {
127 min[comp] = min[comp].min(pal[i][comp]);
128 max[comp] = max[comp].max(pal[i][comp]);
129 }
130 }
131 let dr = max[0] - min[0];
132 let dg = max[1] - min[1];
133 let db = max[2] - min[2];
134 let med = [avg_u8(min[0], max[0]), avg_u8(min[1], max[1]), avg_u8(min[2], max[2])];
135 let comp = if dr > dg && dr > db {
136 0
137 } else if db > dr && db > dg {
138 2
139 } else {
140 1
141 };
142 let pivot = Self::reorder(&mut pal[start..end], comp, med[comp]) + start;
143 let newnode = KDNode { key: med, comp: comp as u8, idx: 0, child0: 0, child1: 0 };
144 let cur_node = self.nodes.len();
145 self.nodes.push(newnode);
146 if root != 1024 {
147 if child0 {
148 self.nodes[root].child0 = cur_node;
149 } else {
150 self.nodes[root].child1 = cur_node;
151 }
152 }
153 self.build(pal, start, pivot, cur_node, true);
154 self.build(pal, pivot, end, cur_node, false);
155 }
156 fn reorder(pal: &mut[[u8; 4]], comp: usize, med: u8) -> usize {
157 let mut start = 0;
158 let mut end = pal.len() - 1;
159 while start < end {
160 while start < end && pal[start][comp] <= med {
161 start += 1;
162 }
163 while start < end && pal[end][comp] > med {
164 end -= 1;
165 }
166 if start < end {
167 pal.swap(start, end);
168 start += 1;
169 end -= 1;
170 }
171 }
172 start
173 }
174 pub fn search(&self, pix: [u8; 3]) -> usize {
175 let mut idx = 0;
176 loop {
177 let cnode = &self.nodes[idx];
178 if cnode.child0 == 0 {
179 return cnode.idx as usize;
180 }
181 let nidx = if cnode.key[cnode.comp as usize] >= pix[cnode.comp as usize] { cnode.child0 } else { cnode.child1 };
182 idx = nidx;
183 }
184 }
185 }