base/generic_channel/
callback.rs1use std::fmt;
63use std::marker::PhantomData;
64use std::sync::{Arc, Mutex};
65
66use ipc_channel::ErrorKind;
67use ipc_channel::ipc::IpcSender;
68use ipc_channel::router::ROUTER;
69use malloc_size_of::{MallocSizeOf, MallocSizeOfOps};
70use serde::de::VariantAccess;
71use serde::{Deserialize, Deserializer, Serialize, Serializer};
72use servo_config::opts;
73
74use crate::generic_channel::{SendError, SendResult};
75
76pub type MsgCallback<T> = dyn FnMut(Result<T, ipc_channel::Error>) + Send;
83
84pub struct GenericCallback<T>(GenericCallbackVariants<T>)
90where
91 T: Serialize + Send + 'static;
92
93enum GenericCallbackVariants<T>
94where
95 T: Serialize + Send + 'static,
96{
97 CrossProcess(IpcSender<T>),
98 InProcess(Arc<Mutex<MsgCallback<T>>>),
99}
100
101impl<T> Clone for GenericCallback<T>
102where
103 T: Serialize + Send + 'static,
104{
105 fn clone(&self) -> Self {
106 let variant = match &self.0 {
107 GenericCallbackVariants::CrossProcess(sender) => {
108 GenericCallbackVariants::CrossProcess((*sender).clone())
109 },
110 GenericCallbackVariants::InProcess(callback) => {
111 GenericCallbackVariants::InProcess(callback.clone())
112 },
113 };
114 GenericCallback(variant)
115 }
116}
117
118impl<T> MallocSizeOf for GenericCallback<T>
119where
120 T: Serialize + Send + 'static,
121{
122 fn size_of(&self, _ops: &mut MallocSizeOfOps) -> usize {
123 0
124 }
125}
126
127impl<T> GenericCallback<T>
128where
129 T: for<'de> Deserialize<'de> + Serialize + Send + 'static,
130{
131 pub fn new<F: FnMut(Result<T, ipc_channel::Error>) + Send + 'static>(
135 callback: F,
136 ) -> Result<Self, ipc_channel::Error> {
137 let generic_callback = if opts::get().multiprocess || opts::get().force_ipc {
138 let (ipc_sender, ipc_receiver) = ipc_channel::ipc::channel()?;
139 ROUTER.add_typed_route(ipc_receiver, Box::new(callback));
140 GenericCallback(GenericCallbackVariants::CrossProcess(ipc_sender))
141 } else {
142 let callback = Arc::new(Mutex::new(callback));
143 GenericCallback(GenericCallbackVariants::InProcess(callback))
144 };
145 Ok(generic_callback)
146 }
147
148 pub fn send(&self, value: T) -> SendResult {
154 match &self.0 {
155 GenericCallbackVariants::CrossProcess(sender) => {
156 sender.send(value).map_err(|error| match *error {
157 ErrorKind::Io(_) => SendError::Disconnected,
158 serialization_error => {
159 SendError::SerializationError(serialization_error.to_string())
160 },
161 })
162 },
163 GenericCallbackVariants::InProcess(callback) => {
164 let mut cb = callback.lock().expect("poisoned");
165 (*cb)(Ok(value));
166 Ok(())
167 },
168 }
169 }
170}
171
172impl<T> Serialize for GenericCallback<T>
173where
174 T: Serialize + Send + 'static,
175{
176 fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
177 match &self.0 {
178 GenericCallbackVariants::CrossProcess(sender) => {
179 s.serialize_newtype_variant("GenericCallback", 0, "CrossProcess", sender)
180 },
181 GenericCallbackVariants::InProcess(wrapped_callback) => {
188 if opts::get().multiprocess {
189 return Err(serde::ser::Error::custom(
190 "InProcess callback can't be serialized in multiprocess mode",
191 ));
192 }
193 let cloned_callback = Box::new(wrapped_callback.clone());
197 let sender_clone_addr = Box::leak(cloned_callback) as *mut Arc<_> as usize;
198 s.serialize_newtype_variant("GenericCallback", 1, "InProcess", &sender_clone_addr)
199 },
200 }
201 }
202}
203
204struct GenericCallbackVisitor<T> {
205 marker: PhantomData<T>,
206}
207
208impl<'de, T> serde::de::Visitor<'de> for GenericCallbackVisitor<T>
209where
210 T: Serialize + Deserialize<'de> + Send + 'static,
211{
212 type Value = GenericCallback<T>;
213
214 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
215 formatter.write_str("a GenericCallback variant")
216 }
217
218 fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
219 where
220 A: serde::de::EnumAccess<'de>,
221 {
222 #[derive(Deserialize)]
223 enum GenericCallbackVariantNames {
224 CrossProcess,
225 InProcess,
226 }
227
228 let (variant_name, variant_data): (GenericCallbackVariantNames, _) = data.variant()?;
229
230 match variant_name {
231 GenericCallbackVariantNames::CrossProcess => variant_data
232 .newtype_variant::<IpcSender<T>>()
233 .map(|sender| GenericCallback(GenericCallbackVariants::CrossProcess(sender))),
234 GenericCallbackVariantNames::InProcess => {
235 if opts::get().multiprocess {
236 return Err(serde::de::Error::custom(
237 "InProcess callback found in multiprocess mode",
238 ));
239 }
240 let addr = variant_data.newtype_variant::<usize>()?;
241 let ptr = addr as *mut Arc<Mutex<_>>;
242 #[allow(unsafe_code)]
248 let callback = unsafe { Box::from_raw(ptr) };
249 Ok(GenericCallback(GenericCallbackVariants::InProcess(
250 *callback,
251 )))
252 },
253 }
254 }
255}
256
257impl<'a, T> Deserialize<'a> for GenericCallback<T>
258where
259 T: Serialize + Deserialize<'a> + Send + 'static,
260{
261 fn deserialize<D>(d: D) -> Result<GenericCallback<T>, D::Error>
262 where
263 D: Deserializer<'a>,
264 {
265 d.deserialize_enum(
266 "GenericCallback",
267 &["CrossProcess", "InProcess"],
268 GenericCallbackVisitor {
269 marker: PhantomData,
270 },
271 )
272 }
273}
274
275impl<T> fmt::Debug for GenericCallback<T>
276where
277 T: Serialize + Send + 'static,
278{
279 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
280 write!(f, "GenericCallback(..)")
281 }
282}
283
284#[cfg(test)]
285mod single_process_callback_test {
286 use std::sync::Arc;
287 use std::sync::atomic::{AtomicUsize, Ordering};
288
289 use crate::generic_channel::GenericCallback;
290
291 #[test]
292 fn generic_callback() {
293 let number = Arc::new(AtomicUsize::new(0));
294 let number_clone = number.clone();
295 let callback = move |msg: Result<usize, ipc_channel::Error>| {
296 number_clone.store(msg.unwrap(), Ordering::SeqCst)
297 };
298 let generic_callback = GenericCallback::new(callback).unwrap();
299 std::thread::scope(|s| {
300 s.spawn(move || generic_callback.send(42));
301 });
302 assert_eq!(number.load(Ordering::SeqCst), 42);
303 }
304
305 #[test]
306 fn generic_callback_via_generic_sender() {
307 let number = Arc::new(AtomicUsize::new(0));
308 let number_clone = number.clone();
309 let callback = move |msg: Result<usize, ipc_channel::Error>| {
310 number_clone.store(msg.unwrap(), Ordering::SeqCst)
311 };
312 let generic_callback = GenericCallback::new(callback).unwrap();
313 let (tx, rx) = crate::generic_channel::channel().unwrap();
314
315 tx.send(generic_callback).unwrap();
316 std::thread::scope(|s| {
317 s.spawn(move || {
318 let callback = rx.recv().unwrap();
319 callback.send(42).unwrap();
320 });
321 });
322 assert_eq!(number.load(Ordering::SeqCst), 42);
323 }
324
325 #[test]
326 fn generic_callback_via_ipc_sender() {
327 let number = Arc::new(AtomicUsize::new(0));
328 let number_clone = number.clone();
329 let callback = move |msg: Result<usize, ipc_channel::Error>| {
330 number_clone.store(msg.unwrap(), Ordering::SeqCst)
331 };
332 let generic_callback = GenericCallback::new(callback).unwrap();
333 let (tx, rx) = ipc_channel::ipc::channel().unwrap();
334
335 tx.send(generic_callback).unwrap();
336 std::thread::scope(|s| {
337 s.spawn(move || {
338 let callback = rx.recv().unwrap();
339 callback.send(42).unwrap();
340 });
341 });
342 assert_eq!(number.load(Ordering::SeqCst), 42);
343 }
344}