]> git.nihav.org Git - nihav.git/commitdiff
codec_support: add module for generic vector quantisation
authorKostya Shishkov <kostya.shishkov@gmail.com>
Sat, 30 May 2020 10:07:17 +0000 (12:07 +0200)
committerKostya Shishkov <kostya.shishkov@gmail.com>
Sat, 30 May 2020 10:07:17 +0000 (12:07 +0200)
nihav-codec-support/Cargo.toml
nihav-codec-support/src/lib.rs
nihav-codec-support/src/vq/generic_elbg.rs [new file with mode: 0644]
nihav-codec-support/src/vq/generic_mediancut.rs [new file with mode: 0644]
nihav-codec-support/src/vq/mod.rs [new file with mode: 0644]

index c302a167ae1d3f02e0bfdae89be78deef06b9d16..d0d7117b4fb4b1605e3ba4e1d86539c9e5694a5e 100644 (file)
@@ -18,3 +18,4 @@ dct = ["dsp"]
 fft = ["dsp"]
 mdct = ["fft", "dsp"]
 dsp_window = ["dsp"]
+vq = []
index d2c889f476d9d0655608939796cf3d0b9646fe50..9139a2ae0b4fc10eeec9fd3c083829ed26e42eb9 100644 (file)
@@ -18,4 +18,7 @@ pub mod imgwrite;
 
 pub mod test;
 
+#[cfg(feature="vq")]
+pub mod vq;
+
 extern crate nihav_core;
