wayland_backend/rs/server_impl/
common_poll.rs

1use std::{
2    os::unix::io::{AsRawFd, BorrowedFd, OwnedFd},
3    sync::{Arc, Mutex},
4};
5
6use super::{
7    handle::State, ClientId, Data, GlobalHandler, GlobalId, Handle, InnerClientId, InnerGlobalId,
8    InnerHandle, InnerObjectId, ObjectId,
9};
10use crate::{
11    core_interfaces::{WL_DISPLAY_INTERFACE, WL_REGISTRY_INTERFACE},
12    protocol::{same_interface, Argument, Message},
13    rs::map::Object,
14    types::server::InitError,
15};
16
17#[cfg(any(target_os = "linux", target_os = "android"))]
18use rustix::event::{epoll, Timespec};
19
20#[cfg(any(
21    target_os = "dragonfly",
22    target_os = "freebsd",
23    target_os = "netbsd",
24    target_os = "openbsd",
25    target_os = "macos"
26))]
27use rustix::event::kqueue::*;
28use smallvec::SmallVec;
29
30#[derive(Debug)]
31pub struct InnerBackend<D: 'static> {
32    state: Arc<Mutex<State<D>>>,
33}
34
35impl<D> InnerBackend<D> {
36    pub fn new() -> Result<Self, InitError> {
37        #[cfg(any(target_os = "linux", target_os = "android"))]
38        let poll_fd = epoll::create(epoll::CreateFlags::CLOEXEC)
39            .map_err(Into::into)
40            .map_err(InitError::Io)?;
41
42        #[cfg(any(
43            target_os = "dragonfly",
44            target_os = "freebsd",
45            target_os = "netbsd",
46            target_os = "openbsd",
47            target_os = "macos"
48        ))]
49        let poll_fd = kqueue().map_err(Into::into).map_err(InitError::Io)?;
50
51        Ok(Self { state: Arc::new(Mutex::new(State::new(poll_fd))) })
52    }
53
54    pub fn flush(&self, client: Option<ClientId>) -> std::io::Result<()> {
55        self.state.lock().unwrap().flush(client)
56    }
57
58    pub fn handle(&self) -> Handle {
59        Handle { handle: InnerHandle { state: self.state.clone() as Arc<_> } }
60    }
61
62    pub fn poll_fd(&self) -> BorrowedFd<'_> {
63        let raw_fd = self.state.lock().unwrap().poll_fd.as_raw_fd();
64        // This allows the lifetime of the BorrowedFd to be tied to &self rather than the lock guard,
65        // which is the real safety concern
66        unsafe { BorrowedFd::borrow_raw(raw_fd) }
67    }
68
69    pub fn dispatch_client(
70        &self,
71        data: &mut D,
72        client_id: InnerClientId,
73    ) -> std::io::Result<usize> {
74        let ret = self.dispatch_events_for(data, client_id);
75        let cleanup = self.state.lock().unwrap().cleanup();
76        cleanup(&self.handle(), data);
77        ret
78    }
79
80    #[cfg(any(target_os = "linux", target_os = "android"))]
81    pub fn dispatch_all_clients(&self, data: &mut D) -> std::io::Result<usize> {
82        use std::os::unix::io::AsFd;
83
84        let poll_fd = self.poll_fd();
85        let mut dispatched = 0;
86        let mut events = Vec::<epoll::Event>::with_capacity(32);
87        loop {
88            let buffer = rustix::buffer::spare_capacity(&mut events);
89            epoll::wait(poll_fd.as_fd(), buffer, Some(&Timespec::default()))?;
90
91            if events.is_empty() {
92                break;
93            }
94
95            for event in events.drain(..) {
96                let id = InnerClientId::from_u64(event.data.u64());
97                // remove the cb while we call it, to gracefully handle reentrancy
98                if let Ok(count) = self.dispatch_events_for(data, id) {
99                    dispatched += count;
100                }
101            }
102            let cleanup = self.state.lock().unwrap().cleanup();
103            cleanup(&self.handle(), data);
104        }
105
106        Ok(dispatched)
107    }
108
109    #[cfg(any(
110        target_os = "dragonfly",
111        target_os = "freebsd",
112        target_os = "netbsd",
113        target_os = "openbsd",
114        target_os = "macos"
115    ))]
116    pub fn dispatch_all_clients(&self, data: &mut D) -> std::io::Result<usize> {
117        use std::time::Duration;
118
119        let poll_fd = self.poll_fd();
120        let mut dispatched = 0;
121        let mut events = Vec::<Event>::with_capacity(32);
122        loop {
123            let buffer = rustix::buffer::spare_capacity(&mut events);
124            let nevents = unsafe { kevent(&poll_fd, &[], buffer, Some(Duration::ZERO))? };
125
126            if nevents == 0 {
127                break;
128            }
129
130            for event in events.drain(..) {
131                let id = InnerClientId::from_u64(event.udata() as u64);
132                // remove the cb while we call it, to gracefully handle reentrancy
133                if let Ok(count) = self.dispatch_events_for(data, id) {
134                    dispatched += count;
135                }
136            }
137            let cleanup = self.state.lock().unwrap().cleanup();
138            cleanup(&self.handle(), data);
139        }
140
141        Ok(dispatched)
142    }
143
144    pub(crate) fn dispatch_events_for(
145        &self,
146        data: &mut D,
147        client_id: InnerClientId,
148    ) -> std::io::Result<usize> {
149        let mut dispatched = 0;
150        let handle = self.handle();
151        let mut state = self.state.lock().unwrap();
152        loop {
153            let action = {
154                let state = &mut *state;
155                if let Ok(client) = state.clients.get_client_mut(client_id.clone()) {
156                    let (message, object) = match client.next_request() {
157                        Ok(v) => v,
158                        Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
159                            if dispatched > 0 {
160                                break;
161                            } else {
162                                return Err(e);
163                            }
164                        }
165                        Err(e) => {
166                            #[cfg(any(target_os = "linux", target_os = "android"))]
167                            {
168                                epoll::delete(&state.poll_fd, client)?;
169                            }
170
171                            #[cfg(any(
172                                target_os = "dragonfly",
173                                target_os = "freebsd",
174                                target_os = "netbsd",
175                                target_os = "openbsd",
176                                target_os = "macos"
177                            ))]
178                            {
179                                use rustix::event::kqueue::*;
180                                use std::os::unix::io::{AsFd, AsRawFd};
181
182                                let evt = Event::new(
183                                    EventFilter::Read(client.as_fd().as_raw_fd()),
184                                    EventFlags::DELETE,
185                                    client_id.as_u64() as *mut _,
186                                );
187
188                                let events: &mut [Event] = &mut [];
189                                unsafe {
190                                    kevent(&state.poll_fd, &[evt], events, None).map(|_| ())?;
191                                }
192                            }
193                            return Err(e);
194                        }
195                    };
196                    dispatched += 1;
197                    if same_interface(object.interface, &WL_DISPLAY_INTERFACE) {
198                        client.handle_display_request(message, &mut state.registry);
199                        continue;
200                    } else if same_interface(object.interface, &WL_REGISTRY_INTERFACE) {
201                        if let Some((client, global, object, handler)) =
202                            client.handle_registry_request(message, &mut state.registry)
203                        {
204                            DispatchAction::Bind { client, global, object, handler }
205                        } else {
206                            continue;
207                        }
208                    } else {
209                        let object_id = InnerObjectId {
210                            id: message.sender_id,
211                            serial: object.data.serial,
212                            interface: object.interface,
213                            client_id: client.id.clone(),
214                        };
215                        let opcode = message.opcode;
216                        let (arguments, is_destructor, created_id) =
217                            match client.process_request(&object, message) {
218                                Some(args) => args,
219                                None => continue,
220                            };
221                        // Return the whole set to invoke the callback while handle is not borrower via client
222                        DispatchAction::Request {
223                            object,
224                            object_id,
225                            opcode,
226                            arguments,
227                            is_destructor,
228                            created_id,
229                        }
230                    }
231                } else {
232                    return Err(std::io::Error::new(
233                        std::io::ErrorKind::InvalidInput,
234                        "Invalid client ID",
235                    ));
236                }
237            };
238            match action {
239                DispatchAction::Request {
240                    object,
241                    object_id,
242                    opcode,
243                    arguments,
244                    is_destructor,
245                    created_id,
246                } => {
247                    // temporarily unlock the state Mutex while this request is dispatched
248                    std::mem::drop(state);
249                    let ret = object.data.user_data.clone().request(
250                        &handle.clone(),
251                        data,
252                        ClientId { id: client_id.clone() },
253                        Message {
254                            sender_id: ObjectId { id: object_id.clone() },
255                            opcode,
256                            args: arguments,
257                        },
258                    );
259                    if is_destructor {
260                        object.data.user_data.clone().destroyed(
261                            &handle.clone(),
262                            data,
263                            ClientId { id: client_id.clone() },
264                            ObjectId { id: object_id.clone() },
265                        );
266                    }
267                    // acquire the lock again and continue
268                    state = self.state.lock().unwrap();
269                    if is_destructor {
270                        if let Ok(client) = state.clients.get_client_mut(client_id.clone()) {
271                            client.send_delete_id(object_id);
272                        }
273                    }
274                    match (created_id, ret) {
275                        (Some(child_id), Some(child_data)) => {
276                            if let Ok(client) = state.clients.get_client_mut(client_id.clone()) {
277                                client
278                                    .map
279                                    .with(child_id.id, |obj| obj.data.user_data = child_data)
280                                    .unwrap();
281                            }
282                        }
283                        (None, None) => {}
284                        (Some(child_id), None) => {
285                            // Allow the callback to not return any data if the client is already dead (typically
286                            // if the callback provoked a protocol error)
287                            if let Ok(client) = state.clients.get_client(client_id.clone()) {
288                                if !client.killed {
289                                    panic!(
290                                        "Callback creating object {child_id} did not provide any object data."
291                                    );
292                                }
293                            }
294                        }
295                        (None, Some(_)) => {
296                            panic!("An object data was returned from a callback not creating any object");
297                        }
298                    }
299                    // dropping the object calls destructors from which users could call into wayland-backend again.
300                    // so lets release and relock the state again, to avoid a deadlock
301                    std::mem::drop(state);
302                    std::mem::drop(object);
303                    state = self.state.lock().unwrap();
304                }
305                DispatchAction::Bind { object, client, global, handler } => {
306                    // temporarily unlock the state Mutex while this request is dispatched
307                    std::mem::drop(state);
308                    let child_data = handler.bind(
309                        &handle.clone(),
310                        data,
311                        ClientId { id: client.clone() },
312                        GlobalId { id: global },
313                        ObjectId { id: object.clone() },
314                    );
315                    // acquire the lock again and continue
316                    state = self.state.lock().unwrap();
317                    if let Ok(client) = state.clients.get_client_mut(client.clone()) {
318                        client.map.with(object.id, |obj| obj.data.user_data = child_data).unwrap();
319                    }
320                }
321            }
322        }
323        Ok(dispatched)
324    }
325}
326
327enum DispatchAction<D: 'static> {
328    Request {
329        object: Object<Data<D>>,
330        object_id: InnerObjectId,
331        opcode: u16,
332        arguments: SmallVec<[Argument<ObjectId, OwnedFd>; 4]>,
333        is_destructor: bool,
334        created_id: Option<InnerObjectId>,
335    },
336    Bind {
337        object: InnerObjectId,
338        client: InnerClientId,
339        global: InnerGlobalId,
340        handler: Arc<dyn GlobalHandler<D>>,
341    },
342}