128633f01e8760a2cf8e5c993bc636e01db0e7ee
[nihav.git] / nihav-duck / src / codecs / vpenc / motion_est.rs
1 use nihav_codec_support::codecs::{MV, ZERO_MV};
2
3 use std::str::FromStr;
4
5 #[derive(Debug,Clone,Copy,PartialEq)]
6 pub enum MVSearchMode {
7 Full,
8 Diamond,
9 Hexagon,
10 }
11
12 impl Default for MVSearchMode {
13 fn default() -> Self { MVSearchMode::Hexagon }
14 }
15
16 pub struct ParseError{}
17
18 impl FromStr for MVSearchMode {
19 type Err = ParseError;
20
21 #[allow(clippy::single_match)]
22 fn from_str(s: &str) -> Result<Self, Self::Err> {
23 match s {
24 "full" => Ok(MVSearchMode::Full),
25 "dia" => Ok(MVSearchMode::Diamond),
26 "hex" => Ok(MVSearchMode::Hexagon),
27 _ => Err(ParseError{}),
28 }
29 }
30 }
31
32 impl std::fmt::Display for MVSearchMode {
33 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
34 match *self {
35 MVSearchMode::Full => write!(f, "full"),
36 MVSearchMode::Diamond => write!(f, "dia"),
37 MVSearchMode::Hexagon => write!(f, "hex"),
38 }
39 }
40 }
41
42 trait FromPixels {
43 fn from_pixels(self) -> Self;
44 }
45
46 impl FromPixels for MV {
47 fn from_pixels(self) -> MV {
48 MV { x: self.x * 4, y: self.y * 4 }
49 }
50 }
51
52 pub const DIA_PATTERN: [MV; 9] = [
53 ZERO_MV,
54 MV {x: -2, y: 0},
55 MV {x: -1, y: 1},
56 MV {x: 0, y: 2},
57 MV {x: 1, y: 1},
58 MV {x: 2, y: 0},
59 MV {x: 1, y: -1},
60 MV {x: 0, y: -2},
61 MV {x: -1, y: -1}
62 ];
63
64 pub const HEX_PATTERN: [MV; 7] = [
65 ZERO_MV,
66 MV {x: -2, y: 0},
67 MV {x: -1, y: 2},
68 MV {x: 1, y: 2},
69 MV {x: 2, y: 0},
70 MV {x: 1, y: -2},
71 MV {x: -1, y: -2}
72 ];
73
74 pub const REFINEMENT: [MV; 4] = [
75 MV {x: -1, y: 0},
76 MV {x: 0, y: 1},
77 MV {x: 1, y: 0},
78 MV {x: 0, y: -1}
79 ];
80
81 #[macro_export]
82 macro_rules! search_template {
83 ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident) => ({
84 let mut best_dist = MAX_DIST;
85 let mut best_mv;
86
87 let mut min_dist;
88 let mut min_idx;
89
90 $self.reset();
91 loop {
92 let mut cur_best_dist = best_dist;
93 for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) {
94 if *dist == MAX_DIST {
95 *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point.from_pixels(), cur_best_dist);
96 cur_best_dist = cur_best_dist.min(*dist);
97 if *dist <= DIST_THRESH {
98 break;
99 }
100 }
101 }
102 min_dist = $self.dist[0];
103 min_idx = 0;
104 for (i, &dist) in $self.dist.iter().enumerate().skip(1) {
105 if dist < min_dist {
106 min_dist = dist;
107 min_idx = i;
108 if dist <= DIST_THRESH {
109 break;
110 }
111 }
112 }
113 if min_dist <= DIST_THRESH || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range || $self.point[min_idx].y.abs() >= $mv_est.mv_range {
114 break;
115 }
116 best_dist = min_dist;
117 $self.update($self.steps[min_idx]);
118 }
119 best_dist = min_dist;
120 best_mv = $self.point[min_idx];
121 if best_dist <= DIST_THRESH {
122 return (best_mv.from_pixels(), best_dist);
123 }
124 for &step in REFINEMENT.iter() {
125 let mv = best_mv + step;
126 let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv.from_pixels(), MAX_DIST);
127 if best_dist > dist {
128 best_dist = dist;
129 best_mv = mv;
130 }
131 }
132 best_mv = best_mv.from_pixels();
133 if best_dist <= DIST_THRESH {
134 return (best_mv, best_dist);
135 }
136
137 // subpel refinement
138 $self.set_new_point(best_mv, best_dist);
139 loop {
140 let mut cur_best_dist = best_dist;
141 for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) {
142 if *dist == MAX_DIST {
143 *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point, cur_best_dist);
144 cur_best_dist = cur_best_dist.min(*dist);
145 if *dist <= DIST_THRESH {
146 break;
147 }
148 }
149 }
150 min_dist = $self.dist[0];
151 min_idx = 0;
152 for (i, &dist) in $self.dist.iter().enumerate().skip(1) {
153 if dist < min_dist {
154 min_dist = dist;
155 min_idx = i;
156 if dist <= DIST_THRESH {
157 break;
158 }
159 }
160 }
161 if min_dist <= DIST_THRESH || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range * 4 || $self.point[min_idx].y.abs() >= $mv_est.mv_range * 4 {
162 break;
163 }
164 best_dist = min_dist;
165 $self.update($self.steps[min_idx]);
166 }
167 best_dist = min_dist;
168 best_mv = $self.point[min_idx];
169 if best_dist <= DIST_THRESH {
170 return (best_mv, best_dist);
171 }
172 for &step in REFINEMENT.iter() {
173 let mv = best_mv + step;
174 let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv, MAX_DIST);
175 if best_dist > dist {
176 best_dist = dist;
177 best_mv = mv;
178 }
179 }
180 (best_mv, best_dist)
181 })
182 }