sea_query_derive/iden/
write_arm.rs

1use std::convert::TryFrom;
2use std::marker::PhantomData;
3
4use heck::ToSnakeCase;
5use proc_macro2::{Span, TokenStream};
6use quote::{ToTokens, TokenStreamExt, quote};
7use syn::{Error, Fields, FieldsNamed, Ident, Variant};
8
9use super::{
10    DeriveIden, DeriveIdenStatic, attr::IdenAttr, error::ErrorMsg, find_attr, is_static_iden,
11};
12
13pub(crate) trait WriteArm {
14    fn variant(variant: TokenStream, name: TokenStream) -> TokenStream;
15    fn flattened(variant: TokenStream, name: &Ident) -> TokenStream;
16}
17
18impl WriteArm for DeriveIden {
19    fn variant(variant: TokenStream, name: TokenStream) -> TokenStream {
20        quote! { Self::#variant => #name }
21    }
22
23    fn flattened(variant: TokenStream, name: &Ident) -> TokenStream {
24        quote! { Self::#variant => #name.unquoted() }
25    }
26}
27
28impl WriteArm for DeriveIdenStatic {
29    fn variant(variant: TokenStream, name: TokenStream) -> TokenStream {
30        quote! { Self::#variant => #name }
31    }
32
33    fn flattened(variant: TokenStream, name: &Ident) -> TokenStream {
34        quote! { Self::#variant => #name.as_str() }
35    }
36}
37
38pub(crate) struct IdenVariant<'a, T> {
39    ident: &'a Ident,
40    fields: &'a Fields,
41    table_name: &'a str,
42    attr: Option<IdenAttr>,
43    _p: PhantomData<T>,
44}
45
46impl<'a, T> TryFrom<(&'a str, &'a Variant)> for IdenVariant<'a, T>
47where
48    T: WriteArm,
49{
50    type Error = Error;
51
52    fn try_from((table_name, value): (&'a str, &'a Variant)) -> Result<Self, Self::Error> {
53        let Variant {
54            ident,
55            fields,
56            attrs,
57            ..
58        } = value;
59        let attr = find_attr(attrs).map(IdenAttr::try_from).transpose()?;
60
61        Self::new(ident, fields, table_name, attr)
62    }
63}
64
65impl<T> ToTokens for IdenVariant<'_, T>
66where
67    T: WriteArm,
68{
69    fn to_tokens(&self, tokens: &mut TokenStream) {
70        match self.fields {
71            Fields::Named(named) => self.to_tokens_from_named(named, tokens),
72            Fields::Unnamed(_) => self.to_tokens_from_unnamed(tokens),
73            Fields::Unit => self.to_tokens_from_unit(tokens),
74        }
75    }
76}
77
78impl<'a, T> IdenVariant<'a, T>
79where
80    T: WriteArm,
81{
82    fn new(
83        ident: &'a Ident,
84        fields: &'a Fields,
85        table_name: &'a str,
86        attr: Option<IdenAttr>,
87    ) -> syn::Result<Self> {
88        let unsupported_error = Err(Error::new_spanned(
89            fields,
90            ErrorMsg::UnsupportedFlattenTarget,
91        ));
92        // sanity check to not have flatten on a unit variant, or variants with more than 1 field
93        if attr == Some(IdenAttr::Flatten) {
94            match fields {
95                Fields::Named(n) => {
96                    if n.named.len() != 1 {
97                        return unsupported_error;
98                    }
99                }
100                Fields::Unnamed(u) => {
101                    if u.unnamed.len() != 1 {
102                        return unsupported_error;
103                    }
104                }
105                Fields::Unit => return unsupported_error,
106            }
107        }
108
109        Ok(Self {
110            ident,
111            fields,
112            table_name,
113            attr,
114            _p: PhantomData::<T>,
115        })
116    }
117
118    fn to_tokens_from_named(&self, named: &FieldsNamed, tokens: &mut TokenStream) {
119        let ident = self.ident;
120
121        let match_arm = if self.attr == Some(IdenAttr::Flatten) {
122            // indexing is safe because len is guaranteed to be 1 from the constructor.
123            let field = &named.named[0];
124            // Unwrapping the ident is also safe because a named field always has an ident.
125            let capture = field.ident.as_ref().unwrap();
126            let variant = quote! { #ident{#capture} };
127            T::flattened(variant, capture)
128        } else {
129            let variant = quote! { #ident{..} };
130            self.write_variant_name(variant)
131        };
132
133        tokens.append_all(match_arm)
134    }
135
136    fn to_tokens_from_unnamed(&self, tokens: &mut TokenStream) {
137        let ident = self.ident;
138
139        let match_arm = if self.attr == Some(IdenAttr::Flatten) {
140            // The case where unnamed fields length is not 1 is handled by new
141            let capture = Delegated.into();
142            let variant = quote! { #ident(#capture) };
143            T::flattened(variant, &capture)
144        } else {
145            let variant = quote! { #ident(..) };
146            self.write_variant_name(variant)
147        };
148
149        tokens.append_all(match_arm)
150    }
151
152    fn to_tokens_from_unit(&self, tokens: &mut TokenStream) {
153        let ident = self.ident;
154        let variant = quote! { #ident };
155
156        tokens.append_all(self.write_variant_name(variant))
157    }
158
159    fn table_or_snake_case(&self) -> String {
160        if self.ident == "Table" {
161            self.table_name.to_owned()
162        } else {
163            self.ident.to_string().to_snake_case()
164        }
165    }
166
167    fn write_variant_name(&self, variant: TokenStream) -> TokenStream {
168        let name = self
169            .attr
170            .as_ref()
171            .map(|a| match a {
172                IdenAttr::Rename(name) => quote! { #name },
173                IdenAttr::Method(method) => quote! { self.#method() },
174                IdenAttr::Flatten => unreachable!(),
175            })
176            .unwrap_or_else(|| {
177                let name = self.table_or_snake_case();
178                quote! { #name }
179            });
180
181        T::variant(variant, name)
182    }
183
184    pub(crate) fn is_static_iden(&self) -> bool {
185        let name: String = match &self.attr {
186            Some(a) => match a {
187                IdenAttr::Rename(name) => name.to_owned(),
188                IdenAttr::Method(_) => return false,
189                IdenAttr::Flatten => return false,
190            },
191            None => self.table_or_snake_case(),
192        };
193
194        is_static_iden(&name)
195    }
196}
197
198struct Delegated;
199
200impl From<Delegated> for Ident {
201    fn from(_: Delegated) -> Self {
202        Ident::new("delegated", Span::call_site())
203    }
204}