zvariant_derive/
value.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{quote, ToTokens};
3use syn::{
4    spanned::Spanned, Attribute, Data, DataEnum, DeriveInput, Error, Fields, Generics, Ident,
5    Lifetime, LifetimeParam, Variant,
6};
7use zvariant_utils::macros;
8
9use crate::utils::*;
10
11pub enum ValueType {
12    Value,
13    OwnedValue,
14}
15
16pub fn expand_derive(ast: DeriveInput, value_type: ValueType) -> Result<TokenStream, Error> {
17    let zv = zvariant_path();
18
19    match &ast.data {
20        Data::Struct(ds) => match &ds.fields {
21            Fields::Named(_) | Fields::Unnamed(_) => {
22                let StructAttributes {
23                    signature,
24                    rename_all,
25                    ..
26                } = StructAttributes::parse(&ast.attrs)?;
27                let signature = signature.map(|signature| match signature.as_str() {
28                    "dict" => "a{sv}".to_string(),
29                    _ => signature,
30                });
31
32                impl_struct(
33                    value_type,
34                    ast.ident,
35                    ast.generics,
36                    &ds.fields,
37                    signature,
38                    &zv,
39                    rename_all,
40                )
41            }
42            Fields::Unit => Err(Error::new(ast.span(), "Unit structures not supported")),
43        },
44        Data::Enum(data) => impl_enum(value_type, ast.ident, ast.generics, ast.attrs, data, &zv),
45        _ => Err(Error::new(
46            ast.span(),
47            "only structs and enums are supported",
48        )),
49    }
50}
51
52fn impl_struct(
53    value_type: ValueType,
54    name: Ident,
55    generics: Generics,
56    fields: &Fields,
57    signature: Option<String>,
58    zv: &TokenStream,
59    rename_all: Option<String>,
60) -> Result<TokenStream, Error> {
61    let statc_lifetime = LifetimeParam::new(Lifetime::new("'static", Span::call_site()));
62    let (
63        value_type,
64        value_lifetime,
65        into_value_trait,
66        into_value_method,
67        into_value_error_decl,
68        into_value_ret,
69        into_value_error_transform,
70    ) = match value_type {
71        ValueType::Value => {
72            let mut lifetimes = generics.lifetimes();
73            let value_lifetime = lifetimes
74                .next()
75                .cloned()
76                .unwrap_or_else(|| statc_lifetime.clone());
77            if lifetimes.next().is_some() {
78                return Err(Error::new(
79                    name.span(),
80                    "Type with more than 1 lifetime not supported",
81                ));
82            }
83
84            (
85                quote! { #zv::Value<#value_lifetime> },
86                value_lifetime,
87                quote! { From },
88                quote! { from },
89                quote! {},
90                quote! { Self },
91                quote! {},
92            )
93        }
94        ValueType::OwnedValue => (
95            quote! { #zv::OwnedValue },
96            statc_lifetime,
97            quote! { TryFrom },
98            quote! { try_from },
99            quote! { type Error = #zv::Error; },
100            quote! { #zv::Result<Self> },
101            quote! { .map_err(::std::convert::Into::into) },
102        ),
103    };
104
105    let type_params = generics.type_params().cloned().collect::<Vec<_>>();
106    let (from_value_where_clause, into_value_where_clause) = if !type_params.is_empty() {
107        (
108            Some(quote! {
109                where
110                #(
111                    #type_params: ::std::convert::TryFrom<#zv::Value<#value_lifetime>> + #zv::Type,
112                    <#type_params as ::std::convert::TryFrom<#zv::Value<#value_lifetime>>>::Error: ::std::convert::Into<#zv::Error>
113                ),*
114            }),
115            Some(quote! {
116                where
117                #(
118                    #type_params: ::std::convert::Into<#zv::Value<#value_lifetime>> + #zv::Type
119                ),*
120            }),
121        )
122    } else {
123        (None, None)
124    };
125    let (impl_generics, ty_generics, _) = generics.split_for_impl();
126    match fields {
127        Fields::Named(_) => {
128            let field_names: Vec<_> = fields
129                .iter()
130                .map(|field| field.ident.to_token_stream())
131                .collect();
132            let (from_value_impl, into_value_impl) = match signature {
133                Some(signature) if signature == "a{sv}" => {
134                    // User wants the type to be encoded as a dict.
135                    // FIXME: Not the most efficient implementation.
136                    let (fields_init, entries_init): (TokenStream, TokenStream) = fields
137                        .iter()
138                        .map(|field| {
139                            let FieldAttributes { rename } =
140                                FieldAttributes::parse(&field.attrs).unwrap_or_default();
141                            let field_name = field.ident.to_token_stream();
142                            let key_name = rename_identifier(
143                                field.ident.as_ref().unwrap().to_string(),
144                                field.span(),
145                                rename,
146                                rename_all.as_deref(),
147                            )
148                            .unwrap_or(field_name.to_string());
149                            let convert = if macros::ty_is_option(&field.ty) {
150                                quote! {
151                                    .map(#zv::Value::downcast)
152                                    .transpose()?
153                                }
154                            } else {
155                                quote! {
156                                    .ok_or_else(|| #zv::Error::IncorrectType)?
157                                    .downcast()?
158                                }
159                            };
160
161                            let fields_init = quote! {
162                                #field_name: fields
163                                    .remove(#key_name)
164                                    #convert,
165                            };
166                            let entries_init = if macros::ty_is_option(&field.ty) {
167                                quote! {
168                                    if let Some(v) = s.#field_name {
169                                        fields.insert(
170                                            #key_name,
171                                            #zv::Value::from(v),
172                                        );
173                                    }
174                                }
175                            } else {
176                                quote! {
177                                    fields.insert(
178                                        #key_name,
179                                        #zv::Value::from(s.#field_name),
180                                    );
181                                }
182                            };
183
184                            (fields_init, entries_init)
185                        })
186                        .unzip();
187
188                    (
189                        quote! {
190                            let mut fields = <::std::collections::HashMap::<
191                                ::std::string::String,
192                                #zv::Value,
193                            >>::try_from(value)?;
194
195                            ::std::result::Result::Ok(Self { #fields_init })
196                        },
197                        quote! {
198                            let mut fields = ::std::collections::HashMap::new();
199                            #entries_init
200
201                            <#value_type>::#into_value_method(#zv::Value::from(fields))
202                                #into_value_error_transform
203                        },
204                    )
205                }
206                Some(_) | None => (
207                    quote! {
208                        let mut fields = #zv::Structure::try_from(value)?.into_fields();
209
210                        ::std::result::Result::Ok(Self {
211                            #(
212                                #field_names: fields.remove(0).downcast()?
213                            ),*
214                        })
215                    },
216                    quote! {
217                        <#value_type>::#into_value_method(#zv::StructureBuilder::new()
218                        #(
219                            .add_field(s.#field_names)
220                        )*
221                        .build().unwrap())
222                        #into_value_error_transform
223                    },
224                ),
225            };
226            Ok(quote! {
227                impl #impl_generics ::std::convert::TryFrom<#value_type> for #name #ty_generics
228                    #from_value_where_clause
229                {
230                    type Error = #zv::Error;
231
232                    #[inline]
233                    fn try_from(value: #value_type) -> #zv::Result<Self> {
234                        #from_value_impl
235                    }
236                }
237
238                impl #impl_generics #into_value_trait<#name #ty_generics> for #value_type
239                    #into_value_where_clause
240                {
241                    #into_value_error_decl
242
243                    #[inline]
244                    fn #into_value_method(s: #name #ty_generics) -> #into_value_ret {
245                        #into_value_impl
246                    }
247                }
248            })
249        }
250        Fields::Unnamed(_) if fields.iter().next().is_some() => {
251            // Newtype struct.
252            Ok(quote! {
253                impl #impl_generics ::std::convert::TryFrom<#value_type> for #name #ty_generics
254                    #from_value_where_clause
255                {
256                    type Error = #zv::Error;
257
258                    #[inline]
259                    fn try_from(value: #value_type) -> #zv::Result<Self> {
260                        ::std::convert::TryInto::try_into(value).map(Self)
261                    }
262                }
263
264                impl #impl_generics #into_value_trait<#name #ty_generics> for #value_type
265                    #into_value_where_clause
266                {
267                    #into_value_error_decl
268
269                    #[inline]
270                    fn #into_value_method(s: #name #ty_generics) -> #into_value_ret {
271                        <#value_type>::#into_value_method(s.0) #into_value_error_transform
272                    }
273                }
274            })
275        }
276        Fields::Unnamed(_) => panic!("impl_struct must not be called for tuples"),
277        Fields::Unit => panic!("impl_struct must not be called for unit structures"),
278    }
279}
280
281fn impl_enum(
282    value_type: ValueType,
283    name: Ident,
284    _generics: Generics,
285    attrs: Vec<Attribute>,
286    data: &DataEnum,
287    zv: &TokenStream,
288) -> Result<TokenStream, Error> {
289    let repr: TokenStream = match attrs.iter().find(|attr| attr.path().is_ident("repr")) {
290        Some(repr_attr) => repr_attr.parse_args()?,
291        None => quote! { u32 },
292    };
293    let enum_attrs = EnumAttributes::parse(&attrs)?;
294    let str_enum = enum_attrs
295        .signature
296        .map(|sig| sig == "s")
297        .unwrap_or_default();
298
299    let mut variant_names = vec![];
300    let mut str_values = vec![];
301    for variant in &data.variants {
302        let variant_attrs = VariantAttributes::parse(&variant.attrs)?;
303        // Ensure all variants of the enum are unit type
304        match variant.fields {
305            Fields::Unit => {
306                variant_names.push(&variant.ident);
307                if str_enum {
308                    let str_value = enum_name_for_variant(
309                        variant,
310                        variant_attrs.rename,
311                        enum_attrs.rename_all.as_ref().map(AsRef::as_ref),
312                    )?;
313                    str_values.push(str_value);
314                }
315            }
316            _ => return Err(Error::new(variant.span(), "must be a unit variant")),
317        }
318    }
319
320    let into_val = if str_enum {
321        quote! {
322            match e {
323                #(
324                    #name::#variant_names => #str_values,
325                )*
326            }
327        }
328    } else {
329        quote! { e as #repr }
330    };
331
332    let (value_type, into_value) = match value_type {
333        ValueType::Value => (
334            quote! { #zv::Value<'_> },
335            quote! {
336                impl ::std::convert::From<#name> for #zv::Value<'_> {
337                    #[inline]
338                    fn from(e: #name) -> Self {
339                        <#zv::Value as ::std::convert::From<_>>::from(#into_val)
340                    }
341                }
342            },
343        ),
344        ValueType::OwnedValue => (
345            quote! { #zv::OwnedValue },
346            quote! {
347                impl ::std::convert::TryFrom<#name> for #zv::OwnedValue {
348                    type Error = #zv::Error;
349
350                    #[inline]
351                    fn try_from(e: #name) -> #zv::Result<Self> {
352                        <#zv::OwnedValue as ::std::convert::TryFrom<_>>::try_from(
353                            <#zv::Value as ::std::convert::From<_>>::from(#into_val)
354                        )
355                    }
356                }
357            },
358        ),
359    };
360
361    let from_val = if str_enum {
362        quote! {
363            let v: #zv::Str = ::std::convert::TryInto::try_into(value)?;
364
365            ::std::result::Result::Ok(match v.as_str() {
366                #(
367                    #str_values => #name::#variant_names,
368                )*
369                _ => return ::std::result::Result::Err(#zv::Error::IncorrectType),
370            })
371        }
372    } else {
373        quote! {
374            let v: #repr = ::std::convert::TryInto::try_into(value)?;
375
376            ::std::result::Result::Ok(match v {
377                #(
378                    x if x == #name::#variant_names as #repr => #name::#variant_names
379                 ),*,
380                _ => return ::std::result::Result::Err(#zv::Error::IncorrectType),
381            })
382        }
383    };
384
385    Ok(quote! {
386        impl ::std::convert::TryFrom<#value_type> for #name {
387            type Error = #zv::Error;
388
389            #[inline]
390            fn try_from(value: #value_type) -> #zv::Result<Self> {
391                #from_val
392            }
393        }
394
395        #into_value
396    })
397}
398
399fn enum_name_for_variant(
400    v: &Variant,
401    rename_attr: Option<String>,
402    rename_all_attr: Option<&str>,
403) -> Result<String, Error> {
404    let ident = v.ident.to_string();
405
406    rename_identifier(ident, v.span(), rename_attr, rename_all_attr)
407}