strum_macros/macros/
enum_table.rs1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote};
3use syn::{spanned::Spanned, Data, DeriveInput, Fields};
4
5use crate::helpers::{non_enum_error, snakify, HasStrumVariantProperties};
6
7pub fn enum_table_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
8 let name = &ast.ident;
9 let gen = &ast.generics;
10 let vis = &ast.vis;
11 let mut doc_comment = format!("A map over the variants of `{}`", name);
12
13 if gen.lifetimes().count() > 0 {
14 return Err(syn::Error::new(
15 Span::call_site(),
16 "`EnumTable` doesn't support enums with lifetimes.",
17 ));
18 }
19
20 let variants = match &ast.data {
21 Data::Enum(v) => &v.variants,
22 _ => return Err(non_enum_error()),
23 };
24
25 let table_name = format_ident!("{}Table", name);
26
27 let mut pascal_idents = Vec::new();
29 let mut snake_idents = Vec::new();
31 let mut get_matches = Vec::new();
33 let mut get_matches_mut = Vec::new();
35 let mut set_matches = Vec::new();
37 let mut closure_fields = Vec::new();
39 let mut transform_fields = Vec::new();
41
42 let mut disabled_variants = Vec::new();
44 let mut disabled_matches = Vec::new();
46
47 for variant in variants {
48 if variant.get_variant_properties()?.disabled.is_some() {
50 let disabled_ident = &variant.ident;
51 let panic_message = format!(
52 "Can't use `{}` with `{}` - variant is disabled for Strum features",
53 disabled_ident, table_name
54 );
55 disabled_variants.push(disabled_ident);
56 disabled_matches.push(quote!(#name::#disabled_ident => panic!(#panic_message),));
57 continue;
58 }
59
60 if !matches!(variant.fields, Fields::Unit) {
62 return Err(syn::Error::new(
63 variant.fields.span(),
64 "`EnumTable` doesn't support enums with non-unit variants",
65 ));
66 };
67
68 let pascal_case = &variant.ident;
69 let snake_case = format_ident!("_{}", snakify(&pascal_case.to_string()));
70
71 get_matches.push(quote! {#name::#pascal_case => &self.#snake_case,});
72 get_matches_mut.push(quote! {#name::#pascal_case => &mut self.#snake_case,});
73 set_matches.push(quote! {#name::#pascal_case => self.#snake_case = new_value,});
74 closure_fields.push(quote! {#snake_case: func(#name::#pascal_case),});
75 transform_fields.push(quote! {#snake_case: func(#name::#pascal_case, &self.#snake_case),});
76 pascal_idents.push(pascal_case);
77 snake_idents.push(snake_case);
78 }
79
80 if pascal_idents.is_empty() {
82 return Err(syn::Error::new(
83 variants.span(),
84 "`EnumTable` requires at least one non-disabled variant",
85 ));
86 }
87
88 if !disabled_variants.is_empty() {
90 doc_comment.push_str(&format!(
91 "\n# Panics\nIndexing `{}` with any of the following variants will cause a panic:",
92 table_name
93 ));
94 for variant in disabled_variants {
95 doc_comment.push_str(&format!("\n\n- `{}::{}`", name, variant));
96 }
97 }
98
99 let doc_new = format!(
100 "Create a new {} with a value for each variant of {}",
101 table_name, name
102 );
103 let doc_closure = format!(
104 "Create a new {} by running a function on each variant of `{}`",
105 table_name, name
106 );
107 let doc_transform = format!("Create a new `{}` by running a function on each variant of `{}` and the corresponding value in the current `{0}`", table_name, name);
108 let doc_filled = format!(
109 "Create a new `{}` with the same value in each field.",
110 table_name
111 );
112 let doc_option_all = format!("Converts `{}<Option<T>>` into `Option<{0}<T>>`. Returns `Some` if all fields are `Some`, otherwise returns `None`.", table_name);
113 let doc_result_all_ok = format!("Converts `{}<Result<T, E>>` into `Result<{0}<T>, E>`. Returns `Ok` if all fields are `Ok`, otherwise returns `Err`.", table_name);
114
115 Ok(quote! {
116 #[doc = #doc_comment]
117 #[allow(
118 missing_copy_implementations,
119 )]
120 #[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
121 #vis struct #table_name<T> {
122 #(#snake_idents: T,)*
123 }
124
125 #[automatically_derived]
126 impl<T: Clone> #table_name<T> {
127 #[doc = #doc_filled]
128 #vis fn filled(value: T) -> #table_name<T> {
129 #table_name {
130 #(#snake_idents: value.clone(),)*
131 }
132 }
133 }
134
135 #[automatically_derived]
136 impl<T> #table_name<T> {
137 #[doc = #doc_new]
138 #[inline]
139 #vis fn new(
140 #(#snake_idents: T,)*
141 ) -> #table_name<T> {
142 #table_name {
143 #(#snake_idents,)*
144 }
145 }
146
147 #[doc = #doc_closure]
148 #[inline]
149 #vis fn from_closure<F: FnMut(#name)->T>(mut func: F) -> #table_name<T> {
150 #table_name {
151 #(#closure_fields)*
152 }
153 }
154
155 #[doc = #doc_transform]
156 #[inline]
157 #vis fn transform<U, F: FnMut(#name, &T)->U>(&self, mut func: F) -> #table_name<U> {
158 #table_name {
159 #(#transform_fields)*
160 }
161 }
162
163 }
164
165 #[automatically_derived]
166 impl<T> ::core::ops::Index<#name> for #table_name<T> {
167 type Output = T;
168
169 #[inline]
170 fn index(&self, idx: #name) -> &T {
171 match idx {
172 #(#get_matches)*
173 #(#disabled_matches)*
174 }
175 }
176 }
177
178 #[automatically_derived]
179 impl<T> ::core::ops::IndexMut<#name> for #table_name<T> {
180 #[inline]
181 fn index_mut(&mut self, idx: #name) -> &mut T {
182 match idx {
183 #(#get_matches_mut)*
184 #(#disabled_matches)*
185 }
186 }
187 }
188
189 #[automatically_derived]
190 impl<T> #table_name<::core::option::Option<T>> {
191 #[doc = #doc_option_all]
192 #[inline]
193 #vis fn all(self) -> ::core::option::Option<#table_name<T>> {
194 if let #table_name {
195 #(#snake_idents: ::core::option::Option::Some(#snake_idents),)*
196 } = self {
197 ::core::option::Option::Some(#table_name {
198 #(#snake_idents,)*
199 })
200 } else {
201 ::core::option::Option::None
202 }
203 }
204 }
205
206 #[automatically_derived]
207 impl<T, E> #table_name<::core::result::Result<T, E>> {
208 #[doc = #doc_result_all_ok]
209 #[inline]
210 #vis fn all_ok(self) -> ::core::result::Result<#table_name<T>, E> {
211 ::core::result::Result::Ok(#table_name {
212 #(#snake_idents: self.#snake_idents?,)*
213 })
214 }
215 }
216 })
217}