use std::str::FromStr;
-#[derive(Debug,Clone,Copy,PartialEq)]
+#[derive(Debug,Clone,Copy,PartialEq,Default)]
+#[allow(dead_code)]
pub enum MVSearchMode {
Full,
+ SEA,
Diamond,
+ #[default]
Hexagon,
-}
-
-impl Default for MVSearchMode {
- fn default() -> Self { MVSearchMode::Hexagon }
+ EPZS,
}
pub struct ParseError{}
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"full" => Ok(MVSearchMode::Full),
+ "sea" => Ok(MVSearchMode::SEA),
"dia" => Ok(MVSearchMode::Diamond),
"hex" => Ok(MVSearchMode::Hexagon),
+ "epzs" => Ok(MVSearchMode::EPZS),
_ => Err(ParseError{}),
}
}
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match *self {
MVSearchMode::Full => write!(f, "full"),
+ MVSearchMode::SEA => write!(f, "sea"),
MVSearchMode::Diamond => write!(f, "dia"),
MVSearchMode::Hexagon => write!(f, "hex"),
+ MVSearchMode::EPZS => write!(f, "epzs"),
}
}
}
#[macro_export]
macro_rules! search_template {
- ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident) => ({
- let mut best_dist = MAX_DIST;
- let mut best_mv;
+ ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident, $threshold: expr) => ({
+ search_template!($self, $mv_est, $cur_blk, $mb_x, $mb_y, $sad_func, $threshold, ZERO_MV, MAX_DIST, true)
+ });
+ ($self: expr, $mv_est: expr, $cur_blk: expr, $mb_x: expr, $mb_y: expr, $sad_func: ident, $threshold: expr, $start_mv: expr, $best_dist: expr, $fullpel_stage: expr) => ({
+ let mut best_dist = $best_dist;
+ let mut best_mv = $start_mv;
let mut min_dist;
let mut min_idx;
- $self.reset();
- loop {
- let mut cur_best_dist = best_dist;
- for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) {
- if *dist == MAX_DIST {
- *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point.from_pixels(), cur_best_dist);
- cur_best_dist = cur_best_dist.min(*dist);
- if *dist <= DIST_THRESH {
- break;
+ if $fullpel_stage {
+ $self.reset();
+ loop {
+ let mut cur_best_dist = best_dist;
+ for (dist, &point) in $self.dist.iter_mut().zip($self.point.iter()) {
+ if *dist == MAX_DIST {
+ *dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point.from_pixels(), cur_best_dist);
+ cur_best_dist = cur_best_dist.min(*dist);
+ if *dist <= $threshold {
+ break;
+ }
}
}
- }
- min_dist = $self.dist[0];
- min_idx = 0;
- for (i, &dist) in $self.dist.iter().enumerate().skip(1) {
- if dist < min_dist {
- min_dist = dist;
- min_idx = i;
- if dist <= DIST_THRESH {
- break;
+ min_dist = $self.dist[0];
+ min_idx = 0;
+ for (i, &dist) in $self.dist.iter().enumerate().skip(1) {
+ if dist < min_dist {
+ min_dist = dist;
+ min_idx = i;
+ if dist <= $threshold {
+ break;
+ }
}
}
- }
- if min_dist <= DIST_THRESH || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range || $self.point[min_idx].y.abs() >= $mv_est.mv_range {
- break;
+ if min_dist <= $threshold || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range || $self.point[min_idx].y.abs() >= $mv_est.mv_range {
+ break;
+ }
+ best_dist = min_dist;
+ $self.update($self.steps[min_idx]);
}
best_dist = min_dist;
- $self.update($self.steps[min_idx]);
- }
- best_dist = min_dist;
- best_mv = $self.point[min_idx];
- if best_dist <= DIST_THRESH {
- return (best_mv.from_pixels(), best_dist);
- }
- for &step in REFINEMENT.iter() {
- let mv = best_mv + step;
- let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv.from_pixels(), MAX_DIST);
- if best_dist > dist {
- best_dist = dist;
- best_mv = mv;
+ best_mv = $self.point[min_idx];
+ if best_dist <= $threshold {
+ return (best_mv.from_pixels(), best_dist);
+ }
+ for &step in REFINEMENT.iter() {
+ let mv = best_mv + step;
+ let dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, mv.from_pixels(), MAX_DIST);
+ if best_dist > dist {
+ best_dist = dist;
+ best_mv = mv;
+ }
+ }
+ best_mv = best_mv.from_pixels();
+ if best_dist <= $threshold {
+ return (best_mv, best_dist);
}
- }
- best_mv = best_mv.from_pixels();
- if best_dist <= DIST_THRESH {
- return (best_mv, best_dist);
}
// subpel refinement
if *dist == MAX_DIST {
*dist = $mv_est.$sad_func($cur_blk, $mb_x, $mb_y, point, cur_best_dist);
cur_best_dist = cur_best_dist.min(*dist);
- if *dist <= DIST_THRESH {
+ if *dist <= $threshold {
break;
}
}
if dist < min_dist {
min_dist = dist;
min_idx = i;
- if dist <= DIST_THRESH {
+ if dist <= $threshold {
break;
}
}
}
- if min_dist <= DIST_THRESH || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range * 4 || $self.point[min_idx].y.abs() >= $mv_est.mv_range * 4 {
+ if min_dist <= $threshold || min_idx == 0 || best_dist == min_dist || $self.point[min_idx].x.abs() >= $mv_est.mv_range * 8 || $self.point[min_idx].y.abs() >= $mv_est.mv_range * 8 {
break;
}
best_dist = min_dist;
}
best_dist = min_dist;
best_mv = $self.point[min_idx];
- if best_dist <= DIST_THRESH {
+ if best_dist <= $threshold {
return (best_mv, best_dist);
}
for &step in REFINEMENT.iter() {
}
}
(best_mv, best_dist)
- })
+ });
}