zbus/connection/handshake/
client.rs

1use async_trait::async_trait;
2use tracing::{instrument, trace, warn};
3
4use crate::{conn::socket::ReadHalf, is_flatpak, names::OwnedUniqueName, Message};
5
6use super::{
7    sasl_auth_id, AuthMechanism, Authenticated, BoxedSplit, Command, Common, Error, Handshake,
8    OwnedGuid, Result,
9};
10
11/// A representation of an in-progress handshake, client-side
12///
13/// This struct is an async-compatible representation of the initial handshake that must be
14/// performed before a D-Bus connection can be used.
15#[derive(Debug)]
16pub struct Client {
17    common: Common,
18    server_guid: Option<OwnedGuid>,
19    bus: bool,
20}
21
22impl Client {
23    /// Start a handshake on this client socket
24    pub fn new(
25        socket: BoxedSplit,
26        mechanism: Option<AuthMechanism>,
27        server_guid: Option<OwnedGuid>,
28        bus: bool,
29    ) -> Client {
30        let mechanism = mechanism.unwrap_or_else(|| socket.read().auth_mechanism());
31
32        Client {
33            common: Common::new(socket, mechanism),
34            server_guid,
35            bus,
36        }
37    }
38
39    fn set_guid(&mut self, guid: OwnedGuid) -> Result<()> {
40        match &self.server_guid {
41            Some(server_guid) if *server_guid != guid => {
42                return Err(Error::Handshake(format!(
43                    "Server GUID mismatch: expected {server_guid}, got {guid}",
44                )));
45            }
46            Some(_) => (),
47            None => self.server_guid = Some(guid),
48        }
49
50        Ok(())
51    }
52
53    // The dbus daemon on some platforms requires sending the zero byte as a
54    // separate message with SCM_CREDS.
55    #[instrument(skip(self))]
56    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
57    async fn send_zero_byte(&mut self) -> Result<()> {
58        let write = self.common.socket_mut().write_mut();
59
60        let written = match write.send_zero_byte().await.map_err(|e| {
61            Error::Handshake(format!("Could not send zero byte with credentials: {e}"))
62        })? {
63            // This likely means that the socket type is unable to send SCM_CREDS.
64            // Let's try to send the 0 byte as a regular message.
65            None => write.sendmsg(&[0], &[]).await?,
66            Some(n) => n,
67        };
68
69        if written != 1 {
70            return Err(Error::Handshake(
71                "Could not send zero byte with credentials".to_string(),
72            ));
73        }
74
75        Ok(())
76    }
77
78    /// Perform the authentication handshake with the server.
79    #[instrument(skip(self))]
80    async fn authenticate(&mut self) -> Result<()> {
81        let mechanism = self.common.mechanism();
82        trace!("Trying {mechanism} mechanism");
83        let auth_cmd = match mechanism {
84            AuthMechanism::Anonymous => Command::Auth(Some(mechanism), Some("zbus".into())),
85            AuthMechanism::External => {
86                Command::Auth(Some(mechanism), Some(sasl_auth_id()?.into_bytes()))
87            }
88        };
89        self.common.write_command(auth_cmd).await?;
90
91        match self.common.read_command().await? {
92            Command::Ok(guid) => {
93                trace!("Received OK from server");
94                self.set_guid(guid)?;
95
96                Ok(())
97            }
98            Command::Rejected(accepted) => {
99                let list = accepted.replace(" ", ", ");
100                Err(Error::Handshake(format!(
101                    "{mechanism} rejected by the server. Accepted mechanisms: [{list}]"
102                )))
103            }
104            Command::Error(e) => Err(Error::Handshake(format!("Received error from server: {e}"))),
105            cmd => Err(Error::Handshake(format!(
106                "Unexpected command from server: {cmd}"
107            ))),
108        }
109    }
110
111    /// Sends out all commands after authentication.
112    #[instrument(skip(self))]
113    async fn send_secondary_commands(&mut self) -> Result<usize> {
114        let mut commands = Vec::with_capacity(4);
115
116        let can_pass_fd = self.common.socket_mut().read_mut().can_pass_unix_fd();
117        if can_pass_fd {
118            // xdg-dbus-proxy can't handle pipelining, hence this special handling.
119            // FIXME: Remove this as soon as flatpak is fixed and fix is available in major distros.
120            // See https://github.com/flatpak/xdg-dbus-proxy/issues/21
121            if is_flatpak() {
122                self.common.write_command(Command::NegotiateUnixFD).await?;
123                match self.common.read_command().await? {
124                    Command::AgreeUnixFD => self.common.set_cap_unix_fd(true),
125                    Command::Error(e) => warn!("UNIX file descriptor passing rejected: {e}"),
126                    cmd => {
127                        return Err(Error::Handshake(format!(
128                            "Unexpected command from server: {cmd}"
129                        )))
130                    }
131                }
132            } else {
133                commands.push(Command::NegotiateUnixFD);
134            }
135        };
136        commands.push(Command::Begin);
137        let hello_method = if self.bus {
138            Some(create_hello_method_call())
139        } else {
140            None
141        };
142
143        self.common
144            .write_commands(&commands, hello_method.as_ref().map(|m| &**m.data()))
145            .await?;
146
147        // Server replies to all commands except `BEGIN`.
148        Ok(commands.len() - 1)
149    }
150
151    #[instrument(skip(self))]
152    async fn receive_secondary_responses(&mut self, expected_n_responses: usize) -> Result<()> {
153        for response in self.common.read_commands(expected_n_responses).await? {
154            match response {
155                Command::Ok(guid) => {
156                    trace!("Received OK from server");
157                    self.set_guid(guid)?;
158                }
159                Command::AgreeUnixFD => self.common.set_cap_unix_fd(true),
160                Command::Error(e) => warn!("UNIX file descriptor passing rejected: {e}"),
161                cmd => {
162                    return Err(Error::Handshake(format!(
163                        "Unexpected command from server: {cmd}"
164                    )))
165                }
166            }
167        }
168
169        Ok(())
170    }
171}
172
173#[async_trait]
174impl Handshake for Client {
175    #[instrument(skip(self))]
176    async fn perform(mut self) -> Result<Authenticated> {
177        trace!("Initializing");
178
179        #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
180        self.send_zero_byte().await?;
181
182        self.authenticate().await?;
183        let expected_n_responses = self.send_secondary_commands().await?;
184
185        if expected_n_responses > 0 {
186            self.receive_secondary_responses(expected_n_responses)
187                .await?;
188        }
189
190        trace!("Handshake done");
191        #[cfg(unix)]
192        let (socket, mut recv_buffer, received_fds, cap_unix_fd, _) = self.common.into_components();
193        #[cfg(not(unix))]
194        let (socket, mut recv_buffer, _, _) = self.common.into_components();
195        let (mut read, write) = socket.take();
196
197        // If we're a bus connection, we need to read the unique name from `Hello` response.
198        let unique_name = if self.bus {
199            let unique_name = receive_hello_response(&mut read, &mut recv_buffer).await?;
200
201            Some(unique_name)
202        } else {
203            None
204        };
205
206        Ok(Authenticated {
207            socket_write: write,
208            socket_read: Some(read),
209            server_guid: self.server_guid.unwrap(),
210            #[cfg(unix)]
211            cap_unix_fd,
212            already_received_bytes: recv_buffer,
213            #[cfg(unix)]
214            already_received_fds: received_fds,
215            unique_name,
216        })
217    }
218}
219
220fn create_hello_method_call() -> Message {
221    Message::method_call("/org/freedesktop/DBus", "Hello")
222        .unwrap()
223        .destination("org.freedesktop.DBus")
224        .unwrap()
225        .interface("org.freedesktop.DBus")
226        .unwrap()
227        .build(&())
228        .unwrap()
229}
230
231async fn receive_hello_response(
232    read: &mut Box<dyn ReadHalf>,
233    recv_buffer: &mut Vec<u8>,
234) -> Result<OwnedUniqueName> {
235    use crate::message::Type;
236
237    let reply = read
238        .receive_message(
239            0,
240            recv_buffer,
241            #[cfg(unix)]
242            &mut vec![],
243        )
244        .await?;
245    match reply.message_type() {
246        Type::MethodReturn => reply.body().deserialize(),
247        Type::Error => Err(Error::from(reply)),
248        m => Err(Error::Handshake(format!("Unexpected message `{m:?}`"))),
249    }
250}