]>
Commit | Line | Data |
---|---|---|
4965a5e5 KS |
1 | use nihav_core::frame::*; |
2 | use nihav_codec_support::codecs::{MV, ZERO_MV}; | |
3 | use std::str::FromStr; | |
4 | use super::dsp::{RefMBData, luma_mc, chroma_mc}; | |
5 | ||
e6aaad5c | 6 | #[derive(Clone,Copy,PartialEq,Default)] |
4965a5e5 KS |
7 | pub enum MVSearchMode { |
8 | Dummy, | |
9 | Diamond, | |
e6aaad5c | 10 | #[default] |
4965a5e5 KS |
11 | Hexagon, |
12 | UMH, | |
13 | } | |
14 | ||
15 | impl MVSearchMode { | |
16 | pub const fn get_possible_modes() -> &'static [&'static str] { | |
17 | &["diamond", "hexagon", "umh"] | |
18 | } | |
19 | fn create(self) -> Box<dyn MVSearch+Send> { | |
20 | match self { | |
21 | MVSearchMode::Dummy => Box::new(DummySearcher{}), | |
22 | MVSearchMode::Diamond => Box::new(DiaSearch::new()), | |
23 | MVSearchMode::Hexagon => Box::new(HexSearch::new()), | |
24 | MVSearchMode::UMH => Box::new(UnevenHexSearch::new()), | |
25 | } | |
26 | } | |
27 | } | |
28 | ||
4965a5e5 KS |
29 | impl std::fmt::Display for MVSearchMode { |
30 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | |
31 | match *self { | |
32 | MVSearchMode::Diamond => write!(f, "diamond"), | |
33 | MVSearchMode::Hexagon => write!(f, "hexagon"), | |
34 | MVSearchMode::UMH => write!(f, "umh"), | |
35 | MVSearchMode::Dummy => write!(f, "dummy"), | |
36 | } | |
37 | } | |
38 | } | |
39 | ||
40 | impl FromStr for MVSearchMode { | |
41 | type Err = (); | |
42 | fn from_str(s: &str) -> Result<Self, Self::Err> { | |
43 | match s { | |
44 | "diamond" => Ok(MVSearchMode::Diamond), | |
45 | "hexagon" => Ok(MVSearchMode::Hexagon), | |
46 | "umh" => Ok(MVSearchMode::UMH), | |
47 | "dummy" => Ok(MVSearchMode::Dummy), | |
48 | _ => Err(()), | |
49 | } | |
50 | } | |
51 | } | |
52 | ||
53 | const MAX_DIST: u32 = std::u32::MAX; | |
54 | const DIST_THRESH: u32 = 256; | |
55 | ||
56 | trait FromPixels { | |
57 | fn from_pixels(self) -> Self; | |
58 | } | |
59 | ||
60 | impl FromPixels for MV { | |
61 | fn from_pixels(self) -> MV { | |
62 | MV { x: self.x * 4, y: self.y * 4 } | |
63 | } | |
64 | } | |
65 | ||
66 | const DIA_PATTERN: [MV; 9] = [ | |
67 | ZERO_MV, | |
68 | MV {x: -2, y: 0}, | |
69 | MV {x: -1, y: 1}, | |
70 | MV {x: 0, y: 2}, | |
71 | MV {x: 1, y: 1}, | |
72 | MV {x: 2, y: 0}, | |
73 | MV {x: 1, y: -1}, | |
74 | MV {x: 0, y: -2}, | |
75 | MV {x: -1, y: -1} | |
76 | ]; | |
77 | ||
78 | const HEX_PATTERN: [MV; 7] = [ | |
79 | ZERO_MV, | |
80 | MV {x: -2, y: 0}, | |
81 | MV {x: -1, y: 2}, | |
82 | MV {x: 1, y: 2}, | |
83 | MV {x: 2, y: 0}, | |
84 | MV {x: 1, y: -2}, | |
85 | MV {x: -1, y: -2} | |
86 | ]; | |
87 | ||
88 | const REFINEMENT: [MV; 4] = [ | |
89 | MV {x: -1, y: 0}, | |
90 | MV {x: 0, y: 1}, | |
91 | MV {x: 1, y: 0}, | |
92 | MV {x: 0, y: -1} | |
93 | ]; | |
94 | ||
95 | macro_rules! search_template { | |
96 | ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident, $threshold: expr) => ({ | |
97 | search_template!($self, $mv_est, $cur_blk, $mb_x, $mb_y, $sad_func, $threshold, ZERO_MV, MAX_DIST, true) | |
98 | }); | |
99 | ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident, $threshold: expr, $start_mv: expr, $best_dist: expr, $fullpel_stage: expr) => ({ | |
100 | let mut best_dist = $best_dist; | |
101 | let mut best_mv = $start_mv; | |
102 | ||
103 | let mut min_dist; | |
104 | let mut min_idx; | |
105 | ||
106 | if $fullpel_stage { | |
107 | $self.reset(); | |
108 | loop { | |
109 | let mut cur_best_dist = best_dist; | |
110 | for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) { | |
111 | if *dist == MAX_DIST { | |
112 | *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point.from_pixels(), cur_best_dist); | |
113 | cur_best_dist = cur_best_dist.min(*dist); | |
114 | if *dist <= $threshold { | |
115 | break; | |
116 | } | |
117 | } | |
118 | } | |
119 | min_dist = $self.dist[0]; | |
120 | min_idx = 0; | |
121 | for (i, &dist) in $self.dist.iter().enumerate().skip(1) { | |
122 | if dist < min_dist { | |
123 | min_dist = dist; | |
124 | min_idx = i; | |
125 | if dist <= $threshold { | |
126 | break; | |
127 | } | |
128 | } | |
129 | } | |
130 | if min_dist <= $threshold || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range || $self.point[min_idx].y.abs() >= $mv_est.mv_range { | |
131 | break; | |
132 | } | |
133 | best_dist = min_dist; | |
134 | $self.update($self.steps[min_idx]); | |
135 | } | |
136 | best_dist = min_dist; | |
137 | best_mv = $self.point[min_idx]; | |
138 | if best_dist <= $threshold { | |
139 | return (best_mv.from_pixels(), best_dist); | |
140 | } | |
141 | for &step in REFINEMENT.iter() { | |
142 | let mv = best_mv + step; | |
143 | let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv.from_pixels(), MAX_DIST); | |
144 | if best_dist > dist { | |
145 | best_dist = dist; | |
146 | best_mv = mv; | |
147 | } | |
148 | } | |
149 | best_mv = best_mv.from_pixels(); | |
150 | if best_dist <= $threshold { | |
151 | return (best_mv, best_dist); | |
152 | } | |
153 | } | |
154 | ||
155 | // subpel refinement | |
156 | $self.set_new_point(best_mv, best_dist); | |
157 | loop { | |
158 | let mut cur_best_dist = best_dist; | |
159 | for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) { | |
160 | if *dist == MAX_DIST { | |
161 | *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point, cur_best_dist); | |
162 | cur_best_dist = cur_best_dist.min(*dist); | |
163 | if *dist <= $threshold { | |
164 | break; | |
165 | } | |
166 | } | |
167 | } | |
168 | min_dist = $self.dist[0]; | |
169 | min_idx = 0; | |
170 | for (i, &dist) in $self.dist.iter().enumerate().skip(1) { | |
171 | if dist < min_dist { | |
172 | min_dist = dist; | |
173 | min_idx = i; | |
174 | if dist <= $threshold { | |
175 | break; | |
176 | } | |
177 | } | |
178 | } | |
179 | if min_dist <= $threshold || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range * 8 || $self.point[min_idx].y.abs() >= $mv_est.mv_range * 8 { | |
180 | break; | |
181 | } | |
182 | best_dist = min_dist; | |
183 | $self.update($self.steps[min_idx]); | |
184 | } | |
185 | best_dist = min_dist; | |
186 | best_mv = $self.point[min_idx]; | |
187 | if best_dist <= $threshold { | |
188 | return (best_mv, best_dist); | |
189 | } | |
190 | for &step in REFINEMENT.iter() { | |
191 | let mv = best_mv + step; | |
192 | let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv, MAX_DIST); | |
193 | if best_dist > dist { | |
194 | best_dist = dist; | |
195 | best_mv = mv; | |
196 | } | |
197 | } | |
198 | (best_mv, best_dist) | |
199 | }); | |
200 | } | |
201 | ||
202 | macro_rules! pattern_search { | |
203 | ($struct_name: ident, $patterns: expr) => { | |
204 | pub struct $struct_name { | |
205 | point: [MV; $patterns.len()], | |
206 | dist: [u32; $patterns.len()], | |
207 | steps: &'static [MV; $patterns.len()], | |
208 | } | |
209 | ||
210 | impl $struct_name { | |
211 | pub fn new() -> Self { | |
212 | Self { | |
213 | point: $patterns, | |
214 | dist: [MAX_DIST; $patterns.len()], | |
215 | steps: &$patterns, | |
216 | } | |
217 | } | |
218 | fn reset(&mut self) { | |
219 | self.point = $patterns; | |
220 | self.dist = [MAX_DIST; $patterns.len()]; | |
221 | } | |
222 | fn set_new_point(&mut self, start: MV, dist: u32) { | |
223 | for (dst, &src) in self.point.iter_mut().zip(self.steps.iter()) { | |
224 | *dst = src + start; | |
225 | } | |
226 | self.dist = [MAX_DIST; $patterns.len()]; | |
227 | self.dist[0] = dist; | |
228 | } | |
229 | fn update(&mut self, step: MV) { | |
230 | let mut new_point = self.point; | |
231 | let mut new_dist = [MAX_DIST; $patterns.len()]; | |
232 | ||
233 | for point in new_point.iter_mut() { | |
234 | *point += step; | |
235 | } | |
236 | ||
237 | for (new_point, new_dist) in new_point.iter_mut().zip(new_dist.iter_mut()) { | |
238 | for (&old_point, &old_dist) in self.point.iter().zip(self.dist.iter()) { | |
239 | if *new_point == old_point { | |
240 | *new_dist = old_dist; | |
241 | break; | |
242 | } | |
243 | } | |
244 | } | |
245 | self.point = new_point; | |
246 | self.dist = new_dist; | |
247 | } | |
248 | } | |
249 | ||
250 | impl MVSearch for $struct_name { | |
251 | fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &RefMBData, mb_x: usize, mb_y: usize, _cand_mvs: &[MV]) -> (MV, u32) { | |
252 | search_template!(self, mv_est, cur_mb, mb_x, mb_y, sad_mb, DIST_THRESH) | |
253 | } | |
254 | fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) { | |
255 | search_template!(self, mv_est, ref_blk, xpos, ypos, sad_blk8, DIST_THRESH / 4) | |
256 | } | |
257 | } | |
258 | } | |
259 | } | |
260 | ||
261 | pattern_search!(DiaSearch, DIA_PATTERN); | |
262 | pattern_search!(HexSearch, HEX_PATTERN); | |
263 | ||
264 | const LARGE_HEX_PATTERN: [MV; 16] = [ | |
265 | MV { x: -4, y: 0 }, | |
266 | MV { x: -4, y: 1 }, | |
267 | MV { x: -4, y: 2 }, | |
268 | MV { x: -2, y: 3 }, | |
269 | MV { x: 0, y: 4 }, | |
270 | MV { x: 2, y: 3 }, | |
271 | MV { x: 4, y: 2 }, | |
272 | MV { x: 4, y: 1 }, | |
273 | MV { x: 4, y: 0 }, | |
274 | MV { x: 4, y: -1 }, | |
275 | MV { x: 4, y: -2 }, | |
276 | MV { x: -2, y: -3 }, | |
277 | MV { x: 0, y: -4 }, | |
278 | MV { x: -2, y: -3 }, | |
279 | MV { x: -4, y: -2 }, | |
280 | MV { x: -4, y: -1 } | |
281 | ]; | |
282 | ||
283 | const UNSYMM_CROSS: [MV; 4] = [ | |
284 | MV { x: -2, y: 0 }, | |
285 | MV { x: 0, y: 1 }, | |
286 | MV { x: 2, y: 0 }, | |
287 | MV { x: 0, y: -1 } | |
288 | ]; | |
289 | ||
290 | #[derive(Default)] | |
291 | struct UniqueSet<T:Copy+Default> { | |
292 | list: [T; 16], | |
293 | count: usize, | |
294 | } | |
295 | ||
296 | impl<T:Copy+Default+PartialEq> UniqueSet<T> { | |
297 | fn new() -> Self { Self::default() } | |
298 | fn clear(&mut self) { self.count = 0; } | |
299 | fn get_list(&self) -> &[T] { &self.list[..self.count] } | |
300 | fn add(&mut self, val: T) { | |
301 | if self.count < self.list.len() && !self.get_list().contains(&val) { | |
302 | self.list[self.count] = val; | |
303 | self.count += 1; | |
304 | } | |
305 | } | |
306 | } | |
307 | ||
308 | trait MVOps { | |
309 | fn scale(self, scale: i16) -> Self; | |
310 | fn is_in_range(self, range: i16) -> bool; | |
311 | } | |
312 | ||
313 | impl MVOps for MV { | |
314 | fn scale(self, scale: i16) -> MV { | |
315 | MV { x: self.x * scale, y: self.y * scale } | |
316 | } | |
317 | fn is_in_range(self, range: i16) -> bool { | |
318 | self.x.abs() <= range && self.y.abs() <= range | |
319 | } | |
320 | } | |
321 | ||
322 | macro_rules! single_search_step { | |
323 | ($start:expr, $best_dist:expr, $mv_est:expr, $sad_func:ident, $ref_blk:expr, $xpos:expr, $ypos:expr, $pattern:expr, $scale:expr, $dist_thr:expr) => {{ | |
324 | let mut best_mv = $start; | |
325 | let mut best_dist = $best_dist; | |
326 | for point in $pattern.iter() { | |
327 | let mv = point.scale($scale) + $start; | |
328 | if !mv.is_in_range($mv_est.mv_range * 4) { | |
329 | continue; | |
330 | } | |
331 | let dist = $mv_est.$sad_func($ref_blk, $xpos, $ypos, mv, best_dist); | |
332 | if dist < best_dist { | |
333 | best_mv = mv; | |
334 | best_dist = dist; | |
335 | if best_dist < $dist_thr { | |
336 | break; | |
337 | } | |
338 | } | |
339 | } | |
340 | (best_mv, best_dist, best_mv != $start) | |
341 | }} | |
342 | } | |
343 | ||
344 | struct UnevenHexSearch { | |
345 | mv_list: UniqueSet<MV>, | |
346 | } | |
347 | ||
348 | impl UnevenHexSearch { | |
349 | fn new() -> Self { | |
350 | Self { | |
351 | mv_list: UniqueSet::new(), | |
352 | } | |
353 | } | |
354 | fn get_cand_mv(&mut self, cand_mvs: &[MV]) -> MV { | |
355 | self.mv_list.clear(); | |
356 | for &mv in cand_mvs.iter() { | |
357 | self.mv_list.add(mv); | |
358 | } | |
359 | match self.mv_list.count { | |
360 | 1 => self.mv_list.list[0], | |
361 | 3 => MV::pred(self.mv_list.list[0], self.mv_list.list[1], self.mv_list.list[2]), | |
362 | _ => { | |
363 | let sum = self.mv_list.get_list().iter().fold((0i32, 0i32), | |
364 | |acc, mv| (acc.0 + i32::from(mv.x), acc.1 + i32::from(mv.y))); | |
365 | MV {x: (sum.0 / (self.mv_list.count as i32)) as i16, | |
366 | y: (sum.1 / (self.mv_list.count as i32)) as i16} | |
367 | }, | |
368 | } | |
369 | } | |
370 | } | |
371 | ||
372 | macro_rules! umh_search_template { | |
373 | ($cand_mv:expr, $cutoff:expr, $mv_est:expr, $sad_func:ident, $ref_blk:expr, $xpos:expr, $ypos:expr) => {{ | |
374 | let cand_mv = $cand_mv; | |
375 | let best_dist = $mv_est.$sad_func($ref_blk, $xpos, $ypos, cand_mv, MAX_DIST); | |
376 | if best_dist < $cutoff { | |
377 | return (cand_mv, best_dist); | |
378 | } | |
379 | ||
380 | // step 1 - small refinement search | |
381 | let (mut cand_mv, mut best_dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, DIA_PATTERN, 1, $cutoff); | |
382 | if best_dist < $cutoff { | |
383 | return (cand_mv, best_dist); | |
384 | } | |
385 | ||
386 | // step 2 - unsymmetrical cross search | |
387 | loop { | |
388 | let (mv, dist, changed) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, UNSYMM_CROSS, 4, $cutoff); | |
389 | if !changed { | |
390 | break; | |
391 | } | |
392 | cand_mv = mv; | |
393 | best_dist = dist; | |
394 | if best_dist < $cutoff { | |
395 | return (mv, dist); | |
396 | } | |
397 | } | |
398 | ||
399 | // step 3 - multi-hexagon grid search | |
400 | let mut scale = 4; | |
401 | while scale > 0 { | |
402 | let (mv, dist, changed) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, LARGE_HEX_PATTERN, scale, $cutoff); | |
403 | if !changed { | |
404 | break; | |
405 | } | |
406 | cand_mv = mv; | |
407 | best_dist = dist; | |
408 | if best_dist < $cutoff { | |
409 | return (mv, dist); | |
410 | } | |
411 | scale >>= 1; | |
412 | } | |
413 | // step 4 - final hexagon search | |
414 | let (cand_mv, best_dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, HEX_PATTERN, 1, $cutoff); | |
415 | if best_dist > $cutoff { | |
416 | let (mv, dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, DIA_PATTERN, 1, $cutoff); | |
417 | (mv, dist) | |
418 | } else { | |
419 | (cand_mv, best_dist) | |
420 | } | |
421 | }} | |
422 | } | |
423 | ||
424 | impl MVSearch for UnevenHexSearch { | |
425 | fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32) { | |
426 | let cand_mv = self.get_cand_mv(cand_mvs); | |
427 | let cutoff = mv_est.cutoff_thr; | |
428 | umh_search_template!(cand_mv, cutoff, mv_est, sad_mb, cur_mb, mb_x, mb_y) | |
429 | } | |
430 | fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32) { | |
431 | let cand_mv = self.get_cand_mv(cand_mvs); | |
432 | let cutoff = mv_est.cutoff_thr / 4; | |
433 | umh_search_template!(cand_mv, cutoff, mv_est, sad_blk8, ref_blk, xpos, ypos) | |
434 | } | |
435 | } | |
436 | ||
437 | struct MVEstimator<'a> { | |
438 | pic: &'a NAVideoBuffer<u8>, | |
439 | mv_range: i16, | |
440 | cutoff_thr: u32, | |
441 | } | |
442 | ||
443 | macro_rules! sad { | |
444 | ($src1:expr, $src2:expr) => { | |
445 | $src1.iter().zip($src2.iter()).fold(0u32, |acc, (&a, &b)| | |
446 | acc + (((i32::from(a) - i32::from(b)) * (i32::from(a) - i32::from(b))) as u32)) | |
447 | } | |
448 | } | |
449 | ||
450 | impl<'a> MVEstimator<'a> { | |
451 | fn sad_mb(&self, ref_mb: &RefMBData, mb_x: usize, mb_y: usize, mv: MV, cur_best_dist: u32) -> u32 { | |
452 | let mut dst = RefMBData::new(); | |
453 | luma_mc(&mut dst.y, 16, self.pic, mb_x * 16, mb_y * 16, mv, true); | |
454 | ||
455 | let mut dist = 0; | |
456 | for (dline, sline) in dst.y.chunks(16).zip(ref_mb.y.chunks(16)) { | |
457 | dist += sad!(dline, sline); | |
458 | if dist > cur_best_dist { | |
459 | return dist; | |
460 | } | |
461 | } | |
462 | chroma_mc(&mut dst.u, 8, self.pic, mb_x * 8, mb_y * 8, 1, mv, true); | |
463 | dist += sad!(dst.u, ref_mb.u); | |
464 | if dist > cur_best_dist { | |
465 | return dist; | |
466 | } | |
467 | chroma_mc(&mut dst.v, 8, self.pic, mb_x * 8, mb_y * 8, 2, mv, true); | |
468 | dist += sad!(dst.v, ref_mb.v); | |
469 | ||
470 | dist | |
471 | } | |
472 | fn sad_blk8(&self, ref_mb: &RefMBData, xpos: usize, ypos: usize, mv: MV, cur_best_dist: u32) -> u32 { | |
473 | let mut cur_y = [0; 64]; | |
474 | let mut cur_u = [0; 16]; | |
475 | let mut cur_v = [0; 16]; | |
476 | ||
477 | let mut dist = 0; | |
478 | ||
479 | let y_off = (xpos & 8) + (ypos & 8) * 16; | |
480 | luma_mc(&mut cur_y, 8, self.pic, xpos, ypos, mv, false); | |
481 | for (dline, sline) in cur_y.chunks(8).zip(ref_mb.y[y_off..].chunks(16)) { | |
482 | dist += sad!(dline, sline); | |
483 | if dist > cur_best_dist { | |
484 | return dist; | |
485 | } | |
486 | } | |
487 | ||
488 | let c_off = (xpos & 8) / 2 + (ypos & 8) * 4; | |
489 | chroma_mc(&mut cur_u, 4, self.pic, xpos / 2, ypos / 2, 1, mv, false); | |
490 | for (dline, sline) in cur_u.chunks(4).zip(ref_mb.u[c_off..].chunks(8)) { | |
491 | dist += sad!(dline, sline); | |
492 | if dist > cur_best_dist { | |
493 | return dist; | |
494 | } | |
495 | } | |
496 | chroma_mc(&mut cur_v, 4, self.pic, xpos / 2, ypos / 2, 2, mv, false); | |
497 | for (dline, sline) in cur_v.chunks(4).zip(ref_mb.v[c_off..].chunks(8)) { | |
498 | dist += sad!(dline, sline); | |
499 | if dist > cur_best_dist { | |
500 | return dist; | |
501 | } | |
502 | } | |
503 | ||
504 | dist | |
505 | } | |
506 | } | |
507 | ||
508 | trait MVSearch { | |
509 | fn search_mb(&mut self, mv_est: &mut MVEstimator, ref_mb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32); | |
510 | fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32); | |
511 | } | |
512 | ||
513 | struct DummySearcher {} | |
514 | ||
515 | impl MVSearch for DummySearcher { | |
516 | fn search_mb(&mut self, _mv_est: &mut MVEstimator, _ref_mb: &RefMBData, _mb_x: usize, _mb_y: usize, _cand_mvs: &[MV]) -> (MV, u32) { | |
517 | (ZERO_MV, std::u32::MAX / 2) | |
518 | } | |
519 | fn search_blk8(&mut self, _mv_est: &mut MVEstimator, _ref_mb: &RefMBData, _xpos: usize, _ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) { | |
520 | (ZERO_MV, std::u32::MAX / 2) | |
521 | } | |
522 | } | |
523 | ||
524 | pub struct MotionEstimator { | |
525 | pub range: i16, | |
526 | pub thresh: u32, | |
527 | mode: MVSearchMode, | |
528 | srch: Box<dyn MVSearch+Send>, | |
529 | } | |
530 | ||
531 | impl MotionEstimator { | |
532 | pub fn new() -> Self { | |
533 | let mode = MVSearchMode::default(); | |
534 | Self { | |
535 | range: 64, | |
536 | thresh: 32, | |
537 | mode, | |
538 | srch: mode.create(), | |
539 | } | |
540 | } | |
541 | pub fn get_mode(&self) -> MVSearchMode { self.mode } | |
542 | pub fn set_mode(&mut self, new_mode: MVSearchMode) { | |
543 | if self.mode != new_mode { | |
544 | self.mode = new_mode; | |
545 | self.srch = self.mode.create(); | |
546 | } | |
547 | } | |
548 | pub fn search_mb_p(&mut self, pic: &NAVideoBuffer<u8>, refmb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32) { | |
549 | let mut mv_est = MVEstimator { | |
550 | mv_range: self.range, | |
551 | cutoff_thr: self.thresh, | |
552 | pic, | |
553 | }; | |
554 | self.srch.search_mb(&mut mv_est, refmb, mb_x, mb_y, cand_mvs) | |
555 | } | |
556 | pub fn search_blk8(&mut self, pic: &NAVideoBuffer<u8>, refmb: &RefMBData, xoff: usize, yoff: usize, cand_mvs: &[MV]) -> (MV, u32) { | |
557 | let mut mv_est = MVEstimator { | |
558 | mv_range: self.range, | |
559 | cutoff_thr: self.thresh, | |
560 | pic, | |
561 | }; | |
562 | self.srch.search_blk8(&mut mv_est, refmb, xoff, yoff, cand_mvs) | |
563 | } | |
564 | } | |
565 | ||
566 | pub struct SearchB<'a> { | |
567 | ref_p: &'a NAVideoBuffer<u8>, | |
568 | ref_n: &'a NAVideoBuffer<u8>, | |
569 | xpos: usize, | |
570 | ypos: usize, | |
571 | ratios: [u32; 2], | |
572 | tmp1: RefMBData, | |
573 | tmp2: RefMBData, | |
574 | pred_blk: RefMBData, | |
575 | } | |
576 | ||
577 | impl<'a> SearchB<'a> { | |
578 | pub fn new(ref_p: &'a NAVideoBuffer<u8>, ref_n: &'a NAVideoBuffer<u8>, mb_x: usize, mb_y: usize, ratios: [u32; 2]) -> Self { | |
579 | Self { | |
580 | ref_p, ref_n, | |
581 | xpos: mb_x * 16, | |
582 | ypos: mb_y * 16, | |
583 | ratios, | |
584 | tmp1: RefMBData::new(), | |
585 | tmp2: RefMBData::new(), | |
586 | pred_blk: RefMBData::new(), | |
587 | } | |
588 | } | |
589 | pub fn search_mb(&mut self, ref_mb: &RefMBData, cand_mvs: [MV; 2]) -> (MV, MV) { | |
590 | let mut best_cand = cand_mvs; | |
591 | let mut best_dist = self.interp_b_dist(ref_mb, best_cand, MAX_DIST); | |
592 | ||
593 | loop { | |
594 | let mut improved = false; | |
595 | for &fmv_add in DIA_PATTERN.iter() { | |
596 | for &bmv_add in DIA_PATTERN.iter() { | |
597 | let cand = [best_cand[0] + fmv_add.from_pixels(), | |
598 | best_cand[1] + bmv_add.from_pixels()]; | |
599 | let dist = self.interp_b_dist(ref_mb, cand, best_dist); | |
600 | if dist < best_dist { | |
601 | best_dist = dist; | |
602 | best_cand = cand; | |
603 | improved = true; | |
604 | } | |
605 | } | |
606 | } | |
607 | if !improved { | |
608 | break; | |
609 | } | |
610 | } | |
611 | ||
612 | for &fmv_add in REFINEMENT.iter() { | |
613 | for &bmv_add in REFINEMENT.iter() { | |
614 | let cand = [best_cand[0] + fmv_add, best_cand[1] + bmv_add]; | |
615 | let dist = self.interp_b_dist(ref_mb, cand, best_dist); | |
616 | if dist < best_dist { | |
617 | best_dist = dist; | |
618 | best_cand = cand; | |
619 | } | |
620 | } | |
621 | } | |
622 | ||
623 | (best_cand[0], best_cand[1]) | |
624 | } | |
625 | fn interp_b_dist(&mut self, ref_mb: &RefMBData, cand_mv: [MV; 2], cur_best_dist: u32) -> u32 { | |
626 | let [fmv, bmv] = cand_mv; | |
627 | luma_mc(&mut self.tmp1.y, 16, self.ref_p, self.xpos, self.ypos, fmv, true); | |
628 | chroma_mc(&mut self.tmp1.u, 8, self.ref_p, self.xpos / 2, self.ypos / 2, 1, fmv, true); | |
629 | chroma_mc(&mut self.tmp1.v, 8, self.ref_p, self.xpos / 2, self.ypos / 2, 2, fmv, true); | |
630 | luma_mc(&mut self.tmp2.y, 16, self.ref_n, self.xpos, self.ypos, bmv, true); | |
631 | chroma_mc(&mut self.tmp2.u, 8, self.ref_n, self.xpos / 2, self.ypos / 2, 1, bmv, true); | |
632 | chroma_mc(&mut self.tmp2.v, 8, self.ref_n, self.xpos / 2, self.ypos / 2, 2, bmv, true); | |
633 | self.pred_blk.avg(&self.tmp1, self.ratios[0], &self.tmp2, self.ratios[1]); | |
634 | ||
635 | let mut dist = 0; | |
636 | for (dline, sline) in self.pred_blk.y.chunks(16).zip(ref_mb.y.chunks(16)) { | |
637 | dist += sad!(dline, sline); | |
638 | if dist > cur_best_dist { | |
639 | return dist; | |
640 | } | |
641 | } | |
642 | dist += sad!(self.pred_blk.u, ref_mb.u); | |
643 | if dist > cur_best_dist { | |
644 | return dist; | |
645 | } | |
646 | dist += sad!(self.pred_blk.v, ref_mb.v); | |
647 | ||
648 | dist | |
649 | } | |
650 | } | |
651 | ||
652 | macro_rules! hadamard { | |
653 | ($s0:expr, $s1:expr, $s2:expr, $s3:expr, $d0:expr, $d1:expr, $d2:expr, $d3:expr) => { | |
654 | let t0 = $s0 + $s1; | |
655 | let t1 = $s0 - $s1; | |
656 | let t2 = $s2 + $s3; | |
657 | let t3 = $s2 - $s3; | |
658 | $d0 = t0 + t2; | |
659 | $d2 = t0 - t2; | |
660 | $d1 = t1 + t3; | |
661 | $d3 = t1 - t3; | |
662 | } | |
663 | } | |
664 | ||
665 | pub struct FrameComplexityEstimate { | |
666 | ref_frm: NAVideoBufferRef<u8>, | |
667 | cur_frm: NAVideoBufferRef<u8>, | |
668 | nxt_frm: NAVideoBufferRef<u8>, | |
669 | width: usize, | |
670 | height: usize, | |
671 | } | |
672 | ||
673 | impl FrameComplexityEstimate { | |
674 | pub fn new() -> Self { | |
675 | let vinfo = NAVideoInfo::new(24, 24, false, YUV420_FORMAT); | |
676 | let vt = alloc_video_buffer(vinfo, 4).unwrap(); | |
677 | let buf = vt.get_vbuf().unwrap(); | |
678 | Self { | |
679 | ref_frm: buf.clone(), | |
680 | cur_frm: buf.clone(), | |
681 | nxt_frm: buf, | |
682 | width: 0, | |
683 | height: 0, | |
684 | } | |
685 | } | |
686 | pub fn resize(&mut self, width: usize, height: usize) { | |
687 | if width != self.width || height != self.height { | |
688 | self.width = width; | |
689 | self.height = height; | |
690 | ||
691 | let vinfo = NAVideoInfo::new(self.width / 2, self.height / 2, false, YUV420_FORMAT); | |
692 | let vt = alloc_video_buffer(vinfo, 4).unwrap(); | |
693 | self.ref_frm = vt.get_vbuf().unwrap(); | |
694 | let frm = self.ref_frm.get_data_mut().unwrap(); | |
695 | for el in frm.iter_mut() { | |
696 | *el = 0x80; | |
697 | } | |
698 | let vt = alloc_video_buffer(vinfo, 4).unwrap(); | |
699 | self.cur_frm = vt.get_vbuf().unwrap(); | |
700 | let vt = alloc_video_buffer(vinfo, 4).unwrap(); | |
701 | self.nxt_frm = vt.get_vbuf().unwrap(); | |
702 | } | |
703 | } | |
704 | pub fn set_current(&mut self, frm: &NAVideoBuffer<u8>) { | |
705 | Self::downscale(&mut self.cur_frm, frm); | |
706 | } | |
707 | pub fn get_complexity(&self, ftype: FrameType) -> u32 { | |
708 | match ftype { | |
709 | FrameType::I => Self::calculate_i_cplx(&self.cur_frm), | |
710 | FrameType::P => Self::calculate_mv_diff(&self.ref_frm, &self.cur_frm), | |
711 | _ => 0, | |
712 | } | |
713 | } | |
714 | pub fn decide_b_frame(&mut self, frm1: &NAVideoBuffer<u8>, frm2: &NAVideoBuffer<u8>) -> bool { | |
715 | Self::downscale(&mut self.cur_frm, frm1); | |
716 | Self::downscale(&mut self.nxt_frm, frm2); | |
717 | let diff_ref_cur = Self::calculate_mv_diff(&self.ref_frm, &self.cur_frm); | |
718 | let diff_cur_nxt = Self::calculate_mv_diff(&self.cur_frm, &self.nxt_frm); | |
719 | ||
720 | // simple rule - if complexity ref->cur and cur->next is about the same this should be a B-frame | |
721 | let ddiff = diff_ref_cur.max(diff_cur_nxt) - diff_ref_cur.min(diff_cur_nxt); | |
722 | if ddiff < 256 { | |
723 | true | |
724 | } else { | |
725 | let mut order = 0; | |
726 | while (ddiff << order) < diff_ref_cur.min(diff_cur_nxt) { | |
727 | order += 1; | |
728 | } | |
729 | order > 2 | |
730 | } | |
731 | } | |
732 | pub fn update_ref(&mut self) { | |
733 | std::mem::swap(&mut self.ref_frm, &mut self.cur_frm); | |
734 | } | |
735 | ||
736 | fn add_mv(mb_x: usize, mb_y: usize, mv: MV) -> (usize, usize) { | |
737 | (((mb_x * 16) as isize + (mv.x as isize)) as usize, | |
738 | ((mb_y * 16) as isize + (mv.y as isize)) as usize) | |
739 | } | |
740 | fn calculate_i_cplx(frm: &NAVideoBuffer<u8>) -> u32 { | |
741 | let (w, h) = frm.get_dimensions(0); | |
742 | let src = frm.get_data(); | |
743 | let stride = frm.get_stride(0); | |
744 | let mut sum = 0; | |
745 | let mut offset = 0; | |
746 | for y in (0..h).step_by(4) { | |
747 | for x in (0..w).step_by(4) { | |
748 | sum += Self::satd_i(src, offset + x, stride, x > 0, y > 0); | |
749 | } | |
750 | offset += stride * 4; | |
751 | } | |
752 | sum | |
753 | } | |
754 | fn calculate_mv_diff(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>) -> u32 { | |
755 | let (w, h) = ref_frm.get_dimensions(0); | |
756 | let mut sum = 0; | |
757 | for mb_y in 0..(h / 16) { | |
758 | for mb_x in 0..(w / 16) { | |
759 | sum += Self::satd_mb_diff(ref_frm, cur_frm, mb_x, mb_y); | |
760 | } | |
761 | } | |
762 | sum | |
763 | } | |
764 | fn satd_mb_diff(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>, mb_x: usize, mb_y: usize) -> u32 { | |
765 | let mv = Self::search_mv(ref_frm, cur_frm, mb_x, mb_y); | |
766 | let mut sum = 0; | |
767 | let src0 = ref_frm.get_data(); | |
768 | let src1 = cur_frm.get_data(); | |
769 | let stride = ref_frm.get_stride(0); | |
770 | let (src_x, src_y) = Self::add_mv(mb_x, mb_y, mv); | |
771 | for y in (0..16).step_by(4) { | |
772 | for x in (0..16).step_by(4) { | |
773 | sum += Self::satd(&src0[src_x + x + (src_y + y) * stride..], | |
774 | &src1[mb_x * 16 + x + (mb_y * 16 + y) * stride..], | |
775 | stride); | |
776 | } | |
777 | } | |
778 | sum | |
779 | } | |
780 | fn search_mv(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>, mb_x: usize, mb_y: usize) -> MV { | |
781 | let stride = ref_frm.get_stride(0); | |
782 | let (w, h) = ref_frm.get_dimensions(0); | |
783 | let (v_edge, h_edge) = (w - 16, h - 16); | |
784 | let ref_src = ref_frm.get_data(); | |
785 | let cur_src = cur_frm.get_data(); | |
786 | let cur_src = &cur_src[mb_x * 16 + mb_y * 16 * stride..]; | |
787 | ||
788 | let mut best_mv = ZERO_MV; | |
789 | let mut best_dist = Self::sad(cur_src, ref_src, mb_x, mb_y, stride, best_mv); | |
790 | if best_dist == 0 { | |
791 | return best_mv; | |
792 | } | |
793 | ||
794 | for step in (0..=2).rev() { | |
795 | let mut changed = true; | |
796 | while changed { | |
797 | changed = false; | |
798 | for &mv in DIA_PATTERN[1..].iter() { | |
799 | let cand_mv = best_mv + mv.scale(1 << step); | |
800 | let (cx, cy) = Self::add_mv(mb_x, mb_y, cand_mv); | |
801 | if cx > v_edge || cy > h_edge { | |
802 | continue; | |
803 | } | |
804 | let cand_dist = Self::sad(cur_src, ref_src, mb_x, mb_y, stride, cand_mv); | |
805 | if cand_dist < best_dist { | |
806 | best_dist = cand_dist; | |
807 | best_mv = cand_mv; | |
808 | if best_dist == 0 { | |
809 | return best_mv; | |
810 | } | |
811 | changed = true; | |
812 | } | |
813 | } | |
814 | } | |
815 | } | |
816 | best_mv | |
817 | } | |
818 | fn sad(cur_src: &[u8], src: &[u8], mb_x: usize, mb_y: usize, stride: usize, mv: MV) -> u32 { | |
819 | let (src_x, src_y) = Self::add_mv(mb_x, mb_y, mv); | |
820 | let mut sum = 0; | |
821 | for (line1, line2) in cur_src.chunks(stride).zip(src[src_x + src_y * stride..].chunks(stride)).take(16) { | |
822 | sum += line1[..16].iter().zip(line2[..16].iter()).fold(0u32, | |
823 | |acc, (&a, &b)| acc + u32::from(a.max(b) - a.min(b)) * u32::from(a.max(b) - a.min(b))); | |
824 | } | |
825 | sum | |
826 | } | |
827 | fn satd_i(src: &[u8], mut offset: usize, stride: usize, has_left: bool, has_top: bool) -> u32 { | |
828 | let mut diffs = [0; 16]; | |
829 | match (has_left, has_top) { | |
830 | (true, true) => { | |
831 | for row in diffs.chunks_exact_mut(4) { | |
832 | let mut left = i16::from(src[offset - 1]); | |
833 | let mut tl = i16::from(src[offset - stride - 1]); | |
834 | for (x, dst) in row.iter_mut().enumerate() { | |
835 | let cur = i16::from(src[offset + x]); | |
836 | let top = i16::from(src[offset + x - stride]); | |
837 | ||
838 | *dst = cur - (top + left + tl - top.min(left).min(tl) - top.max(left).max(tl)); | |
839 | ||
840 | left = cur; | |
841 | tl = top; | |
842 | } | |
843 | ||
844 | offset += stride; | |
845 | } | |
846 | }, | |
847 | (true, false) => { | |
848 | for (dst, (left, cur)) in diffs.chunks_exact_mut(4).zip( | |
849 | src[offset - 1..].chunks(stride).zip(src[offset..].chunks(stride))) { | |
850 | for (dst, (&left, &cur)) in dst.iter_mut().zip(left.iter().zip(cur.iter())) { | |
851 | *dst = i16::from(cur) - i16::from(left); | |
852 | } | |
853 | } | |
854 | }, | |
855 | (false, true) => { | |
856 | for (dst, (top, cur)) in diffs.chunks_exact_mut(4).zip( | |
857 | src[offset - stride..].chunks(stride).zip(src[offset..].chunks(stride))) { | |
858 | for (dst, (&top, &cur)) in dst.iter_mut().zip(top.iter().zip(cur.iter())) { | |
859 | *dst = i16::from(cur) - i16::from(top); | |
860 | } | |
861 | } | |
862 | }, | |
863 | (false, false) => { | |
864 | for (dst, src) in diffs.chunks_exact_mut(4).zip(src[offset..].chunks(stride)) { | |
865 | for (dst, &src) in dst.iter_mut().zip(src.iter()) { | |
866 | *dst = i16::from(src) - 128; | |
867 | } | |
868 | } | |
869 | }, | |
870 | }; | |
871 | for row in diffs.chunks_exact_mut(4) { | |
872 | hadamard!(row[0], row[1], row[2], row[3], row[0], row[1], row[2], row[3]); | |
873 | } | |
874 | for i in 0..4 { | |
875 | hadamard!(diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12], | |
876 | diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12]); | |
877 | } | |
e6aaad5c | 878 | diffs.iter().fold(0u32, |acc, x| acc + (x.unsigned_abs() as u32)) |
4965a5e5 KS |
879 | } |
880 | fn satd(src0: &[u8], src1: &[u8], stride: usize) -> u32 { | |
881 | let mut diffs = [0; 16]; | |
882 | for (dst, (src0, src1)) in diffs.chunks_exact_mut(4).zip( | |
883 | src0.chunks(stride).zip(src1.chunks(stride))) { | |
884 | hadamard!(i16::from(src0[0]) - i16::from(src1[0]), | |
885 | i16::from(src0[1]) - i16::from(src1[1]), | |
886 | i16::from(src0[2]) - i16::from(src1[2]), | |
887 | i16::from(src0[3]) - i16::from(src1[3]), | |
888 | dst[0], dst[1], dst[2], dst[3]); | |
889 | } | |
890 | for i in 0..4 { | |
891 | hadamard!(diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12], | |
892 | diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12]); | |
893 | } | |
e6aaad5c | 894 | diffs.iter().fold(0u32, |acc, x| acc + (x.unsigned_abs() as u32)) |
4965a5e5 KS |
895 | } |
896 | fn downscale(dst: &mut NAVideoBuffer<u8>, src: &NAVideoBuffer<u8>) { | |
897 | let dst = NASimpleVideoFrame::from_video_buf(dst).unwrap(); | |
898 | let sdata = src.get_data(); | |
899 | for plane in 0..3 { | |
900 | let cur_w = dst.width[plane]; | |
901 | let cur_h = dst.height[plane]; | |
902 | let doff = dst.offset[plane]; | |
903 | let soff = src.get_offset(plane); | |
904 | let dstride = dst.stride[plane]; | |
905 | let sstride = src.get_stride(plane); | |
906 | for (dline, sstrip) in dst.data[doff..].chunks_exact_mut(dstride).zip( | |
907 | sdata[soff..].chunks_exact(sstride * 2)).take(cur_h) { | |
908 | let (line0, line1) = sstrip.split_at(sstride); | |
909 | for (dst, (src0, src1)) in dline.iter_mut().zip( | |
910 | line0.chunks_exact(2).zip(line1.chunks_exact(2))).take(cur_w) { | |
911 | *dst = ((u16::from(src0[0]) + u16::from(src0[1]) + | |
912 | u16::from(src1[0]) + u16::from(src1[1]) + 2) >> 2) as u8; | |
913 | } | |
914 | } | |
915 | } | |
916 | } | |
917 | } |