derive_more_impl/ops/
add.rs

1//! Implementation of [`ops::Add`]-like derive macros.
2
3#[cfg(doc)]
4use std::ops;
5
6use proc_macro2::TokenStream;
7use quote::{format_ident, ToTokens as _};
8use syn::spanned::Spanned as _;
9
10use super::{SkippedFields, StructuralExpansion};
11use crate::utils::attr::{self, ParseMultiple as _};
12
13/// Expands an [`ops::Add`]-like derive macro.
14///
15/// Available macros:
16/// - [`Add`](ops::Add)
17/// - [`BitAnd`](ops::BitAnd)
18/// - [`BitOr`](ops::BitOr)
19/// - [`BitXor`](ops::BitXor)
20/// - [`Sub`](ops::Sub)
21pub fn expand(input: &syn::DeriveInput, trait_name: &str) -> syn::Result<TokenStream> {
22    let trait_name = normalize_trait_name(trait_name);
23    let attr_name = format_ident!("{}", trait_name_to_attribute_name(trait_name));
24
25    let mut variants = vec![];
26    match &input.data {
27        syn::Data::Struct(data) => {
28            if let Some(skip) = attr::Skip::parse_attrs(&input.attrs, &attr_name)? {
29                return Err(syn::Error::new(
30                    skip.span,
31                    format!(
32                        "`#[{attr_name}({})]` attribute can be placed only on struct fields",
33                        skip.item.name(),
34                    ),
35                ));
36            } else if matches!(data.fields, syn::Fields::Unit) {
37                return Err(syn::Error::new(
38                    data.struct_token.span(),
39                    format!("`{trait_name}` cannot be derived for unit structs"),
40                ));
41            }
42            let mut skipped_fields = SkippedFields::default();
43            for (n, field) in data.fields.iter().enumerate() {
44                if attr::Skip::parse_attrs(&field.attrs, &attr_name)?.is_some() {
45                    _ = skipped_fields.insert(n);
46                }
47            }
48            if data.fields.len() == skipped_fields.len() {
49                return Err(syn::Error::new(
50                    data.struct_token.span(),
51                    format!(
52                        "`{trait_name}` cannot be derived for structs with all the fields being \
53                         skipped",
54                    ),
55                ));
56            }
57            variants.push((None, &data.fields, skipped_fields));
58        }
59        syn::Data::Enum(data) => {
60            if let Some(skip) = attr::Skip::parse_attrs(&input.attrs, &attr_name)? {
61                return Err(syn::Error::new(
62                    skip.span,
63                    format!(
64                        "`#[{attr_name}({})]` attribute can be placed only on enum fields",
65                        skip.item.name(),
66                    ),
67                ));
68            }
69            for variant in &data.variants {
70                if let Some(skip) = attr::Skip::parse_attrs(&variant.attrs, &attr_name)?
71                {
72                    return Err(syn::Error::new(
73                        skip.span,
74                        format!(
75                            "`#[{attr_name}({})]` attribute can be placed only on variant fields",
76                            skip.item.name(),
77                        ),
78                    ));
79                }
80                let mut skipped_fields = SkippedFields::default();
81                for (n, field) in variant.fields.iter().enumerate() {
82                    if attr::Skip::parse_attrs(&field.attrs, &attr_name)?.is_some() {
83                        _ = skipped_fields.insert(n);
84                    }
85                }
86                if !matches!(variant.fields, syn::Fields::Unit)
87                    && variant.fields.len() == skipped_fields.len()
88                {
89                    return Err(syn::Error::new(
90                        variant.span(),
91                        format!(
92                            "`{trait_name}` cannot be derived for enum with all the fields being \
93                             skipped in its variants",
94                        ),
95                    ));
96                }
97                variants.push((Some(&variant.ident), &variant.fields, skipped_fields));
98            }
99        }
100        syn::Data::Union(data) => {
101            return Err(syn::Error::new(
102                data.union_token.span(),
103                format!("`{trait_name}` cannot be derived for unions"),
104            ));
105        }
106    }
107
108    Ok(StructuralExpansion {
109        trait_ty: format_ident!("{trait_name}"),
110        method_ident: format_ident!("{}", trait_name_to_method_name(trait_name)),
111        self_ty: (&input.ident, &input.generics),
112        variants,
113        is_enum: matches!(input.data, syn::Data::Enum(_)),
114    }
115    .into_token_stream())
116}
117
118/// Matches the provided derive macro `name` to appropriate actual trait name.
119fn normalize_trait_name(name: &str) -> &'static str {
120    match name {
121        "Add" => "Add",
122        "BitAnd" => "BitAnd",
123        "BitOr" => "BitOr",
124        "BitXor" => "BitXor",
125        "Sub" => "Sub",
126        _ => unimplemented!(),
127    }
128}
129
130/// Matches the provided [`ops::Add`]-like trait `name` to its attribute's name.
131fn trait_name_to_attribute_name(name: &str) -> &'static str {
132    trait_name_to_method_name(name)
133}
134
135/// Matches the provided [`ops::Add`]-like trait `name` to its method name.
136fn trait_name_to_method_name(name: &str) -> &'static str {
137    match name {
138        "Add" => "add",
139        "BitAnd" => "bitand",
140        "BitOr" => "bitor",
141        "BitXor" => "bitxor",
142        "Sub" => "sub",
143        _ => unimplemented!(),
144    }
145}