1use async_broadcast::{InactiveReceiver, Receiver, Sender as Broadcaster, broadcast};
3use enumflags2::BitFlags;
4use event_listener::{Event, EventListener};
5use ordered_stream::{OrderedFuture, OrderedStream, PollResult};
6use std::{
7 collections::HashMap,
8 io::{self, ErrorKind},
9 num::NonZeroU32,
10 pin::Pin,
11 sync::{Arc, OnceLock, Weak},
12 task::{Context, Poll},
13 time::Duration,
14};
15use tracing::{Instrument, debug, info_span, instrument, trace, trace_span, warn};
16use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, OwnedUniqueName, WellKnownName};
17use zvariant::ObjectPath;
18
19use futures_core::Future;
20use futures_lite::StreamExt;
21
22use crate::{
23 DBusError, Error, Executor, MatchRule, MessageStream, ObjectServer, OwnedGuid, OwnedMatchRule,
24 Result, Task,
25 async_lock::{Mutex, Semaphore, SemaphorePermit},
26 fdo::{ConnectionCredentials, ReleaseNameReply, RequestNameFlags, RequestNameReply},
27 is_flatpak,
28 message::{Flags, Message, Type},
29 timeout::timeout,
30};
31
32mod builder;
33pub use builder::Builder;
34
35pub mod socket;
36pub use socket::Socket;
37
38mod socket_reader;
39use socket_reader::SocketReader;
40
41pub(crate) mod handshake;
42pub use handshake::AuthMechanism;
43use handshake::Authenticated;
44
45const DEFAULT_MAX_QUEUED: usize = 64;
46const DEFAULT_MAX_METHOD_RETURN_QUEUED: usize = 8;
47
48#[derive(Debug)]
50pub(crate) struct ConnectionInner {
51 server_guid: OwnedGuid,
52 #[cfg(unix)]
53 cap_unix_fd: bool,
54 #[cfg(feature = "p2p")]
55 bus_conn: bool,
56 unique_name: OnceLock<OwnedUniqueName>,
57 registered_names: Mutex<HashMap<WellKnownName<'static>, NameStatus>>,
58
59 activity_event: Arc<Event>,
60 socket_write: Mutex<Box<dyn socket::WriteHalf>>,
61
62 executor: Executor<'static>,
64
65 #[allow(unused)]
67 socket_reader_task: OnceLock<Task<()>>,
68
69 pub(crate) msg_receiver: InactiveReceiver<Result<Message>>,
70 pub(crate) method_return_receiver: InactiveReceiver<Result<Message>>,
71 msg_senders: Arc<Mutex<HashMap<Option<OwnedMatchRule>, MsgBroadcaster>>>,
72
73 subscriptions: Mutex<Subscriptions>,
74
75 object_server: OnceLock<ObjectServer>,
76 object_server_dispatch_task: OnceLock<Task<()>>,
77
78 drop_event: Event,
79
80 method_timeout: Option<Duration>,
81 credentials: OnceLock<Arc<ConnectionCredentials>>,
83}
84
85impl Drop for ConnectionInner {
86 fn drop(&mut self) {
87 self.drop_event.notify(usize::MAX);
91 }
92}
93
94type Subscriptions = HashMap<OwnedMatchRule, (u64, InactiveReceiver<Result<Message>>)>;
95
96pub(crate) type MsgBroadcaster = Broadcaster<Result<Message>>;
97
98#[derive(Clone, Debug)]
212#[must_use = "Dropping a `Connection` will close the underlying socket."]
213pub struct Connection {
214 pub(crate) inner: Arc<ConnectionInner>,
215}
216
217#[derive(Debug)]
223pub(crate) struct PendingMethodCall {
224 stream: Option<MessageStream>,
225 serial: NonZeroU32,
226}
227
228impl Future for PendingMethodCall {
229 type Output = Result<Message>;
230
231 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
232 self.poll_before(cx, None).map(|ret| {
233 ret.map(|(_, r)| r).unwrap_or_else(|| {
234 Err(crate::Error::InputOutput(
235 io::Error::new(ErrorKind::BrokenPipe, "socket closed").into(),
236 ))
237 })
238 })
239 }
240}
241
242impl OrderedFuture for PendingMethodCall {
243 type Output = Result<Message>;
244 type Ordering = zbus::message::Sequence;
245
246 fn poll_before(
247 self: Pin<&mut Self>,
248 cx: &mut Context<'_>,
249 before: Option<&Self::Ordering>,
250 ) -> Poll<Option<(Self::Ordering, Self::Output)>> {
251 let this = self.get_mut();
252 if let Some(stream) = &mut this.stream {
253 loop {
254 match Pin::new(&mut *stream).poll_next_before(cx, before) {
255 Poll::Ready(PollResult::Item {
256 data: Ok(msg),
257 ordering,
258 }) => {
259 if msg.header().reply_serial() != Some(this.serial) {
260 continue;
261 }
262 let res = match msg.message_type() {
263 Type::Error => Err(msg.into()),
264 Type::MethodReturn => Ok(msg),
265 _ => continue,
266 };
267 this.stream = None;
268 return Poll::Ready(Some((ordering, res)));
269 }
270 Poll::Ready(PollResult::Item {
271 data: Err(e),
272 ordering,
273 }) => {
274 return Poll::Ready(Some((ordering, Err(e))));
275 }
276
277 Poll::Ready(PollResult::NoneBefore) => {
278 return Poll::Ready(None);
279 }
280 Poll::Ready(PollResult::Terminated) => {
281 return Poll::Ready(None);
282 }
283 Poll::Pending => return Poll::Pending,
284 }
285 }
286 }
287 Poll::Ready(None)
288 }
289}
290
291impl Connection {
292 pub async fn send(&self, msg: &Message) -> Result<()> {
294 #[cfg(unix)]
295 if !msg.data().fds().is_empty() && !self.inner.cap_unix_fd {
296 return Err(Error::Unsupported);
297 }
298
299 self.inner.activity_event.notify(usize::MAX);
300 let mut write = self.inner.socket_write.lock().await;
301
302 write.send_message(msg).await
303 }
304
305 pub async fn call_method<'d, 'p, 'i, 'm, D, P, I, M, B>(
312 &self,
313 destination: Option<D>,
314 path: P,
315 interface: Option<I>,
316 method_name: M,
317 body: &B,
318 ) -> Result<Message>
319 where
320 D: TryInto<BusName<'d>>,
321 P: TryInto<ObjectPath<'p>>,
322 I: TryInto<InterfaceName<'i>>,
323 M: TryInto<MemberName<'m>>,
324 D::Error: Into<Error>,
325 P::Error: Into<Error>,
326 I::Error: Into<Error>,
327 M::Error: Into<Error>,
328 B: serde::ser::Serialize + zvariant::DynamicType,
329 {
330 let method = self
331 .call_method_raw(
332 destination,
333 path,
334 interface,
335 method_name,
336 BitFlags::empty(),
337 body,
338 )
339 .await?
340 .expect("no reply");
341
342 if let Some(tout) = self.method_timeout() {
343 timeout(method, tout).await
344 } else {
345 method.await
346 }
347 }
348
349 pub(crate) async fn call_method_raw<'d, 'p, 'i, 'm, D, P, I, M, B>(
360 &self,
361 destination: Option<D>,
362 path: P,
363 interface: Option<I>,
364 method_name: M,
365 flags: BitFlags<Flags>,
366 body: &B,
367 ) -> Result<Option<PendingMethodCall>>
368 where
369 D: TryInto<BusName<'d>>,
370 P: TryInto<ObjectPath<'p>>,
371 I: TryInto<InterfaceName<'i>>,
372 M: TryInto<MemberName<'m>>,
373 D::Error: Into<Error>,
374 P::Error: Into<Error>,
375 I::Error: Into<Error>,
376 M::Error: Into<Error>,
377 B: serde::ser::Serialize + zvariant::DynamicType,
378 {
379 let _permit = acquire_serial_num_semaphore().await;
380
381 let mut builder = Message::method_call(path, method_name)?;
382 if let Some(sender) = self.unique_name() {
383 builder = builder.sender(sender)?
384 }
385 if let Some(destination) = destination {
386 builder = builder.destination(destination)?
387 }
388 if let Some(interface) = interface {
389 builder = builder.interface(interface)?
390 }
391 for flag in flags {
392 builder = builder.with_flags(flag)?;
393 }
394 let msg = builder.build(body)?;
395
396 let msg_receiver = self.inner.method_return_receiver.activate_cloned();
397 let stream = Some(MessageStream::for_subscription_channel(
398 msg_receiver,
399 None,
401 self,
402 ));
403 let serial = msg.primary_header().serial_num();
404 self.send(&msg).await?;
405 if flags.contains(Flags::NoReplyExpected) {
406 Ok(None)
407 } else {
408 Ok(Some(PendingMethodCall { stream, serial }))
409 }
410 }
411
412 pub async fn emit_signal<'d, 'p, 'i, 'm, D, P, I, M, B>(
416 &self,
417 destination: Option<D>,
418 path: P,
419 interface: I,
420 signal_name: M,
421 body: &B,
422 ) -> Result<()>
423 where
424 D: TryInto<BusName<'d>>,
425 P: TryInto<ObjectPath<'p>>,
426 I: TryInto<InterfaceName<'i>>,
427 M: TryInto<MemberName<'m>>,
428 D::Error: Into<Error>,
429 P::Error: Into<Error>,
430 I::Error: Into<Error>,
431 M::Error: Into<Error>,
432 B: serde::ser::Serialize + zvariant::DynamicType,
433 {
434 let _permit = acquire_serial_num_semaphore().await;
435
436 let mut b = Message::signal(path, interface, signal_name)?;
437 if let Some(sender) = self.unique_name() {
438 b = b.sender(sender)?;
439 }
440 if let Some(destination) = destination {
441 b = b.destination(destination)?;
442 }
443 let m = b.build(body)?;
444
445 self.send(&m).await
446 }
447
448 pub async fn reply<B>(&self, call: &zbus::message::Header<'_>, body: &B) -> Result<()>
453 where
454 B: serde::ser::Serialize + zvariant::DynamicType,
455 {
456 let _permit = acquire_serial_num_semaphore().await;
457
458 let mut b = Message::method_return(call)?;
459 if let Some(sender) = self.unique_name() {
460 b = b.sender(sender)?;
461 }
462 let m = b.build(body)?;
463 self.send(&m).await
464 }
465
466 pub async fn reply_error<'e, E, B>(
471 &self,
472 call: &zbus::message::Header<'_>,
473 error_name: E,
474 body: &B,
475 ) -> Result<()>
476 where
477 B: serde::ser::Serialize + zvariant::DynamicType,
478 E: TryInto<ErrorName<'e>>,
479 E::Error: Into<Error>,
480 {
481 let _permit = acquire_serial_num_semaphore().await;
482
483 let mut b = Message::error(call, error_name)?;
484 if let Some(sender) = self.unique_name() {
485 b = b.sender(sender)?;
486 }
487 let m = b.build(body)?;
488 self.send(&m).await
489 }
490
491 pub async fn reply_dbus_error(
496 &self,
497 call: &zbus::message::Header<'_>,
498 err: impl DBusError,
499 ) -> Result<()> {
500 let _permit = acquire_serial_num_semaphore().await;
501
502 let m = err.create_reply(call)?;
503 self.send(&m).await
504 }
505
506 pub async fn request_name<'w, W>(&self, well_known_name: W) -> Result<()>
536 where
537 W: TryInto<WellKnownName<'w>>,
538 W::Error: Into<Error>,
539 {
540 self.request_name_with_flags(well_known_name, BitFlags::default())
541 .await
542 .map(|_| ())
543 }
544
545 pub async fn request_name_with_flags<'w, W>(
619 &self,
620 well_known_name: W,
621 flags: BitFlags<RequestNameFlags>,
622 ) -> Result<RequestNameReply>
623 where
624 W: TryInto<WellKnownName<'w>>,
625 W::Error: Into<Error>,
626 {
627 let well_known_name = well_known_name.try_into().map_err(Into::into)?;
628
629 if self.is_bus() && self.inner.object_server.get().is_none() {
632 warn!(
633 "Requesting name `{well_known_name}` before setting up the object server. \
634 Method calls arriving before interfaces are registered may be lost. \
635 Consider using `connection::Builder::serve_at()` and `::name()` instead.",
636 );
637 }
638 let mut names = self.inner.registered_names.lock().await;
641
642 match names.get(&well_known_name) {
643 Some(NameStatus::Owner(_)) => return Ok(RequestNameReply::AlreadyOwner),
644 Some(NameStatus::Queued(_)) => return Ok(RequestNameReply::InQueue),
645 None => (),
646 }
647
648 if !self.is_bus() {
649 names.insert(well_known_name.to_owned(), NameStatus::Owner(None));
650
651 return Ok(RequestNameReply::PrimaryOwner);
652 }
653
654 let acquired_match_rule = MatchRule::fdo_signal_builder("NameAcquired")
655 .arg(0, well_known_name.as_ref())
656 .unwrap()
657 .build();
658 let mut acquired_stream = self.add_match(acquired_match_rule.into(), None).await?;
659 let lost_match_rule = MatchRule::fdo_signal_builder("NameLost")
660 .arg(0, well_known_name.as_ref())
661 .unwrap()
662 .build();
663 let mut lost_stream = self.add_match(lost_match_rule.into(), None).await?;
664 let reply = self
665 .call_method(
666 Some("org.freedesktop.DBus"),
667 "/org/freedesktop/DBus",
668 Some("org.freedesktop.DBus"),
669 "RequestName",
670 &(well_known_name.clone(), flags),
671 )
672 .await?
673 .body()
674 .deserialize::<RequestNameReply>()?;
675 let lost_task_name = format!("monitor_name_lost{{name={well_known_name}}}");
676 let lost_task_name_span = info_span!("monitor_name_lost", name = %well_known_name);
677 let name_lost_fut = if flags.contains(RequestNameFlags::AllowReplacement) {
678 let weak_conn = WeakConnection::from(self);
679 let well_known_name = well_known_name.to_owned();
680 Some(
681 async move {
682 loop {
683 let signal = lost_stream.next().await;
684 let inner = match weak_conn.upgrade() {
685 Some(conn) => conn.inner.clone(),
686 None => break,
687 };
688
689 match signal {
690 Some(signal) => match signal {
691 Ok(_) => {
692 tracing::info!(
693 "Connection `{}` lost name `{}`",
694 inner.unique_name.get().unwrap(),
697 well_known_name
698 );
699 inner.registered_names.lock().await.remove(&well_known_name);
700
701 break;
702 }
703 Err(e) => warn!("Failed to parse `NameLost` signal: {}", e),
704 },
705 None => {
706 trace!("`NameLost` signal stream closed");
707 break;
716 }
717 }
718 }
719 }
720 .instrument(lost_task_name_span),
721 )
722 } else {
723 None
724 };
725 let status = match reply {
726 RequestNameReply::InQueue => {
727 let weak_conn = WeakConnection::from(self);
728 let well_known_name = well_known_name.to_owned();
729 let task_name = format!("monitor_name_acquired{{name={well_known_name}}}");
730 let task_name_span = info_span!("monitor_name_acquired", name = %well_known_name);
731 let task = self.executor().spawn(
732 async move {
733 loop {
734 let signal = acquired_stream.next().await;
735 let inner = match weak_conn.upgrade() {
736 Some(conn) => conn.inner.clone(),
737 None => break,
738 };
739 match signal {
740 Some(signal) => match signal {
741 Ok(_) => {
742 let mut names = inner.registered_names.lock().await;
743 if let Some(status) = names.get_mut(&well_known_name) {
744 let task = name_lost_fut.map(|fut| {
745 inner.executor.spawn(fut, &lost_task_name)
746 });
747 *status = NameStatus::Owner(task);
748
749 break;
750 }
751 }
753 Err(e) => warn!("Failed to parse `NameAcquired` signal: {}", e),
754 },
755 None => {
756 trace!("`NameAcquired` signal stream closed");
757 break;
760 }
761 }
762 }
763 }
764 .instrument(task_name_span),
765 &task_name,
766 );
767
768 NameStatus::Queued(task)
769 }
770 RequestNameReply::PrimaryOwner | RequestNameReply::AlreadyOwner => {
771 let task = name_lost_fut.map(|fut| self.executor().spawn(fut, &lost_task_name));
772
773 NameStatus::Owner(task)
774 }
775 RequestNameReply::Exists => return Err(Error::NameTaken),
776 };
777
778 names.insert(well_known_name.to_owned(), status);
779
780 Ok(reply)
781 }
782
783 pub async fn release_name<'w, W>(&self, well_known_name: W) -> Result<bool>
792 where
793 W: TryInto<WellKnownName<'w>>,
794 W::Error: Into<Error>,
795 {
796 let well_known_name: WellKnownName<'w> = well_known_name.try_into().map_err(Into::into)?;
797 let mut names = self.inner.registered_names.lock().await;
798 if names.remove(&well_known_name.to_owned()).is_none() {
800 return Ok(false);
801 };
802
803 if !self.is_bus() {
804 return Ok(true);
805 }
806
807 self.call_method(
808 Some("org.freedesktop.DBus"),
809 "/org/freedesktop/DBus",
810 Some("org.freedesktop.DBus"),
811 "ReleaseName",
812 &well_known_name,
813 )
814 .await?
815 .body()
816 .deserialize::<ReleaseNameReply>()
817 .map(|r| r == ReleaseNameReply::Released)
818 }
819
820 pub fn is_bus(&self) -> bool {
825 #[cfg(feature = "p2p")]
826 {
827 self.inner.bus_conn
828 }
829 #[cfg(not(feature = "p2p"))]
830 {
831 true
832 }
833 }
834
835 pub fn unique_name(&self) -> Option<&OwnedUniqueName> {
840 self.inner.unique_name.get()
841 }
842
843 #[cfg(feature = "bus-impl")]
853 pub fn set_unique_name<U>(&self, unique_name: U) -> Result<()>
854 where
855 U: TryInto<OwnedUniqueName>,
856 U::Error: Into<Error>,
857 {
858 let name = unique_name.try_into().map_err(Into::into)?;
859 self.set_unique_name_(name);
860
861 Ok(())
862 }
863
864 pub fn max_queued(&self) -> usize {
866 self.inner.msg_receiver.capacity()
867 }
868
869 pub fn set_max_queued(&mut self, max: usize) {
871 self.inner.msg_receiver.clone().set_capacity(max);
872 }
873
874 pub fn server_guid(&self) -> &OwnedGuid {
876 &self.inner.server_guid
877 }
878
879 pub fn executor(&self) -> &Executor<'static> {
931 &self.inner.executor
932 }
933
934 pub fn object_server(&self) -> &ObjectServer {
942 self.ensure_object_server(true)
943 }
944
945 pub(crate) fn ensure_object_server(&self, start: bool) -> &ObjectServer {
946 self.inner
947 .object_server
948 .get_or_init(move || self.setup_object_server(start, None))
949 }
950
951 fn setup_object_server(&self, start: bool, started_event: Option<Event>) -> ObjectServer {
952 if start {
953 self.start_object_server(started_event);
954 }
955
956 ObjectServer::new(self)
957 }
958
959 #[instrument(skip(self))]
960 pub(crate) fn start_object_server(&self, started_event: Option<Event>) {
961 self.inner.object_server_dispatch_task.get_or_init(|| {
962 trace!("starting ObjectServer task");
963 let weak_conn = WeakConnection::from(self);
964
965 self.inner.executor.spawn(
966 async move {
967 let mut stream = match weak_conn.upgrade() {
968 Some(conn) => {
969 let mut builder = MatchRule::builder().msg_type(Type::MethodCall);
970 if let Some(unique_name) = conn.unique_name() {
971 builder = builder.destination(&**unique_name).expect("unique name");
972 }
973 let rule = builder.build();
974 match conn.add_match(rule.into(), None).await {
975 Ok(stream) => stream,
976 Err(e) => {
977 debug!("Failed to create message stream: {}", e);
979
980 return;
981 }
982 }
983 }
984 None => {
985 trace!("Connection is gone, stopping associated object server task");
986
987 return;
988 }
989 };
990 if let Some(started_event) = started_event {
991 started_event.notify(1);
992 }
993
994 trace!("waiting for incoming method call messages..");
995 while let Some(msg) = stream.next().await.and_then(|m| {
996 if let Err(e) = &m {
997 debug!("Error while reading from object server stream: {:?}", e);
998 }
999 m.ok()
1000 }) {
1001 if let Some(conn) = weak_conn.upgrade() {
1002 let hdr = msg.header();
1003 if !conn.is_bus() {
1006 match hdr.destination() {
1007 Some(BusName::Unique(_)) | None => (),
1009 Some(BusName::WellKnown(dest)) => {
1010 let names = conn.inner.registered_names.lock().await;
1011 if !names.is_empty() && !names.contains_key(dest) {
1015 trace!(
1016 "Got a method call for a different destination: {}",
1017 dest
1018 );
1019
1020 continue;
1021 }
1022 }
1023 }
1024 }
1025 let server = conn.object_server();
1026 if let Err(e) = server.dispatch_call(&msg, &hdr).await {
1027 debug!(
1028 "Error dispatching message. Message: {:?}, error: {:?}",
1029 msg, e
1030 );
1031 }
1032 } else {
1033 trace!("Connection is gone, stopping associated object server task");
1036 break;
1037 }
1038 }
1039 }
1040 .instrument(info_span!("obj_server_task")),
1041 "obj_server_task",
1042 )
1043 });
1044 }
1045
1046 pub(crate) async fn add_match(
1047 &self,
1048 rule: OwnedMatchRule,
1049 max_queued: Option<usize>,
1050 ) -> Result<Receiver<Result<Message>>> {
1051 use std::collections::hash_map::Entry;
1052
1053 if self.inner.msg_senders.lock().await.is_empty() {
1054 return Err(Error::InputOutput(Arc::new(io::Error::new(
1056 io::ErrorKind::BrokenPipe,
1057 "Socket reader task has errored out",
1058 ))));
1059 }
1060
1061 let mut subscriptions = self.inner.subscriptions.lock().await;
1062 let msg_type = rule.msg_type().unwrap_or(Type::Signal);
1063 match subscriptions.entry(rule.clone()) {
1064 Entry::Vacant(e) => {
1065 let max_queued = max_queued.unwrap_or(DEFAULT_MAX_QUEUED);
1066 let (sender, mut receiver) = broadcast(max_queued);
1067 receiver.set_await_active(false);
1068 if self.is_bus() && msg_type == Type::Signal {
1069 self.call_method(
1070 Some("org.freedesktop.DBus"),
1071 "/org/freedesktop/DBus",
1072 Some("org.freedesktop.DBus"),
1073 "AddMatch",
1074 &e.key(),
1075 )
1076 .await?;
1077 }
1078 e.insert((1, receiver.clone().deactivate()));
1079 self.inner
1080 .msg_senders
1081 .lock()
1082 .await
1083 .insert(Some(rule), sender);
1084
1085 Ok(receiver)
1086 }
1087 Entry::Occupied(mut e) => {
1088 let (num_subscriptions, receiver) = e.get_mut();
1089 *num_subscriptions += 1;
1090 if let Some(max_queued) = max_queued {
1091 if max_queued > receiver.capacity() {
1092 receiver.set_capacity(max_queued);
1093 }
1094 }
1095
1096 Ok(receiver.activate_cloned())
1097 }
1098 }
1099 }
1100
1101 pub(crate) async fn remove_match(&self, rule: OwnedMatchRule) -> Result<bool> {
1102 use std::collections::hash_map::Entry;
1103 let mut subscriptions = self.inner.subscriptions.lock().await;
1104 let msg_type = rule.msg_type().unwrap_or(Type::Signal);
1107 match subscriptions.entry(rule) {
1108 Entry::Vacant(_) => Ok(false),
1109 Entry::Occupied(mut e) => {
1110 let rule = e.key().inner().clone();
1111 e.get_mut().0 -= 1;
1112 if e.get().0 == 0 {
1113 if self.is_bus() && msg_type == Type::Signal {
1114 self.call_method(
1115 Some("org.freedesktop.DBus"),
1116 "/org/freedesktop/DBus",
1117 Some("org.freedesktop.DBus"),
1118 "RemoveMatch",
1119 &rule,
1120 )
1121 .await?;
1122 }
1123 e.remove();
1124 self.inner
1125 .msg_senders
1126 .lock()
1127 .await
1128 .remove(&Some(rule.into()));
1129 }
1130 Ok(true)
1131 }
1132 }
1133 }
1134
1135 pub(crate) fn queue_remove_match(&self, rule: OwnedMatchRule) {
1136 let conn = self.clone();
1137 let task_name = format!("Remove match `{}`", *rule);
1138 let remove_match =
1139 async move { conn.remove_match(rule).await }.instrument(trace_span!("{}", task_name));
1140 self.inner.executor.spawn(remove_match, &task_name).detach()
1141 }
1142
1143 pub fn method_timeout(&self) -> Option<Duration> {
1145 self.inner.method_timeout
1146 }
1147
1148 pub(crate) async fn new(
1149 auth: Authenticated,
1150 #[allow(unused)] bus_connection: bool,
1151 executor: Executor<'static>,
1152 method_timeout: Option<Duration>,
1153 ) -> Result<Self> {
1154 #[cfg(unix)]
1155 let cap_unix_fd = auth.cap_unix_fd;
1156
1157 macro_rules! create_msg_broadcast_channel {
1158 ($size:expr) => {{
1159 let (msg_sender, msg_receiver) = broadcast($size);
1160 let mut msg_receiver = msg_receiver.deactivate();
1161 msg_receiver.set_await_active(false);
1162
1163 (msg_sender, msg_receiver)
1164 }};
1165 }
1166 let (msg_sender, msg_receiver) = create_msg_broadcast_channel!(DEFAULT_MAX_QUEUED);
1168 let mut msg_senders = HashMap::new();
1169 msg_senders.insert(None, msg_sender);
1170
1171 let (method_return_sender, method_return_receiver) =
1173 create_msg_broadcast_channel!(DEFAULT_MAX_METHOD_RETURN_QUEUED);
1174 let rule = MatchRule::builder()
1175 .msg_type(Type::MethodReturn)
1176 .build()
1177 .into();
1178 msg_senders.insert(Some(rule), method_return_sender.clone());
1179 let rule = MatchRule::builder().msg_type(Type::Error).build().into();
1180 msg_senders.insert(Some(rule), method_return_sender);
1181 let msg_senders = Arc::new(Mutex::new(msg_senders));
1182 let subscriptions = Mutex::new(HashMap::new());
1183
1184 let connection = Self {
1185 inner: Arc::new(ConnectionInner {
1186 activity_event: Arc::new(Event::new()),
1187 socket_write: Mutex::new(auth.socket_write),
1188 server_guid: auth.server_guid,
1189 #[cfg(unix)]
1190 cap_unix_fd,
1191 #[cfg(feature = "p2p")]
1192 bus_conn: bus_connection,
1193 unique_name: OnceLock::new(),
1194 subscriptions,
1195 object_server: OnceLock::new(),
1196 object_server_dispatch_task: OnceLock::new(),
1197 executor,
1198 socket_reader_task: OnceLock::new(),
1199 msg_senders,
1200 msg_receiver,
1201 method_return_receiver,
1202 registered_names: Mutex::new(HashMap::new()),
1203 drop_event: Event::new(),
1204 method_timeout,
1205 credentials: OnceLock::new(),
1206 }),
1207 };
1208
1209 if let Some(unique_name) = auth.unique_name {
1210 connection.set_unique_name_(unique_name);
1211 }
1212
1213 Ok(connection)
1214 }
1215
1216 pub async fn session() -> Result<Self> {
1218 Builder::session()?.build().await
1219 }
1220
1221 pub async fn system() -> Result<Self> {
1223 Builder::system()?.build().await
1224 }
1225
1226 pub fn monitor_activity(&self) -> EventListener {
1230 self.inner.activity_event.listen()
1231 }
1232
1233 pub async fn peer_creds(&self) -> io::Result<&Arc<ConnectionCredentials>> {
1244 let mut socket_write = self.inner.socket_write.lock().await;
1245
1246 if let Some(creds) = self.inner.credentials.get() {
1248 return Ok(creds);
1249 }
1250
1251 self.inner
1252 .credentials
1253 .set(socket_write.peer_credentials().await.map(Arc::new)?)
1254 .expect("credentials cache set more than once");
1255
1256 Ok(self
1257 .inner
1258 .credentials
1259 .get()
1260 .expect("credentials should have been set"))
1261 }
1262
1263 #[deprecated(since = "5.13.0", note = "Use `peer_creds` instead")]
1272 pub async fn peer_credentials(&self) -> io::Result<ConnectionCredentials> {
1273 self.inner
1274 .socket_write
1275 .lock()
1276 .await
1277 .peer_credentials()
1278 .await
1279 }
1280
1281 pub async fn close(self) -> Result<()> {
1285 self.inner.activity_event.notify(usize::MAX);
1286 self.inner
1287 .socket_write
1288 .lock()
1289 .await
1290 .close()
1291 .await
1292 .map_err(Into::into)
1293 }
1294
1295 pub async fn graceful_shutdown(self) {
1339 let listener = self.inner.drop_event.listen();
1340 drop(self);
1341 listener.await;
1342 }
1343
1344 pub(crate) fn init_socket_reader(
1345 &self,
1346 socket_read: Box<dyn socket::ReadHalf>,
1347 already_read: Vec<u8>,
1348 #[cfg(unix)] already_received_fds: Vec<std::os::fd::OwnedFd>,
1349 ) {
1350 let inner = &self.inner;
1351 inner
1352 .socket_reader_task
1353 .set(
1354 SocketReader::new(
1355 socket_read,
1356 inner.msg_senders.clone(),
1357 already_read,
1358 #[cfg(unix)]
1359 already_received_fds,
1360 inner.activity_event.clone(),
1361 )
1362 .spawn(&inner.executor),
1363 )
1364 .expect("Attempted to set `socket_reader_task` twice");
1365 }
1366
1367 fn set_unique_name_(&self, name: OwnedUniqueName) {
1368 self.inner
1369 .unique_name
1370 .set(name)
1371 .expect("unique name already set");
1373 }
1374}
1375
1376#[cfg(feature = "blocking-api")]
1377impl From<crate::blocking::Connection> for Connection {
1378 fn from(conn: crate::blocking::Connection) -> Self {
1379 conn.into_inner()
1380 }
1381}
1382
1383#[derive(Debug, Clone)]
1385pub(crate) struct WeakConnection {
1386 inner: Weak<ConnectionInner>,
1387}
1388
1389impl WeakConnection {
1390 pub fn upgrade(&self) -> Option<Connection> {
1392 self.inner.upgrade().map(|inner| Connection { inner })
1393 }
1394}
1395
1396impl From<&Connection> for WeakConnection {
1397 fn from(conn: &Connection) -> Self {
1398 Self {
1399 inner: Arc::downgrade(&conn.inner),
1400 }
1401 }
1402}
1403
1404#[derive(Debug)]
1405enum NameStatus {
1406 Owner(#[allow(unused)] Option<Task<()>>),
1408 Queued(#[allow(unused)] Task<()>),
1410}
1411
1412static SERIAL_NUM_SEMAPHORE: Semaphore = Semaphore::new(1);
1413
1414async fn acquire_serial_num_semaphore() -> Option<SemaphorePermit<'static>> {
1419 if is_flatpak() {
1420 Some(SERIAL_NUM_SEMAPHORE.acquire().await)
1421 } else {
1422 None
1423 }
1424}
1425
1426#[cfg(test)]
1427mod tests {
1428 use super::*;
1429 use crate::fdo::DBusProxy;
1430 use ntest::timeout;
1431 use std::{pin::pin, time::Duration};
1432 use test_log::test;
1433
1434 #[cfg(windows)]
1435 #[test]
1436 fn connect_autolaunch_session_bus() {
1437 let addr =
1438 crate::win32::autolaunch_bus_address().expect("Unable to get session bus address");
1439
1440 crate::block_on(async { addr.connect().await }).expect("Unable to connect to session bus");
1441 }
1442
1443 #[cfg(target_os = "macos")]
1444 #[test]
1445 fn connect_launchd_session_bus() {
1446 use crate::address::{Address, Transport, transport::Launchd};
1447 crate::block_on(async {
1448 let addr = Address::from(Transport::Launchd(Launchd::new(
1449 "DBUS_LAUNCHD_SESSION_BUS_SOCKET",
1450 )));
1451 addr.connect().await
1452 })
1453 .expect("Unable to connect to session bus");
1454 }
1455
1456 #[test]
1457 #[timeout(15000)]
1458 fn disconnect_on_drop() {
1459 crate::utils::block_on(test_disconnect_on_drop());
1462 }
1463
1464 async fn test_disconnect_on_drop() {
1465 #[derive(Default)]
1466 struct MyInterface {}
1467
1468 #[crate::interface(name = "dev.peelz.FooBar.Baz")]
1469 impl MyInterface {
1470 fn do_thing(&self) {}
1471 }
1472 let name = "dev.peelz.foobar";
1473 let connection = Builder::session()
1474 .unwrap()
1475 .name(name)
1476 .unwrap()
1477 .serve_at("/dev/peelz/FooBar", MyInterface::default())
1478 .unwrap()
1479 .build()
1480 .await
1481 .unwrap();
1482
1483 let connection2 = Connection::session().await.unwrap();
1484 let dbus = DBusProxy::new(&connection2).await.unwrap();
1485 let mut stream = dbus
1486 .receive_name_owner_changed_with_args(&[(0, name), (2, "")])
1487 .await
1488 .unwrap();
1489
1490 drop(connection);
1491
1492 stream.next().await.unwrap();
1494
1495 let name_has_owner = dbus.name_has_owner(name.try_into().unwrap()).await.unwrap();
1497 assert!(!name_has_owner);
1498 }
1499
1500 #[tokio::test(start_paused = true)]
1501 #[timeout(15000)]
1502 async fn test_graceful_shutdown() {
1503 let connection = Connection::session().await.unwrap();
1505 let clone = connection.clone();
1506 let mut shutdown = pin!(connection.graceful_shutdown());
1507 tokio::select! {
1510 _ = tokio::time::sleep(Duration::from_secs(u64::MAX)) => {},
1511 _ = &mut shutdown => {
1512 panic!("Graceful shutdown unexpectedly completed");
1513 }
1514 }
1515
1516 drop(clone);
1517 shutdown.await;
1518
1519 struct GracefulInterface {
1521 method_called: Event,
1522 wait_before_return: Option<EventListener>,
1523 announce_done: Event,
1524 }
1525
1526 #[crate::interface(name = "dev.peelz.TestGracefulShutdown")]
1527 impl GracefulInterface {
1528 async fn do_thing(&mut self) {
1529 self.method_called.notify(1);
1530 if let Some(listener) = self.wait_before_return.take() {
1531 listener.await;
1532 }
1533 self.announce_done.notify(1);
1534 }
1535 }
1536
1537 let method_called = Event::new();
1538 let method_called_listener = method_called.listen();
1539
1540 let trigger_return = Event::new();
1541 let wait_before_return = Some(trigger_return.listen());
1542
1543 let announce_done = Event::new();
1544 let done_listener = announce_done.listen();
1545
1546 let interface = GracefulInterface {
1547 method_called,
1548 wait_before_return,
1549 announce_done,
1550 };
1551
1552 let name = "dev.peelz.TestGracefulShutdown";
1553 let obj = "/dev/peelz/TestGracefulShutdown";
1554 let connection = Builder::session()
1555 .unwrap()
1556 .name(name)
1557 .unwrap()
1558 .serve_at(obj, interface)
1559 .unwrap()
1560 .build()
1561 .await
1562 .unwrap();
1563
1564 let client_conn = Connection::session().await.unwrap();
1566 tokio::spawn(async move {
1567 client_conn
1568 .call_method(Some(name), obj, Some(name), "DoThing", &())
1569 .await
1570 .unwrap();
1571 });
1572
1573 method_called_listener.await;
1576
1577 let mut shutdown = pin!(connection.graceful_shutdown());
1578 tokio::select! {
1579 _ = tokio::time::sleep(Duration::from_secs(u64::MAX)) => {},
1580 _ = &mut shutdown => {
1581 panic!("Graceful shutdown unexpectedly completed");
1583 }
1584 }
1585
1586 trigger_return.notify(1);
1588 shutdown.await;
1589
1590 done_listener.await;
1592 }
1593}
1594
1595#[cfg(feature = "p2p")]
1596#[cfg(test)]
1597mod p2p_tests {
1598 use event_listener::Event;
1599 use futures_util::TryStreamExt;
1600 use ntest::timeout;
1601 use test_log::test;
1602 use zvariant::{Endian, NATIVE_ENDIAN};
1603
1604 use super::{Builder, Connection, socket};
1605 use crate::{Guid, Message, MessageStream, Result, conn::AuthMechanism};
1606
1607 async fn test_p2p(
1609 server1: Connection,
1610 client1: Connection,
1611 server2: Connection,
1612 client2: Connection,
1613 ) -> Result<()> {
1614 let forward1 = {
1615 let stream = MessageStream::from(server1.clone());
1616 let sink = client2.clone();
1617
1618 stream.try_for_each(move |msg| {
1619 let sink = sink.clone();
1620 async move { sink.send(&msg).await }
1621 })
1622 };
1623 let forward2 = {
1624 let stream = MessageStream::from(client2.clone());
1625 let sink = server1.clone();
1626
1627 stream.try_for_each(move |msg| {
1628 let sink = sink.clone();
1629 async move { sink.send(&msg).await }
1630 })
1631 };
1632 let _forward_task = client1.executor().spawn(
1633 async move { futures_util::try_join!(forward1, forward2) },
1634 "forward_task",
1635 );
1636
1637 let server_ready = Event::new();
1638 let server_ready_listener = server_ready.listen();
1639 let client_done = Event::new();
1640 let client_done_listener = client_done.listen();
1641
1642 let server_future = async move {
1643 let mut stream = MessageStream::from(&server2);
1644 server_ready.notify(1);
1645 let method = loop {
1646 let m = stream.try_next().await?.unwrap();
1647 if m.to_string() == "Method call Test" {
1648 assert_eq!(m.body().deserialize::<u64>().unwrap(), 64);
1649 break m;
1650 }
1651 };
1652
1653 server2
1655 .emit_signal(None::<()>, "/", "org.zbus.p2p", "ASignalForYou", &())
1656 .await?;
1657 server2.reply(&method.header(), &("yay")).await?;
1658 client_done_listener.await;
1659
1660 Ok(())
1661 };
1662
1663 let client_future = async move {
1664 let mut stream = MessageStream::from(&client1);
1665 server_ready_listener.await;
1666 let endian = match NATIVE_ENDIAN {
1670 Endian::Little => Endian::Big,
1671 Endian::Big => Endian::Little,
1672 };
1673 let method = Message::method_call("/", "Test")?
1674 .interface("org.zbus.p2p")?
1675 .endian(endian)
1676 .build(&64u64)?;
1677 client1.send(&method).await?;
1678 let m = stream.try_next().await?.unwrap();
1680 client_done.notify(1);
1681 assert_eq!(m.to_string(), "Signal ASignalForYou");
1682 let reply = stream.try_next().await?.unwrap();
1683 assert_eq!(reply.to_string(), "Method return");
1684 assert_eq!(Endian::from(reply.primary_header().endian_sig()), endian);
1686 reply.body().deserialize::<String>()
1687 };
1688
1689 let (val, _) = futures_util::try_join!(client_future, server_future,)?;
1690 assert_eq!(val, "yay");
1691
1692 Ok(())
1693 }
1694
1695 #[test]
1696 #[timeout(15000)]
1697 fn tcp_p2p() {
1698 crate::utils::block_on(test_tcp_p2p()).unwrap();
1699 }
1700
1701 async fn test_tcp_p2p() -> Result<()> {
1702 let (server1, client1) = tcp_p2p_pipe().await?;
1703 let (server2, client2) = tcp_p2p_pipe().await?;
1704
1705 test_p2p(server1, client1, server2, client2).await
1706 }
1707
1708 async fn tcp_p2p_pipe() -> Result<(Connection, Connection)> {
1709 let guid = Guid::generate();
1710
1711 #[cfg(not(feature = "tokio"))]
1712 let (server_conn_builder, client_conn_builder) = {
1713 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
1714 let addr = listener.local_addr().unwrap();
1715 let p1 = std::net::TcpStream::connect(addr).unwrap();
1716 let p0 = listener.incoming().next().unwrap().unwrap();
1717
1718 (
1719 Builder::tcp_stream(p0)
1720 .server(guid)
1721 .unwrap()
1722 .p2p()
1723 .auth_mechanism(AuthMechanism::Anonymous),
1724 Builder::tcp_stream(p1).p2p(),
1725 )
1726 };
1727
1728 #[cfg(feature = "tokio")]
1729 let (server_conn_builder, client_conn_builder) = {
1730 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1731 let addr = listener.local_addr().unwrap();
1732 let p1 = tokio::net::TcpStream::connect(addr).await.unwrap();
1733 let p0 = listener.accept().await.unwrap().0;
1734
1735 (
1736 Builder::tcp_stream(p0)
1737 .server(guid)
1738 .unwrap()
1739 .p2p()
1740 .auth_mechanism(AuthMechanism::Anonymous),
1741 Builder::tcp_stream(p1).p2p(),
1742 )
1743 };
1744
1745 futures_util::try_join!(server_conn_builder.build(), client_conn_builder.build())
1746 }
1747
1748 #[cfg(unix)]
1749 #[test]
1750 #[timeout(15000)]
1751 fn unix_p2p() {
1752 crate::utils::block_on(test_unix_p2p()).unwrap();
1753 }
1754
1755 #[cfg(unix)]
1756 async fn test_unix_p2p() -> Result<()> {
1757 let (server1, client1) = unix_p2p_pipe().await?;
1758 let (server2, client2) = unix_p2p_pipe().await?;
1759
1760 test_p2p(server1, client1, server2, client2).await
1761 }
1762
1763 #[cfg(unix)]
1764 async fn unix_p2p_pipe() -> Result<(Connection, Connection)> {
1765 #[cfg(not(feature = "tokio"))]
1766 use std::os::unix::net::UnixStream;
1767 #[cfg(feature = "tokio")]
1768 use tokio::net::UnixStream;
1769 #[cfg(all(windows, not(feature = "tokio")))]
1770 use uds_windows::UnixStream;
1771
1772 let guid = Guid::generate();
1773
1774 let (p0, p1) = UnixStream::pair().unwrap();
1775
1776 futures_util::try_join!(
1777 Builder::unix_stream(p1).p2p().build(),
1778 Builder::unix_stream(p0).server(guid).unwrap().p2p().build(),
1779 )
1780 }
1781
1782 #[cfg(any(
1783 all(feature = "vsock", not(feature = "tokio")),
1784 feature = "tokio-vsock"
1785 ))]
1786 #[test]
1787 #[timeout(15000)]
1788 fn vsock_connect() {
1789 let _ = crate::utils::block_on(test_vsock_connect()).unwrap();
1790 }
1791
1792 #[cfg(any(
1793 all(feature = "vsock", not(feature = "tokio")),
1794 feature = "tokio-vsock"
1795 ))]
1796 async fn test_vsock_connect() -> Result<(Connection, Connection)> {
1797 #[cfg(feature = "tokio-vsock")]
1798 use futures_util::StreamExt;
1799
1800 let guid = Guid::generate();
1801
1802 #[cfg(all(feature = "vsock", not(feature = "tokio")))]
1803 let listener = vsock::VsockListener::bind_with_cid_port(vsock::VMADDR_CID_LOCAL, u32::MAX)?;
1804 #[cfg(feature = "tokio-vsock")]
1805 let listener = tokio_vsock::VsockListener::bind(tokio_vsock::VsockAddr::new(1, u32::MAX))?;
1806
1807 let addr = listener.local_addr()?;
1808 let addr = format!("vsock:cid={},port={},guid={guid}", addr.cid(), addr.port());
1809
1810 let server = async {
1811 #[cfg(all(feature = "vsock", not(feature = "tokio")))]
1812 let server =
1813 crate::Task::spawn_blocking(move || listener.incoming().next(), "").await?;
1814 #[cfg(feature = "tokio-vsock")]
1815 let server = listener.incoming().next().await;
1816 Builder::vsock_stream(server.unwrap()?)
1817 .server(guid)?
1818 .p2p()
1819 .auth_mechanism(AuthMechanism::Anonymous)
1820 .build()
1821 .await
1822 };
1823
1824 let client = crate::connection::Builder::address(addr.as_str())?
1825 .p2p()
1826 .build();
1827
1828 futures_util::try_join!(server, client)
1829 }
1830
1831 #[cfg(any(
1832 all(feature = "vsock", not(feature = "tokio")),
1833 feature = "tokio-vsock"
1834 ))]
1835 #[test]
1836 #[timeout(15000)]
1837 fn vsock_p2p() {
1838 crate::utils::block_on(test_vsock_p2p()).unwrap();
1839 }
1840
1841 #[cfg(any(
1842 all(feature = "vsock", not(feature = "tokio")),
1843 feature = "tokio-vsock"
1844 ))]
1845 async fn test_vsock_p2p() -> Result<()> {
1846 let (server1, client1) = vsock_p2p_pipe().await?;
1847 let (server2, client2) = vsock_p2p_pipe().await?;
1848
1849 test_p2p(server1, client1, server2, client2).await
1850 }
1851
1852 #[cfg(all(feature = "vsock", not(feature = "tokio")))]
1853 async fn vsock_p2p_pipe() -> Result<(Connection, Connection)> {
1854 let guid = Guid::generate();
1855
1856 let listener =
1857 vsock::VsockListener::bind_with_cid_port(vsock::VMADDR_CID_LOCAL, u32::MAX).unwrap();
1858 let addr = listener.local_addr().unwrap();
1859 let client = vsock::VsockStream::connect(&addr).unwrap();
1860 let server = listener.incoming().next().unwrap().unwrap();
1861
1862 futures_util::try_join!(
1863 Builder::vsock_stream(server)
1864 .server(guid)
1865 .unwrap()
1866 .p2p()
1867 .auth_mechanism(AuthMechanism::Anonymous)
1868 .build(),
1869 Builder::vsock_stream(client).p2p().build(),
1870 )
1871 }
1872
1873 #[cfg(feature = "tokio-vsock")]
1874 async fn vsock_p2p_pipe() -> Result<(Connection, Connection)> {
1875 use futures_util::StreamExt;
1876 use tokio_vsock::VsockAddr;
1877
1878 let guid = Guid::generate();
1879
1880 let listener = tokio_vsock::VsockListener::bind(VsockAddr::new(1, u32::MAX)).unwrap();
1881 let addr = listener.local_addr().unwrap();
1882 let client = tokio_vsock::VsockStream::connect(addr).await.unwrap();
1883 let server = listener.incoming().next().await.unwrap().unwrap();
1884
1885 futures_util::try_join!(
1886 Builder::vsock_stream(server)
1887 .server(guid)
1888 .unwrap()
1889 .p2p()
1890 .auth_mechanism(AuthMechanism::Anonymous)
1891 .build(),
1892 Builder::vsock_stream(client).p2p().build(),
1893 )
1894 }
1895
1896 #[test]
1897 #[timeout(15000)]
1898 fn channel_pair() {
1899 crate::utils::block_on(test_channel_pair()).unwrap();
1900 }
1901
1902 async fn test_channel_pair() -> Result<()> {
1903 let (server1, client1) = create_channel_pair().await;
1904 let (server2, client2) = create_channel_pair().await;
1905
1906 test_p2p(server1, client1, server2, client2).await
1907 }
1908
1909 async fn create_channel_pair() -> (Connection, Connection) {
1910 let (a, b) = socket::Channel::pair();
1911
1912 let guid = crate::Guid::generate();
1913 let conn1 = Builder::authenticated_socket(a, guid.clone())
1914 .unwrap()
1915 .p2p()
1916 .build()
1917 .await
1918 .unwrap();
1919 let conn2 = Builder::authenticated_socket(b, guid)
1920 .unwrap()
1921 .p2p()
1922 .build()
1923 .await
1924 .unwrap();
1925
1926 (conn1, conn2)
1927 }
1928}