1 use super::{VQElement, VQElementSum};
3 // very simple RNG for internal needs
9 fn new() -> Self { Self { seed: 0x1234 } }
10 fn next(&mut self) -> u8 {
11 if (self.seed & 0x8000) != 0 {
12 self.seed = ((self.seed & 0x7FFF) * 2) ^ 0x1B2B;
25 struct Cluster<T: VQElement, TS: VQElementSum<T>> {
32 impl<T: VQElement, TS: VQElementSum<T>> Cluster<T, TS> {
33 fn new(centroid: T) -> Self {
43 self.sum = TS::zero();
46 fn add_point(&mut self, entry: &Entry<T>) {
47 self.sum.add(entry.val, entry.count);
48 self.count += entry.count;
50 fn add_dist(&mut self, entry: &Entry<T>) {
51 self.dist += u64::from(self.centroid.dist(entry.val)) * entry.count;
53 fn calc_centroid(&mut self) {
55 self.centroid = self.sum.get_centroid();
58 fn calc_dist(&mut self) {
62 pub struct ELBG<T: VQElement, TS: VQElementSum<T>> {
63 clusters: Vec<Cluster<T, TS>>,
66 impl<T: VQElement+Default, TS: VQElementSum<T>> ELBG<T, TS> {
67 pub fn new(initial_cb: &[T]) -> Self {
68 let mut clusters = Vec::with_capacity(initial_cb.len());
69 for elem in initial_cb.iter() {
70 let cluster = Cluster::new(*elem);
71 clusters.push(cluster);
77 fn new_split(old_index: usize, entries: &[Entry<T>], indices: &[usize]) -> Option<(T, T)> {
78 let mut max = T::min_cw();
79 let mut min = T::max_cw();
80 let mut found = false;
81 for (entry, idx) in entries.iter().zip(indices) {
82 if *idx == old_index {
83 max = max.max(entry.val);
84 min = min.min(entry.val);
91 let mut ts0 = TS::zero();
92 let mut ts1 = TS::zero();
93 ts0.add(min, 2); ts0.add(max, 1);
94 ts1.add(min, 1); ts1.add(max, 2);
95 Some((ts0.get_centroid(), ts1.get_centroid()))
97 fn old_centre(&self, old_index1: usize, old_index2: usize, entries: &[Entry<T>], indices: &[usize]) -> T {
98 let mut max = T::min_cw();
99 let mut min = T::max_cw();
100 let mut found = false;
101 for (entry, idx) in entries.iter().zip(indices) {
102 if *idx == old_index1 || *idx == old_index2 {
103 max = max.max(entry.val);
104 min = min.min(entry.val);
109 max = self.clusters[old_index1].centroid.max(self.clusters[old_index2].centroid);
110 min = self.clusters[old_index1].centroid.min(self.clusters[old_index2].centroid);
112 let mut ts = TS::zero();
113 ts.add(min, 2); ts.add(max, 1);
116 fn estimate_old(old_idx0: usize, old_idx1: usize, c: T, entries: &[Entry<T>], indices: &[usize]) -> u64 {
117 let mut clu: Cluster<T, TS> = Cluster::new(c);
119 for (entry, idx) in entries.iter().zip(indices) {
120 if *idx == old_idx0 || *idx == old_idx1 {
122 count += entry.count;
129 fn estimate_new(c0: T, c1: T, old_idx: usize, entries: &[Entry<T>], indices: &[usize]) -> u64 {
130 let mut clu0: Cluster<T, TS> = Cluster::new(c0);
131 let mut clu1: Cluster<T, TS> = Cluster::new(c1);
134 for (entry, idx) in entries.iter().zip(indices) {
136 if c0.dist(entry.val) < c1.dist(entry.val) {
137 clu0.add_dist(entry);
138 count0 += entry.count;
140 clu1.add_dist(entry);
141 count1 += entry.count;
149 clu0.dist + clu1.dist
151 #[allow(clippy::cognitive_complexity)]
152 pub fn quantise(&mut self, src: &[T], dst: &mut [T]) -> usize {
153 if src.is_empty() || dst.len() != self.clusters.len() {
156 let mut old_cb = vec![T::default(); self.clusters.len()];
157 let mut prev_dist = std::u64::MAX;
158 let mut dist = std::u64::MAX / 2;
159 let mut indices = Vec::with_capacity(src.len());
160 let mut elements = Vec::with_capacity(src.len());
161 elements.extend_from_slice(src);
162 for comp in 0..T::num_components() {
163 T::sort_by_component(elements.as_mut_slice(), comp);
165 let mut entries = Vec::with_capacity(elements.len() / 2);
166 let mut lastval = elements[0];
168 for point in elements.iter().skip(1) {
169 if &lastval == point {
172 entries.push(Entry { val: lastval, count: run });
177 entries.push(Entry { val: lastval, count: run });
180 let mut cw_count = 0;
181 let mut low_u: Vec<usize> = Vec::with_capacity(self.clusters.len());
182 let mut high_u: Vec<usize> = Vec::with_capacity(self.clusters.len());
183 let mut rng = RNG::new();
184 let mut iterations = 0usize;
185 let mut do_elbg_step = true;
186 while (iterations < 20) && (dist < prev_dist - prev_dist / 100) {
190 for cluster in self.clusters.iter() {
191 if cluster.count == 0 {
194 old_cb[cw_count] = cluster.centroid;
197 for cluster in self.clusters.iter_mut() {
201 // put points into the nearest clusters
203 for entry in entries.iter() {
205 let mut bestdist = std::u32::MAX;
206 for (i, cluster) in self.clusters.iter().enumerate() {
207 let dist = entry.val.dist(cluster.centroid);
216 indices.push(bestidx);
217 self.clusters[bestidx].add_point(entry);
220 for cluster in self.clusters.iter_mut() {
221 cluster.calc_centroid();
224 for (idx, entry) in indices.iter().zip(entries.iter()) {
225 self.clusters[*idx].add_dist(entry);
227 for cluster in self.clusters.iter_mut() {
229 dist += cluster.dist;
232 let dmean = dist / (dst.len() as u64);
235 let mut used = vec![false; dst.len()];
236 for (i, cluster) in self.clusters.iter().enumerate() {
237 if cluster.dist < dmean {
239 } else if cluster.dist > dmean * 2 {
246 do_elbg_step = false;
247 for low_idx in low_u.iter() {
248 if high_u.is_empty() {
251 let high_idx_idx = (rng.next() as usize) % high_u.len();
252 let high_idx = high_u[high_idx_idx];
253 let mut closest_idx = *low_idx;
254 let mut closest_dist = std::u32::MAX;
255 let low_centr = self.clusters[*low_idx].centroid;
256 for i in 0..dst.len() {//low_u.iter() {
257 if i == *low_idx || used[i] {
260 let dist = self.clusters[i].centroid.dist(low_centr);
261 if closest_dist > dist {
266 if closest_idx == *low_idx {
269 let old_dist = self.clusters[*low_idx].dist + self.clusters[closest_idx].dist + self.clusters[high_idx].dist;
270 let old_centr = self.old_centre(*low_idx, closest_idx, entries.as_slice(), indices.as_slice());
271 let ret = Self::new_split(high_idx, entries.as_slice(), indices.as_slice());
272 if let Some((centr0, centr1)) = ret {
273 let dist_o = if old_dist > self.clusters[high_idx].dist {
274 Self::estimate_old(*low_idx, closest_idx, old_centr, entries.as_slice(), indices.as_slice())
276 let dist_n = Self::estimate_new(centr0, centr1, high_idx, entries.as_slice(), indices.as_slice());
277 if dist_o + dist_n < old_dist {
278 self.clusters[*low_idx ].centroid = old_centr;
279 self.clusters[closest_idx].centroid = centr0;
280 self.clusters[high_idx ].centroid = centr1;
281 used[*low_idx] = true;
282 used[closest_idx] = true;
283 used[high_idx] = true;
284 high_u.remove(high_idx_idx);
292 if dist < prev_dist {
294 for cluster in self.clusters.iter() {
295 if cluster.count == 0 {
298 old_cb[cw_count] = cluster.centroid;
302 dst.copy_from_slice(&old_cb);