Skip to main content

zbus_macros/
error.rs

1use proc_macro2::TokenStream;
2use quote::{ToTokens, quote};
3use syn::{Data, DeriveInput, Error, Fields, Ident, Variant, spanned::Spanned};
4use zvariant_utils::def_attrs;
5
6def_attrs! {
7    crate zbus;
8
9    pub StructAttributes("struct") {
10        prefix str,
11        impl_display bool,
12        crate_path str
13    };
14
15    pub VariantAttributes("enum variant") {
16        name str,
17        error none
18    };
19}
20
21use crate::utils::*;
22
23pub fn expand_derive(input: DeriveInput) -> Result<TokenStream, Error> {
24    let StructAttributes {
25        prefix,
26        impl_display,
27        crate_path: crate_attr,
28    } = StructAttributes::parse(&input.attrs)?;
29    let crate_path = parse_crate_path(crate_attr.as_deref())?;
30    let prefix = prefix.unwrap_or_else(|| "org.freedesktop.DBus".to_string());
31    let generate_display = impl_display.unwrap_or(true);
32
33    let (_vis, name, _generics, data) = match input.data {
34        Data::Enum(data) => (input.vis, input.ident, input.generics, data),
35        _ => return Err(Error::new(input.span(), "only enums supported")),
36    };
37
38    let zbus = zbus_path(crate_path.as_ref());
39    let mut replies = quote! {};
40    let mut error_names = quote! {};
41    let mut error_descriptions = quote! {};
42    let mut error_converts = quote! {};
43
44    let mut zbus_error_variant = None;
45
46    for variant in data.variants {
47        let VariantAttributes { name, error } = VariantAttributes::parse(&variant.attrs)?;
48        let ident = &variant.ident;
49        let name = name.unwrap_or_else(|| ident.to_string());
50
51        let fqn = if !error {
52            format!("{prefix}.{name}")
53        } else {
54            // The ZBus error variant will always be a hardcoded string.
55            String::from("org.freedesktop.zbus.Error")
56        };
57
58        let error_name = quote! {
59            #zbus::names::ErrorName::from_static_str_unchecked(#fqn)
60        };
61        let e = match variant.fields {
62            Fields::Unit => quote! {
63                Self::#ident => #error_name,
64            },
65            Fields::Unnamed(_) => quote! {
66                Self::#ident(..) => #error_name,
67            },
68            Fields::Named(_) => quote! {
69                Self::#ident { .. } => #error_name,
70            },
71        };
72        error_names.extend(e);
73
74        if error {
75            if zbus_error_variant.is_some() {
76                panic!("More than 1 `#[zbus(error)]` variant found");
77            }
78
79            zbus_error_variant = Some(quote! { #ident });
80        }
81
82        // FIXME: this will error if the first field is not a string as per the dbus spec, but we
83        // may support other cases?
84        let e = match &variant.fields {
85            Fields::Unit => quote! {
86                Self::#ident => None,
87            },
88            Fields::Unnamed(_) => {
89                if error {
90                    quote! {
91                        Self::#ident(e) => e.description(),
92                    }
93                } else {
94                    quote! {
95                        Self::#ident(desc, ..) => Some(&desc),
96                    }
97                }
98            }
99            Fields::Named(n) => {
100                let f = &n
101                    .named
102                    .first()
103                    .ok_or_else(|| Error::new(n.span(), "expected at least one field"))?
104                    .ident;
105                quote! {
106                    Self::#ident { #f, } => Some(#f),
107                }
108            }
109        };
110        error_descriptions.extend(e);
111
112        // The conversion for #[zbus(error)] variant is handled separately/explicitly.
113        if !error {
114            // FIXME: deserialize msg to error field instead, to support variable args
115            let e = match &variant.fields {
116                Fields::Unit => quote! {
117                    #fqn => Self::#ident,
118                },
119                Fields::Unnamed(_) => quote! {
120                    #fqn => Self::#ident(
121                        desc.map(::std::string::String::from).unwrap_or_default(),
122                    ),
123                },
124                Fields::Named(n) => {
125                    let f = &n
126                        .named
127                        .first()
128                        .ok_or_else(|| Error::new(n.span(), "expected at least one field"))?
129                        .ident;
130                    quote! {
131                        #fqn => Self::#ident {
132                            #f: desc.map(::std::string::String::from).unwrap_or_default(),
133                        },
134                    }
135                }
136            };
137            error_converts.extend(e);
138        }
139
140        let r = gen_reply_for_variant(&variant, error, &zbus)?;
141        replies.extend(r);
142    }
143
144    let from_zbus_error_impl = zbus_error_variant
145        .map(|ident| {
146            quote! {
147                impl ::std::convert::From<#zbus::Error> for #name {
148                    fn from(value: #zbus::Error) -> #name {
149                        match &value {
150                            #zbus::Error::MethodError(name, desc, _) => {
151                                let desc = desc.as_deref();
152                                match name.as_str() {
153                                    #error_converts
154                                    _ => Self::#ident(value),
155                                }
156                            }
157                            #zbus::Error::FDO(e) => {
158                                let e = ::std::convert::AsRef::as_ref(e);
159                                let name = #zbus::DBusError::name(e);
160                                let desc = #zbus::DBusError::description(e);
161                                match name.as_str() {
162                                    #error_converts
163                                    _ => Self::#ident(value),
164                                }
165                            }
166                            _ => Self::#ident(value),
167                        }
168                    }
169                }
170            }
171        })
172        .unwrap_or_default();
173
174    let display_impl = if generate_display {
175        quote! {
176            impl ::std::fmt::Display for #name {
177                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
178                    let name = #zbus::DBusError::name(self);
179                    let description = #zbus::DBusError::description(self).unwrap_or("no description");
180                    ::std::write!(f, "{}: {}", name, description)
181                }
182            }
183        }
184    } else {
185        quote! {}
186    };
187
188    Ok(quote! {
189        impl #zbus::DBusError for #name {
190            fn name(&self) -> #zbus::names::ErrorName {
191                match self {
192                    #error_names
193                }
194            }
195
196            fn description(&self) -> Option<&str> {
197                match self {
198                    #error_descriptions
199                }
200            }
201
202            fn create_reply(&self, call: &#zbus::message::Header) -> #zbus::Result<#zbus::message::Message> {
203                let name = self.name();
204                match self {
205                    #replies
206                }
207            }
208        }
209
210        #display_impl
211
212        impl ::std::error::Error for #name {}
213
214        #from_zbus_error_impl
215    })
216}
217
218fn gen_reply_for_variant(
219    variant: &Variant,
220    zbus_error_variant: bool,
221    zbus: &TokenStream,
222) -> Result<TokenStream, Error> {
223    let ident = &variant.ident;
224    match &variant.fields {
225        Fields::Unit => Ok(quote! {
226            Self::#ident => #zbus::message::Message::error(call, name)?.build(&()),
227        }),
228        Fields::Unnamed(f) => {
229            // Name the unnamed fields as the number of the field with an 'f' in front.
230            let in_fields = (0..f.unnamed.len())
231                .map(|n| Ident::new(&format!("f{n}"), ident.span()).to_token_stream())
232                .collect::<Vec<_>>();
233            let out_fields = if zbus_error_variant {
234                let error_field = in_fields.first().ok_or_else(|| {
235                    Error::new(
236                        ident.span(),
237                        "expected at least one field for #[zbus(error)] variant",
238                    )
239                })?;
240                vec![quote! {
241                    match #error_field {
242                        #zbus::Error::MethodError(name, desc, _) => {
243                            ::std::clone::Clone::clone(desc)
244                        }
245                        _ => None,
246                    }
247                    .unwrap_or_else(|| ::std::string::ToString::to_string(#error_field))
248                }]
249            } else {
250                // FIXME: Workaround for https://github.com/rust-lang/rust-clippy/issues/10577
251                #[allow(clippy::redundant_clone)]
252                in_fields.clone()
253            };
254
255            Ok(quote! {
256                Self::#ident(#(#in_fields),*) => #zbus::message::Message::error(call, name)?.build(&(#(#out_fields),*)),
257            })
258        }
259        Fields::Named(f) => {
260            let fields = f.named.iter().map(|v| v.ident.as_ref()).collect::<Vec<_>>();
261            Ok(quote! {
262                Self::#ident { #(#fields),* } => #zbus::message::Message::error(call, name)?.build(&(#(#fields),*)),
263            })
264        }
265    }
266}