derive_more_impl/ops/
mod.rs

1//! Implementations of [`ops`]-related derive macros.
2//!
3//! [`ops`]: std::ops
4
5#[cfg(feature = "add")]
6pub(crate) mod add;
7#[cfg(feature = "add_assign")]
8pub(crate) mod add_assign;
9#[cfg(feature = "mul")]
10pub(crate) mod mul;
11#[cfg(feature = "mul_assign")]
12pub(crate) mod mul_assign;
13
14use proc_macro2::TokenStream;
15use quote::{format_ident, quote, ToTokens};
16use syn::parse_quote;
17
18#[cfg(doc)]
19use crate::utils::attr;
20use crate::utils::{
21    pattern_matching::FieldsExt as _, structural_inclusion::TypeExt as _,
22    GenericsSearch, HashSet,
23};
24
25/// Indices of [`syn::Field`]s marked with an [`attr::Skip`].
26type SkippedFields = HashSet<usize>;
27
28#[cfg(any(feature = "add_assign", feature = "mul_assign"))]
29/// Expansion of a macro for generating a structural trait implementation with a `&mut self` method
30/// receiver for an enum or a struct.
31struct AssignStructuralExpansion<'i> {
32    /// [`syn::Ident`] of the implemented trait.
33    ///
34    /// [`syn::Ident`]: struct@syn::Ident
35    trait_ty: syn::Ident,
36
37    /// [`syn::Ident`] and [`syn::Receiver`] of the implemented method in trait.
38    ///
39    /// [`syn::Ident`]: struct@syn::Ident
40    method_ident: syn::Ident,
41
42    /// [`syn::Ident`] and [`syn::Generics`] of the implementor enum/struct.
43    ///
44    /// [`syn::Ident`]: struct@syn::Ident
45    self_ty: (&'i syn::Ident, &'i syn::Generics),
46
47    /// [`syn::Fields`] of the enum/struct to be used in this [`AssignStructuralExpansion`].
48    variants: Vec<(Option<&'i syn::Ident>, &'i syn::Fields, SkippedFields)>,
49
50    /// Indicator whether this expansion is for an enum.
51    is_enum: bool,
52}
53
54#[cfg(any(feature = "add_assign", feature = "mul_assign"))]
55impl AssignStructuralExpansion<'_> {
56    /// Generates body of the method implementation for this [`StructuralExpansion`].
57    fn body(&self) -> TokenStream {
58        let method_path = {
59            let trait_ty = &self.trait_ty;
60            let method_ident = &self.method_ident;
61
62            quote! { derive_more::core::ops::#trait_ty::#method_ident }
63        };
64
65        let match_arms = self
66            .variants
67            .iter()
68            .map(|(variant, all_fields, skipped_fields)| {
69                let variant = variant.map(|variant| quote! { :: #variant });
70                let self_pat = all_fields.non_exhaustive_arm_pattern("__self_", skipped_fields);
71                let rhs_pat = all_fields.non_exhaustive_arm_pattern("__rhs_", skipped_fields);
72
73                let fields_exprs = (0..all_fields.len())
74                    .filter(|num| !skipped_fields.contains(num))
75                    .map(|num| {
76                        let self_val = format_ident!("__self_{num}");
77                        let rhs_val = format_ident!("__rhs_{num}");
78
79                        quote! { #method_path(#self_val, #rhs_val); }
80                    });
81
82                quote! {
83                    (Self #variant #self_pat, Self #variant #rhs_pat) => { #( #fields_exprs )* }
84                }
85            })
86            .collect::<Vec<_>>();
87
88        let wrong_variant_arm = (self.is_enum && match_arms.len() > 1).then(|| {
89            quote! { _ => {} }
90        });
91
92        quote! {
93            match (self, __rhs) {
94                #( #match_arms )*
95                #wrong_variant_arm
96            }
97        }
98    }
99}
100
101#[cfg(any(feature = "add_assign", feature = "mul_assign"))]
102impl ToTokens for AssignStructuralExpansion<'_> {
103    fn to_tokens(&self, tokens: &mut TokenStream) {
104        let trait_ty = &self.trait_ty;
105        let method_ident = &self.method_ident;
106
107        let ty = self.self_ty.0;
108        let (_, ty_generics, _) = self.self_ty.1.split_for_impl();
109        let implementor_ty: syn::Type = parse_quote! { #ty #ty_generics };
110        let self_ty: syn::Type = parse_quote! { Self };
111
112        let generics_search = GenericsSearch::from(self.self_ty.1);
113        let mut generics = self.self_ty.1.clone();
114        for (_, all_fields, skipped_fields) in &self.variants {
115            for field_ty in all_fields.iter().enumerate().filter_map(|(n, field)| {
116                (!skipped_fields.contains(&n)).then_some(&field.ty)
117            }) {
118                if generics_search.any_in(field_ty)
119                    && !field_ty.contains_type_structurally(&self_ty)
120                    && !field_ty.contains_type_structurally(&implementor_ty)
121                {
122                    generics.make_where_clause().predicates.push(parse_quote! {
123                        #field_ty: derive_more::core::ops:: #trait_ty
124                    });
125                }
126            }
127        }
128        let (impl_generics, _, where_clause) = generics.split_for_impl();
129
130        let body = self.body();
131
132        quote! {
133            #[allow(private_bounds)]
134            #[automatically_derived]
135            impl #impl_generics derive_more::core::ops:: #trait_ty for #implementor_ty
136                 #where_clause
137            {
138                #[inline]
139                #[track_caller]
140                fn #method_ident(&mut self, __rhs: Self) {
141                    #body
142                }
143            }
144        }
145        .to_tokens(tokens);
146    }
147}
148
149#[cfg(any(feature = "add", feature = "mul"))]
150/// Expansion of a macro for generating a structural trait implementation with a `self` method
151/// receiver for an enum or a struct.
152struct StructuralExpansion<'i> {
153    /// [`syn::Ident`] of the implemented trait.
154    ///
155    /// [`syn::Ident`]: struct@syn::Ident
156    trait_ty: syn::Ident,
157
158    /// [`syn::Ident`] and [`syn::Receiver`] of the implemented method in trait.
159    ///
160    /// [`syn::Ident`]: struct@syn::Ident
161    method_ident: syn::Ident,
162
163    /// [`syn::Ident`] and [`syn::Generics`] of the implementor enum/struct.
164    ///
165    /// [`syn::Ident`]: struct@syn::Ident
166    self_ty: (&'i syn::Ident, &'i syn::Generics),
167
168    /// [`syn::Fields`] of the enum/struct to be used in this [`StructuralExpansion`].
169    variants: Vec<(Option<&'i syn::Ident>, &'i syn::Fields, SkippedFields)>,
170
171    /// Indicator whether this expansion is for an enum.
172    is_enum: bool,
173}
174
175#[cfg(any(feature = "add", feature = "mul"))]
176impl StructuralExpansion<'_> {
177    /// Generates body of the method implementation for this [`StructuralExpansion`].
178    fn body(&self) -> TokenStream {
179        // TODO: Try remove once MSRV is bumped up.
180        // Special case: empty enum.
181        if self.is_enum && self.variants.is_empty() {
182            return quote! { match self {} };
183        }
184
185        let method_name = self.method_ident.to_string();
186        let method_path = {
187            let trait_ty = &self.trait_ty;
188            let method_ident = &self.method_ident;
189
190            parse_quote! { derive_more::core::ops::#trait_ty::#method_ident }
191        };
192
193        let match_arms = self
194            .variants
195            .iter()
196            .map(|(variant, all_fields, skipped_fields)| {
197                let variant = variant.map(|variant| quote! { :: #variant });
198                let self_pat = all_fields.exhaustive_arm_pattern("__self_");
199                let rhs_pat = all_fields.exhaustive_arm_pattern("__rhs_");
200
201                let expr = if matches!(all_fields, syn::Fields::Unit) {
202                    quote! {
203                        derive_more::core::result::Result::Err(derive_more::BinaryError::Unit(
204                            derive_more::UnitError::new(#method_name)
205                        ))
206                    }
207                } else {
208                    let fields_expr = all_fields.arm_expr(&method_path, skipped_fields);
209                    if self.is_enum {
210                        quote! { derive_more::core::result::Result::Ok(Self #variant #fields_expr) }
211                    } else {
212                        quote! { Self #variant #fields_expr }
213                    }
214                };
215
216                quote! {
217                    (Self #variant #self_pat, Self #variant #rhs_pat) => #expr,
218                }
219            })
220            .collect::<Vec<_>>();
221
222        let wrong_variant_arm = (self.is_enum && match_arms.len() > 1).then(|| {
223            quote! {
224                _ => derive_more::core::result::Result::Err(derive_more::BinaryError::Mismatch(
225                    derive_more::WrongVariantError::new(#method_name)
226                )),
227            }
228        });
229
230        quote! {
231            match (self, __rhs) {
232                #( #match_arms )*
233                #wrong_variant_arm
234            }
235        }
236    }
237}
238
239#[cfg(any(feature = "add", feature = "mul"))]
240impl ToTokens for StructuralExpansion<'_> {
241    fn to_tokens(&self, tokens: &mut TokenStream) {
242        let trait_ty = &self.trait_ty;
243        let method_ident = &self.method_ident;
244
245        let ty = self.self_ty.0;
246        let (_, ty_generics, _) = self.self_ty.1.split_for_impl();
247        let implementor_ty: syn::Type = parse_quote! { #ty #ty_generics };
248        let self_ty: syn::Type = parse_quote! { Self };
249
250        let output_ty = if self.is_enum {
251            parse_quote! { derive_more::core::result::Result<#self_ty, derive_more::BinaryError> }
252        } else {
253            self_ty.clone()
254        };
255
256        let generics_search = GenericsSearch::from(self.self_ty.1);
257        let mut generics = self.self_ty.1.clone();
258        for (_, all_fields, skipped_fields) in &self.variants {
259            for field_ty in all_fields.iter().enumerate().filter_map(|(n, field)| {
260                (!skipped_fields.contains(&n)).then_some(&field.ty)
261            }) {
262                if generics_search.any_in(field_ty)
263                    && !field_ty.contains_type_structurally(&self_ty)
264                    && !field_ty.contains_type_structurally(&implementor_ty)
265                {
266                    generics.make_where_clause().predicates.push(parse_quote! {
267                        #field_ty: derive_more::core::ops:: #trait_ty <Output = #field_ty>
268                    });
269                }
270            }
271        }
272        let (impl_generics, _, where_clause) = generics.split_for_impl();
273
274        let body = self.body();
275
276        quote! {
277            #[allow(private_bounds)]
278            #[automatically_derived]
279            impl #impl_generics derive_more::core::ops:: #trait_ty for #implementor_ty
280                 #where_clause
281            {
282                type Output = #output_ty;
283
284                #[inline]
285                #[track_caller]
286                fn #method_ident(self, __rhs: Self) -> Self::Output {
287                    #body
288                }
289            }
290        }
291        .to_tokens(tokens);
292    }
293}
294
295#[cfg(any(feature = "add", feature = "mul"))]
296/// Extension of [`syn::Fields`] used by a [`StructuralExpansion`].
297trait StructuralExpansionFieldsExt {
298    /// Generates a resulting expression with these [`syn::Fields`] in a matched arm of a `match`
299    /// expression, by applying the specified method.
300    fn arm_expr(
301        &self,
302        method: &syn::Path,
303        skipped_indices: &SkippedFields,
304    ) -> TokenStream;
305}
306
307#[cfg(any(feature = "add", feature = "mul"))]
308impl StructuralExpansionFieldsExt for syn::Fields {
309    fn arm_expr(
310        &self,
311        method_path: &syn::Path,
312        skipped_indices: &SkippedFields,
313    ) -> TokenStream {
314        match self {
315            Self::Named(fields) => {
316                let fields = fields.named.iter().enumerate().map(|(num, field)| {
317                    let name = &field.ident;
318                    let self_val = format_ident!("__self_{num}");
319                    if skipped_indices.contains(&num) {
320                        quote! { #name: #self_val }
321                    } else {
322                        let rhs_val = format_ident!("__rhs_{num}");
323                        quote! { #name: #method_path(#self_val, #rhs_val) }
324                    }
325                });
326                quote! {{ #( #fields , )* }}
327            }
328            Self::Unnamed(fields) => {
329                let fields = (0..fields.unnamed.len()).map(|num| {
330                    let self_val = format_ident!("__self_{num}");
331                    if skipped_indices.contains(&num) {
332                        quote! { #self_val }
333                    } else {
334                        let rhs_val = format_ident!("__rhs_{num}");
335                        quote! { #method_path(#self_val, #rhs_val) }
336                    }
337                });
338                quote! {( #( #fields , )* )}
339            }
340            Self::Unit => quote! {},
341        }
342    }
343}