VP6 encoder
[nihav.git] / nihav-duck / src / codecs / vp6enc / models.rs
1 use nihav_codec_support::codecs::ZIGZAG;
2 use super::super::vp6data::*;
3
4 #[derive(Clone,Copy,Default)]
5 pub struct VP56MVModel {
6 pub nz_prob: u8,
7 pub sign_prob: u8,
8 pub raw_probs: [u8; 8],
9 pub tree_probs: [u8; 7],
10 }
11
12 #[derive(Clone,Copy,Default)]
13 pub struct VP56MBTypeModel {
14 pub probs: [u8; 10],
15 }
16
17 #[derive(Clone,Copy,Default)]
18 pub struct VP56CoeffModel {
19 pub dc_token_probs: [[[u8; 5]; 6]; 6],
20 pub dc_value_probs: [u8; 11],
21 pub ac_val_probs: [[[u8; 11]; 6]; 3],
22 }
23
24 #[derive(Clone)]
25 pub struct VP6Models {
26 pub scan_order: [usize; 64],
27 pub scan: [usize; 64],
28 pub zigzag: [usize; 64],
29 pub zero_run_probs: [[u8; 14]; 2],
30 }
31
32 const MAX_HUFF_ELEMS: usize = 12;
33 #[derive(Clone,Copy,Default)]
34 pub struct VP6Huff {
35 pub codes: [u16; MAX_HUFF_ELEMS],
36 pub bits: [u8; MAX_HUFF_ELEMS],
37 }
38
39 #[derive(Clone,Copy,Default)]
40 struct Node {
41 weight: u16,
42 sym: i8,
43 ch0: usize,
44 ch1: usize,
45 }
46
47 fn prob2weight(a: u8, b: u8) -> u8 {
48 let w = ((u16::from(a) * u16::from(b)) >> 8) as u8;
49 if w == 0 {
50 1
51 } else {
52 w
53 }
54 }
55
56 impl VP6Huff {
57 pub fn build_codes(&mut self, probs: &[u8; 11]) {
58 let mut weights = [0u8; 12];
59
60 weights[11] = prob2weight( probs[0], probs[ 1]);
61 weights[ 0] = prob2weight( probs[0], !probs[ 1]);
62 weights[ 1] = prob2weight(!probs[0], probs[ 2]);
63 let lvroot = prob2weight(!probs[0], !probs[ 2]);
64 let tworoot = prob2weight( lvroot, probs[ 3]);
65 let hlroot = prob2weight( lvroot, !probs[ 3]);
66 weights[ 2] = prob2weight( tworoot, probs[ 4]);
67 let root34 = prob2weight( tworoot, !probs[ 4]);
68 weights[ 3] = prob2weight( root34, probs[ 5]);
69 weights[ 4] = prob2weight( root34, !probs[ 5]);
70 let c1root = prob2weight( hlroot, probs[ 6]);
71 let c34root = prob2weight( hlroot, !probs[ 6]);
72 weights[ 5] = prob2weight( c1root, probs[ 7]);
73 weights[ 6] = prob2weight( c1root, !probs[ 7]);
74 let c3root = prob2weight( c34root, probs[ 8]);
75 let c4root = prob2weight( c34root, !probs[ 8]);
76 weights[ 7] = prob2weight( c3root, probs[ 9]);
77 weights[ 8] = prob2weight( c3root, !probs[ 9]);
78 weights[ 9] = prob2weight( c4root, probs[10]);
79 weights[10] = prob2weight( c4root, !probs[10]);
80
81 self.build(&weights);
82 }
83 pub fn build_codes_zero_run(&mut self, probs: &[u8; 14]) {
84 let mut weights = [0u8; 9];
85
86 let root = prob2weight( probs[0], probs[1]);
87 weights[0] = prob2weight( root, probs[2]);
88 weights[1] = prob2weight( root, !probs[2]);
89
90 let root = prob2weight( probs[0], !probs[1]);
91 weights[2] = prob2weight( root, probs[3]);
92 weights[3] = prob2weight( root, !probs[3]);
93
94 let root = prob2weight(!probs[0], probs[4]);
95 weights[8] = prob2weight(!probs[0], !probs[4]);
96 let root1 = prob2weight( root, probs[5]);
97 let root2 = prob2weight( root, !probs[5]);
98 weights[4] = prob2weight( root1, probs[6]);
99 weights[5] = prob2weight( root1, !probs[6]);
100 weights[6] = prob2weight( root2, probs[7]);
101 weights[7] = prob2weight( root2, !probs[7]);
102
103 self.build(&weights);
104 }
105 fn build(&mut self, weights: &[u8]) {
106 let mut nodes = [Node::default(); MAX_HUFF_ELEMS * 2];
107 let mut nlen = 0;
108
109 for w in weights.iter().rev() {
110 let weight = u16::from(*w);
111 let mut pos = nlen;
112 for i in 0..nlen {
113 if nodes[i].weight > weight {
114 pos = i;
115 break;
116 }
117 }
118 for j in (pos..nlen).rev() {
119 nodes[j + 1] = nodes[j];
120 }
121 nodes[pos].weight = weight;
122 nodes[pos].sym = (weights.len() - nlen - 1) as i8;
123 nlen += 1;
124 }
125
126 let mut low = 0;
127 for _ in 0..nlen-1 {
128 let nnode = Node {
129 weight: nodes[low + 0].weight + nodes[low + 1].weight,
130 sym: -1,
131 ch0: low + 0,
132 ch1: low + 1,
133 };
134 low += 2;
135 let mut pos = low;
136 while (pos < nlen) && (nodes[pos].weight < nnode.weight) {
137 pos += 1;
138 }
139 for j in (pos..nlen).rev() {
140 nodes[j + 1] = nodes[j];
141 }
142 nodes[pos] = nnode;
143 nlen += 1;
144 }
145 self.get_codes(&nodes, nlen - 1, 0, 0);
146 for i in nlen..self.codes.len() {
147 self.codes[i] = self.codes[0];
148 self.bits[i] = self.bits[0];
149 }
150 }
151 fn get_codes(&mut self, nodes: &[Node], pos: usize, code: u16, len: u8) {
152 if nodes[pos].sym >= 0 {
153 self.codes[nodes[pos].sym as usize] = code;
154 self.bits [nodes[pos].sym as usize] = len;
155 } else {
156 self.get_codes(nodes, nodes[pos].ch0, (code << 1) | 0, len + 1);
157 self.get_codes(nodes, nodes[pos].ch1, (code << 1) | 1, len + 1);
158 }
159 }
160 }
161
162 #[derive(Clone,Copy,Default)]
163 pub struct VP6HuffModels {
164 pub dc_token_tree: [VP6Huff; 2],
165 pub ac_token_tree: [[[VP6Huff; 6]; 3]; 2],
166 pub zero_run_tree: [VP6Huff; 2],
167 }
168
169 impl VP6Models {
170 fn new() -> Self {
171 Self {
172 scan_order: [0; 64],
173 scan: [0; 64],
174 zigzag: [0; 64],
175 zero_run_probs: [[0; 14]; 2],
176 }
177 }
178 }
179
180 #[derive(Clone)]
181 pub struct VP56Models {
182 pub mv_models: [VP56MVModel; 2],
183 pub mbtype_models: [[VP56MBTypeModel; 10]; 3],
184 pub coeff_models: [VP56CoeffModel; 2],
185 pub prob_xmitted: [[u8; 20]; 3],
186 pub vp6models: VP6Models,
187 pub vp6huff: VP6HuffModels,
188 }
189
190 impl VP56Models {
191 pub fn new() -> Self {
192 Self {
193 mv_models: [VP56MVModel::default(); 2],
194 mbtype_models: [[VP56MBTypeModel::default(); 10]; 3],
195 coeff_models: [VP56CoeffModel::default(); 2],
196 prob_xmitted: [[0; 20]; 3],
197 vp6models: VP6Models::new(),
198 vp6huff: VP6HuffModels::default(),
199 }
200 }
201 pub fn reset(&mut self, interlaced: bool) {
202 for (i, mdl) in self.mv_models.iter_mut().enumerate() {
203 mdl.nz_prob = NZ_PROBS[i];
204 mdl.sign_prob = 128;
205 mdl.raw_probs.copy_from_slice(&RAW_PROBS[i]);
206 mdl.tree_probs.copy_from_slice(&TREE_PROBS[i]);
207 }
208
209 for mdl in self.coeff_models.iter_mut() {
210 mdl.dc_value_probs = [128; 11];
211 mdl.ac_val_probs = [[[128; 11]; 6]; 3];
212 }
213 self.vp6models.zero_run_probs.copy_from_slice(&ZERO_RUN_PROBS);
214 reset_scan(&mut self.vp6models, interlaced);
215 }
216 pub fn reset_mbtype_models(&mut self) {
217 const DEFAULT_XMITTED_PROBS: [[u8; 20]; 3] = [
218 [ 42, 69, 2, 1, 7, 1, 42, 44, 22, 6, 3, 1, 2, 0, 5, 1, 1, 0, 0, 0 ],
219 [ 8, 229, 1, 1, 8, 0, 0, 0, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 0, 0 ],
220 [ 35, 122, 1, 1, 6, 1, 34, 46, 0, 0, 2, 1, 1, 0, 1, 0, 1, 1, 0, 0 ]
221 ];
222 self.prob_xmitted.copy_from_slice(&DEFAULT_XMITTED_PROBS);
223 }
224 }
225
226 pub fn reset_scan(model: &mut VP6Models, interlaced: bool) {
227 if !interlaced {
228 model.scan_order.copy_from_slice(&VP6_DEFAULT_SCAN_ORDER);
229 } else {
230 model.scan_order.copy_from_slice(&VP6_INTERLACED_SCAN_ORDER);
231 }
232 for i in 0..64 { model.scan[i] = i; }
233 model.zigzag.copy_from_slice(&ZIGZAG);
234 }
235
236 #[derive(Clone,Copy,Default)]
237 pub struct ProbCounter {
238 zeroes: u32,
239 total: u32,
240 }
241
242 // bits to code zero probability multiplied by eight
243 pub const PROB_BITS: [u8; 256] = [
244 0, 64, 56, 51, 48, 45, 43, 42,
245 40, 39, 37, 36, 35, 34, 34, 33,
246 32, 31, 31, 30, 29, 29, 28, 28,
247 27, 27, 26, 26, 26, 25, 25, 24,
248 24, 24, 23, 23, 23, 22, 22, 22,
249 21, 21, 21, 21, 20, 20, 20, 20,
250 19, 19, 19, 19, 18, 18, 18, 18,
251 18, 17, 17, 17, 17, 17, 16, 16,
252 16, 16, 16, 15, 15, 15, 15, 15,
253 15, 14, 14, 14, 14, 14, 14, 14,
254 13, 13, 13, 13, 13, 13, 13, 12,
255 12, 12, 12, 12, 12, 12, 12, 11,
256 11, 11, 11, 11, 11, 11, 11, 11,
257 10, 10, 10, 10, 10, 10, 10, 10,
258 10, 9, 9, 9, 9, 9, 9, 9,
259 9, 9, 9, 8, 8, 8, 8, 8,
260 8, 8, 8, 8, 8, 8, 7, 7,
261 7, 7, 7, 7, 7, 7, 7, 7,
262 7, 7, 6, 6, 6, 6, 6, 6,
263 6, 6, 6, 6, 6, 6, 6, 5,
264 5, 5, 5, 5, 5, 5, 5, 5,
265 5, 5, 5, 5, 5, 5, 4, 4,
266 4, 4, 4, 4, 4, 4, 4, 4,
267 4, 4, 4, 4, 4, 4, 3, 3,
268 3, 3, 3, 3, 3, 3, 3, 3,
269 3, 3, 3, 3, 3, 3, 3, 2,
270 2, 2, 2, 2, 2, 2, 2, 2,
271 2, 2, 2, 2, 2, 2, 2, 2,
272 2, 1, 1, 1, 1, 1, 1, 1,
273 1, 1, 1, 1, 1, 1, 1, 1,
274 1, 1, 1, 1, 1, 1, 0, 0,
275 0, 0, 0, 0, 0, 0, 0, 0
276 ];
277
278 impl ProbCounter {
279 pub fn add(&mut self, b: bool) {
280 if !b {
281 self.zeroes += 1;
282 }
283 self.total += 1;
284 }
285 pub fn to_prob(self) -> u8 {
286 if self.total > 0 {
287 (((self.zeroes << 8) / self.total).min(254) & !1).max(1) as u8
288 } else {
289 128
290 }
291 }
292 pub fn to_prob_worthy(&self, old_prob: u8) -> u8 {
293 if self.total > 0 {
294 let new_prob = self.to_prob();
295 let new_bits = Self::est_bits(new_prob, self.zeroes, self.total);
296 let old_bits = Self::est_bits(old_prob, self.zeroes, self.total);
297
298 if new_bits + 7 < old_bits {
299 new_prob
300 } else {
301 old_prob
302 }
303 } else {
304 old_prob
305 }
306 }
307 fn est_bits(prob: u8, zeroes: u32, total: u32) -> u32 {
308 (u32::from(PROB_BITS[prob as usize]) * zeroes + u32::from(PROB_BITS[256 - (prob as usize)]) * (total - zeroes) + 7) >> 3
309 }
310 }
311
312 #[derive(Clone,Copy,Default)]
313 pub struct VP56MVModelStat {
314 pub nz_prob: ProbCounter,
315 pub sign_prob: ProbCounter,
316 pub raw_probs: [ProbCounter; 8],
317 pub tree_probs: [ProbCounter; 7],
318 }
319
320 #[derive(Clone,Copy,Default)]
321 pub struct VP56CoeffModelStat {
322 pub dc_token_probs: [[[ProbCounter; 5]; 6]; 6],
323 pub dc_value_probs: [ProbCounter; 11],
324 pub ac_val_probs: [[[ProbCounter; 11]; 6]; 3],
325 }
326
327 #[derive(Default)]
328 pub struct VP6ModelsStat {
329 pub zero_run_probs: [[ProbCounter; 14]; 2],
330 }
331
332 pub struct VP56ModelsStat {
333 pub mv_models: [VP56MVModelStat; 2],
334 pub mbtype_models: [[[usize; 10]; 10]; 3],
335 pub coeff_models: [VP56CoeffModelStat; 2],
336 pub vp6models: VP6ModelsStat,
337 }
338
339 impl VP56ModelsStat {
340 pub fn new() -> Self {
341 Self {
342 mv_models: [VP56MVModelStat::default(); 2],
343 mbtype_models: [[[0; 10]; 10]; 3],
344 coeff_models: [VP56CoeffModelStat::default(); 2],
345 vp6models: VP6ModelsStat::default(),
346 }
347 }
348 pub fn reset(&mut self) {
349 self.mv_models = [VP56MVModelStat::default(); 2];
350 self.mbtype_models = [[[0; 10]; 10]; 3];
351 self.coeff_models = [VP56CoeffModelStat::default(); 2];
352 self.vp6models = VP6ModelsStat::default();
353 }
354 pub fn generate(&self, dst: &mut VP56Models, is_intra: bool) {
355 if !is_intra {
356 for (dmv, smv) in dst.mv_models.iter_mut().zip(self.mv_models.iter()) {
357 dmv.nz_prob = smv.nz_prob.to_prob_worthy(dmv.nz_prob);
358 dmv.sign_prob = smv.sign_prob.to_prob_worthy(dmv.sign_prob);
359 for (dp, sp) in dmv.raw_probs.iter_mut().zip(smv.raw_probs.iter()) {
360 *dp = sp.to_prob_worthy(*dp);
361 }
362 for (dp, sp) in dmv.tree_probs.iter_mut().zip(smv.tree_probs.iter()) {
363 *dp = sp.to_prob_worthy(*dp);
364 }
365 }
366 for (xmit, mdl) in dst.prob_xmitted.iter_mut().zip(self.mbtype_models.iter()) {
367 Self::generate_prob_xmitted(xmit, mdl);
368 }
369 }
370 for (dmv, smv) in dst.coeff_models.iter_mut().zip(self.coeff_models.iter()) {
371 for (dp, sp) in dmv.dc_value_probs.iter_mut().zip(smv.dc_value_probs.iter()) {
372 *dp = sp.to_prob_worthy(*dp);
373 }
374 for (dp, sp) in dmv.ac_val_probs.iter_mut().zip(smv.ac_val_probs.iter()) {
375 for (dp, sp) in dp.iter_mut().zip(sp.iter()) {
376 for (dp, sp) in dp.iter_mut().zip(sp.iter()) {
377 *dp = sp.to_prob_worthy(*dp);
378 }
379 }
380 }
381 }
382 for (dp, sp) in dst.vp6models.zero_run_probs.iter_mut().zip(self.vp6models.zero_run_probs.iter()) {
383 for (dp, sp) in dp.iter_mut().zip(sp.iter()) {
384 *dp = sp.to_prob_worthy(*dp);
385 }
386 }
387 }
388 /*
389 VPMBType::InterNoMV => 0,
390 VPMBType::Intra => 1,
391 VPMBType::InterMV => 2,
392 VPMBType::InterNearest => 3,
393 VPMBType::InterNear => 4,
394 VPMBType::GoldenNoMV => 5,
395 VPMBType::GoldenMV => 6,
396 VPMBType::InterFourMV => 7,
397 VPMBType::GoldenNearest => 8,
398 VPMBType::GoldenNear => 9,
399 */
400 fn generate_prob_xmitted(probs: &mut [u8; 20], mbtype: &[[usize; 10]; 10]) {
401 let mut sums = [0; 20];
402 let mut total = 0;
403 for (last, row) in mbtype.iter().enumerate() {
404 for (cur, &count) in row.iter().enumerate() {
405 if last == cur {
406 sums[cur * 2 + 1] = count;
407 } else {
408 sums[cur * 2] += count;
409 }
410 total += count;
411 }
412 }
413 if total != 0 {
414 let mut sum = 0;
415 for (dprob, &sprob) in probs.iter_mut().zip(sums.iter()) {
416 if sprob != 0 {
417 *dprob = ((sprob * 256 + total - 1) / total).min(255) as u8;
418 sum += u16::from(*dprob);
419 } else {
420 *dprob = 0;
421 }
422 }
423 while sum > 256 {
424 for prob in probs.iter_mut() {
425 if *prob > 1 {
426 *prob -= 1;
427 sum -= 1;
428 if sum == 256 {
429 break;
430 }
431 }
432 }
433 }
434 } else {
435 *probs = [0; 20];
436 }
437 }
438 }