1use std::{
2 ffi::CString,
3 os::unix::{
4 io::{OwnedFd, RawFd},
5 net::UnixStream,
6 },
7 sync::{Arc, Mutex, Weak},
8};
9
10use crate::{
11 protocol::{same_interface, Interface, Message, ObjectInfo, ANONYMOUS_INTERFACE},
12 rs::DEFAULT_MAX_BUFFER_SIZE,
13 types::server::{DisconnectReason, GlobalInfo, InvalidId},
14};
15
16use super::{
17 client::ClientStore, registry::Registry, ClientData, ClientId, Credentials, GlobalHandler,
18 GlobalId, InnerClientId, InnerGlobalId, InnerObjectId, ObjectData, ObjectId,
19};
20
21pub(crate) type PendingDestructor<D> = (Arc<dyn ObjectData<D>>, InnerClientId, InnerObjectId);
22
23#[derive(Debug)]
24pub struct State<D: 'static> {
25 pub(crate) clients: ClientStore<D>,
26 pub(crate) registry: Registry<D>,
27 pub(crate) pending_destructors: Vec<PendingDestructor<D>>,
28 pub(crate) poll_fd: OwnedFd,
29 pub(crate) default_max_buffer_size: usize,
30}
31
32impl<D> State<D> {
33 pub(crate) fn new(poll_fd: OwnedFd) -> Self {
34 let debug =
35 matches!(std::env::var_os("WAYLAND_DEBUG"), Some(str) if str == "1" || str == "server");
36 Self {
37 clients: ClientStore::new(debug),
38 registry: Registry::new(),
39 pending_destructors: Vec::new(),
40 poll_fd,
41 default_max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
42 }
43 }
44
45 pub(crate) fn cleanup<'a>(&mut self) -> impl FnOnce(&super::Handle, &mut D) + 'a {
46 let dead_clients = self.clients.cleanup(&mut self.pending_destructors);
47 self.registry.cleanup(&dead_clients, &self.pending_destructors);
48 let pending_destructors = std::mem::take(&mut self.pending_destructors);
50 move |handle, data| {
51 for (object_data, client_id, object_id) in pending_destructors {
52 object_data.clone().destroyed(
53 handle,
54 data,
55 ClientId { id: client_id },
56 ObjectId { id: object_id },
57 );
58 }
59 std::mem::drop(dead_clients);
60 }
61 }
62
63 pub(crate) fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
64 if let Some(ClientId { id: client }) = client {
65 match self.clients.get_client_mut(client) {
66 Ok(client) => client.flush(),
67 Err(InvalidId) => Ok(()),
68 }
69 } else {
70 for client in self.clients.clients_mut() {
71 let _ = client.flush();
72 }
73 Ok(())
74 }
75 }
76
77 pub fn set_default_max_buffer_size(&mut self, max_buffer_size: usize) {
78 self.default_max_buffer_size = max_buffer_size;
79 }
80
81 fn set_client_max_buffer_size(&mut self, client: InnerClientId, max_buffer_size: usize) {
82 if let Ok(client) = self.clients.get_client_mut(client) {
83 client.socket.set_max_buffer_size(Some(max_buffer_size));
84 }
85 }
86}
87
88#[derive(Clone)]
89pub struct InnerHandle {
90 pub(crate) state: Arc<Mutex<dyn ErasedState + Send>>,
91}
92
93impl std::fmt::Debug for InnerHandle {
94 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 fmt.debug_struct("InnerHandle[rs]").finish_non_exhaustive()
96 }
97}
98
99#[derive(Clone)]
100pub struct WeakInnerHandle {
101 pub(crate) state: Weak<Mutex<dyn ErasedState + Send>>,
102}
103
104impl std::fmt::Debug for WeakInnerHandle {
105 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 fmt.debug_struct("WeakInnerHandle[rs]").finish_non_exhaustive()
107 }
108}
109
110impl WeakInnerHandle {
111 pub fn upgrade(&self) -> Option<InnerHandle> {
112 self.state.upgrade().map(|state| InnerHandle { state })
113 }
114}
115
116impl InnerHandle {
117 pub fn downgrade(&self) -> WeakInnerHandle {
118 WeakInnerHandle { state: Arc::downgrade(&self.state) }
119 }
120
121 pub fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId> {
122 self.state.lock().unwrap().object_info(id)
123 }
124
125 pub fn insert_client(
126 &self,
127 stream: UnixStream,
128 data: Arc<dyn ClientData>,
129 ) -> std::io::Result<InnerClientId> {
130 self.state.lock().unwrap().insert_client(stream, data)
131 }
132
133 pub fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId> {
134 self.state.lock().unwrap().get_client(id)
135 }
136
137 pub fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId> {
138 self.state.lock().unwrap().get_client_data(id)
139 }
140
141 pub fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId> {
142 self.state.lock().unwrap().get_client_credentials(id)
143 }
144
145 pub fn with_all_clients(&self, mut f: impl FnMut(ClientId)) {
146 self.state.lock().unwrap().with_all_clients(&mut f)
147 }
148
149 pub fn with_all_objects_for(
150 &self,
151 client_id: InnerClientId,
152 mut f: impl FnMut(ObjectId),
153 ) -> Result<(), InvalidId> {
154 self.state.lock().unwrap().with_all_objects_for(client_id, &mut f)
155 }
156
157 pub fn object_for_protocol_id(
158 &self,
159 client_id: InnerClientId,
160 interface: &'static Interface,
161 protocol_id: u32,
162 ) -> Result<ObjectId, InvalidId> {
163 self.state.lock().unwrap().object_for_protocol_id(client_id, interface, protocol_id)
164 }
165
166 pub fn create_object<D: 'static>(
167 &self,
168 client_id: InnerClientId,
169 interface: &'static Interface,
170 version: u32,
171 data: Arc<dyn ObjectData<D>>,
172 ) -> Result<ObjectId, InvalidId> {
173 let mut state = self.state.lock().unwrap();
174 let state = (&mut *state as &mut dyn ErasedState)
175 .downcast_mut::<State<D>>()
176 .expect("Wrong type parameter passed to Handle::create_object().");
177 let client = state.clients.get_client_mut(client_id)?;
178 Ok(ObjectId { id: client.create_object(interface, version, data) })
179 }
180
181 pub fn destroy_object<D: 'static>(&self, id: &ObjectId) -> Result<(), InvalidId> {
182 let mut state = self.state.lock().unwrap();
183 let state = (&mut *state as &mut dyn ErasedState)
184 .downcast_mut::<State<D>>()
185 .expect("Wrong type parameter passed to Handle::destroy_object().");
186 let client = state.clients.get_client_mut(id.id.client_id.clone())?;
187 client.destroy_object(id.id.clone(), &mut state.pending_destructors)
188 }
189
190 pub fn null_id() -> ObjectId {
191 ObjectId {
192 id: InnerObjectId {
193 id: 0,
194 serial: 0,
195 client_id: InnerClientId { id: 0, serial: 0 },
196 interface: &ANONYMOUS_INTERFACE,
197 },
198 }
199 }
200
201 pub fn send_event(&self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId> {
202 self.state.lock().unwrap().send_event(msg)
203 }
204
205 pub fn get_object_data<D: 'static>(
206 &self,
207 id: InnerObjectId,
208 ) -> Result<Arc<dyn ObjectData<D>>, InvalidId> {
209 let mut state = self.state.lock().unwrap();
210 let state = (&mut *state as &mut dyn ErasedState)
211 .downcast_mut::<State<D>>()
212 .expect("Wrong type parameter passed to Handle::get_object_data().");
213 state.clients.get_client(id.client_id.clone())?.get_object_data(id)
214 }
215
216 pub fn get_object_data_any(
217 &self,
218 id: InnerObjectId,
219 ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId> {
220 self.state.lock().unwrap().get_object_data_any(id)
221 }
222
223 pub fn set_object_data<D: 'static>(
224 &self,
225 id: InnerObjectId,
226 data: Arc<dyn ObjectData<D>>,
227 ) -> Result<(), InvalidId> {
228 let mut state = self.state.lock().unwrap();
229 let state = (&mut *state as &mut dyn ErasedState)
230 .downcast_mut::<State<D>>()
231 .expect("Wrong type parameter passed to Handle::set_object_data().");
232 state.clients.get_client_mut(id.client_id.clone())?.set_object_data(id, data)
233 }
234
235 pub fn post_error(&self, object_id: InnerObjectId, error_code: u32, message: CString) {
236 self.state.lock().unwrap().post_error(object_id, error_code, message)
237 }
238
239 pub fn kill_client(&self, client_id: InnerClientId, reason: DisconnectReason) {
240 self.state.lock().unwrap().kill_client(client_id, reason)
241 }
242
243 pub fn create_global<D: 'static>(
244 &self,
245 interface: &'static Interface,
246 version: u32,
247 handler: Arc<dyn GlobalHandler<D>>,
248 ) -> InnerGlobalId {
249 let mut state = self.state.lock().unwrap();
250 let state = (&mut *state as &mut dyn ErasedState)
251 .downcast_mut::<State<D>>()
252 .expect("Wrong type parameter passed to Handle::create_global().");
253 state.registry.create_global(interface, version, handler, &mut state.clients)
254 }
255
256 pub fn disable_global<D: 'static>(&self, id: InnerGlobalId) {
257 let mut state = self.state.lock().unwrap();
258 let state = (&mut *state as &mut dyn ErasedState)
259 .downcast_mut::<State<D>>()
260 .expect("Wrong type parameter passed to Handle::disable_global().");
261
262 state.registry.disable_global(id, &mut state.clients)
263 }
264
265 pub fn remove_global<D: 'static>(&self, id: InnerGlobalId) {
266 let mut state_lock = self.state.lock().unwrap();
267 let state = (&mut *state_lock as &mut dyn ErasedState)
268 .downcast_mut::<State<D>>()
269 .expect("Wrong type parameter passed to Handle::remove_global().");
270
271 let global = state.registry.remove_global(id, &mut state.clients);
272 drop(state_lock);
274 drop(global);
275 }
276
277 pub fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId> {
278 self.state.lock().unwrap().global_info(id)
279 }
280
281 #[cfg_attr(not(feature = "libwayland_server_1_22"), allow(dead_code))]
282 pub fn global_name(&self, global: InnerGlobalId, client: InnerClientId) -> Option<u32> {
283 self.state.lock().unwrap().global_name(global, client)
284 }
285
286 pub fn get_global_handler<D: 'static>(
287 &self,
288 id: InnerGlobalId,
289 ) -> Result<Arc<dyn GlobalHandler<D>>, InvalidId> {
290 let mut state = self.state.lock().unwrap();
291 let state = (&mut *state as &mut dyn ErasedState)
292 .downcast_mut::<State<D>>()
293 .expect("Wrong type parameter passed to Handle::get_global_handler().");
294 state.registry.get_handler(id)
295 }
296
297 pub fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
298 self.state.lock().unwrap().flush(client)
299 }
300
301 #[allow(dead_code)]
302 pub fn set_default_max_buffer_size(&self, max_buffer_size: usize) {
303 self.state.lock().unwrap().set_default_max_buffer_size(max_buffer_size)
304 }
305
306 #[allow(dead_code)]
307 pub fn set_client_max_buffer_size(&self, client: InnerClientId, max_buffer_size: usize) {
308 self.state.lock().unwrap().set_client_max_buffer_size(client, max_buffer_size)
309 }
310}
311
312pub(crate) trait ErasedState: downcast_rs::Downcast {
313 fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId>;
314 fn insert_client(
315 &mut self,
316 stream: UnixStream,
317 data: Arc<dyn ClientData>,
318 ) -> std::io::Result<InnerClientId>;
319 fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId>;
320 fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId>;
321 fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId>;
322 fn with_all_clients(&self, f: &mut dyn FnMut(ClientId));
323 fn with_all_objects_for(
324 &self,
325 client_id: InnerClientId,
326 f: &mut dyn FnMut(ObjectId),
327 ) -> Result<(), InvalidId>;
328 fn object_for_protocol_id(
329 &self,
330 client_id: InnerClientId,
331 interface: &'static Interface,
332 protocol_id: u32,
333 ) -> Result<ObjectId, InvalidId>;
334 fn get_object_data_any(
335 &self,
336 id: InnerObjectId,
337 ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId>;
338 fn send_event(&mut self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId>;
339 fn post_error(&mut self, object_id: InnerObjectId, error_code: u32, message: CString);
340 fn kill_client(&mut self, client_id: InnerClientId, reason: DisconnectReason);
341 fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId>;
342 fn global_name(&self, global: InnerGlobalId, client: InnerClientId) -> Option<u32>;
343 fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()>;
344 fn set_default_max_buffer_size(&mut self, max_buffer_size: usize);
345 fn set_client_max_buffer_size(&mut self, client: InnerClientId, max_buffer_size: usize);
346}
347
348downcast_rs::impl_downcast!(ErasedState);
349
350impl<D> ErasedState for State<D> {
351 fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId> {
352 self.clients.get_client(id.client_id.clone())?.object_info(id)
353 }
354
355 fn insert_client(
356 &mut self,
357 stream: UnixStream,
358 data: Arc<dyn ClientData>,
359 ) -> std::io::Result<InnerClientId> {
360 let id = self.clients.create_client(stream, data, self.default_max_buffer_size);
361 let client = self.clients.get_client(id.clone()).unwrap();
362
363 #[cfg(any(target_os = "linux", target_os = "android", target_os = "redox"))]
365 let ret = {
366 use rustix::event::epoll;
367 epoll::add(
368 &self.poll_fd,
369 client,
370 epoll::EventData::new_u64(id.as_u64()),
371 epoll::EventFlags::IN,
372 )
373 };
374
375 #[cfg(any(
376 target_os = "dragonfly",
377 target_os = "freebsd",
378 target_os = "netbsd",
379 target_os = "openbsd",
380 target_os = "macos"
381 ))]
382 let ret = {
383 use rustix::event::kqueue::*;
384 use std::os::unix::io::{AsFd, AsRawFd};
385
386 let evt = Event::new(
387 EventFilter::Read(client.as_fd().as_raw_fd()),
388 EventFlags::ADD | EventFlags::RECEIPT,
389 id.as_u64() as *mut _,
390 );
391
392 let events: &mut [Event] = &mut [];
393 unsafe { kevent(&self.poll_fd, &[evt], events, None).map(|_| ()) }
394 };
395
396 match ret {
397 Ok(()) => Ok(id),
398 Err(e) => {
399 self.kill_client(id, DisconnectReason::ConnectionClosed);
400 Err(e.into())
401 }
402 }
403 }
404
405 fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId> {
406 if self.clients.get_client(id.client_id.clone()).is_ok() {
407 Ok(ClientId { id: id.client_id })
408 } else {
409 Err(InvalidId)
410 }
411 }
412
413 fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId> {
414 let client = self.clients.get_client(id)?;
415 Ok(client.data.clone())
416 }
417
418 fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId> {
419 let client = self.clients.get_client(id)?;
420 Ok(client.get_credentials())
421 }
422
423 fn with_all_clients(&self, f: &mut dyn FnMut(ClientId)) {
424 for client in self.clients.all_clients_id() {
425 f(client)
426 }
427 }
428
429 fn with_all_objects_for(
430 &self,
431 client_id: InnerClientId,
432 f: &mut dyn FnMut(ObjectId),
433 ) -> Result<(), InvalidId> {
434 let client = self.clients.get_client(client_id)?;
435 for object in client.all_objects() {
436 f(object)
437 }
438 Ok(())
439 }
440
441 fn object_for_protocol_id(
442 &self,
443 client_id: InnerClientId,
444 interface: &'static Interface,
445 protocol_id: u32,
446 ) -> Result<ObjectId, InvalidId> {
447 let client = self.clients.get_client(client_id)?;
448 let object = client.object_for_protocol_id(protocol_id)?;
449 if same_interface(interface, object.interface) {
450 Ok(ObjectId { id: object })
451 } else {
452 Err(InvalidId)
453 }
454 }
455
456 fn get_object_data_any(
457 &self,
458 id: InnerObjectId,
459 ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId> {
460 self.clients
461 .get_client(id.client_id.clone())?
462 .get_object_data(id)
463 .map(|arc| arc.into_any_arc())
464 }
465
466 fn send_event(&mut self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId> {
467 self.clients
468 .get_client_mut(msg.sender_id.id.client_id.clone())?
469 .send_event(msg, Some(&mut self.pending_destructors))
470 }
471
472 fn post_error(&mut self, object_id: InnerObjectId, error_code: u32, message: CString) {
473 if let Ok(client) = self.clients.get_client_mut(object_id.client_id.clone()) {
474 client.post_error(object_id, error_code, message)
475 }
476 }
477
478 fn kill_client(&mut self, client_id: InnerClientId, reason: DisconnectReason) {
479 if let Ok(client) = self.clients.get_client_mut(client_id) {
480 client.kill(reason)
481 }
482 }
483 fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId> {
484 self.registry.get_info(id)
485 }
486
487 fn global_name(&self, global_id: InnerGlobalId, client_id: InnerClientId) -> Option<u32> {
488 let client = self.clients.get_client(client_id.clone()).ok()?;
489 let handler = self.registry.get_handler(global_id.clone()).ok()?;
490 let name = global_id.id;
491
492 let can_view =
493 handler.can_view(ClientId { id: client_id }, &client.data, GlobalId { id: global_id });
494
495 if can_view {
496 Some(name)
497 } else {
498 None
499 }
500 }
501
502 fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
503 self.flush(client)
504 }
505
506 fn set_default_max_buffer_size(&mut self, max_buffer_size: usize) {
507 self.set_default_max_buffer_size(max_buffer_size)
508 }
509
510 fn set_client_max_buffer_size(&mut self, client: InnerClientId, max_buffer_size: usize) {
511 self.set_client_max_buffer_size(client, max_buffer_size)
512 }
513}