]>
Commit | Line | Data |
---|---|---|
1 | use super::{VQElement, VQElementSum}; | |
2 | ||
3 | // very simple RNG for internal needs | |
4 | struct RNG { | |
5 | seed: u16, | |
6 | } | |
7 | ||
8 | impl RNG { | |
9 | fn new() -> Self { Self { seed: 0x1234 } } | |
10 | fn next(&mut self) -> u8 { | |
11 | if (self.seed & 0x8000) != 0 { | |
12 | self.seed = (self.seed & 0x7FFF) * 2 ^ 0x1B2B; | |
13 | } else { | |
14 | self.seed <<= 1; | |
15 | } | |
16 | self.seed as u8 | |
17 | } | |
18 | } | |
19 | ||
20 | struct Entry<T> { | |
21 | val: T, | |
22 | count: u64, | |
23 | } | |
24 | ||
25 | struct Cluster<T: VQElement, TS: VQElementSum<T>> { | |
26 | centroid: T, | |
27 | dist: u64, | |
28 | count: u64, | |
29 | sum: TS, | |
30 | } | |
31 | ||
32 | impl<T: VQElement, TS: VQElementSum<T>> Cluster<T, TS> { | |
33 | fn new(centroid: T) -> Self { | |
34 | Self { | |
35 | centroid, | |
36 | dist: 0, | |
37 | count: 0, | |
38 | sum: TS::zero(), | |
39 | } | |
40 | } | |
41 | fn reset(&mut self) { | |
42 | self.count = 0; | |
43 | self.sum = TS::zero(); | |
44 | self.dist = 0; | |
45 | } | |
46 | fn add_point(&mut self, entry: &Entry<T>) { | |
47 | self.sum.add(entry.val, entry.count); | |
48 | self.count += entry.count; | |
49 | } | |
50 | fn add_dist(&mut self, entry: &Entry<T>) { | |
51 | self.dist += u64::from(self.centroid.dist(entry.val)) * entry.count; | |
52 | } | |
53 | fn calc_centroid(&mut self) { | |
54 | self.centroid = self.sum.get_centroid(); | |
55 | } | |
56 | fn calc_dist(&mut self) { | |
57 | if self.count != 0 { | |
58 | self.dist = (self.dist + self.count / 2) / self.count; | |
59 | } | |
60 | } | |
61 | } | |
62 | ||
63 | pub struct ELBG<T: VQElement, TS: VQElementSum<T>> { | |
64 | clusters: Vec<Cluster<T, TS>>, | |
65 | } | |
66 | ||
67 | impl<T: VQElement+Default, TS: VQElementSum<T>> ELBG<T, TS> { | |
68 | pub fn new(initial_cb: &[T]) -> Self { | |
69 | let mut clusters = Vec::with_capacity(initial_cb.len()); | |
70 | for elem in initial_cb.iter() { | |
71 | let cluster = Cluster::new(*elem); | |
72 | clusters.push(cluster); | |
73 | } | |
74 | Self { | |
75 | clusters, | |
76 | } | |
77 | } | |
78 | fn new_split(old_index: usize, entries: &[Entry<T>], indices: &[usize]) -> Option<(T, T)> { | |
79 | let mut max = T::min_cw(); | |
80 | let mut min = T::max_cw(); | |
81 | let mut found = false; | |
82 | for (entry, idx) in entries.iter().zip(indices) { | |
83 | if *idx == old_index { | |
84 | max = max.max(entry.val); | |
85 | min = min.min(entry.val); | |
86 | found = true; | |
87 | } | |
88 | } | |
89 | if !found { | |
90 | return None; | |
91 | } | |
92 | let mut ts0 = TS::zero(); | |
93 | let mut ts1 = TS::zero(); | |
94 | ts0.add(min, 2); ts0.add(max, 1); | |
95 | ts1.add(min, 1); ts1.add(max, 2); | |
96 | Some((ts0.get_centroid(), ts1.get_centroid())) | |
97 | } | |
98 | fn old_centre(&self, old_index1: usize, old_index2: usize, entries: &[Entry<T>], indices: &[usize]) -> T { | |
99 | let mut max = T::min_cw(); | |
100 | let mut min = T::max_cw(); | |
101 | let mut found = false; | |
102 | for (entry, idx) in entries.iter().zip(indices) { | |
103 | if *idx == old_index1 || *idx == old_index2 { | |
104 | max = max.max(entry.val); | |
105 | min = min.min(entry.val); | |
106 | found = true; | |
107 | } | |
108 | } | |
109 | if !found { | |
110 | max = self.clusters[old_index1].centroid.max(self.clusters[old_index2].centroid); | |
111 | min = self.clusters[old_index1].centroid.min(self.clusters[old_index2].centroid); | |
112 | } | |
113 | let mut ts = TS::zero(); | |
114 | ts.add(min, 2); ts.add(max, 1); | |
115 | ts.get_centroid() | |
116 | } | |
117 | fn estimate_old(old_idx0: usize, old_idx1: usize, c: T, entries: &[Entry<T>], indices: &[usize]) -> u64 { | |
118 | let mut clu: Cluster<T, TS> = Cluster::new(c); | |
119 | let mut count = 0; | |
120 | for (entry, idx) in entries.iter().zip(indices) { | |
121 | if *idx == old_idx0 || *idx == old_idx1 { | |
122 | clu.add_dist(entry); | |
123 | count += entry.count; | |
124 | } | |
125 | } | |
126 | clu.count = count; | |
127 | clu.calc_dist(); | |
128 | clu.dist | |
129 | } | |
130 | fn estimate_new(c0: T, c1: T, old_idx: usize, entries: &[Entry<T>], indices: &[usize]) -> u64 { | |
131 | let mut clu0: Cluster<T, TS> = Cluster::new(c0); | |
132 | let mut clu1: Cluster<T, TS> = Cluster::new(c1); | |
133 | let mut count0 = 0; | |
134 | let mut count1 = 0; | |
135 | for (entry, idx) in entries.iter().zip(indices) { | |
136 | if *idx == old_idx { | |
137 | if c0.dist(entry.val) < c1.dist(entry.val) { | |
138 | clu0.add_dist(entry); | |
139 | count0 += entry.count; | |
140 | } else { | |
141 | clu1.add_dist(entry); | |
142 | count1 += entry.count; | |
143 | } | |
144 | } | |
145 | } | |
146 | clu0.count = count0; | |
147 | clu1.count = count1; | |
148 | clu0.calc_dist(); | |
149 | clu1.calc_dist(); | |
150 | clu0.dist + clu1.dist | |
151 | } | |
152 | pub fn quantise(&mut self, src: &[T], dst: &mut [T]) { | |
153 | if src.len() < 1 || dst.len() != self.clusters.len() { | |
154 | return; | |
155 | } | |
156 | let mut old_cb = vec![T::default(); self.clusters.len()]; | |
157 | let mut prev_dist = std::u64::MAX; | |
158 | let mut dist = std::u64::MAX / 2; | |
159 | let mut indices = Vec::with_capacity(src.len()); | |
160 | let mut elements = Vec::with_capacity(src.len()); | |
161 | elements.extend_from_slice(src); | |
162 | for comp in 0..T::num_components() { | |
163 | T::sort_by_component(elements.as_mut_slice(), comp); | |
164 | } | |
165 | let mut entries = Vec::with_capacity(elements.len() / 2); | |
166 | let mut lastval = elements[0]; | |
167 | let mut run = 1; | |
168 | for point in elements.iter().skip(1) { | |
169 | if &lastval == point { | |
170 | run += 1; | |
171 | } else { | |
172 | entries.push(Entry { val: lastval, count: run }); | |
173 | lastval = *point; | |
174 | run = 1; | |
175 | } | |
176 | } | |
177 | entries.push(Entry { val: lastval, count: run }); | |
178 | drop(elements); | |
179 | ||
180 | let mut low_u: Vec<usize> = Vec::with_capacity(self.clusters.len()); | |
181 | let mut high_u: Vec<usize> = Vec::with_capacity(self.clusters.len()); | |
182 | let mut rng = RNG::new(); | |
183 | let mut iterations = 0usize; | |
184 | let mut do_elbg_step = true; | |
185 | while (iterations < 20) && (dist < prev_dist - prev_dist / 1000) { | |
186 | prev_dist = dist; | |
187 | for i in 0..dst.len() { | |
188 | old_cb[i] = self.clusters[i].centroid; | |
189 | self.clusters[i].reset(); | |
190 | } | |
191 | // put points into the nearest clusters | |
192 | indices.truncate(0); | |
193 | for entry in entries.iter() { | |
194 | let mut bestidx = 0; | |
195 | let mut bestdist = std::u32::MAX; | |
196 | for (i, cluster) in self.clusters.iter().enumerate() { | |
197 | let dist = entry.val.dist(cluster.centroid); | |
198 | if bestdist > dist { | |
199 | bestdist = dist; | |
200 | bestidx = i; | |
201 | if dist == 0 { | |
202 | break; | |
203 | } | |
204 | } | |
205 | } | |
206 | indices.push(bestidx); | |
207 | self.clusters[bestidx].add_point(entry); | |
208 | } | |
209 | // calculate params | |
210 | for cluster in self.clusters.iter_mut() { | |
211 | cluster.calc_centroid(); | |
212 | } | |
213 | dist = 0; | |
214 | for (idx, entry) in indices.iter().zip(entries.iter()) { | |
215 | self.clusters[*idx].add_dist(entry); | |
216 | } | |
217 | for cluster in self.clusters.iter_mut() { | |
218 | cluster.calc_dist(); | |
219 | dist += cluster.dist; | |
220 | } | |
221 | ||
222 | let dmean = dist / (dst.len() as u64); | |
223 | low_u.truncate(0); | |
224 | high_u.truncate(0); | |
225 | let mut used = vec![false; dst.len()]; | |
226 | for (i, cluster) in self.clusters.iter().enumerate() { | |
227 | if cluster.dist < dmean { | |
228 | low_u.push(i); | |
229 | } else if cluster.dist > dmean * 2 { | |
230 | high_u.push(i); | |
231 | used[i] = true; | |
232 | } | |
233 | } | |
234 | ||
235 | if do_elbg_step { | |
236 | do_elbg_step = false; | |
237 | for low_idx in low_u.iter() { | |
238 | if high_u.len() == 0 { | |
239 | break; | |
240 | } | |
241 | let high_idx_idx = (rng.next() as usize) % high_u.len(); | |
242 | let high_idx = high_u[high_idx_idx]; | |
243 | let mut closest_idx = *low_idx; | |
244 | let mut closest_dist = std::u32::MAX; | |
245 | let low_centr = self.clusters[*low_idx].centroid; | |
246 | for i in 0..dst.len() {//low_u.iter() { | |
247 | if i == *low_idx || used[i] { | |
248 | continue; | |
249 | } | |
250 | let dist = self.clusters[i].centroid.dist(low_centr); | |
251 | if closest_dist > dist { | |
252 | closest_dist = dist; | |
253 | closest_idx = i; | |
254 | } | |
255 | } | |
256 | if closest_idx == *low_idx { | |
257 | continue; | |
258 | } | |
259 | let old_dist = self.clusters[*low_idx].dist + self.clusters[closest_idx].dist + self.clusters[high_idx].dist; | |
260 | let old_centr = self.old_centre(*low_idx, closest_idx, entries.as_slice(), indices.as_slice()); | |
261 | let ret = Self::new_split(high_idx, entries.as_slice(), indices.as_slice()); | |
262 | if let Some((centr0, centr1)) = ret { | |
263 | let dist_o = if old_dist > self.clusters[high_idx].dist { | |
264 | Self::estimate_old(*low_idx, closest_idx, old_centr, entries.as_slice(), indices.as_slice()) | |
265 | } else { 0 }; | |
266 | let dist_n = Self::estimate_new(centr0, centr1, high_idx, entries.as_slice(), indices.as_slice()); | |
267 | if dist_o + dist_n < old_dist { | |
268 | self.clusters[*low_idx ].centroid = old_centr; | |
269 | self.clusters[closest_idx].centroid = centr0; | |
270 | self.clusters[high_idx ].centroid = centr1; | |
271 | used[*low_idx] = true; | |
272 | used[closest_idx] = true; | |
273 | used[high_idx] = true; | |
274 | high_u.remove(high_idx_idx); | |
275 | do_elbg_step = true; | |
276 | } | |
277 | } | |
278 | } | |
279 | } | |
280 | iterations += 1; | |
281 | } | |
282 | if dist < prev_dist { | |
283 | for i in 0..dst.len() { | |
284 | old_cb[i] = self.clusters[i].centroid; | |
285 | } | |
286 | } | |
287 | dst.copy_from_slice(&old_cb); | |
288 | } | |
289 | } |