wayland_backend/rs/
socket.rs

1//! Wayland socket manipulation
2
3use std::collections::VecDeque;
4use std::io::{ErrorKind, IoSlice, IoSliceMut, Result as IoResult};
5use std::mem::MaybeUninit;
6use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
7use std::os::unix::net::UnixStream;
8use std::slice;
9
10use rustix::io::retry_on_intr;
11use rustix::net::{
12    recvmsg, send, sendmsg, RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags,
13    SendAncillaryBuffer, SendAncillaryMessage, SendFlags,
14};
15
16use crate::protocol::{ArgumentType, Message};
17
18use super::wire::{parse_message, write_to_buffers, MessageParseError, MessageWriteError};
19
20/// Maximum number of FD that can be sent in a single socket message
21pub const MAX_FDS_OUT: usize = 28;
22/// Maximum number of bytes that can be sent in a single socket message
23pub const MAX_BYTES_OUT: usize = 4096;
24
25/*
26 * Socket
27 */
28
29/// A wayland socket
30#[derive(Debug)]
31pub struct Socket {
32    stream: UnixStream,
33}
34
35impl Socket {
36    /// Send a single message to the socket
37    ///
38    /// A single socket message can contain several wayland messages
39    ///
40    /// The `fds` slice should not be longer than `MAX_FDS_OUT`, and the `bytes`
41    /// slice should not be longer than `MAX_BYTES_OUT` otherwise the receiving
42    /// end may lose some data.
43    pub fn send_msg(&self, bytes: &[u8], fds: &[OwnedFd]) -> IoResult<usize> {
44        #[cfg(not(target_os = "macos"))]
45        let flags = SendFlags::DONTWAIT | SendFlags::NOSIGNAL;
46        #[cfg(target_os = "macos")]
47        let flags = SendFlags::DONTWAIT;
48
49        if !fds.is_empty() {
50            let iov = [IoSlice::new(bytes)];
51            let mut cmsg_space =
52                vec![MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(fds.len()))];
53            let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
54            let fds =
55                unsafe { slice::from_raw_parts(fds.as_ptr() as *const BorrowedFd, fds.len()) };
56            cmsg_buffer.push(SendAncillaryMessage::ScmRights(fds));
57            Ok(retry_on_intr(|| sendmsg(self, &iov, &mut cmsg_buffer, flags))?)
58        } else {
59            Ok(retry_on_intr(|| send(self, bytes, flags))?)
60        }
61    }
62
63    /// Receive a single message from the socket
64    ///
65    /// Return the number of bytes received and the number of Fds received.
66    ///
67    /// Errors with `WouldBlock` is no message is available.
68    ///
69    /// A single socket message can contain several wayland messages.
70    ///
71    /// The `buffer` slice should be at least `MAX_BYTES_OUT` long and the `fds`
72    /// slice `MAX_FDS_OUT` long, otherwise some data of the received message may
73    /// be lost.
74    pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut VecDeque<OwnedFd>) -> IoResult<usize> {
75        #[cfg(not(target_os = "macos"))]
76        let flags = RecvFlags::DONTWAIT | RecvFlags::CMSG_CLOEXEC;
77        #[cfg(target_os = "macos")]
78        let flags = RecvFlags::DONTWAIT;
79
80        let mut cmsg_space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(MAX_FDS_OUT))];
81        let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
82        let mut iov = [IoSliceMut::new(buffer)];
83        let msg = retry_on_intr(|| recvmsg(&self.stream, &mut iov[..], &mut cmsg_buffer, flags))?;
84
85        let received_fds = cmsg_buffer
86            .drain()
87            .filter_map(|cmsg| match cmsg {
88                RecvAncillaryMessage::ScmRights(fds) => Some(fds),
89                _ => None,
90            })
91            .flatten();
92        fds.extend(received_fds);
93        #[cfg(target_os = "macos")]
94        for fd in fds.iter() {
95            if let Ok(flags) = rustix::io::fcntl_getfd(fd) {
96                let _ = rustix::io::fcntl_setfd(fd, flags | rustix::io::FdFlags::CLOEXEC);
97            }
98        }
99        Ok(msg.bytes)
100    }
101}
102
103impl From<UnixStream> for Socket {
104    fn from(stream: UnixStream) -> Self {
105        // macOS doesn't have MSG_NOSIGNAL, but has SO_NOSIGPIPE instead
106        #[cfg(target_os = "macos")]
107        let _ = rustix::net::sockopt::set_socket_nosigpipe(&stream, true);
108        Self { stream }
109    }
110}
111
112impl AsFd for Socket {
113    fn as_fd(&self) -> BorrowedFd<'_> {
114        self.stream.as_fd()
115    }
116}
117
118impl AsRawFd for Socket {
119    fn as_raw_fd(&self) -> RawFd {
120        self.stream.as_raw_fd()
121    }
122}
123
124/*
125 * BufferedSocket
126 */
127
128/// An adapter around a raw Socket that directly handles buffering and
129/// conversion from/to wayland messages
130#[derive(Debug)]
131pub struct BufferedSocket {
132    socket: Socket,
133    in_data: Buffer<u8>,
134    in_fds: VecDeque<OwnedFd>,
135    out_data: Buffer<u8>,
136    out_fds: Vec<OwnedFd>,
137}
138
139impl BufferedSocket {
140    /// Wrap a Socket into a Buffered Socket
141    pub fn new(socket: Socket) -> Self {
142        Self {
143            socket,
144            in_data: Buffer::new(2 * MAX_BYTES_OUT), // Incoming buffers are twice as big in order to be
145            in_fds: VecDeque::new(),                 // able to store leftover data if needed
146            out_data: Buffer::new(MAX_BYTES_OUT),
147            out_fds: Vec::new(),
148        }
149    }
150
151    /// Flush the contents of the outgoing buffer into the socket
152    pub fn flush(&mut self) -> IoResult<()> {
153        let bytes = self.out_data.get_contents();
154        let fds = &self.out_fds;
155        let mut written_bytes = 0;
156        let mut written_fds = 0;
157        let mut ret = Ok(());
158        while written_bytes < bytes.len() {
159            let mut bytes_to_write = &bytes[written_bytes..];
160            let mut fds_to_write = &fds[written_fds..];
161            if fds_to_write.len() > MAX_FDS_OUT {
162                // While we need to send more than MAX_FDS_OUT fds,
163                // send them with separate send_msg calls in MAX_FDS_OUT sized chunks
164                // together with 1 byte of normal data.
165                // This achieves the same that libwayland does in wl_connection_flush
166                // and ensures that all file descriptors are sent
167                // before we run out of bytes of normal data to send.
168                bytes_to_write = &bytes_to_write[..1];
169                fds_to_write = &fds_to_write[..MAX_FDS_OUT];
170            }
171            if bytes_to_write.len() > MAX_BYTES_OUT {
172                // Also ensure the MAX_BYTES_OUT limit, this stays redundant as long
173                // as self.out_data has a fixed MAX_BYTES_OUT size and cannot grow.
174                bytes_to_write = &bytes_to_write[..MAX_BYTES_OUT];
175            }
176            match self.socket.send_msg(bytes_to_write, fds_to_write) {
177                Ok(0) => {
178                    // This branch should be unreachable because
179                    // a non-zero sized send or sendmsg should never return 0
180                    written_fds += fds_to_write.len();
181                    break;
182                }
183                Ok(count) => {
184                    written_bytes += count;
185                    written_fds += fds_to_write.len();
186                }
187                Err(error) => {
188                    ret = Err(error);
189                    break;
190                }
191            }
192        }
193        // Either the flush is done or got an error, so clean up and
194        // remove the data and the fds which were sent form the outgoing buffers
195        self.out_data.offset(written_bytes);
196        self.out_data.move_to_front();
197        self.out_fds.drain(..written_fds);
198        ret
199    }
200
201    // internal method
202    //
203    // attempts to write a message in the internal out buffers,
204    // returns true if successful
205    //
206    // if false is returned, it means there is not enough space
207    // in the buffer
208    fn attempt_write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<bool> {
209        match write_to_buffers(msg, self.out_data.get_writable_storage(), &mut self.out_fds) {
210            Ok(bytes_out) => {
211                self.out_data.advance(bytes_out);
212                Ok(true)
213            }
214            Err(MessageWriteError::BufferTooSmall) => Ok(false),
215            Err(MessageWriteError::DupFdFailed(e)) => Err(e),
216        }
217    }
218
219    /// Write a message to the outgoing buffer
220    ///
221    /// This method may flush the internal buffer if necessary (if it is full).
222    ///
223    /// If the message is too big to fit in the buffer, the error `Error::Sys(E2BIG)`
224    /// will be returned.
225    pub fn write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<()> {
226        if !self.attempt_write_message(msg)? {
227            // the attempt failed, there is not enough space in the buffer
228            // we need to flush it
229            if let Err(e) = self.flush() {
230                if e.kind() != ErrorKind::WouldBlock {
231                    return Err(e);
232                }
233            }
234            if !self.attempt_write_message(msg)? {
235                // If this fails again, this means the message is too big
236                // to be transmitted at all
237                return Err(rustix::io::Errno::TOOBIG.into());
238            }
239        }
240        Ok(())
241    }
242
243    /// Try to fill the incoming buffers of this socket, to prepare
244    /// a new round of parsing.
245    pub fn fill_incoming_buffers(&mut self) -> IoResult<()> {
246        // reorganize the buffers
247        self.in_data.move_to_front();
248        // receive a message
249        let in_bytes = {
250            let bytes = self.in_data.get_writable_storage();
251            self.socket.rcv_msg(bytes, &mut self.in_fds)?
252        };
253        if in_bytes == 0 {
254            // the other end of the socket was closed
255            return Err(rustix::io::Errno::PIPE.into());
256        }
257        // advance the storage
258        self.in_data.advance(in_bytes);
259        Ok(())
260    }
261
262    /// Read and deserialize a single message from the incoming buffers socket
263    ///
264    /// This method requires one closure that given an object id and an opcode,
265    /// must provide the signature of the associated request/event, in the form of
266    /// a `&'static [ArgumentType]`.
267    pub fn read_one_message<F>(
268        &mut self,
269        mut signature: F,
270    ) -> Result<Message<u32, OwnedFd>, MessageParseError>
271    where
272        F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
273    {
274        let (msg, read_data) = {
275            let data = self.in_data.get_contents();
276            if data.len() < 2 * 4 {
277                return Err(MessageParseError::MissingData);
278            }
279            let object_id = u32::from_ne_bytes([data[0], data[1], data[2], data[3]]);
280            let word_2 = u32::from_ne_bytes([data[4], data[5], data[6], data[7]]);
281            let opcode = (word_2 & 0x0000_FFFF) as u16;
282            if let Some(sig) = signature(object_id, opcode) {
283                match parse_message(data, sig, &mut self.in_fds) {
284                    Ok((msg, rest_data)) => (msg, data.len() - rest_data.len()),
285                    Err(e) => return Err(e),
286                }
287            } else {
288                // no signature found ?
289                return Err(MessageParseError::Malformed);
290            }
291        };
292
293        self.in_data.offset(read_data);
294
295        Ok(msg)
296    }
297}
298
299impl AsRawFd for BufferedSocket {
300    fn as_raw_fd(&self) -> RawFd {
301        self.socket.as_raw_fd()
302    }
303}
304
305impl AsFd for BufferedSocket {
306    fn as_fd(&self) -> BorrowedFd<'_> {
307        self.socket.as_fd()
308    }
309}
310
311/*
312 * Buffer
313 */
314#[derive(Debug)]
315struct Buffer<T: Copy> {
316    storage: Vec<T>,
317    occupied: usize,
318    offset: usize,
319}
320
321impl<T: Copy + Default> Buffer<T> {
322    fn new(size: usize) -> Self {
323        Self { storage: vec![T::default(); size], occupied: 0, offset: 0 }
324    }
325
326    /// Advance the internal counter of occupied space
327    fn advance(&mut self, bytes: usize) {
328        self.occupied += bytes;
329    }
330
331    /// Advance the read offset of current occupied space
332    fn offset(&mut self, bytes: usize) {
333        self.offset += bytes;
334    }
335
336    /// Clears the contents of the buffer
337    ///
338    /// This only sets the counter of occupied space back to zero,
339    /// allowing previous content to be overwritten.
340    #[allow(unused)]
341    fn clear(&mut self) {
342        self.occupied = 0;
343        self.offset = 0;
344    }
345
346    /// Get the current contents of the occupied space of the buffer
347    fn get_contents(&self) -> &[T] {
348        &self.storage[(self.offset)..(self.occupied)]
349    }
350
351    /// Get mutable access to the unoccupied space of the buffer
352    fn get_writable_storage(&mut self) -> &mut [T] {
353        &mut self.storage[(self.occupied)..]
354    }
355
356    /// Move the unread contents of the buffer to the front, to ensure
357    /// maximal write space availability
358    fn move_to_front(&mut self) {
359        if self.occupied > self.offset {
360            self.storage.copy_within((self.offset)..(self.occupied), 0)
361        }
362        self.occupied -= self.offset;
363        self.offset = 0;
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::protocol::{AllowNull, Argument, ArgumentType, Message};
371
372    use std::ffi::CString;
373    use std::os::unix::io::IntoRawFd;
374
375    use smallvec::smallvec;
376
377    fn same_file(a: BorrowedFd, b: BorrowedFd) -> bool {
378        let stat1 = rustix::fs::fstat(a).unwrap();
379        let stat2 = rustix::fs::fstat(b).unwrap();
380        stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
381    }
382
383    // check if two messages are equal
384    //
385    // if arguments contain FDs, check that the fd point to
386    // the same file, rather than are the same number.
387    fn assert_eq_msgs<Fd: AsRawFd + std::fmt::Debug>(
388        msg1: &Message<u32, Fd>,
389        msg2: &Message<u32, Fd>,
390    ) {
391        assert_eq!(msg1.sender_id, msg2.sender_id);
392        assert_eq!(msg1.opcode, msg2.opcode);
393        assert_eq!(msg1.args.len(), msg2.args.len());
394        for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
395            if let (Argument::Fd(fd1), Argument::Fd(fd2)) = (arg1, arg2) {
396                let fd1 = unsafe { BorrowedFd::borrow_raw(fd1.as_raw_fd()) };
397                let fd2 = unsafe { BorrowedFd::borrow_raw(fd2.as_raw_fd()) };
398                assert!(same_file(fd1, fd2));
399            } else {
400                assert_eq!(arg1, arg2);
401            }
402        }
403    }
404
405    #[test]
406    fn write_read_cycle() {
407        let msg = Message {
408            sender_id: 42,
409            opcode: 7,
410            args: smallvec![
411                Argument::Uint(3),
412                Argument::Fixed(-89),
413                Argument::Str(Some(Box::new(CString::new(&b"I like trains!"[..]).unwrap()))),
414                Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into()),
415                Argument::Object(88),
416                Argument::NewId(56),
417                Argument::Int(-25),
418            ],
419        };
420
421        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
422        let mut client = BufferedSocket::new(Socket::from(client));
423        let mut server = BufferedSocket::new(Socket::from(server));
424
425        client.write_message(&msg).unwrap();
426        client.flush().unwrap();
427
428        static SIGNATURE: &[ArgumentType] = &[
429            ArgumentType::Uint,
430            ArgumentType::Fixed,
431            ArgumentType::Str(AllowNull::No),
432            ArgumentType::Array,
433            ArgumentType::Object(AllowNull::No),
434            ArgumentType::NewId,
435            ArgumentType::Int,
436        ];
437
438        server.fill_incoming_buffers().unwrap();
439
440        let ret_msg =
441            server
442                .read_one_message(|sender_id, opcode| {
443                    if sender_id == 42 && opcode == 7 {
444                        Some(SIGNATURE)
445                    } else {
446                        None
447                    }
448                })
449                .unwrap();
450
451        assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
452    }
453
454    #[test]
455    fn write_read_cycle_fd() {
456        let msg = Message {
457            sender_id: 42,
458            opcode: 7,
459            args: smallvec![
460                Argument::Fd(1), // stdin
461                Argument::Fd(0), // stdout
462            ],
463        };
464
465        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
466        let mut client = BufferedSocket::new(Socket::from(client));
467        let mut server = BufferedSocket::new(Socket::from(server));
468
469        client.write_message(&msg).unwrap();
470        client.flush().unwrap();
471
472        static SIGNATURE: &[ArgumentType] = &[ArgumentType::Fd, ArgumentType::Fd];
473
474        server.fill_incoming_buffers().unwrap();
475
476        let ret_msg =
477            server
478                .read_one_message(|sender_id, opcode| {
479                    if sender_id == 42 && opcode == 7 {
480                        Some(SIGNATURE)
481                    } else {
482                        None
483                    }
484                })
485                .unwrap();
486        assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
487    }
488
489    #[test]
490    fn write_read_cycle_multiple() {
491        let messages = vec![
492            Message {
493                sender_id: 42,
494                opcode: 0,
495                args: smallvec![
496                    Argument::Int(42),
497                    Argument::Str(Some(Box::new(CString::new(&b"I like trains"[..]).unwrap()))),
498                ],
499            },
500            Message {
501                sender_id: 42,
502                opcode: 1,
503                args: smallvec![
504                    Argument::Fd(1), // stdin
505                    Argument::Fd(0), // stdout
506                ],
507            },
508            Message {
509                sender_id: 42,
510                opcode: 2,
511                args: smallvec![
512                    Argument::Uint(3),
513                    Argument::Fd(2), // stderr
514                ],
515            },
516        ];
517
518        static SIGNATURES: &[&[ArgumentType]] = &[
519            &[ArgumentType::Int, ArgumentType::Str(AllowNull::No)],
520            &[ArgumentType::Fd, ArgumentType::Fd],
521            &[ArgumentType::Uint, ArgumentType::Fd],
522        ];
523
524        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
525        let mut client = BufferedSocket::new(Socket::from(client));
526        let mut server = BufferedSocket::new(Socket::from(server));
527
528        for msg in &messages {
529            client.write_message(msg).unwrap();
530        }
531        client.flush().unwrap();
532
533        server.fill_incoming_buffers().unwrap();
534
535        let mut recv_msgs = Vec::new();
536        while let Ok(message) = server.read_one_message(|sender_id, opcode| {
537            if sender_id == 42 {
538                Some(SIGNATURES[opcode as usize])
539            } else {
540                None
541            }
542        }) {
543            recv_msgs.push(message);
544        }
545        assert_eq!(recv_msgs.len(), 3);
546        for (msg1, msg2) in messages.into_iter().zip(recv_msgs.into_iter()) {
547            assert_eq_msgs(&msg1.map_fd(|fd| fd.as_raw_fd()), &msg2.map_fd(IntoRawFd::into_raw_fd));
548        }
549    }
550
551    #[test]
552    fn parse_with_string_len_multiple_of_4() {
553        let msg = Message {
554            sender_id: 2,
555            opcode: 0,
556            args: smallvec![
557                Argument::Uint(18),
558                Argument::Str(Some(Box::new(CString::new(&b"wl_shell"[..]).unwrap()))),
559                Argument::Uint(1),
560            ],
561        };
562
563        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
564        let mut client = BufferedSocket::new(Socket::from(client));
565        let mut server = BufferedSocket::new(Socket::from(server));
566
567        client.write_message(&msg).unwrap();
568        client.flush().unwrap();
569
570        static SIGNATURE: &[ArgumentType] =
571            &[ArgumentType::Uint, ArgumentType::Str(AllowNull::No), ArgumentType::Uint];
572
573        server.fill_incoming_buffers().unwrap();
574
575        let ret_msg =
576            server
577                .read_one_message(|sender_id, opcode| {
578                    if sender_id == 2 && opcode == 0 {
579                        Some(SIGNATURE)
580                    } else {
581                        None
582                    }
583                })
584                .unwrap();
585
586        assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
587    }
588}