derive_more_impl/ops/
add.rs1#[cfg(doc)]
4use std::ops;
5
6use proc_macro2::TokenStream;
7use quote::{format_ident, ToTokens as _};
8use syn::spanned::Spanned as _;
9
10use super::{SkippedFields, StructuralExpansion};
11use crate::utils::attr::{self, ParseMultiple as _};
12
13pub fn expand(input: &syn::DeriveInput, trait_name: &str) -> syn::Result<TokenStream> {
22 let trait_name = normalize_trait_name(trait_name);
23 let attr_name = format_ident!("{}", trait_name_to_attribute_name(trait_name));
24
25 let mut variants = vec![];
26 match &input.data {
27 syn::Data::Struct(data) => {
28 if let Some(skip) = attr::Skip::parse_attrs(&input.attrs, &attr_name)? {
29 return Err(syn::Error::new(
30 skip.span,
31 format!(
32 "`#[{attr_name}({})]` attribute can be placed only on struct fields",
33 skip.item.name(),
34 ),
35 ));
36 } else if matches!(data.fields, syn::Fields::Unit) {
37 return Err(syn::Error::new(
38 data.struct_token.span(),
39 format!("`{trait_name}` cannot be derived for unit structs"),
40 ));
41 }
42 let mut skipped_fields = SkippedFields::default();
43 for (n, field) in data.fields.iter().enumerate() {
44 if attr::Skip::parse_attrs(&field.attrs, &attr_name)?.is_some() {
45 _ = skipped_fields.insert(n);
46 }
47 }
48 if data.fields.len() == skipped_fields.len() {
49 return Err(syn::Error::new(
50 data.struct_token.span(),
51 format!(
52 "`{trait_name}` cannot be derived for structs with all the fields being \
53 skipped",
54 ),
55 ));
56 }
57 variants.push((None, &data.fields, skipped_fields));
58 }
59 syn::Data::Enum(data) => {
60 if let Some(skip) = attr::Skip::parse_attrs(&input.attrs, &attr_name)? {
61 return Err(syn::Error::new(
62 skip.span,
63 format!(
64 "`#[{attr_name}({})]` attribute can be placed only on enum fields",
65 skip.item.name(),
66 ),
67 ));
68 }
69 for variant in &data.variants {
70 if let Some(skip) = attr::Skip::parse_attrs(&variant.attrs, &attr_name)?
71 {
72 return Err(syn::Error::new(
73 skip.span,
74 format!(
75 "`#[{attr_name}({})]` attribute can be placed only on variant fields",
76 skip.item.name(),
77 ),
78 ));
79 }
80 let mut skipped_fields = SkippedFields::default();
81 for (n, field) in variant.fields.iter().enumerate() {
82 if attr::Skip::parse_attrs(&field.attrs, &attr_name)?.is_some() {
83 _ = skipped_fields.insert(n);
84 }
85 }
86 if !matches!(variant.fields, syn::Fields::Unit)
87 && variant.fields.len() == skipped_fields.len()
88 {
89 return Err(syn::Error::new(
90 variant.span(),
91 format!(
92 "`{trait_name}` cannot be derived for enum with all the fields being \
93 skipped in its variants",
94 ),
95 ));
96 }
97 variants.push((Some(&variant.ident), &variant.fields, skipped_fields));
98 }
99 }
100 syn::Data::Union(data) => {
101 return Err(syn::Error::new(
102 data.union_token.span(),
103 format!("`{trait_name}` cannot be derived for unions"),
104 ));
105 }
106 }
107
108 Ok(StructuralExpansion {
109 trait_ty: format_ident!("{trait_name}"),
110 method_ident: format_ident!("{}", trait_name_to_method_name(trait_name)),
111 self_ty: (&input.ident, &input.generics),
112 variants,
113 is_enum: matches!(input.data, syn::Data::Enum(_)),
114 }
115 .into_token_stream())
116}
117
118fn normalize_trait_name(name: &str) -> &'static str {
120 match name {
121 "Add" => "Add",
122 "BitAnd" => "BitAnd",
123 "BitOr" => "BitOr",
124 "BitXor" => "BitXor",
125 "Sub" => "Sub",
126 _ => unimplemented!(),
127 }
128}
129
130fn trait_name_to_attribute_name(name: &str) -> &'static str {
132 trait_name_to_method_name(name)
133}
134
135fn trait_name_to_method_name(name: &str) -> &'static str {
137 match name {
138 "Add" => "add",
139 "BitAnd" => "bitand",
140 "BitOr" => "bitor",
141 "BitXor" => "bitxor",
142 "Sub" => "sub",
143 _ => unimplemented!(),
144 }
145}