zbus/connection/handshake/
client.rs

1use async_trait::async_trait;
2use tracing::{instrument, trace, warn};
3
4use crate::{Message, conn::socket::ReadHalf, is_flatpak, names::OwnedUniqueName};
5
6use super::{
7    AuthMechanism, Authenticated, BoxedSplit, Command, Common, Error, Handshake, OwnedGuid, Result,
8    sasl_auth_id,
9};
10
11/// A representation of an in-progress handshake, client-side
12///
13/// This struct is an async-compatible representation of the initial handshake that must be
14/// performed before a D-Bus connection can be used.
15#[derive(Debug)]
16pub struct Client {
17    common: Common,
18    server_guid: Option<OwnedGuid>,
19    bus: bool,
20    user_id: Result<String>,
21}
22
23impl Client {
24    /// Start a handshake on this client socket
25    pub fn new(
26        socket: BoxedSplit,
27        mechanism: Option<AuthMechanism>,
28        server_guid: Option<OwnedGuid>,
29        bus: bool,
30        user_id: Option<u32>,
31    ) -> Client {
32        let mechanism = mechanism.unwrap_or_else(|| socket.read().auth_mechanism());
33
34        Client {
35            common: Common::new(socket, mechanism),
36            server_guid,
37            bus,
38            user_id: match user_id {
39                Some(value) => Ok(value.to_string()),
40                None => sasl_auth_id(),
41            },
42        }
43    }
44
45    fn set_guid(&mut self, guid: OwnedGuid) -> Result<()> {
46        match &self.server_guid {
47            Some(server_guid) if *server_guid != guid => {
48                return Err(Error::Handshake(format!(
49                    "Server GUID mismatch: expected {server_guid}, got {guid}",
50                )));
51            }
52            Some(_) => (),
53            None => self.server_guid = Some(guid),
54        }
55
56        Ok(())
57    }
58
59    // The dbus daemon on some platforms requires sending the zero byte as a
60    // separate message with SCM_CREDS.
61    #[instrument(skip(self), level = "trace")]
62    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
63    async fn send_zero_byte(&mut self) -> Result<()> {
64        let write = self.common.socket_mut().write_mut();
65
66        let written = match write.send_zero_byte().await.map_err(|e| {
67            Error::Handshake(format!("Could not send zero byte with credentials: {e}"))
68        })? {
69            // This likely means that the socket type is unable to send SCM_CREDS.
70            // Let's try to send the 0 byte as a regular message.
71            None => write.sendmsg(&[0], &[]).await?,
72            Some(n) => n,
73        };
74
75        if written != 1 {
76            return Err(Error::Handshake(
77                "Could not send zero byte with credentials".to_string(),
78            ));
79        }
80
81        Ok(())
82    }
83
84    /// Perform the authentication handshake with the server.
85    #[instrument(skip(self), level = "trace")]
86    async fn authenticate(&mut self) -> Result<()> {
87        let mechanism = self.common.mechanism();
88        trace!("Trying {mechanism} mechanism");
89        let user_id = self.user_id.clone();
90        let auth_cmd = match mechanism {
91            AuthMechanism::Anonymous => Command::Auth(Some(mechanism), Some("zbus".into())),
92            AuthMechanism::External => Command::Auth(Some(mechanism), Some(user_id?.into_bytes())),
93        };
94        self.common.write_command(auth_cmd).await?;
95
96        match self.common.read_command().await? {
97            Command::Ok(guid) => {
98                trace!("Received OK from server");
99                self.set_guid(guid)?;
100
101                Ok(())
102            }
103            Command::Rejected(accepted) => {
104                let list = accepted.replace(" ", ", ");
105                Err(Error::Handshake(format!(
106                    "{mechanism} rejected by the server. Accepted mechanisms: [{list}]"
107                )))
108            }
109            Command::Error(e) => Err(Error::Handshake(format!("Received error from server: {e}"))),
110            cmd => Err(Error::Handshake(format!(
111                "Unexpected command from server: {cmd}"
112            ))),
113        }
114    }
115
116    /// Sends out all commands after authentication.
117    #[instrument(skip(self), level = "trace")]
118    async fn send_secondary_commands(&mut self) -> Result<usize> {
119        let mut commands = Vec::with_capacity(4);
120
121        let can_pass_fd = self.common.socket_mut().read_mut().can_pass_unix_fd();
122        if can_pass_fd {
123            // xdg-dbus-proxy can't handle pipelining, hence this special handling.
124            // FIXME: Remove this as soon as flatpak is fixed and fix is available in major distros.
125            // See https://github.com/flatpak/xdg-dbus-proxy/issues/21
126            if is_flatpak() {
127                self.common.write_command(Command::NegotiateUnixFD).await?;
128                match self.common.read_command().await? {
129                    Command::AgreeUnixFD => self.common.set_cap_unix_fd(true),
130                    Command::Error(e) => warn!("UNIX file descriptor passing rejected: {e}"),
131                    cmd => {
132                        return Err(Error::Handshake(format!(
133                            "Unexpected command from server: {cmd}"
134                        )));
135                    }
136                }
137            } else {
138                commands.push(Command::NegotiateUnixFD);
139            }
140        };
141        commands.push(Command::Begin);
142        let hello_method = if self.bus {
143            Some(create_hello_method_call())
144        } else {
145            None
146        };
147
148        self.common
149            .write_commands(&commands, hello_method.as_ref().map(|m| &**m.data()))
150            .await?;
151
152        // Server replies to all commands except `BEGIN`.
153        Ok(commands.len() - 1)
154    }
155
156    #[instrument(skip(self), level = "trace")]
157    async fn receive_secondary_responses(&mut self, expected_n_responses: usize) -> Result<()> {
158        for response in self.common.read_commands(expected_n_responses).await? {
159            match response {
160                Command::Ok(guid) => {
161                    trace!("Received OK from server");
162                    self.set_guid(guid)?;
163                }
164                Command::AgreeUnixFD => self.common.set_cap_unix_fd(true),
165                Command::Error(e) => warn!("UNIX file descriptor passing rejected: {e}"),
166                cmd => {
167                    return Err(Error::Handshake(format!(
168                        "Unexpected command from server: {cmd}"
169                    )));
170                }
171            }
172        }
173
174        Ok(())
175    }
176}
177
178#[async_trait]
179impl Handshake for Client {
180    #[instrument(skip(self), level = "trace")]
181    async fn perform(mut self) -> Result<Authenticated> {
182        trace!("Initializing");
183
184        #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
185        self.send_zero_byte().await?;
186
187        self.authenticate().await?;
188        let expected_n_responses = self.send_secondary_commands().await?;
189
190        if expected_n_responses > 0 {
191            self.receive_secondary_responses(expected_n_responses)
192                .await?;
193        }
194
195        trace!("Handshake done");
196        #[cfg(unix)]
197        let (socket, mut recv_buffer, received_fds, cap_unix_fd, _) = self.common.into_components();
198        #[cfg(not(unix))]
199        let (socket, mut recv_buffer, _, _) = self.common.into_components();
200        let (mut read, write) = socket.take();
201
202        // If we're a bus connection, we need to read the unique name from `Hello` response.
203        let unique_name = if self.bus {
204            let unique_name = receive_hello_response(&mut read, &mut recv_buffer).await?;
205
206            Some(unique_name)
207        } else {
208            None
209        };
210
211        Ok(Authenticated {
212            socket_write: write,
213            socket_read: Some(read),
214            server_guid: self.server_guid.unwrap(),
215            #[cfg(unix)]
216            cap_unix_fd,
217            already_received_bytes: recv_buffer,
218            #[cfg(unix)]
219            already_received_fds: received_fds,
220            unique_name,
221        })
222    }
223}
224
225fn create_hello_method_call() -> Message {
226    Message::method_call("/org/freedesktop/DBus", "Hello")
227        .unwrap()
228        .destination("org.freedesktop.DBus")
229        .unwrap()
230        .interface("org.freedesktop.DBus")
231        .unwrap()
232        .build(&())
233        .unwrap()
234}
235
236async fn receive_hello_response(
237    read: &mut Box<dyn ReadHalf>,
238    recv_buffer: &mut Vec<u8>,
239) -> Result<OwnedUniqueName> {
240    use crate::message::Type;
241
242    let reply = read
243        .receive_message(
244            0,
245            recv_buffer,
246            #[cfg(unix)]
247            &mut vec![],
248        )
249        .await?;
250    match reply.message_type() {
251        Type::MethodReturn => reply.body().deserialize(),
252        Type::Error => Err(Error::from(reply)),
253        m => Err(Error::Handshake(format!("Unexpected message `{m:?}`"))),
254    }
255}