base/generic_channel/
lazy_callback.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5//! # Lazy Callbacks
6//!
7//! When constructing callbacks we sometimes have a large distance between where the channel for the callback
8//! is created and where the initial callback will be created. Refactoring of this code is sometimes not possible.
9//! Here we provide [LazyCallback]. We use 'lazy_callback()' to generate a [LazyCallback] and a [CallbackSetter].
10//! The [LazyCallback] works like a [GenericCallback] and can be used to execute callbacks in the receiver process.
11//! The [CallbackSetter] has a single consuming method of 'set_callback' which will set the callback that the [LazyCallback]
12//! will then execute on messages send to it.
13//!
14//! This is achieved with having the LazyCallback having a back channel in single process mode that sets the [GenericCallback].
15//! Hence, this is slightly less efficient than a [GenericCallback]
16
17use std::cell::{OnceCell, RefCell};
18use std::fmt;
19use std::marker::PhantomData;
20
21use ipc_channel::ipc::{IpcReceiver, IpcSender};
22use ipc_channel::router::ROUTER;
23use malloc_size_of::{MallocSizeOf as MallocSizeOfTrait, MallocSizeOfOps};
24use malloc_size_of_derive::MallocSizeOf;
25use serde::de::VariantAccess;
26use serde::{Deserialize, Deserializer, Serialize, Serializer};
27use servo_config::opts;
28
29use crate::generic_channel::{GenericCallback, SendError, SendResult, use_ipc};
30
31/// Basic struct for [LazyCallback]
32#[derive(MallocSizeOf)]
33pub struct LazyCallback<T: Serialize + for<'de> Deserialize<'de> + Send + 'static>(
34    LazyCallbackVariants<T>,
35);
36
37enum LazyCallbackVariants<T>
38where
39    T: Serialize + Send + 'static,
40{
41    InProcess {
42        callback_receiver: RefCell<Option<crossbeam_channel::Receiver<GenericCallback<T>>>>,
43        callback: OnceCell<GenericCallback<T>>,
44    },
45    Ipc(IpcSender<T>),
46}
47
48impl<T> MallocSizeOfTrait for LazyCallbackVariants<T>
49where
50    T: Serialize + Send + 'static,
51{
52    fn size_of(&self, ops: &mut MallocSizeOfOps) -> usize {
53        match self {
54            LazyCallbackVariants::InProcess {
55                callback_receiver,
56                callback,
57            } => callback_receiver.size_of(ops) + callback.size_of(ops),
58            LazyCallbackVariants::Ipc(_) => 0,
59        }
60    }
61}
62
63impl<T> LazyCallback<T>
64where
65    T: Serialize + for<'de> Deserialize<'de> + Send + 'static,
66{
67    /// Send messages to the callback. This might block until the callback is set via the 'CallbackSetter'
68    pub fn send(&self, value: T) -> SendResult {
69        match &self.0 {
70            LazyCallbackVariants::InProcess {
71                callback_receiver,
72                callback,
73            } => {
74                if let Some(cb) = callback.get() {
75                    cb.send(value)
76                } else {
77                    // Init callback
78                    if let Ok(cb) = callback_receiver.borrow_mut().take().unwrap().recv() {
79                        let _ = callback.set(cb);
80                        callback.get().unwrap().send(value)
81                    } else {
82                        log::error!("Could not get callback. Callback_receiver already dropped");
83                        SendResult::Err(SendError::Disconnected)
84                    }
85                }
86            },
87            LazyCallbackVariants::Ipc(ipc_sender) => {
88                ipc_sender.send(value).map_err(|error| match error {
89                    ipc_channel::IpcError::SerializationError(ser_de_error) => {
90                        SendError::SerializationError(ser_de_error.to_string())
91                    },
92                    ipc_channel::IpcError::Io(_) | ipc_channel::IpcError::Disconnected => {
93                        SendError::Disconnected
94                    },
95                })
96            },
97        }
98    }
99}
100
101pub struct CallbackSetter<T: Serialize + Send + 'static>(CallbackSetterVariants<T>);
102
103impl<T: Serialize + Send> fmt::Debug for CallbackSetter<T> {
104    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105        f.debug_tuple("CallbackSetter").finish()
106    }
107}
108
109impl<T> Serialize for CallbackSetter<T>
110where
111    T: Serialize + Send + 'static,
112{
113    fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
114        match &self.0 {
115            CallbackSetterVariants::Ipc(sender) => {
116                s.serialize_newtype_variant("CallbackSetter", 0, "Ipc", sender)
117            },
118            // The only reason we need / want serialization in single-process mode is to support
119            // sending GenericCallbacks over existing IPC channels. This allows us to
120            // incrementally port IPC channels to the GenericChannel, without needing to follow a
121            // top-to-bottom approach.
122            // Long-term we can remove this branch in the code again and replace it with
123            // unreachable, since likely all IPC channels would be GenericChannels.
124            CallbackSetterVariants::InProcess(wrapped_callback) => {
125                if use_ipc() {
126                    return Err(serde::ser::Error::custom(
127                        "InProcess callback setter can't be serialized in multiprocess mode",
128                    ));
129                }
130                // Due to the signature of `serialize` we need to clone the Arc to get an owned
131                // pointer we can leak.
132                // We additionally need to Box to get a thin pointer.
133                let cloned_callback = Box::new(wrapped_callback.clone());
134                let sender_clone_addr = Box::leak(cloned_callback) as *mut _ as usize;
135                s.serialize_newtype_variant("CallbackSetter", 1, "InProcess", &sender_clone_addr)
136            },
137        }
138    }
139}
140
141struct LazyCallbackSetterVisitor<T> {
142    marker: PhantomData<T>,
143}
144
145impl<'de, T> serde::de::Visitor<'de> for LazyCallbackSetterVisitor<T>
146where
147    T: Serialize + Deserialize<'de> + Send + 'static,
148{
149    type Value = CallbackSetter<T>;
150
151    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
152        formatter.write_str("a GenericCallback variant")
153    }
154
155    fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
156    where
157        A: serde::de::EnumAccess<'de>,
158    {
159        #[derive(Deserialize)]
160        enum LazyCallbackSetterVariantNames {
161            Ipc,
162            InProcess,
163        }
164
165        let (variant_name, variant_data): (LazyCallbackSetterVariantNames, _) = data.variant()?;
166
167        match variant_name {
168            LazyCallbackSetterVariantNames::Ipc => variant_data
169                .newtype_variant::<IpcReceiver<T>>()
170                .map(|receiver| CallbackSetter(CallbackSetterVariants::Ipc(receiver))),
171            LazyCallbackSetterVariantNames::InProcess => {
172                if use_ipc() {
173                    return Err(serde::de::Error::custom(
174                        "InProcess callback found in multiprocess mode",
175                    ));
176                }
177                let addr = variant_data.newtype_variant::<usize>()?;
178                let ptr = addr as *mut _;
179                // SAFETY: We know we are in the same address space as the sender, so we can safely
180                // reconstruct the Box, that we previously leaked with `into_raw` during
181                // serialization.
182                // Attention: Code reviewers should carefully compare the deserialization here
183                // with the serialization above.
184                #[expect(unsafe_code)]
185                let callback = unsafe { Box::from_raw(ptr) };
186                Ok(CallbackSetter(CallbackSetterVariants::InProcess(*callback)))
187            },
188        }
189    }
190}
191
192impl<'a, T> Deserialize<'a> for CallbackSetter<T>
193where
194    T: Serialize + Deserialize<'a> + Send + 'static,
195{
196    fn deserialize<D>(d: D) -> Result<CallbackSetter<T>, D::Error>
197    where
198        D: Deserializer<'a>,
199    {
200        d.deserialize_enum(
201            "GenericCallback",
202            &["CrossProcess", "InProcess"],
203            LazyCallbackSetterVisitor {
204                marker: PhantomData,
205            },
206        )
207    }
208}
209
210enum CallbackSetterVariants<T>
211where
212    T: Serialize + Send + 'static,
213{
214    InProcess(crossbeam_channel::Sender<GenericCallback<T>>),
215    Ipc(IpcReceiver<T>),
216}
217
218impl<T> CallbackSetter<T>
219where
220    T: Serialize + for<'de> Deserialize<'de> + Send + 'static,
221{
222    /// This sets the callback.
223    pub fn set_callback<F: FnMut(Result<T, ipc_channel::IpcError>) + Send + 'static>(
224        self,
225        mut callback: F,
226    ) {
227        match self.0 {
228            CallbackSetterVariants::InProcess(sender) => {
229                let callback = GenericCallback::new(callback).expect("Could not create callback");
230                if sender.send(callback).is_err() {
231                    log::error!("Could not send callback, sender was already dropped");
232                }
233            },
234            CallbackSetterVariants::Ipc(ipc_receiver) => {
235                let new_callback = move |msg: Result<T, ipc_channel::SerDeError>| {
236                    callback(msg.map_err(|error| error.into()))
237                };
238                ROUTER.add_typed_route(ipc_receiver, Box::new(new_callback));
239            },
240        }
241    }
242}
243
244/// This function should never be exported.
245fn lazy_callback_inprocess<T>() -> (LazyCallback<T>, CallbackSetter<T>)
246where
247    T: Serialize + for<'de> Deserialize<'de> + Send + 'static,
248{
249    let (callback_sender, callback_receiver) = crossbeam_channel::bounded(1);
250    let lazycallback = LazyCallback(LazyCallbackVariants::InProcess {
251        callback_receiver: RefCell::new(Some(callback_receiver)),
252        callback: OnceCell::new(),
253    });
254
255    let callback_setter = CallbackSetter(CallbackSetterVariants::InProcess(callback_sender));
256
257    (lazycallback, callback_setter)
258}
259
260/// This function should never be exported.
261fn lazy_callback_ipc<T>() -> (LazyCallback<T>, CallbackSetter<T>)
262where
263    T: Serialize + for<'de> Deserialize<'de> + Send + 'static,
264{
265    let (sender, receiver) = ipc_channel::ipc::channel().expect("Could not create channel");
266    let callback = LazyCallback(LazyCallbackVariants::Ipc(sender));
267    let callback_setter = CallbackSetter(CallbackSetterVariants::Ipc(receiver));
268    (callback, callback_setter)
269}
270
271/// A LazyCallback is a Callback that will be initialized at a later date.
272/// We return the 'LazyCallback' which is a GenericCallback.
273/// We also return a 'CallbackSetter' where the callback can be set at a later date.
274pub fn lazy_callback<T>() -> (LazyCallback<T>, CallbackSetter<T>)
275where
276    T: Serialize + for<'de> Deserialize<'de> + Send + 'static,
277{
278    if opts::get().multiprocess || opts::get().force_ipc {
279        lazy_callback_ipc()
280    } else {
281        lazy_callback_inprocess()
282    }
283}
284
285#[cfg(test)]
286mod single_process_callback_test {
287    use crate::generic_channel::lazy_callback::{lazy_callback_inprocess, lazy_callback_ipc};
288    use crate::generic_channel::{CallbackSetter, LazyCallback};
289    fn test_lazy_callback(callback: LazyCallback<bool>, callback_setter: CallbackSetter<bool>) {
290        let t1 = std::thread::spawn(move || {
291            callback.send(true).expect("Could not send");
292        });
293
294        let (sender, receiver) = crossbeam_channel::bounded(1);
295        let t2 = std::thread::spawn(move || {
296            std::thread::sleep(std::time::Duration::from_secs(1));
297            callback_setter.set_callback(move |value| {
298                sender.send(value).expect("Could not send");
299            });
300        });
301
302        t1.join().expect("error joining thread");
303        t2.join().expect("error joining thread");
304        assert_eq!(receiver.recv().unwrap().unwrap(), true);
305    }
306
307    #[test]
308    fn lazy_callback_simple_inprocess() {
309        let (callback, callback_setter) = lazy_callback_inprocess();
310        test_lazy_callback(callback, callback_setter);
311    }
312
313    #[test]
314    fn lazy_callback_simple_ipc() {
315        let (callback, callback_setter) = lazy_callback_ipc();
316        test_lazy_callback(callback, callback_setter);
317    }
318}