1use std::collections::VecDeque;
4use std::io::{ErrorKind, IoSlice, IoSliceMut, Result as IoResult};
5use std::mem::MaybeUninit;
6use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
7use std::os::unix::net::UnixStream;
8use std::slice;
9
10use rustix::io::retry_on_intr;
11use rustix::net::{
12 recvmsg, send, sendmsg, RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags,
13 SendAncillaryBuffer, SendAncillaryMessage, SendFlags,
14};
15
16use crate::protocol::{ArgumentType, Message};
17
18use super::wire::{parse_message, write_to_buffers, MessageParseError, MessageWriteError};
19
20pub const MAX_FDS_OUT: usize = 28;
22pub const MAX_BYTES_OUT: usize = 4096;
24
25#[derive(Debug)]
31pub struct Socket {
32 stream: UnixStream,
33}
34
35impl Socket {
36 pub fn send_msg(&self, bytes: &[u8], fds: &[OwnedFd]) -> IoResult<usize> {
44 #[cfg(not(target_os = "macos"))]
45 let flags = SendFlags::DONTWAIT | SendFlags::NOSIGNAL;
46 #[cfg(target_os = "macos")]
47 let flags = SendFlags::DONTWAIT;
48
49 if !fds.is_empty() {
50 let iov = [IoSlice::new(bytes)];
51 let mut cmsg_space =
52 vec![MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(fds.len()))];
53 let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
54 let fds =
55 unsafe { slice::from_raw_parts(fds.as_ptr() as *const BorrowedFd, fds.len()) };
56 cmsg_buffer.push(SendAncillaryMessage::ScmRights(fds));
57 Ok(retry_on_intr(|| sendmsg(self, &iov, &mut cmsg_buffer, flags))?)
58 } else {
59 Ok(retry_on_intr(|| send(self, bytes, flags))?)
60 }
61 }
62
63 pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut VecDeque<OwnedFd>) -> IoResult<usize> {
75 #[cfg(not(target_os = "macos"))]
76 let flags = RecvFlags::DONTWAIT | RecvFlags::CMSG_CLOEXEC;
77 #[cfg(target_os = "macos")]
78 let flags = RecvFlags::DONTWAIT;
79
80 let mut cmsg_space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(MAX_FDS_OUT))];
81 let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
82 let mut iov = [IoSliceMut::new(buffer)];
83 let msg = retry_on_intr(|| recvmsg(&self.stream, &mut iov[..], &mut cmsg_buffer, flags))?;
84
85 let received_fds = cmsg_buffer
86 .drain()
87 .filter_map(|cmsg| match cmsg {
88 RecvAncillaryMessage::ScmRights(fds) => Some(fds),
89 _ => None,
90 })
91 .flatten();
92 fds.extend(received_fds);
93 #[cfg(target_os = "macos")]
94 for fd in fds.iter() {
95 if let Ok(flags) = rustix::io::fcntl_getfd(fd) {
96 let _ = rustix::io::fcntl_setfd(fd, flags | rustix::io::FdFlags::CLOEXEC);
97 }
98 }
99 Ok(msg.bytes)
100 }
101}
102
103impl From<UnixStream> for Socket {
104 fn from(stream: UnixStream) -> Self {
105 #[cfg(target_os = "macos")]
107 let _ = rustix::net::sockopt::set_socket_nosigpipe(&stream, true);
108 Self { stream }
109 }
110}
111
112impl AsFd for Socket {
113 fn as_fd(&self) -> BorrowedFd<'_> {
114 self.stream.as_fd()
115 }
116}
117
118impl AsRawFd for Socket {
119 fn as_raw_fd(&self) -> RawFd {
120 self.stream.as_raw_fd()
121 }
122}
123
124#[derive(Debug)]
131pub struct BufferedSocket {
132 socket: Socket,
133 in_data: Buffer<u8>,
134 in_fds: VecDeque<OwnedFd>,
135 out_data: Buffer<u8>,
136 out_fds: Vec<OwnedFd>,
137}
138
139impl BufferedSocket {
140 pub fn new(socket: Socket) -> Self {
142 Self {
143 socket,
144 in_data: Buffer::new(2 * MAX_BYTES_OUT), in_fds: VecDeque::new(), out_data: Buffer::new(MAX_BYTES_OUT),
147 out_fds: Vec::new(),
148 }
149 }
150
151 pub fn flush(&mut self) -> IoResult<()> {
153 let bytes = self.out_data.get_contents();
154 let fds = &self.out_fds;
155 let mut written_bytes = 0;
156 let mut written_fds = 0;
157 let mut ret = Ok(());
158 while written_bytes < bytes.len() {
159 let mut bytes_to_write = &bytes[written_bytes..];
160 let mut fds_to_write = &fds[written_fds..];
161 if fds_to_write.len() > MAX_FDS_OUT {
162 bytes_to_write = &bytes_to_write[..1];
169 fds_to_write = &fds_to_write[..MAX_FDS_OUT];
170 }
171 if bytes_to_write.len() > MAX_BYTES_OUT {
172 bytes_to_write = &bytes_to_write[..MAX_BYTES_OUT];
175 }
176 match self.socket.send_msg(bytes_to_write, fds_to_write) {
177 Ok(0) => {
178 written_fds += fds_to_write.len();
181 break;
182 }
183 Ok(count) => {
184 written_bytes += count;
185 written_fds += fds_to_write.len();
186 }
187 Err(error) => {
188 ret = Err(error);
189 break;
190 }
191 }
192 }
193 self.out_data.offset(written_bytes);
196 self.out_data.move_to_front();
197 self.out_fds.drain(..written_fds);
198 ret
199 }
200
201 fn attempt_write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<bool> {
209 match write_to_buffers(msg, self.out_data.get_writable_storage(), &mut self.out_fds) {
210 Ok(bytes_out) => {
211 self.out_data.advance(bytes_out);
212 Ok(true)
213 }
214 Err(MessageWriteError::BufferTooSmall) => Ok(false),
215 Err(MessageWriteError::DupFdFailed(e)) => Err(e),
216 }
217 }
218
219 pub fn write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<()> {
226 if !self.attempt_write_message(msg)? {
227 if let Err(e) = self.flush() {
230 if e.kind() != ErrorKind::WouldBlock {
231 return Err(e);
232 }
233 }
234 if !self.attempt_write_message(msg)? {
235 return Err(rustix::io::Errno::TOOBIG.into());
238 }
239 }
240 Ok(())
241 }
242
243 pub fn fill_incoming_buffers(&mut self) -> IoResult<()> {
246 self.in_data.move_to_front();
248 let in_bytes = {
250 let bytes = self.in_data.get_writable_storage();
251 self.socket.rcv_msg(bytes, &mut self.in_fds)?
252 };
253 if in_bytes == 0 {
254 return Err(rustix::io::Errno::PIPE.into());
256 }
257 self.in_data.advance(in_bytes);
259 Ok(())
260 }
261
262 pub fn read_one_message<F>(
268 &mut self,
269 mut signature: F,
270 ) -> Result<Message<u32, OwnedFd>, MessageParseError>
271 where
272 F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
273 {
274 let (msg, read_data) = {
275 let data = self.in_data.get_contents();
276 if data.len() < 2 * 4 {
277 return Err(MessageParseError::MissingData);
278 }
279 let object_id = u32::from_ne_bytes([data[0], data[1], data[2], data[3]]);
280 let word_2 = u32::from_ne_bytes([data[4], data[5], data[6], data[7]]);
281 let opcode = (word_2 & 0x0000_FFFF) as u16;
282 if let Some(sig) = signature(object_id, opcode) {
283 match parse_message(data, sig, &mut self.in_fds) {
284 Ok((msg, rest_data)) => (msg, data.len() - rest_data.len()),
285 Err(e) => return Err(e),
286 }
287 } else {
288 return Err(MessageParseError::Malformed);
290 }
291 };
292
293 self.in_data.offset(read_data);
294
295 Ok(msg)
296 }
297}
298
299impl AsRawFd for BufferedSocket {
300 fn as_raw_fd(&self) -> RawFd {
301 self.socket.as_raw_fd()
302 }
303}
304
305impl AsFd for BufferedSocket {
306 fn as_fd(&self) -> BorrowedFd<'_> {
307 self.socket.as_fd()
308 }
309}
310
311#[derive(Debug)]
315struct Buffer<T: Copy> {
316 storage: Vec<T>,
317 occupied: usize,
318 offset: usize,
319}
320
321impl<T: Copy + Default> Buffer<T> {
322 fn new(size: usize) -> Self {
323 Self { storage: vec![T::default(); size], occupied: 0, offset: 0 }
324 }
325
326 fn advance(&mut self, bytes: usize) {
328 self.occupied += bytes;
329 }
330
331 fn offset(&mut self, bytes: usize) {
333 self.offset += bytes;
334 }
335
336 #[allow(unused)]
341 fn clear(&mut self) {
342 self.occupied = 0;
343 self.offset = 0;
344 }
345
346 fn get_contents(&self) -> &[T] {
348 &self.storage[(self.offset)..(self.occupied)]
349 }
350
351 fn get_writable_storage(&mut self) -> &mut [T] {
353 &mut self.storage[(self.occupied)..]
354 }
355
356 fn move_to_front(&mut self) {
359 if self.occupied > self.offset {
360 self.storage.copy_within((self.offset)..(self.occupied), 0)
361 }
362 self.occupied -= self.offset;
363 self.offset = 0;
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::protocol::{AllowNull, Argument, ArgumentType, Message};
371
372 use std::ffi::CString;
373 use std::os::unix::io::IntoRawFd;
374
375 use smallvec::smallvec;
376
377 fn same_file(a: BorrowedFd, b: BorrowedFd) -> bool {
378 let stat1 = rustix::fs::fstat(a).unwrap();
379 let stat2 = rustix::fs::fstat(b).unwrap();
380 stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
381 }
382
383 fn assert_eq_msgs<Fd: AsRawFd + std::fmt::Debug>(
388 msg1: &Message<u32, Fd>,
389 msg2: &Message<u32, Fd>,
390 ) {
391 assert_eq!(msg1.sender_id, msg2.sender_id);
392 assert_eq!(msg1.opcode, msg2.opcode);
393 assert_eq!(msg1.args.len(), msg2.args.len());
394 for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
395 if let (Argument::Fd(fd1), Argument::Fd(fd2)) = (arg1, arg2) {
396 let fd1 = unsafe { BorrowedFd::borrow_raw(fd1.as_raw_fd()) };
397 let fd2 = unsafe { BorrowedFd::borrow_raw(fd2.as_raw_fd()) };
398 assert!(same_file(fd1, fd2));
399 } else {
400 assert_eq!(arg1, arg2);
401 }
402 }
403 }
404
405 #[test]
406 fn write_read_cycle() {
407 let msg = Message {
408 sender_id: 42,
409 opcode: 7,
410 args: smallvec![
411 Argument::Uint(3),
412 Argument::Fixed(-89),
413 Argument::Str(Some(Box::new(CString::new(&b"I like trains!"[..]).unwrap()))),
414 Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into()),
415 Argument::Object(88),
416 Argument::NewId(56),
417 Argument::Int(-25),
418 ],
419 };
420
421 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
422 let mut client = BufferedSocket::new(Socket::from(client));
423 let mut server = BufferedSocket::new(Socket::from(server));
424
425 client.write_message(&msg).unwrap();
426 client.flush().unwrap();
427
428 static SIGNATURE: &[ArgumentType] = &[
429 ArgumentType::Uint,
430 ArgumentType::Fixed,
431 ArgumentType::Str(AllowNull::No),
432 ArgumentType::Array,
433 ArgumentType::Object(AllowNull::No),
434 ArgumentType::NewId,
435 ArgumentType::Int,
436 ];
437
438 server.fill_incoming_buffers().unwrap();
439
440 let ret_msg =
441 server
442 .read_one_message(|sender_id, opcode| {
443 if sender_id == 42 && opcode == 7 {
444 Some(SIGNATURE)
445 } else {
446 None
447 }
448 })
449 .unwrap();
450
451 assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
452 }
453
454 #[test]
455 fn write_read_cycle_fd() {
456 let msg = Message {
457 sender_id: 42,
458 opcode: 7,
459 args: smallvec![
460 Argument::Fd(1), Argument::Fd(0), ],
463 };
464
465 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
466 let mut client = BufferedSocket::new(Socket::from(client));
467 let mut server = BufferedSocket::new(Socket::from(server));
468
469 client.write_message(&msg).unwrap();
470 client.flush().unwrap();
471
472 static SIGNATURE: &[ArgumentType] = &[ArgumentType::Fd, ArgumentType::Fd];
473
474 server.fill_incoming_buffers().unwrap();
475
476 let ret_msg =
477 server
478 .read_one_message(|sender_id, opcode| {
479 if sender_id == 42 && opcode == 7 {
480 Some(SIGNATURE)
481 } else {
482 None
483 }
484 })
485 .unwrap();
486 assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
487 }
488
489 #[test]
490 fn write_read_cycle_multiple() {
491 let messages = vec![
492 Message {
493 sender_id: 42,
494 opcode: 0,
495 args: smallvec![
496 Argument::Int(42),
497 Argument::Str(Some(Box::new(CString::new(&b"I like trains"[..]).unwrap()))),
498 ],
499 },
500 Message {
501 sender_id: 42,
502 opcode: 1,
503 args: smallvec![
504 Argument::Fd(1), Argument::Fd(0), ],
507 },
508 Message {
509 sender_id: 42,
510 opcode: 2,
511 args: smallvec![
512 Argument::Uint(3),
513 Argument::Fd(2), ],
515 },
516 ];
517
518 static SIGNATURES: &[&[ArgumentType]] = &[
519 &[ArgumentType::Int, ArgumentType::Str(AllowNull::No)],
520 &[ArgumentType::Fd, ArgumentType::Fd],
521 &[ArgumentType::Uint, ArgumentType::Fd],
522 ];
523
524 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
525 let mut client = BufferedSocket::new(Socket::from(client));
526 let mut server = BufferedSocket::new(Socket::from(server));
527
528 for msg in &messages {
529 client.write_message(msg).unwrap();
530 }
531 client.flush().unwrap();
532
533 server.fill_incoming_buffers().unwrap();
534
535 let mut recv_msgs = Vec::new();
536 while let Ok(message) = server.read_one_message(|sender_id, opcode| {
537 if sender_id == 42 {
538 Some(SIGNATURES[opcode as usize])
539 } else {
540 None
541 }
542 }) {
543 recv_msgs.push(message);
544 }
545 assert_eq!(recv_msgs.len(), 3);
546 for (msg1, msg2) in messages.into_iter().zip(recv_msgs.into_iter()) {
547 assert_eq_msgs(&msg1.map_fd(|fd| fd.as_raw_fd()), &msg2.map_fd(IntoRawFd::into_raw_fd));
548 }
549 }
550
551 #[test]
552 fn parse_with_string_len_multiple_of_4() {
553 let msg = Message {
554 sender_id: 2,
555 opcode: 0,
556 args: smallvec![
557 Argument::Uint(18),
558 Argument::Str(Some(Box::new(CString::new(&b"wl_shell"[..]).unwrap()))),
559 Argument::Uint(1),
560 ],
561 };
562
563 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
564 let mut client = BufferedSocket::new(Socket::from(client));
565 let mut server = BufferedSocket::new(Socket::from(server));
566
567 client.write_message(&msg).unwrap();
568 client.flush().unwrap();
569
570 static SIGNATURE: &[ArgumentType] =
571 &[ArgumentType::Uint, ArgumentType::Str(AllowNull::No), ArgumentType::Uint];
572
573 server.fill_incoming_buffers().unwrap();
574
575 let ret_msg =
576 server
577 .read_one_message(|sender_id, opcode| {
578 if sender_id == 2 && opcode == 0 {
579 Some(SIGNATURE)
580 } else {
581 None
582 }
583 })
584 .unwrap();
585
586 assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
587 }
588}