Skip to main content

zbus/connection/socket/
unix.rs

1#[cfg(not(feature = "tokio"))]
2use async_io::Async;
3#[cfg(target_os = "linux")]
4use std::os::unix::io::FromRawFd;
5#[cfg(unix)]
6use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd};
7#[cfg(all(unix, not(feature = "tokio")))]
8use std::os::unix::net::UnixStream;
9#[cfg(not(feature = "tokio"))]
10use std::sync::Arc;
11#[cfg(unix)]
12use std::{
13    future::poll_fn,
14    io::{IoSlice, IoSliceMut},
15    os::fd::OwnedFd,
16    task::Poll,
17};
18#[cfg(all(windows, not(feature = "tokio")))]
19use uds_windows::UnixStream;
20
21#[cfg(unix)]
22use rustix::net::{
23    RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags, SendAncillaryBuffer,
24    SendAncillaryMessage, SendFlags, recvmsg, sendmsg,
25};
26
27#[cfg(unix)]
28use crate::utils::FDS_MAX;
29
30#[cfg(all(unix, not(feature = "tokio")))]
31#[async_trait::async_trait]
32impl super::ReadHalf for Arc<Async<UnixStream>> {
33    async fn recvmsg(&mut self, buf: &mut [u8]) -> super::RecvmsgResult {
34        poll_fn(|cx| {
35            let (len, fds) = loop {
36                match fd_recvmsg(self.as_fd(), buf) {
37                    Err(e) if e.kind() == std::io::ErrorKind::Interrupted => {}
38                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
39                        match self.poll_readable(cx) {
40                            Poll::Pending => return Poll::Pending,
41                            Poll::Ready(res) => res?,
42                        }
43                    }
44                    v => break v?,
45                }
46            };
47            Poll::Ready(Ok((len, fds)))
48        })
49        .await
50    }
51
52    /// Supports passing file descriptors.
53    fn can_pass_unix_fd(&self) -> bool {
54        true
55    }
56
57    async fn peer_credentials(&mut self) -> std::io::Result<crate::fdo::ConnectionCredentials> {
58        get_unix_peer_creds(self).await
59    }
60}
61
62#[cfg(all(unix, not(feature = "tokio")))]
63#[async_trait::async_trait]
64impl super::WriteHalf for Arc<Async<UnixStream>> {
65    async fn sendmsg(
66        &mut self,
67        buffer: &[u8],
68        #[cfg(unix)] fds: &[BorrowedFd<'_>],
69    ) -> std::io::Result<usize> {
70        poll_fn(|cx| {
71            loop {
72                match fd_sendmsg(
73                    self.as_fd(),
74                    buffer,
75                    #[cfg(unix)]
76                    fds,
77                ) {
78                    Err(e) if e.kind() == std::io::ErrorKind::Interrupted => {}
79                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
80                        match self.poll_writable(cx) {
81                            Poll::Pending => return Poll::Pending,
82                            Poll::Ready(res) => res?,
83                        }
84                    }
85                    v => return Poll::Ready(v),
86                }
87            }
88        })
89        .await
90    }
91
92    async fn close(&mut self) -> std::io::Result<()> {
93        let stream = self.clone();
94        crate::Task::spawn_blocking(
95            move || stream.get_ref().shutdown(std::net::Shutdown::Both),
96            "close socket",
97        )
98        .await?
99    }
100
101    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
102    async fn send_zero_byte(&mut self) -> std::io::Result<Option<usize>> {
103        send_zero_byte(self).await.map(Some)
104    }
105
106    /// Supports passing file descriptors.
107    fn can_pass_unix_fd(&self) -> bool {
108        true
109    }
110
111    async fn peer_credentials(&mut self) -> std::io::Result<crate::fdo::ConnectionCredentials> {
112        super::ReadHalf::peer_credentials(self).await
113    }
114}
115
116#[cfg(all(unix, feature = "tokio"))]
117impl super::Socket for tokio::net::UnixStream {
118    type ReadHalf = tokio::net::unix::OwnedReadHalf;
119    type WriteHalf = tokio::net::unix::OwnedWriteHalf;
120
121    fn split(self) -> super::Split<Self::ReadHalf, Self::WriteHalf> {
122        let (read, write) = self.into_split();
123
124        super::Split { read, write }
125    }
126}
127
128#[cfg(all(unix, feature = "tokio"))]
129#[async_trait::async_trait]
130impl super::ReadHalf for tokio::net::unix::OwnedReadHalf {
131    async fn recvmsg(&mut self, buf: &mut [u8]) -> super::RecvmsgResult {
132        let stream = self.as_ref();
133        poll_fn(|cx| {
134            loop {
135                match stream.try_io(tokio::io::Interest::READABLE, || {
136                    // We use own custom function for reading because we need to receive file
137                    // descriptors too.
138                    fd_recvmsg(stream.as_fd(), buf)
139                }) {
140                    Err(e) if e.kind() == std::io::ErrorKind::Interrupted => {}
141                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
142                        match stream.poll_read_ready(cx) {
143                            Poll::Pending => return Poll::Pending,
144                            Poll::Ready(res) => res?,
145                        }
146                    }
147                    v => return Poll::Ready(v),
148                }
149            }
150        })
151        .await
152    }
153
154    /// Supports passing file descriptors.
155    fn can_pass_unix_fd(&self) -> bool {
156        true
157    }
158
159    async fn peer_credentials(&mut self) -> std::io::Result<crate::fdo::ConnectionCredentials> {
160        get_unix_peer_creds(self.as_ref()).await
161    }
162}
163
164#[cfg(all(unix, feature = "tokio"))]
165#[async_trait::async_trait]
166impl super::WriteHalf for tokio::net::unix::OwnedWriteHalf {
167    async fn sendmsg(
168        &mut self,
169        buffer: &[u8],
170        #[cfg(unix)] fds: &[BorrowedFd<'_>],
171    ) -> std::io::Result<usize> {
172        let stream = self.as_ref();
173        poll_fn(|cx| {
174            loop {
175                match stream.try_io(tokio::io::Interest::WRITABLE, || {
176                    fd_sendmsg(
177                        stream.as_fd(),
178                        buffer,
179                        #[cfg(unix)]
180                        fds,
181                    )
182                }) {
183                    Err(e) if e.kind() == std::io::ErrorKind::Interrupted => {}
184                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
185                        match stream.poll_write_ready(cx) {
186                            Poll::Pending => return Poll::Pending,
187                            Poll::Ready(res) => res?,
188                        }
189                    }
190                    v => return Poll::Ready(v),
191                }
192            }
193        })
194        .await
195    }
196
197    async fn close(&mut self) -> std::io::Result<()> {
198        tokio::io::AsyncWriteExt::shutdown(self).await
199    }
200
201    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
202    async fn send_zero_byte(&mut self) -> std::io::Result<Option<usize>> {
203        send_zero_byte(self.as_ref()).await.map(Some)
204    }
205
206    /// Supports passing file descriptors.
207    fn can_pass_unix_fd(&self) -> bool {
208        true
209    }
210
211    async fn peer_credentials(&mut self) -> std::io::Result<crate::fdo::ConnectionCredentials> {
212        get_unix_peer_creds(self.as_ref()).await
213    }
214}
215
216#[cfg(all(windows, not(feature = "tokio")))]
217#[async_trait::async_trait]
218impl super::ReadHalf for Arc<Async<UnixStream>> {
219    async fn recvmsg(&mut self, buf: &mut [u8]) -> super::RecvmsgResult {
220        match futures_lite::AsyncReadExt::read(&mut self.as_ref(), buf).await {
221            Err(e) => Err(e),
222            Ok(len) => {
223                #[cfg(unix)]
224                let ret = (len, vec![]);
225                #[cfg(not(unix))]
226                let ret = len;
227                Ok(ret)
228            }
229        }
230    }
231
232    async fn peer_credentials(&mut self) -> std::io::Result<crate::fdo::ConnectionCredentials> {
233        let stream = self.clone();
234        crate::Task::spawn_blocking(
235            move || {
236                use crate::win32::{ProcessToken, unix_stream_get_peer_pid};
237
238                let pid = unix_stream_get_peer_pid(stream.get_ref())? as _;
239                let sid = ProcessToken::open(if pid != 0 { Some(pid as _) } else { None })
240                    .and_then(|process_token| process_token.sid())?;
241                Ok(crate::fdo::ConnectionCredentials::default()
242                    .set_process_id(pid)
243                    .set_windows_sid(sid))
244            },
245            "peer credentials",
246        )
247        .await?
248    }
249}
250
251#[cfg(all(windows, not(feature = "tokio")))]
252#[async_trait::async_trait]
253impl super::WriteHalf for Arc<Async<UnixStream>> {
254    async fn sendmsg(
255        &mut self,
256        buf: &[u8],
257        #[cfg(unix)] _fds: &[BorrowedFd<'_>],
258    ) -> std::io::Result<usize> {
259        futures_lite::AsyncWriteExt::write(&mut self.as_ref(), buf).await
260    }
261
262    async fn close(&mut self) -> std::io::Result<()> {
263        let stream = self.clone();
264        crate::Task::spawn_blocking(
265            move || stream.get_ref().shutdown(std::net::Shutdown::Both),
266            "close socket",
267        )
268        .await?
269    }
270
271    async fn peer_credentials(&mut self) -> std::io::Result<crate::fdo::ConnectionCredentials> {
272        super::ReadHalf::peer_credentials(self).await
273    }
274}
275
276#[cfg(unix)]
277fn fd_recvmsg(fd: BorrowedFd<'_>, buffer: &mut [u8]) -> std::io::Result<(usize, Vec<OwnedFd>)> {
278    use std::mem::MaybeUninit;
279
280    let mut iov = [IoSliceMut::new(buffer)];
281    let mut cmsg_buffer = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(FDS_MAX))];
282    let mut ancillary = RecvAncillaryBuffer::new(&mut cmsg_buffer);
283
284    let msg = recvmsg(fd, &mut iov, &mut ancillary, RecvFlags::empty())?;
285    if msg.bytes == 0 {
286        return Err(std::io::Error::new(
287            std::io::ErrorKind::BrokenPipe,
288            "failed to read from socket",
289        ));
290    }
291    let mut fds = vec![];
292    for msg in ancillary.drain() {
293        match msg {
294            RecvAncillaryMessage::ScmRights(iter) => {
295                fds.extend(iter);
296            }
297            #[cfg(any(target_os = "linux", target_os = "android"))]
298            RecvAncillaryMessage::ScmCredentials(_) => {
299                // On Linux, credentials might be received. This shouldn't normally happen
300                // in our use case since we don't request them, but ignore if present.
301                continue;
302            }
303            _ => {
304                return Err(std::io::Error::new(
305                    std::io::ErrorKind::InvalidData,
306                    "unexpected CMSG kind",
307                ));
308            }
309        }
310    }
311    Ok((msg.bytes, fds))
312}
313
314#[cfg(unix)]
315fn fd_sendmsg(fd: BorrowedFd<'_>, buffer: &[u8], fds: &[BorrowedFd<'_>]) -> std::io::Result<usize> {
316    use std::mem::MaybeUninit;
317
318    let iov = [IoSlice::new(buffer)];
319    let mut cmsg_buffer = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(FDS_MAX))];
320    let mut ancillary = SendAncillaryBuffer::new(&mut cmsg_buffer);
321
322    if !fds.is_empty() && !ancillary.push(SendAncillaryMessage::ScmRights(fds)) {
323        return Err(std::io::Error::new(
324            std::io::ErrorKind::InvalidInput,
325            "too many file descriptors",
326        ));
327    }
328
329    #[cfg(not(any(
330        target_os = "macos",
331        target_os = "ios",
332        target_os = "tvos",
333        target_os = "visionos",
334        target_os = "watchos",
335        target_os = "redox"
336    )))]
337    let flags = SendFlags::NOSIGNAL;
338    #[cfg(any(
339        target_os = "macos",
340        target_os = "ios",
341        target_os = "tvos",
342        target_os = "visionos",
343        target_os = "watchos",
344        target_os = "redox"
345    ))]
346    let flags = SendFlags::empty();
347
348    let sent = sendmsg(fd, &iov, &mut ancillary, flags)?;
349    if sent == 0 {
350        // can it really happen?
351        return Err(std::io::Error::new(
352            std::io::ErrorKind::WriteZero,
353            "failed to write to buffer",
354        ));
355    }
356
357    Ok(sent)
358}
359
360#[cfg(unix)]
361async fn get_unix_peer_creds(fd: &impl AsFd) -> std::io::Result<crate::fdo::ConnectionCredentials> {
362    let fd = fd.as_fd().as_raw_fd();
363    // FIXME: Is it likely enough for sending of 1 byte to block, to justify a task (possibly
364    // launching a thread in turn)?
365    crate::Task::spawn_blocking(move || get_unix_peer_creds_blocking(fd), "peer credentials")
366        .await?
367}
368
369#[cfg(unix)]
370fn get_unix_peer_creds_blocking(fd: RawFd) -> std::io::Result<crate::fdo::ConnectionCredentials> {
371    // TODO: get this BorrowedFd directly from get_unix_peer_creds(), but this requires a
372    // 'static lifetime due to the Task.
373    let fd = unsafe { BorrowedFd::borrow_raw(fd) };
374    let mut creds = crate::fdo::ConnectionCredentials::default();
375
376    #[cfg(any(target_os = "android", target_os = "linux"))]
377    {
378        use rustix::net::sockopt::socket_peercred;
379        use tracing::debug;
380
381        let ucred = socket_peercred(fd)?;
382        let uid = ucred.uid.as_raw();
383        let gid = ucred.gid.as_raw();
384        let pid = ucred.pid.as_raw_nonzero().get() as u32;
385
386        creds = creds.set_unix_user_id(uid).set_process_id(pid);
387
388        // The dbus spec requires groups to be either absent or complete (primary +
389        // secondary groups).
390
391        // FIXME: rustix does not and [will not] provide `getpwuid_r` and `getgrouplist` so we're
392        // left with no choice but to use libc directly. We could consider using `sysinfo` crate
393        // though.
394        //
395        // [will not]: https://docs.rs/rustix/latest/rustix/not_implemented/higher_level/index.html
396        let mut passwd: libc::passwd = unsafe { std::mem::zeroed() };
397        let mut buf = vec![0u8; 16384];
398        let mut result: *mut libc::passwd = std::ptr::null_mut();
399
400        unsafe {
401            libc::getpwuid_r(
402                uid,
403                &mut passwd,
404                buf.as_mut_ptr() as *mut libc::c_char,
405                buf.len(),
406                &mut result,
407            );
408        }
409
410        if !result.is_null() {
411            let username = unsafe { std::ffi::CStr::from_ptr((*result).pw_name) };
412
413            // Get supplementary groups.
414            let mut ngroups = 64i32;
415            let mut groups = vec![0u32; ngroups as usize];
416
417            unsafe {
418                if libc::getgrouplist(
419                    username.as_ptr(),
420                    gid,
421                    groups.as_mut_ptr() as *mut libc::gid_t,
422                    &mut ngroups,
423                ) >= 0
424                {
425                    groups.truncate(ngroups as usize);
426                    // The spec also requires the groups to be numerically sorted.
427                    groups.sort();
428                    for group in groups {
429                        creds = creds.add_unix_group_id(group);
430                    }
431                } else {
432                    debug!("Group lookup failed for user {:?}", username);
433                }
434            }
435        }
436
437        #[cfg(target_os = "linux")]
438        {
439            // FIXME: Replace with rustix API when it provides SO_PEERPIDFD sockopt:
440            // https://github.com/bytecodealliance/rustix/pull/1474
441            use libc::{c_int, socklen_t};
442            use std::mem::{MaybeUninit, size_of};
443
444            let mut pidfd = MaybeUninit::<c_int>::zeroed();
445            let mut len = size_of::<c_int>() as socklen_t;
446
447            let ret = unsafe {
448                libc::getsockopt(
449                    fd.as_raw_fd(),
450                    libc::SOL_SOCKET,
451                    libc::SO_PEERPIDFD,
452                    pidfd.as_mut_ptr().cast(),
453                    &mut len,
454                )
455            };
456
457            if ret == 0 {
458                let pidfd = unsafe { pidfd.assume_init() };
459                creds = creds
460                    .set_process_fd(unsafe { std::os::fd::OwnedFd::from_raw_fd(pidfd).into() });
461            } else if ret < 0 {
462                let err = std::io::Error::last_os_error();
463                // ENOPROTOOPT means the kernel doesn't support this feature.
464                if err.raw_os_error() != Some(libc::ENOPROTOOPT) {
465                    return Err(err);
466                }
467            }
468        }
469    }
470
471    #[cfg(any(
472        target_os = "macos",
473        target_os = "ios",
474        target_os = "freebsd",
475        target_os = "dragonfly",
476        target_os = "openbsd",
477        target_os = "netbsd"
478    ))]
479    {
480        // FIXME: Replace with rustix API when it provides the require API:
481        // https://github.com/bytecodealliance/rustix/issues/1533
482        let mut uid: libc::uid_t = 0;
483        let mut gid: libc::gid_t = 0;
484
485        let ret = unsafe { libc::getpeereid(fd.as_raw_fd(), &mut uid, &mut gid) };
486        if ret != 0 {
487            return Err(std::io::Error::last_os_error());
488        }
489
490        creds = creds.set_unix_user_id(uid);
491
492        // FIXME: Handle pid fetching too
493    }
494
495    Ok(creds)
496}
497
498// Send 0 byte as a separate SCM_CREDS message.
499#[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
500async fn send_zero_byte(fd: &impl AsFd) -> std::io::Result<usize> {
501    let fd = fd.as_fd().as_raw_fd();
502    crate::Task::spawn_blocking(move || send_zero_byte_blocking(fd), "send zero byte").await?
503}
504
505#[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
506fn send_zero_byte_blocking(fd: RawFd) -> std::io::Result<usize> {
507    // FIXME: Replace with rustix API when it provides SCM_CREDS support for BSD.
508    // For now, use libc directly since rustix doesn't support sending SCM_CREDS on BSD.
509    use std::mem::MaybeUninit;
510
511    let mut iov = libc::iovec {
512        iov_base: c"".as_ptr() as *mut libc::c_void,
513        iov_len: 1,
514    };
515
516    let mut msg: libc::msghdr = unsafe { MaybeUninit::zeroed().assume_init() };
517    msg.msg_iov = &mut iov;
518    msg.msg_iovlen = 1;
519
520    // SCM_CREDS on BSD doesn't actually send data in the control message.
521    // Instead, it tells the kernel to attach credentials when receiving.
522    // We just need to allocate space for the cmsg header with no data.
523    let cmsg_space = unsafe { libc::CMSG_SPACE(0) as usize };
524    let mut cmsg_buf = vec![0u8; cmsg_space];
525
526    msg.msg_control = cmsg_buf.as_mut_ptr() as *mut libc::c_void;
527    msg.msg_controllen = cmsg_space as _;
528
529    let cmsg = unsafe { libc::CMSG_FIRSTHDR(&msg) };
530    if !cmsg.is_null() {
531        unsafe {
532            (*cmsg).cmsg_level = libc::SOL_SOCKET;
533            (*cmsg).cmsg_type = libc::SCM_CREDS;
534            (*cmsg).cmsg_len = libc::CMSG_LEN(0) as _;
535        }
536    }
537
538    let ret = unsafe { libc::sendmsg(fd, &msg, 0) };
539    if ret < 0 {
540        Err(std::io::Error::last_os_error())
541    } else {
542        Ok(ret as usize)
543    }
544}