h264: add multi-threaded decoder
[nihav.git] / nihav-itu / src / codecs / h264 / dispatch.rs
CommitLineData
11d7aef2
KS
1use std::sync::{Arc, Barrier};
2use std::sync::atomic::*;
3use std::thread;
4
5use nihav_core::codecs::{DecoderError, DecoderResult};
6
7use super::{FrameDecoder, PictureInfo, Shareable};
8
9#[derive(Clone,Copy,Debug,PartialEq)]
10pub enum FrameDecodingStatus {
11 Ok,
12 NotReady,
13 Error,
14 NotFound,
15}
16
17struct FrameState {
18 pinfo: PictureInfo,
19 mb_pos: AtomicUsize,
20 error: AtomicBool,
21 complete: AtomicBool,
22 output: AtomicBool,
23 worker: Option<thread::JoinHandle<DecoderResult<()>>>,
24 result: DecoderResult<()>,
25 num_refs: usize,
26 ref_frames: Vec<u32>,
27}
28
29impl FrameState {
30 fn get_id(&self) -> u32 { self.pinfo.full_id }
31 fn get_user_id(&self) -> u32 { self.pinfo.user_id }
32 fn is_working(&self) -> bool {
33 self.worker.is_some() &&
34 !self.complete.load(Ordering::Relaxed) &&
35 !self.error.load(Ordering::Relaxed)
36 }
37 fn is_output_candidate(&self) -> bool {
38 !self.output.load(Ordering::Relaxed) &&
39 (self.complete.load(Ordering::Relaxed) || self.error.load(Ordering::Relaxed))
40 }
41}
42
43pub struct ThreadDispatcher {
44 fstate: Vec<FrameState>,
45 pub max_threads: usize,
46 cur_threads: usize,
47}
48
49impl ThreadDispatcher {
50 pub fn new() -> Self {
51 Self {
52 fstate: Vec::new(),
53 max_threads: 3,
54 cur_threads: 0,
55 }
56 }
57 pub fn can_decode_more(&self) -> bool {
58 let out_cand = self.fstate.iter().filter(|state| state.is_output_candidate()).count();
59 if out_cand > self.max_threads {
60 return false;
61 }
62 if (self.cur_threads < self.max_threads) || (self.max_threads == 0) {
63 true
64 } else {
65 let real_workers = self.fstate.iter().fold(0usize,
66 |acc, state| acc + (state.is_working() as usize));
67 real_workers < self.max_threads
68 }
69 }
70 fn cleanup(&mut self) {
71 for state in self.fstate.iter_mut() {
72 if state.worker.is_some() && !state.is_working() {
73 let mut ret = None;
74 std::mem::swap(&mut state.worker, &mut ret);
75 if let Some(handle) = ret {
76 state.result = handle.join().unwrap();
77 }
78 self.cur_threads -= 1;
79 }
80 }
81 }
82 fn unref_frame(&mut self, id: u32) {
83 let mut toremove = Vec::new();
84 for state in self.fstate.iter() {
85 if state.num_refs == 0 && state.output.load(Ordering::Relaxed) {
86 toremove.push(state.get_id());
87 }
88 }
89 if let Some(idx) = self.find_by_id(id) {
90 let mut ref_frm = Vec::new();
91 std::mem::swap(&mut ref_frm, &mut self.fstate[idx].ref_frames);
92 for state in self.fstate.iter_mut() {
93 if ref_frm.contains(&state.get_id()) {
94 assert!(state.num_refs >= 2);
95 state.num_refs -= 2;
96 }
97 }
98 if self.fstate[idx].num_refs == 0 && self.fstate[idx].output.load(Ordering::Relaxed) {
99 self.remove_frame(id);
100 }
101 }
102 for &id in toremove.iter() {
103 self.remove_frame(id);
104 }
105 }
106 fn find_by_id(&self, id: u32) -> Option<usize> {
107 self.fstate.iter().position(|x| x.get_id() == id)
108 }
109 fn set_completed(&self, id: u32) {
110 if let Some(idx) = self.find_by_id(id) {
111 self.fstate[idx].complete.store(true, Ordering::Relaxed);
112 }
113 }
114 fn set_error(&self, id: u32) {
115 if let Some(idx) = self.find_by_id(id) {
116 self.fstate[idx].error.store(true, Ordering::Relaxed);
117 }
118 }
119 pub fn update_pos(&self, id: u32, mb_pos: usize) {
120 if let Some(idx) = self.find_by_id(id) {
121 self.fstate[idx].mb_pos.store(mb_pos, Ordering::Relaxed);
122 }
123 }
124 pub fn check_pos(&self, id: u32, mb_pos: usize) -> FrameDecodingStatus {
125 if let Some(idx) = self.find_by_id(id) {
126 let state = &self.fstate[idx];
127 if !state.error.load(Ordering::Relaxed) {
128 if state.complete.load(Ordering::Relaxed) || mb_pos < state.mb_pos.load(Ordering::Relaxed) {
129 FrameDecodingStatus::Ok
130 } else {
131 FrameDecodingStatus::NotReady
132 }
133 } else {
134 FrameDecodingStatus::Error
135 }
136 } else {
137 FrameDecodingStatus::NotFound
138 }
139 }
140 fn remove_frame(&mut self, id: u32) {
141 if let Some(idx) = self.find_by_id(id) {
142 self.fstate.remove(idx);
143 }
144 }
145 /*fn print_state(&self) {
146 print!(" state:");
147 for state in self.fstate.iter() {
148 print!(" s{}b{}r{}{}{}{}", state.get_id(),
149 state.mb_pos.load(Ordering::Relaxed), state.num_refs,
150 if state.error.load(Ordering::Relaxed) { "E" } else {""},
151 if state.complete.load(Ordering::Relaxed) {"C"} else {""},
152 if state.output.load(Ordering::Relaxed) {"O"} else {""});
153 }
154 println!();
155 }*/
156 pub fn has_output(&self) -> bool {
157 for state in self.fstate.iter() {
158 if state.is_output_candidate() {
159 return true;
160 }
161 }
162 false
163 }
164}
165
166pub fn queue_decoding(disp: &mut Shareable<ThreadDispatcher>, mut fdec: FrameDecoder, initial_ref_frames: &[u32], ref_frames: &[u32]) {
167 let barrier = Arc::new(Barrier::new(2));
168 let starter = Arc::clone(&barrier);
169
170 let pinfo = fdec.cur_pic.clone();
171 let pic_id = pinfo.full_id;
172 let shared_disp = Arc::clone(disp);
173 let worker = thread::Builder::new().name("frame ".to_string() + &pic_id.to_string()).spawn(move || {
174 barrier.wait();
175
176 let mut slices = Vec::new();
177 std::mem::swap(&mut slices, &mut fdec.slices);
178 let mut cur_mb = 0;
179 for (hdr, hdr_size, refs, nal) in slices.iter() {
180 if hdr.first_mb_in_slice != cur_mb {
181 if let Ok(rd) = shared_disp.read() {
182 rd.set_error(pic_id);
183 } else {
184 panic!("can't set error");
185 }
186 return Err(DecoderError::InvalidData);
187 }
188 match fdec.decode_slice(hdr, *hdr_size, refs, nal) {
189 Ok(pos) => cur_mb = pos,
190 Err(err) => {
191 if let Ok(rd) = shared_disp.read() {
192 rd.set_error(pic_id);
193 } else {
194 panic!("can't set error");
195 }
196 return Err(err);
197 },
198 };
199 }
200
201 if cur_mb == fdec.num_mbs {
202 if let Ok(rd) = shared_disp.read() {
203 rd.set_completed(pic_id);
204 } else {
205 panic!("can't set status");
206 }
207 }
208
209 DecoderResult::Ok(())
210 }).unwrap();
211 let new_state = FrameState {
212 pinfo,
213 mb_pos: AtomicUsize::new(0),
214 error: AtomicBool::new(false),
215 complete: AtomicBool::new(false),
216 output: AtomicBool::new(false),
217 worker: Some(worker),
218 result: DecoderResult::Err(DecoderError::Bug),
219 num_refs: 0,
220 ref_frames: initial_ref_frames.to_vec(),
221 };
222 if let Ok(ref mut ds) = disp.write() {
223 let new_id = new_state.get_id();
224 if ds.find_by_id(new_id).is_some() {
225 ds.remove_frame(new_id);
226 }
227 ds.cleanup();
228 ds.fstate.push(new_state);
229 for state in ds.fstate.iter_mut() {
230 if ref_frames.contains(&state.get_id()) {
231 state.num_refs += 1;
232 }
233 if initial_ref_frames.contains(&state.get_id()) {
234 state.num_refs += 1;
235 }
236 }
237 ds.cur_threads += 1;
238 starter.wait();
239 } else {
240 panic!("cannot invoke thread dispatcher");
241 }
242}
243
244pub fn wait_for_one(dispatch: &mut Shareable<ThreadDispatcher>) -> Result<PictureInfo, (DecoderError, u32)> {
245 /*if let Ok(ref ds) = dispatch.read() {
246 ds.print_state();
247 }*/
248 let start = std::time::Instant::now();
249 'main_loop: loop {
250 if std::time::Instant::now().duration_since(start) > std::time::Duration::from_millis(20000) { panic!(" too long!"); }
251 if let Ok(ref ds) = dispatch.read() {
252 let mut nw = 0;
253 for state in ds.fstate.iter() {
254 if state.is_working() {
255 nw += 1;
256 }
257 if state.is_output_candidate() {
258 break 'main_loop;
259 }
260 }
261 if nw == 0 {
262 return Err((DecoderError::NoFrame, 0));
263 }
264 } else {
265 panic!("can't peek into status");
266 }
267 thread::yield_now();
268 }
269 if let Ok(ref mut ds) = dispatch.write() {
270 ds.cleanup();
271 let mut found = None;
272 for state in ds.fstate.iter() {
273 if state.is_output_candidate() {
274 state.output.store(true, Ordering::Relaxed);
275 if let DecoderResult::Err(err) = state.result {
276 let id = state.get_id();
277 let user_id = state.get_user_id();
278 ds.unref_frame(id);
279 return Err((err, user_id));
280 } else {
281 found = Some(state.pinfo.clone());
282 break;
283 }
284 }
285 }
286 if let Some(ret) = found {
287 ds.unref_frame(ret.full_id);
288 Ok(ret)
289 } else {
290 unreachable!();
291 }
292 } else {
293 panic!("can't grab status");
294 }
295}
296
297pub fn clear_threads(dispatch: &mut Shareable<ThreadDispatcher>) {
298 /*if let Ok(ref ds) = dispatch.read() {
299 ds.print_state();
300 }*/
301 let mut to_wait = Vec::new();
302 if let Ok(ref mut ds) = dispatch.write() {
303 while let Some(state) = ds.fstate.pop() {
304 if let Some(handle) = state.worker {
305 to_wait.push(handle);
306 }
307 }
308 ds.cur_threads = 0;
309 } else {
310 panic!("can't grab status");
311 }
312 while let Some(handle) = to_wait.pop() {
313 let _ = handle.join();
314 }
315}