wayland_backend/rs/server_impl/
handle.rs

1use std::{
2    ffi::CString,
3    os::unix::{
4        io::{OwnedFd, RawFd},
5        net::UnixStream,
6    },
7    sync::{Arc, Mutex, Weak},
8};
9
10use crate::{
11    protocol::{same_interface, Interface, Message, ObjectInfo, ANONYMOUS_INTERFACE},
12    rs::DEFAULT_MAX_BUFFER_SIZE,
13    types::server::{DisconnectReason, GlobalInfo, InvalidId},
14};
15
16use super::{
17    client::ClientStore, registry::Registry, ClientData, ClientId, Credentials, GlobalHandler,
18    GlobalId, InnerClientId, InnerGlobalId, InnerObjectId, ObjectData, ObjectId,
19};
20
21pub(crate) type PendingDestructor<D> = (Arc<dyn ObjectData<D>>, InnerClientId, InnerObjectId);
22
23#[derive(Debug)]
24pub struct State<D: 'static> {
25    pub(crate) clients: ClientStore<D>,
26    pub(crate) registry: Registry<D>,
27    pub(crate) pending_destructors: Vec<PendingDestructor<D>>,
28    pub(crate) poll_fd: OwnedFd,
29    pub(crate) default_max_buffer_size: usize,
30}
31
32impl<D> State<D> {
33    pub(crate) fn new(poll_fd: OwnedFd) -> Self {
34        let debug =
35            matches!(std::env::var_os("WAYLAND_DEBUG"), Some(str) if str == "1" || str == "server");
36        Self {
37            clients: ClientStore::new(debug),
38            registry: Registry::new(),
39            pending_destructors: Vec::new(),
40            poll_fd,
41            default_max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
42        }
43    }
44
45    pub(crate) fn cleanup<'a>(&mut self) -> impl FnOnce(&super::Handle, &mut D) + 'a {
46        let dead_clients = self.clients.cleanup(&mut self.pending_destructors);
47        self.registry.cleanup(&dead_clients, &self.pending_destructors);
48        // return a closure that will do the cleanup once invoked
49        let pending_destructors = std::mem::take(&mut self.pending_destructors);
50        move |handle, data| {
51            for (object_data, client_id, object_id) in pending_destructors {
52                object_data.clone().destroyed(
53                    handle,
54                    data,
55                    ClientId { id: client_id },
56                    ObjectId { id: object_id },
57                );
58            }
59            std::mem::drop(dead_clients);
60        }
61    }
62
63    pub(crate) fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
64        if let Some(ClientId { id: client }) = client {
65            match self.clients.get_client_mut(client) {
66                Ok(client) => client.flush(),
67                Err(InvalidId) => Ok(()),
68            }
69        } else {
70            for client in self.clients.clients_mut() {
71                let _ = client.flush();
72            }
73            Ok(())
74        }
75    }
76
77    pub fn set_default_max_buffer_size(&mut self, max_buffer_size: usize) {
78        self.default_max_buffer_size = max_buffer_size;
79    }
80
81    fn set_client_max_buffer_size(&mut self, client: InnerClientId, max_buffer_size: usize) {
82        if let Ok(client) = self.clients.get_client_mut(client) {
83            client.socket.set_max_buffer_size(Some(max_buffer_size));
84        }
85    }
86}
87
88#[derive(Clone)]
89pub struct InnerHandle {
90    pub(crate) state: Arc<Mutex<dyn ErasedState + Send>>,
91}
92
93impl std::fmt::Debug for InnerHandle {
94    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        fmt.debug_struct("InnerHandle[rs]").finish_non_exhaustive()
96    }
97}
98
99#[derive(Clone)]
100pub struct WeakInnerHandle {
101    pub(crate) state: Weak<Mutex<dyn ErasedState + Send>>,
102}
103
104impl std::fmt::Debug for WeakInnerHandle {
105    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        fmt.debug_struct("WeakInnerHandle[rs]").finish_non_exhaustive()
107    }
108}
109
110impl WeakInnerHandle {
111    pub fn upgrade(&self) -> Option<InnerHandle> {
112        self.state.upgrade().map(|state| InnerHandle { state })
113    }
114}
115
116impl InnerHandle {
117    pub fn downgrade(&self) -> WeakInnerHandle {
118        WeakInnerHandle { state: Arc::downgrade(&self.state) }
119    }
120
121    pub fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId> {
122        self.state.lock().unwrap().object_info(id)
123    }
124
125    pub fn insert_client(
126        &self,
127        stream: UnixStream,
128        data: Arc<dyn ClientData>,
129    ) -> std::io::Result<InnerClientId> {
130        self.state.lock().unwrap().insert_client(stream, data)
131    }
132
133    pub fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId> {
134        self.state.lock().unwrap().get_client(id)
135    }
136
137    pub fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId> {
138        self.state.lock().unwrap().get_client_data(id)
139    }
140
141    pub fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId> {
142        self.state.lock().unwrap().get_client_credentials(id)
143    }
144
145    pub fn with_all_clients(&self, mut f: impl FnMut(ClientId)) {
146        self.state.lock().unwrap().with_all_clients(&mut f)
147    }
148
149    pub fn with_all_objects_for(
150        &self,
151        client_id: InnerClientId,
152        mut f: impl FnMut(ObjectId),
153    ) -> Result<(), InvalidId> {
154        self.state.lock().unwrap().with_all_objects_for(client_id, &mut f)
155    }
156
157    pub fn object_for_protocol_id(
158        &self,
159        client_id: InnerClientId,
160        interface: &'static Interface,
161        protocol_id: u32,
162    ) -> Result<ObjectId, InvalidId> {
163        self.state.lock().unwrap().object_for_protocol_id(client_id, interface, protocol_id)
164    }
165
166    pub fn create_object<D: 'static>(
167        &self,
168        client_id: InnerClientId,
169        interface: &'static Interface,
170        version: u32,
171        data: Arc<dyn ObjectData<D>>,
172    ) -> Result<ObjectId, InvalidId> {
173        let mut state = self.state.lock().unwrap();
174        let state = (&mut *state as &mut dyn ErasedState)
175            .downcast_mut::<State<D>>()
176            .expect("Wrong type parameter passed to Handle::create_object().");
177        let client = state.clients.get_client_mut(client_id)?;
178        Ok(ObjectId { id: client.create_object(interface, version, data) })
179    }
180
181    pub fn destroy_object<D: 'static>(&self, id: &ObjectId) -> Result<(), InvalidId> {
182        let mut state = self.state.lock().unwrap();
183        let state = (&mut *state as &mut dyn ErasedState)
184            .downcast_mut::<State<D>>()
185            .expect("Wrong type parameter passed to Handle::destroy_object().");
186        let client = state.clients.get_client_mut(id.id.client_id.clone())?;
187        client.destroy_object(id.id.clone(), &mut state.pending_destructors)
188    }
189
190    pub fn null_id() -> ObjectId {
191        ObjectId {
192            id: InnerObjectId {
193                id: 0,
194                serial: 0,
195                client_id: InnerClientId { id: 0, serial: 0 },
196                interface: &ANONYMOUS_INTERFACE,
197            },
198        }
199    }
200
201    pub fn send_event(&self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId> {
202        self.state.lock().unwrap().send_event(msg)
203    }
204
205    pub fn get_object_data<D: 'static>(
206        &self,
207        id: InnerObjectId,
208    ) -> Result<Arc<dyn ObjectData<D>>, InvalidId> {
209        let mut state = self.state.lock().unwrap();
210        let state = (&mut *state as &mut dyn ErasedState)
211            .downcast_mut::<State<D>>()
212            .expect("Wrong type parameter passed to Handle::get_object_data().");
213        state.clients.get_client(id.client_id.clone())?.get_object_data(id)
214    }
215
216    pub fn get_object_data_any(
217        &self,
218        id: InnerObjectId,
219    ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId> {
220        self.state.lock().unwrap().get_object_data_any(id)
221    }
222
223    pub fn set_object_data<D: 'static>(
224        &self,
225        id: InnerObjectId,
226        data: Arc<dyn ObjectData<D>>,
227    ) -> Result<(), InvalidId> {
228        let mut state = self.state.lock().unwrap();
229        let state = (&mut *state as &mut dyn ErasedState)
230            .downcast_mut::<State<D>>()
231            .expect("Wrong type parameter passed to Handle::set_object_data().");
232        state.clients.get_client_mut(id.client_id.clone())?.set_object_data(id, data)
233    }
234
235    pub fn post_error(&self, object_id: InnerObjectId, error_code: u32, message: CString) {
236        self.state.lock().unwrap().post_error(object_id, error_code, message)
237    }
238
239    pub fn kill_client(&self, client_id: InnerClientId, reason: DisconnectReason) {
240        self.state.lock().unwrap().kill_client(client_id, reason)
241    }
242
243    pub fn create_global<D: 'static>(
244        &self,
245        interface: &'static Interface,
246        version: u32,
247        handler: Arc<dyn GlobalHandler<D>>,
248    ) -> InnerGlobalId {
249        let mut state = self.state.lock().unwrap();
250        let state = (&mut *state as &mut dyn ErasedState)
251            .downcast_mut::<State<D>>()
252            .expect("Wrong type parameter passed to Handle::create_global().");
253        state.registry.create_global(interface, version, handler, &mut state.clients)
254    }
255
256    pub fn disable_global<D: 'static>(&self, id: InnerGlobalId) {
257        let mut state = self.state.lock().unwrap();
258        let state = (&mut *state as &mut dyn ErasedState)
259            .downcast_mut::<State<D>>()
260            .expect("Wrong type parameter passed to Handle::disable_global().");
261
262        state.registry.disable_global(id, &mut state.clients)
263    }
264
265    pub fn remove_global<D: 'static>(&self, id: InnerGlobalId) {
266        let mut state_lock = self.state.lock().unwrap();
267        let state = (&mut *state_lock as &mut dyn ErasedState)
268            .downcast_mut::<State<D>>()
269            .expect("Wrong type parameter passed to Handle::remove_global().");
270
271        let global = state.registry.remove_global(id, &mut state.clients);
272        // Don't free global user-data until lock is released
273        drop(state_lock);
274        drop(global);
275    }
276
277    pub fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId> {
278        self.state.lock().unwrap().global_info(id)
279    }
280
281    #[cfg_attr(not(feature = "libwayland_server_1_22"), allow(dead_code))]
282    pub fn global_name(&self, global: InnerGlobalId, client: InnerClientId) -> Option<u32> {
283        self.state.lock().unwrap().global_name(global, client)
284    }
285
286    pub fn get_global_handler<D: 'static>(
287        &self,
288        id: InnerGlobalId,
289    ) -> Result<Arc<dyn GlobalHandler<D>>, InvalidId> {
290        let mut state = self.state.lock().unwrap();
291        let state = (&mut *state as &mut dyn ErasedState)
292            .downcast_mut::<State<D>>()
293            .expect("Wrong type parameter passed to Handle::get_global_handler().");
294        state.registry.get_handler(id)
295    }
296
297    pub fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
298        self.state.lock().unwrap().flush(client)
299    }
300
301    #[allow(dead_code)]
302    pub fn set_default_max_buffer_size(&self, max_buffer_size: usize) {
303        self.state.lock().unwrap().set_default_max_buffer_size(max_buffer_size)
304    }
305
306    #[allow(dead_code)]
307    pub fn set_client_max_buffer_size(&self, client: InnerClientId, max_buffer_size: usize) {
308        self.state.lock().unwrap().set_client_max_buffer_size(client, max_buffer_size)
309    }
310}
311
312pub(crate) trait ErasedState: downcast_rs::Downcast {
313    fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId>;
314    fn insert_client(
315        &mut self,
316        stream: UnixStream,
317        data: Arc<dyn ClientData>,
318    ) -> std::io::Result<InnerClientId>;
319    fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId>;
320    fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId>;
321    fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId>;
322    fn with_all_clients(&self, f: &mut dyn FnMut(ClientId));
323    fn with_all_objects_for(
324        &self,
325        client_id: InnerClientId,
326        f: &mut dyn FnMut(ObjectId),
327    ) -> Result<(), InvalidId>;
328    fn object_for_protocol_id(
329        &self,
330        client_id: InnerClientId,
331        interface: &'static Interface,
332        protocol_id: u32,
333    ) -> Result<ObjectId, InvalidId>;
334    fn get_object_data_any(
335        &self,
336        id: InnerObjectId,
337    ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId>;
338    fn send_event(&mut self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId>;
339    fn post_error(&mut self, object_id: InnerObjectId, error_code: u32, message: CString);
340    fn kill_client(&mut self, client_id: InnerClientId, reason: DisconnectReason);
341    fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId>;
342    fn global_name(&self, global: InnerGlobalId, client: InnerClientId) -> Option<u32>;
343    fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()>;
344    fn set_default_max_buffer_size(&mut self, max_buffer_size: usize);
345    fn set_client_max_buffer_size(&mut self, client: InnerClientId, max_buffer_size: usize);
346}
347
348downcast_rs::impl_downcast!(ErasedState);
349
350impl<D> ErasedState for State<D> {
351    fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId> {
352        self.clients.get_client(id.client_id.clone())?.object_info(id)
353    }
354
355    fn insert_client(
356        &mut self,
357        stream: UnixStream,
358        data: Arc<dyn ClientData>,
359    ) -> std::io::Result<InnerClientId> {
360        let id = self.clients.create_client(stream, data, self.default_max_buffer_size);
361        let client = self.clients.get_client(id.clone()).unwrap();
362
363        // register the client to the internal epoll
364        #[cfg(any(target_os = "linux", target_os = "android", target_os = "redox"))]
365        let ret = {
366            use rustix::event::epoll;
367            epoll::add(
368                &self.poll_fd,
369                client,
370                epoll::EventData::new_u64(id.as_u64()),
371                epoll::EventFlags::IN,
372            )
373        };
374
375        #[cfg(any(
376            target_os = "dragonfly",
377            target_os = "freebsd",
378            target_os = "netbsd",
379            target_os = "openbsd",
380            target_os = "macos"
381        ))]
382        let ret = {
383            use rustix::event::kqueue::*;
384            use std::os::unix::io::{AsFd, AsRawFd};
385
386            let evt = Event::new(
387                EventFilter::Read(client.as_fd().as_raw_fd()),
388                EventFlags::ADD | EventFlags::RECEIPT,
389                id.as_u64() as *mut _,
390            );
391
392            let events: &mut [Event] = &mut [];
393            unsafe { kevent(&self.poll_fd, &[evt], events, None).map(|_| ()) }
394        };
395
396        match ret {
397            Ok(()) => Ok(id),
398            Err(e) => {
399                self.kill_client(id, DisconnectReason::ConnectionClosed);
400                Err(e.into())
401            }
402        }
403    }
404
405    fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId> {
406        if self.clients.get_client(id.client_id.clone()).is_ok() {
407            Ok(ClientId { id: id.client_id })
408        } else {
409            Err(InvalidId)
410        }
411    }
412
413    fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId> {
414        let client = self.clients.get_client(id)?;
415        Ok(client.data.clone())
416    }
417
418    fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId> {
419        let client = self.clients.get_client(id)?;
420        Ok(client.get_credentials())
421    }
422
423    fn with_all_clients(&self, f: &mut dyn FnMut(ClientId)) {
424        for client in self.clients.all_clients_id() {
425            f(client)
426        }
427    }
428
429    fn with_all_objects_for(
430        &self,
431        client_id: InnerClientId,
432        f: &mut dyn FnMut(ObjectId),
433    ) -> Result<(), InvalidId> {
434        let client = self.clients.get_client(client_id)?;
435        for object in client.all_objects() {
436            f(object)
437        }
438        Ok(())
439    }
440
441    fn object_for_protocol_id(
442        &self,
443        client_id: InnerClientId,
444        interface: &'static Interface,
445        protocol_id: u32,
446    ) -> Result<ObjectId, InvalidId> {
447        let client = self.clients.get_client(client_id)?;
448        let object = client.object_for_protocol_id(protocol_id)?;
449        if same_interface(interface, object.interface) {
450            Ok(ObjectId { id: object })
451        } else {
452            Err(InvalidId)
453        }
454    }
455
456    fn get_object_data_any(
457        &self,
458        id: InnerObjectId,
459    ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId> {
460        self.clients
461            .get_client(id.client_id.clone())?
462            .get_object_data(id)
463            .map(|arc| arc.into_any_arc())
464    }
465
466    fn send_event(&mut self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId> {
467        self.clients
468            .get_client_mut(msg.sender_id.id.client_id.clone())?
469            .send_event(msg, Some(&mut self.pending_destructors))
470    }
471
472    fn post_error(&mut self, object_id: InnerObjectId, error_code: u32, message: CString) {
473        if let Ok(client) = self.clients.get_client_mut(object_id.client_id.clone()) {
474            client.post_error(object_id, error_code, message)
475        }
476    }
477
478    fn kill_client(&mut self, client_id: InnerClientId, reason: DisconnectReason) {
479        if let Ok(client) = self.clients.get_client_mut(client_id) {
480            client.kill(reason)
481        }
482    }
483    fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId> {
484        self.registry.get_info(id)
485    }
486
487    fn global_name(&self, global_id: InnerGlobalId, client_id: InnerClientId) -> Option<u32> {
488        let client = self.clients.get_client(client_id.clone()).ok()?;
489        let handler = self.registry.get_handler(global_id.clone()).ok()?;
490        let name = global_id.id;
491
492        let can_view =
493            handler.can_view(ClientId { id: client_id }, &client.data, GlobalId { id: global_id });
494
495        if can_view {
496            Some(name)
497        } else {
498            None
499        }
500    }
501
502    fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
503        self.flush(client)
504    }
505
506    fn set_default_max_buffer_size(&mut self, max_buffer_size: usize) {
507        self.set_default_max_buffer_size(max_buffer_size)
508    }
509
510    fn set_client_max_buffer_size(&mut self, client: InnerClientId, max_buffer_size: usize) {
511        self.set_client_max_buffer_size(client, max_buffer_size)
512    }
513}