diff --git a/nihav-codec-support/src/vq/generic_elbg.rs b/nihav-codec-support/src/vq/generic_elbg.rs
new file mode 100644 (file)
index 0000000..b2c838d
--- /dev/null
@@ -0,0 +1,289 @@
+use super::{VQElement, VQElementSum};
+
+// very simple RNG for internal needs
+struct RNG {
+    seed: u16,
+}
+
+impl RNG {
+    fn new() -> Self { Self { seed: 0x1234 } }
+    fn next(&mut self) -> u8 {
+        if (self.seed & 0x8000) != 0 {
+            self.seed = (self.seed & 0x7FFF) * 2 ^ 0x1B2B;
+        } else {
+            self.seed <<= 1;
+        }
+        self.seed as u8
+    }
+}
+
+struct Entry<T> {
+    val:        T,
+    count:      u64,
+}
+
+struct Cluster<T: VQElement, TS: VQElementSum<T>> {
+    centroid:   T,
+    dist:       u64,
+    count:      u64,
+    sum:        TS,
+}
+
+impl<T: VQElement, TS: VQElementSum<T>> Cluster<T, TS> {
+    fn new(centroid: T) -> Self {
+        Self {
+            centroid,
+            dist:       0,
+            count:      0,
+            sum:        TS::zero(),
+        }
+    }
+    fn reset(&mut self) {
+        self.count = 0;
+        self.sum   = TS::zero();
+        self.dist  = 0;
+    }
+    fn add_point(&mut self, entry: &Entry<T>) {
+        self.sum.add(entry.val, entry.count);
+        self.count += entry.count;
+    }
+    fn add_dist(&mut self, entry: &Entry<T>) {
+        self.dist += u64::from(self.centroid.dist(entry.val)) * entry.count;
+    }
+    fn calc_centroid(&mut self) {
+        self.centroid = self.sum.get_centroid();
+    }
+    fn calc_dist(&mut self) {
+        if self.count != 0 {
+            self.dist = (self.dist + self.count / 2) / self.count;
+        }
+    }
+}
+
+pub struct ELBG<T: VQElement, TS: VQElementSum<T>> {
+    clusters:   Vec<Cluster<T, TS>>,
+}
+
+impl<T: VQElement+Default, TS: VQElementSum<T>> ELBG<T, TS> {
+    pub fn new(initial_cb: &[T]) -> Self {
+        let mut clusters = Vec::with_capacity(initial_cb.len());
+        for elem in initial_cb.iter() {
+            let cluster = Cluster::new(*elem);
+            clusters.push(cluster);
+        }
+        Self {
+            clusters,
+        }
+    }
+    fn new_split(old_index: usize, entries: &[Entry<T>], indices: &[usize]) -> Option<(T, T)> {
+        let mut max = T::min_cw();
+        let mut min = T::max_cw();
+        let mut found = false;
+        for (entry, idx) in entries.iter().zip(indices) {
+            if *idx == old_index {
+                max = max.max(entry.val);
+                min = min.min(entry.val);
+                found = true;
+            }
+        }
+        if !found {
+            return None;
+        }
+        let mut ts0 = TS::zero();
+        let mut ts1 = TS::zero();
+        ts0.add(min, 2); ts0.add(max, 1);
+        ts1.add(min, 1); ts1.add(max, 2);
+        Some((ts0.get_centroid(), ts1.get_centroid()))
+    }
+    fn old_centre(&self, old_index1: usize, old_index2: usize, entries: &[Entry<T>], indices: &[usize]) -> T {
+        let mut max = T::min_cw();
+        let mut min = T::max_cw();
+        let mut found = false;
+        for (entry, idx) in entries.iter().zip(indices) {
+            if *idx == old_index1 || *idx == old_index2 {
+                max = max.max(entry.val);
+                min = min.min(entry.val);
+                found = true;
+            }
+        }
+        if !found {
+            max = self.clusters[old_index1].centroid.max(self.clusters[old_index2].centroid);
+            min = self.clusters[old_index1].centroid.min(self.clusters[old_index2].centroid);
+        }
+        let mut ts = TS::zero();
+        ts.add(min, 2); ts.add(max, 1);
+        ts.get_centroid()
+    }
+    fn estimate_old(old_idx0: usize, old_idx1: usize, c: T, entries: &[Entry<T>], indices: &[usize]) -> u64 {
+        let mut clu: Cluster<T, TS> = Cluster::new(c);
+        let mut count = 0;
+        for (entry, idx) in entries.iter().zip(indices) {
+            if *idx == old_idx0 || *idx == old_idx1 {
+                clu.add_dist(entry);
+                count += entry.count;
+            }
+        }
+        clu.count = count;
+        clu.calc_dist();
+        clu.dist
+    }
+    fn estimate_new(c0: T, c1: T, old_idx: usize, entries: &[Entry<T>], indices: &[usize]) -> u64 {
+        let mut clu0: Cluster<T, TS> = Cluster::new(c0);
+        let mut clu1: Cluster<T, TS> = Cluster::new(c1);
+        let mut count0 = 0;
+        let mut count1 = 0;
+        for (entry, idx) in entries.iter().zip(indices) {
+            if *idx == old_idx {
+                if c0.dist(entry.val) < c1.dist(entry.val) {
+                    clu0.add_dist(entry);
+                    count0 += entry.count;
+                } else {
+                    clu1.add_dist(entry);
+                    count1 += entry.count;
+                }
+            }
+        }
+        clu0.count = count0;
+        clu1.count = count1;
+        clu0.calc_dist();
+        clu1.calc_dist();
+        clu0.dist + clu1.dist
+    }
+    pub fn quantise(&mut self, src: &[T], dst: &mut [T]) {
+        if src.len() < 1 || dst.len() != self.clusters.len() {
+            return;
+        }
+        let mut old_cb = vec![T::default(); self.clusters.len()];
+        let mut prev_dist = std::u64::MAX;
+        let mut dist = std::u64::MAX / 2;
+        let mut indices = Vec::with_capacity(src.len());
+        let mut elements = Vec::with_capacity(src.len());
+        elements.extend_from_slice(src);
+        for comp in 0..T::num_components() {
+            T::sort_by_component(elements.as_mut_slice(), comp);
+        }
+        let mut entries = Vec::with_capacity(elements.len() / 2);
+        let mut lastval = elements[0];
+        let mut run = 1;
+        for point in elements.iter().skip(1) {
+            if &lastval == point {
+                run += 1;
+            } else {
+                entries.push(Entry { val: lastval, count: run });
+                lastval = *point;
+                run = 1;
+            }
+        }
+        entries.push(Entry { val: lastval, count: run });
+        drop(elements);
+
+        let mut low_u:  Vec<usize> = Vec::with_capacity(self.clusters.len());
+        let mut high_u: Vec<usize> = Vec::with_capacity(self.clusters.len());
+        let mut rng = RNG::new();
+        let mut iterations = 0usize;
+        let mut do_elbg_step = true;
+        while (iterations < 20) && (dist < prev_dist - prev_dist / 1000) {
+            prev_dist = dist;
+            for i in 0..dst.len() {
+                old_cb[i] = self.clusters[i].centroid;
+                self.clusters[i].reset();
+            }
+            // put points into the nearest clusters
+            indices.truncate(0);
+            for entry in entries.iter() {
+                let mut bestidx = 0;
+                let mut bestdist = std::u32::MAX;
+                for (i, cluster) in self.clusters.iter().enumerate() {
+                    let dist = entry.val.dist(cluster.centroid);
+                    if bestdist > dist {
+                        bestdist = dist;
+                        bestidx = i;
+                        if dist == 0 {
+                            break;
+                        }
+                    }
+                }
+                indices.push(bestidx);
+                self.clusters[bestidx].add_point(entry);
+            }
+            // calculate params
+            for cluster in self.clusters.iter_mut() {
+                cluster.calc_centroid();
+            }
+            dist = 0;
+            for (idx, entry) in indices.iter().zip(entries.iter()) {
+                self.clusters[*idx].add_dist(entry);
+            }
+            for cluster in self.clusters.iter_mut() {
+                cluster.calc_dist();
+                dist += cluster.dist;
+            }
+
+            let dmean = dist / (dst.len() as u64);
+            low_u.truncate(0);
+            high_u.truncate(0);
+            let mut used = vec![false; dst.len()];
+            for (i, cluster) in self.clusters.iter().enumerate() {
+                if cluster.dist < dmean {
+                    low_u.push(i);
+                } else if cluster.dist > dmean * 2 {
+                    high_u.push(i);
+                    used[i] = true;
+                }
+            }
+
+            if do_elbg_step {
+                do_elbg_step = false;
+                for low_idx in low_u.iter() {
+                    if high_u.len() == 0 {
+                        break;
+                    }
+                    let high_idx_idx = (rng.next() as usize) % high_u.len();
+                    let high_idx = high_u[high_idx_idx];
+                    let mut closest_idx = *low_idx;
+                    let mut closest_dist = std::u32::MAX;
+                    let low_centr = self.clusters[*low_idx].centroid;
+                    for i in 0..dst.len() {//low_u.iter() {
+                        if i == *low_idx || used[i] {
+                            continue;
+                        }
+                        let dist = self.clusters[i].centroid.dist(low_centr);
+                        if closest_dist > dist {
+                            closest_dist = dist;
+                            closest_idx  = i;
+                        }
+                    }
+                    if closest_idx == *low_idx {
+                        continue;
+                    }
+                    let old_dist = self.clusters[*low_idx].dist + self.clusters[closest_idx].dist + self.clusters[high_idx].dist;
+                    let old_centr = self.old_centre(*low_idx, closest_idx, entries.as_slice(), indices.as_slice());
+                    let ret = Self::new_split(high_idx, entries.as_slice(), indices.as_slice());
+                    if let Some((centr0, centr1)) = ret {
+                        let dist_o = if old_dist > self.clusters[high_idx].dist {
+                                Self::estimate_old(*low_idx, closest_idx, old_centr, entries.as_slice(), indices.as_slice())
+                            } else { 0 };
+                        let dist_n = Self::estimate_new(centr0, centr1, high_idx, entries.as_slice(), indices.as_slice());
+                        if dist_o + dist_n < old_dist {
+                            self.clusters[*low_idx   ].centroid = old_centr;
+                            self.clusters[closest_idx].centroid = centr0;
+                            self.clusters[high_idx   ].centroid = centr1;
+                            used[*low_idx]    = true;
+                            used[closest_idx] = true;
+                            used[high_idx]    = true;
+                            high_u.remove(high_idx_idx);
+                            do_elbg_step = true;
+                        }
+                    }
+                }
+            }
+            iterations += 1;
+        }
+        if dist < prev_dist {
+            for i in 0..dst.len() {
+                old_cb[i] = self.clusters[i].centroid;
+            }
+        }
+        dst.copy_from_slice(&old_cb);
+    }
+}
diff --git a/nihav-codec-support/src/vq/generic_mediancut.rs b/nihav-codec-support/src/vq/generic_mediancut.rs
new file mode 100644 (file)
index 0000000..e6b35e6
--- /dev/null
@@ -0,0 +1,117 @@
+use super::{VQElement, VQElementSum};
+
+struct VQBox<'a, T: VQElement> {
+    points: &'a mut [T],
+    max:    T,
+    min:    T,
+}
+
+impl<'a, T: VQElement> VQBox<'a, T> {
+    fn new(points: &'a mut [T]) -> Self {
+        let mut max = T::min_cw();
+        let mut min = T::max_cw();
+        for point in points.iter() {
+            max = max.max(*point);
+            min = min.min(*point);
+        }
+        Self { points, max, min }
+    }
+    fn can_split(&self) -> bool {
+        self.max != self.min
+    }
+    fn calc_min_and_max(points: &[T]) -> (T, T) {
+        let mut max = T::min_cw();
+        let mut min = T::max_cw();
+        for point in points.iter() {
+            max = max.max(*point);
+            min = min.min(*point);
+        }
+        (min, max)
+    }
+    fn get_pivot(arr: &[T]) -> usize {
+        if arr.len() < 2 {
+            return 0;
+        }
+        let mut lastval = arr[0];
+        let mut pivot = 0;
+        let mut idx = 1;
+        for el in arr.iter().skip(1) {
+            if *el != lastval && (pivot == 0 || idx <= arr.len() / 2) {
+                pivot = idx;
+                lastval = *el;
+            }
+            idx += 1;
+        }
+        pivot
+    }
+    fn split(self) -> (VQBox<'a, T>, VQBox<'a, T>) {
+        let sort_c = T::max_dist_component(&self.min, &self.max);
+        T::sort_by_component(self.points, sort_c);
+        let pivot = Self::get_pivot(self.points);
+        let (part0, part1) = self.points.split_at_mut(pivot);
+        let (min0, max0) = Self::calc_min_and_max(part0);
+        let (min1, max1) = Self::calc_min_and_max(part1);
+        let box0 = VQBox { points: part0, max: max0, min: min0 };
+        let box1 = VQBox { points: part1, max: max1, min: min1 };
+
+        (box0, box1)
+    }
+}
+
+pub fn quantise_median_cut<T: VQElement, TS: VQElementSum<T>>(src: &[T], dst: &mut [T]) -> usize {
+    let mut points = Vec::with_capacity(src.len());
+    points.extend(src.into_iter());
+    for comp in 0..T::num_components() {
+        T::sort_by_component(points.as_mut_slice(), comp);
+    }
+    let box0 = VQBox::new(points.as_mut_slice());
+    let mut boxes: Vec<VQBox<T>> = Vec::with_capacity(dst.len());
+    boxes.push(box0);
+    let mut changed = true;
+    while changed && boxes.len() < dst.len() {
+        let end = boxes.len();
+        changed = false;
+        let mut split_largest = false;
+        for _ in 0..end {
+            let curbox = boxes.remove(0);
+            if curbox.can_split() {
+                let (box0, box1) = curbox.split();
+                boxes.push(box0);
+                boxes.push(box1);
+                changed = true;
+            } else {
+                boxes.push(curbox);
+                split_largest = true;
+                break;
+            }
+            if boxes.len() == dst.len() {
+                break;
+            }
+        }
+        if split_largest {
+            let mut maxidx = 0;
+            let mut lcount = 0;
+            for (i, cbox) in boxes.iter().enumerate() {
+                if cbox.can_split() && cbox.points.len() > lcount {
+                    lcount = cbox.points.len();
+                    maxidx = i;
+                }
+            }
+            if lcount > 0 {
+                let curbox = boxes.remove(maxidx);
+                let (box0, box1) = curbox.split();
+                boxes.push(box0);
+                boxes.push(box1);
+                changed = true;
+            }
+        }
+    }
+    for (dst, curbox) in dst.iter_mut().zip(boxes.iter()) {
+        let mut sum = TS::zero();
+        sum.add(curbox.min, 1);
+        sum.add(curbox.max, 1);
+        *dst = sum.get_centroid();
+    }
+
+    boxes.len()
+}
diff --git a/nihav-codec-support/src/vq/mod.rs b/nihav-codec-support/src/vq/mod.rs
new file mode 100644 (file)
index 0000000..38ab9e3
--- /dev/null
@@ -0,0 +1,23 @@
+//! Vector quantisation routines.
+mod generic_elbg;
+mod generic_mediancut;
+
+pub trait VQElement: Sized+Copy+PartialEq {
+    fn dist(&self, rval: Self) -> u32;
+    fn min_cw() -> Self;
+    fn max_cw() -> Self;
+    fn min(&self, rval: Self) -> Self;
+    fn max(&self, rval: Self) -> Self;
+    fn num_components() -> usize;
+    fn sort_by_component(arr: &mut [Self], component: usize);
+    fn max_dist_component(min: &Self, max: &Self) -> usize;
+}
+
+pub trait VQElementSum<T: VQElement> {
+    fn zero() -> Self;
+    fn add(&mut self, rval: T, count: u64);
+    fn get_centroid(&self) -> T;
+}
+
+pub use self::generic_elbg::ELBG;
+pub use self::generic_mediancut::quantise_median_cut;