]>
Commit | Line | Data |
---|---|---|
1 | use nihav_core::frame::*; | |
2 | use nihav_codec_support::codecs::{MV, ZERO_MV}; | |
3 | use std::str::FromStr; | |
4 | use super::dsp::{RefMBData, luma_mc, chroma_mc}; | |
5 | ||
6 | #[derive(Clone,Copy,PartialEq)] | |
7 | pub enum MVSearchMode { | |
8 | Dummy, | |
9 | Diamond, | |
10 | Hexagon, | |
11 | UMH, | |
12 | } | |
13 | ||
14 | impl MVSearchMode { | |
15 | pub const fn get_possible_modes() -> &'static [&'static str] { | |
16 | &["diamond", "hexagon", "umh"] | |
17 | } | |
18 | fn create(self) -> Box<dyn MVSearch+Send> { | |
19 | match self { | |
20 | MVSearchMode::Dummy => Box::new(DummySearcher{}), | |
21 | MVSearchMode::Diamond => Box::new(DiaSearch::new()), | |
22 | MVSearchMode::Hexagon => Box::new(HexSearch::new()), | |
23 | MVSearchMode::UMH => Box::new(UnevenHexSearch::new()), | |
24 | } | |
25 | } | |
26 | } | |
27 | ||
28 | impl Default for MVSearchMode { | |
29 | fn default() -> Self { MVSearchMode::Hexagon } | |
30 | } | |
31 | ||
32 | impl std::fmt::Display for MVSearchMode { | |
33 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | |
34 | match *self { | |
35 | MVSearchMode::Diamond => write!(f, "diamond"), | |
36 | MVSearchMode::Hexagon => write!(f, "hexagon"), | |
37 | MVSearchMode::UMH => write!(f, "umh"), | |
38 | MVSearchMode::Dummy => write!(f, "dummy"), | |
39 | } | |
40 | } | |
41 | } | |
42 | ||
43 | impl FromStr for MVSearchMode { | |
44 | type Err = (); | |
45 | fn from_str(s: &str) -> Result<Self, Self::Err> { | |
46 | match s { | |
47 | "diamond" => Ok(MVSearchMode::Diamond), | |
48 | "hexagon" => Ok(MVSearchMode::Hexagon), | |
49 | "umh" => Ok(MVSearchMode::UMH), | |
50 | "dummy" => Ok(MVSearchMode::Dummy), | |
51 | _ => Err(()), | |
52 | } | |
53 | } | |
54 | } | |
55 | ||
56 | const MAX_DIST: u32 = std::u32::MAX; | |
57 | const DIST_THRESH: u32 = 256; | |
58 | ||
59 | trait FromPixels { | |
60 | fn from_pixels(self) -> Self; | |
61 | } | |
62 | ||
63 | impl FromPixels for MV { | |
64 | fn from_pixels(self) -> MV { | |
65 | MV { x: self.x * 4, y: self.y * 4 } | |
66 | } | |
67 | } | |
68 | ||
69 | const DIA_PATTERN: [MV; 9] = [ | |
70 | ZERO_MV, | |
71 | MV {x: -2, y: 0}, | |
72 | MV {x: -1, y: 1}, | |
73 | MV {x: 0, y: 2}, | |
74 | MV {x: 1, y: 1}, | |
75 | MV {x: 2, y: 0}, | |
76 | MV {x: 1, y: -1}, | |
77 | MV {x: 0, y: -2}, | |
78 | MV {x: -1, y: -1} | |
79 | ]; | |
80 | ||
81 | const HEX_PATTERN: [MV; 7] = [ | |
82 | ZERO_MV, | |
83 | MV {x: -2, y: 0}, | |
84 | MV {x: -1, y: 2}, | |
85 | MV {x: 1, y: 2}, | |
86 | MV {x: 2, y: 0}, | |
87 | MV {x: 1, y: -2}, | |
88 | MV {x: -1, y: -2} | |
89 | ]; | |
90 | ||
91 | const REFINEMENT: [MV; 4] = [ | |
92 | MV {x: -1, y: 0}, | |
93 | MV {x: 0, y: 1}, | |
94 | MV {x: 1, y: 0}, | |
95 | MV {x: 0, y: -1} | |
96 | ]; | |
97 | ||
98 | macro_rules! search_template { | |
99 | ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident, $threshold: expr) => ({ | |
100 | search_template!($self, $mv_est, $cur_blk, $mb_x, $mb_y, $sad_func, $threshold, ZERO_MV, MAX_DIST, true) | |
101 | }); | |
102 | ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident, $threshold: expr, $start_mv: expr, $best_dist: expr, $fullpel_stage: expr) => ({ | |
103 | let mut best_dist = $best_dist; | |
104 | let mut best_mv = $start_mv; | |
105 | ||
106 | let mut min_dist; | |
107 | let mut min_idx; | |
108 | ||
109 | if $fullpel_stage { | |
110 | $self.reset(); | |
111 | loop { | |
112 | let mut cur_best_dist = best_dist; | |
113 | for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) { | |
114 | if *dist == MAX_DIST { | |
115 | *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point.from_pixels(), cur_best_dist); | |
116 | cur_best_dist = cur_best_dist.min(*dist); | |
117 | if *dist <= $threshold { | |
118 | break; | |
119 | } | |
120 | } | |
121 | } | |
122 | min_dist = $self.dist[0]; | |
123 | min_idx = 0; | |
124 | for (i, &dist) in $self.dist.iter().enumerate().skip(1) { | |
125 | if dist < min_dist { | |
126 | min_dist = dist; | |
127 | min_idx = i; | |
128 | if dist <= $threshold { | |
129 | break; | |
130 | } | |
131 | } | |
132 | } | |
133 | if min_dist <= $threshold || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range || $self.point[min_idx].y.abs() >= $mv_est.mv_range { | |
134 | break; | |
135 | } | |
136 | best_dist = min_dist; | |
137 | $self.update($self.steps[min_idx]); | |
138 | } | |
139 | best_dist = min_dist; | |
140 | best_mv = $self.point[min_idx]; | |
141 | if best_dist <= $threshold { | |
142 | return (best_mv.from_pixels(), best_dist); | |
143 | } | |
144 | for &step in REFINEMENT.iter() { | |
145 | let mv = best_mv + step; | |
146 | let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv.from_pixels(), MAX_DIST); | |
147 | if best_dist > dist { | |
148 | best_dist = dist; | |
149 | best_mv = mv; | |
150 | } | |
151 | } | |
152 | best_mv = best_mv.from_pixels(); | |
153 | if best_dist <= $threshold { | |
154 | return (best_mv, best_dist); | |
155 | } | |
156 | } | |
157 | ||
158 | // subpel refinement | |
159 | $self.set_new_point(best_mv, best_dist); | |
160 | loop { | |
161 | let mut cur_best_dist = best_dist; | |
162 | for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) { | |
163 | if *dist == MAX_DIST { | |
164 | *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point, cur_best_dist); | |
165 | cur_best_dist = cur_best_dist.min(*dist); | |
166 | if *dist <= $threshold { | |
167 | break; | |
168 | } | |
169 | } | |
170 | } | |
171 | min_dist = $self.dist[0]; | |
172 | min_idx = 0; | |
173 | for (i, &dist) in $self.dist.iter().enumerate().skip(1) { | |
174 | if dist < min_dist { | |
175 | min_dist = dist; | |
176 | min_idx = i; | |
177 | if dist <= $threshold { | |
178 | break; | |
179 | } | |
180 | } | |
181 | } | |
182 | if min_dist <= $threshold || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range * 8 || $self.point[min_idx].y.abs() >= $mv_est.mv_range * 8 { | |
183 | break; | |
184 | } | |
185 | best_dist = min_dist; | |
186 | $self.update($self.steps[min_idx]); | |
187 | } | |
188 | best_dist = min_dist; | |
189 | best_mv = $self.point[min_idx]; | |
190 | if best_dist <= $threshold { | |
191 | return (best_mv, best_dist); | |
192 | } | |
193 | for &step in REFINEMENT.iter() { | |
194 | let mv = best_mv + step; | |
195 | let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv, MAX_DIST); | |
196 | if best_dist > dist { | |
197 | best_dist = dist; | |
198 | best_mv = mv; | |
199 | } | |
200 | } | |
201 | (best_mv, best_dist) | |
202 | }); | |
203 | } | |
204 | ||
205 | macro_rules! pattern_search { | |
206 | ($struct_name: ident, $patterns: expr) => { | |
207 | pub struct $struct_name { | |
208 | point: [MV; $patterns.len()], | |
209 | dist: [u32; $patterns.len()], | |
210 | steps: &'static [MV; $patterns.len()], | |
211 | } | |
212 | ||
213 | impl $struct_name { | |
214 | pub fn new() -> Self { | |
215 | Self { | |
216 | point: $patterns, | |
217 | dist: [MAX_DIST; $patterns.len()], | |
218 | steps: &$patterns, | |
219 | } | |
220 | } | |
221 | fn reset(&mut self) { | |
222 | self.point = $patterns; | |
223 | self.dist = [MAX_DIST; $patterns.len()]; | |
224 | } | |
225 | fn set_new_point(&mut self, start: MV, dist: u32) { | |
226 | for (dst, &src) in self.point.iter_mut().zip(self.steps.iter()) { | |
227 | *dst = src + start; | |
228 | } | |
229 | self.dist = [MAX_DIST; $patterns.len()]; | |
230 | self.dist[0] = dist; | |
231 | } | |
232 | fn update(&mut self, step: MV) { | |
233 | let mut new_point = self.point; | |
234 | let mut new_dist = [MAX_DIST; $patterns.len()]; | |
235 | ||
236 | for point in new_point.iter_mut() { | |
237 | *point += step; | |
238 | } | |
239 | ||
240 | for (new_point, new_dist) in new_point.iter_mut().zip(new_dist.iter_mut()) { | |
241 | for (&old_point, &old_dist) in self.point.iter().zip(self.dist.iter()) { | |
242 | if *new_point == old_point { | |
243 | *new_dist = old_dist; | |
244 | break; | |
245 | } | |
246 | } | |
247 | } | |
248 | self.point = new_point; | |
249 | self.dist = new_dist; | |
250 | } | |
251 | } | |
252 | ||
253 | impl MVSearch for $struct_name { | |
254 | fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &RefMBData, mb_x: usize, mb_y: usize, _cand_mvs: &[MV]) -> (MV, u32) { | |
255 | search_template!(self, mv_est, cur_mb, mb_x, mb_y, sad_mb, DIST_THRESH) | |
256 | } | |
257 | fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) { | |
258 | search_template!(self, mv_est, ref_blk, xpos, ypos, sad_blk8, DIST_THRESH / 4) | |
259 | } | |
260 | } | |
261 | } | |
262 | } | |
263 | ||
264 | pattern_search!(DiaSearch, DIA_PATTERN); | |
265 | pattern_search!(HexSearch, HEX_PATTERN); | |
266 | ||
267 | const LARGE_HEX_PATTERN: [MV; 16] = [ | |
268 | MV { x: -4, y: 0 }, | |
269 | MV { x: -4, y: 1 }, | |
270 | MV { x: -4, y: 2 }, | |
271 | MV { x: -2, y: 3 }, | |
272 | MV { x: 0, y: 4 }, | |
273 | MV { x: 2, y: 3 }, | |
274 | MV { x: 4, y: 2 }, | |
275 | MV { x: 4, y: 1 }, | |
276 | MV { x: 4, y: 0 }, | |
277 | MV { x: 4, y: -1 }, | |
278 | MV { x: 4, y: -2 }, | |
279 | MV { x: -2, y: -3 }, | |
280 | MV { x: 0, y: -4 }, | |
281 | MV { x: -2, y: -3 }, | |
282 | MV { x: -4, y: -2 }, | |
283 | MV { x: -4, y: -1 } | |
284 | ]; | |
285 | ||
286 | const UNSYMM_CROSS: [MV; 4] = [ | |
287 | MV { x: -2, y: 0 }, | |
288 | MV { x: 0, y: 1 }, | |
289 | MV { x: 2, y: 0 }, | |
290 | MV { x: 0, y: -1 } | |
291 | ]; | |
292 | ||
293 | #[derive(Default)] | |
294 | struct UniqueSet<T:Copy+Default> { | |
295 | list: [T; 16], | |
296 | count: usize, | |
297 | } | |
298 | ||
299 | impl<T:Copy+Default+PartialEq> UniqueSet<T> { | |
300 | fn new() -> Self { Self::default() } | |
301 | fn clear(&mut self) { self.count = 0; } | |
302 | fn get_list(&self) -> &[T] { &self.list[..self.count] } | |
303 | fn add(&mut self, val: T) { | |
304 | if self.count < self.list.len() && !self.get_list().contains(&val) { | |
305 | self.list[self.count] = val; | |
306 | self.count += 1; | |
307 | } | |
308 | } | |
309 | } | |
310 | ||
311 | trait MVOps { | |
312 | fn scale(self, scale: i16) -> Self; | |
313 | fn is_in_range(self, range: i16) -> bool; | |
314 | } | |
315 | ||
316 | impl MVOps for MV { | |
317 | fn scale(self, scale: i16) -> MV { | |
318 | MV { x: self.x * scale, y: self.y * scale } | |
319 | } | |
320 | fn is_in_range(self, range: i16) -> bool { | |
321 | self.x.abs() <= range && self.y.abs() <= range | |
322 | } | |
323 | } | |
324 | ||
325 | macro_rules! single_search_step { | |
326 | ($start:expr, $best_dist:expr, $mv_est:expr, $sad_func:ident, $ref_blk:expr, $xpos:expr, $ypos:expr, $pattern:expr, $scale:expr, $dist_thr:expr) => {{ | |
327 | let mut best_mv = $start; | |
328 | let mut best_dist = $best_dist; | |
329 | for point in $pattern.iter() { | |
330 | let mv = point.scale($scale) + $start; | |
331 | if !mv.is_in_range($mv_est.mv_range * 4) { | |
332 | continue; | |
333 | } | |
334 | let dist = $mv_est.$sad_func($ref_blk, $xpos, $ypos, mv, best_dist); | |
335 | if dist < best_dist { | |
336 | best_mv = mv; | |
337 | best_dist = dist; | |
338 | if best_dist < $dist_thr { | |
339 | break; | |
340 | } | |
341 | } | |
342 | } | |
343 | (best_mv, best_dist, best_mv != $start) | |
344 | }} | |
345 | } | |
346 | ||
347 | struct UnevenHexSearch { | |
348 | mv_list: UniqueSet<MV>, | |
349 | } | |
350 | ||
351 | impl UnevenHexSearch { | |
352 | fn new() -> Self { | |
353 | Self { | |
354 | mv_list: UniqueSet::new(), | |
355 | } | |
356 | } | |
357 | fn get_cand_mv(&mut self, cand_mvs: &[MV]) -> MV { | |
358 | self.mv_list.clear(); | |
359 | for &mv in cand_mvs.iter() { | |
360 | self.mv_list.add(mv); | |
361 | } | |
362 | match self.mv_list.count { | |
363 | 1 => self.mv_list.list[0], | |
364 | 3 => MV::pred(self.mv_list.list[0], self.mv_list.list[1], self.mv_list.list[2]), | |
365 | _ => { | |
366 | let sum = self.mv_list.get_list().iter().fold((0i32, 0i32), | |
367 | |acc, mv| (acc.0 + i32::from(mv.x), acc.1 + i32::from(mv.y))); | |
368 | MV {x: (sum.0 / (self.mv_list.count as i32)) as i16, | |
369 | y: (sum.1 / (self.mv_list.count as i32)) as i16} | |
370 | }, | |
371 | } | |
372 | } | |
373 | } | |
374 | ||
375 | macro_rules! umh_search_template { | |
376 | ($cand_mv:expr, $cutoff:expr, $mv_est:expr, $sad_func:ident, $ref_blk:expr, $xpos:expr, $ypos:expr) => {{ | |
377 | let cand_mv = $cand_mv; | |
378 | let best_dist = $mv_est.$sad_func($ref_blk, $xpos, $ypos, cand_mv, MAX_DIST); | |
379 | if best_dist < $cutoff { | |
380 | return (cand_mv, best_dist); | |
381 | } | |
382 | ||
383 | // step 1 - small refinement search | |
384 | let (mut cand_mv, mut best_dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, DIA_PATTERN, 1, $cutoff); | |
385 | if best_dist < $cutoff { | |
386 | return (cand_mv, best_dist); | |
387 | } | |
388 | ||
389 | // step 2 - unsymmetrical cross search | |
390 | loop { | |
391 | let (mv, dist, changed) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, UNSYMM_CROSS, 4, $cutoff); | |
392 | if !changed { | |
393 | break; | |
394 | } | |
395 | cand_mv = mv; | |
396 | best_dist = dist; | |
397 | if best_dist < $cutoff { | |
398 | return (mv, dist); | |
399 | } | |
400 | } | |
401 | ||
402 | // step 3 - multi-hexagon grid search | |
403 | let mut scale = 4; | |
404 | while scale > 0 { | |
405 | let (mv, dist, changed) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, LARGE_HEX_PATTERN, scale, $cutoff); | |
406 | if !changed { | |
407 | break; | |
408 | } | |
409 | cand_mv = mv; | |
410 | best_dist = dist; | |
411 | if best_dist < $cutoff { | |
412 | return (mv, dist); | |
413 | } | |
414 | scale >>= 1; | |
415 | } | |
416 | // step 4 - final hexagon search | |
417 | let (cand_mv, best_dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, HEX_PATTERN, 1, $cutoff); | |
418 | if best_dist > $cutoff { | |
419 | let (mv, dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, DIA_PATTERN, 1, $cutoff); | |
420 | (mv, dist) | |
421 | } else { | |
422 | (cand_mv, best_dist) | |
423 | } | |
424 | }} | |
425 | } | |
426 | ||
427 | impl MVSearch for UnevenHexSearch { | |
428 | fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32) { | |
429 | let cand_mv = self.get_cand_mv(cand_mvs); | |
430 | let cutoff = mv_est.cutoff_thr; | |
431 | umh_search_template!(cand_mv, cutoff, mv_est, sad_mb, cur_mb, mb_x, mb_y) | |
432 | } | |
433 | fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32) { | |
434 | let cand_mv = self.get_cand_mv(cand_mvs); | |
435 | let cutoff = mv_est.cutoff_thr / 4; | |
436 | umh_search_template!(cand_mv, cutoff, mv_est, sad_blk8, ref_blk, xpos, ypos) | |
437 | } | |
438 | } | |
439 | ||
440 | struct MVEstimator<'a> { | |
441 | pic: &'a NAVideoBuffer<u8>, | |
442 | mv_range: i16, | |
443 | cutoff_thr: u32, | |
444 | } | |
445 | ||
446 | macro_rules! sad { | |
447 | ($src1:expr, $src2:expr) => { | |
448 | $src1.iter().zip($src2.iter()).fold(0u32, |acc, (&a, &b)| | |
449 | acc + (((i32::from(a) - i32::from(b)) * (i32::from(a) - i32::from(b))) as u32)) | |
450 | } | |
451 | } | |
452 | ||
453 | impl<'a> MVEstimator<'a> { | |
454 | fn sad_mb(&self, ref_mb: &RefMBData, mb_x: usize, mb_y: usize, mv: MV, cur_best_dist: u32) -> u32 { | |
455 | let mut dst = RefMBData::new(); | |
456 | luma_mc(&mut dst.y, 16, self.pic, mb_x * 16, mb_y * 16, mv, true); | |
457 | ||
458 | let mut dist = 0; | |
459 | for (dline, sline) in dst.y.chunks(16).zip(ref_mb.y.chunks(16)) { | |
460 | dist += sad!(dline, sline); | |
461 | if dist > cur_best_dist { | |
462 | return dist; | |
463 | } | |
464 | } | |
465 | chroma_mc(&mut dst.u, 8, self.pic, mb_x * 8, mb_y * 8, 1, mv, true); | |
466 | dist += sad!(dst.u, ref_mb.u); | |
467 | if dist > cur_best_dist { | |
468 | return dist; | |
469 | } | |
470 | chroma_mc(&mut dst.v, 8, self.pic, mb_x * 8, mb_y * 8, 2, mv, true); | |
471 | dist += sad!(dst.v, ref_mb.v); | |
472 | ||
473 | dist | |
474 | } | |
475 | fn sad_blk8(&self, ref_mb: &RefMBData, xpos: usize, ypos: usize, mv: MV, cur_best_dist: u32) -> u32 { | |
476 | let mut cur_y = [0; 64]; | |
477 | let mut cur_u = [0; 16]; | |
478 | let mut cur_v = [0; 16]; | |
479 | ||
480 | let mut dist = 0; | |
481 | ||
482 | let y_off = (xpos & 8) + (ypos & 8) * 16; | |
483 | luma_mc(&mut cur_y, 8, self.pic, xpos, ypos, mv, false); | |
484 | for (dline, sline) in cur_y.chunks(8).zip(ref_mb.y[y_off..].chunks(16)) { | |
485 | dist += sad!(dline, sline); | |
486 | if dist > cur_best_dist { | |
487 | return dist; | |
488 | } | |
489 | } | |
490 | ||
491 | let c_off = (xpos & 8) / 2 + (ypos & 8) * 4; | |
492 | chroma_mc(&mut cur_u, 4, self.pic, xpos / 2, ypos / 2, 1, mv, false); | |
493 | for (dline, sline) in cur_u.chunks(4).zip(ref_mb.u[c_off..].chunks(8)) { | |
494 | dist += sad!(dline, sline); | |
495 | if dist > cur_best_dist { | |
496 | return dist; | |
497 | } | |
498 | } | |
499 | chroma_mc(&mut cur_v, 4, self.pic, xpos / 2, ypos / 2, 2, mv, false); | |
500 | for (dline, sline) in cur_v.chunks(4).zip(ref_mb.v[c_off..].chunks(8)) { | |
501 | dist += sad!(dline, sline); | |
502 | if dist > cur_best_dist { | |
503 | return dist; | |
504 | } | |
505 | } | |
506 | ||
507 | dist | |
508 | } | |
509 | } | |
510 | ||
511 | trait MVSearch { | |
512 | fn search_mb(&mut self, mv_est: &mut MVEstimator, ref_mb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32); | |
513 | fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32); | |
514 | } | |
515 | ||
516 | struct DummySearcher {} | |
517 | ||
518 | impl MVSearch for DummySearcher { | |
519 | fn search_mb(&mut self, _mv_est: &mut MVEstimator, _ref_mb: &RefMBData, _mb_x: usize, _mb_y: usize, _cand_mvs: &[MV]) -> (MV, u32) { | |
520 | (ZERO_MV, std::u32::MAX / 2) | |
521 | } | |
522 | fn search_blk8(&mut self, _mv_est: &mut MVEstimator, _ref_mb: &RefMBData, _xpos: usize, _ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) { | |
523 | (ZERO_MV, std::u32::MAX / 2) | |
524 | } | |
525 | } | |
526 | ||
527 | pub struct MotionEstimator { | |
528 | pub range: i16, | |
529 | pub thresh: u32, | |
530 | mode: MVSearchMode, | |
531 | srch: Box<dyn MVSearch+Send>, | |
532 | } | |
533 | ||
534 | impl MotionEstimator { | |
535 | pub fn new() -> Self { | |
536 | let mode = MVSearchMode::default(); | |
537 | Self { | |
538 | range: 64, | |
539 | thresh: 32, | |
540 | mode, | |
541 | srch: mode.create(), | |
542 | } | |
543 | } | |
544 | pub fn get_mode(&self) -> MVSearchMode { self.mode } | |
545 | pub fn set_mode(&mut self, new_mode: MVSearchMode) { | |
546 | if self.mode != new_mode { | |
547 | self.mode = new_mode; | |
548 | self.srch = self.mode.create(); | |
549 | } | |
550 | } | |
551 | pub fn search_mb_p(&mut self, pic: &NAVideoBuffer<u8>, refmb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32) { | |
552 | let mut mv_est = MVEstimator { | |
553 | mv_range: self.range, | |
554 | cutoff_thr: self.thresh, | |
555 | pic, | |
556 | }; | |
557 | self.srch.search_mb(&mut mv_est, refmb, mb_x, mb_y, cand_mvs) | |
558 | } | |
559 | pub fn search_blk8(&mut self, pic: &NAVideoBuffer<u8>, refmb: &RefMBData, xoff: usize, yoff: usize, cand_mvs: &[MV]) -> (MV, u32) { | |
560 | let mut mv_est = MVEstimator { | |
561 | mv_range: self.range, | |
562 | cutoff_thr: self.thresh, | |
563 | pic, | |
564 | }; | |
565 | self.srch.search_blk8(&mut mv_est, refmb, xoff, yoff, cand_mvs) | |
566 | } | |
567 | } | |
568 | ||
569 | pub struct SearchB<'a> { | |
570 | ref_p: &'a NAVideoBuffer<u8>, | |
571 | ref_n: &'a NAVideoBuffer<u8>, | |
572 | xpos: usize, | |
573 | ypos: usize, | |
574 | ratios: [u32; 2], | |
575 | tmp1: RefMBData, | |
576 | tmp2: RefMBData, | |
577 | pred_blk: RefMBData, | |
578 | } | |
579 | ||
580 | impl<'a> SearchB<'a> { | |
581 | pub fn new(ref_p: &'a NAVideoBuffer<u8>, ref_n: &'a NAVideoBuffer<u8>, mb_x: usize, mb_y: usize, ratios: [u32; 2]) -> Self { | |
582 | Self { | |
583 | ref_p, ref_n, | |
584 | xpos: mb_x * 16, | |
585 | ypos: mb_y * 16, | |
586 | ratios, | |
587 | tmp1: RefMBData::new(), | |
588 | tmp2: RefMBData::new(), | |
589 | pred_blk: RefMBData::new(), | |
590 | } | |
591 | } | |
592 | pub fn search_mb(&mut self, ref_mb: &RefMBData, cand_mvs: [MV; 2]) -> (MV, MV) { | |
593 | let mut best_cand = cand_mvs; | |
594 | let mut best_dist = self.interp_b_dist(ref_mb, best_cand, MAX_DIST); | |
595 | ||
596 | loop { | |
597 | let mut improved = false; | |
598 | for &fmv_add in DIA_PATTERN.iter() { | |
599 | for &bmv_add in DIA_PATTERN.iter() { | |
600 | let cand = [best_cand[0] + fmv_add.from_pixels(), | |
601 | best_cand[1] + bmv_add.from_pixels()]; | |
602 | let dist = self.interp_b_dist(ref_mb, cand, best_dist); | |
603 | if dist < best_dist { | |
604 | best_dist = dist; | |
605 | best_cand = cand; | |
606 | improved = true; | |
607 | } | |
608 | } | |
609 | } | |
610 | if !improved { | |
611 | break; | |
612 | } | |
613 | } | |
614 | ||
615 | for &fmv_add in REFINEMENT.iter() { | |
616 | for &bmv_add in REFINEMENT.iter() { | |
617 | let cand = [best_cand[0] + fmv_add, best_cand[1] + bmv_add]; | |
618 | let dist = self.interp_b_dist(ref_mb, cand, best_dist); | |
619 | if dist < best_dist { | |
620 | best_dist = dist; | |
621 | best_cand = cand; | |
622 | } | |
623 | } | |
624 | } | |
625 | ||
626 | (best_cand[0], best_cand[1]) | |
627 | } | |
628 | fn interp_b_dist(&mut self, ref_mb: &RefMBData, cand_mv: [MV; 2], cur_best_dist: u32) -> u32 { | |
629 | let [fmv, bmv] = cand_mv; | |
630 | luma_mc(&mut self.tmp1.y, 16, self.ref_p, self.xpos, self.ypos, fmv, true); | |
631 | chroma_mc(&mut self.tmp1.u, 8, self.ref_p, self.xpos / 2, self.ypos / 2, 1, fmv, true); | |
632 | chroma_mc(&mut self.tmp1.v, 8, self.ref_p, self.xpos / 2, self.ypos / 2, 2, fmv, true); | |
633 | luma_mc(&mut self.tmp2.y, 16, self.ref_n, self.xpos, self.ypos, bmv, true); | |
634 | chroma_mc(&mut self.tmp2.u, 8, self.ref_n, self.xpos / 2, self.ypos / 2, 1, bmv, true); | |
635 | chroma_mc(&mut self.tmp2.v, 8, self.ref_n, self.xpos / 2, self.ypos / 2, 2, bmv, true); | |
636 | self.pred_blk.avg(&self.tmp1, self.ratios[0], &self.tmp2, self.ratios[1]); | |
637 | ||
638 | let mut dist = 0; | |
639 | for (dline, sline) in self.pred_blk.y.chunks(16).zip(ref_mb.y.chunks(16)) { | |
640 | dist += sad!(dline, sline); | |
641 | if dist > cur_best_dist { | |
642 | return dist; | |
643 | } | |
644 | } | |
645 | dist += sad!(self.pred_blk.u, ref_mb.u); | |
646 | if dist > cur_best_dist { | |
647 | return dist; | |
648 | } | |
649 | dist += sad!(self.pred_blk.v, ref_mb.v); | |
650 | ||
651 | dist | |
652 | } | |
653 | } | |
654 | ||
655 | macro_rules! hadamard { | |
656 | ($s0:expr, $s1:expr, $s2:expr, $s3:expr, $d0:expr, $d1:expr, $d2:expr, $d3:expr) => { | |
657 | let t0 = $s0 + $s1; | |
658 | let t1 = $s0 - $s1; | |
659 | let t2 = $s2 + $s3; | |
660 | let t3 = $s2 - $s3; | |
661 | $d0 = t0 + t2; | |
662 | $d2 = t0 - t2; | |
663 | $d1 = t1 + t3; | |
664 | $d3 = t1 - t3; | |
665 | } | |
666 | } | |
667 | ||
668 | pub struct FrameComplexityEstimate { | |
669 | ref_frm: NAVideoBufferRef<u8>, | |
670 | cur_frm: NAVideoBufferRef<u8>, | |
671 | nxt_frm: NAVideoBufferRef<u8>, | |
672 | width: usize, | |
673 | height: usize, | |
674 | } | |
675 | ||
676 | impl FrameComplexityEstimate { | |
677 | pub fn new() -> Self { | |
678 | let vinfo = NAVideoInfo::new(24, 24, false, YUV420_FORMAT); | |
679 | let vt = alloc_video_buffer(vinfo, 4).unwrap(); | |
680 | let buf = vt.get_vbuf().unwrap(); | |
681 | Self { | |
682 | ref_frm: buf.clone(), | |
683 | cur_frm: buf.clone(), | |
684 | nxt_frm: buf, | |
685 | width: 0, | |
686 | height: 0, | |
687 | } | |
688 | } | |
689 | pub fn resize(&mut self, width: usize, height: usize) { | |
690 | if width != self.width || height != self.height { | |
691 | self.width = width; | |
692 | self.height = height; | |
693 | ||
694 | let vinfo = NAVideoInfo::new(self.width / 2, self.height / 2, false, YUV420_FORMAT); | |
695 | let vt = alloc_video_buffer(vinfo, 4).unwrap(); | |
696 | self.ref_frm = vt.get_vbuf().unwrap(); | |
697 | let frm = self.ref_frm.get_data_mut().unwrap(); | |
698 | for el in frm.iter_mut() { | |
699 | *el = 0x80; | |
700 | } | |
701 | let vt = alloc_video_buffer(vinfo, 4).unwrap(); | |
702 | self.cur_frm = vt.get_vbuf().unwrap(); | |
703 | let vt = alloc_video_buffer(vinfo, 4).unwrap(); | |
704 | self.nxt_frm = vt.get_vbuf().unwrap(); | |
705 | } | |
706 | } | |
707 | pub fn set_current(&mut self, frm: &NAVideoBuffer<u8>) { | |
708 | Self::downscale(&mut self.cur_frm, frm); | |
709 | } | |
710 | pub fn get_complexity(&self, ftype: FrameType) -> u32 { | |
711 | match ftype { | |
712 | FrameType::I => Self::calculate_i_cplx(&self.cur_frm), | |
713 | FrameType::P => Self::calculate_mv_diff(&self.ref_frm, &self.cur_frm), | |
714 | _ => 0, | |
715 | } | |
716 | } | |
717 | pub fn decide_b_frame(&mut self, frm1: &NAVideoBuffer<u8>, frm2: &NAVideoBuffer<u8>) -> bool { | |
718 | Self::downscale(&mut self.cur_frm, frm1); | |
719 | Self::downscale(&mut self.nxt_frm, frm2); | |
720 | let diff_ref_cur = Self::calculate_mv_diff(&self.ref_frm, &self.cur_frm); | |
721 | let diff_cur_nxt = Self::calculate_mv_diff(&self.cur_frm, &self.nxt_frm); | |
722 | ||
723 | // simple rule - if complexity ref->cur and cur->next is about the same this should be a B-frame | |
724 | let ddiff = diff_ref_cur.max(diff_cur_nxt) - diff_ref_cur.min(diff_cur_nxt); | |
725 | if ddiff < 256 { | |
726 | true | |
727 | } else { | |
728 | let mut order = 0; | |
729 | while (ddiff << order) < diff_ref_cur.min(diff_cur_nxt) { | |
730 | order += 1; | |
731 | } | |
732 | order > 2 | |
733 | } | |
734 | } | |
735 | pub fn update_ref(&mut self) { | |
736 | std::mem::swap(&mut self.ref_frm, &mut self.cur_frm); | |
737 | } | |
738 | ||
739 | fn add_mv(mb_x: usize, mb_y: usize, mv: MV) -> (usize, usize) { | |
740 | (((mb_x * 16) as isize + (mv.x as isize)) as usize, | |
741 | ((mb_y * 16) as isize + (mv.y as isize)) as usize) | |
742 | } | |
743 | fn calculate_i_cplx(frm: &NAVideoBuffer<u8>) -> u32 { | |
744 | let (w, h) = frm.get_dimensions(0); | |
745 | let src = frm.get_data(); | |
746 | let stride = frm.get_stride(0); | |
747 | let mut sum = 0; | |
748 | let mut offset = 0; | |
749 | for y in (0..h).step_by(4) { | |
750 | for x in (0..w).step_by(4) { | |
751 | sum += Self::satd_i(src, offset + x, stride, x > 0, y > 0); | |
752 | } | |
753 | offset += stride * 4; | |
754 | } | |
755 | sum | |
756 | } | |
757 | fn calculate_mv_diff(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>) -> u32 { | |
758 | let (w, h) = ref_frm.get_dimensions(0); | |
759 | let mut sum = 0; | |
760 | for mb_y in 0..(h / 16) { | |
761 | for mb_x in 0..(w / 16) { | |
762 | sum += Self::satd_mb_diff(ref_frm, cur_frm, mb_x, mb_y); | |
763 | } | |
764 | } | |
765 | sum | |
766 | } | |
767 | fn satd_mb_diff(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>, mb_x: usize, mb_y: usize) -> u32 { | |
768 | let mv = Self::search_mv(ref_frm, cur_frm, mb_x, mb_y); | |
769 | let mut sum = 0; | |
770 | let src0 = ref_frm.get_data(); | |
771 | let src1 = cur_frm.get_data(); | |
772 | let stride = ref_frm.get_stride(0); | |
773 | let (src_x, src_y) = Self::add_mv(mb_x, mb_y, mv); | |
774 | for y in (0..16).step_by(4) { | |
775 | for x in (0..16).step_by(4) { | |
776 | sum += Self::satd(&src0[src_x + x + (src_y + y) * stride..], | |
777 | &src1[mb_x * 16 + x + (mb_y * 16 + y) * stride..], | |
778 | stride); | |
779 | } | |
780 | } | |
781 | sum | |
782 | } | |
783 | fn search_mv(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>, mb_x: usize, mb_y: usize) -> MV { | |
784 | let stride = ref_frm.get_stride(0); | |
785 | let (w, h) = ref_frm.get_dimensions(0); | |
786 | let (v_edge, h_edge) = (w - 16, h - 16); | |
787 | let ref_src = ref_frm.get_data(); | |
788 | let cur_src = cur_frm.get_data(); | |
789 | let cur_src = &cur_src[mb_x * 16 + mb_y * 16 * stride..]; | |
790 | ||
791 | let mut best_mv = ZERO_MV; | |
792 | let mut best_dist = Self::sad(cur_src, ref_src, mb_x, mb_y, stride, best_mv); | |
793 | if best_dist == 0 { | |
794 | return best_mv; | |
795 | } | |
796 | ||
797 | for step in (0..=2).rev() { | |
798 | let mut changed = true; | |
799 | while changed { | |
800 | changed = false; | |
801 | for &mv in DIA_PATTERN[1..].iter() { | |
802 | let cand_mv = best_mv + mv.scale(1 << step); | |
803 | let (cx, cy) = Self::add_mv(mb_x, mb_y, cand_mv); | |
804 | if cx > v_edge || cy > h_edge { | |
805 | continue; | |
806 | } | |
807 | let cand_dist = Self::sad(cur_src, ref_src, mb_x, mb_y, stride, cand_mv); | |
808 | if cand_dist < best_dist { | |
809 | best_dist = cand_dist; | |
810 | best_mv = cand_mv; | |
811 | if best_dist == 0 { | |
812 | return best_mv; | |
813 | } | |
814 | changed = true; | |
815 | } | |
816 | } | |
817 | } | |
818 | } | |
819 | best_mv | |
820 | } | |
821 | fn sad(cur_src: &[u8], src: &[u8], mb_x: usize, mb_y: usize, stride: usize, mv: MV) -> u32 { | |
822 | let (src_x, src_y) = Self::add_mv(mb_x, mb_y, mv); | |
823 | let mut sum = 0; | |
824 | for (line1, line2) in cur_src.chunks(stride).zip(src[src_x + src_y * stride..].chunks(stride)).take(16) { | |
825 | sum += line1[..16].iter().zip(line2[..16].iter()).fold(0u32, | |
826 | |acc, (&a, &b)| acc + u32::from(a.max(b) - a.min(b)) * u32::from(a.max(b) - a.min(b))); | |
827 | } | |
828 | sum | |
829 | } | |
830 | fn satd_i(src: &[u8], mut offset: usize, stride: usize, has_left: bool, has_top: bool) -> u32 { | |
831 | let mut diffs = [0; 16]; | |
832 | match (has_left, has_top) { | |
833 | (true, true) => { | |
834 | for row in diffs.chunks_exact_mut(4) { | |
835 | let mut left = i16::from(src[offset - 1]); | |
836 | let mut tl = i16::from(src[offset - stride - 1]); | |
837 | for (x, dst) in row.iter_mut().enumerate() { | |
838 | let cur = i16::from(src[offset + x]); | |
839 | let top = i16::from(src[offset + x - stride]); | |
840 | ||
841 | *dst = cur - (top + left + tl - top.min(left).min(tl) - top.max(left).max(tl)); | |
842 | ||
843 | left = cur; | |
844 | tl = top; | |
845 | } | |
846 | ||
847 | offset += stride; | |
848 | } | |
849 | }, | |
850 | (true, false) => { | |
851 | for (dst, (left, cur)) in diffs.chunks_exact_mut(4).zip( | |
852 | src[offset - 1..].chunks(stride).zip(src[offset..].chunks(stride))) { | |
853 | for (dst, (&left, &cur)) in dst.iter_mut().zip(left.iter().zip(cur.iter())) { | |
854 | *dst = i16::from(cur) - i16::from(left); | |
855 | } | |
856 | } | |
857 | }, | |
858 | (false, true) => { | |
859 | for (dst, (top, cur)) in diffs.chunks_exact_mut(4).zip( | |
860 | src[offset - stride..].chunks(stride).zip(src[offset..].chunks(stride))) { | |
861 | for (dst, (&top, &cur)) in dst.iter_mut().zip(top.iter().zip(cur.iter())) { | |
862 | *dst = i16::from(cur) - i16::from(top); | |
863 | } | |
864 | } | |
865 | }, | |
866 | (false, false) => { | |
867 | for (dst, src) in diffs.chunks_exact_mut(4).zip(src[offset..].chunks(stride)) { | |
868 | for (dst, &src) in dst.iter_mut().zip(src.iter()) { | |
869 | *dst = i16::from(src) - 128; | |
870 | } | |
871 | } | |
872 | }, | |
873 | }; | |
874 | for row in diffs.chunks_exact_mut(4) { | |
875 | hadamard!(row[0], row[1], row[2], row[3], row[0], row[1], row[2], row[3]); | |
876 | } | |
877 | for i in 0..4 { | |
878 | hadamard!(diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12], | |
879 | diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12]); | |
880 | } | |
881 | diffs.iter().fold(0u32, |acc, x| acc + (x.abs() as u32)) | |
882 | } | |
883 | fn satd(src0: &[u8], src1: &[u8], stride: usize) -> u32 { | |
884 | let mut diffs = [0; 16]; | |
885 | for (dst, (src0, src1)) in diffs.chunks_exact_mut(4).zip( | |
886 | src0.chunks(stride).zip(src1.chunks(stride))) { | |
887 | hadamard!(i16::from(src0[0]) - i16::from(src1[0]), | |
888 | i16::from(src0[1]) - i16::from(src1[1]), | |
889 | i16::from(src0[2]) - i16::from(src1[2]), | |
890 | i16::from(src0[3]) - i16::from(src1[3]), | |
891 | dst[0], dst[1], dst[2], dst[3]); | |
892 | } | |
893 | for i in 0..4 { | |
894 | hadamard!(diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12], | |
895 | diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12]); | |
896 | } | |
897 | diffs.iter().fold(0u32, |acc, x| acc + (x.abs() as u32)) | |
898 | } | |
899 | fn downscale(dst: &mut NAVideoBuffer<u8>, src: &NAVideoBuffer<u8>) { | |
900 | let dst = NASimpleVideoFrame::from_video_buf(dst).unwrap(); | |
901 | let sdata = src.get_data(); | |
902 | for plane in 0..3 { | |
903 | let cur_w = dst.width[plane]; | |
904 | let cur_h = dst.height[plane]; | |
905 | let doff = dst.offset[plane]; | |
906 | let soff = src.get_offset(plane); | |
907 | let dstride = dst.stride[plane]; | |
908 | let sstride = src.get_stride(plane); | |
909 | for (dline, sstrip) in dst.data[doff..].chunks_exact_mut(dstride).zip( | |
910 | sdata[soff..].chunks_exact(sstride * 2)).take(cur_h) { | |
911 | let (line0, line1) = sstrip.split_at(sstride); | |
912 | for (dst, (src0, src1)) in dline.iter_mut().zip( | |
913 | line0.chunks_exact(2).zip(line1.chunks_exact(2))).take(cur_w) { | |
914 | *dst = ((u16::from(src0[0]) + u16::from(src0[1]) + | |
915 | u16::from(src1[0]) + u16::from(src1[1]) + 2) >> 2) as u8; | |
916 | } | |
917 | } | |
918 | } | |
919 | } | |
920 | } |