add RDFT
authorKostya Shishkov <kostya.shishkov@gmail.com>
Fri, 25 Jan 2019 12:34:20 +0000 (13:34 +0100)
committerKostya Shishkov <kostya.shishkov@gmail.com>
Fri, 25 Jan 2019 12:34:20 +0000 (13:34 +0100)
nihav-core/src/dsp/fft.rs

index 98f890467389b6bbb13ab0378411d204b750fcfe..b7bf77a0dd34795527a828269e160feef31c5443 100644 (file)
@@ -425,6 +425,81 @@ impl FFTBuilder {
     }
 }
 
+pub struct RDFT {
+    table:  Vec<FFTComplex>,
+    fft:    FFT,
+    fwd:    bool,
+    size:   usize,
+}
+
+fn crossadd(a: &FFTComplex, b: &FFTComplex) -> FFTComplex {
+    FFTComplex { re: a.re + b.re, im: a.im - b.im }
+}
+
+impl RDFT {
+    pub fn do_rdft(&mut self, src: &[FFTComplex], dst: &mut [FFTComplex]) {
+        dst.copy_from_slice(src);
+        self.do_rdft_inplace(dst);
+    }
+    pub fn do_rdft_inplace(&mut self, buf: &mut [FFTComplex]) {
+        if !self.fwd {
+            for n in 0..self.size/2 {
+                let in0 = buf[n + 1];
+                let in1 = buf[self.size - n - 1];
+
+                let t0 = crossadd(&in0, &in1);
+                let t1 = FFTComplex { re: in1.im + in0.im, im: in1.re - in0.re };
+                let tab = self.table[n];
+                let t2 = FFTComplex { re: t1.im * tab.im + t1.re * tab.re, im: t1.im * tab.re - t1.re * tab.im };
+
+                buf[n + 1] = FFTComplex { re: t0.im - t2.im, im: t0.re - t2.re }; // (t0 - t2).conj().rotate()
+                buf[self.size - n - 1] = (t0 + t2).rotate();
+            }
+            let a = buf[0].re;
+            let b = buf[0].im;
+            buf[0].re = a - b;
+            buf[0].im = a + b;
+        }
+        self.fft.do_fft_inplace(buf, true);
+        if self.fwd {
+            for n in 0..self.size/2 {
+                let in0 = buf[n + 1];
+                let in1 = buf[self.size - n - 1];
+
+                let t0 = crossadd(&in0, &in1).scale(0.5);
+                let t1 = FFTComplex { re: in0.im + in1.im, im: in0.re - in1.re };
+                let t2 = t1 * self.table[n];
+
+                buf[n + 1] = crossadd(&t0, &t2);
+                buf[self.size - n - 1] = FFTComplex { re: t0.re - t2.re, im: -(t0.im + t2.im) }; 
+            }
+            let a = buf[0].re;
+            let b = buf[0].im;
+            buf[0].re = a + b;
+            buf[0].im = a - b;
+        } else {
+            for n in 0..self.size {
+                buf[n] = FFTComplex{ re: buf[n].im, im: buf[n].re };
+            }
+        }
+    }
+}
+
+pub struct RDFTBuilder {
+}
+
+impl RDFTBuilder {
+    pub fn new_rdft(mode: FFTMode, size: usize, forward: bool) -> RDFT {
+        let mut table: Vec<FFTComplex> = Vec::with_capacity(size / 4);
+        let (base, scale) = if forward { (consts::PI / (size as f32), 0.5) } else { (-consts::PI / (size as f32), 1.0) };
+        for i in 0..size/2 {
+            table.push(FFTComplex::exp(base * ((i + 1) as f32)).scale(scale));
+        }
+        let fft = FFTBuilder::new_fft(mode, size);
+        RDFT { table, fft, size, fwd: forward }
+    }
+}
+
 
 #[cfg(test)]
 mod test {
@@ -472,4 +547,29 @@ mod test {
             assert!((fout1[i].im - fout3[i].im).abs() < 1.0);
         }
     }
+
+    #[test]
+    fn test_rdft() {
+        let mut fin:   [FFTComplex; 128] = [FFTC_ZERO; 128];
+        let mut fout1: [FFTComplex; 128] = [FFTC_ZERO; 128];
+        let mut rdft = RDFTBuilder::new_rdft(FFTMode::SplitRadix,  fin.len(), true);
+        let mut seed: u32 = 42;
+        for i in 0..fin.len() {
+            seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
+            let val = (seed >> 16) as i16;
+            fin[i].re = (val as f32) / 256.0;
+            seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
+            let val = (seed >> 16) as i16;
+            fin[i].im = (val as f32) / 256.0;
+        }
+        rdft.do_rdft(&fin, &mut fout1);
+        let mut irdft = RDFTBuilder::new_rdft(FFTMode::SplitRadix,  fin.len(), false);
+        irdft.do_rdft_inplace(&mut fout1);
+
+        for i in 0..fin.len() {
+            let tst = fout1[i].scale(0.5/(fout1.len() as f32));
+            assert!((tst.re - fin[i].re).abs() < 1.0);
+            assert!((tst.im - fin[i].im).abs() < 1.0);
+        }
+    }
 }