zbus/connection/socket/
unix.rs

1#[cfg(not(feature = "tokio"))]
2use async_io::Async;
3#[cfg(target_os = "linux")]
4use std::os::unix::io::FromRawFd;
5#[cfg(unix)]
6use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd};
7#[cfg(all(unix, not(feature = "tokio")))]
8use std::os::unix::net::UnixStream;
9#[cfg(not(feature = "tokio"))]
10use std::sync::Arc;
11#[cfg(unix)]
12use std::{
13    future::poll_fn,
14    io::{IoSlice, IoSliceMut},
15    os::fd::OwnedFd,
16    task::Poll,
17};
18#[cfg(all(windows, not(feature = "tokio")))]
19use uds_windows::UnixStream;
20
21#[cfg(unix)]
22use rustix::net::{
23    RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags, SendAncillaryBuffer,
24    SendAncillaryMessage, SendFlags, recvmsg, sendmsg,
25};
26
27#[cfg(unix)]
28use crate::utils::FDS_MAX;
29
30#[cfg(all(unix, not(feature = "tokio")))]
31#[async_trait::async_trait]
32impl super::ReadHalf for Arc<Async<UnixStream>> {
33    async fn recvmsg(&mut self, buf: &mut [u8]) -> super::RecvmsgResult {
34        poll_fn(|cx| {
35            let (len, fds) = loop {
36                match fd_recvmsg(self.as_fd(), buf) {
37                    Err(e) if e.kind() == std::io::ErrorKind::Interrupted => {}
38                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
39                        match self.poll_readable(cx) {
40                            Poll::Pending => return Poll::Pending,
41                            Poll::Ready(res) => res?,
42                        }
43                    }
44                    v => break v?,
45                }
46            };
47            Poll::Ready(Ok((len, fds)))
48        })
49        .await
50    }
51
52    /// Supports passing file descriptors.
53    fn can_pass_unix_fd(&self) -> bool {
54        true
55    }
56
57    async fn peer_credentials(&mut self) -> std::io::Result<crate::fdo::ConnectionCredentials> {
58        get_unix_peer_creds(self).await
59    }
60}
61
62#[cfg(all(unix, not(feature = "tokio")))]
63#[async_trait::async_trait]
64impl super::WriteHalf for Arc<Async<UnixStream>> {
65    async fn sendmsg(
66        &mut self,
67        buffer: &[u8],
68        #[cfg(unix)] fds: &[BorrowedFd<'_>],
69    ) -> std::io::Result<usize> {
70        poll_fn(|cx| {
71            loop {
72                match fd_sendmsg(
73                    self.as_fd(),
74                    buffer,
75                    #[cfg(unix)]
76                    fds,
77                ) {
78                    Err(e) if e.kind() == std::io::ErrorKind::Interrupted => {}
79                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
80                        match self.poll_writable(cx) {
81                            Poll::Pending => return Poll::Pending,
82                            Poll::Ready(res) => res?,
83                        }
84                    }
85                    v => return Poll::Ready(v),
86                }
87            }
88        })
89        .await
90    }
91
92    async fn close(&mut self) -> std::io::Result<()> {
93        let stream = self.clone();
94        crate::Task::spawn_blocking(
95            move || stream.get_ref().shutdown(std::net::Shutdown::Both),
96            "close socket",
97        )
98        .await?
99    }
100
101    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
102    async fn send_zero_byte(&mut self) -> std::io::Result<Option<usize>> {
103        send_zero_byte(self).await.map(Some)
104    }
105
106    /// Supports passing file descriptors.
107    fn can_pass_unix_fd(&self) -> bool {
108        true
109    }
110
111    async fn peer_credentials(&mut self) -> std::io::Result<crate::fdo::ConnectionCredentials> {
112        super::ReadHalf::peer_credentials(self).await
113    }
114}
115
116#[cfg(all(unix, feature = "tokio"))]
117impl super::Socket for tokio::net::UnixStream {
118    type ReadHalf = tokio::net::unix::OwnedReadHalf;
119    type WriteHalf = tokio::net::unix::OwnedWriteHalf;
120
121    fn split(self) -> super::Split<Self::ReadHalf, Self::WriteHalf> {
122        let (read, write) = self.into_split();
123
124        super::Split { read, write }
125    }
126}
127
128#[cfg(all(unix, feature = "tokio"))]
129#[async_trait::async_trait]
130impl super::ReadHalf for tokio::net::unix::OwnedReadHalf {
131    async fn recvmsg(&mut self, buf: &mut [u8]) -> super::RecvmsgResult {
132        let stream = self.as_ref();
133        poll_fn(|cx| {
134            loop {
135                match stream.try_io(tokio::io::Interest::READABLE, || {
136                    // We use own custom function for reading because we need to receive file
137                    // descriptors too.
138                    fd_recvmsg(stream.as_fd(), buf)
139                }) {
140                    Err(e) if e.kind() == std::io::ErrorKind::Interrupted => {}
141                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
142                        match stream.poll_read_ready(cx) {
143                            Poll::Pending => return Poll::Pending,
144                            Poll::Ready(res) => res?,
145                        }
146                    }
147                    v => return Poll::Ready(v),
148                }
149            }
150        })
151        .await
152    }
153
154    /// Supports passing file descriptors.
155    fn can_pass_unix_fd(&self) -> bool {
156        true
157    }
158
159    async fn peer_credentials(&mut self) -> std::io::Result<crate::fdo::ConnectionCredentials> {
160        get_unix_peer_creds(self.as_ref()).await
161    }
162}
163
164#[cfg(all(unix, feature = "tokio"))]
165#[async_trait::async_trait]
166impl super::WriteHalf for tokio::net::unix::OwnedWriteHalf {
167    async fn sendmsg(
168        &mut self,
169        buffer: &[u8],
170        #[cfg(unix)] fds: &[BorrowedFd<'_>],
171    ) -> std::io::Result<usize> {
172        let stream = self.as_ref();
173        poll_fn(|cx| {
174            loop {
175                match stream.try_io(tokio::io::Interest::WRITABLE, || {
176                    fd_sendmsg(
177                        stream.as_fd(),
178                        buffer,
179                        #[cfg(unix)]
180                        fds,
181                    )
182                }) {
183                    Err(e) if e.kind() == std::io::ErrorKind::Interrupted => {}
184                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
185                        match stream.poll_write_ready(cx) {
186                            Poll::Pending => return Poll::Pending,
187                            Poll::Ready(res) => res?,
188                        }
189                    }
190                    v => return Poll::Ready(v),
191                }
192            }
193        })
194        .await
195    }
196
197    async fn close(&mut self) -> std::io::Result<()> {
198        tokio::io::AsyncWriteExt::shutdown(self).await
199    }
200
201    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
202    async fn send_zero_byte(&mut self) -> std::io::Result<Option<usize>> {
203        send_zero_byte(self.as_ref()).await.map(Some)
204    }
205
206    /// Supports passing file descriptors.
207    fn can_pass_unix_fd(&self) -> bool {
208        true
209    }
210
211    async fn peer_credentials(&mut self) -> std::io::Result<crate::fdo::ConnectionCredentials> {
212        get_unix_peer_creds(self.as_ref()).await
213    }
214}
215
216#[cfg(all(windows, not(feature = "tokio")))]
217#[async_trait::async_trait]
218impl super::ReadHalf for Arc<Async<UnixStream>> {
219    async fn recvmsg(&mut self, buf: &mut [u8]) -> super::RecvmsgResult {
220        match futures_lite::AsyncReadExt::read(&mut self.as_ref(), buf).await {
221            Err(e) => Err(e),
222            Ok(len) => {
223                #[cfg(unix)]
224                let ret = (len, vec![]);
225                #[cfg(not(unix))]
226                let ret = len;
227                Ok(ret)
228            }
229        }
230    }
231
232    async fn peer_credentials(&mut self) -> std::io::Result<crate::fdo::ConnectionCredentials> {
233        let stream = self.clone();
234        crate::Task::spawn_blocking(
235            move || {
236                use crate::win32::{ProcessToken, unix_stream_get_peer_pid};
237
238                let pid = unix_stream_get_peer_pid(stream.get_ref())? as _;
239                let sid = ProcessToken::open(if pid != 0 { Some(pid as _) } else { None })
240                    .and_then(|process_token| process_token.sid())?;
241                Ok(crate::fdo::ConnectionCredentials::default()
242                    .set_process_id(pid)
243                    .set_windows_sid(sid))
244            },
245            "peer credentials",
246        )
247        .await?
248    }
249}
250
251#[cfg(all(windows, not(feature = "tokio")))]
252#[async_trait::async_trait]
253impl super::WriteHalf for Arc<Async<UnixStream>> {
254    async fn sendmsg(
255        &mut self,
256        buf: &[u8],
257        #[cfg(unix)] _fds: &[BorrowedFd<'_>],
258    ) -> std::io::Result<usize> {
259        futures_lite::AsyncWriteExt::write(&mut self.as_ref(), buf).await
260    }
261
262    async fn close(&mut self) -> std::io::Result<()> {
263        let stream = self.clone();
264        crate::Task::spawn_blocking(
265            move || stream.get_ref().shutdown(std::net::Shutdown::Both),
266            "close socket",
267        )
268        .await?
269    }
270
271    async fn peer_credentials(&mut self) -> std::io::Result<crate::fdo::ConnectionCredentials> {
272        super::ReadHalf::peer_credentials(self).await
273    }
274}
275
276#[cfg(unix)]
277fn fd_recvmsg(fd: BorrowedFd<'_>, buffer: &mut [u8]) -> std::io::Result<(usize, Vec<OwnedFd>)> {
278    use std::mem::MaybeUninit;
279
280    let mut iov = [IoSliceMut::new(buffer)];
281    let mut cmsg_buffer = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(FDS_MAX))];
282    let mut ancillary = RecvAncillaryBuffer::new(&mut cmsg_buffer);
283
284    let msg = recvmsg(fd, &mut iov, &mut ancillary, RecvFlags::empty())?;
285    if msg.bytes == 0 {
286        return Err(std::io::Error::new(
287            std::io::ErrorKind::BrokenPipe,
288            "failed to read from socket",
289        ));
290    }
291    let mut fds = vec![];
292    for msg in ancillary.drain() {
293        match msg {
294            RecvAncillaryMessage::ScmRights(iter) => {
295                fds.extend(iter);
296            }
297            #[cfg(any(target_os = "linux", target_os = "android"))]
298            RecvAncillaryMessage::ScmCredentials(_) => {
299                // On Linux, credentials might be received. This shouldn't normally happen
300                // in our use case since we don't request them, but ignore if present.
301                continue;
302            }
303            _ => {
304                return Err(std::io::Error::new(
305                    std::io::ErrorKind::InvalidData,
306                    "unexpected CMSG kind",
307                ));
308            }
309        }
310    }
311    Ok((msg.bytes, fds))
312}
313
314#[cfg(unix)]
315fn fd_sendmsg(fd: BorrowedFd<'_>, buffer: &[u8], fds: &[BorrowedFd<'_>]) -> std::io::Result<usize> {
316    use std::mem::MaybeUninit;
317
318    let iov = [IoSlice::new(buffer)];
319    let mut cmsg_buffer = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(FDS_MAX))];
320    let mut ancillary = SendAncillaryBuffer::new(&mut cmsg_buffer);
321
322    if !fds.is_empty() && !ancillary.push(SendAncillaryMessage::ScmRights(fds)) {
323        return Err(std::io::Error::new(
324            std::io::ErrorKind::InvalidInput,
325            "too many file descriptors",
326        ));
327    }
328
329    #[cfg(not(any(target_os = "macos", target_os = "redox")))]
330    let flags = SendFlags::NOSIGNAL;
331    #[cfg(any(target_os = "macos", target_os = "redox"))]
332    let flags = SendFlags::empty();
333
334    let sent = sendmsg(fd, &iov, &mut ancillary, flags)?;
335    if sent == 0 {
336        // can it really happen?
337        return Err(std::io::Error::new(
338            std::io::ErrorKind::WriteZero,
339            "failed to write to buffer",
340        ));
341    }
342
343    Ok(sent)
344}
345
346#[cfg(unix)]
347async fn get_unix_peer_creds(fd: &impl AsFd) -> std::io::Result<crate::fdo::ConnectionCredentials> {
348    let fd = fd.as_fd().as_raw_fd();
349    // FIXME: Is it likely enough for sending of 1 byte to block, to justify a task (possibly
350    // launching a thread in turn)?
351    crate::Task::spawn_blocking(move || get_unix_peer_creds_blocking(fd), "peer credentials")
352        .await?
353}
354
355#[cfg(unix)]
356fn get_unix_peer_creds_blocking(fd: RawFd) -> std::io::Result<crate::fdo::ConnectionCredentials> {
357    // TODO: get this BorrowedFd directly from get_unix_peer_creds(), but this requires a
358    // 'static lifetime due to the Task.
359    let fd = unsafe { BorrowedFd::borrow_raw(fd) };
360    let mut creds = crate::fdo::ConnectionCredentials::default();
361
362    #[cfg(any(target_os = "android", target_os = "linux"))]
363    {
364        use rustix::net::sockopt::socket_peercred;
365        use tracing::debug;
366
367        let ucred = socket_peercred(fd)?;
368        let uid = ucred.uid.as_raw();
369        let gid = ucred.gid.as_raw();
370        let pid = ucred.pid.as_raw_nonzero().get() as u32;
371
372        creds = creds.set_unix_user_id(uid).set_process_id(pid);
373
374        // The dbus spec requires groups to be either absent or complete (primary +
375        // secondary groups).
376
377        // FIXME: rustix does not and [will not] provide `getpwuid_r` and `getgrouplist` so we're
378        // left with no choice but to use libc directly. We could consider using `sysinfo` crate
379        // though.
380        //
381        // [will not]: https://docs.rs/rustix/latest/rustix/not_implemented/higher_level/index.html
382        let mut passwd: libc::passwd = unsafe { std::mem::zeroed() };
383        let mut buf = vec![0u8; 16384];
384        let mut result: *mut libc::passwd = std::ptr::null_mut();
385
386        unsafe {
387            libc::getpwuid_r(
388                uid,
389                &mut passwd,
390                buf.as_mut_ptr() as *mut libc::c_char,
391                buf.len(),
392                &mut result,
393            );
394        }
395
396        if !result.is_null() {
397            let username = unsafe { std::ffi::CStr::from_ptr((*result).pw_name) };
398
399            // Get supplementary groups.
400            let mut ngroups = 64i32;
401            let mut groups = vec![0u32; ngroups as usize];
402
403            unsafe {
404                if libc::getgrouplist(
405                    username.as_ptr(),
406                    gid,
407                    groups.as_mut_ptr() as *mut libc::gid_t,
408                    &mut ngroups,
409                ) >= 0
410                {
411                    groups.truncate(ngroups as usize);
412                    // The spec also requires the groups to be numerically sorted.
413                    groups.sort();
414                    for group in groups {
415                        creds = creds.add_unix_group_id(group);
416                    }
417                } else {
418                    debug!("Group lookup failed for user {:?}", username);
419                }
420            }
421        }
422
423        #[cfg(target_os = "linux")]
424        {
425            // FIXME: Replace with rustix API when it provides SO_PEERPIDFD sockopt:
426            // https://github.com/bytecodealliance/rustix/pull/1474
427            use libc::{c_int, socklen_t};
428            use std::mem::{MaybeUninit, size_of};
429
430            let mut pidfd = MaybeUninit::<c_int>::zeroed();
431            let mut len = size_of::<c_int>() as socklen_t;
432
433            let ret = unsafe {
434                libc::getsockopt(
435                    fd.as_raw_fd(),
436                    libc::SOL_SOCKET,
437                    libc::SO_PEERPIDFD,
438                    pidfd.as_mut_ptr().cast(),
439                    &mut len,
440                )
441            };
442
443            if ret == 0 {
444                let pidfd = unsafe { pidfd.assume_init() };
445                creds = creds
446                    .set_process_fd(unsafe { std::os::fd::OwnedFd::from_raw_fd(pidfd).into() });
447            } else if ret < 0 {
448                let err = std::io::Error::last_os_error();
449                // ENOPROTOOPT means the kernel doesn't support this feature.
450                if err.raw_os_error() != Some(libc::ENOPROTOOPT) {
451                    return Err(err);
452                }
453            }
454        }
455    }
456
457    #[cfg(any(
458        target_os = "macos",
459        target_os = "ios",
460        target_os = "freebsd",
461        target_os = "dragonfly",
462        target_os = "openbsd",
463        target_os = "netbsd"
464    ))]
465    {
466        // FIXME: Replace with rustix API when it provides the require API:
467        // https://github.com/bytecodealliance/rustix/issues/1533
468        let mut uid: libc::uid_t = 0;
469        let mut gid: libc::gid_t = 0;
470
471        let ret = unsafe { libc::getpeereid(fd.as_raw_fd(), &mut uid, &mut gid) };
472        if ret != 0 {
473            return Err(std::io::Error::last_os_error());
474        }
475
476        creds = creds.set_unix_user_id(uid);
477
478        // FIXME: Handle pid fetching too
479    }
480
481    Ok(creds)
482}
483
484// Send 0 byte as a separate SCM_CREDS message.
485#[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
486async fn send_zero_byte(fd: &impl AsFd) -> std::io::Result<usize> {
487    let fd = fd.as_fd().as_raw_fd();
488    crate::Task::spawn_blocking(move || send_zero_byte_blocking(fd), "send zero byte").await?
489}
490
491#[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
492fn send_zero_byte_blocking(fd: RawFd) -> std::io::Result<usize> {
493    // FIXME: Replace with rustix API when it provides SCM_CREDS support for BSD.
494    // For now, use libc directly since rustix doesn't support sending SCM_CREDS on BSD.
495    use std::mem::MaybeUninit;
496
497    let mut iov = libc::iovec {
498        iov_base: c"".as_ptr() as *mut libc::c_void,
499        iov_len: 1,
500    };
501
502    let mut msg: libc::msghdr = unsafe { MaybeUninit::zeroed().assume_init() };
503    msg.msg_iov = &mut iov;
504    msg.msg_iovlen = 1;
505
506    // SCM_CREDS on BSD doesn't actually send data in the control message.
507    // Instead, it tells the kernel to attach credentials when receiving.
508    // We just need to allocate space for the cmsg header with no data.
509    let cmsg_space = unsafe { libc::CMSG_SPACE(0) as usize };
510    let mut cmsg_buf = vec![0u8; cmsg_space];
511
512    msg.msg_control = cmsg_buf.as_mut_ptr() as *mut libc::c_void;
513    msg.msg_controllen = cmsg_space as _;
514
515    let cmsg = unsafe { libc::CMSG_FIRSTHDR(&msg) };
516    if !cmsg.is_null() {
517        unsafe {
518            (*cmsg).cmsg_level = libc::SOL_SOCKET;
519            (*cmsg).cmsg_type = libc::SCM_CREDS;
520            (*cmsg).cmsg_len = libc::CMSG_LEN(0) as _;
521        }
522    }
523
524    let ret = unsafe { libc::sendmsg(fd, &msg, 0) };
525    if ret < 0 {
526        Err(std::io::Error::last_os_error())
527    } else {
528        Ok(ret as usize)
529    }
530}