codec_support/vq: do not move empty group centroid
[nihav.git] / nihav-codec-support / src / vq / generic_elbg.rs
1 use super::{VQElement, VQElementSum};
2
3 // very simple RNG for internal needs
4 struct RNG {
5 seed: u16,
6 }
7
8 impl RNG {
9 fn new() -> Self { Self { seed: 0x1234 } }
10 fn next(&mut self) -> u8 {
11 if (self.seed & 0x8000) != 0 {
12 self.seed = (self.seed & 0x7FFF) * 2 ^ 0x1B2B;
13 } else {
14 self.seed <<= 1;
15 }
16 self.seed as u8
17 }
18 }
19
20 struct Entry<T> {
21 val: T,
22 count: u64,
23 }
24
25 struct Cluster<T: VQElement, TS: VQElementSum<T>> {
26 centroid: T,
27 dist: u64,
28 count: u64,
29 sum: TS,
30 }
31
32 impl<T: VQElement, TS: VQElementSum<T>> Cluster<T, TS> {
33 fn new(centroid: T) -> Self {
34 Self {
35 centroid,
36 dist: 0,
37 count: 0,
38 sum: TS::zero(),
39 }
40 }
41 fn reset(&mut self) {
42 self.count = 0;
43 self.sum = TS::zero();
44 self.dist = 0;
45 }
46 fn add_point(&mut self, entry: &Entry<T>) {
47 self.sum.add(entry.val, entry.count);
48 self.count += entry.count;
49 }
50 fn add_dist(&mut self, entry: &Entry<T>) {
51 self.dist += u64::from(self.centroid.dist(entry.val)) * entry.count;
52 }
53 fn calc_centroid(&mut self) {
54 if self.count > 0 {
55 self.centroid = self.sum.get_centroid();
56 }
57 }
58 fn calc_dist(&mut self) {
59 if self.count != 0 {
60 self.dist = (self.dist + self.count / 2) / self.count;
61 }
62 }
63 }
64
65 pub struct ELBG<T: VQElement, TS: VQElementSum<T>> {
66 clusters: Vec<Cluster<T, TS>>,
67 }
68
69 impl<T: VQElement+Default, TS: VQElementSum<T>> ELBG<T, TS> {
70 pub fn new(initial_cb: &[T]) -> Self {
71 let mut clusters = Vec::with_capacity(initial_cb.len());
72 for elem in initial_cb.iter() {
73 let cluster = Cluster::new(*elem);
74 clusters.push(cluster);
75 }
76 Self {
77 clusters,
78 }
79 }
80 fn new_split(old_index: usize, entries: &[Entry<T>], indices: &[usize]) -> Option<(T, T)> {
81 let mut max = T::min_cw();
82 let mut min = T::max_cw();
83 let mut found = false;
84 for (entry, idx) in entries.iter().zip(indices) {
85 if *idx == old_index {
86 max = max.max(entry.val);
87 min = min.min(entry.val);
88 found = true;
89 }
90 }
91 if !found {
92 return None;
93 }
94 let mut ts0 = TS::zero();
95 let mut ts1 = TS::zero();
96 ts0.add(min, 2); ts0.add(max, 1);
97 ts1.add(min, 1); ts1.add(max, 2);
98 Some((ts0.get_centroid(), ts1.get_centroid()))
99 }
100 fn old_centre(&self, old_index1: usize, old_index2: usize, entries: &[Entry<T>], indices: &[usize]) -> T {
101 let mut max = T::min_cw();
102 let mut min = T::max_cw();
103 let mut found = false;
104 for (entry, idx) in entries.iter().zip(indices) {
105 if *idx == old_index1 || *idx == old_index2 {
106 max = max.max(entry.val);
107 min = min.min(entry.val);
108 found = true;
109 }
110 }
111 if !found {
112 max = self.clusters[old_index1].centroid.max(self.clusters[old_index2].centroid);
113 min = self.clusters[old_index1].centroid.min(self.clusters[old_index2].centroid);
114 }
115 let mut ts = TS::zero();
116 ts.add(min, 2); ts.add(max, 1);
117 ts.get_centroid()
118 }
119 fn estimate_old(old_idx0: usize, old_idx1: usize, c: T, entries: &[Entry<T>], indices: &[usize]) -> u64 {
120 let mut clu: Cluster<T, TS> = Cluster::new(c);
121 let mut count = 0;
122 for (entry, idx) in entries.iter().zip(indices) {
123 if *idx == old_idx0 || *idx == old_idx1 {
124 clu.add_dist(entry);
125 count += entry.count;
126 }
127 }
128 clu.count = count;
129 clu.calc_dist();
130 clu.dist
131 }
132 fn estimate_new(c0: T, c1: T, old_idx: usize, entries: &[Entry<T>], indices: &[usize]) -> u64 {
133 let mut clu0: Cluster<T, TS> = Cluster::new(c0);
134 let mut clu1: Cluster<T, TS> = Cluster::new(c1);
135 let mut count0 = 0;
136 let mut count1 = 0;
137 for (entry, idx) in entries.iter().zip(indices) {
138 if *idx == old_idx {
139 if c0.dist(entry.val) < c1.dist(entry.val) {
140 clu0.add_dist(entry);
141 count0 += entry.count;
142 } else {
143 clu1.add_dist(entry);
144 count1 += entry.count;
145 }
146 }
147 }
148 clu0.count = count0;
149 clu1.count = count1;
150 clu0.calc_dist();
151 clu1.calc_dist();
152 clu0.dist + clu1.dist
153 }
154 pub fn quantise(&mut self, src: &[T], dst: &mut [T]) {
155 if src.len() < 1 || dst.len() != self.clusters.len() {
156 return;
157 }
158 let mut old_cb = vec![T::default(); self.clusters.len()];
159 let mut prev_dist = std::u64::MAX;
160 let mut dist = std::u64::MAX / 2;
161 let mut indices = Vec::with_capacity(src.len());
162 let mut elements = Vec::with_capacity(src.len());
163 elements.extend_from_slice(src);
164 for comp in 0..T::num_components() {
165 T::sort_by_component(elements.as_mut_slice(), comp);
166 }
167 let mut entries = Vec::with_capacity(elements.len() / 2);
168 let mut lastval = elements[0];
169 let mut run = 1;
170 for point in elements.iter().skip(1) {
171 if &lastval == point {
172 run += 1;
173 } else {
174 entries.push(Entry { val: lastval, count: run });
175 lastval = *point;
176 run = 1;
177 }
178 }
179 entries.push(Entry { val: lastval, count: run });
180 drop(elements);
181
182 let mut low_u: Vec<usize> = Vec::with_capacity(self.clusters.len());
183 let mut high_u: Vec<usize> = Vec::with_capacity(self.clusters.len());
184 let mut rng = RNG::new();
185 let mut iterations = 0usize;
186 let mut do_elbg_step = true;
187 while (iterations < 20) && (dist < prev_dist - prev_dist / 1000) {
188 prev_dist = dist;
189 for i in 0..dst.len() {
190 old_cb[i] = self.clusters[i].centroid;
191 self.clusters[i].reset();
192 }
193 // put points into the nearest clusters
194 indices.truncate(0);
195 for entry in entries.iter() {
196 let mut bestidx = 0;
197 let mut bestdist = std::u32::MAX;
198 for (i, cluster) in self.clusters.iter().enumerate() {
199 let dist = entry.val.dist(cluster.centroid);
200 if bestdist > dist {
201 bestdist = dist;
202 bestidx = i;
203 if dist == 0 {
204 break;
205 }
206 }
207 }
208 indices.push(bestidx);
209 self.clusters[bestidx].add_point(entry);
210 }
211 // calculate params
212 for cluster in self.clusters.iter_mut() {
213 cluster.calc_centroid();
214 }
215 dist = 0;
216 for (idx, entry) in indices.iter().zip(entries.iter()) {
217 self.clusters[*idx].add_dist(entry);
218 }
219 for cluster in self.clusters.iter_mut() {
220 cluster.calc_dist();
221 dist += cluster.dist;
222 }
223
224 let dmean = dist / (dst.len() as u64);
225 low_u.truncate(0);
226 high_u.truncate(0);
227 let mut used = vec![false; dst.len()];
228 for (i, cluster) in self.clusters.iter().enumerate() {
229 if cluster.dist < dmean {
230 low_u.push(i);
231 } else if cluster.dist > dmean * 2 {
232 high_u.push(i);
233 used[i] = true;
234 }
235 }
236
237 if do_elbg_step {
238 do_elbg_step = false;
239 for low_idx in low_u.iter() {
240 if high_u.len() == 0 {
241 break;
242 }
243 let high_idx_idx = (rng.next() as usize) % high_u.len();
244 let high_idx = high_u[high_idx_idx];
245 let mut closest_idx = *low_idx;
246 let mut closest_dist = std::u32::MAX;
247 let low_centr = self.clusters[*low_idx].centroid;
248 for i in 0..dst.len() {//low_u.iter() {
249 if i == *low_idx || used[i] {
250 continue;
251 }
252 let dist = self.clusters[i].centroid.dist(low_centr);
253 if closest_dist > dist {
254 closest_dist = dist;
255 closest_idx = i;
256 }
257 }
258 if closest_idx == *low_idx {
259 continue;
260 }
261 let old_dist = self.clusters[*low_idx].dist + self.clusters[closest_idx].dist + self.clusters[high_idx].dist;
262 let old_centr = self.old_centre(*low_idx, closest_idx, entries.as_slice(), indices.as_slice());
263 let ret = Self::new_split(high_idx, entries.as_slice(), indices.as_slice());
264 if let Some((centr0, centr1)) = ret {
265 let dist_o = if old_dist > self.clusters[high_idx].dist {
266 Self::estimate_old(*low_idx, closest_idx, old_centr, entries.as_slice(), indices.as_slice())
267 } else { 0 };
268 let dist_n = Self::estimate_new(centr0, centr1, high_idx, entries.as_slice(), indices.as_slice());
269 if dist_o + dist_n < old_dist {
270 self.clusters[*low_idx ].centroid = old_centr;
271 self.clusters[closest_idx].centroid = centr0;
272 self.clusters[high_idx ].centroid = centr1;
273 used[*low_idx] = true;
274 used[closest_idx] = true;
275 used[high_idx] = true;
276 high_u.remove(high_idx_idx);
277 do_elbg_step = true;
278 }
279 }
280 }
281 }
282 iterations += 1;
283 }
284 if dist < prev_dist {
285 for i in 0..dst.len() {
286 old_cb[i] = self.clusters[i].centroid;
287 }
288 }
289 dst.copy_from_slice(&old_cb);
290 }
291 }