split nihav-codec-support crate from nihav-core
[nihav.git] / nihav-codec-support / src / dsp / dct.rs
1 //! Discrete 1-D cosine and sine transforms.
2 use std::f32::consts;
3
4 /// A list of DCT and DST modes.
5 #[allow(non_camel_case_types)]
6 #[derive(Clone,Copy,Debug,PartialEq)]
7 pub enum DCTMode {
8 DCT_I,
9 DCT_II,
10 DCT_III,
11 DCT_IV,
12 DST_I,
13 DST_II,
14 DST_III,
15 DST_IV,
16 }
17
18 /// DCT/DST working context.
19 #[allow(dead_code)]
20 pub struct DCT {
21 tmp: Vec<f32>,
22 tab: Vec<f32>,
23 swaps: Vec<usize>,
24 perms: Vec<usize>,
25 mode: DCTMode,
26 size: usize,
27 is_pow2: bool,
28 perm_tab: Vec<usize>,
29 }
30
31 impl DCT {
32 /// Constructs a new context for the selected DCT or DST operation.
33 pub fn new(mode: DCTMode, size: usize) -> Self {
34 let bits = 31 - (size as u32).leading_zeros();
35 let is_pow2 = (size & (size - 1)) == 0;
36 let mut tmp: Vec<f32>;
37 let mut swaps: Vec<usize> = Vec::new();
38 let mut perms: Vec<usize>;
39 let mut perm_tab: Vec<usize>;
40 tmp = Vec::with_capacity(size);
41 tmp.resize(size, 0.0);
42 if !is_pow2 {
43 perms = Vec::new();
44 perm_tab = Vec::new();
45 } else {
46 perms = Vec::with_capacity(size);
47 for i in 0..size { perms.push(swp_idx(i, bits)); }
48 gen_swaps_for_perm(&mut swaps, &perms);
49
50 perm_tab = Vec::with_capacity(size);
51 perm_tab.push(0); // padding
52 perm_tab.push(0); // size = 1
53 perm_tab.push(0); // size = 2
54 perm_tab.push(1);
55 for blen in 2..=bits {
56 let ssize = 1 << blen;
57 for i in 0..ssize { perm_tab.push(swp_idx(i, blen)); }
58 }
59 }
60 let mut tab: Vec<f32>;
61 match mode {
62 DCTMode::DCT_II |
63 DCTMode::DST_II |
64 DCTMode::DCT_III |
65 DCTMode::DST_III |
66 DCTMode::DCT_IV => {
67 tab = Vec::with_capacity(size * 2);
68 tab.push(1.0); // padding
69 tab.push(0.0);
70 tab.push((consts::PI / 8.0).sin()); // size = 1
71 tab.push((consts::PI / 8.0).cos());
72 if bits > 1 {
73 for blen in 1..=bits {
74 let tsize = 1 << blen;
75 let base = consts::PI / ((tsize * 8) as f32);
76 for i in 0..tsize {
77 let phi = ((i * 2 + 1) as f32) * base;
78 tab.push(phi.sin());
79 tab.push(phi.cos());
80 }
81 }
82 }
83 },
84 /* DCTMode::DST_IV => {
85 },*/
86 _ => { tab = Vec::new(); },
87 };
88
89 Self { tmp, tab, mode, size, swaps, perms, is_pow2, perm_tab }
90 }
91 fn can_do_fast(&mut self) -> bool {
92 if !self.is_pow2 { return false; }
93 match self.mode {
94 DCTMode::DCT_I | DCTMode::DST_I | DCTMode::DST_IV => false,
95 _ => true,
96 }
97 }
98 fn inplace_fast_dct(&mut self, dst: &mut [f32]) {
99 match self.mode {
100 DCTMode::DCT_II => {
101 dct_II_inplace(dst, self.size, 1, &self.tab, &self.perm_tab);
102 },
103 DCTMode::DST_II => {
104 dst_II_inplace(dst, self.size, 1, &self.tab, &self.perm_tab);
105 },
106 DCTMode::DCT_III => {
107 dct_III_inplace(dst, self.size, 1, &self.tab, &self.perm_tab);
108 },
109 DCTMode::DST_III => {
110 dst_III_inplace(dst, self.size, 1, &self.tab, &self.perm_tab);
111 },
112 DCTMode::DCT_IV => {
113 dct_IV_inplace(dst, self.size, 1, &self.tab, &self.perm_tab);
114 },
115 _ => unreachable!(),
116 };
117 }
118 /// Performs DCT/DST.
119 pub fn do_dct(&mut self, src: &[f32], dst: &mut [f32]) {
120 if self.can_do_fast() {
121 for (i, ni) in self.perms.iter().enumerate() { dst[i] = src[*ni]; }
122 self.inplace_fast_dct(dst);
123 } else {
124 do_ref_dct(self.mode, src, dst, self.size);
125 }
126 }
127 /// Performs inplace DCT/DST.
128 pub fn do_dct_inplace(&mut self, buf: &mut [f32]) {
129 if self.can_do_fast() {
130 swap_buf(buf, &self.swaps);
131 self.inplace_fast_dct(buf);
132 } else {
133 self.tmp.copy_from_slice(&buf[0..self.size]);
134 do_ref_dct(self.mode, &self.tmp, buf, self.size);
135 }
136 }
137 /// Returns the scale for output normalisation.
138 pub fn get_scale(&self) -> f32 {
139 let fsize = self.size as f32;
140 match self.mode {
141 DCTMode::DCT_I => 2.0 / (fsize - 1.0),
142 DCTMode::DST_I => 2.0 / (fsize + 1.0),
143 DCTMode::DCT_II => 1.0,
144 DCTMode::DCT_III=> 1.0,
145 DCTMode::DST_II => 1.0,
146 DCTMode::DST_III=> 1.0,
147 DCTMode::DCT_IV => 1.0,
148 _ => 2.0 / fsize,
149 }
150 }
151 }
152
153 fn reverse_bits(inval: u32) -> u32 {
154 const REV_TAB: [u8; 16] = [
155 0b0000, 0b1000, 0b0100, 0b1100, 0b0010, 0b1010, 0b0110, 0b1110,
156 0b0001, 0b1001, 0b0101, 0b1101, 0b0011, 0b1011, 0b0111, 0b1111,
157 ];
158
159 let mut ret = 0;
160 let mut val = inval;
161 for _ in 0..8 {
162 ret = (ret << 4) | (REV_TAB[(val & 0xF) as usize] as u32);
163 val >>= 4;
164 }
165 ret
166 }
167
168 fn swp_idx(idx: usize, bits: u32) -> usize {
169 let s = reverse_bits(idx as u32) as usize;
170 s >> (32 - bits)
171 }
172
173 fn gen_swaps_for_perm(swaps: &mut Vec<usize>, perms: &[usize]) {
174 let mut idx_arr: Vec<usize> = Vec::with_capacity(perms.len());
175 for i in 0..perms.len() { idx_arr.push(i); }
176 let mut run_size = 0;
177 let mut run_pos = 0;
178 for idx in 0..perms.len() {
179 if perms[idx] == idx_arr[idx] {
180 if run_size == 0 { run_pos = idx; }
181 run_size += 1;
182 } else {
183 for i in 0..run_size {
184 swaps.push(run_pos + i);
185 }
186 run_size = 0;
187 let mut spos = idx + 1;
188 while idx_arr[spos] != perms[idx] { spos += 1; }
189 idx_arr[spos] = idx_arr[idx];
190 idx_arr[idx] = perms[idx];
191 swaps.push(spos);
192 }
193 }
194 }
195
196 fn swap_buf(buf: &mut [f32], swaps: &[usize]) {
197 for (idx, nidx) in swaps.iter().enumerate() {
198 if idx != *nidx {
199 buf.swap(*nidx, idx);
200 }
201 }
202 }
203
204 fn do_ref_dct(mode: DCTMode, src: &[f32], dst: &mut [f32], size: usize) {
205 match mode {
206 DCTMode::DCT_I => dct_I_ref(src, dst, size),
207 DCTMode::DST_I => dst_I_ref(src, dst, size),
208 DCTMode::DCT_II => dct_II_ref(src, dst, size),
209 DCTMode::DST_II => dst_II_ref(src, dst, size),
210 DCTMode::DCT_III => dct_III_ref(src, dst, size),
211 DCTMode::DST_III => dst_III_ref(src, dst, size),
212 DCTMode::DCT_IV => dct_IV_ref(src, dst, size),
213 DCTMode::DST_IV => dst_IV_ref(src, dst, size),
214 };
215 }
216
217 #[allow(non_snake_case)]
218 fn dct_I_ref(src: &[f32], dst: &mut [f32], size: usize) {
219 let base = consts::PI / ((size - 1) as f32);
220 for k in 0..size {
221 let mut sum = (src[0] + (if (k & 1) != 0 { -src[size - 1] } else { src[size - 1] })) * 0.5;
222 for n in 1..size-1 {
223 sum += src[n] * (base * ((n * k) as f32)).cos();
224 }
225 dst[k] = sum;
226 }
227 }
228
229 #[allow(non_snake_case)]
230 fn dst_I_ref(src: &[f32], dst: &mut [f32], size: usize) {
231 let base = consts::PI / ((size + 1) as f32);
232 for k in 0..size {
233 let mut sum = 0.0;
234 for n in 0..size {
235 sum += src[n] * (base * (((n + 1) * (k + 1)) as f32)).sin();
236 }
237 dst[k] = sum;
238 }
239 }
240
241 #[allow(non_snake_case)]
242 fn dct_II_ref(src: &[f32], dst: &mut [f32], size: usize) {
243 let base = consts::PI / (size as f32);
244 for k in 0..size {
245 let mut sum = 0.0;
246 for n in 0..size {
247 sum += src[n] * (base * ((n as f32) + 0.5) * (k as f32)).cos();
248 }
249 dst[k] = sum * (if k == 0 { (1.0 / (size as f32)).sqrt() } else { (2.0 / (size as f32)).sqrt() });
250 }
251 }
252
253 #[allow(non_snake_case)]
254 fn dst_II_ref(src: &[f32], dst: &mut [f32], size: usize) {
255 let base = consts::PI / (size as f32);
256 for k in 0..size {
257 let mut sum = 0.0;
258 let kmod = (k + 1) as f32;
259 for n in 0..size {
260 sum += src[n] * (base * ((n as f32) + 0.5) * kmod).sin();
261 }
262 dst[k] = sum * (2.0 / (size as f32)).sqrt();
263 }
264 dst[size - 1] /= consts::SQRT_2;
265 }
266
267 #[allow(non_snake_case)]
268 fn dct_III_ref(src: &[f32], dst: &mut [f32], size: usize) {
269 let base = consts::PI / (size as f32);
270 for k in 0..size {
271 let mut sum = src[0] / consts::SQRT_2;
272 let kmod = (k as f32) + 0.5;
273 for n in 1..size {
274 sum += src[n] * (base * (n as f32) * kmod).cos();
275 }
276 dst[k] = sum * (2.0 / (size as f32)).sqrt();
277 }
278 }
279
280 #[allow(non_snake_case)]
281 fn dst_III_ref(src: &[f32], dst: &mut [f32], size: usize) {
282 let base = consts::PI / (size as f32);
283 for k in 0..size {
284 let mut sum = 0.0;
285 let kmod = (k as f32) + 0.5;
286 for n in 0..size-1 {
287 sum += src[n] * (base * ((n + 1) as f32) * kmod).sin();
288 }
289 sum += src[size - 1] / consts::SQRT_2 * (base * (size as f32) * kmod).sin();
290 dst[k] = sum * (2.0 / (size as f32)).sqrt();
291 }
292 }
293
294 #[allow(non_snake_case)]
295 fn dct_IV_ref(src: &[f32], dst: &mut [f32], size: usize) {
296 let base = consts::PI / (size as f32);
297 for k in 0..size {
298 let mut sum = 0.0;
299 let kmod = (k as f32) + 0.5;
300 for n in 0..size {
301 sum += src[n] * (base * ((n as f32) + 0.5) * kmod).cos();
302 }
303 dst[k] = sum;
304 }
305 }
306
307 #[allow(non_snake_case)]
308 fn dst_IV_ref(src: &[f32], dst: &mut [f32], size: usize) {
309 let base = consts::PI / (size as f32);
310 for k in 0..size {
311 let mut sum = 0.0;
312 let kmod = (k as f32) + 0.5;
313 for n in 0..size {
314 sum += src[n] * (base * ((n as f32) + 0.5) * kmod).sin();
315 }
316 dst[k] = sum;
317 }
318 }
319
320 const DCT_II_C0: f32 = 0.65328148243818826393; // cos(1*PI/8) / sqrt(2)
321 const DCT_II_C1: f32 = 0.27059805007309849220; // cos(3*PI/8) / sqrt(2)
322
323 #[allow(non_snake_case)]
324 fn dct_II_inplace(buf: &mut [f32], size: usize, step: usize, tab: &[f32], perm_tab: &[usize]) {
325 match size {
326 0 | 1 => {},
327 2 => {
328 let i0 = buf[0];
329 let i1 = buf[step];
330 buf[0] = (i0 + i1) / consts::SQRT_2;
331 buf[step] = (i0 - i1) / consts::SQRT_2;
332 },
333 4 => {
334 let i0 = buf[0 * step];
335 let i1 = buf[2 * step];
336 let i2 = buf[1 * step];
337 let i3 = buf[3 * step];
338 let t0 = (i0 + i3) * 0.5;
339 let t1 = (i1 + i2) * 0.5;
340 buf[0 * step] = t0 + t1;
341 buf[2 * step] = t0 - t1;
342 let t0 = i0 - i3;
343 let t1 = i1 - i2;
344 buf[1 * step] = DCT_II_C0 * t0 + DCT_II_C1 * t1;
345 buf[3 * step] = DCT_II_C1 * t0 - DCT_II_C0 * t1;
346 },
347 _ => {
348 let hsize = size >> 1;
349 for i in 0..hsize {
350 let i0 = buf[i * step];
351 let i1 = buf[(size - 1 - i) * step];
352 if (i & 1) == 0 {
353 buf[i * step] = (i0 + i1) / consts::SQRT_2;
354 buf[(size - 1 - i) * step] = (i0 - i1) / consts::SQRT_2;
355 } else {
356 buf[i * step] = (i1 - i0) / consts::SQRT_2;
357 buf[(size - 1 - i) * step] = (i1 + i0) / consts::SQRT_2;
358 }
359 }
360 dct_II_inplace(buf, hsize, step * 2, tab, perm_tab);
361 dct_II_part2_inplace(&mut buf[step..], hsize, step * 2, tab, perm_tab);
362 },
363 };
364 }
365
366 #[allow(non_snake_case)]
367 fn dct_II_part2_inplace(buf: &mut [f32], size: usize, step: usize, tab: &[f32], perm_tab: &[usize]) {
368 let hsize = size >> 1;
369 // todo optimise for size = 4
370 for i in 0..hsize {
371 let i0 = buf[perm_tab[size + i] * step];
372 let i1 = buf[perm_tab[size + size - i - 1] * step];
373 let c0 = tab[size + i * 2 + 0];
374 let c1 = tab[size + i * 2 + 1];
375 buf[perm_tab[size + i] * step] = c0 * i0 + c1 * i1;
376 buf[perm_tab[size + size - i - 1] * step] = c0 * i1 - c1 * i0;
377 }
378
379 dct_II_inplace(buf, hsize, step * 2, tab, perm_tab);
380 dst_II_inplace(&mut buf[step..], hsize, step * 2, tab, perm_tab);
381
382 buf[(size - 1) * step] = -buf[(size - 1) * step];
383 for i in 1..hsize {
384 let (i0, i1) = if (i & 1) == 0 {
385 (buf[i * step * 2], -buf[i * step * 2 - step])
386 } else {
387 (buf[i * step * 2], buf[i * step * 2 - step])
388 };
389 buf[i * step * 2 - step] = (i0 + i1) / consts::SQRT_2;
390 buf[i * step * 2] = (i0 - i1) / consts::SQRT_2;
391 }
392 }
393
394 #[allow(non_snake_case)]
395 fn dst_II_inplace(buf: &mut [f32], size: usize, step: usize, tab: &[f32], perm_tab: &[usize]) {
396 if size <= 1 { return; }
397 let hsize = size >> 1;
398 for i in hsize..size { buf[i * step] = -buf[i * step]; }
399 dct_II_inplace(buf, size, step, tab, perm_tab);
400 for i in 0..hsize {
401 let idx0 = i * step;
402 let idx1 = (size - 1 - i) * step;
403 buf.swap(idx0, idx1);
404 }
405 }
406
407 #[allow(non_snake_case)]
408 fn dct_III_inplace(buf: &mut [f32], size: usize, step: usize, tab: &[f32], perm_tab: &[usize]) {
409 if size <= 1 { return; }
410 let hsize = size >> 1;
411 dct_III_inplace(buf, hsize, step, tab, perm_tab);
412 dct_IV_inplace(&mut buf[step*hsize..], hsize, step, tab, perm_tab);
413 for i in 0..(size >> 2) {
414 buf.swap((size - 1 - i) * step, (hsize + i) * step);
415 }
416 for i in 0..hsize {
417 let i0 = buf[i * step] / consts::SQRT_2;
418 let i1 = buf[(size-i-1) * step] / consts::SQRT_2;
419 buf[i * step] = i0 + i1;
420 buf[(size-i-1) * step] = i0 - i1;
421 }
422 }
423
424 #[allow(non_snake_case)]
425 fn dst_III_inplace(buf: &mut [f32], size: usize, step: usize, tab: &[f32], perm_tab: &[usize]) {
426 if size <= 1 { return; }
427 let hsize = size >> 1;
428 for i in 0..hsize {
429 let idx0 = i * step;
430 let idx1 = (size - 1 - i) * step;
431 buf.swap(idx0, idx1);
432 }
433 dct_III_inplace(buf, size, step, tab, perm_tab);
434 for i in 0..hsize { buf[i * 2 * step + step] = -buf[i * 2 * step + step]; }
435 }
436
437 #[allow(non_snake_case)]
438 fn dct_IV_inplace(buf: &mut [f32], size: usize, step: usize, tab: &[f32], perm_tab: &[usize]) {
439 if size <= 1 { return; }
440 let hsize = size >> 1;
441
442 for i in 0..hsize {
443 let idx0 = perm_tab[size + i];
444 let idx1 = size - 1 - idx0;
445 let i0 = buf[idx0 * step];
446 let i1 = buf[idx1 * step];
447 let c0 = tab[size + i * 2 + 1];
448 let c1 = tab[size + i * 2 + 0];
449 buf[idx0 * step] = c0 * i0 + c1 * i1;
450 buf[idx1 * step] = c0 * i1 - c1 * i0;
451 }
452 for i in (hsize+1..size).step_by(2) {
453 buf[i] = -buf[i];
454 }
455 dct_II_inplace(buf, hsize, step * 2, tab, perm_tab);
456 dct_II_inplace(&mut buf[step..], hsize, step * 2, tab, perm_tab);
457 for i in 0..(size >> 2) {
458 buf.swap((size - 1 - i * 2) * step, (i * 2 + 1) * step);
459 }
460 for i in (3..size).step_by(4) {
461 buf[i] = -buf[i];
462 }
463 buf[0] *= consts::SQRT_2;
464 buf[(size - 1) * step] *= -consts::SQRT_2;
465 for i in 0..hsize-1 {
466 let i0 = buf[(i * 2 + 2) * step];
467 let i1 = buf[(i * 2 + 1) * step];
468 buf[(i * 2 + 2) * step] = i0 + i1;
469 buf[(i * 2 + 1) * step] = i0 - i1;
470 }
471 for i in 0..size {
472 buf[i * step] /= consts::SQRT_2;
473 }
474 }
475
476 #[cfg(test)]
477 mod test {
478 use super::*;
479
480 fn test_pair(mode: DCTMode, invmode: DCTMode, size: usize) {
481 println!("testing {:?} -> {:?}", mode, invmode);
482 let mut fin: Vec<f32> = Vec::with_capacity(size);
483 let mut out: Vec<f32> = Vec::with_capacity(size);
484 out.resize(size, 0.0);
485 let mut seed: u32 = 42;
486 for _ in 0..size {
487 seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
488 let val = (seed >> 16) as i16;
489 fin.push((val as f32) / 256.0);
490 }
491 let mut dct = DCT::new(mode, size);
492 dct.do_dct(&fin, &mut out);
493 let mut dct = DCT::new(invmode, size);
494 dct.do_dct_inplace(&mut out);
495
496 let scale = dct.get_scale();
497 for i in 0..fin.len() {
498 assert!((fin[i] - out[i]*scale).abs() < 1.0e-2);
499 }
500 }
501 #[test]
502 fn test_dct() {
503 test_pair(DCTMode::DCT_I, DCTMode::DCT_I, 32);
504 test_pair(DCTMode::DST_I, DCTMode::DST_I, 32);
505 test_pair(DCTMode::DCT_II, DCTMode::DCT_III, 32);
506 test_pair(DCTMode::DST_II, DCTMode::DST_III, 32);
507 test_pair(DCTMode::DCT_III, DCTMode::DCT_II, 32);
508 test_pair(DCTMode::DST_III, DCTMode::DST_II, 32);
509 test_pair(DCTMode::DCT_IV, DCTMode::DCT_IV, 32);
510 test_pair(DCTMode::DST_IV, DCTMode::DST_IV, 32);
511 }
512 }