wayland_backend/rs/
socket.rs

1//! Wayland socket manipulation
2
3use std::collections::VecDeque;
4use std::io::{ErrorKind, IoSlice, IoSliceMut, Result as IoResult};
5use std::mem::MaybeUninit;
6use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
7use std::os::unix::net::UnixStream;
8use std::slice;
9
10use rustix::io::retry_on_intr;
11use rustix::net::{
12    recvmsg, send, sendmsg, RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags,
13    SendAncillaryBuffer, SendAncillaryMessage, SendFlags,
14};
15
16use crate::protocol::{ArgumentType, Message};
17use crate::rs::DEFAULT_MAX_BUFFER_SIZE;
18
19use super::wire::{parse_message, write_to_buffers, MessageParseError, MessageWriteError};
20
21/// Maximum number of FD that can be sent in a single socket message
22pub const MAX_FDS_OUT: usize = 28;
23/// Maximum number of bytes that can be sent in a single socket message
24pub const MAX_BYTES_OUT: usize = 4096;
25
26/*
27 * Socket
28 */
29
30/// A wayland socket
31#[derive(Debug)]
32pub struct Socket {
33    stream: UnixStream,
34}
35
36impl Socket {
37    /// Send a single message to the socket
38    ///
39    /// A single socket message can contain several wayland messages
40    ///
41    /// The `fds` slice should not be longer than `MAX_FDS_OUT`, and the `bytes`
42    /// slice should not be longer than `MAX_BYTES_OUT` otherwise the receiving
43    /// end may lose some data.
44    pub fn send_msg(&self, bytes: &[u8], fds: &[OwnedFd]) -> IoResult<usize> {
45        #[cfg(not(any(target_os = "macos", target_os = "redox")))]
46        let flags = SendFlags::DONTWAIT | SendFlags::NOSIGNAL;
47        #[cfg(any(target_os = "macos", target_os = "redox"))]
48        let flags = SendFlags::DONTWAIT;
49
50        if !fds.is_empty() {
51            let iov = [IoSlice::new(bytes)];
52            let mut cmsg_space =
53                vec![MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(fds.len()))];
54            let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
55            let fds =
56                unsafe { slice::from_raw_parts(fds.as_ptr() as *const BorrowedFd, fds.len()) };
57            cmsg_buffer.push(SendAncillaryMessage::ScmRights(fds));
58            Ok(retry_on_intr(|| sendmsg(self, &iov, &mut cmsg_buffer, flags))?)
59        } else {
60            Ok(retry_on_intr(|| send(self, bytes, flags))?)
61        }
62    }
63
64    /// Receive a single message from the socket
65    ///
66    /// Return the number of bytes received and the number of Fds received.
67    ///
68    /// Errors with `WouldBlock` is no message is available.
69    ///
70    /// A single socket message can contain several wayland messages.
71    ///
72    /// The `buffer` slice should be at least `MAX_BYTES_OUT` long and the `fds`
73    /// slice `MAX_FDS_OUT` long, otherwise some data of the received message may
74    /// be lost.
75    pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut VecDeque<OwnedFd>) -> IoResult<usize> {
76        #[cfg(not(any(target_os = "macos", target_os = "redox")))]
77        let flags = RecvFlags::DONTWAIT | RecvFlags::CMSG_CLOEXEC;
78        #[cfg(any(target_os = "macos", target_os = "redox"))]
79        let flags = RecvFlags::DONTWAIT;
80
81        let mut cmsg_space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(MAX_FDS_OUT))];
82        let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
83        let mut iov = [IoSliceMut::new(buffer)];
84        let msg = retry_on_intr(|| recvmsg(&self.stream, &mut iov[..], &mut cmsg_buffer, flags))?;
85
86        let received_fds = cmsg_buffer
87            .drain()
88            .filter_map(|cmsg| match cmsg {
89                RecvAncillaryMessage::ScmRights(fds) => Some(fds),
90                _ => None,
91            })
92            .flatten();
93        fds.extend(received_fds);
94        #[cfg(any(target_os = "macos", target_os = "redox"))]
95        for fd in fds.iter() {
96            if let Ok(flags) = rustix::io::fcntl_getfd(fd) {
97                let _ = rustix::io::fcntl_setfd(fd, flags | rustix::io::FdFlags::CLOEXEC);
98            }
99        }
100        Ok(msg.bytes)
101    }
102}
103
104impl From<UnixStream> for Socket {
105    fn from(stream: UnixStream) -> Self {
106        // macOS doesn't have MSG_NOSIGNAL, but has SO_NOSIGPIPE instead
107        #[cfg(target_os = "macos")]
108        let _ = rustix::net::sockopt::set_socket_nosigpipe(&stream, true);
109        Self { stream }
110    }
111}
112
113impl AsFd for Socket {
114    fn as_fd(&self) -> BorrowedFd<'_> {
115        self.stream.as_fd()
116    }
117}
118
119impl AsRawFd for Socket {
120    fn as_raw_fd(&self) -> RawFd {
121        self.stream.as_raw_fd()
122    }
123}
124
125/*
126 * BufferedSocket
127 */
128
129/// An adapter around a raw Socket that directly handles buffering and
130/// conversion from/to wayland messages
131#[derive(Debug)]
132pub struct BufferedSocket {
133    socket: Socket,
134    in_data: Buffer<u8>,
135    in_fds: VecDeque<OwnedFd>,
136    out_data: Buffer<u8>,
137    out_fds: Vec<OwnedFd>,
138    max_buffer_size: Option<usize>,
139}
140
141impl BufferedSocket {
142    /// Wrap a Socket into a Buffered Socket
143    pub fn new(socket: Socket, max_buffer_size: Option<usize>) -> Self {
144        Self {
145            socket,
146            in_data: Buffer::new(DEFAULT_MAX_BUFFER_SIZE),
147            in_fds: VecDeque::new(),
148            out_data: Buffer::new(DEFAULT_MAX_BUFFER_SIZE),
149            out_fds: Vec::new(),
150            max_buffer_size: max_buffer_size.map(round_max_buffer_size),
151        }
152    }
153
154    /// Flush the contents of the outgoing buffer into the socket
155    pub fn flush(&mut self) -> IoResult<()> {
156        let bytes = self.out_data.get_contents();
157        let fds = &self.out_fds;
158        let mut written_bytes = 0;
159        let mut written_fds = 0;
160        let mut ret = Ok(());
161        while written_bytes < bytes.len() {
162            let mut bytes_to_write = &bytes[written_bytes..];
163            let mut fds_to_write = &fds[written_fds..];
164            if fds_to_write.len() > MAX_FDS_OUT {
165                // While we need to send more than MAX_FDS_OUT fds,
166                // send them with separate send_msg calls in MAX_FDS_OUT sized chunks
167                // together with 1 byte of normal data.
168                // This achieves the same that libwayland does in wl_connection_flush
169                // and ensures that all file descriptors are sent
170                // before we run out of bytes of normal data to send.
171                bytes_to_write = &bytes_to_write[..1];
172                fds_to_write = &fds_to_write[..MAX_FDS_OUT];
173            }
174            if bytes_to_write.len() > MAX_BYTES_OUT {
175                // Also ensure the MAX_BYTES_OUT limit, this stays redundant as long
176                // as self.out_data has a fixed MAX_BYTES_OUT size and cannot grow.
177                bytes_to_write = &bytes_to_write[..MAX_BYTES_OUT];
178            }
179            match self.socket.send_msg(bytes_to_write, fds_to_write) {
180                Ok(0) => {
181                    // This branch should be unreachable because
182                    // a non-zero sized send or sendmsg should never return 0
183                    written_fds += fds_to_write.len();
184                    break;
185                }
186                Ok(count) => {
187                    written_bytes += count;
188                    written_fds += fds_to_write.len();
189                }
190                Err(error) => {
191                    ret = Err(error);
192                    break;
193                }
194            }
195        }
196        // Either the flush is done or got an error, so clean up and
197        // remove the data and the fds which were sent form the outgoing buffers
198        self.out_data.offset(written_bytes);
199        self.out_data.move_to_front();
200        self.out_fds.drain(..written_fds);
201        ret
202    }
203
204    // internal method
205    //
206    // attempts to write a message in the internal out buffers,
207    // returns true if successful
208    //
209    // if false is returned, it means there is not enough space
210    // in the buffer
211    fn attempt_write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<bool> {
212        let fds_len = self.out_fds.len();
213        loop {
214            match write_to_buffers(msg, self.out_data.get_writable_storage(), &mut self.out_fds) {
215                Ok(bytes_out) => {
216                    self.out_data.advance(bytes_out);
217                    return Ok(true);
218                }
219                Err(MessageWriteError::BufferTooSmall) => {
220                    self.out_fds.truncate(fds_len);
221                    if self.max_buffer_size.is_some_and(|s| s <= self.out_data.storage.len()) {
222                        return Ok(false);
223                    }
224                    self.out_data.increase_capacity();
225                }
226                Err(MessageWriteError::DupFdFailed(e)) => {
227                    return Err(e);
228                }
229            }
230        }
231    }
232
233    /// Write a message to the outgoing buffer
234    ///
235    /// This method may flush the internal buffer if necessary (if it is full).
236    ///
237    /// If the message is too big to fit in the buffer, the error `Error::Sys(E2BIG)`
238    /// will be returned.
239    pub fn write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<()> {
240        if !self.attempt_write_message(msg)? {
241            // the attempt failed, there is not enough space in the buffer
242            // we need to flush it
243            if let Err(e) = self.flush() {
244                if e.kind() != ErrorKind::WouldBlock {
245                    return Err(e);
246                }
247            }
248            if !self.attempt_write_message(msg)? {
249                // If this fails again, this means the message is too big
250                // to be transmitted at all
251                return Err(rustix::io::Errno::TOOBIG.into());
252            }
253        }
254        Ok(())
255    }
256
257    /// Try to fill the incoming buffers of this socket, to prepare
258    /// a new round of parsing.
259    pub fn fill_incoming_buffers(&mut self) -> IoResult<()> {
260        // reorganize the buffers
261        self.in_data.move_to_front();
262        // receive a message
263        let in_bytes = {
264            let bytes = self.in_data.get_writable_storage();
265            self.socket.rcv_msg(bytes, &mut self.in_fds)?
266        };
267        if in_bytes == 0 {
268            // the other end of the socket was closed
269            return Err(rustix::io::Errno::PIPE.into());
270        }
271        // advance the storage
272        self.in_data.advance(in_bytes);
273        Ok(())
274    }
275
276    /// Read and deserialize a single message from the incoming buffers socket
277    ///
278    /// This method requires one closure that given an object id and an opcode,
279    /// must provide the signature of the associated request/event, in the form of
280    /// a `&'static [ArgumentType]`.
281    pub fn read_one_message<F>(
282        &mut self,
283        mut signature: F,
284    ) -> Result<Message<u32, OwnedFd>, MessageParseError>
285    where
286        F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
287    {
288        let (msg, read_data) = {
289            let data = self.in_data.get_contents();
290            if data.len() < 2 * 4 {
291                return Err(MessageParseError::MissingData);
292            }
293            let object_id = u32::from_ne_bytes([data[0], data[1], data[2], data[3]]);
294            let word_2 = u32::from_ne_bytes([data[4], data[5], data[6], data[7]]);
295            let opcode = (word_2 & 0x0000_FFFF) as u16;
296            if let Some(sig) = signature(object_id, opcode) {
297                match parse_message(data, sig, &mut self.in_fds) {
298                    Ok((msg, rest_data)) => (msg, data.len() - rest_data.len()),
299                    Err(e) => return Err(e),
300                }
301            } else {
302                // no signature found ?
303                return Err(MessageParseError::Malformed);
304            }
305        };
306
307        self.in_data.offset(read_data);
308
309        Ok(msg)
310    }
311
312    pub fn set_max_buffer_size(&mut self, max_buffer_size: Option<usize>) {
313        // TODO: what if it decreases?
314        self.max_buffer_size = max_buffer_size.map(round_max_buffer_size);
315    }
316}
317
318impl AsRawFd for BufferedSocket {
319    fn as_raw_fd(&self) -> RawFd {
320        self.socket.as_raw_fd()
321    }
322}
323
324impl AsFd for BufferedSocket {
325    fn as_fd(&self) -> BorrowedFd<'_> {
326        self.socket.as_fd()
327    }
328}
329
330/*
331 * Buffer
332 */
333#[derive(Debug)]
334struct Buffer<T: Copy> {
335    storage: Vec<T>,
336    occupied: usize,
337    offset: usize,
338}
339
340impl<T: Copy + Default> Buffer<T> {
341    fn new(size: usize) -> Self {
342        Self { storage: vec![T::default(); size], occupied: 0, offset: 0 }
343    }
344
345    /// Advance the internal counter of occupied space
346    fn advance(&mut self, bytes: usize) {
347        self.occupied += bytes;
348    }
349
350    /// Advance the read offset of current occupied space
351    fn offset(&mut self, bytes: usize) {
352        self.offset += bytes;
353    }
354
355    fn increase_capacity(&mut self) {
356        let new_len = self.storage.len() * 2;
357        self.storage.resize_with(new_len, Default::default);
358    }
359
360    /// Clears the contents of the buffer
361    ///
362    /// This only sets the counter of occupied space back to zero,
363    /// allowing previous content to be overwritten.
364    #[allow(unused)]
365    fn clear(&mut self) {
366        self.occupied = 0;
367        self.offset = 0;
368    }
369
370    /// Get the current contents of the occupied space of the buffer
371    fn get_contents(&self) -> &[T] {
372        &self.storage[(self.offset)..(self.occupied)]
373    }
374
375    /// Get mutable access to the unoccupied space of the buffer
376    fn get_writable_storage(&mut self) -> &mut [T] {
377        &mut self.storage[(self.occupied)..]
378    }
379
380    /// Move the unread contents of the buffer to the front, to ensure
381    /// maximal write space availability
382    fn move_to_front(&mut self) {
383        if self.occupied > self.offset {
384            self.storage.copy_within((self.offset)..(self.occupied), 0)
385        }
386        self.occupied -= self.offset;
387        self.offset = 0;
388    }
389}
390
391fn round_max_buffer_size(max_buffer_size: usize) -> usize {
392    max_buffer_size.checked_next_power_of_two().unwrap_or(usize::MAX).max(DEFAULT_MAX_BUFFER_SIZE)
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use crate::protocol::{AllowNull, Argument, ArgumentType, Message};
399
400    use std::ffi::CString;
401    use std::os::unix::io::IntoRawFd;
402
403    use smallvec::smallvec;
404
405    fn same_file(a: BorrowedFd, b: BorrowedFd) -> bool {
406        let stat1 = rustix::fs::fstat(a).unwrap();
407        let stat2 = rustix::fs::fstat(b).unwrap();
408        stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
409    }
410
411    // check if two messages are equal
412    //
413    // if arguments contain FDs, check that the fd point to
414    // the same file, rather than are the same number.
415    fn assert_eq_msgs<Fd: AsRawFd + std::fmt::Debug>(
416        msg1: &Message<u32, Fd>,
417        msg2: &Message<u32, Fd>,
418    ) {
419        assert_eq!(msg1.sender_id, msg2.sender_id);
420        assert_eq!(msg1.opcode, msg2.opcode);
421        assert_eq!(msg1.args.len(), msg2.args.len());
422        for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
423            if let (Argument::Fd(fd1), Argument::Fd(fd2)) = (arg1, arg2) {
424                let fd1 = unsafe { BorrowedFd::borrow_raw(fd1.as_raw_fd()) };
425                let fd2 = unsafe { BorrowedFd::borrow_raw(fd2.as_raw_fd()) };
426                assert!(same_file(fd1, fd2));
427            } else {
428                assert_eq!(arg1, arg2);
429            }
430        }
431    }
432
433    #[test]
434    fn write_read_cycle() {
435        let msg = Message {
436            sender_id: 42,
437            opcode: 7,
438            args: smallvec![
439                Argument::Uint(3),
440                Argument::Fixed(-89),
441                Argument::Str(Some(Box::new(CString::new(&b"I like trains!"[..]).unwrap()))),
442                Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into()),
443                Argument::Object(88),
444                Argument::NewId(56),
445                Argument::Int(-25),
446            ],
447        };
448
449        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
450        let mut client = BufferedSocket::new(Socket::from(client), Some(DEFAULT_MAX_BUFFER_SIZE));
451        let mut server = BufferedSocket::new(Socket::from(server), Some(DEFAULT_MAX_BUFFER_SIZE));
452
453        client.write_message(&msg).unwrap();
454        client.flush().unwrap();
455
456        static SIGNATURE: &[ArgumentType] = &[
457            ArgumentType::Uint,
458            ArgumentType::Fixed,
459            ArgumentType::Str(AllowNull::No),
460            ArgumentType::Array,
461            ArgumentType::Object(AllowNull::No),
462            ArgumentType::NewId,
463            ArgumentType::Int,
464        ];
465
466        server.fill_incoming_buffers().unwrap();
467
468        let ret_msg =
469            server
470                .read_one_message(|sender_id, opcode| {
471                    if sender_id == 42 && opcode == 7 {
472                        Some(SIGNATURE)
473                    } else {
474                        None
475                    }
476                })
477                .unwrap();
478
479        assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
480    }
481
482    #[test]
483    fn write_read_cycle_fd() {
484        let msg = Message {
485            sender_id: 42,
486            opcode: 7,
487            args: smallvec![
488                Argument::Fd(1), // stdin
489                Argument::Fd(0), // stdout
490            ],
491        };
492
493        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
494        let mut client = BufferedSocket::new(Socket::from(client), Some(DEFAULT_MAX_BUFFER_SIZE));
495        let mut server = BufferedSocket::new(Socket::from(server), Some(DEFAULT_MAX_BUFFER_SIZE));
496
497        client.write_message(&msg).unwrap();
498        client.flush().unwrap();
499
500        static SIGNATURE: &[ArgumentType] = &[ArgumentType::Fd, ArgumentType::Fd];
501
502        server.fill_incoming_buffers().unwrap();
503
504        let ret_msg =
505            server
506                .read_one_message(|sender_id, opcode| {
507                    if sender_id == 42 && opcode == 7 {
508                        Some(SIGNATURE)
509                    } else {
510                        None
511                    }
512                })
513                .unwrap();
514        assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
515    }
516
517    #[test]
518    fn write_read_cycle_multiple() {
519        let messages = vec![
520            Message {
521                sender_id: 42,
522                opcode: 0,
523                args: smallvec![
524                    Argument::Int(42),
525                    Argument::Str(Some(Box::new(CString::new(&b"I like trains"[..]).unwrap()))),
526                ],
527            },
528            Message {
529                sender_id: 42,
530                opcode: 1,
531                args: smallvec![
532                    Argument::Fd(1), // stdin
533                    Argument::Fd(0), // stdout
534                ],
535            },
536            Message {
537                sender_id: 42,
538                opcode: 2,
539                args: smallvec![
540                    Argument::Uint(3),
541                    Argument::Fd(2), // stderr
542                ],
543            },
544        ];
545
546        static SIGNATURES: &[&[ArgumentType]] = &[
547            &[ArgumentType::Int, ArgumentType::Str(AllowNull::No)],
548            &[ArgumentType::Fd, ArgumentType::Fd],
549            &[ArgumentType::Uint, ArgumentType::Fd],
550        ];
551
552        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
553        let mut client = BufferedSocket::new(Socket::from(client), Some(DEFAULT_MAX_BUFFER_SIZE));
554        let mut server = BufferedSocket::new(Socket::from(server), Some(DEFAULT_MAX_BUFFER_SIZE));
555
556        for msg in &messages {
557            client.write_message(msg).unwrap();
558        }
559        client.flush().unwrap();
560
561        server.fill_incoming_buffers().unwrap();
562
563        let mut recv_msgs = Vec::new();
564        while let Ok(message) = server.read_one_message(|sender_id, opcode| {
565            if sender_id == 42 {
566                Some(SIGNATURES[opcode as usize])
567            } else {
568                None
569            }
570        }) {
571            recv_msgs.push(message);
572        }
573        assert_eq!(recv_msgs.len(), 3);
574        for (msg1, msg2) in messages.into_iter().zip(recv_msgs.into_iter()) {
575            assert_eq_msgs(&msg1.map_fd(|fd| fd.as_raw_fd()), &msg2.map_fd(IntoRawFd::into_raw_fd));
576        }
577    }
578
579    #[test]
580    fn parse_with_string_len_multiple_of_4() {
581        let msg = Message {
582            sender_id: 2,
583            opcode: 0,
584            args: smallvec![
585                Argument::Uint(18),
586                Argument::Str(Some(Box::new(CString::new(&b"wl_shell"[..]).unwrap()))),
587                Argument::Uint(1),
588            ],
589        };
590
591        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
592        let mut client = BufferedSocket::new(Socket::from(client), Some(DEFAULT_MAX_BUFFER_SIZE));
593        let mut server = BufferedSocket::new(Socket::from(server), Some(DEFAULT_MAX_BUFFER_SIZE));
594
595        client.write_message(&msg).unwrap();
596        client.flush().unwrap();
597
598        static SIGNATURE: &[ArgumentType] =
599            &[ArgumentType::Uint, ArgumentType::Str(AllowNull::No), ArgumentType::Uint];
600
601        server.fill_incoming_buffers().unwrap();
602
603        let ret_msg =
604            server
605                .read_one_message(|sender_id, opcode| {
606                    if sender_id == 2 && opcode == 0 {
607                        Some(SIGNATURE)
608                    } else {
609                        None
610                    }
611                })
612                .unwrap();
613
614        assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
615    }
616}