1use std::collections::VecDeque;
4use std::io::{ErrorKind, IoSlice, IoSliceMut, Result as IoResult};
5use std::mem::MaybeUninit;
6use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
7use std::os::unix::net::UnixStream;
8use std::slice;
9
10use rustix::io::retry_on_intr;
11use rustix::net::{
12 recvmsg, send, sendmsg, RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags,
13 SendAncillaryBuffer, SendAncillaryMessage, SendFlags,
14};
15
16use crate::protocol::{ArgumentType, Message};
17use crate::rs::DEFAULT_MAX_BUFFER_SIZE;
18
19use super::wire::{parse_message, write_to_buffers, MessageParseError, MessageWriteError};
20
21pub const MAX_FDS_OUT: usize = 28;
23pub const MAX_BYTES_OUT: usize = 4096;
25
26#[derive(Debug)]
32pub struct Socket {
33 stream: UnixStream,
34}
35
36impl Socket {
37 pub fn send_msg(&self, bytes: &[u8], fds: &[OwnedFd]) -> IoResult<usize> {
45 #[cfg(not(any(target_os = "macos", target_os = "redox")))]
46 let flags = SendFlags::DONTWAIT | SendFlags::NOSIGNAL;
47 #[cfg(any(target_os = "macos", target_os = "redox"))]
48 let flags = SendFlags::DONTWAIT;
49
50 if !fds.is_empty() {
51 let iov = [IoSlice::new(bytes)];
52 let mut cmsg_space =
53 vec![MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(fds.len()))];
54 let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
55 let fds =
56 unsafe { slice::from_raw_parts(fds.as_ptr() as *const BorrowedFd, fds.len()) };
57 cmsg_buffer.push(SendAncillaryMessage::ScmRights(fds));
58 Ok(retry_on_intr(|| sendmsg(self, &iov, &mut cmsg_buffer, flags))?)
59 } else {
60 Ok(retry_on_intr(|| send(self, bytes, flags))?)
61 }
62 }
63
64 pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut VecDeque<OwnedFd>) -> IoResult<usize> {
76 #[cfg(not(any(target_os = "macos", target_os = "redox")))]
77 let flags = RecvFlags::DONTWAIT | RecvFlags::CMSG_CLOEXEC;
78 #[cfg(any(target_os = "macos", target_os = "redox"))]
79 let flags = RecvFlags::DONTWAIT;
80
81 let mut cmsg_space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(MAX_FDS_OUT))];
82 let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
83 let mut iov = [IoSliceMut::new(buffer)];
84 let msg = retry_on_intr(|| recvmsg(&self.stream, &mut iov[..], &mut cmsg_buffer, flags))?;
85
86 let received_fds = cmsg_buffer
87 .drain()
88 .filter_map(|cmsg| match cmsg {
89 RecvAncillaryMessage::ScmRights(fds) => Some(fds),
90 _ => None,
91 })
92 .flatten();
93 fds.extend(received_fds);
94 #[cfg(any(target_os = "macos", target_os = "redox"))]
95 for fd in fds.iter() {
96 if let Ok(flags) = rustix::io::fcntl_getfd(fd) {
97 let _ = rustix::io::fcntl_setfd(fd, flags | rustix::io::FdFlags::CLOEXEC);
98 }
99 }
100 Ok(msg.bytes)
101 }
102}
103
104impl From<UnixStream> for Socket {
105 fn from(stream: UnixStream) -> Self {
106 #[cfg(target_os = "macos")]
108 let _ = rustix::net::sockopt::set_socket_nosigpipe(&stream, true);
109 Self { stream }
110 }
111}
112
113impl AsFd for Socket {
114 fn as_fd(&self) -> BorrowedFd<'_> {
115 self.stream.as_fd()
116 }
117}
118
119impl AsRawFd for Socket {
120 fn as_raw_fd(&self) -> RawFd {
121 self.stream.as_raw_fd()
122 }
123}
124
125#[derive(Debug)]
132pub struct BufferedSocket {
133 socket: Socket,
134 in_data: Buffer<u8>,
135 in_fds: VecDeque<OwnedFd>,
136 out_data: Buffer<u8>,
137 out_fds: Vec<OwnedFd>,
138 max_buffer_size: Option<usize>,
139}
140
141impl BufferedSocket {
142 pub fn new(socket: Socket, max_buffer_size: Option<usize>) -> Self {
144 Self {
145 socket,
146 in_data: Buffer::new(DEFAULT_MAX_BUFFER_SIZE),
147 in_fds: VecDeque::new(),
148 out_data: Buffer::new(DEFAULT_MAX_BUFFER_SIZE),
149 out_fds: Vec::new(),
150 max_buffer_size: max_buffer_size.map(round_max_buffer_size),
151 }
152 }
153
154 pub fn flush(&mut self) -> IoResult<()> {
156 let bytes = self.out_data.get_contents();
157 let fds = &self.out_fds;
158 let mut written_bytes = 0;
159 let mut written_fds = 0;
160 let mut ret = Ok(());
161 while written_bytes < bytes.len() {
162 let mut bytes_to_write = &bytes[written_bytes..];
163 let mut fds_to_write = &fds[written_fds..];
164 if fds_to_write.len() > MAX_FDS_OUT {
165 bytes_to_write = &bytes_to_write[..1];
172 fds_to_write = &fds_to_write[..MAX_FDS_OUT];
173 }
174 if bytes_to_write.len() > MAX_BYTES_OUT {
175 bytes_to_write = &bytes_to_write[..MAX_BYTES_OUT];
178 }
179 match self.socket.send_msg(bytes_to_write, fds_to_write) {
180 Ok(0) => {
181 written_fds += fds_to_write.len();
184 break;
185 }
186 Ok(count) => {
187 written_bytes += count;
188 written_fds += fds_to_write.len();
189 }
190 Err(error) => {
191 ret = Err(error);
192 break;
193 }
194 }
195 }
196 self.out_data.offset(written_bytes);
199 self.out_data.move_to_front();
200 self.out_fds.drain(..written_fds);
201 ret
202 }
203
204 fn attempt_write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<bool> {
212 let fds_len = self.out_fds.len();
213 loop {
214 match write_to_buffers(msg, self.out_data.get_writable_storage(), &mut self.out_fds) {
215 Ok(bytes_out) => {
216 self.out_data.advance(bytes_out);
217 return Ok(true);
218 }
219 Err(MessageWriteError::BufferTooSmall) => {
220 self.out_fds.truncate(fds_len);
221 if self.max_buffer_size.is_some_and(|s| s <= self.out_data.storage.len()) {
222 return Ok(false);
223 }
224 self.out_data.increase_capacity();
225 }
226 Err(MessageWriteError::DupFdFailed(e)) => {
227 return Err(e);
228 }
229 }
230 }
231 }
232
233 pub fn write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<()> {
240 if !self.attempt_write_message(msg)? {
241 if let Err(e) = self.flush() {
244 if e.kind() != ErrorKind::WouldBlock {
245 return Err(e);
246 }
247 }
248 if !self.attempt_write_message(msg)? {
249 return Err(rustix::io::Errno::TOOBIG.into());
252 }
253 }
254 Ok(())
255 }
256
257 pub fn fill_incoming_buffers(&mut self) -> IoResult<()> {
260 self.in_data.move_to_front();
262 let in_bytes = {
264 let bytes = self.in_data.get_writable_storage();
265 self.socket.rcv_msg(bytes, &mut self.in_fds)?
266 };
267 if in_bytes == 0 {
268 return Err(rustix::io::Errno::PIPE.into());
270 }
271 self.in_data.advance(in_bytes);
273 Ok(())
274 }
275
276 pub fn read_one_message<F>(
282 &mut self,
283 mut signature: F,
284 ) -> Result<Message<u32, OwnedFd>, MessageParseError>
285 where
286 F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
287 {
288 let (msg, read_data) = {
289 let data = self.in_data.get_contents();
290 if data.len() < 2 * 4 {
291 return Err(MessageParseError::MissingData);
292 }
293 let object_id = u32::from_ne_bytes([data[0], data[1], data[2], data[3]]);
294 let word_2 = u32::from_ne_bytes([data[4], data[5], data[6], data[7]]);
295 let opcode = (word_2 & 0x0000_FFFF) as u16;
296 if let Some(sig) = signature(object_id, opcode) {
297 match parse_message(data, sig, &mut self.in_fds) {
298 Ok((msg, rest_data)) => (msg, data.len() - rest_data.len()),
299 Err(e) => return Err(e),
300 }
301 } else {
302 return Err(MessageParseError::Malformed);
304 }
305 };
306
307 self.in_data.offset(read_data);
308
309 Ok(msg)
310 }
311
312 pub fn set_max_buffer_size(&mut self, max_buffer_size: Option<usize>) {
313 self.max_buffer_size = max_buffer_size.map(round_max_buffer_size);
315 }
316}
317
318impl AsRawFd for BufferedSocket {
319 fn as_raw_fd(&self) -> RawFd {
320 self.socket.as_raw_fd()
321 }
322}
323
324impl AsFd for BufferedSocket {
325 fn as_fd(&self) -> BorrowedFd<'_> {
326 self.socket.as_fd()
327 }
328}
329
330#[derive(Debug)]
334struct Buffer<T: Copy> {
335 storage: Vec<T>,
336 occupied: usize,
337 offset: usize,
338}
339
340impl<T: Copy + Default> Buffer<T> {
341 fn new(size: usize) -> Self {
342 Self { storage: vec![T::default(); size], occupied: 0, offset: 0 }
343 }
344
345 fn advance(&mut self, bytes: usize) {
347 self.occupied += bytes;
348 }
349
350 fn offset(&mut self, bytes: usize) {
352 self.offset += bytes;
353 }
354
355 fn increase_capacity(&mut self) {
356 let new_len = self.storage.len() * 2;
357 self.storage.resize_with(new_len, Default::default);
358 }
359
360 #[allow(unused)]
365 fn clear(&mut self) {
366 self.occupied = 0;
367 self.offset = 0;
368 }
369
370 fn get_contents(&self) -> &[T] {
372 &self.storage[(self.offset)..(self.occupied)]
373 }
374
375 fn get_writable_storage(&mut self) -> &mut [T] {
377 &mut self.storage[(self.occupied)..]
378 }
379
380 fn move_to_front(&mut self) {
383 if self.occupied > self.offset {
384 self.storage.copy_within((self.offset)..(self.occupied), 0)
385 }
386 self.occupied -= self.offset;
387 self.offset = 0;
388 }
389}
390
391fn round_max_buffer_size(max_buffer_size: usize) -> usize {
392 max_buffer_size.checked_next_power_of_two().unwrap_or(usize::MAX).max(DEFAULT_MAX_BUFFER_SIZE)
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use crate::protocol::{AllowNull, Argument, ArgumentType, Message};
399
400 use std::ffi::CString;
401 use std::os::unix::io::IntoRawFd;
402
403 use smallvec::smallvec;
404
405 fn same_file(a: BorrowedFd, b: BorrowedFd) -> bool {
406 let stat1 = rustix::fs::fstat(a).unwrap();
407 let stat2 = rustix::fs::fstat(b).unwrap();
408 stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
409 }
410
411 fn assert_eq_msgs<Fd: AsRawFd + std::fmt::Debug>(
416 msg1: &Message<u32, Fd>,
417 msg2: &Message<u32, Fd>,
418 ) {
419 assert_eq!(msg1.sender_id, msg2.sender_id);
420 assert_eq!(msg1.opcode, msg2.opcode);
421 assert_eq!(msg1.args.len(), msg2.args.len());
422 for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
423 if let (Argument::Fd(fd1), Argument::Fd(fd2)) = (arg1, arg2) {
424 let fd1 = unsafe { BorrowedFd::borrow_raw(fd1.as_raw_fd()) };
425 let fd2 = unsafe { BorrowedFd::borrow_raw(fd2.as_raw_fd()) };
426 assert!(same_file(fd1, fd2));
427 } else {
428 assert_eq!(arg1, arg2);
429 }
430 }
431 }
432
433 #[test]
434 fn write_read_cycle() {
435 let msg = Message {
436 sender_id: 42,
437 opcode: 7,
438 args: smallvec![
439 Argument::Uint(3),
440 Argument::Fixed(-89),
441 Argument::Str(Some(Box::new(CString::new(&b"I like trains!"[..]).unwrap()))),
442 Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into()),
443 Argument::Object(88),
444 Argument::NewId(56),
445 Argument::Int(-25),
446 ],
447 };
448
449 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
450 let mut client = BufferedSocket::new(Socket::from(client), Some(DEFAULT_MAX_BUFFER_SIZE));
451 let mut server = BufferedSocket::new(Socket::from(server), Some(DEFAULT_MAX_BUFFER_SIZE));
452
453 client.write_message(&msg).unwrap();
454 client.flush().unwrap();
455
456 static SIGNATURE: &[ArgumentType] = &[
457 ArgumentType::Uint,
458 ArgumentType::Fixed,
459 ArgumentType::Str(AllowNull::No),
460 ArgumentType::Array,
461 ArgumentType::Object(AllowNull::No),
462 ArgumentType::NewId,
463 ArgumentType::Int,
464 ];
465
466 server.fill_incoming_buffers().unwrap();
467
468 let ret_msg =
469 server
470 .read_one_message(|sender_id, opcode| {
471 if sender_id == 42 && opcode == 7 {
472 Some(SIGNATURE)
473 } else {
474 None
475 }
476 })
477 .unwrap();
478
479 assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
480 }
481
482 #[test]
483 fn write_read_cycle_fd() {
484 let msg = Message {
485 sender_id: 42,
486 opcode: 7,
487 args: smallvec![
488 Argument::Fd(1), Argument::Fd(0), ],
491 };
492
493 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
494 let mut client = BufferedSocket::new(Socket::from(client), Some(DEFAULT_MAX_BUFFER_SIZE));
495 let mut server = BufferedSocket::new(Socket::from(server), Some(DEFAULT_MAX_BUFFER_SIZE));
496
497 client.write_message(&msg).unwrap();
498 client.flush().unwrap();
499
500 static SIGNATURE: &[ArgumentType] = &[ArgumentType::Fd, ArgumentType::Fd];
501
502 server.fill_incoming_buffers().unwrap();
503
504 let ret_msg =
505 server
506 .read_one_message(|sender_id, opcode| {
507 if sender_id == 42 && opcode == 7 {
508 Some(SIGNATURE)
509 } else {
510 None
511 }
512 })
513 .unwrap();
514 assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
515 }
516
517 #[test]
518 fn write_read_cycle_multiple() {
519 let messages = vec![
520 Message {
521 sender_id: 42,
522 opcode: 0,
523 args: smallvec![
524 Argument::Int(42),
525 Argument::Str(Some(Box::new(CString::new(&b"I like trains"[..]).unwrap()))),
526 ],
527 },
528 Message {
529 sender_id: 42,
530 opcode: 1,
531 args: smallvec![
532 Argument::Fd(1), Argument::Fd(0), ],
535 },
536 Message {
537 sender_id: 42,
538 opcode: 2,
539 args: smallvec![
540 Argument::Uint(3),
541 Argument::Fd(2), ],
543 },
544 ];
545
546 static SIGNATURES: &[&[ArgumentType]] = &[
547 &[ArgumentType::Int, ArgumentType::Str(AllowNull::No)],
548 &[ArgumentType::Fd, ArgumentType::Fd],
549 &[ArgumentType::Uint, ArgumentType::Fd],
550 ];
551
552 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
553 let mut client = BufferedSocket::new(Socket::from(client), Some(DEFAULT_MAX_BUFFER_SIZE));
554 let mut server = BufferedSocket::new(Socket::from(server), Some(DEFAULT_MAX_BUFFER_SIZE));
555
556 for msg in &messages {
557 client.write_message(msg).unwrap();
558 }
559 client.flush().unwrap();
560
561 server.fill_incoming_buffers().unwrap();
562
563 let mut recv_msgs = Vec::new();
564 while let Ok(message) = server.read_one_message(|sender_id, opcode| {
565 if sender_id == 42 {
566 Some(SIGNATURES[opcode as usize])
567 } else {
568 None
569 }
570 }) {
571 recv_msgs.push(message);
572 }
573 assert_eq!(recv_msgs.len(), 3);
574 for (msg1, msg2) in messages.into_iter().zip(recv_msgs.into_iter()) {
575 assert_eq_msgs(&msg1.map_fd(|fd| fd.as_raw_fd()), &msg2.map_fd(IntoRawFd::into_raw_fd));
576 }
577 }
578
579 #[test]
580 fn parse_with_string_len_multiple_of_4() {
581 let msg = Message {
582 sender_id: 2,
583 opcode: 0,
584 args: smallvec![
585 Argument::Uint(18),
586 Argument::Str(Some(Box::new(CString::new(&b"wl_shell"[..]).unwrap()))),
587 Argument::Uint(1),
588 ],
589 };
590
591 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
592 let mut client = BufferedSocket::new(Socket::from(client), Some(DEFAULT_MAX_BUFFER_SIZE));
593 let mut server = BufferedSocket::new(Socket::from(server), Some(DEFAULT_MAX_BUFFER_SIZE));
594
595 client.write_message(&msg).unwrap();
596 client.flush().unwrap();
597
598 static SIGNATURE: &[ArgumentType] =
599 &[ArgumentType::Uint, ArgumentType::Str(AllowNull::No), ArgumentType::Uint];
600
601 server.fill_incoming_buffers().unwrap();
602
603 let ret_msg =
604 server
605 .read_one_message(|sender_id, opcode| {
606 if sender_id == 2 && opcode == 0 {
607 Some(SIGNATURE)
608 } else {
609 None
610 }
611 })
612 .unwrap();
613
614 assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
615 }
616}