use std::collections::VecDeque;
use std::io::{ErrorKind, IoSlice, IoSliceMut, Result as IoResult};
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
use std::os::unix::net::UnixStream;
use std::slice;
use rustix::io::retry_on_intr;
use rustix::net::{
recvmsg, send, sendmsg, RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags,
SendAncillaryBuffer, SendAncillaryMessage, SendFlags,
};
use crate::protocol::{ArgumentType, Message};
use super::wire::{parse_message, write_to_buffers, MessageParseError, MessageWriteError};
pub const MAX_FDS_OUT: usize = 28;
pub const MAX_BYTES_OUT: usize = 4096;
#[derive(Debug)]
pub struct Socket {
stream: UnixStream,
}
impl Socket {
pub fn send_msg(&self, bytes: &[u8], fds: &[OwnedFd]) -> IoResult<usize> {
#[cfg(not(target_os = "macos"))]
let flags = SendFlags::DONTWAIT | SendFlags::NOSIGNAL;
#[cfg(target_os = "macos")]
let flags = SendFlags::DONTWAIT;
if !fds.is_empty() {
let iov = [IoSlice::new(bytes)];
let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(fds.len()))];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
let fds =
unsafe { slice::from_raw_parts(fds.as_ptr() as *const BorrowedFd, fds.len()) };
cmsg_buffer.push(SendAncillaryMessage::ScmRights(fds));
Ok(retry_on_intr(|| sendmsg(self, &iov, &mut cmsg_buffer, flags))?)
} else {
Ok(retry_on_intr(|| send(self, bytes, flags))?)
}
}
pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut VecDeque<OwnedFd>) -> IoResult<usize> {
#[cfg(not(target_os = "macos"))]
let flags = RecvFlags::DONTWAIT | RecvFlags::CMSG_CLOEXEC;
#[cfg(target_os = "macos")]
let flags = RecvFlags::DONTWAIT;
let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(MAX_FDS_OUT))];
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut iov = [IoSliceMut::new(buffer)];
let msg = retry_on_intr(|| recvmsg(&self.stream, &mut iov[..], &mut cmsg_buffer, flags))?;
let received_fds = cmsg_buffer
.drain()
.filter_map(|cmsg| match cmsg {
RecvAncillaryMessage::ScmRights(fds) => Some(fds),
_ => None,
})
.flatten();
fds.extend(received_fds);
#[cfg(target_os = "macos")]
for fd in fds.iter() {
if let Ok(flags) = rustix::io::fcntl_getfd(fd) {
let _ = rustix::io::fcntl_setfd(fd, flags | rustix::io::FdFlags::CLOEXEC);
}
}
Ok(msg.bytes)
}
}
impl From<UnixStream> for Socket {
fn from(stream: UnixStream) -> Self {
#[cfg(target_os = "macos")]
let _ = rustix::net::sockopt::set_socket_nosigpipe(&stream, true);
Self { stream }
}
}
impl AsFd for Socket {
fn as_fd(&self) -> BorrowedFd<'_> {
self.stream.as_fd()
}
}
impl AsRawFd for Socket {
fn as_raw_fd(&self) -> RawFd {
self.stream.as_raw_fd()
}
}
#[derive(Debug)]
pub struct BufferedSocket {
socket: Socket,
in_data: Buffer<u8>,
in_fds: VecDeque<OwnedFd>,
out_data: Buffer<u8>,
out_fds: Vec<OwnedFd>,
}
impl BufferedSocket {
pub fn new(socket: Socket) -> Self {
Self {
socket,
in_data: Buffer::new(2 * MAX_BYTES_OUT), in_fds: VecDeque::new(), out_data: Buffer::new(MAX_BYTES_OUT),
out_fds: Vec::new(),
}
}
pub fn flush(&mut self) -> IoResult<()> {
let written = {
let bytes = self.out_data.get_contents();
if bytes.is_empty() {
return Ok(());
}
self.socket.send_msg(bytes, &self.out_fds)?
};
self.out_data.offset(written);
self.out_data.move_to_front();
self.out_fds.clear();
Ok(())
}
fn attempt_write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<bool> {
match write_to_buffers(msg, self.out_data.get_writable_storage(), &mut self.out_fds) {
Ok(bytes_out) => {
self.out_data.advance(bytes_out);
Ok(true)
}
Err(MessageWriteError::BufferTooSmall) => Ok(false),
Err(MessageWriteError::DupFdFailed(e)) => Err(e),
}
}
pub fn write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<()> {
if !self.attempt_write_message(msg)? {
if let Err(e) = self.flush() {
if e.kind() != ErrorKind::WouldBlock {
return Err(e);
}
}
if !self.attempt_write_message(msg)? {
return Err(rustix::io::Errno::TOOBIG.into());
}
}
Ok(())
}
pub fn fill_incoming_buffers(&mut self) -> IoResult<()> {
self.in_data.move_to_front();
let in_bytes = {
let bytes = self.in_data.get_writable_storage();
self.socket.rcv_msg(bytes, &mut self.in_fds)?
};
if in_bytes == 0 {
return Err(rustix::io::Errno::PIPE.into());
}
self.in_data.advance(in_bytes);
Ok(())
}
pub fn read_one_message<F>(
&mut self,
mut signature: F,
) -> Result<Message<u32, OwnedFd>, MessageParseError>
where
F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
{
let (msg, read_data) = {
let data = self.in_data.get_contents();
if data.len() < 2 * 4 {
return Err(MessageParseError::MissingData);
}
let object_id = u32::from_ne_bytes([data[0], data[1], data[2], data[3]]);
let word_2 = u32::from_ne_bytes([data[4], data[5], data[6], data[7]]);
let opcode = (word_2 & 0x0000_FFFF) as u16;
if let Some(sig) = signature(object_id, opcode) {
match parse_message(data, sig, &mut self.in_fds) {
Ok((msg, rest_data)) => (msg, data.len() - rest_data.len()),
Err(e) => return Err(e),
}
} else {
return Err(MessageParseError::Malformed);
}
};
self.in_data.offset(read_data);
Ok(msg)
}
}
impl AsRawFd for BufferedSocket {
fn as_raw_fd(&self) -> RawFd {
self.socket.as_raw_fd()
}
}
impl AsFd for BufferedSocket {
fn as_fd(&self) -> BorrowedFd<'_> {
self.socket.as_fd()
}
}
#[derive(Debug)]
struct Buffer<T: Copy> {
storage: Vec<T>,
occupied: usize,
offset: usize,
}
impl<T: Copy + Default> Buffer<T> {
fn new(size: usize) -> Self {
Self { storage: vec![T::default(); size], occupied: 0, offset: 0 }
}
fn advance(&mut self, bytes: usize) {
self.occupied += bytes;
}
fn offset(&mut self, bytes: usize) {
self.offset += bytes;
}
#[allow(unused)]
fn clear(&mut self) {
self.occupied = 0;
self.offset = 0;
}
fn get_contents(&self) -> &[T] {
&self.storage[(self.offset)..(self.occupied)]
}
fn get_writable_storage(&mut self) -> &mut [T] {
&mut self.storage[(self.occupied)..]
}
fn move_to_front(&mut self) {
if self.occupied > self.offset {
self.storage.copy_within((self.offset)..(self.occupied), 0)
}
self.occupied -= self.offset;
self.offset = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::{AllowNull, Argument, ArgumentType, Message};
use std::ffi::CString;
use std::os::unix::io::IntoRawFd;
use smallvec::smallvec;
fn same_file(a: BorrowedFd, b: BorrowedFd) -> bool {
let stat1 = rustix::fs::fstat(a).unwrap();
let stat2 = rustix::fs::fstat(b).unwrap();
stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
}
fn assert_eq_msgs<Fd: AsRawFd + std::fmt::Debug>(
msg1: &Message<u32, Fd>,
msg2: &Message<u32, Fd>,
) {
assert_eq!(msg1.sender_id, msg2.sender_id);
assert_eq!(msg1.opcode, msg2.opcode);
assert_eq!(msg1.args.len(), msg2.args.len());
for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
if let (Argument::Fd(fd1), Argument::Fd(fd2)) = (arg1, arg2) {
let fd1 = unsafe { BorrowedFd::borrow_raw(fd1.as_raw_fd()) };
let fd2 = unsafe { BorrowedFd::borrow_raw(fd2.as_raw_fd()) };
assert!(same_file(fd1, fd2));
} else {
assert_eq!(arg1, arg2);
}
}
}
#[test]
fn write_read_cycle() {
let msg = Message {
sender_id: 42,
opcode: 7,
args: smallvec![
Argument::Uint(3),
Argument::Fixed(-89),
Argument::Str(Some(Box::new(CString::new(&b"I like trains!"[..]).unwrap()))),
Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into()),
Argument::Object(88),
Argument::NewId(56),
Argument::Int(-25),
],
};
let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
let mut client = BufferedSocket::new(Socket::from(client));
let mut server = BufferedSocket::new(Socket::from(server));
client.write_message(&msg).unwrap();
client.flush().unwrap();
static SIGNATURE: &[ArgumentType] = &[
ArgumentType::Uint,
ArgumentType::Fixed,
ArgumentType::Str(AllowNull::No),
ArgumentType::Array,
ArgumentType::Object(AllowNull::No),
ArgumentType::NewId,
ArgumentType::Int,
];
server.fill_incoming_buffers().unwrap();
let ret_msg =
server
.read_one_message(|sender_id, opcode| {
if sender_id == 42 && opcode == 7 {
Some(SIGNATURE)
} else {
None
}
})
.unwrap();
assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
}
#[test]
fn write_read_cycle_fd() {
let msg = Message {
sender_id: 42,
opcode: 7,
args: smallvec![
Argument::Fd(1), Argument::Fd(0), ],
};
let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
let mut client = BufferedSocket::new(Socket::from(client));
let mut server = BufferedSocket::new(Socket::from(server));
client.write_message(&msg).unwrap();
client.flush().unwrap();
static SIGNATURE: &[ArgumentType] = &[ArgumentType::Fd, ArgumentType::Fd];
server.fill_incoming_buffers().unwrap();
let ret_msg =
server
.read_one_message(|sender_id, opcode| {
if sender_id == 42 && opcode == 7 {
Some(SIGNATURE)
} else {
None
}
})
.unwrap();
assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
}
#[test]
fn write_read_cycle_multiple() {
let messages = vec![
Message {
sender_id: 42,
opcode: 0,
args: smallvec![
Argument::Int(42),
Argument::Str(Some(Box::new(CString::new(&b"I like trains"[..]).unwrap()))),
],
},
Message {
sender_id: 42,
opcode: 1,
args: smallvec![
Argument::Fd(1), Argument::Fd(0), ],
},
Message {
sender_id: 42,
opcode: 2,
args: smallvec![
Argument::Uint(3),
Argument::Fd(2), ],
},
];
static SIGNATURES: &[&[ArgumentType]] = &[
&[ArgumentType::Int, ArgumentType::Str(AllowNull::No)],
&[ArgumentType::Fd, ArgumentType::Fd],
&[ArgumentType::Uint, ArgumentType::Fd],
];
let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
let mut client = BufferedSocket::new(Socket::from(client));
let mut server = BufferedSocket::new(Socket::from(server));
for msg in &messages {
client.write_message(msg).unwrap();
}
client.flush().unwrap();
server.fill_incoming_buffers().unwrap();
let mut recv_msgs = Vec::new();
while let Ok(message) = server.read_one_message(|sender_id, opcode| {
if sender_id == 42 {
Some(SIGNATURES[opcode as usize])
} else {
None
}
}) {
recv_msgs.push(message);
}
assert_eq!(recv_msgs.len(), 3);
for (msg1, msg2) in messages.into_iter().zip(recv_msgs.into_iter()) {
assert_eq_msgs(&msg1.map_fd(|fd| fd.as_raw_fd()), &msg2.map_fd(IntoRawFd::into_raw_fd));
}
}
#[test]
fn parse_with_string_len_multiple_of_4() {
let msg = Message {
sender_id: 2,
opcode: 0,
args: smallvec![
Argument::Uint(18),
Argument::Str(Some(Box::new(CString::new(&b"wl_shell"[..]).unwrap()))),
Argument::Uint(1),
],
};
let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
let mut client = BufferedSocket::new(Socket::from(client));
let mut server = BufferedSocket::new(Socket::from(server));
client.write_message(&msg).unwrap();
client.flush().unwrap();
static SIGNATURE: &[ArgumentType] =
&[ArgumentType::Uint, ArgumentType::Str(AllowNull::No), ArgumentType::Uint];
server.fill_incoming_buffers().unwrap();
let ret_msg =
server
.read_one_message(|sender_id, opcode| {
if sender_id == 2 && opcode == 0 {
Some(SIGNATURE)
} else {
None
}
})
.unwrap();
assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
}
}