1use std::fmt;
63use std::marker::PhantomData;
64use std::sync::{Arc, Mutex};
65
66use ipc_channel::ErrorKind;
67use ipc_channel::ipc::IpcSender;
68use ipc_channel::router::ROUTER;
69use malloc_size_of::{MallocSizeOf, MallocSizeOfOps};
70use serde::de::VariantAccess;
71use serde::{Deserialize, Deserializer, Serialize, Serializer};
72use servo_config::opts;
73
74use crate::generic_channel::{GenericReceiver, GenericReceiverVariants, SendError, SendResult};
75
76pub type MsgCallback<T> = dyn FnMut(Result<T, ipc_channel::Error>) + Send;
83
84pub struct GenericCallback<T>(GenericCallbackVariants<T>)
90where
91 T: Serialize + Send + 'static;
92
93enum GenericCallbackVariants<T>
94where
95 T: Serialize + Send + 'static,
96{
97 CrossProcess(IpcSender<T>),
98 InProcess(Arc<Mutex<MsgCallback<T>>>),
99}
100
101impl<T> Clone for GenericCallback<T>
102where
103 T: Serialize + Send + 'static,
104{
105 fn clone(&self) -> Self {
106 let variant = match &self.0 {
107 GenericCallbackVariants::CrossProcess(sender) => {
108 GenericCallbackVariants::CrossProcess((*sender).clone())
109 },
110 GenericCallbackVariants::InProcess(callback) => {
111 GenericCallbackVariants::InProcess(callback.clone())
112 },
113 };
114 GenericCallback(variant)
115 }
116}
117
118impl<T> MallocSizeOf for GenericCallback<T>
119where
120 T: Serialize + Send + 'static,
121{
122 fn size_of(&self, _ops: &mut MallocSizeOfOps) -> usize {
123 0
124 }
125}
126
127impl<T> GenericCallback<T>
128where
129 T: for<'de> Deserialize<'de> + Serialize + Send + 'static,
130{
131 pub fn new<F: FnMut(Result<T, ipc_channel::Error>) + Send + 'static>(
135 callback: F,
136 ) -> Result<Self, ipc_channel::Error> {
137 let generic_callback = if opts::get().multiprocess || opts::get().force_ipc {
138 let (ipc_sender, ipc_receiver) = ipc_channel::ipc::channel()?;
139 ROUTER.add_typed_route(ipc_receiver, Box::new(callback));
140 GenericCallback(GenericCallbackVariants::CrossProcess(ipc_sender))
141 } else {
142 let callback = Arc::new(Mutex::new(callback));
143 GenericCallback(GenericCallbackVariants::InProcess(callback))
144 };
145 Ok(generic_callback)
146 }
147
148 pub fn new_blocking() -> Result<(Self, GenericReceiver<T>), ipc_channel::Error> {
150 if opts::get().multiprocess || opts::get().force_ipc {
151 let (sender, receiver) = ipc_channel::ipc::channel()?;
152 let generic_callback = GenericCallback(GenericCallbackVariants::CrossProcess(sender));
153 let receiver = GenericReceiver(GenericReceiverVariants::Ipc(receiver));
154 Ok((generic_callback, receiver))
155 } else {
156 let (sender, receiver) = crossbeam_channel::bounded(1);
157 let callback = Arc::new(Mutex::new(move |msg| {
158 if sender.send(msg).is_err() {
159 log::error!("Error in callback");
160 }
161 }));
162 let generic_callback = GenericCallback(GenericCallbackVariants::InProcess(callback));
163 let receiver = GenericReceiver(GenericReceiverVariants::Crossbeam(receiver));
164 Ok((generic_callback, receiver))
165 }
166 }
167
168 pub fn send(&self, value: T) -> SendResult {
174 match &self.0 {
175 GenericCallbackVariants::CrossProcess(sender) => {
176 sender.send(value).map_err(|error| match *error {
177 ErrorKind::Io(_) => SendError::Disconnected,
178 serialization_error => {
179 SendError::SerializationError(serialization_error.to_string())
180 },
181 })
182 },
183 GenericCallbackVariants::InProcess(callback) => {
184 let mut cb = callback.lock().expect("poisoned");
185 (*cb)(Ok(value));
186 Ok(())
187 },
188 }
189 }
190}
191
192impl<T> Serialize for GenericCallback<T>
193where
194 T: Serialize + Send + 'static,
195{
196 fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
197 match &self.0 {
198 GenericCallbackVariants::CrossProcess(sender) => {
199 s.serialize_newtype_variant("GenericCallback", 0, "CrossProcess", sender)
200 },
201 GenericCallbackVariants::InProcess(wrapped_callback) => {
208 if opts::get().multiprocess {
209 return Err(serde::ser::Error::custom(
210 "InProcess callback can't be serialized in multiprocess mode",
211 ));
212 }
213 let cloned_callback = Box::new(wrapped_callback.clone());
217 let sender_clone_addr = Box::leak(cloned_callback) as *mut Arc<_> as usize;
218 s.serialize_newtype_variant("GenericCallback", 1, "InProcess", &sender_clone_addr)
219 },
220 }
221 }
222}
223
224struct GenericCallbackVisitor<T> {
225 marker: PhantomData<T>,
226}
227
228impl<'de, T> serde::de::Visitor<'de> for GenericCallbackVisitor<T>
229where
230 T: Serialize + Deserialize<'de> + Send + 'static,
231{
232 type Value = GenericCallback<T>;
233
234 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
235 formatter.write_str("a GenericCallback variant")
236 }
237
238 fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
239 where
240 A: serde::de::EnumAccess<'de>,
241 {
242 #[derive(Deserialize)]
243 enum GenericCallbackVariantNames {
244 CrossProcess,
245 InProcess,
246 }
247
248 let (variant_name, variant_data): (GenericCallbackVariantNames, _) = data.variant()?;
249
250 match variant_name {
251 GenericCallbackVariantNames::CrossProcess => variant_data
252 .newtype_variant::<IpcSender<T>>()
253 .map(|sender| GenericCallback(GenericCallbackVariants::CrossProcess(sender))),
254 GenericCallbackVariantNames::InProcess => {
255 if opts::get().multiprocess {
256 return Err(serde::de::Error::custom(
257 "InProcess callback found in multiprocess mode",
258 ));
259 }
260 let addr = variant_data.newtype_variant::<usize>()?;
261 let ptr = addr as *mut Arc<Mutex<_>>;
262 #[expect(unsafe_code)]
268 let callback = unsafe { Box::from_raw(ptr) };
269 Ok(GenericCallback(GenericCallbackVariants::InProcess(
270 *callback,
271 )))
272 },
273 }
274 }
275}
276
277impl<'a, T> Deserialize<'a> for GenericCallback<T>
278where
279 T: Serialize + Deserialize<'a> + Send + 'static,
280{
281 fn deserialize<D>(d: D) -> Result<GenericCallback<T>, D::Error>
282 where
283 D: Deserializer<'a>,
284 {
285 d.deserialize_enum(
286 "GenericCallback",
287 &["CrossProcess", "InProcess"],
288 GenericCallbackVisitor {
289 marker: PhantomData,
290 },
291 )
292 }
293}
294
295impl<T> fmt::Debug for GenericCallback<T>
296where
297 T: Serialize + Send + 'static,
298{
299 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
300 write!(f, "GenericCallback(..)")
301 }
302}
303
304#[cfg(test)]
305mod single_process_callback_test {
306 use std::sync::Arc;
307 use std::sync::atomic::{AtomicUsize, Ordering};
308
309 use crate::generic_channel::GenericCallback;
310
311 #[test]
312 fn generic_callback() {
313 let number = Arc::new(AtomicUsize::new(0));
314 let number_clone = number.clone();
315 let callback = move |msg: Result<usize, ipc_channel::Error>| {
316 number_clone.store(msg.unwrap(), Ordering::SeqCst)
317 };
318 let generic_callback = GenericCallback::new(callback).unwrap();
319 std::thread::scope(|s| {
320 s.spawn(move || generic_callback.send(42));
321 });
322 assert_eq!(number.load(Ordering::SeqCst), 42);
323 }
324
325 #[test]
326 fn generic_callback_via_generic_sender() {
327 let number = Arc::new(AtomicUsize::new(0));
328 let number_clone = number.clone();
329 let callback = move |msg: Result<usize, ipc_channel::Error>| {
330 number_clone.store(msg.unwrap(), Ordering::SeqCst)
331 };
332 let generic_callback = GenericCallback::new(callback).unwrap();
333 let (tx, rx) = crate::generic_channel::channel().unwrap();
334
335 tx.send(generic_callback).unwrap();
336 std::thread::scope(|s| {
337 s.spawn(move || {
338 let callback = rx.recv().unwrap();
339 callback.send(42).unwrap();
340 });
341 });
342 assert_eq!(number.load(Ordering::SeqCst), 42);
343 }
344
345 #[test]
346 fn generic_callback_via_ipc_sender() {
347 let number = Arc::new(AtomicUsize::new(0));
348 let number_clone = number.clone();
349 let callback = move |msg: Result<usize, ipc_channel::Error>| {
350 number_clone.store(msg.unwrap(), Ordering::SeqCst)
351 };
352 let generic_callback = GenericCallback::new(callback).unwrap();
353 let (tx, rx) = ipc_channel::ipc::channel().unwrap();
354
355 tx.send(generic_callback).unwrap();
356 std::thread::scope(|s| {
357 s.spawn(move || {
358 let callback = rx.recv().unwrap();
359 callback.send(42).unwrap();
360 });
361 });
362 assert_eq!(number.load(Ordering::SeqCst), 42);
363 }
364
365 #[test]
366 fn generic_callback_blocking() {
367 let (callback, receiver) = GenericCallback::new_blocking().unwrap();
368 std::thread::spawn(move || {
369 std::thread::sleep(std::time::Duration::from_secs(1));
370 assert!(callback.send(42).is_ok());
371 });
372 assert_eq!(receiver.recv().unwrap(), 42);
373 }
374}