1use std::collections::VecDeque;
4use std::ffi::CStr;
5use std::os::unix::io::{BorrowedFd, OwnedFd, RawFd};
6
7use crate::protocol::{Argument, ArgumentType, Message};
8
9use smallvec::SmallVec;
10
11#[derive(Debug)]
13pub enum MessageWriteError {
14 BufferTooSmall,
16 DupFdFailed(std::io::Error),
18}
19
20impl std::error::Error for MessageWriteError {}
21
22impl std::fmt::Display for MessageWriteError {
23 #[cfg_attr(unstable_coverage, coverage(off))]
24 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> Result<(), ::std::fmt::Error> {
25 match self {
26 Self::BufferTooSmall => {
27 f.write_str("The provided buffer is too small to hold message content.")
28 }
29 Self::DupFdFailed(e) => {
30 write!(
31 f,
32 "The message contains a file descriptor that could not be dup()-ed ({e})."
33 )
34 }
35 }
36 }
37}
38
39#[derive(Debug, Clone)]
41pub enum MessageParseError {
42 MissingFD,
44 MissingData,
46 Malformed,
48}
49
50impl std::error::Error for MessageParseError {}
51
52impl std::fmt::Display for MessageParseError {
53 #[cfg_attr(unstable_coverage, coverage(off))]
54 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> Result<(), ::std::fmt::Error> {
55 match *self {
56 Self::MissingFD => {
57 f.write_str("The message references a FD but the buffer FD is empty.")
58 }
59 Self::MissingData => f.write_str("More data is needed to deserialize the message"),
60 Self::Malformed => f.write_str("The message is malformed and cannot be parsed"),
61 }
62 }
63}
64
65pub fn write_to_buffers(
71 msg: &Message<u32, RawFd>,
72 payload: &mut [u8],
73 fds: &mut Vec<OwnedFd>,
74) -> Result<usize, MessageWriteError> {
75 let orig_payload_len = payload.len();
76 fn write_buf(u: u32, payload: &mut [u8]) -> Result<&mut [u8], MessageWriteError> {
78 if payload.len() >= 4 {
79 let (head, tail) = payload.split_at_mut(4);
80 head.copy_from_slice(&u.to_ne_bytes());
81 Ok(tail)
82 } else {
83 Err(MessageWriteError::BufferTooSmall)
84 }
85 }
86
87 fn write_array_to_payload<'a>(
89 array: &[u8],
90 payload: &'a mut [u8],
91 ) -> Result<&'a mut [u8], MessageWriteError> {
92 let payload = write_buf(array.len() as u32, payload)?;
94
95 let len = next_multiple_of(array.len(), 4);
97
98 if payload.len() < len {
99 return Err(MessageWriteError::BufferTooSmall);
100 }
101
102 let (buffer_slice, rest) = payload.split_at_mut(len);
103 buffer_slice[..array.len()].copy_from_slice(array);
104 Ok(rest)
105 }
106
107 let free_size = payload.len();
108 if free_size < 2 * 4 {
109 return Err(MessageWriteError::BufferTooSmall);
110 }
111
112 let (header, mut payload) = payload.split_at_mut(2 * 4);
113
114 for arg in &msg.args {
116 payload = match *arg {
117 Argument::Int(i) => write_buf(i as u32, payload)?,
118 Argument::Uint(u) => write_buf(u, payload)?,
119 Argument::Fixed(f) => write_buf(f as u32, payload)?,
120 Argument::Str(Some(ref s)) => write_array_to_payload(s.as_bytes_with_nul(), payload)?,
121 Argument::Str(None) => write_array_to_payload(&[], payload)?,
122 Argument::Object(o) => write_buf(o, payload)?,
123 Argument::NewId(n) => write_buf(n, payload)?,
124 Argument::Array(ref a) => write_array_to_payload(a, payload)?,
125 Argument::Fd(fd) => {
126 let dup_fd = unsafe { BorrowedFd::borrow_raw(fd) }
127 .try_clone_to_owned()
128 .map_err(MessageWriteError::DupFdFailed)?;
129 fds.push(dup_fd);
130 payload
131 }
132 };
133 }
134
135 let wrote_size = free_size - payload.len();
136 header[..4].copy_from_slice(&msg.sender_id.to_ne_bytes());
137 header[4..]
138 .copy_from_slice(&(((wrote_size as u32) << 16) | u32::from(msg.opcode)).to_ne_bytes());
139 Ok(orig_payload_len - payload.len())
140}
141
142#[allow(clippy::type_complexity)]
150pub fn parse_message<'a>(
151 raw: &'a [u8],
152 signature: &[ArgumentType],
153 fds: &mut VecDeque<OwnedFd>,
154) -> Result<(Message<u32, OwnedFd>, &'a [u8]), MessageParseError> {
155 fn read_array_from_payload(
157 array_len: usize,
158 payload: &[u8],
159 ) -> Result<(&[u8], &[u8]), MessageParseError> {
160 let len = next_multiple_of(array_len, 4);
161 if len > payload.len() {
162 return Err(MessageParseError::MissingData);
163 }
164 Ok((&payload[..array_len], &payload[len..]))
165 }
166
167 if raw.len() < 2 * 4 {
168 return Err(MessageParseError::MissingData);
169 }
170
171 let sender_id = u32::from_ne_bytes([raw[0], raw[1], raw[2], raw[3]]);
172 let word_2 = u32::from_ne_bytes([raw[4], raw[5], raw[6], raw[7]]);
173 let opcode = (word_2 & 0x0000_FFFF) as u16;
174 let len = (word_2 >> 16) as usize;
175
176 if len < 2 * 4 {
177 return Err(MessageParseError::Malformed);
178 } else if len > raw.len() {
179 return Err(MessageParseError::MissingData);
180 }
181
182 let fd_len = signature.iter().filter(|x| matches!(x, ArgumentType::Fd)).count();
183 if fd_len > fds.len() {
184 return Err(MessageParseError::MissingFD);
185 }
186
187 let (mut payload, rest) = raw.split_at(len);
188 payload = &payload[2 * 4..];
189
190 let arguments = signature
191 .iter()
192 .map(|argtype| {
193 if let ArgumentType::Fd = *argtype {
194 if let Some(front) = fds.pop_front() {
196 Ok(Argument::Fd(front))
197 } else {
198 Err(MessageParseError::MissingFD)
199 }
200 } else if payload.len() >= 4 {
201 let (front, mut tail) = payload.split_at(4);
202 let front = u32::from_ne_bytes(front.try_into().unwrap());
203 let arg = match *argtype {
204 ArgumentType::Int => Ok(Argument::Int(front as i32)),
205 ArgumentType::Uint => Ok(Argument::Uint(front)),
206 ArgumentType::Fixed => Ok(Argument::Fixed(front as i32)),
207 ArgumentType::Str(_) => {
208 read_array_from_payload(front as usize, tail).and_then(|(v, rest)| {
209 tail = rest;
210 if !v.is_empty() {
211 match CStr::from_bytes_with_nul(v) {
212 Ok(s) => Ok(Argument::Str(Some(Box::new(s.into())))),
213 Err(_) => Err(MessageParseError::Malformed),
214 }
215 } else {
216 Ok(Argument::Str(None))
217 }
218 })
219 }
220 ArgumentType::Object(_) => Ok(Argument::Object(front)),
221 ArgumentType::NewId => Ok(Argument::NewId(front)),
222 ArgumentType::Array => {
223 read_array_from_payload(front as usize, tail).map(|(v, rest)| {
224 tail = rest;
225 Argument::Array(Box::new(v.into()))
226 })
227 }
228 ArgumentType::Fd => unreachable!(),
229 };
230 payload = tail;
231 arg
232 } else {
233 Err(MessageParseError::MissingData)
234 }
235 })
236 .collect::<Result<SmallVec<_>, MessageParseError>>()?;
237
238 let msg = Message { sender_id, opcode, args: arguments };
239 Ok((msg, rest))
240}
241
242fn next_multiple_of(lhs: usize, rhs: usize) -> usize {
244 match lhs % rhs {
245 0 => lhs,
246 r => lhs + (rhs - r),
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use crate::protocol::AllowNull;
254 use smallvec::smallvec;
255 use std::{ffi::CString, os::unix::io::IntoRawFd};
256
257 #[test]
258 fn into_from_raw_cycle() {
259 let mut bytes_buffer = vec![0; 1024];
260 let mut fd_buffer = Vec::new();
261
262 let msg = Message {
263 sender_id: 42,
264 opcode: 7,
265 args: smallvec![
266 Argument::Uint(3),
267 Argument::Fixed(-89),
268 Argument::Str(Some(Box::new(CString::new(&b"I like trains!"[..]).unwrap()))),
269 Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into()),
270 Argument::Object(88),
271 Argument::NewId(56),
272 Argument::Int(-25),
273 ],
274 };
275 write_to_buffers(&msg, &mut bytes_buffer[..], &mut fd_buffer).unwrap();
277 let mut fd_buffer = VecDeque::from(fd_buffer);
279 let (rebuilt, _) = parse_message(
280 &bytes_buffer[..],
281 &[
282 ArgumentType::Uint,
283 ArgumentType::Fixed,
284 ArgumentType::Str(AllowNull::No),
285 ArgumentType::Array,
286 ArgumentType::Object(AllowNull::No),
287 ArgumentType::NewId,
288 ArgumentType::Int,
289 ],
290 &mut fd_buffer,
291 )
292 .unwrap();
293 assert_eq!(rebuilt.map_fd(IntoRawFd::into_raw_fd), msg);
294 }
295}