1use std::fmt;
63use std::marker::PhantomData;
64use std::sync::{Arc, Mutex};
65
66use ipc_channel::ipc::IpcSender;
67use ipc_channel::router::ROUTER;
68use malloc_size_of::{MallocSizeOf, MallocSizeOfOps};
69use serde::de::VariantAccess;
70use serde::{Deserialize, Deserializer, Serialize, Serializer};
71use servo_config::opts;
72
73use crate::generic_channel::{
74 GenericReceiver, GenericReceiverVariants, SendError, SendResult, use_ipc,
75};
76
77pub type MsgCallback<T> = dyn FnMut(Result<T, ipc_channel::IpcError>) + Send;
84
85pub struct GenericCallback<T>(GenericCallbackVariants<T>)
91where
92 T: Serialize + Send + 'static;
93
94enum GenericCallbackVariants<T>
95where
96 T: Serialize + Send + 'static,
97{
98 CrossProcess(IpcSender<T>),
99 InProcess(Arc<Mutex<MsgCallback<T>>>),
100}
101
102impl<T> Clone for GenericCallback<T>
103where
104 T: Serialize + Send + 'static,
105{
106 fn clone(&self) -> Self {
107 let variant = match &self.0 {
108 GenericCallbackVariants::CrossProcess(sender) => {
109 GenericCallbackVariants::CrossProcess((*sender).clone())
110 },
111 GenericCallbackVariants::InProcess(callback) => {
112 GenericCallbackVariants::InProcess(callback.clone())
113 },
114 };
115 GenericCallback(variant)
116 }
117}
118
119impl<T> MallocSizeOf for GenericCallback<T>
120where
121 T: Serialize + Send + 'static,
122{
123 fn size_of(&self, _ops: &mut MallocSizeOfOps) -> usize {
124 0
125 }
126}
127
128impl<T> GenericCallback<T>
129where
130 T: for<'de> Deserialize<'de> + Serialize + Send + 'static,
131{
132 pub fn new<F: FnMut(Result<T, ipc_channel::IpcError>) + Send + 'static>(
136 mut callback: F,
137 ) -> Result<Self, ipc_channel::IpcError> {
138 let generic_callback = if use_ipc() {
139 let (ipc_sender, ipc_receiver) = ipc_channel::ipc::channel()?;
140 let new_callback = move |msg: Result<T, ipc_channel::SerDeError>| {
141 callback(msg.map_err(|error| error.into()))
142 };
143 ROUTER.add_typed_route(ipc_receiver, Box::new(new_callback));
144 GenericCallback(GenericCallbackVariants::CrossProcess(ipc_sender))
145 } else {
146 let callback = Arc::new(Mutex::new(callback));
147 GenericCallback(GenericCallbackVariants::InProcess(callback))
148 };
149 Ok(generic_callback)
150 }
151
152 pub fn new_blocking() -> Result<(Self, GenericReceiver<T>), ipc_channel::IpcError> {
154 if use_ipc() {
155 let (sender, receiver) = ipc_channel::ipc::channel()?;
156 let generic_callback = GenericCallback(GenericCallbackVariants::CrossProcess(sender));
157 let receiver = GenericReceiver(GenericReceiverVariants::Ipc(receiver));
158 Ok((generic_callback, receiver))
159 } else {
160 let (sender, receiver) = crossbeam_channel::bounded(1);
161 let callback = Arc::new(Mutex::new(move |msg| {
162 if sender.send(msg).is_err() {
163 log::error!("Error in callback");
164 }
165 }));
166 let generic_callback = GenericCallback(GenericCallbackVariants::InProcess(callback));
167 let receiver = GenericReceiver(GenericReceiverVariants::Crossbeam(receiver));
168 Ok((generic_callback, receiver))
169 }
170 }
171
172 pub fn send(&self, value: T) -> SendResult {
178 match &self.0 {
179 GenericCallbackVariants::CrossProcess(sender) => {
180 sender.send(value).map_err(|error| match error {
181 ipc_channel::IpcError::SerializationError(ser_de_error) => {
182 SendError::SerializationError(ser_de_error.to_string())
183 },
184 ipc_channel::IpcError::Io(_) | ipc_channel::IpcError::Disconnected => {
185 SendError::Disconnected
186 },
187 })
188 },
189 GenericCallbackVariants::InProcess(callback) => {
190 let mut cb = callback.lock().expect("poisoned");
191 (*cb)(Ok(value));
192 Ok(())
193 },
194 }
195 }
196}
197
198impl<T> Serialize for GenericCallback<T>
199where
200 T: Serialize + Send + 'static,
201{
202 fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
203 match &self.0 {
204 GenericCallbackVariants::CrossProcess(sender) => {
205 s.serialize_newtype_variant("GenericCallback", 0, "CrossProcess", sender)
206 },
207 GenericCallbackVariants::InProcess(wrapped_callback) => {
214 if opts::get().multiprocess {
215 return Err(serde::ser::Error::custom(
216 "InProcess callback can't be serialized in multiprocess mode",
217 ));
218 }
219 let cloned_callback = Box::new(wrapped_callback.clone());
223 let sender_clone_addr = Box::leak(cloned_callback) as *mut Arc<_> as usize;
224 s.serialize_newtype_variant("GenericCallback", 1, "InProcess", &sender_clone_addr)
225 },
226 }
227 }
228}
229
230struct GenericCallbackVisitor<T> {
231 marker: PhantomData<T>,
232}
233
234impl<'de, T> serde::de::Visitor<'de> for GenericCallbackVisitor<T>
235where
236 T: Serialize + Deserialize<'de> + Send + 'static,
237{
238 type Value = GenericCallback<T>;
239
240 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
241 formatter.write_str("a GenericCallback variant")
242 }
243
244 fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
245 where
246 A: serde::de::EnumAccess<'de>,
247 {
248 #[derive(Deserialize)]
249 enum GenericCallbackVariantNames {
250 CrossProcess,
251 InProcess,
252 }
253
254 let (variant_name, variant_data): (GenericCallbackVariantNames, _) = data.variant()?;
255
256 match variant_name {
257 GenericCallbackVariantNames::CrossProcess => variant_data
258 .newtype_variant::<IpcSender<T>>()
259 .map(|sender| GenericCallback(GenericCallbackVariants::CrossProcess(sender))),
260 GenericCallbackVariantNames::InProcess => {
261 if use_ipc() {
262 return Err(serde::de::Error::custom(
263 "InProcess callback found in multiprocess mode",
264 ));
265 }
266 let addr = variant_data.newtype_variant::<usize>()?;
267 let ptr = addr as *mut Arc<Mutex<_>>;
268 #[expect(unsafe_code)]
274 let callback = unsafe { Box::from_raw(ptr) };
275 Ok(GenericCallback(GenericCallbackVariants::InProcess(
276 *callback,
277 )))
278 },
279 }
280 }
281}
282
283impl<'a, T> Deserialize<'a> for GenericCallback<T>
284where
285 T: Serialize + Deserialize<'a> + Send + 'static,
286{
287 fn deserialize<D>(d: D) -> Result<GenericCallback<T>, D::Error>
288 where
289 D: Deserializer<'a>,
290 {
291 d.deserialize_enum(
292 "GenericCallback",
293 &["CrossProcess", "InProcess"],
294 GenericCallbackVisitor {
295 marker: PhantomData,
296 },
297 )
298 }
299}
300
301impl<T> fmt::Debug for GenericCallback<T>
302where
303 T: Serialize + Send + 'static,
304{
305 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
306 write!(f, "GenericCallback(..)")
307 }
308}
309
310#[cfg(test)]
311mod single_process_callback_test {
312 use std::sync::Arc;
313 use std::sync::atomic::{AtomicUsize, Ordering};
314
315 use crate::generic_channel::GenericCallback;
316
317 #[test]
318 fn generic_callback() {
319 let number = Arc::new(AtomicUsize::new(0));
320 let number_clone = number.clone();
321 let callback = move |msg: Result<usize, ipc_channel::IpcError>| {
322 number_clone.store(msg.unwrap(), Ordering::SeqCst)
323 };
324 let generic_callback = GenericCallback::new(callback).unwrap();
325 std::thread::scope(|s| {
326 s.spawn(move || generic_callback.send(42));
327 });
328 assert_eq!(number.load(Ordering::SeqCst), 42);
329 }
330
331 #[test]
332 fn generic_callback_via_generic_sender() {
333 let number = Arc::new(AtomicUsize::new(0));
334 let number_clone = number.clone();
335 let callback = move |msg: Result<usize, ipc_channel::IpcError>| {
336 number_clone.store(msg.unwrap(), Ordering::SeqCst)
337 };
338 let generic_callback = GenericCallback::new(callback).unwrap();
339 let (tx, rx) = crate::generic_channel::channel().unwrap();
340
341 tx.send(generic_callback).unwrap();
342 std::thread::scope(|s| {
343 s.spawn(move || {
344 let callback = rx.recv().unwrap();
345 callback.send(42).unwrap();
346 });
347 });
348 assert_eq!(number.load(Ordering::SeqCst), 42);
349 }
350
351 #[test]
352 fn generic_callback_via_ipc_sender() {
353 let number = Arc::new(AtomicUsize::new(0));
354 let number_clone = number.clone();
355 let callback = move |msg: Result<usize, ipc_channel::IpcError>| {
356 number_clone.store(msg.unwrap(), Ordering::SeqCst)
357 };
358 let generic_callback = GenericCallback::new(callback).unwrap();
359 let (tx, rx) = ipc_channel::ipc::channel().unwrap();
360
361 tx.send(generic_callback).unwrap();
362 std::thread::scope(|s| {
363 s.spawn(move || {
364 let callback = rx.recv().unwrap();
365 callback.send(42).unwrap();
366 });
367 });
368 assert_eq!(number.load(Ordering::SeqCst), 42);
369 }
370
371 #[test]
372 fn generic_callback_blocking() {
373 let (callback, receiver) = GenericCallback::new_blocking().unwrap();
374 std::thread::spawn(move || {
375 std::thread::sleep(std::time::Duration::from_secs(1));
376 assert!(callback.send(42).is_ok());
377 });
378 assert_eq!(receiver.recv().unwrap(), 42);
379 }
380}