Skip to main content

strum_macros/macros/strings/
from_string.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields};
4
5use crate::helpers::{
6    missing_parse_err_attr_error, non_enum_error, occurrence_error, HasInnerVariantProperties,
7    HasStrumVariantProperties, HasTypeProperties,
8};
9
10pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
11    let name = &ast.ident;
12    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
13    let variants = match &ast.data {
14        Data::Enum(v) => &v.variants,
15        _ => return Err(non_enum_error()),
16    };
17
18    let type_properties = ast.get_type_properties()?;
19    let strum_module_path = type_properties.crate_module_path();
20
21    // It's an error to provide an err_fn but not an err_ty.
22    if type_properties.parse_err_fn.is_some() && type_properties.parse_err_ty.is_none() {
23        return Err(missing_parse_err_attr_error());
24    }
25
26    let mut default_kw = None;
27    let mut default_match_arm = None;
28
29    let mut phf_exact_match_arms = Vec::new();
30    let mut standard_match_arms = Vec::new();
31    for variant in variants {
32        let ident = &variant.ident;
33        let variant_properties = variant.get_variant_properties()?;
34
35        if variant_properties.disabled.is_some() {
36            continue;
37        }
38
39        if let Some(kw) = variant_properties.default {
40            if let Some(fst_kw) = default_kw {
41                return Err(occurrence_error(fst_kw, kw, "default"));
42            }
43
44            default_kw = Some(kw);
45
46            match &variant.fields {
47                Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
48                    default_match_arm = Some(quote! {
49                        #name::#ident(s.into())
50                    });
51                }
52                Fields::Named(ref f) if f.named.len() == 1 => {
53                    let field_name = f.named.last().unwrap().ident.as_ref().unwrap();
54                    default_match_arm = Some(quote! { #name::#ident { #field_name : s.into() } });
55                }
56                _ => {
57                    return Err(syn::Error::new_spanned(
58                        variant,
59                        "Default only works on newtype structs with a single String field",
60                    ))
61                }
62            }
63
64            continue;
65        }
66
67        let params = match &variant.fields {
68            Fields::Unit => quote! {},
69            Fields::Unnamed(fields) => {
70                if let Some(ref value) = variant_properties.default_with {
71                    let func = proc_macro2::Ident::new(&value.value(), value.span());
72                    let defaults = vec![quote! { #func() }];
73                    quote! { (#(#defaults),*) }
74                } else {
75                    let defaults =
76                        ::core::iter::repeat(quote!(Default::default())).take(fields.unnamed.len());
77                    quote! { (#(#defaults),*) }
78                }
79            }
80            Fields::Named(fields) => {
81                let mut defaults = vec![];
82                for field in &fields.named {
83                    let meta = field.get_variant_inner_properties()?;
84                    let field = field.ident.as_ref().unwrap();
85
86                    if let Some(default_with) = meta.default_with {
87                        let func =
88                            proc_macro2::Ident::new(&default_with.value(), default_with.span());
89                        defaults.push(quote! {
90                            #field: #func()
91                        });
92                    } else {
93                        defaults.push(quote! { #field: Default::default() });
94                    }
95                }
96
97                quote! { {#(#defaults),*} }
98            }
99        };
100
101        let is_ascii_case_insensitive = variant_properties
102            .ascii_case_insensitive
103            .unwrap_or(type_properties.ascii_case_insensitive);
104
105        // If we don't have any custom variants, add the default serialized name.
106        for serialization in variant_properties.get_serializations(type_properties.case_style) {
107            if type_properties.use_phf {
108                phf_exact_match_arms.push(quote! { #serialization => #name::#ident #params, });
109
110                if is_ascii_case_insensitive {
111                    // Store the lowercase and UPPERCASE variants in the phf map to capture
112                    let ser_string = serialization.value();
113
114                    let lower =
115                        syn::LitStr::new(&ser_string.to_ascii_lowercase(), serialization.span());
116                    let upper =
117                        syn::LitStr::new(&ser_string.to_ascii_uppercase(), serialization.span());
118                    phf_exact_match_arms.push(quote! { #lower => #name::#ident #params, });
119                    phf_exact_match_arms.push(quote! { #upper => #name::#ident #params, });
120                    standard_match_arms.push(quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, });
121                }
122            } else if !is_ascii_case_insensitive {
123                standard_match_arms.push(quote! { #serialization => #name::#ident #params, });
124            } else {
125                standard_match_arms.push(quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, });
126            }
127        }
128    }
129
130    // Determine the error type on FromStr and TryFrom based on what the user
131    // has configured and whether there is a default variant.
132    let is_infallible = default_match_arm.is_some();
133    let has_custom_err_ty = type_properties.parse_err_ty.is_some();
134    let err_ty = if let Some(ty) = type_properties.parse_err_ty {
135        quote! { #ty }
136    } else if is_infallible {
137        quote! { ::core::convert::Infallible }
138    } else {
139        quote! { #strum_module_path::ParseError }
140    };
141
142    // Determine the default match arm behavior based on whether the user provided a "default"
143    // or if the user provided a custom error function.
144    let default_match_arm = if let Some(default_match_arm) = default_match_arm {
145        default_match_arm
146    } else if let Some(f) = type_properties.parse_err_fn {
147        quote! { return ::core::result::Result::Err(#f(s)) }
148    } else if has_custom_err_ty {
149        // The user defined a custom error type, but not a custom error function. This is an error
150        // if the method isn't infallible.
151        return Err(missing_parse_err_attr_error());
152    } else {
153        quote! { return ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) }
154    };
155
156    let mut match_expression = if standard_match_arms.is_empty() {
157        default_match_arm
158    } else {
159        quote! {
160            match s {
161                #(#standard_match_arms)*
162                _ => #default_match_arm,
163            }
164        }
165    };
166
167    if !phf_exact_match_arms.is_empty() {
168        match_expression = quote! {
169            use #strum_module_path::_private_phf_reexport_for_macro_if_phf_feature as phf;
170            static PHF: phf::Map<&'static str, #name> = phf::phf_map! {
171                #(#phf_exact_match_arms)*
172            };
173
174            if let Some(value) = PHF.get(s).cloned() {
175                value
176            } else {
177                #match_expression
178            }
179        }
180    }
181
182    let from_impl = if is_infallible && !has_custom_err_ty {
183        quote! {
184            #[allow(clippy::use_self)]
185            #[automatically_derived]
186            impl #impl_generics ::core::convert::From<&str> for #name #ty_generics #where_clause {
187                #[inline]
188                fn from(s: &str) -> #name #ty_generics {
189                    #match_expression
190                }
191            }
192        }
193    } else {
194        quote! {
195            #[allow(clippy::use_self)]
196            #[automatically_derived]
197            impl #impl_generics ::core::convert::TryFrom<&str> for #name #ty_generics #where_clause {
198                type Error = #err_ty;
199
200                #[inline]
201                fn try_from(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::convert::TryFrom<&str>>::Error> {
202                    Ok({
203                        #match_expression
204                    })
205                }
206            }
207        }
208    };
209
210    let from_str = quote! {
211        #[allow(clippy::use_self)]
212        #[automatically_derived]
213        impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause {
214            type Err = #err_ty;
215
216            #[inline]
217            fn from_str(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::str::FromStr>::Err> {
218                <Self as ::core::convert::TryFrom<&str>>::try_from(s)
219            }
220        }
221    };
222
223    Ok(quote! {
224        #from_str
225        #from_impl
226    })
227}