allow RDFT use inverse FFT and remove reversing results in Bink audio decoder
[nihav.git] / nihav-core / src / dsp / fft.rs
1 use std::f32::{self, consts};
2 use std::ops::{Not, Neg, Add, AddAssign, Sub, SubAssign, Mul, MulAssign};
3 use std::fmt;
4
5 #[derive(Debug,Clone,Copy,PartialEq)]
6 pub struct FFTComplex {
7 pub re: f32,
8 pub im: f32,
9 }
10
11 impl FFTComplex {
12 pub fn exp(val: f32) -> Self {
13 FFTComplex { re: val.cos(), im: val.sin() }
14 }
15 pub fn rotate(self) -> Self {
16 FFTComplex { re: -self.im, im: self.re }
17 }
18 pub fn scale(self, scale: f32) -> Self {
19 FFTComplex { re: self.re * scale, im: self.im * scale }
20 }
21 }
22
23 impl Neg for FFTComplex {
24 type Output = FFTComplex;
25 fn neg(self) -> Self::Output {
26 FFTComplex { re: -self.re, im: -self.im }
27 }
28 }
29
30 impl Not for FFTComplex {
31 type Output = FFTComplex;
32 fn not(self) -> Self::Output {
33 FFTComplex { re: self.re, im: -self.im }
34 }
35 }
36
37 impl Add for FFTComplex {
38 type Output = FFTComplex;
39 fn add(self, other: Self) -> Self::Output {
40 FFTComplex { re: self.re + other.re, im: self.im + other.im }
41 }
42 }
43
44 impl AddAssign for FFTComplex {
45 fn add_assign(&mut self, other: Self) {
46 self.re += other.re;
47 self.im += other.im;
48 }
49 }
50
51 impl Sub for FFTComplex {
52 type Output = FFTComplex;
53 fn sub(self, other: Self) -> Self::Output {
54 FFTComplex { re: self.re - other.re, im: self.im - other.im }
55 }
56 }
57
58 impl SubAssign for FFTComplex {
59 fn sub_assign(&mut self, other: Self) {
60 self.re -= other.re;
61 self.im -= other.im;
62 }
63 }
64
65 impl Mul for FFTComplex {
66 type Output = FFTComplex;
67 fn mul(self, other: Self) -> Self::Output {
68 FFTComplex { re: self.re * other.re - self.im * other.im,
69 im: self.im * other.re + self.re * other.im }
70 }
71 }
72
73 impl MulAssign for FFTComplex {
74 fn mul_assign(&mut self, other: Self) {
75 let re = self.re * other.re - self.im * other.im;
76 let im = self.im * other.re + self.re * other.im;
77 self.re = re;
78 self.im = im;
79 }
80 }
81
82 impl fmt::Display for FFTComplex {
83 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
84 write!(f, "({}, {})", self.re, self.im)
85 }
86 }
87
88 pub const FFTC_ZERO: FFTComplex = FFTComplex { re: 0.0, im: 0.0 };
89
90 #[derive(Debug,Clone,Copy,PartialEq)]
91 pub enum FFTMode {
92 Matrix,
93 CooleyTukey,
94 SplitRadix,
95 }
96
97 pub struct FFT {
98 table: Vec<FFTComplex>,
99 perms: Vec<usize>,
100 swaps: Vec<usize>,
101 bits: u32,
102 mode: FFTMode,
103 }
104
105 impl FFT {
106 fn do_fft_inplace_ct(&mut self, data: &mut [FFTComplex], bits: u32, forward: bool) {
107 if bits == 0 { return; }
108 if bits == 1 {
109 let sum01 = data[0] + data[1];
110 let dif01 = data[0] - data[1];
111 data[0] = sum01;
112 data[1] = dif01;
113 return;
114 }
115 if bits == 2 {
116 let sum01 = data[0] + data[1];
117 let dif01 = data[0] - data[1];
118 let sum23 = data[2] + data[3];
119 let dif23 = data[2] - data[3];
120 if forward {
121 data[0] = sum01 + sum23;
122 data[1] = dif01 - dif23.rotate();
123 data[2] = sum01 - sum23;
124 data[3] = dif01 + dif23.rotate();
125 } else {
126 data[0] = sum01 + sum23;
127 data[1] = dif01 + dif23.rotate();
128 data[2] = sum01 - sum23;
129 data[3] = dif01 - dif23.rotate();
130 }
131 return;
132 }
133
134 let hsize = (1 << (bits - 1)) as usize;
135 self.do_fft_inplace_ct(&mut data[0..hsize], bits - 1, forward);
136 self.do_fft_inplace_ct(&mut data[hsize..], bits - 1, forward);
137 let offs = hsize;
138 {
139 let e = data[0];
140 let o = data[hsize];
141 data[0] = e + o;
142 data[hsize] = e - o;
143 }
144 if forward {
145 for k in 1..hsize {
146 let e = data[k];
147 let o = data[k + hsize] * self.table[offs + k];
148 data[k] = e + o;
149 data[k + hsize] = e - o;
150 }
151 } else {
152 for k in 1..hsize {
153 let e = data[k];
154 let o = data[k + hsize] * !self.table[offs + k];
155 data[k] = e + o;
156 data[k + hsize] = e - o;
157 }
158 }
159 }
160
161 fn do_fft_inplace_splitradix(&mut self, data: &mut [FFTComplex], bits: u32, forward: bool) {
162 if bits == 0 { return; }
163 if bits == 1 {
164 let sum01 = data[0] + data[1];
165 let dif01 = data[0] - data[1];
166 data[0] = sum01;
167 data[1] = dif01;
168 return;
169 }
170 if bits == 2 {
171 let sum01 = data[0] + data[2];
172 let dif01 = data[0] - data[2];
173 let sum23 = data[1] + data[3];
174 let dif23 = data[1] - data[3];
175 if forward {
176 data[0] = sum01 + sum23;
177 data[1] = dif01 - dif23.rotate();
178 data[2] = sum01 - sum23;
179 data[3] = dif01 + dif23.rotate();
180 } else {
181 data[0] = sum01 + sum23;
182 data[1] = dif01 + dif23.rotate();
183 data[2] = sum01 - sum23;
184 data[3] = dif01 - dif23.rotate();
185 }
186 return;
187 }
188 let qsize = (1 << (bits - 2)) as usize;
189 let hsize = (1 << (bits - 1)) as usize;
190 let q3size = qsize + hsize;
191
192 self.do_fft_inplace_splitradix(&mut data[0 ..hsize], bits - 1, forward);
193 self.do_fft_inplace_splitradix(&mut data[hsize ..q3size], bits - 2, forward);
194 self.do_fft_inplace_splitradix(&mut data[q3size..], bits - 2, forward);
195 let off = hsize;
196 if forward {
197 {
198 let t3 = data[0 + hsize] + data[0 + q3size];
199 let t4 = (data[0 + hsize] - data[0 + q3size]).rotate();
200 let e1 = data[0];
201 let e2 = data[0 + qsize];
202 data[0] = e1 + t3;
203 data[0 + qsize] = e2 - t4;
204 data[0 + hsize] = e1 - t3;
205 data[0 + q3size] = e2 + t4;
206 }
207 for k in 1..qsize {
208 let t1 = self.table[off + k * 2 + 0] * data[k + hsize];
209 let t2 = self.table[off + k * 2 + 1] * data[k + q3size];
210 let t3 = t1 + t2;
211 let t4 = (t1 - t2).rotate();
212 let e1 = data[k];
213 let e2 = data[k + qsize];
214 data[k] = e1 + t3;
215 data[k + qsize] = e2 - t4;
216 data[k + hsize] = e1 - t3;
217 data[k + qsize * 3] = e2 + t4;
218 }
219 } else {
220 {
221 let t3 = data[0 + hsize] + data[0 + q3size];
222 let t4 = (data[0 + hsize] - data[0 + q3size]).rotate();
223 let e1 = data[0];
224 let e2 = data[0 + qsize];
225 data[0] = e1 + t3;
226 data[0 + qsize] = e2 + t4;
227 data[0 + hsize] = e1 - t3;
228 data[0 + q3size] = e2 - t4;
229 }
230 for k in 1..qsize {
231 let t1 = !self.table[off + k * 2 + 0] * data[k + hsize];
232 let t2 = !self.table[off + k * 2 + 1] * data[k + q3size];
233 let t3 = t1 + t2;
234 let t4 = (t1 - t2).rotate();
235 let e1 = data[k];
236 let e2 = data[k + qsize];
237 data[k] = e1 + t3;
238 data[k + qsize] = e2 + t4;
239 data[k + hsize] = e1 - t3;
240 data[k + qsize * 3] = e2 - t4;
241 }
242 }
243 }
244
245 pub fn do_fft(&mut self, src: &[FFTComplex], dst: &mut [FFTComplex], forward: bool) {
246 match self.mode {
247 FFTMode::Matrix => {
248 let base = if forward { -consts::PI * 2.0 / (src.len() as f32) }
249 else { consts::PI * 2.0 / (src.len() as f32) };
250 for k in 0..src.len() {
251 let mut sum = FFTC_ZERO;
252 for n in 0..src.len() {
253 let w = FFTComplex::exp(base * ((n * k) as f32));
254 sum += src[n] * w;
255 }
256 dst[k] = sum;
257 }
258 },
259 FFTMode::CooleyTukey => {
260 let bits = self.bits;
261 for k in 0..src.len() { dst[k] = src[self.perms[k]]; }
262 self.do_fft_inplace_ct(dst, bits, forward);
263 },
264 FFTMode::SplitRadix => {
265 let bits = self.bits;
266 for k in 0..src.len() { dst[k] = src[self.perms[k]]; }
267 self.do_fft_inplace_splitradix(dst, bits, forward);
268 },
269 };
270 }
271
272 pub fn do_fft_inplace(&mut self, data: &mut [FFTComplex], forward: bool) {
273 for idx in 0..self.swaps.len() {
274 let nidx = self.swaps[idx];
275 if idx != nidx {
276 let t = data[nidx];
277 data[nidx] = data[idx];
278 data[idx] = t;
279 }
280 }
281 match self.mode {
282 FFTMode::Matrix => {
283 let size = (1 << self.bits) as usize;
284 let base = if forward { -consts::PI * 2.0 / (size as f32) }
285 else { consts::PI * 2.0 / (size as f32) };
286 let mut res: Vec<FFTComplex> = Vec::with_capacity(size);
287 for k in 0..size {
288 let mut sum = FFTC_ZERO;
289 for n in 0..size {
290 let w = FFTComplex::exp(base * ((n * k) as f32));
291 sum += data[n] * w;
292 }
293 res.push(sum);
294 }
295 for k in 0..size {
296 data[k] = res[k];
297 }
298 },
299 FFTMode::CooleyTukey => {
300 let bits = self.bits;
301 self.do_fft_inplace_ct(data, bits, forward);
302 },
303 FFTMode::SplitRadix => {
304 let bits = self.bits;
305 self.do_fft_inplace_splitradix(data, bits, forward);
306 },
307 };
308 }
309 }
310
311 pub struct FFTBuilder {
312 }
313
314 fn reverse_bits(inval: u32) -> u32 {
315 const REV_TAB: [u8; 16] = [
316 0b0000, 0b1000, 0b0100, 0b1100, 0b0010, 0b1010, 0b0110, 0b1110,
317 0b0001, 0b1001, 0b0101, 0b1101, 0b0011, 0b1011, 0b0111, 0b1111,
318 ];
319
320 let mut ret = 0;
321 let mut val = inval;
322 for _ in 0..8 {
323 ret = (ret << 4) | (REV_TAB[(val & 0xF) as usize] as u32);
324 val = val >> 4;
325 }
326 ret
327 }
328
329 fn swp_idx(idx: usize, bits: u32) -> usize {
330 let s = reverse_bits(idx as u32) as usize;
331 s >> (32 - bits)
332 }
333
334 fn gen_sr_perms(swaps: &mut [usize], size: usize) {
335 if size <= 4 { return; }
336 let mut evec: Vec<usize> = Vec::with_capacity(size / 2);
337 let mut ovec1: Vec<usize> = Vec::with_capacity(size / 4);
338 let mut ovec2: Vec<usize> = Vec::with_capacity(size / 4);
339 for k in 0..size/4 {
340 evec.push (swaps[k * 4 + 0]);
341 ovec1.push(swaps[k * 4 + 1]);
342 evec.push (swaps[k * 4 + 2]);
343 ovec2.push(swaps[k * 4 + 3]);
344 }
345 for k in 0..size/2 { swaps[k] = evec[k]; }
346 for k in 0..size/4 { swaps[k + size/2] = ovec1[k]; }
347 for k in 0..size/4 { swaps[k + 3*size/4] = ovec2[k]; }
348 gen_sr_perms(&mut swaps[0..size/2], size/2);
349 gen_sr_perms(&mut swaps[size/2..3*size/4], size/4);
350 gen_sr_perms(&mut swaps[3*size/4..], size/4);
351 }
352
353 fn gen_swaps_for_perm(swaps: &mut Vec<usize>, perms: &Vec<usize>) {
354 let mut idx_arr: Vec<usize> = Vec::with_capacity(perms.len());
355 for i in 0..perms.len() { idx_arr.push(i); }
356 let mut run_size = 0;
357 let mut run_pos = 0;
358 for idx in 0..perms.len() {
359 if perms[idx] == idx_arr[idx] {
360 if run_size == 0 { run_pos = idx; }
361 run_size += 1;
362 } else {
363 for i in 0..run_size {
364 swaps.push(run_pos + i);
365 }
366 run_size = 0;
367 let mut spos = idx + 1;
368 while idx_arr[spos] != perms[idx] { spos += 1; }
369 idx_arr[spos] = idx_arr[idx];
370 idx_arr[idx] = perms[idx];
371 swaps.push(spos);
372 }
373 }
374 }
375
376 impl FFTBuilder {
377 pub fn new_fft(mode: FFTMode, size: usize) -> FFT {
378 let mut swaps: Vec<usize>;
379 let mut perms: Vec<usize>;
380 let mut table: Vec<FFTComplex>;
381 let bits = 31 - (size as u32).leading_zeros();
382 match mode {
383 FFTMode::Matrix => {
384 swaps = Vec::new();
385 perms = Vec::new();
386 table = Vec::new();
387 },
388 FFTMode::CooleyTukey => {
389 perms = Vec::with_capacity(size);
390 for i in 0..size {
391 perms.push(swp_idx(i, bits));
392 }
393 swaps = Vec::with_capacity(size);
394 table = Vec::with_capacity(size);
395 for _ in 0..4 { table.push(FFTC_ZERO); }
396 for b in 3..(bits+1) {
397 let hsize = (1 << (b - 1)) as usize;
398 let base = -consts::PI / (hsize as f32);
399 for k in 0..hsize {
400 table.push(FFTComplex::exp(base * (k as f32)));
401 }
402 }
403 },
404 FFTMode::SplitRadix => {
405 perms = Vec::with_capacity(size);
406 for i in 0..size {
407 perms.push(i);
408 }
409 gen_sr_perms(perms.as_mut_slice(), 1 << bits);
410 swaps = Vec::with_capacity(size);
411 table = Vec::with_capacity(size);
412 for _ in 0..4 { table.push(FFTC_ZERO); }
413 for b in 3..(bits+1) {
414 let qsize = (1 << (b - 2)) as usize;
415 let base = -consts::PI / ((qsize * 2) as f32);
416 for k in 0..qsize {
417 table.push(FFTComplex::exp(base * ((k * 1) as f32)));
418 table.push(FFTComplex::exp(base * ((k * 3) as f32)));
419 }
420 }
421 },
422 };
423 gen_swaps_for_perm(&mut swaps, &perms);
424 FFT { mode: mode, swaps: swaps, perms: perms, bits: bits, table: table }
425 }
426 }
427
428 pub struct RDFT {
429 table: Vec<FFTComplex>,
430 fft: FFT,
431 fwd: bool,
432 size: usize,
433 fwd_fft: bool,
434 }
435
436 fn crossadd(a: &FFTComplex, b: &FFTComplex) -> FFTComplex {
437 FFTComplex { re: a.re + b.re, im: a.im - b.im }
438 }
439
440 impl RDFT {
441 pub fn do_rdft(&mut self, src: &[FFTComplex], dst: &mut [FFTComplex]) {
442 dst.copy_from_slice(src);
443 self.do_rdft_inplace(dst);
444 }
445 pub fn do_rdft_inplace(&mut self, buf: &mut [FFTComplex]) {
446 if !self.fwd {
447 for n in 0..self.size/2 {
448 let in0 = buf[n + 1];
449 let in1 = buf[self.size - n - 1];
450
451 let t0 = crossadd(&in0, &in1);
452 let t1 = FFTComplex { re: in1.im + in0.im, im: in1.re - in0.re };
453 let tab = self.table[n];
454 let t2 = FFTComplex { re: t1.im * tab.im + t1.re * tab.re, im: t1.im * tab.re - t1.re * tab.im };
455
456 buf[n + 1] = FFTComplex { re: t0.im - t2.im, im: t0.re - t2.re }; // (t0 - t2).conj().rotate()
457 buf[self.size - n - 1] = (t0 + t2).rotate();
458 }
459 let a = buf[0].re;
460 let b = buf[0].im;
461 buf[0].re = a - b;
462 buf[0].im = a + b;
463 }
464 self.fft.do_fft_inplace(buf, self.fwd_fft);
465 if self.fwd {
466 for n in 0..self.size/2 {
467 let in0 = buf[n + 1];
468 let in1 = buf[self.size - n - 1];
469
470 let t0 = crossadd(&in0, &in1).scale(0.5);
471 let t1 = FFTComplex { re: in0.im + in1.im, im: in0.re - in1.re };
472 let t2 = t1 * self.table[n];
473
474 buf[n + 1] = crossadd(&t0, &t2);
475 buf[self.size - n - 1] = FFTComplex { re: t0.re - t2.re, im: -(t0.im + t2.im) };
476 }
477 let a = buf[0].re;
478 let b = buf[0].im;
479 buf[0].re = a + b;
480 buf[0].im = a - b;
481 } else {
482 for n in 0..self.size {
483 buf[n] = FFTComplex{ re: buf[n].im, im: buf[n].re };
484 }
485 }
486 }
487 }
488
489 pub struct RDFTBuilder {
490 }
491
492 impl RDFTBuilder {
493 pub fn new_rdft(mode: FFTMode, size: usize, forward: bool, forward_fft: bool) -> RDFT {
494 let mut table: Vec<FFTComplex> = Vec::with_capacity(size / 4);
495 let (base, scale) = if forward { (consts::PI / (size as f32), 0.5) } else { (-consts::PI / (size as f32), 1.0) };
496 for i in 0..size/2 {
497 table.push(FFTComplex::exp(base * ((i + 1) as f32)).scale(scale));
498 }
499 let fft = FFTBuilder::new_fft(mode, size);
500 RDFT { table, fft, size, fwd: forward, fwd_fft: forward_fft }
501 }
502 }
503
504
505 #[cfg(test)]
506 mod test {
507 use super::*;
508
509 #[test]
510 fn test_fft() {
511 let mut fin: [FFTComplex; 128] = [FFTC_ZERO; 128];
512 let mut fout1: [FFTComplex; 128] = [FFTC_ZERO; 128];
513 let mut fout2: [FFTComplex; 128] = [FFTC_ZERO; 128];
514 let mut fout3: [FFTComplex; 128] = [FFTC_ZERO; 128];
515 let mut fft1 = FFTBuilder::new_fft(FFTMode::Matrix, fin.len());
516 let mut fft2 = FFTBuilder::new_fft(FFTMode::CooleyTukey, fin.len());
517 let mut fft3 = FFTBuilder::new_fft(FFTMode::SplitRadix, fin.len());
518 let mut seed: u32 = 42;
519 for i in 0..fin.len() {
520 seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
521 let val = (seed >> 16) as i16;
522 fin[i].re = (val as f32) / 256.0;
523 seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
524 let val = (seed >> 16) as i16;
525 fin[i].im = (val as f32) / 256.0;
526 }
527 fft1.do_fft(&fin, &mut fout1, true);
528 fft2.do_fft(&fin, &mut fout2, true);
529 fft3.do_fft(&fin, &mut fout3, true);
530
531 for i in 0..fin.len() {
532 assert!((fout1[i].re - fout2[i].re).abs() < 1.0);
533 assert!((fout1[i].im - fout2[i].im).abs() < 1.0);
534 assert!((fout1[i].re - fout3[i].re).abs() < 1.0);
535 assert!((fout1[i].im - fout3[i].im).abs() < 1.0);
536 }
537 fft1.do_fft_inplace(&mut fout1, false);
538 fft2.do_fft_inplace(&mut fout2, false);
539 fft3.do_fft_inplace(&mut fout3, false);
540
541 let sc = 1.0 / (fin.len() as f32);
542 for i in 0..fin.len() {
543 assert!((fin[i].re - fout1[i].re * sc).abs() < 1.0);
544 assert!((fin[i].im - fout1[i].im * sc).abs() < 1.0);
545 assert!((fout1[i].re - fout2[i].re).abs() < 1.0);
546 assert!((fout1[i].im - fout2[i].im).abs() < 1.0);
547 assert!((fout1[i].re - fout3[i].re).abs() < 1.0);
548 assert!((fout1[i].im - fout3[i].im).abs() < 1.0);
549 }
550 }
551
552 #[test]
553 fn test_rdft() {
554 let mut fin: [FFTComplex; 128] = [FFTC_ZERO; 128];
555 let mut fout1: [FFTComplex; 128] = [FFTC_ZERO; 128];
556 let mut rdft = RDFTBuilder::new_rdft(FFTMode::SplitRadix, fin.len(), true, true);
557 let mut seed: u32 = 42;
558 for i in 0..fin.len() {
559 seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
560 let val = (seed >> 16) as i16;
561 fin[i].re = (val as f32) / 256.0;
562 seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
563 let val = (seed >> 16) as i16;
564 fin[i].im = (val as f32) / 256.0;
565 }
566 rdft.do_rdft(&fin, &mut fout1);
567 let mut irdft = RDFTBuilder::new_rdft(FFTMode::SplitRadix, fin.len(), false, true);
568 irdft.do_rdft_inplace(&mut fout1);
569
570 for i in 0..fin.len() {
571 let tst = fout1[i].scale(0.5/(fout1.len() as f32));
572 assert!((tst.re - fin[i].re).abs() < 1.0);
573 assert!((tst.im - fin[i].im).abs() < 1.0);
574 }
575 }
576 }