]>
Commit | Line | Data |
---|---|---|
971ae306 KS |
1 | use super::{VQElement, VQElementSum}; |
2 | ||
3 | // very simple RNG for internal needs | |
4 | struct RNG { | |
5 | seed: u16, | |
6 | } | |
7 | ||
8 | impl RNG { | |
9 | fn new() -> Self { Self { seed: 0x1234 } } | |
10 | fn next(&mut self) -> u8 { | |
11 | if (self.seed & 0x8000) != 0 { | |
03011b99 | 12 | self.seed = ((self.seed & 0x7FFF) * 2) ^ 0x1B2B; |
971ae306 KS |
13 | } else { |
14 | self.seed <<= 1; | |
15 | } | |
16 | self.seed as u8 | |
17 | } | |
18 | } | |
19 | ||
20 | struct Entry<T> { | |
21 | val: T, | |
22 | count: u64, | |
23 | } | |
24 | ||
25 | struct Cluster<T: VQElement, TS: VQElementSum<T>> { | |
26 | centroid: T, | |
27 | dist: u64, | |
28 | count: u64, | |
29 | sum: TS, | |
30 | } | |
31 | ||
32 | impl<T: VQElement, TS: VQElementSum<T>> Cluster<T, TS> { | |
33 | fn new(centroid: T) -> Self { | |
34 | Self { | |
35 | centroid, | |
36 | dist: 0, | |
37 | count: 0, | |
38 | sum: TS::zero(), | |
39 | } | |
40 | } | |
41 | fn reset(&mut self) { | |
42 | self.count = 0; | |
43 | self.sum = TS::zero(); | |
44 | self.dist = 0; | |
45 | } | |
46 | fn add_point(&mut self, entry: &Entry<T>) { | |
47 | self.sum.add(entry.val, entry.count); | |
48 | self.count += entry.count; | |
49 | } | |
50 | fn add_dist(&mut self, entry: &Entry<T>) { | |
51 | self.dist += u64::from(self.centroid.dist(entry.val)) * entry.count; | |
52 | } | |
53 | fn calc_centroid(&mut self) { | |
4347bba1 KS |
54 | if self.count > 0 { |
55 | self.centroid = self.sum.get_centroid(); | |
56 | } | |
971ae306 KS |
57 | } |
58 | fn calc_dist(&mut self) { | |
971ae306 KS |
59 | } |
60 | } | |
61 | ||
62 | pub struct ELBG<T: VQElement, TS: VQElementSum<T>> { | |
63 | clusters: Vec<Cluster<T, TS>>, | |
64 | } | |
65 | ||
66 | impl<T: VQElement+Default, TS: VQElementSum<T>> ELBG<T, TS> { | |
67 | pub fn new(initial_cb: &[T]) -> Self { | |
68 | let mut clusters = Vec::with_capacity(initial_cb.len()); | |
69 | for elem in initial_cb.iter() { | |
70 | let cluster = Cluster::new(*elem); | |
71 | clusters.push(cluster); | |
72 | } | |
73 | Self { | |
74 | clusters, | |
75 | } | |
76 | } | |
77 | fn new_split(old_index: usize, entries: &[Entry<T>], indices: &[usize]) -> Option<(T, T)> { | |
78 | let mut max = T::min_cw(); | |
79 | let mut min = T::max_cw(); | |
80 | let mut found = false; | |
81 | for (entry, idx) in entries.iter().zip(indices) { | |
82 | if *idx == old_index { | |
83 | max = max.max(entry.val); | |
84 | min = min.min(entry.val); | |
85 | found = true; | |
86 | } | |
87 | } | |
88 | if !found { | |
89 | return None; | |
90 | } | |
91 | let mut ts0 = TS::zero(); | |
92 | let mut ts1 = TS::zero(); | |
93 | ts0.add(min, 2); ts0.add(max, 1); | |
94 | ts1.add(min, 1); ts1.add(max, 2); | |
95 | Some((ts0.get_centroid(), ts1.get_centroid())) | |
96 | } | |
97 | fn old_centre(&self, old_index1: usize, old_index2: usize, entries: &[Entry<T>], indices: &[usize]) -> T { | |
98 | let mut max = T::min_cw(); | |
99 | let mut min = T::max_cw(); | |
100 | let mut found = false; | |
101 | for (entry, idx) in entries.iter().zip(indices) { | |
102 | if *idx == old_index1 || *idx == old_index2 { | |
103 | max = max.max(entry.val); | |
104 | min = min.min(entry.val); | |
105 | found = true; | |
106 | } | |
107 | } | |
108 | if !found { | |
109 | max = self.clusters[old_index1].centroid.max(self.clusters[old_index2].centroid); | |
110 | min = self.clusters[old_index1].centroid.min(self.clusters[old_index2].centroid); | |
111 | } | |
112 | let mut ts = TS::zero(); | |
113 | ts.add(min, 2); ts.add(max, 1); | |
114 | ts.get_centroid() | |
115 | } | |
116 | fn estimate_old(old_idx0: usize, old_idx1: usize, c: T, entries: &[Entry<T>], indices: &[usize]) -> u64 { | |
117 | let mut clu: Cluster<T, TS> = Cluster::new(c); | |
118 | let mut count = 0; | |
119 | for (entry, idx) in entries.iter().zip(indices) { | |
120 | if *idx == old_idx0 || *idx == old_idx1 { | |
121 | clu.add_dist(entry); | |
122 | count += entry.count; | |
123 | } | |
124 | } | |
125 | clu.count = count; | |
126 | clu.calc_dist(); | |
127 | clu.dist | |
128 | } | |
129 | fn estimate_new(c0: T, c1: T, old_idx: usize, entries: &[Entry<T>], indices: &[usize]) -> u64 { | |
130 | let mut clu0: Cluster<T, TS> = Cluster::new(c0); | |
131 | let mut clu1: Cluster<T, TS> = Cluster::new(c1); | |
132 | let mut count0 = 0; | |
133 | let mut count1 = 0; | |
134 | for (entry, idx) in entries.iter().zip(indices) { | |
135 | if *idx == old_idx { | |
136 | if c0.dist(entry.val) < c1.dist(entry.val) { | |
137 | clu0.add_dist(entry); | |
138 | count0 += entry.count; | |
139 | } else { | |
140 | clu1.add_dist(entry); | |
141 | count1 += entry.count; | |
142 | } | |
143 | } | |
144 | } | |
145 | clu0.count = count0; | |
146 | clu1.count = count1; | |
147 | clu0.calc_dist(); | |
148 | clu1.calc_dist(); | |
149 | clu0.dist + clu1.dist | |
150 | } | |
b7c882c1 | 151 | #[allow(clippy::cognitive_complexity)] |
f808017e | 152 | pub fn quantise(&mut self, src: &[T], dst: &mut [T]) -> usize { |
03011b99 | 153 | if src.is_empty() || dst.len() != self.clusters.len() { |
f808017e | 154 | return 0; |
971ae306 KS |
155 | } |
156 | let mut old_cb = vec![T::default(); self.clusters.len()]; | |
157 | let mut prev_dist = std::u64::MAX; | |
158 | let mut dist = std::u64::MAX / 2; | |
159 | let mut indices = Vec::with_capacity(src.len()); | |
160 | let mut elements = Vec::with_capacity(src.len()); | |
161 | elements.extend_from_slice(src); | |
162 | for comp in 0..T::num_components() { | |
163 | T::sort_by_component(elements.as_mut_slice(), comp); | |
164 | } | |
165 | let mut entries = Vec::with_capacity(elements.len() / 2); | |
166 | let mut lastval = elements[0]; | |
167 | let mut run = 1; | |
168 | for point in elements.iter().skip(1) { | |
169 | if &lastval == point { | |
170 | run += 1; | |
171 | } else { | |
172 | entries.push(Entry { val: lastval, count: run }); | |
173 | lastval = *point; | |
174 | run = 1; | |
175 | } | |
176 | } | |
177 | entries.push(Entry { val: lastval, count: run }); | |
178 | drop(elements); | |
179 | ||
f808017e | 180 | let mut cw_count = 0; |
971ae306 KS |
181 | let mut low_u: Vec<usize> = Vec::with_capacity(self.clusters.len()); |
182 | let mut high_u: Vec<usize> = Vec::with_capacity(self.clusters.len()); | |
183 | let mut rng = RNG::new(); | |
184 | let mut iterations = 0usize; | |
185 | let mut do_elbg_step = true; | |
c5a6ae87 | 186 | while (iterations < 20) && (dist < prev_dist - prev_dist / 100) { |
971ae306 | 187 | prev_dist = dist; |
f808017e KS |
188 | |
189 | cw_count = 0; | |
190 | for cluster in self.clusters.iter() { | |
191 | if cluster.count == 0 { | |
192 | continue; | |
193 | } | |
194 | old_cb[cw_count] = cluster.centroid; | |
195 | cw_count += 1; | |
196 | } | |
197 | for cluster in self.clusters.iter_mut() { | |
198 | cluster.reset(); | |
971ae306 | 199 | } |
f808017e | 200 | |
971ae306 | 201 | // put points into the nearest clusters |
37952415 | 202 | indices.clear(); |
971ae306 KS |
203 | for entry in entries.iter() { |
204 | let mut bestidx = 0; | |
205 | let mut bestdist = std::u32::MAX; | |
206 | for (i, cluster) in self.clusters.iter().enumerate() { | |
207 | let dist = entry.val.dist(cluster.centroid); | |
208 | if bestdist > dist { | |
209 | bestdist = dist; | |
210 | bestidx = i; | |
211 | if dist == 0 { | |
212 | break; | |
213 | } | |
214 | } | |
215 | } | |
216 | indices.push(bestidx); | |
217 | self.clusters[bestidx].add_point(entry); | |
218 | } | |
219 | // calculate params | |
220 | for cluster in self.clusters.iter_mut() { | |
221 | cluster.calc_centroid(); | |
222 | } | |
223 | dist = 0; | |
224 | for (idx, entry) in indices.iter().zip(entries.iter()) { | |
225 | self.clusters[*idx].add_dist(entry); | |
226 | } | |
227 | for cluster in self.clusters.iter_mut() { | |
228 | cluster.calc_dist(); | |
229 | dist += cluster.dist; | |
230 | } | |
231 | ||
232 | let dmean = dist / (dst.len() as u64); | |
37952415 KS |
233 | low_u.clear(); |
234 | high_u.clear(); | |
971ae306 KS |
235 | let mut used = vec![false; dst.len()]; |
236 | for (i, cluster) in self.clusters.iter().enumerate() { | |
237 | if cluster.dist < dmean { | |
238 | low_u.push(i); | |
239 | } else if cluster.dist > dmean * 2 { | |
240 | high_u.push(i); | |
241 | used[i] = true; | |
242 | } | |
243 | } | |
244 | ||
245 | if do_elbg_step { | |
246 | do_elbg_step = false; | |
247 | for low_idx in low_u.iter() { | |
03011b99 | 248 | if high_u.is_empty() { |
971ae306 KS |
249 | break; |
250 | } | |
251 | let high_idx_idx = (rng.next() as usize) % high_u.len(); | |
252 | let high_idx = high_u[high_idx_idx]; | |
253 | let mut closest_idx = *low_idx; | |
254 | let mut closest_dist = std::u32::MAX; | |
255 | let low_centr = self.clusters[*low_idx].centroid; | |
256 | for i in 0..dst.len() {//low_u.iter() { | |
257 | if i == *low_idx || used[i] { | |
258 | continue; | |
259 | } | |
260 | let dist = self.clusters[i].centroid.dist(low_centr); | |
261 | if closest_dist > dist { | |
262 | closest_dist = dist; | |
263 | closest_idx = i; | |
264 | } | |
265 | } | |
266 | if closest_idx == *low_idx { | |
267 | continue; | |
268 | } | |
269 | let old_dist = self.clusters[*low_idx].dist + self.clusters[closest_idx].dist + self.clusters[high_idx].dist; | |
270 | let old_centr = self.old_centre(*low_idx, closest_idx, entries.as_slice(), indices.as_slice()); | |
271 | let ret = Self::new_split(high_idx, entries.as_slice(), indices.as_slice()); | |
272 | if let Some((centr0, centr1)) = ret { | |
273 | let dist_o = if old_dist > self.clusters[high_idx].dist { | |
274 | Self::estimate_old(*low_idx, closest_idx, old_centr, entries.as_slice(), indices.as_slice()) | |
275 | } else { 0 }; | |
276 | let dist_n = Self::estimate_new(centr0, centr1, high_idx, entries.as_slice(), indices.as_slice()); | |
277 | if dist_o + dist_n < old_dist { | |
278 | self.clusters[*low_idx ].centroid = old_centr; | |
279 | self.clusters[closest_idx].centroid = centr0; | |
280 | self.clusters[high_idx ].centroid = centr1; | |
281 | used[*low_idx] = true; | |
282 | used[closest_idx] = true; | |
283 | used[high_idx] = true; | |
284 | high_u.remove(high_idx_idx); | |
285 | do_elbg_step = true; | |
286 | } | |
287 | } | |
288 | } | |
289 | } | |
290 | iterations += 1; | |
291 | } | |
292 | if dist < prev_dist { | |
f808017e KS |
293 | cw_count = 0; |
294 | for cluster in self.clusters.iter() { | |
295 | if cluster.count == 0 { | |
296 | continue; | |
297 | } | |
298 | old_cb[cw_count] = cluster.centroid; | |
299 | cw_count += 1; | |
971ae306 KS |
300 | } |
301 | } | |
302 | dst.copy_from_slice(&old_cb); | |
f808017e | 303 | cw_count |
971ae306 KS |
304 | } |
305 | } |