zbus/connection/handshake/
client.rs1use async_trait::async_trait;
2use tracing::{instrument, trace, warn};
3
4use crate::{conn::socket::ReadHalf, is_flatpak, names::OwnedUniqueName, Message};
5
6use super::{
7 sasl_auth_id, AuthMechanism, Authenticated, BoxedSplit, Command, Common, Error, Handshake,
8 OwnedGuid, Result,
9};
10
11#[derive(Debug)]
16pub struct Client {
17 common: Common,
18 server_guid: Option<OwnedGuid>,
19 bus: bool,
20}
21
22impl Client {
23 pub fn new(
25 socket: BoxedSplit,
26 mechanism: Option<AuthMechanism>,
27 server_guid: Option<OwnedGuid>,
28 bus: bool,
29 ) -> Client {
30 let mechanism = mechanism.unwrap_or_else(|| socket.read().auth_mechanism());
31
32 Client {
33 common: Common::new(socket, mechanism),
34 server_guid,
35 bus,
36 }
37 }
38
39 fn set_guid(&mut self, guid: OwnedGuid) -> Result<()> {
40 match &self.server_guid {
41 Some(server_guid) if *server_guid != guid => {
42 return Err(Error::Handshake(format!(
43 "Server GUID mismatch: expected {server_guid}, got {guid}",
44 )));
45 }
46 Some(_) => (),
47 None => self.server_guid = Some(guid),
48 }
49
50 Ok(())
51 }
52
53 #[instrument(skip(self))]
56 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
57 async fn send_zero_byte(&mut self) -> Result<()> {
58 let write = self.common.socket_mut().write_mut();
59
60 let written = match write.send_zero_byte().await.map_err(|e| {
61 Error::Handshake(format!("Could not send zero byte with credentials: {e}"))
62 })? {
63 None => write.sendmsg(&[0], &[]).await?,
66 Some(n) => n,
67 };
68
69 if written != 1 {
70 return Err(Error::Handshake(
71 "Could not send zero byte with credentials".to_string(),
72 ));
73 }
74
75 Ok(())
76 }
77
78 #[instrument(skip(self))]
80 async fn authenticate(&mut self) -> Result<()> {
81 let mechanism = self.common.mechanism();
82 trace!("Trying {mechanism} mechanism");
83 let auth_cmd = match mechanism {
84 AuthMechanism::Anonymous => Command::Auth(Some(mechanism), Some("zbus".into())),
85 AuthMechanism::External => {
86 Command::Auth(Some(mechanism), Some(sasl_auth_id()?.into_bytes()))
87 }
88 };
89 self.common.write_command(auth_cmd).await?;
90
91 match self.common.read_command().await? {
92 Command::Ok(guid) => {
93 trace!("Received OK from server");
94 self.set_guid(guid)?;
95
96 Ok(())
97 }
98 Command::Rejected(accepted) => {
99 let list = accepted.replace(" ", ", ");
100 Err(Error::Handshake(format!(
101 "{mechanism} rejected by the server. Accepted mechanisms: [{list}]"
102 )))
103 }
104 Command::Error(e) => Err(Error::Handshake(format!("Received error from server: {e}"))),
105 cmd => Err(Error::Handshake(format!(
106 "Unexpected command from server: {cmd}"
107 ))),
108 }
109 }
110
111 #[instrument(skip(self))]
113 async fn send_secondary_commands(&mut self) -> Result<usize> {
114 let mut commands = Vec::with_capacity(4);
115
116 let can_pass_fd = self.common.socket_mut().read_mut().can_pass_unix_fd();
117 if can_pass_fd {
118 if is_flatpak() {
122 self.common.write_command(Command::NegotiateUnixFD).await?;
123 match self.common.read_command().await? {
124 Command::AgreeUnixFD => self.common.set_cap_unix_fd(true),
125 Command::Error(e) => warn!("UNIX file descriptor passing rejected: {e}"),
126 cmd => {
127 return Err(Error::Handshake(format!(
128 "Unexpected command from server: {cmd}"
129 )))
130 }
131 }
132 } else {
133 commands.push(Command::NegotiateUnixFD);
134 }
135 };
136 commands.push(Command::Begin);
137 let hello_method = if self.bus {
138 Some(create_hello_method_call())
139 } else {
140 None
141 };
142
143 self.common
144 .write_commands(&commands, hello_method.as_ref().map(|m| &**m.data()))
145 .await?;
146
147 Ok(commands.len() - 1)
149 }
150
151 #[instrument(skip(self))]
152 async fn receive_secondary_responses(&mut self, expected_n_responses: usize) -> Result<()> {
153 for response in self.common.read_commands(expected_n_responses).await? {
154 match response {
155 Command::Ok(guid) => {
156 trace!("Received OK from server");
157 self.set_guid(guid)?;
158 }
159 Command::AgreeUnixFD => self.common.set_cap_unix_fd(true),
160 Command::Error(e) => warn!("UNIX file descriptor passing rejected: {e}"),
161 cmd => {
162 return Err(Error::Handshake(format!(
163 "Unexpected command from server: {cmd}"
164 )))
165 }
166 }
167 }
168
169 Ok(())
170 }
171}
172
173#[async_trait]
174impl Handshake for Client {
175 #[instrument(skip(self))]
176 async fn perform(mut self) -> Result<Authenticated> {
177 trace!("Initializing");
178
179 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
180 self.send_zero_byte().await?;
181
182 self.authenticate().await?;
183 let expected_n_responses = self.send_secondary_commands().await?;
184
185 if expected_n_responses > 0 {
186 self.receive_secondary_responses(expected_n_responses)
187 .await?;
188 }
189
190 trace!("Handshake done");
191 #[cfg(unix)]
192 let (socket, mut recv_buffer, received_fds, cap_unix_fd, _) = self.common.into_components();
193 #[cfg(not(unix))]
194 let (socket, mut recv_buffer, _, _) = self.common.into_components();
195 let (mut read, write) = socket.take();
196
197 let unique_name = if self.bus {
199 let unique_name = receive_hello_response(&mut read, &mut recv_buffer).await?;
200
201 Some(unique_name)
202 } else {
203 None
204 };
205
206 Ok(Authenticated {
207 socket_write: write,
208 socket_read: Some(read),
209 server_guid: self.server_guid.unwrap(),
210 #[cfg(unix)]
211 cap_unix_fd,
212 already_received_bytes: recv_buffer,
213 #[cfg(unix)]
214 already_received_fds: received_fds,
215 unique_name,
216 })
217 }
218}
219
220fn create_hello_method_call() -> Message {
221 Message::method_call("/org/freedesktop/DBus", "Hello")
222 .unwrap()
223 .destination("org.freedesktop.DBus")
224 .unwrap()
225 .interface("org.freedesktop.DBus")
226 .unwrap()
227 .build(&())
228 .unwrap()
229}
230
231async fn receive_hello_response(
232 read: &mut Box<dyn ReadHalf>,
233 recv_buffer: &mut Vec<u8>,
234) -> Result<OwnedUniqueName> {
235 use crate::message::Type;
236
237 let reply = read
238 .receive_message(
239 0,
240 recv_buffer,
241 #[cfg(unix)]
242 &mut vec![],
243 )
244 .await?;
245 match reply.message_type() {
246 Type::MethodReturn => reply.body().deserialize(),
247 Type::Error => Err(Error::from(reply)),
248 m => Err(Error::Handshake(format!("Unexpected message `{m:?}`"))),
249 }
250}