zbus/connection/handshake/
client.rs1use async_trait::async_trait;
2use tracing::{instrument, trace, warn};
3
4use crate::{Message, conn::socket::ReadHalf, is_flatpak, names::OwnedUniqueName};
5
6use super::{
7 AuthMechanism, Authenticated, BoxedSplit, Command, Common, Error, Handshake, OwnedGuid, Result,
8 sasl_auth_id,
9};
10
11#[derive(Debug)]
16pub struct Client {
17 common: Common,
18 server_guid: Option<OwnedGuid>,
19 bus: bool,
20 user_id: Result<String>,
21}
22
23impl Client {
24 pub fn new(
26 socket: BoxedSplit,
27 mechanism: Option<AuthMechanism>,
28 server_guid: Option<OwnedGuid>,
29 bus: bool,
30 user_id: Option<u32>,
31 ) -> Client {
32 let mechanism = mechanism.unwrap_or_else(|| socket.read().auth_mechanism());
33
34 Client {
35 common: Common::new(socket, mechanism),
36 server_guid,
37 bus,
38 user_id: match user_id {
39 Some(value) => Ok(value.to_string()),
40 None => sasl_auth_id(),
41 },
42 }
43 }
44
45 fn set_guid(&mut self, guid: OwnedGuid) -> Result<()> {
46 match &self.server_guid {
47 Some(server_guid) if *server_guid != guid => {
48 return Err(Error::Handshake(format!(
49 "Server GUID mismatch: expected {server_guid}, got {guid}",
50 )));
51 }
52 Some(_) => (),
53 None => self.server_guid = Some(guid),
54 }
55
56 Ok(())
57 }
58
59 #[instrument(skip(self), level = "trace")]
62 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
63 async fn send_zero_byte(&mut self) -> Result<()> {
64 let write = self.common.socket_mut().write_mut();
65
66 let written = match write.send_zero_byte().await.map_err(|e| {
67 Error::Handshake(format!("Could not send zero byte with credentials: {e}"))
68 })? {
69 None => write.sendmsg(&[0], &[]).await?,
72 Some(n) => n,
73 };
74
75 if written != 1 {
76 return Err(Error::Handshake(
77 "Could not send zero byte with credentials".to_string(),
78 ));
79 }
80
81 Ok(())
82 }
83
84 #[instrument(skip(self), level = "trace")]
86 async fn authenticate(&mut self) -> Result<()> {
87 let mechanism = self.common.mechanism();
88 trace!("Trying {mechanism} mechanism");
89 let user_id = self.user_id.clone();
90 let auth_cmd = match mechanism {
91 AuthMechanism::Anonymous => Command::Auth(Some(mechanism), Some("zbus".into())),
92 AuthMechanism::External => Command::Auth(Some(mechanism), Some(user_id?.into_bytes())),
93 };
94 self.common.write_command(auth_cmd).await?;
95
96 match self.common.read_command().await? {
97 Command::Ok(guid) => {
98 trace!("Received OK from server");
99 self.set_guid(guid)?;
100
101 Ok(())
102 }
103 Command::Rejected(accepted) => {
104 let list = accepted.replace(" ", ", ");
105 Err(Error::Handshake(format!(
106 "{mechanism} rejected by the server. Accepted mechanisms: [{list}]"
107 )))
108 }
109 Command::Error(e) => Err(Error::Handshake(format!("Received error from server: {e}"))),
110 cmd => Err(Error::Handshake(format!(
111 "Unexpected command from server: {cmd}"
112 ))),
113 }
114 }
115
116 #[instrument(skip(self), level = "trace")]
118 async fn send_secondary_commands(&mut self) -> Result<usize> {
119 let mut commands = Vec::with_capacity(4);
120
121 let can_pass_fd = self.common.socket_mut().read_mut().can_pass_unix_fd();
122 if can_pass_fd {
123 if is_flatpak() {
127 self.common.write_command(Command::NegotiateUnixFD).await?;
128 match self.common.read_command().await? {
129 Command::AgreeUnixFD => self.common.set_cap_unix_fd(true),
130 Command::Error(e) => warn!("UNIX file descriptor passing rejected: {e}"),
131 cmd => {
132 return Err(Error::Handshake(format!(
133 "Unexpected command from server: {cmd}"
134 )));
135 }
136 }
137 } else {
138 commands.push(Command::NegotiateUnixFD);
139 }
140 };
141 commands.push(Command::Begin);
142 let hello_method = if self.bus {
143 Some(create_hello_method_call())
144 } else {
145 None
146 };
147
148 self.common
149 .write_commands(&commands, hello_method.as_ref().map(|m| &**m.data()))
150 .await?;
151
152 Ok(commands.len() - 1)
154 }
155
156 #[instrument(skip(self), level = "trace")]
157 async fn receive_secondary_responses(&mut self, expected_n_responses: usize) -> Result<()> {
158 for response in self.common.read_commands(expected_n_responses).await? {
159 match response {
160 Command::Ok(guid) => {
161 trace!("Received OK from server");
162 self.set_guid(guid)?;
163 }
164 Command::AgreeUnixFD => self.common.set_cap_unix_fd(true),
165 Command::Error(e) => warn!("UNIX file descriptor passing rejected: {e}"),
166 cmd => {
167 return Err(Error::Handshake(format!(
168 "Unexpected command from server: {cmd}"
169 )));
170 }
171 }
172 }
173
174 Ok(())
175 }
176}
177
178#[async_trait]
179impl Handshake for Client {
180 #[instrument(skip(self), level = "trace")]
181 async fn perform(mut self) -> Result<Authenticated> {
182 trace!("Initializing");
183
184 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
185 self.send_zero_byte().await?;
186
187 self.authenticate().await?;
188 let expected_n_responses = self.send_secondary_commands().await?;
189
190 if expected_n_responses > 0 {
191 self.receive_secondary_responses(expected_n_responses)
192 .await?;
193 }
194
195 trace!("Handshake done");
196 #[cfg(unix)]
197 let (socket, mut recv_buffer, received_fds, cap_unix_fd, _) = self.common.into_components();
198 #[cfg(not(unix))]
199 let (socket, mut recv_buffer, _, _) = self.common.into_components();
200 let (mut read, write) = socket.take();
201
202 let unique_name = if self.bus {
204 let unique_name = receive_hello_response(&mut read, &mut recv_buffer).await?;
205
206 Some(unique_name)
207 } else {
208 None
209 };
210
211 Ok(Authenticated {
212 socket_write: write,
213 socket_read: Some(read),
214 server_guid: self.server_guid.unwrap(),
215 #[cfg(unix)]
216 cap_unix_fd,
217 already_received_bytes: recv_buffer,
218 #[cfg(unix)]
219 already_received_fds: received_fds,
220 unique_name,
221 })
222 }
223}
224
225fn create_hello_method_call() -> Message {
226 Message::method_call("/org/freedesktop/DBus", "Hello")
227 .unwrap()
228 .destination("org.freedesktop.DBus")
229 .unwrap()
230 .interface("org.freedesktop.DBus")
231 .unwrap()
232 .build(&())
233 .unwrap()
234}
235
236async fn receive_hello_response(
237 read: &mut Box<dyn ReadHalf>,
238 recv_buffer: &mut Vec<u8>,
239) -> Result<OwnedUniqueName> {
240 use crate::message::Type;
241
242 let reply = read
243 .receive_message(
244 0,
245 recv_buffer,
246 #[cfg(unix)]
247 &mut vec![],
248 )
249 .await?;
250 match reply.message_type() {
251 Type::MethodReturn => reply.body().deserialize(),
252 Type::Error => Err(Error::from(reply)),
253 m => Err(Error::Handshake(format!("Unexpected message `{m:?}`"))),
254 }
255}