]>
Commit | Line | Data |
---|---|---|
4965a5e5 KS |
1 | use nihav_core::frame::*; |
2 | use nihav_codec_support::codecs::{MV, ZERO_MV}; | |
3 | use std::str::FromStr; | |
4 | use super::dsp::{RefMBData, luma_mc, chroma_mc}; | |
5 | ||
e6aaad5c | 6 | #[derive(Clone,Copy,PartialEq,Default)] |
4965a5e5 KS |
7 | pub enum MVSearchMode { |
8 | Dummy, | |
9 | Diamond, | |
e6aaad5c | 10 | #[default] |
4965a5e5 KS |
11 | Hexagon, |
12 | UMH, | |
13 | } | |
14 | ||
15 | impl MVSearchMode { | |
16 | pub const fn get_possible_modes() -> &'static [&'static str] { | |
17 | &["diamond", "hexagon", "umh"] | |
18 | } | |
19 | fn create(self) -> Box<dyn MVSearch+Send> { | |
20 | match self { | |
21 | MVSearchMode::Dummy => Box::new(DummySearcher{}), | |
22 | MVSearchMode::Diamond => Box::new(DiaSearch::new()), | |
23 | MVSearchMode::Hexagon => Box::new(HexSearch::new()), | |
24 | MVSearchMode::UMH => Box::new(UnevenHexSearch::new()), | |
25 | } | |
26 | } | |
27 | } | |
28 | ||
4965a5e5 KS |
29 | impl std::fmt::Display for MVSearchMode { |
30 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | |
31 | match *self { | |
32 | MVSearchMode::Diamond => write!(f, "diamond"), | |
33 | MVSearchMode::Hexagon => write!(f, "hexagon"), | |
34 | MVSearchMode::UMH => write!(f, "umh"), | |
35 | MVSearchMode::Dummy => write!(f, "dummy"), | |
36 | } | |
37 | } | |
38 | } | |
39 | ||
40 | impl FromStr for MVSearchMode { | |
41 | type Err = (); | |
42 | fn from_str(s: &str) -> Result<Self, Self::Err> { | |
43 | match s { | |
44 | "diamond" => Ok(MVSearchMode::Diamond), | |
45 | "hexagon" => Ok(MVSearchMode::Hexagon), | |
46 | "umh" => Ok(MVSearchMode::UMH), | |
47 | "dummy" => Ok(MVSearchMode::Dummy), | |
48 | _ => Err(()), | |
49 | } | |
50 | } | |
51 | } | |
52 | ||
08e0a8bf | 53 | const MAX_DIST: u32 = u32::MAX; |
4965a5e5 KS |
54 | const DIST_THRESH: u32 = 256; |
55 | ||
d92111a8 | 56 | #[allow(clippy::wrong_self_convention)] |
4965a5e5 KS |
57 | trait FromPixels { |
58 | fn from_pixels(self) -> Self; | |
59 | } | |
60 | ||
61 | impl FromPixels for MV { | |
62 | fn from_pixels(self) -> MV { | |
63 | MV { x: self.x * 4, y: self.y * 4 } | |
64 | } | |
65 | } | |
66 | ||
67 | const DIA_PATTERN: [MV; 9] = [ | |
68 | ZERO_MV, | |
69 | MV {x: -2, y: 0}, | |
70 | MV {x: -1, y: 1}, | |
71 | MV {x: 0, y: 2}, | |
72 | MV {x: 1, y: 1}, | |
73 | MV {x: 2, y: 0}, | |
74 | MV {x: 1, y: -1}, | |
75 | MV {x: 0, y: -2}, | |
76 | MV {x: -1, y: -1} | |
77 | ]; | |
78 | ||
79 | const HEX_PATTERN: [MV; 7] = [ | |
80 | ZERO_MV, | |
81 | MV {x: -2, y: 0}, | |
82 | MV {x: -1, y: 2}, | |
83 | MV {x: 1, y: 2}, | |
84 | MV {x: 2, y: 0}, | |
85 | MV {x: 1, y: -2}, | |
86 | MV {x: -1, y: -2} | |
87 | ]; | |
88 | ||
89 | const REFINEMENT: [MV; 4] = [ | |
90 | MV {x: -1, y: 0}, | |
91 | MV {x: 0, y: 1}, | |
92 | MV {x: 1, y: 0}, | |
93 | MV {x: 0, y: -1} | |
94 | ]; | |
95 | ||
96 | macro_rules! search_template { | |
97 | ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident, $threshold: expr) => ({ | |
98 | search_template!($self, $mv_est, $cur_blk, $mb_x, $mb_y, $sad_func, $threshold, ZERO_MV, MAX_DIST, true) | |
99 | }); | |
100 | ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident, $threshold: expr, $start_mv: expr, $best_dist: expr, $fullpel_stage: expr) => ({ | |
101 | let mut best_dist = $best_dist; | |
102 | let mut best_mv = $start_mv; | |
103 | ||
104 | let mut min_dist; | |
105 | let mut min_idx; | |
106 | ||
107 | if $fullpel_stage { | |
108 | $self.reset(); | |
109 | loop { | |
110 | let mut cur_best_dist = best_dist; | |
111 | for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) { | |
112 | if *dist == MAX_DIST { | |
113 | *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point.from_pixels(), cur_best_dist); | |
114 | cur_best_dist = cur_best_dist.min(*dist); | |
115 | if *dist <= $threshold { | |
116 | break; | |
117 | } | |
118 | } | |
119 | } | |
120 | min_dist = $self.dist[0]; | |
121 | min_idx = 0; | |
122 | for (i, &dist) in $self.dist.iter().enumerate().skip(1) { | |
123 | if dist < min_dist { | |
124 | min_dist = dist; | |
125 | min_idx = i; | |
126 | if dist <= $threshold { | |
127 | break; | |
128 | } | |
129 | } | |
130 | } | |
131 | if min_dist <= $threshold || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range || $self.point[min_idx].y.abs() >= $mv_est.mv_range { | |
132 | break; | |
133 | } | |
134 | best_dist = min_dist; | |
135 | $self.update($self.steps[min_idx]); | |
136 | } | |
137 | best_dist = min_dist; | |
138 | best_mv = $self.point[min_idx]; | |
139 | if best_dist <= $threshold { | |
140 | return (best_mv.from_pixels(), best_dist); | |
141 | } | |
142 | for &step in REFINEMENT.iter() { | |
143 | let mv = best_mv + step; | |
144 | let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv.from_pixels(), MAX_DIST); | |
145 | if best_dist > dist { | |
146 | best_dist = dist; | |
147 | best_mv = mv; | |
148 | } | |
149 | } | |
150 | best_mv = best_mv.from_pixels(); | |
151 | if best_dist <= $threshold { | |
152 | return (best_mv, best_dist); | |
153 | } | |
154 | } | |
155 | ||
156 | // subpel refinement | |
157 | $self.set_new_point(best_mv, best_dist); | |
158 | loop { | |
159 | let mut cur_best_dist = best_dist; | |
160 | for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) { | |
161 | if *dist == MAX_DIST { | |
162 | *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point, cur_best_dist); | |
163 | cur_best_dist = cur_best_dist.min(*dist); | |
164 | if *dist <= $threshold { | |
165 | break; | |
166 | } | |
167 | } | |
168 | } | |
169 | min_dist = $self.dist[0]; | |
170 | min_idx = 0; | |
171 | for (i, &dist) in $self.dist.iter().enumerate().skip(1) { | |
172 | if dist < min_dist { | |
173 | min_dist = dist; | |
174 | min_idx = i; | |
175 | if dist <= $threshold { | |
176 | break; | |
177 | } | |
178 | } | |
179 | } | |
180 | if min_dist <= $threshold || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range * 8 || $self.point[min_idx].y.abs() >= $mv_est.mv_range * 8 { | |
181 | break; | |
182 | } | |
183 | best_dist = min_dist; | |
184 | $self.update($self.steps[min_idx]); | |
185 | } | |
186 | best_dist = min_dist; | |
187 | best_mv = $self.point[min_idx]; | |
188 | if best_dist <= $threshold { | |
189 | return (best_mv, best_dist); | |
190 | } | |
191 | for &step in REFINEMENT.iter() { | |
192 | let mv = best_mv + step; | |
193 | let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv, MAX_DIST); | |
194 | if best_dist > dist { | |
195 | best_dist = dist; | |
196 | best_mv = mv; | |
197 | } | |
198 | } | |
199 | (best_mv, best_dist) | |
200 | }); | |
201 | } | |
202 | ||
203 | macro_rules! pattern_search { | |
204 | ($struct_name: ident, $patterns: expr) => { | |
205 | pub struct $struct_name { | |
206 | point: [MV; $patterns.len()], | |
207 | dist: [u32; $patterns.len()], | |
208 | steps: &'static [MV; $patterns.len()], | |
209 | } | |
210 | ||
211 | impl $struct_name { | |
212 | pub fn new() -> Self { | |
213 | Self { | |
214 | point: $patterns, | |
215 | dist: [MAX_DIST; $patterns.len()], | |
216 | steps: &$patterns, | |
217 | } | |
218 | } | |
219 | fn reset(&mut self) { | |
220 | self.point = $patterns; | |
221 | self.dist = [MAX_DIST; $patterns.len()]; | |
222 | } | |
223 | fn set_new_point(&mut self, start: MV, dist: u32) { | |
224 | for (dst, &src) in self.point.iter_mut().zip(self.steps.iter()) { | |
225 | *dst = src + start; | |
226 | } | |
227 | self.dist = [MAX_DIST; $patterns.len()]; | |
228 | self.dist[0] = dist; | |
229 | } | |
230 | fn update(&mut self, step: MV) { | |
231 | let mut new_point = self.point; | |
232 | let mut new_dist = [MAX_DIST; $patterns.len()]; | |
233 | ||
234 | for point in new_point.iter_mut() { | |
235 | *point += step; | |
236 | } | |
237 | ||
238 | for (new_point, new_dist) in new_point.iter_mut().zip(new_dist.iter_mut()) { | |
239 | for (&old_point, &old_dist) in self.point.iter().zip(self.dist.iter()) { | |
240 | if *new_point == old_point { | |
241 | *new_dist = old_dist; | |
242 | break; | |
243 | } | |
244 | } | |
245 | } | |
246 | self.point = new_point; | |
247 | self.dist = new_dist; | |
248 | } | |
249 | } | |
250 | ||
251 | impl MVSearch for $struct_name { | |
252 | fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &RefMBData, mb_x: usize, mb_y: usize, _cand_mvs: &[MV]) -> (MV, u32) { | |
253 | search_template!(self, mv_est, cur_mb, mb_x, mb_y, sad_mb, DIST_THRESH) | |
254 | } | |
255 | fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) { | |
256 | search_template!(self, mv_est, ref_blk, xpos, ypos, sad_blk8, DIST_THRESH / 4) | |
257 | } | |
258 | } | |
259 | } | |
260 | } | |
261 | ||
262 | pattern_search!(DiaSearch, DIA_PATTERN); | |
263 | pattern_search!(HexSearch, HEX_PATTERN); | |
264 | ||
265 | const LARGE_HEX_PATTERN: [MV; 16] = [ | |
266 | MV { x: -4, y: 0 }, | |
267 | MV { x: -4, y: 1 }, | |
268 | MV { x: -4, y: 2 }, | |
269 | MV { x: -2, y: 3 }, | |
270 | MV { x: 0, y: 4 }, | |
271 | MV { x: 2, y: 3 }, | |
272 | MV { x: 4, y: 2 }, | |
273 | MV { x: 4, y: 1 }, | |
274 | MV { x: 4, y: 0 }, | |
275 | MV { x: 4, y: -1 }, | |
276 | MV { x: 4, y: -2 }, | |
277 | MV { x: -2, y: -3 }, | |
278 | MV { x: 0, y: -4 }, | |
279 | MV { x: -2, y: -3 }, | |
280 | MV { x: -4, y: -2 }, | |
281 | MV { x: -4, y: -1 } | |
282 | ]; | |
283 | ||
284 | const UNSYMM_CROSS: [MV; 4] = [ | |
285 | MV { x: -2, y: 0 }, | |
286 | MV { x: 0, y: 1 }, | |
287 | MV { x: 2, y: 0 }, | |
288 | MV { x: 0, y: -1 } | |
289 | ]; | |
290 | ||
291 | #[derive(Default)] | |
292 | struct UniqueSet<T:Copy+Default> { | |
293 | list: [T; 16], | |
294 | count: usize, | |
295 | } | |
296 | ||
297 | impl<T:Copy+Default+PartialEq> UniqueSet<T> { | |
298 | fn new() -> Self { Self::default() } | |
299 | fn clear(&mut self) { self.count = 0; } | |
300 | fn get_list(&self) -> &[T] { &self.list[..self.count] } | |
301 | fn add(&mut self, val: T) { | |
302 | if self.count < self.list.len() && !self.get_list().contains(&val) { | |
303 | self.list[self.count] = val; | |
304 | self.count += 1; | |
305 | } | |
306 | } | |
307 | } | |
308 | ||
d92111a8 | 309 | #[allow(clippy::wrong_self_convention)] |
4965a5e5 KS |
310 | trait MVOps { |
311 | fn scale(self, scale: i16) -> Self; | |
312 | fn is_in_range(self, range: i16) -> bool; | |
313 | } | |
314 | ||
315 | impl MVOps for MV { | |
316 | fn scale(self, scale: i16) -> MV { | |
317 | MV { x: self.x * scale, y: self.y * scale } | |
318 | } | |
319 | fn is_in_range(self, range: i16) -> bool { | |
320 | self.x.abs() <= range && self.y.abs() <= range | |
321 | } | |
322 | } | |
323 | ||
324 | macro_rules! single_search_step { | |
325 | ($start:expr, $best_dist:expr, $mv_est:expr, $sad_func:ident, $ref_blk:expr, $xpos:expr, $ypos:expr, $pattern:expr, $scale:expr, $dist_thr:expr) => {{ | |
326 | let mut best_mv = $start; | |
327 | let mut best_dist = $best_dist; | |
328 | for point in $pattern.iter() { | |
329 | let mv = point.scale($scale) + $start; | |
330 | if !mv.is_in_range($mv_est.mv_range * 4) { | |
331 | continue; | |
332 | } | |
333 | let dist = $mv_est.$sad_func($ref_blk, $xpos, $ypos, mv, best_dist); | |
334 | if dist < best_dist { | |
335 | best_mv = mv; | |
336 | best_dist = dist; | |
337 | if best_dist < $dist_thr { | |
338 | break; | |
339 | } | |
340 | } | |
341 | } | |
342 | (best_mv, best_dist, best_mv != $start) | |
343 | }} | |
344 | } | |
345 | ||
346 | struct UnevenHexSearch { | |
347 | mv_list: UniqueSet<MV>, | |
348 | } | |
349 | ||
350 | impl UnevenHexSearch { | |
351 | fn new() -> Self { | |
352 | Self { | |
353 | mv_list: UniqueSet::new(), | |
354 | } | |
355 | } | |
356 | fn get_cand_mv(&mut self, cand_mvs: &[MV]) -> MV { | |
357 | self.mv_list.clear(); | |
358 | for &mv in cand_mvs.iter() { | |
359 | self.mv_list.add(mv); | |
360 | } | |
361 | match self.mv_list.count { | |
362 | 1 => self.mv_list.list[0], | |
363 | 3 => MV::pred(self.mv_list.list[0], self.mv_list.list[1], self.mv_list.list[2]), | |
364 | _ => { | |
365 | let sum = self.mv_list.get_list().iter().fold((0i32, 0i32), | |
366 | |acc, mv| (acc.0 + i32::from(mv.x), acc.1 + i32::from(mv.y))); | |
367 | MV {x: (sum.0 / (self.mv_list.count as i32)) as i16, | |
368 | y: (sum.1 / (self.mv_list.count as i32)) as i16} | |
369 | }, | |
370 | } | |
371 | } | |
372 | } | |
373 | ||
374 | macro_rules! umh_search_template { | |
375 | ($cand_mv:expr, $cutoff:expr, $mv_est:expr, $sad_func:ident, $ref_blk:expr, $xpos:expr, $ypos:expr) => {{ | |
376 | let cand_mv = $cand_mv; | |
377 | let best_dist = $mv_est.$sad_func($ref_blk, $xpos, $ypos, cand_mv, MAX_DIST); | |
378 | if best_dist < $cutoff { | |
379 | return (cand_mv, best_dist); | |
380 | } | |
381 | ||
382 | // step 1 - small refinement search | |
383 | let (mut cand_mv, mut best_dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, DIA_PATTERN, 1, $cutoff); | |
384 | if best_dist < $cutoff { | |
385 | return (cand_mv, best_dist); | |
386 | } | |
387 | ||
388 | // step 2 - unsymmetrical cross search | |
389 | loop { | |
390 | let (mv, dist, changed) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, UNSYMM_CROSS, 4, $cutoff); | |
391 | if !changed { | |
392 | break; | |
393 | } | |
394 | cand_mv = mv; | |
395 | best_dist = dist; | |
396 | if best_dist < $cutoff { | |
397 | return (mv, dist); | |
398 | } | |
399 | } | |
400 | ||
401 | // step 3 - multi-hexagon grid search | |
402 | let mut scale = 4; | |
403 | while scale > 0 { | |
404 | let (mv, dist, changed) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, LARGE_HEX_PATTERN, scale, $cutoff); | |
405 | if !changed { | |
406 | break; | |
407 | } | |
408 | cand_mv = mv; | |
409 | best_dist = dist; | |
410 | if best_dist < $cutoff { | |
411 | return (mv, dist); | |
412 | } | |
413 | scale >>= 1; | |
414 | } | |
415 | // step 4 - final hexagon search | |
416 | let (cand_mv, best_dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, HEX_PATTERN, 1, $cutoff); | |
417 | if best_dist > $cutoff { | |
418 | let (mv, dist, _) = single_search_step!(cand_mv, best_dist, $mv_est, $sad_func, $ref_blk, $xpos, $ypos, DIA_PATTERN, 1, $cutoff); | |
419 | (mv, dist) | |
420 | } else { | |
421 | (cand_mv, best_dist) | |
422 | } | |
423 | }} | |
424 | } | |
425 | ||
426 | impl MVSearch for UnevenHexSearch { | |
427 | fn search_mb(&mut self, mv_est: &mut MVEstimator, cur_mb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32) { | |
428 | let cand_mv = self.get_cand_mv(cand_mvs); | |
429 | let cutoff = mv_est.cutoff_thr; | |
430 | umh_search_template!(cand_mv, cutoff, mv_est, sad_mb, cur_mb, mb_x, mb_y) | |
431 | } | |
432 | fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32) { | |
433 | let cand_mv = self.get_cand_mv(cand_mvs); | |
434 | let cutoff = mv_est.cutoff_thr / 4; | |
435 | umh_search_template!(cand_mv, cutoff, mv_est, sad_blk8, ref_blk, xpos, ypos) | |
436 | } | |
437 | } | |
438 | ||
439 | struct MVEstimator<'a> { | |
440 | pic: &'a NAVideoBuffer<u8>, | |
441 | mv_range: i16, | |
442 | cutoff_thr: u32, | |
443 | } | |
444 | ||
445 | macro_rules! sad { | |
446 | ($src1:expr, $src2:expr) => { | |
447 | $src1.iter().zip($src2.iter()).fold(0u32, |acc, (&a, &b)| | |
448 | acc + (((i32::from(a) - i32::from(b)) * (i32::from(a) - i32::from(b))) as u32)) | |
449 | } | |
450 | } | |
451 | ||
452 | impl<'a> MVEstimator<'a> { | |
453 | fn sad_mb(&self, ref_mb: &RefMBData, mb_x: usize, mb_y: usize, mv: MV, cur_best_dist: u32) -> u32 { | |
454 | let mut dst = RefMBData::new(); | |
455 | luma_mc(&mut dst.y, 16, self.pic, mb_x * 16, mb_y * 16, mv, true); | |
456 | ||
457 | let mut dist = 0; | |
458 | for (dline, sline) in dst.y.chunks(16).zip(ref_mb.y.chunks(16)) { | |
459 | dist += sad!(dline, sline); | |
460 | if dist > cur_best_dist { | |
461 | return dist; | |
462 | } | |
463 | } | |
464 | chroma_mc(&mut dst.u, 8, self.pic, mb_x * 8, mb_y * 8, 1, mv, true); | |
465 | dist += sad!(dst.u, ref_mb.u); | |
466 | if dist > cur_best_dist { | |
467 | return dist; | |
468 | } | |
469 | chroma_mc(&mut dst.v, 8, self.pic, mb_x * 8, mb_y * 8, 2, mv, true); | |
470 | dist += sad!(dst.v, ref_mb.v); | |
471 | ||
472 | dist | |
473 | } | |
474 | fn sad_blk8(&self, ref_mb: &RefMBData, xpos: usize, ypos: usize, mv: MV, cur_best_dist: u32) -> u32 { | |
475 | let mut cur_y = [0; 64]; | |
476 | let mut cur_u = [0; 16]; | |
477 | let mut cur_v = [0; 16]; | |
478 | ||
479 | let mut dist = 0; | |
480 | ||
481 | let y_off = (xpos & 8) + (ypos & 8) * 16; | |
482 | luma_mc(&mut cur_y, 8, self.pic, xpos, ypos, mv, false); | |
483 | for (dline, sline) in cur_y.chunks(8).zip(ref_mb.y[y_off..].chunks(16)) { | |
484 | dist += sad!(dline, sline); | |
485 | if dist > cur_best_dist { | |
486 | return dist; | |
487 | } | |
488 | } | |
489 | ||
490 | let c_off = (xpos & 8) / 2 + (ypos & 8) * 4; | |
491 | chroma_mc(&mut cur_u, 4, self.pic, xpos / 2, ypos / 2, 1, mv, false); | |
492 | for (dline, sline) in cur_u.chunks(4).zip(ref_mb.u[c_off..].chunks(8)) { | |
493 | dist += sad!(dline, sline); | |
494 | if dist > cur_best_dist { | |
495 | return dist; | |
496 | } | |
497 | } | |
498 | chroma_mc(&mut cur_v, 4, self.pic, xpos / 2, ypos / 2, 2, mv, false); | |
499 | for (dline, sline) in cur_v.chunks(4).zip(ref_mb.v[c_off..].chunks(8)) { | |
500 | dist += sad!(dline, sline); | |
501 | if dist > cur_best_dist { | |
502 | return dist; | |
503 | } | |
504 | } | |
505 | ||
506 | dist | |
507 | } | |
508 | } | |
509 | ||
510 | trait MVSearch { | |
511 | fn search_mb(&mut self, mv_est: &mut MVEstimator, ref_mb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32); | |
512 | fn search_blk8(&mut self, mv_est: &mut MVEstimator, ref_blk: &RefMBData, xpos: usize, ypos: usize, cand_mvs: &[MV]) -> (MV, u32); | |
513 | } | |
514 | ||
515 | struct DummySearcher {} | |
516 | ||
517 | impl MVSearch for DummySearcher { | |
518 | fn search_mb(&mut self, _mv_est: &mut MVEstimator, _ref_mb: &RefMBData, _mb_x: usize, _mb_y: usize, _cand_mvs: &[MV]) -> (MV, u32) { | |
08e0a8bf | 519 | (ZERO_MV, u32::MAX / 2) |
4965a5e5 KS |
520 | } |
521 | fn search_blk8(&mut self, _mv_est: &mut MVEstimator, _ref_mb: &RefMBData, _xpos: usize, _ypos: usize, _cand_mvs: &[MV]) -> (MV, u32) { | |
08e0a8bf | 522 | (ZERO_MV, u32::MAX / 2) |
4965a5e5 KS |
523 | } |
524 | } | |
525 | ||
526 | pub struct MotionEstimator { | |
527 | pub range: i16, | |
528 | pub thresh: u32, | |
529 | mode: MVSearchMode, | |
530 | srch: Box<dyn MVSearch+Send>, | |
531 | } | |
532 | ||
533 | impl MotionEstimator { | |
534 | pub fn new() -> Self { | |
535 | let mode = MVSearchMode::default(); | |
536 | Self { | |
537 | range: 64, | |
538 | thresh: 32, | |
539 | mode, | |
540 | srch: mode.create(), | |
541 | } | |
542 | } | |
543 | pub fn get_mode(&self) -> MVSearchMode { self.mode } | |
544 | pub fn set_mode(&mut self, new_mode: MVSearchMode) { | |
545 | if self.mode != new_mode { | |
546 | self.mode = new_mode; | |
547 | self.srch = self.mode.create(); | |
548 | } | |
549 | } | |
550 | pub fn search_mb_p(&mut self, pic: &NAVideoBuffer<u8>, refmb: &RefMBData, mb_x: usize, mb_y: usize, cand_mvs: &[MV]) -> (MV, u32) { | |
551 | let mut mv_est = MVEstimator { | |
552 | mv_range: self.range, | |
553 | cutoff_thr: self.thresh, | |
554 | pic, | |
555 | }; | |
556 | self.srch.search_mb(&mut mv_est, refmb, mb_x, mb_y, cand_mvs) | |
557 | } | |
558 | pub fn search_blk8(&mut self, pic: &NAVideoBuffer<u8>, refmb: &RefMBData, xoff: usize, yoff: usize, cand_mvs: &[MV]) -> (MV, u32) { | |
559 | let mut mv_est = MVEstimator { | |
560 | mv_range: self.range, | |
561 | cutoff_thr: self.thresh, | |
562 | pic, | |
563 | }; | |
564 | self.srch.search_blk8(&mut mv_est, refmb, xoff, yoff, cand_mvs) | |
565 | } | |
566 | } | |
567 | ||
568 | pub struct SearchB<'a> { | |
569 | ref_p: &'a NAVideoBuffer<u8>, | |
570 | ref_n: &'a NAVideoBuffer<u8>, | |
571 | xpos: usize, | |
572 | ypos: usize, | |
573 | ratios: [u32; 2], | |
574 | tmp1: RefMBData, | |
575 | tmp2: RefMBData, | |
576 | pred_blk: RefMBData, | |
577 | } | |
578 | ||
579 | impl<'a> SearchB<'a> { | |
580 | pub fn new(ref_p: &'a NAVideoBuffer<u8>, ref_n: &'a NAVideoBuffer<u8>, mb_x: usize, mb_y: usize, ratios: [u32; 2]) -> Self { | |
581 | Self { | |
582 | ref_p, ref_n, | |
583 | xpos: mb_x * 16, | |
584 | ypos: mb_y * 16, | |
585 | ratios, | |
586 | tmp1: RefMBData::new(), | |
587 | tmp2: RefMBData::new(), | |
588 | pred_blk: RefMBData::new(), | |
589 | } | |
590 | } | |
591 | pub fn search_mb(&mut self, ref_mb: &RefMBData, cand_mvs: [MV; 2]) -> (MV, MV) { | |
592 | let mut best_cand = cand_mvs; | |
593 | let mut best_dist = self.interp_b_dist(ref_mb, best_cand, MAX_DIST); | |
594 | ||
595 | loop { | |
596 | let mut improved = false; | |
597 | for &fmv_add in DIA_PATTERN.iter() { | |
598 | for &bmv_add in DIA_PATTERN.iter() { | |
599 | let cand = [best_cand[0] + fmv_add.from_pixels(), | |
600 | best_cand[1] + bmv_add.from_pixels()]; | |
601 | let dist = self.interp_b_dist(ref_mb, cand, best_dist); | |
602 | if dist < best_dist { | |
603 | best_dist = dist; | |
604 | best_cand = cand; | |
605 | improved = true; | |
606 | } | |
607 | } | |
608 | } | |
609 | if !improved { | |
610 | break; | |
611 | } | |
612 | } | |
613 | ||
614 | for &fmv_add in REFINEMENT.iter() { | |
615 | for &bmv_add in REFINEMENT.iter() { | |
616 | let cand = [best_cand[0] + fmv_add, best_cand[1] + bmv_add]; | |
617 | let dist = self.interp_b_dist(ref_mb, cand, best_dist); | |
618 | if dist < best_dist { | |
619 | best_dist = dist; | |
620 | best_cand = cand; | |
621 | } | |
622 | } | |
623 | } | |
624 | ||
625 | (best_cand[0], best_cand[1]) | |
626 | } | |
627 | fn interp_b_dist(&mut self, ref_mb: &RefMBData, cand_mv: [MV; 2], cur_best_dist: u32) -> u32 { | |
628 | let [fmv, bmv] = cand_mv; | |
629 | luma_mc(&mut self.tmp1.y, 16, self.ref_p, self.xpos, self.ypos, fmv, true); | |
630 | chroma_mc(&mut self.tmp1.u, 8, self.ref_p, self.xpos / 2, self.ypos / 2, 1, fmv, true); | |
631 | chroma_mc(&mut self.tmp1.v, 8, self.ref_p, self.xpos / 2, self.ypos / 2, 2, fmv, true); | |
632 | luma_mc(&mut self.tmp2.y, 16, self.ref_n, self.xpos, self.ypos, bmv, true); | |
633 | chroma_mc(&mut self.tmp2.u, 8, self.ref_n, self.xpos / 2, self.ypos / 2, 1, bmv, true); | |
634 | chroma_mc(&mut self.tmp2.v, 8, self.ref_n, self.xpos / 2, self.ypos / 2, 2, bmv, true); | |
635 | self.pred_blk.avg(&self.tmp1, self.ratios[0], &self.tmp2, self.ratios[1]); | |
636 | ||
637 | let mut dist = 0; | |
638 | for (dline, sline) in self.pred_blk.y.chunks(16).zip(ref_mb.y.chunks(16)) { | |
639 | dist += sad!(dline, sline); | |
640 | if dist > cur_best_dist { | |
641 | return dist; | |
642 | } | |
643 | } | |
644 | dist += sad!(self.pred_blk.u, ref_mb.u); | |
645 | if dist > cur_best_dist { | |
646 | return dist; | |
647 | } | |
648 | dist += sad!(self.pred_blk.v, ref_mb.v); | |
649 | ||
650 | dist | |
651 | } | |
652 | } | |
653 | ||
654 | macro_rules! hadamard { | |
655 | ($s0:expr, $s1:expr, $s2:expr, $s3:expr, $d0:expr, $d1:expr, $d2:expr, $d3:expr) => { | |
656 | let t0 = $s0 + $s1; | |
657 | let t1 = $s0 - $s1; | |
658 | let t2 = $s2 + $s3; | |
659 | let t3 = $s2 - $s3; | |
660 | $d0 = t0 + t2; | |
661 | $d2 = t0 - t2; | |
662 | $d1 = t1 + t3; | |
663 | $d3 = t1 - t3; | |
664 | } | |
665 | } | |
666 | ||
667 | pub struct FrameComplexityEstimate { | |
668 | ref_frm: NAVideoBufferRef<u8>, | |
669 | cur_frm: NAVideoBufferRef<u8>, | |
670 | nxt_frm: NAVideoBufferRef<u8>, | |
671 | width: usize, | |
672 | height: usize, | |
673 | } | |
674 | ||
675 | impl FrameComplexityEstimate { | |
676 | pub fn new() -> Self { | |
677 | let vinfo = NAVideoInfo::new(24, 24, false, YUV420_FORMAT); | |
678 | let vt = alloc_video_buffer(vinfo, 4).unwrap(); | |
679 | let buf = vt.get_vbuf().unwrap(); | |
680 | Self { | |
681 | ref_frm: buf.clone(), | |
682 | cur_frm: buf.clone(), | |
683 | nxt_frm: buf, | |
684 | width: 0, | |
685 | height: 0, | |
686 | } | |
687 | } | |
688 | pub fn resize(&mut self, width: usize, height: usize) { | |
689 | if width != self.width || height != self.height { | |
690 | self.width = width; | |
691 | self.height = height; | |
692 | ||
693 | let vinfo = NAVideoInfo::new(self.width / 2, self.height / 2, false, YUV420_FORMAT); | |
694 | let vt = alloc_video_buffer(vinfo, 4).unwrap(); | |
695 | self.ref_frm = vt.get_vbuf().unwrap(); | |
696 | let frm = self.ref_frm.get_data_mut().unwrap(); | |
697 | for el in frm.iter_mut() { | |
698 | *el = 0x80; | |
699 | } | |
700 | let vt = alloc_video_buffer(vinfo, 4).unwrap(); | |
701 | self.cur_frm = vt.get_vbuf().unwrap(); | |
702 | let vt = alloc_video_buffer(vinfo, 4).unwrap(); | |
703 | self.nxt_frm = vt.get_vbuf().unwrap(); | |
704 | } | |
705 | } | |
706 | pub fn set_current(&mut self, frm: &NAVideoBuffer<u8>) { | |
707 | Self::downscale(&mut self.cur_frm, frm); | |
708 | } | |
709 | pub fn get_complexity(&self, ftype: FrameType) -> u32 { | |
710 | match ftype { | |
711 | FrameType::I => Self::calculate_i_cplx(&self.cur_frm), | |
712 | FrameType::P => Self::calculate_mv_diff(&self.ref_frm, &self.cur_frm), | |
713 | _ => 0, | |
714 | } | |
715 | } | |
716 | pub fn decide_b_frame(&mut self, frm1: &NAVideoBuffer<u8>, frm2: &NAVideoBuffer<u8>) -> bool { | |
717 | Self::downscale(&mut self.cur_frm, frm1); | |
718 | Self::downscale(&mut self.nxt_frm, frm2); | |
719 | let diff_ref_cur = Self::calculate_mv_diff(&self.ref_frm, &self.cur_frm); | |
720 | let diff_cur_nxt = Self::calculate_mv_diff(&self.cur_frm, &self.nxt_frm); | |
721 | ||
722 | // simple rule - if complexity ref->cur and cur->next is about the same this should be a B-frame | |
723 | let ddiff = diff_ref_cur.max(diff_cur_nxt) - diff_ref_cur.min(diff_cur_nxt); | |
724 | if ddiff < 256 { | |
725 | true | |
726 | } else { | |
727 | let mut order = 0; | |
728 | while (ddiff << order) < diff_ref_cur.min(diff_cur_nxt) { | |
729 | order += 1; | |
730 | } | |
731 | order > 2 | |
732 | } | |
733 | } | |
734 | pub fn update_ref(&mut self) { | |
735 | std::mem::swap(&mut self.ref_frm, &mut self.cur_frm); | |
736 | } | |
737 | ||
738 | fn add_mv(mb_x: usize, mb_y: usize, mv: MV) -> (usize, usize) { | |
739 | (((mb_x * 16) as isize + (mv.x as isize)) as usize, | |
740 | ((mb_y * 16) as isize + (mv.y as isize)) as usize) | |
741 | } | |
742 | fn calculate_i_cplx(frm: &NAVideoBuffer<u8>) -> u32 { | |
743 | let (w, h) = frm.get_dimensions(0); | |
744 | let src = frm.get_data(); | |
745 | let stride = frm.get_stride(0); | |
746 | let mut sum = 0; | |
747 | let mut offset = 0; | |
748 | for y in (0..h).step_by(4) { | |
749 | for x in (0..w).step_by(4) { | |
750 | sum += Self::satd_i(src, offset + x, stride, x > 0, y > 0); | |
751 | } | |
752 | offset += stride * 4; | |
753 | } | |
754 | sum | |
755 | } | |
756 | fn calculate_mv_diff(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>) -> u32 { | |
757 | let (w, h) = ref_frm.get_dimensions(0); | |
758 | let mut sum = 0; | |
759 | for mb_y in 0..(h / 16) { | |
760 | for mb_x in 0..(w / 16) { | |
761 | sum += Self::satd_mb_diff(ref_frm, cur_frm, mb_x, mb_y); | |
762 | } | |
763 | } | |
764 | sum | |
765 | } | |
766 | fn satd_mb_diff(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>, mb_x: usize, mb_y: usize) -> u32 { | |
767 | let mv = Self::search_mv(ref_frm, cur_frm, mb_x, mb_y); | |
768 | let mut sum = 0; | |
769 | let src0 = ref_frm.get_data(); | |
770 | let src1 = cur_frm.get_data(); | |
771 | let stride = ref_frm.get_stride(0); | |
772 | let (src_x, src_y) = Self::add_mv(mb_x, mb_y, mv); | |
773 | for y in (0..16).step_by(4) { | |
774 | for x in (0..16).step_by(4) { | |
775 | sum += Self::satd(&src0[src_x + x + (src_y + y) * stride..], | |
776 | &src1[mb_x * 16 + x + (mb_y * 16 + y) * stride..], | |
777 | stride); | |
778 | } | |
779 | } | |
780 | sum | |
781 | } | |
782 | fn search_mv(ref_frm: &NAVideoBuffer<u8>, cur_frm: &NAVideoBuffer<u8>, mb_x: usize, mb_y: usize) -> MV { | |
783 | let stride = ref_frm.get_stride(0); | |
784 | let (w, h) = ref_frm.get_dimensions(0); | |
785 | let (v_edge, h_edge) = (w - 16, h - 16); | |
786 | let ref_src = ref_frm.get_data(); | |
787 | let cur_src = cur_frm.get_data(); | |
788 | let cur_src = &cur_src[mb_x * 16 + mb_y * 16 * stride..]; | |
789 | ||
790 | let mut best_mv = ZERO_MV; | |
791 | let mut best_dist = Self::sad(cur_src, ref_src, mb_x, mb_y, stride, best_mv); | |
792 | if best_dist == 0 { | |
793 | return best_mv; | |
794 | } | |
795 | ||
796 | for step in (0..=2).rev() { | |
797 | let mut changed = true; | |
798 | while changed { | |
799 | changed = false; | |
800 | for &mv in DIA_PATTERN[1..].iter() { | |
801 | let cand_mv = best_mv + mv.scale(1 << step); | |
802 | let (cx, cy) = Self::add_mv(mb_x, mb_y, cand_mv); | |
803 | if cx > v_edge || cy > h_edge { | |
804 | continue; | |
805 | } | |
806 | let cand_dist = Self::sad(cur_src, ref_src, mb_x, mb_y, stride, cand_mv); | |
807 | if cand_dist < best_dist { | |
808 | best_dist = cand_dist; | |
809 | best_mv = cand_mv; | |
810 | if best_dist == 0 { | |
811 | return best_mv; | |
812 | } | |
813 | changed = true; | |
814 | } | |
815 | } | |
816 | } | |
817 | } | |
818 | best_mv | |
819 | } | |
820 | fn sad(cur_src: &[u8], src: &[u8], mb_x: usize, mb_y: usize, stride: usize, mv: MV) -> u32 { | |
821 | let (src_x, src_y) = Self::add_mv(mb_x, mb_y, mv); | |
822 | let mut sum = 0; | |
823 | for (line1, line2) in cur_src.chunks(stride).zip(src[src_x + src_y * stride..].chunks(stride)).take(16) { | |
824 | sum += line1[..16].iter().zip(line2[..16].iter()).fold(0u32, | |
825 | |acc, (&a, &b)| acc + u32::from(a.max(b) - a.min(b)) * u32::from(a.max(b) - a.min(b))); | |
826 | } | |
827 | sum | |
828 | } | |
829 | fn satd_i(src: &[u8], mut offset: usize, stride: usize, has_left: bool, has_top: bool) -> u32 { | |
830 | let mut diffs = [0; 16]; | |
831 | match (has_left, has_top) { | |
832 | (true, true) => { | |
833 | for row in diffs.chunks_exact_mut(4) { | |
834 | let mut left = i16::from(src[offset - 1]); | |
835 | let mut tl = i16::from(src[offset - stride - 1]); | |
836 | for (x, dst) in row.iter_mut().enumerate() { | |
837 | let cur = i16::from(src[offset + x]); | |
838 | let top = i16::from(src[offset + x - stride]); | |
839 | ||
840 | *dst = cur - (top + left + tl - top.min(left).min(tl) - top.max(left).max(tl)); | |
841 | ||
842 | left = cur; | |
843 | tl = top; | |
844 | } | |
845 | ||
846 | offset += stride; | |
847 | } | |
848 | }, | |
849 | (true, false) => { | |
850 | for (dst, (left, cur)) in diffs.chunks_exact_mut(4).zip( | |
851 | src[offset - 1..].chunks(stride).zip(src[offset..].chunks(stride))) { | |
852 | for (dst, (&left, &cur)) in dst.iter_mut().zip(left.iter().zip(cur.iter())) { | |
853 | *dst = i16::from(cur) - i16::from(left); | |
854 | } | |
855 | } | |
856 | }, | |
857 | (false, true) => { | |
858 | for (dst, (top, cur)) in diffs.chunks_exact_mut(4).zip( | |
859 | src[offset - stride..].chunks(stride).zip(src[offset..].chunks(stride))) { | |
860 | for (dst, (&top, &cur)) in dst.iter_mut().zip(top.iter().zip(cur.iter())) { | |
861 | *dst = i16::from(cur) - i16::from(top); | |
862 | } | |
863 | } | |
864 | }, | |
865 | (false, false) => { | |
866 | for (dst, src) in diffs.chunks_exact_mut(4).zip(src[offset..].chunks(stride)) { | |
867 | for (dst, &src) in dst.iter_mut().zip(src.iter()) { | |
868 | *dst = i16::from(src) - 128; | |
869 | } | |
870 | } | |
871 | }, | |
872 | }; | |
873 | for row in diffs.chunks_exact_mut(4) { | |
874 | hadamard!(row[0], row[1], row[2], row[3], row[0], row[1], row[2], row[3]); | |
875 | } | |
876 | for i in 0..4 { | |
877 | hadamard!(diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12], | |
878 | diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12]); | |
879 | } | |
e6aaad5c | 880 | diffs.iter().fold(0u32, |acc, x| acc + (x.unsigned_abs() as u32)) |
4965a5e5 KS |
881 | } |
882 | fn satd(src0: &[u8], src1: &[u8], stride: usize) -> u32 { | |
883 | let mut diffs = [0; 16]; | |
884 | for (dst, (src0, src1)) in diffs.chunks_exact_mut(4).zip( | |
885 | src0.chunks(stride).zip(src1.chunks(stride))) { | |
886 | hadamard!(i16::from(src0[0]) - i16::from(src1[0]), | |
887 | i16::from(src0[1]) - i16::from(src1[1]), | |
888 | i16::from(src0[2]) - i16::from(src1[2]), | |
889 | i16::from(src0[3]) - i16::from(src1[3]), | |
890 | dst[0], dst[1], dst[2], dst[3]); | |
891 | } | |
892 | for i in 0..4 { | |
893 | hadamard!(diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12], | |
894 | diffs[i], diffs[i + 4], diffs[i + 8], diffs[i + 12]); | |
895 | } | |
e6aaad5c | 896 | diffs.iter().fold(0u32, |acc, x| acc + (x.unsigned_abs() as u32)) |
4965a5e5 KS |
897 | } |
898 | fn downscale(dst: &mut NAVideoBuffer<u8>, src: &NAVideoBuffer<u8>) { | |
899 | let dst = NASimpleVideoFrame::from_video_buf(dst).unwrap(); | |
900 | let sdata = src.get_data(); | |
901 | for plane in 0..3 { | |
902 | let cur_w = dst.width[plane]; | |
903 | let cur_h = dst.height[plane]; | |
904 | let doff = dst.offset[plane]; | |
905 | let soff = src.get_offset(plane); | |
906 | let dstride = dst.stride[plane]; | |
907 | let sstride = src.get_stride(plane); | |
908 | for (dline, sstrip) in dst.data[doff..].chunks_exact_mut(dstride).zip( | |
909 | sdata[soff..].chunks_exact(sstride * 2)).take(cur_h) { | |
910 | let (line0, line1) = sstrip.split_at(sstride); | |
911 | for (dst, (src0, src1)) in dline.iter_mut().zip( | |
912 | line0.chunks_exact(2).zip(line1.chunks_exact(2))).take(cur_w) { | |
913 | *dst = ((u16::from(src0[0]) + u16::from(src0[1]) + | |
914 | u16::from(src1[0]) + u16::from(src1[1]) + 2) >> 2) as u8; | |
915 | } | |
916 | } | |
917 | } | |
918 | } | |
919 | } |