zbus/connection/
socket_reader.rs

1use std::{collections::HashMap, sync::Arc};
2
3use event_listener::Event;
4use tracing::{debug, instrument, trace};
5
6use crate::{
7    Executor, Message, OwnedMatchRule, Task, async_lock::Mutex, connection::MsgBroadcaster,
8};
9
10use super::socket::ReadHalf;
11
12#[derive(Debug)]
13pub(crate) struct SocketReader {
14    socket: Box<dyn ReadHalf>,
15    senders: Arc<Mutex<HashMap<Option<OwnedMatchRule>, MsgBroadcaster>>>,
16    already_received_bytes: Vec<u8>,
17    #[cfg(unix)]
18    already_received_fds: Vec<std::os::fd::OwnedFd>,
19    prev_seq: u64,
20    activity_event: Arc<Event>,
21}
22
23impl SocketReader {
24    pub fn new(
25        socket: Box<dyn ReadHalf>,
26        senders: Arc<Mutex<HashMap<Option<OwnedMatchRule>, MsgBroadcaster>>>,
27        already_received_bytes: Vec<u8>,
28        #[cfg(unix)] already_received_fds: Vec<std::os::fd::OwnedFd>,
29        activity_event: Arc<Event>,
30    ) -> Self {
31        Self {
32            socket,
33            senders,
34            already_received_bytes,
35            #[cfg(unix)]
36            already_received_fds,
37            prev_seq: 0,
38            activity_event,
39        }
40    }
41
42    pub fn spawn(self, executor: &Executor<'_>) -> Task<()> {
43        executor.spawn(self.receive_msg(), "socket reader")
44    }
45
46    // Keep receiving messages and put them on the queue.
47    #[instrument(name = "socket reader", skip(self), level = "trace")]
48    async fn receive_msg(mut self) {
49        loop {
50            trace!("Waiting for message on the socket..");
51            let msg = self.read_socket().await;
52            match &msg {
53                Ok(msg) => trace!("Message received on the socket: {:?}", msg),
54                Err(e) => trace!("Error reading from the socket: {:?}", e),
55            };
56
57            let mut senders = self.senders.lock().await;
58            for (rule, sender) in &*senders {
59                if let Ok(msg) = &msg {
60                    if let Some(rule) = rule.as_ref() {
61                        match rule.matches(msg) {
62                            Ok(true) => (),
63                            Ok(false) => continue,
64                            Err(e) => {
65                                debug!("Error matching message against rule: {:?}", e);
66
67                                continue;
68                            }
69                        }
70                    }
71                }
72
73                if let Err(e) = sender.broadcast_direct(msg.clone()).await {
74                    // An error would be due to either of these:
75                    //
76                    // 1. the channel is closed.
77                    // 2. No active receivers.
78                    //
79                    // In either case, just log it unless this is the channel for the generic
80                    // unfiltered stream, where the channel is not created on-demand.
81                    if rule.is_some() {
82                        trace!(
83                            "Error broadcasting message to stream for `{:?}`: {:?}",
84                            rule, e
85                        );
86                    }
87                }
88            }
89            trace!("Broadcasted to all streams: {:?}", msg);
90
91            if msg.is_err() {
92                senders.clear();
93                trace!("Socket reading task stopped");
94
95                return;
96            }
97        }
98    }
99
100    #[instrument(skip(self), level = "trace")]
101    async fn read_socket(&mut self) -> crate::Result<Message> {
102        self.activity_event.notify(usize::MAX);
103        let seq = self.prev_seq + 1;
104        let msg = self
105            .socket
106            .receive_message(
107                seq,
108                &mut self.already_received_bytes,
109                #[cfg(unix)]
110                &mut self.already_received_fds,
111            )
112            .await?;
113        self.prev_seq = seq;
114
115        Ok(msg)
116    }
117}