gstreamer/subclass/
task_pool.rs

1// Take a look at the license at the top of the repository in the LICENSE file.
2
3use std::{
4    hash::{Hash, Hasher},
5    ptr,
6    sync::{Arc, Mutex},
7};
8
9use glib::{ffi::gpointer, prelude::*, subclass::prelude::*, translate::*};
10
11use super::prelude::*;
12use crate::{TaskHandle, TaskPool, ffi};
13
14pub trait TaskPoolImpl: GstObjectImpl + ObjectSubclass<Type: IsA<TaskPool>> {
15    // rustdoc-stripper-ignore-next
16    /// Handle to be returned from the `push` function to allow the caller to wait for the task's
17    /// completion.
18    ///
19    /// If unneeded, you can specify `()` or [`Infallible`](std::convert::Infallible) for a handle
20    /// that does nothing on `join` or drop.
21    type Handle: TaskHandle;
22
23    // rustdoc-stripper-ignore-next
24    /// Prepare the task pool to accept tasks.
25    ///
26    /// This defaults to doing nothing.
27    fn prepare(&self) -> Result<(), glib::Error> {
28        Ok(())
29    }
30
31    // rustdoc-stripper-ignore-next
32    /// Clean up, rejecting further tasks and waiting for all accepted tasks to be stopped.
33    ///
34    /// This is mainly used internally to ensure proper cleanup of internal data structures in test
35    /// suites.
36    fn cleanup(&self) {}
37
38    // rustdoc-stripper-ignore-next
39    /// Deliver a task to the pool.
40    ///
41    /// If returning `Ok`, you need to call the `func` eventually.
42    ///
43    /// If returning `Err`, the `func` must be dropped without calling it.
44    fn push(&self, func: TaskPoolFunction) -> Result<Option<Self::Handle>, glib::Error>;
45}
46
47unsafe impl<T: TaskPoolImpl> IsSubclassable<T> for TaskPool {
48    fn class_init(klass: &mut glib::Class<Self>) {
49        Self::parent_class_init::<T>(klass);
50        let klass = klass.as_mut();
51        klass.prepare = Some(task_pool_prepare::<T>);
52        klass.cleanup = Some(task_pool_cleanup::<T>);
53        klass.push = Some(task_pool_push::<T>);
54        klass.join = Some(task_pool_join::<T>);
55
56        #[cfg(feature = "v1_20")]
57        {
58            klass.dispose_handle = Some(task_pool_dispose_handle::<T>);
59        }
60    }
61}
62
63unsafe extern "C" fn task_pool_prepare<T: TaskPoolImpl>(
64    ptr: *mut ffi::GstTaskPool,
65    error: *mut *mut glib::ffi::GError,
66) {
67    unsafe {
68        let instance = &*(ptr as *mut T::Instance);
69        let imp = instance.imp();
70
71        match imp.prepare() {
72            Ok(()) => {}
73            Err(err) => {
74                if !error.is_null() {
75                    *error = err.into_glib_ptr();
76                }
77            }
78        }
79    }
80}
81
82unsafe extern "C" fn task_pool_cleanup<T: TaskPoolImpl>(ptr: *mut ffi::GstTaskPool) {
83    unsafe {
84        let instance = &*(ptr as *mut T::Instance);
85        let imp = instance.imp();
86
87        imp.cleanup();
88    }
89}
90
91unsafe extern "C" fn task_pool_push<T: TaskPoolImpl>(
92    ptr: *mut ffi::GstTaskPool,
93    func: ffi::GstTaskPoolFunction,
94    user_data: gpointer,
95    error: *mut *mut glib::ffi::GError,
96) -> gpointer {
97    unsafe {
98        let instance = &*(ptr as *mut T::Instance);
99        let imp = instance.imp();
100
101        let func = TaskPoolFunction::new(func.expect("Tried to push null func"), user_data);
102
103        match imp.push(func.clone()) {
104            Ok(None) => ptr::null_mut(),
105            Ok(Some(handle)) => Box::into_raw(Box::new(handle)) as gpointer,
106            Err(err) => {
107                func.prevent_call();
108                if !error.is_null() {
109                    *error = err.into_glib_ptr();
110                }
111                ptr::null_mut()
112            }
113        }
114    }
115}
116
117unsafe extern "C" fn task_pool_join<T: TaskPoolImpl>(ptr: *mut ffi::GstTaskPool, id: gpointer) {
118    unsafe {
119        if id.is_null() {
120            let wrap: Borrowed<TaskPool> = from_glib_borrow(ptr);
121            crate::warning!(
122                crate::CAT_RUST,
123                obj = wrap.as_ref(),
124                "Tried to join null handle"
125            );
126            return;
127        }
128
129        let handle = Box::from_raw(id as *mut T::Handle);
130        handle.join();
131    }
132}
133
134#[cfg(feature = "v1_20")]
135#[cfg_attr(docsrs, doc(cfg(feature = "v1_20")))]
136unsafe extern "C" fn task_pool_dispose_handle<T: TaskPoolImpl>(
137    ptr: *mut ffi::GstTaskPool,
138    id: gpointer,
139) {
140    unsafe {
141        if id.is_null() {
142            let wrap: Borrowed<TaskPool> = from_glib_borrow(ptr);
143            crate::warning!(
144                crate::CAT_RUST,
145                obj = wrap.as_ref(),
146                "Tried to dispose null handle"
147            );
148            return;
149        }
150
151        let handle = Box::from_raw(id as *mut T::Handle);
152        drop(handle);
153    }
154}
155
156// rustdoc-stripper-ignore-next
157/// Function the task pool should execute, provided to [`push`](TaskPoolImpl::push).
158#[derive(Debug)]
159pub struct TaskPoolFunction(Arc<Mutex<Option<TaskPoolFunctionInner>>>);
160
161// `Arc<Mutex<Option<…>>>` is required so that we can enforce that the function
162// has not been called and will never be called after `push` returns `Err`.
163
164#[derive(Debug)]
165struct TaskPoolFunctionInner {
166    func: unsafe extern "C" fn(gpointer),
167    user_data: gpointer,
168    warn_on_drop: bool,
169}
170
171unsafe impl Send for TaskPoolFunctionInner {}
172
173impl TaskPoolFunction {
174    fn new(func: unsafe extern "C" fn(gpointer), user_data: gpointer) -> Self {
175        let inner = TaskPoolFunctionInner {
176            func,
177            user_data,
178            warn_on_drop: true,
179        };
180        Self(Arc::new(Mutex::new(Some(inner))))
181    }
182
183    #[inline]
184    fn clone(&self) -> Self {
185        Self(self.0.clone())
186    }
187
188    // rustdoc-stripper-ignore-next
189    /// Consume and execute the function.
190    pub fn call(self) {
191        let mut inner = self
192            .0
193            .lock()
194            .unwrap()
195            .take()
196            .expect("TaskPoolFunction has already been dropped");
197        inner.warn_on_drop = false;
198        unsafe { (inner.func)(inner.user_data) }
199    }
200
201    fn prevent_call(self) {
202        let mut inner = self
203            .0
204            .lock()
205            .unwrap()
206            .take()
207            .expect("TaskPoolFunction has already been called");
208        inner.warn_on_drop = false;
209        drop(inner);
210    }
211
212    #[inline]
213    fn as_ptr(&self) -> *const Mutex<Option<TaskPoolFunctionInner>> {
214        Arc::as_ptr(&self.0)
215    }
216}
217
218impl Drop for TaskPoolFunctionInner {
219    fn drop(&mut self) {
220        if self.warn_on_drop {
221            crate::warning!(crate::CAT_RUST, "Leaked task function");
222        }
223    }
224}
225
226impl PartialEq for TaskPoolFunction {
227    fn eq(&self, other: &Self) -> bool {
228        self.as_ptr() == other.as_ptr()
229    }
230}
231
232impl Eq for TaskPoolFunction {}
233
234impl PartialOrd for TaskPoolFunction {
235    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
236        Some(self.cmp(other))
237    }
238}
239
240impl Ord for TaskPoolFunction {
241    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
242        self.as_ptr().cmp(&other.as_ptr())
243    }
244}
245
246impl Hash for TaskPoolFunction {
247    fn hash<H: Hasher>(&self, state: &mut H) {
248        self.as_ptr().hash(state)
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use std::{
255        sync::{
256            atomic,
257            mpsc::{TryRecvError, channel},
258        },
259        thread,
260    };
261
262    use super::*;
263    use crate::prelude::*;
264
265    pub mod imp {
266        use super::*;
267
268        #[derive(Default)]
269        pub struct TestPool {
270            pub(super) prepared: atomic::AtomicBool,
271            pub(super) cleaned_up: atomic::AtomicBool,
272        }
273
274        #[glib::object_subclass]
275        impl ObjectSubclass for TestPool {
276            const NAME: &'static str = "TestPool";
277            type Type = super::TestPool;
278            type ParentType = TaskPool;
279        }
280
281        impl ObjectImpl for TestPool {}
282
283        impl GstObjectImpl for TestPool {}
284
285        impl TaskPoolImpl for TestPool {
286            type Handle = TestHandle;
287
288            fn prepare(&self) -> Result<(), glib::Error> {
289                self.prepared.store(true, atomic::Ordering::SeqCst);
290                Ok(())
291            }
292
293            fn cleanup(&self) {
294                self.cleaned_up.store(true, atomic::Ordering::SeqCst);
295            }
296
297            fn push(&self, func: TaskPoolFunction) -> Result<Option<Self::Handle>, glib::Error> {
298                let handle = thread::spawn(move || func.call());
299                Ok(Some(TestHandle(handle)))
300            }
301        }
302
303        pub struct TestHandle(thread::JoinHandle<()>);
304
305        impl TaskHandle for TestHandle {
306            fn join(self) {
307                self.0.join().unwrap();
308            }
309        }
310    }
311
312    glib::wrapper! {
313        pub struct TestPool(ObjectSubclass<imp::TestPool>) @extends TaskPool, crate::Object;
314    }
315
316    unsafe impl Send for TestPool {}
317    unsafe impl Sync for TestPool {}
318
319    impl TestPool {
320        pub fn new() -> Self {
321            Self::default()
322        }
323    }
324
325    impl Default for TestPool {
326        fn default() -> Self {
327            glib::Object::new()
328        }
329    }
330
331    #[test]
332    fn test_simple_subclass() {
333        crate::init().unwrap();
334
335        let pool = TestPool::new();
336        pool.prepare().unwrap();
337
338        let (sender, receiver) = channel();
339
340        let handle = pool
341            .push(move || {
342                sender.send(()).unwrap();
343            })
344            .unwrap();
345        let handle = handle.unwrap();
346
347        assert_eq!(receiver.recv(), Ok(()));
348
349        handle.join();
350        assert_eq!(receiver.try_recv(), Err(TryRecvError::Disconnected));
351
352        pool.cleanup();
353
354        let imp = pool.imp();
355        assert!(imp.prepared.load(atomic::Ordering::SeqCst));
356        assert!(imp.cleaned_up.load(atomic::Ordering::SeqCst));
357    }
358}