zbus/connection/handshake/
common.rs

1use tracing::{instrument, trace};
2
3use super::{AuthMechanism, BoxedSplit, Command};
4use crate::{Error, Result};
5
6// Common code for the client and server side of the handshake.
7#[derive(Debug)]
8pub(super) struct Common {
9    socket: BoxedSplit,
10    recv_buffer: Vec<u8>,
11    #[cfg(unix)]
12    received_fds: Vec<std::os::fd::OwnedFd>,
13    cap_unix_fd: bool,
14    mechanism: AuthMechanism,
15    first_command: bool,
16}
17
18impl Common {
19    /// Start a handshake on this client socket
20    pub fn new(socket: BoxedSplit, mechanism: AuthMechanism) -> Self {
21        Self {
22            socket,
23            recv_buffer: Vec::new(),
24            #[cfg(unix)]
25            received_fds: Vec::new(),
26            cap_unix_fd: false,
27            mechanism,
28            first_command: true,
29        }
30    }
31
32    #[cfg(all(unix, feature = "p2p"))]
33    pub fn socket(&self) -> &BoxedSplit {
34        &self.socket
35    }
36
37    pub fn socket_mut(&mut self) -> &mut BoxedSplit {
38        &mut self.socket
39    }
40
41    pub fn set_cap_unix_fd(&mut self, cap_unix_fd: bool) {
42        self.cap_unix_fd = cap_unix_fd;
43    }
44
45    pub fn mechanism(&self) -> AuthMechanism {
46        self.mechanism
47    }
48
49    pub fn into_components(self) -> IntoComponentsReturn {
50        (
51            self.socket,
52            self.recv_buffer,
53            #[cfg(unix)]
54            self.received_fds,
55            self.cap_unix_fd,
56            self.mechanism,
57        )
58    }
59
60    #[instrument(skip(self))]
61    pub async fn write_command(&mut self, command: Command) -> Result<()> {
62        self.write_commands(&[command], None).await
63    }
64
65    #[instrument(skip(self))]
66    pub async fn write_commands(
67        &mut self,
68        commands: &[Command],
69        extra_bytes: Option<&[u8]>,
70    ) -> Result<()> {
71        let mut send_buffer =
72            commands
73                .iter()
74                .map(Vec::<u8>::from)
75                .fold(vec![], |mut acc, mut c| {
76                    if self.first_command {
77                        // The first command is sent by the client so we can assume it's the client.
78                        self.first_command = false;
79                        // leading 0 is sent separately for `freebsd` and `dragonfly`.
80                        #[cfg(not(any(target_os = "freebsd", target_os = "dragonfly")))]
81                        acc.push(b'\0');
82                    }
83                    acc.append(&mut c);
84                    acc.extend_from_slice(b"\r\n");
85                    acc
86                });
87        if let Some(extra_bytes) = extra_bytes {
88            send_buffer.extend_from_slice(extra_bytes);
89        }
90        while !send_buffer.is_empty() {
91            let written = self
92                .socket
93                .write_mut()
94                .sendmsg(
95                    &send_buffer,
96                    #[cfg(unix)]
97                    &[],
98                )
99                .await?;
100            send_buffer.drain(..written);
101        }
102        trace!("Wrote all commands");
103        Ok(())
104    }
105
106    #[instrument(skip(self))]
107    pub async fn read_command(&mut self) -> Result<Command> {
108        self.read_commands(1)
109            .await
110            .map(|cmds| cmds.into_iter().next().unwrap())
111    }
112
113    #[instrument(skip(self))]
114    pub async fn read_commands(&mut self, n_commands: usize) -> Result<Vec<Command>> {
115        let mut commands = Vec::with_capacity(n_commands);
116        let mut n_received_commands = 0;
117        'outer: loop {
118            while let Some(lf_index) = self.recv_buffer.iter().position(|b| *b == b'\n') {
119                if self.recv_buffer[lf_index - 1] != b'\r' {
120                    return Err(Error::Handshake("Invalid line ending in handshake".into()));
121                }
122
123                #[allow(unused_mut)]
124                let mut start_index = 0;
125                if self.first_command {
126                    // The first command is sent by the client so we can assume it's the server.
127                    self.first_command = false;
128                    if self.recv_buffer[0] != b'\0' {
129                        return Err(Error::Handshake(
130                            "First client byte is not NUL!".to_string(),
131                        ));
132                    }
133
134                    start_index = 1;
135                };
136
137                let line_bytes = self.recv_buffer.drain(..=lf_index);
138                let line = std::str::from_utf8(&line_bytes.as_slice()[start_index..])
139                    .map_err(|e| Error::Handshake(e.to_string()))?;
140
141                trace!("Reading {line}");
142                commands.push(line.parse()?);
143                n_received_commands += 1;
144
145                if n_received_commands == n_commands {
146                    break 'outer;
147                }
148            }
149
150            let mut buf = vec![0; 1024];
151            let res = self.socket.read_mut().recvmsg(&mut buf).await?;
152            let read = {
153                #[cfg(unix)]
154                {
155                    let (read, fds) = res;
156                    if !fds.is_empty() {
157                        // Most likely belonging to the messages already received.
158                        self.received_fds.extend(fds);
159                    }
160                    read
161                }
162                #[cfg(not(unix))]
163                {
164                    res
165                }
166            };
167            if read == 0 {
168                return Err(Error::Handshake("Unexpected EOF during handshake".into()));
169            }
170            self.recv_buffer.extend(&buf[..read]);
171        }
172
173        Ok(commands)
174    }
175}
176
177#[cfg(unix)]
178type IntoComponentsReturn = (
179    BoxedSplit,
180    Vec<u8>,
181    Vec<std::os::fd::OwnedFd>,
182    bool,
183    AuthMechanism,
184);
185#[cfg(not(unix))]
186type IntoComponentsReturn = (BoxedSplit, Vec<u8>, bool, AuthMechanism);