remove trailing whitespaces
[nihav.git] / nihav-core / src / dsp / dct.rs
CommitLineData
5a990253
KS
1use std::f32::consts;
2
3#[allow(non_camel_case_types)]
4#[derive(Clone,Copy,Debug,PartialEq)]
5pub enum DCTMode {
6 DCT_I,
7 DCT_II,
8 DCT_III,
9 DCT_IV,
10 DST_I,
11 DST_II,
12 DST_III,
13 DST_IV,
14}
15
16#[allow(dead_code)]
17pub struct DCT {
18 tmp: Vec<f32>,
19 tab: Vec<f32>,
20 swaps: Vec<usize>,
21 perms: Vec<usize>,
22 mode: DCTMode,
23 size: usize,
24 is_pow2: bool,
25 perm_tab: Vec<usize>,
26}
27
28impl DCT {
29 pub fn new(mode: DCTMode, size: usize) -> Self {
30 let bits = 31 - (size as u32).leading_zeros();
31 let is_pow2 = (size & (size - 1)) == 0;
32 let mut tmp: Vec<f32>;
33 let mut swaps: Vec<usize> = Vec::new();
34 let mut perms: Vec<usize>;
35 let mut perm_tab: Vec<usize>;
36 tmp = Vec::with_capacity(size);
37 tmp.resize(size, 0.0);
38 if !is_pow2 {
39 perms = Vec::new();
40 perm_tab = Vec::new();
41 } else {
42 perms = Vec::with_capacity(size);
43 for i in 0..size { perms.push(swp_idx(i, bits)); }
44 gen_swaps_for_perm(&mut swaps, &perms);
45
46 perm_tab = Vec::with_capacity(size);
47 perm_tab.push(0); // padding
48 perm_tab.push(0); // size = 1
49 perm_tab.push(0); // size = 2
50 perm_tab.push(1);
fdb4b2fb 51 for blen in 2..=bits {
5a990253
KS
52 let ssize = 1 << blen;
53 for i in 0..ssize { perm_tab.push(swp_idx(i, blen)); }
54 }
55 }
56 let mut tab: Vec<f32>;
57 match mode {
58 DCTMode::DCT_II |
59 DCTMode::DST_II |
60 DCTMode::DCT_III |
61 DCTMode::DST_III |
62 DCTMode::DCT_IV => {
63 tab = Vec::with_capacity(size * 2);
64 tab.push(1.0); // padding
65 tab.push(0.0);
66 tab.push((consts::PI / 8.0).sin()); // size = 1
67 tab.push((consts::PI / 8.0).cos());
68 if bits > 1 {
fdb4b2fb 69 for blen in 1..=bits {
5a990253
KS
70 let tsize = 1 << blen;
71 let base = consts::PI / ((tsize * 8) as f32);
72 for i in 0..tsize {
73 let phi = ((i * 2 + 1) as f32) * base;
74 tab.push(phi.sin());
75 tab.push(phi.cos());
76 }
77 }
78 }
79 },
80/* DCTMode::DST_IV => {
81 },*/
82 _ => { tab = Vec::new(); },
83 };
d24468d9 84
5a990253
KS
85 Self { tmp, tab, mode, size, swaps, perms, is_pow2, perm_tab }
86 }
87 fn can_do_fast(&mut self) -> bool {
88 if !self.is_pow2 { return false; }
89 match self.mode {
90 DCTMode::DCT_I | DCTMode::DST_I | DCTMode::DST_IV => false,
91 _ => true,
92 }
93 }
94 fn inplace_fast_dct(&mut self, dst: &mut [f32]) {
95 match self.mode {
96 DCTMode::DCT_II => {
97 dct_II_inplace(dst, self.size, 1, &self.tab, &self.perm_tab);
98 },
99 DCTMode::DST_II => {
100 dst_II_inplace(dst, self.size, 1, &self.tab, &self.perm_tab);
101 },
102 DCTMode::DCT_III => {
103 dct_III_inplace(dst, self.size, 1, &self.tab, &self.perm_tab);
104 },
105 DCTMode::DST_III => {
106 dst_III_inplace(dst, self.size, 1, &self.tab, &self.perm_tab);
107 },
108 DCTMode::DCT_IV => {
109 dct_IV_inplace(dst, self.size, 1, &self.tab, &self.perm_tab);
110 },
111 _ => unreachable!(),
112 };
113 }
114 pub fn do_dct(&mut self, src: &[f32], dst: &mut [f32]) {
115 if self.can_do_fast() {
116 for (i, ni) in self.perms.iter().enumerate() { dst[i] = src[*ni]; }
117 self.inplace_fast_dct(dst);
118 } else {
119 do_ref_dct(self.mode, src, dst, self.size);
120 }
121 }
122 pub fn do_dct_inplace(&mut self, buf: &mut [f32]) {
123 if self.can_do_fast() {
124 swap_buf(buf, &self.swaps);
125 self.inplace_fast_dct(buf);
126 } else {
127 self.tmp.copy_from_slice(&buf[0..self.size]);
128 do_ref_dct(self.mode, &self.tmp, buf, self.size);
129 }
130 }
131 pub fn get_scale(&self) -> f32 {
132 let fsize = self.size as f32;
133 match self.mode {
134 DCTMode::DCT_I => 2.0 / (fsize - 1.0),
135 DCTMode::DST_I => 2.0 / (fsize + 1.0),
136 DCTMode::DCT_II => 1.0,
137 DCTMode::DCT_III=> 1.0,
138 DCTMode::DST_II => 1.0,
139 DCTMode::DST_III=> 1.0,
140 DCTMode::DCT_IV => 1.0,
141 _ => 2.0 / fsize,
142 }
143 }
144}
145
146fn reverse_bits(inval: u32) -> u32 {
147 const REV_TAB: [u8; 16] = [
148 0b0000, 0b1000, 0b0100, 0b1100, 0b0010, 0b1010, 0b0110, 0b1110,
149 0b0001, 0b1001, 0b0101, 0b1101, 0b0011, 0b1011, 0b0111, 0b1111,
150 ];
151
152 let mut ret = 0;
153 let mut val = inval;
154 for _ in 0..8 {
155 ret = (ret << 4) | (REV_TAB[(val & 0xF) as usize] as u32);
fdb4b2fb 156 val >>= 4;
5a990253
KS
157 }
158 ret
159}
160
161fn swp_idx(idx: usize, bits: u32) -> usize {
162 let s = reverse_bits(idx as u32) as usize;
163 s >> (32 - bits)
164}
165
e243ceb4 166fn gen_swaps_for_perm(swaps: &mut Vec<usize>, perms: &[usize]) {
5a990253
KS
167 let mut idx_arr: Vec<usize> = Vec::with_capacity(perms.len());
168 for i in 0..perms.len() { idx_arr.push(i); }
169 let mut run_size = 0;
170 let mut run_pos = 0;
171 for idx in 0..perms.len() {
172 if perms[idx] == idx_arr[idx] {
173 if run_size == 0 { run_pos = idx; }
174 run_size += 1;
175 } else {
176 for i in 0..run_size {
177 swaps.push(run_pos + i);
178 }
179 run_size = 0;
180 let mut spos = idx + 1;
181 while idx_arr[spos] != perms[idx] { spos += 1; }
182 idx_arr[spos] = idx_arr[idx];
183 idx_arr[idx] = perms[idx];
184 swaps.push(spos);
185 }
186 }
187}
188
e243ceb4 189fn swap_buf(buf: &mut [f32], swaps: &[usize]) {
5a990253
KS
190 for (idx, nidx) in swaps.iter().enumerate() {
191 if idx != *nidx {
e243ceb4 192 buf.swap(*nidx, idx);
5a990253
KS
193 }
194 }
195}
196
197fn do_ref_dct(mode: DCTMode, src: &[f32], dst: &mut [f32], size: usize) {
198 match mode {
199 DCTMode::DCT_I => dct_I_ref(src, dst, size),
200 DCTMode::DST_I => dst_I_ref(src, dst, size),
201 DCTMode::DCT_II => dct_II_ref(src, dst, size),
202 DCTMode::DST_II => dst_II_ref(src, dst, size),
203 DCTMode::DCT_III => dct_III_ref(src, dst, size),
204 DCTMode::DST_III => dst_III_ref(src, dst, size),
205 DCTMode::DCT_IV => dct_IV_ref(src, dst, size),
206 DCTMode::DST_IV => dst_IV_ref(src, dst, size),
207 };
208}
209
210#[allow(non_snake_case)]
211fn dct_I_ref(src: &[f32], dst: &mut [f32], size: usize) {
212 let base = consts::PI / ((size - 1) as f32);
213 for k in 0..size {
214 let mut sum = (src[0] + (if (k & 1) != 0 { -src[size - 1] } else { src[size - 1] })) * 0.5;
215 for n in 1..size-1 {
216 sum += src[n] * (base * ((n * k) as f32)).cos();
217 }
218 dst[k] = sum;
219 }
220}
221
222#[allow(non_snake_case)]
223fn dst_I_ref(src: &[f32], dst: &mut [f32], size: usize) {
224 let base = consts::PI / ((size + 1) as f32);
225 for k in 0..size {
226 let mut sum = 0.0;
227 for n in 0..size {
228 sum += src[n] * (base * (((n + 1) * (k + 1)) as f32)).sin();
229 }
230 dst[k] = sum;
231 }
232}
233
234#[allow(non_snake_case)]
235fn dct_II_ref(src: &[f32], dst: &mut [f32], size: usize) {
236 let base = consts::PI / (size as f32);
237 for k in 0..size {
238 let mut sum = 0.0;
239 for n in 0..size {
240 sum += src[n] * (base * ((n as f32) + 0.5) * (k as f32)).cos();
241 }
242 dst[k] = sum * (if k == 0 { (1.0 / (size as f32)).sqrt() } else { (2.0 / (size as f32)).sqrt() });
243 }
244}
245
246#[allow(non_snake_case)]
247fn dst_II_ref(src: &[f32], dst: &mut [f32], size: usize) {
248 let base = consts::PI / (size as f32);
249 for k in 0..size {
250 let mut sum = 0.0;
251 let kmod = (k + 1) as f32;
252 for n in 0..size {
253 sum += src[n] * (base * ((n as f32) + 0.5) * kmod).sin();
254 }
255 dst[k] = sum * (2.0 / (size as f32)).sqrt();
256 }
257 dst[size - 1] /= consts::SQRT_2;
258}
259
260#[allow(non_snake_case)]
261fn dct_III_ref(src: &[f32], dst: &mut [f32], size: usize) {
262 let base = consts::PI / (size as f32);
263 for k in 0..size {
264 let mut sum = src[0] / consts::SQRT_2;
265 let kmod = (k as f32) + 0.5;
266 for n in 1..size {
267 sum += src[n] * (base * (n as f32) * kmod).cos();
268 }
269 dst[k] = sum * (2.0 / (size as f32)).sqrt();
270 }
271}
272
273#[allow(non_snake_case)]
274fn dst_III_ref(src: &[f32], dst: &mut [f32], size: usize) {
275 let base = consts::PI / (size as f32);
276 for k in 0..size {
277 let mut sum = 0.0;
278 let kmod = (k as f32) + 0.5;
279 for n in 0..size-1 {
280 sum += src[n] * (base * ((n + 1) as f32) * kmod).sin();
281 }
282 sum += src[size - 1] / consts::SQRT_2 * (base * (size as f32) * kmod).sin();
283 dst[k] = sum * (2.0 / (size as f32)).sqrt();
284 }
285}
286
287#[allow(non_snake_case)]
288fn dct_IV_ref(src: &[f32], dst: &mut [f32], size: usize) {
289 let base = consts::PI / (size as f32);
290 for k in 0..size {
291 let mut sum = 0.0;
292 let kmod = (k as f32) + 0.5;
293 for n in 0..size {
294 sum += src[n] * (base * ((n as f32) + 0.5) * kmod).cos();
295 }
296 dst[k] = sum;
297 }
298}
299
300#[allow(non_snake_case)]
301fn dst_IV_ref(src: &[f32], dst: &mut [f32], size: usize) {
302 let base = consts::PI / (size as f32);
303 for k in 0..size {
304 let mut sum = 0.0;
305 let kmod = (k as f32) + 0.5;
306 for n in 0..size {
307 sum += src[n] * (base * ((n as f32) + 0.5) * kmod).sin();
308 }
309 dst[k] = sum;
310 }
311}
312
313const DCT_II_C0: f32 = 0.65328148243818826393; // cos(1*PI/8) / sqrt(2)
314const DCT_II_C1: f32 = 0.27059805007309849220; // cos(3*PI/8) / sqrt(2)
315
316#[allow(non_snake_case)]
317fn dct_II_inplace(buf: &mut [f32], size: usize, step: usize, tab: &[f32], perm_tab: &[usize]) {
318 match size {
319 0 | 1 => {},
320 2 => {
321 let i0 = buf[0];
322 let i1 = buf[step];
323 buf[0] = (i0 + i1) / consts::SQRT_2;
324 buf[step] = (i0 - i1) / consts::SQRT_2;
325 },
326 4 => {
327 let i0 = buf[0 * step];
328 let i1 = buf[2 * step];
329 let i2 = buf[1 * step];
330 let i3 = buf[3 * step];
331 let t0 = (i0 + i3) * 0.5;
332 let t1 = (i1 + i2) * 0.5;
333 buf[0 * step] = t0 + t1;
334 buf[2 * step] = t0 - t1;
335 let t0 = i0 - i3;
336 let t1 = i1 - i2;
337 buf[1 * step] = DCT_II_C0 * t0 + DCT_II_C1 * t1;
338 buf[3 * step] = DCT_II_C1 * t0 - DCT_II_C0 * t1;
339 },
340 _ => {
341 let hsize = size >> 1;
342 for i in 0..hsize {
343 let i0 = buf[i * step];
344 let i1 = buf[(size - 1 - i) * step];
345 if (i & 1) == 0 {
346 buf[i * step] = (i0 + i1) / consts::SQRT_2;
347 buf[(size - 1 - i) * step] = (i0 - i1) / consts::SQRT_2;
348 } else {
349 buf[i * step] = (i1 - i0) / consts::SQRT_2;
350 buf[(size - 1 - i) * step] = (i1 + i0) / consts::SQRT_2;
351 }
352 }
353 dct_II_inplace(buf, hsize, step * 2, tab, perm_tab);
354 dct_II_part2_inplace(&mut buf[step..], hsize, step * 2, tab, perm_tab);
355 },
356 };
357}
358
359#[allow(non_snake_case)]
360fn dct_II_part2_inplace(buf: &mut [f32], size: usize, step: usize, tab: &[f32], perm_tab: &[usize]) {
361 let hsize = size >> 1;
362// todo optimise for size = 4
363 for i in 0..hsize {
364 let i0 = buf[perm_tab[size + i] * step];
365 let i1 = buf[perm_tab[size + size - i - 1] * step];
366 let c0 = tab[size + i * 2 + 0];
367 let c1 = tab[size + i * 2 + 1];
368 buf[perm_tab[size + i] * step] = c0 * i0 + c1 * i1;
369 buf[perm_tab[size + size - i - 1] * step] = c0 * i1 - c1 * i0;
370 }
371
372 dct_II_inplace(buf, hsize, step * 2, tab, perm_tab);
373 dst_II_inplace(&mut buf[step..], hsize, step * 2, tab, perm_tab);
374
375 buf[(size - 1) * step] = -buf[(size - 1) * step];
376 for i in 1..hsize {
377 let (i0, i1) = if (i & 1) == 0 {
378 (buf[i * step * 2], -buf[i * step * 2 - step])
379 } else {
380 (buf[i * step * 2], buf[i * step * 2 - step])
381 };
382 buf[i * step * 2 - step] = (i0 + i1) / consts::SQRT_2;
383 buf[i * step * 2] = (i0 - i1) / consts::SQRT_2;
384 }
385}
386
387#[allow(non_snake_case)]
388fn dst_II_inplace(buf: &mut [f32], size: usize, step: usize, tab: &[f32], perm_tab: &[usize]) {
389 if size <= 1 { return; }
390 let hsize = size >> 1;
391 for i in hsize..size { buf[i * step] = -buf[i * step]; }
392 dct_II_inplace(buf, size, step, tab, perm_tab);
393 for i in 0..hsize {
394 let idx0 = i * step;
395 let idx1 = (size - 1 - i) * step;
e243ceb4 396 buf.swap(idx0, idx1);
5a990253
KS
397 }
398}
399
400#[allow(non_snake_case)]
401fn dct_III_inplace(buf: &mut [f32], size: usize, step: usize, tab: &[f32], perm_tab: &[usize]) {
402 if size <= 1 { return; }
403 let hsize = size >> 1;
404 dct_III_inplace(buf, hsize, step, tab, perm_tab);
405 dct_IV_inplace(&mut buf[step*hsize..], hsize, step, tab, perm_tab);
406 for i in 0..(size >> 2) {
e243ceb4 407 buf.swap((size - 1 - i) * step, (hsize + i) * step);
5a990253
KS
408 }
409 for i in 0..hsize {
410 let i0 = buf[i * step] / consts::SQRT_2;
411 let i1 = buf[(size-i-1) * step] / consts::SQRT_2;
412 buf[i * step] = i0 + i1;
413 buf[(size-i-1) * step] = i0 - i1;
414 }
415}
416
417#[allow(non_snake_case)]
418fn dst_III_inplace(buf: &mut [f32], size: usize, step: usize, tab: &[f32], perm_tab: &[usize]) {
419 if size <= 1 { return; }
420 let hsize = size >> 1;
421 for i in 0..hsize {
422 let idx0 = i * step;
423 let idx1 = (size - 1 - i) * step;
e243ceb4 424 buf.swap(idx0, idx1);
5a990253
KS
425 }
426 dct_III_inplace(buf, size, step, tab, perm_tab);
427 for i in 0..hsize { buf[i * 2 * step + step] = -buf[i * 2 * step + step]; }
428}
429
430#[allow(non_snake_case)]
431fn dct_IV_inplace(buf: &mut [f32], size: usize, step: usize, tab: &[f32], perm_tab: &[usize]) {
432 if size <= 1 { return; }
433 let hsize = size >> 1;
434
435 for i in 0..hsize {
436 let idx0 = perm_tab[size + i];
437 let idx1 = size - 1 - idx0;
438 let i0 = buf[idx0 * step];
439 let i1 = buf[idx1 * step];
440 let c0 = tab[size + i * 2 + 1];
441 let c1 = tab[size + i * 2 + 0];
442 buf[idx0 * step] = c0 * i0 + c1 * i1;
443 buf[idx1 * step] = c0 * i1 - c1 * i0;
444 }
445 for i in (hsize+1..size).step_by(2) {
446 buf[i] = -buf[i];
447 }
448 dct_II_inplace(buf, hsize, step * 2, tab, perm_tab);
449 dct_II_inplace(&mut buf[step..], hsize, step * 2, tab, perm_tab);
450 for i in 0..(size >> 2) {
e243ceb4 451 buf.swap((size - 1 - i * 2) * step, (i * 2 + 1) * step);
5a990253
KS
452 }
453 for i in (3..size).step_by(4) {
454 buf[i] = -buf[i];
455 }
456 buf[0] *= consts::SQRT_2;
457 buf[(size - 1) * step] *= -consts::SQRT_2;
458 for i in 0..hsize-1 {
459 let i0 = buf[(i * 2 + 2) * step];
460 let i1 = buf[(i * 2 + 1) * step];
461 buf[(i * 2 + 2) * step] = i0 + i1;
462 buf[(i * 2 + 1) * step] = i0 - i1;
463 }
464 for i in 0..size {
465 buf[i * step] /= consts::SQRT_2;
466 }
467}
468
469#[cfg(test)]
470mod test {
471 use super::*;
472
473 fn test_pair(mode: DCTMode, invmode: DCTMode, size: usize) {
474 println!("testing {:?} -> {:?}", mode, invmode);
475 let mut fin: Vec<f32> = Vec::with_capacity(size);
476 let mut out: Vec<f32> = Vec::with_capacity(size);
477 out.resize(size, 0.0);
478 let mut seed: u32 = 42;
479 for _ in 0..size {
480 seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
481 let val = (seed >> 16) as i16;
482 fin.push((val as f32) / 256.0);
483 }
484 let mut dct = DCT::new(mode, size);
485 dct.do_dct(&fin, &mut out);
486 let mut dct = DCT::new(invmode, size);
487 dct.do_dct_inplace(&mut out);
488
489 let scale = dct.get_scale();
490 for i in 0..fin.len() {
491 assert!((fin[i] - out[i]*scale).abs() < 1.0e-2);
492 }
493 }
494 #[test]
495 fn test_dct() {
496 test_pair(DCTMode::DCT_I, DCTMode::DCT_I, 32);
497 test_pair(DCTMode::DST_I, DCTMode::DST_I, 32);
498 test_pair(DCTMode::DCT_II, DCTMode::DCT_III, 32);
499 test_pair(DCTMode::DST_II, DCTMode::DST_III, 32);
500 test_pair(DCTMode::DCT_III, DCTMode::DCT_II, 32);
501 test_pair(DCTMode::DST_III, DCTMode::DST_II, 32);
502 test_pair(DCTMode::DCT_IV, DCTMode::DCT_IV, 32);
503 test_pair(DCTMode::DST_IV, DCTMode::DST_IV, 32);
504 }
505}