strum_macros/macros/
enum_discriminants.rs1use proc_macro2::{Span, TokenStream, TokenTree};
2use quote::{quote, ToTokens};
3use syn::parse_quote;
4use syn::{Data, DeriveInput, Fields};
5
6use crate::helpers::{non_enum_error, strum_discriminants_passthrough_error, HasTypeProperties};
7
8const ATTRIBUTES_TO_COPY: &[&str] = &["doc", "cfg", "allow", "deny", "strum_discriminants"];
13
14pub fn enum_discriminants_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
15 let name = &ast.ident;
16 let vis = &ast.vis;
17
18 let variants = match &ast.data {
19 Data::Enum(v) => &v.variants,
20 _ => return Err(non_enum_error()),
21 };
22
23 let type_properties = ast.get_type_properties()?;
25 let strum_module_path = type_properties.crate_module_path();
26
27 let mut derives = type_properties.discriminant_derives;
28 let mut discriminants = Vec::new();
29
30 let mut has_default_variant = false;
34 for variant in variants {
35 let ident = &variant.ident;
36 let mut has_default = false;
37
38 let discriminant = variant
39 .discriminant
40 .as_ref()
41 .map(|(_, expr)| quote!( = #expr));
42
43 let mut attrs = Vec::new();
46 for attr in &variant.attrs {
47 if attr.path().is_ident("default") {
48 has_default = true;
49 has_default_variant = true;
50 }
51
52 if !ATTRIBUTES_TO_COPY
53 .iter()
54 .any(|whitelisted| attr.path().is_ident(whitelisted))
55 {
56 continue;
57 }
58
59 if attr.path().is_ident("strum_discriminants") {
60 let mut ts = attr.meta.require_list()?.to_token_stream().into_iter();
61
62 let _ = ts.next();
64
65 let passthrough_group = ts
66 .next()
67 .ok_or_else(|| strum_discriminants_passthrough_error(attr))?;
68
69 let passthrough_attribute = match passthrough_group {
70 TokenTree::Group(ref group) => group.stream(),
71 _ => {
72 return Err(strum_discriminants_passthrough_error(&passthrough_group));
73 }
74 };
75 if passthrough_attribute.is_empty() {
76 return Err(strum_discriminants_passthrough_error(&passthrough_group));
77 }
78
79 attrs.push(quote! { #[#passthrough_attribute] });
80 continue;
81 }
82
83 attrs.push(attr.to_token_stream());
85 }
86
87 let default_attr = if has_default {
88 quote! { #[default] }
89 } else {
90 quote! {}
91 };
92
93 discriminants.push(quote! { #default_attr #(#attrs)* #ident #discriminant });
94 }
95
96 if has_default_variant {
99 derives.push(parse_quote!(::core::default::Default));
100 }
101
102 let derives = quote! {
103 #[derive(Clone, Copy, Debug, PartialEq, Eq, #(#derives),*)]
104 };
105
106 let default_name = syn::Ident::new(&format!("{}Discriminants", name), Span::call_site());
108
109 let discriminants_name = type_properties.discriminant_name.unwrap_or(default_name);
110 let discriminants_vis = type_properties.discriminant_vis.as_ref().unwrap_or(vis);
111
112 let pass_through_attributes = type_properties.discriminant_others;
114 let has_doc = pass_through_attributes
115 .iter()
116 .any(|meta| meta.path().is_ident("doc"));
117 let mut pass_through_attributes: Vec<_> = pass_through_attributes
118 .into_iter()
119 .map(ToTokens::into_token_stream)
120 .collect();
121 if !has_doc {
122 pass_through_attributes.push(quote! {
123 doc = "Auto-generated discriminant enum variants"
124 });
125 }
126
127 let repr = type_properties.enum_repr.map(|repr| quote!(#[repr(#repr)]));
128
129 let arms = variants
149 .iter()
150 .map(|variant| {
151 let ident = &variant.ident;
152 let params = match &variant.fields {
153 Fields::Unit => quote! {},
154 Fields::Unnamed(_fields) => {
155 quote! { (..) }
156 }
157 Fields::Named(_fields) => {
158 quote! { { .. } }
159 }
160 };
161
162 quote! { #name::#ident #params => #discriminants_name::#ident }
163 })
164 .collect::<Vec<_>>();
165
166 let from_fn_body = if variants.is_empty() {
167 quote! { unreachable!()}
169 } else {
170 quote! { match val { #(#arms),* } }
171 };
172
173 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
174 let impl_from = quote! {
175 #[automatically_derived]
176 impl #impl_generics ::core::convert::From< #name #ty_generics > for #discriminants_name #where_clause {
177 #[inline]
178 fn from(val: #name #ty_generics) -> #discriminants_name {
179 #from_fn_body
180 }
181 }
182 };
183 let impl_from_ref = {
184 let mut generics = ast.generics.clone();
185
186 let lifetime = parse_quote!('_enum);
187 let enum_life = quote! { & #lifetime };
188 generics.params.push(lifetime);
189
190 let (impl_generics, _, _) = generics.split_for_impl();
192
193 quote! {
194 #[automatically_derived]
195 impl #impl_generics ::core::convert::From< #enum_life #name #ty_generics > for #discriminants_name #where_clause {
196 #[inline]
197 fn from(val: #enum_life #name #ty_generics) -> #discriminants_name {
198 #from_fn_body
199 }
200 }
201 }
202 };
203
204 let impl_into_discriminant = match type_properties.discriminant_vis {
206 None | Some(syn::Visibility::Public(..)) => quote! {
208 #[automatically_derived]
209 impl #impl_generics #strum_module_path::IntoDiscriminant for #name #ty_generics #where_clause {
210 type Discriminant = #discriminants_name;
211
212 #[inline]
213 fn discriminant(&self) -> Self::Discriminant {
214 <Self::Discriminant as ::core::convert::From<&Self>>::from(self)
215 }
216 }
217 },
218 _ => quote! {},
223 };
224
225 Ok(quote! {
226 #derives
227 #repr
228 #(#[ #pass_through_attributes ])*
229 #discriminants_vis enum #discriminants_name {
230 #(#discriminants),*
231 }
232
233 #impl_into_discriminant
234 #impl_from
235 #impl_from_ref
236 })
237}