bytemuck_derive/
lib.rs

1//! Derive macros for [bytemuck](https://docs.rs/bytemuck) traits.
2
3extern crate proc_macro;
4
5mod traits;
6
7use proc_macro2::TokenStream;
8use quote::quote;
9use syn::{parse_macro_input, DeriveInput, Result};
10
11use crate::traits::{
12  bytemuck_crate_name, AnyBitPattern, CheckedBitPattern, Contiguous, Derivable,
13  NoUninit, Pod, TransparentWrapper, Zeroable,
14};
15
16/// Derive the `Pod` trait for a struct
17///
18/// The macro ensures that the struct follows all the the safety requirements
19/// for the `Pod` trait.
20///
21/// The following constraints need to be satisfied for the macro to succeed
22///
23/// - All fields in the struct must implement `Pod`
24/// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
25/// - The struct must not contain any padding bytes
26/// - The struct contains no generic parameters, if it is not
27///   `#[repr(transparent)]`
28///
29/// ## Examples
30///
31/// ```rust
32/// # use std::marker::PhantomData;
33/// # use bytemuck_derive::{Pod, Zeroable};
34/// #[derive(Copy, Clone, Pod, Zeroable)]
35/// #[repr(C)]
36/// struct Test {
37///   a: u16,
38///   b: u16,
39/// }
40///
41/// #[derive(Copy, Clone, Pod, Zeroable)]
42/// #[repr(transparent)]
43/// struct Generic<A, B> {
44///   a: A,
45///   b: PhantomData<B>,
46/// }
47/// ```
48///
49/// If the struct is generic, it must be `#[repr(transparent)]` also.
50///
51/// ```compile_fail
52/// # use bytemuck::{Pod, Zeroable};
53/// # use std::marker::PhantomData;
54/// #[derive(Copy, Clone, Pod, Zeroable)]
55/// #[repr(C)] // must be `#[repr(transparent)]`
56/// struct Generic<A> {
57///   a: A,
58/// }
59/// ```
60///
61/// If the struct is generic and `#[repr(transparent)]`, then it is only `Pod`
62/// when all of its generics are `Pod`, not just its fields.
63///
64/// ```
65/// # use bytemuck::{Pod, Zeroable};
66/// # use std::marker::PhantomData;
67/// #[derive(Copy, Clone, Pod, Zeroable)]
68/// #[repr(transparent)]
69/// struct Generic<A, B> {
70///   a: A,
71///   b: PhantomData<B>,
72/// }
73///
74/// let _: u32 = bytemuck::cast(Generic { a: 4u32, b: PhantomData::<u32> });
75/// ```
76///
77/// ```compile_fail
78/// # use bytemuck::{Pod, Zeroable};
79/// # use std::marker::PhantomData;
80/// # #[derive(Copy, Clone, Pod, Zeroable)]
81/// # #[repr(transparent)]
82/// # struct Generic<A, B> {
83/// #   a: A,
84/// #   b: PhantomData<B>,
85/// # }
86/// struct NotPod;
87///
88/// let _: u32 = bytemuck::cast(Generic { a: 4u32, b: PhantomData::<NotPod> });
89/// ```
90#[proc_macro_derive(Pod, attributes(bytemuck))]
91pub fn derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
92  let expanded =
93    derive_marker_trait::<Pod>(parse_macro_input!(input as DeriveInput));
94
95  proc_macro::TokenStream::from(expanded)
96}
97
98/// Derive the `AnyBitPattern` trait for a struct
99///
100/// The macro ensures that the struct follows all the the safety requirements
101/// for the `AnyBitPattern` trait.
102///
103/// The following constraints need to be satisfied for the macro to succeed
104///
105/// - All fields in the struct must to implement `AnyBitPattern`
106#[proc_macro_derive(AnyBitPattern, attributes(bytemuck))]
107pub fn derive_anybitpattern(
108  input: proc_macro::TokenStream,
109) -> proc_macro::TokenStream {
110  let expanded = derive_marker_trait::<AnyBitPattern>(parse_macro_input!(
111    input as DeriveInput
112  ));
113
114  proc_macro::TokenStream::from(expanded)
115}
116
117/// Derive the `Zeroable` trait for a type.
118///
119/// The macro ensures that the type follows all the the safety requirements
120/// for the `Zeroable` trait.
121///
122/// The following constraints need to be satisfied for the macro to succeed on a
123/// struct:
124///
125/// - All fields in the struct must implement `Zeroable`
126///
127/// The following constraints need to be satisfied for the macro to succeed on
128/// an enum:
129///
130/// - The enum has an explicit `#[repr(Int)]`, `#[repr(C)]`, or `#[repr(C,
131///   Int)]`.
132/// - The enum has a variant with discriminant 0 (explicitly or implicitly).
133/// - All fields in the variant with discriminant 0 (if any) must implement
134///   `Zeroable`
135///
136/// The macro always succeeds on unions.
137///
138/// ## Example
139///
140/// ```rust
141/// # use bytemuck_derive::{Zeroable};
142/// #[derive(Copy, Clone, Zeroable)]
143/// #[repr(C)]
144/// struct Test {
145///   a: u16,
146///   b: u16,
147/// }
148/// ```
149/// ```rust
150/// # use bytemuck_derive::{Zeroable};
151/// #[derive(Copy, Clone, Zeroable)]
152/// #[repr(i32)]
153/// enum Values {
154///   A = 0,
155///   B = 1,
156///   C = 2,
157/// }
158/// #[derive(Clone, Zeroable)]
159/// #[repr(C)]
160/// enum Implicit {
161///   A(bool, u8, char),
162///   B(String),
163///   C(std::num::NonZeroU8),
164/// }
165/// ```
166///
167/// # Custom bounds
168///
169/// Custom bounds for the derived `Zeroable` impl can be given using the
170/// `#[zeroable(bound = "")]` helper attribute.
171///
172/// Using this attribute additionally opts-in to "perfect derive" semantics,
173/// where instead of adding bounds for each generic type parameter, bounds are
174/// added for each field's type.
175///
176/// ## Examples
177///
178/// ```rust
179/// # use bytemuck::Zeroable;
180/// # use std::marker::PhantomData;
181/// #[derive(Clone, Zeroable)]
182/// #[zeroable(bound = "")]
183/// struct AlwaysZeroable<T> {
184///   a: PhantomData<T>,
185/// }
186///
187/// AlwaysZeroable::<std::num::NonZeroU8>::zeroed();
188/// ```
189/// ```rust
190/// # use bytemuck::{Zeroable};
191/// #[derive(Copy, Clone, Zeroable)]
192/// #[repr(u8)]
193/// #[zeroable(bound = "")]
194/// enum MyOption<T> {
195///   None,
196///   Some(T),
197/// }
198///
199/// assert!(matches!(MyOption::<std::num::NonZeroU8>::zeroed(), MyOption::None));
200/// ```
201///
202/// ```rust,compile_fail
203/// # use bytemuck::Zeroable;
204/// # use std::marker::PhantomData;
205/// #[derive(Clone, Zeroable)]
206/// #[zeroable(bound = "T: Copy")]
207/// struct ZeroableWhenTIsCopy<T> {
208///   a: PhantomData<T>,
209/// }
210///
211/// ZeroableWhenTIsCopy::<String>::zeroed();
212/// ```
213///
214/// The restriction that all fields must be Zeroable is still applied, and this
215/// is enforced using the mentioned "perfect derive" semantics.
216///
217/// ```rust
218/// # use bytemuck::Zeroable;
219/// #[derive(Clone, Zeroable)]
220/// #[zeroable(bound = "")]
221/// struct ZeroableWhenTIsZeroable<T> {
222///   a: T,
223/// }
224/// ZeroableWhenTIsZeroable::<u32>::zeroed();
225/// ```
226///
227/// ```rust,compile_fail
228/// # use bytemuck::Zeroable;
229/// # #[derive(Clone, Zeroable)]
230/// # #[zeroable(bound = "")]
231/// # struct ZeroableWhenTIsZeroable<T> {
232/// #   a: T,
233/// # }
234/// ZeroableWhenTIsZeroable::<String>::zeroed();
235/// ```
236#[proc_macro_derive(Zeroable, attributes(bytemuck, zeroable))]
237pub fn derive_zeroable(
238  input: proc_macro::TokenStream,
239) -> proc_macro::TokenStream {
240  let expanded =
241    derive_marker_trait::<Zeroable>(parse_macro_input!(input as DeriveInput));
242
243  proc_macro::TokenStream::from(expanded)
244}
245
246/// Derive the `NoUninit` trait for a struct or enum
247///
248/// The macro ensures that the type follows all the the safety requirements
249/// for the `NoUninit` trait.
250///
251/// The following constraints need to be satisfied for the macro to succeed
252/// (the rest of the constraints are guaranteed by the `NoUninit` subtrait
253/// bounds, i.e. the type must be `Sized + Copy + 'static`):
254///
255/// If applied to a struct:
256/// - All fields in the struct must implement `NoUninit`
257/// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
258/// - The struct must not contain any padding bytes
259/// - The struct must contain no generic parameters
260///
261/// If applied to an enum:
262/// - The enum must be explicit `#[repr(Int)]`, `#[repr(C)]`, or both
263/// - If the enum has fields:
264///   - All fields must implement `NoUninit`
265///   - All variants must not contain any padding bytes
266///   - All variants must be of the the same size
267///   - There must be no padding bytes between the discriminant and any of the
268///     variant fields
269/// - The enum must contain no generic parameters
270#[proc_macro_derive(NoUninit, attributes(bytemuck))]
271pub fn derive_no_uninit(
272  input: proc_macro::TokenStream,
273) -> proc_macro::TokenStream {
274  let expanded =
275    derive_marker_trait::<NoUninit>(parse_macro_input!(input as DeriveInput));
276
277  proc_macro::TokenStream::from(expanded)
278}
279
280/// Derive the `CheckedBitPattern` trait for a struct or enum.
281///
282/// The macro ensures that the type follows all the the safety requirements
283/// for the `CheckedBitPattern` trait and derives the required `Bits` type
284/// definition and `is_valid_bit_pattern` method for the type automatically.
285///
286/// The following constraints need to be satisfied for the macro to succeed:
287///
288/// If applied to a struct:
289/// - All fields must implement `CheckedBitPattern`
290/// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
291/// - The struct must contain no generic parameters
292///
293/// If applied to an enum:
294/// - The enum must be explicit `#[repr(Int)]`
295/// - All fields in variants must implement `CheckedBitPattern`
296/// - The enum must contain no generic parameters
297#[proc_macro_derive(CheckedBitPattern)]
298pub fn derive_maybe_pod(
299  input: proc_macro::TokenStream,
300) -> proc_macro::TokenStream {
301  let expanded = derive_marker_trait::<CheckedBitPattern>(parse_macro_input!(
302    input as DeriveInput
303  ));
304
305  proc_macro::TokenStream::from(expanded)
306}
307
308/// Derive the `TransparentWrapper` trait for a struct
309///
310/// The macro ensures that the struct follows all the the safety requirements
311/// for the `TransparentWrapper` trait.
312///
313/// The following constraints need to be satisfied for the macro to succeed
314///
315/// - The struct must be `#[repr(transparent)]`
316/// - The struct must contain the `Wrapped` type
317/// - Any ZST fields must be [`Zeroable`][derive@Zeroable].
318///
319/// If the struct only contains a single field, the `Wrapped` type will
320/// automatically be determined. If there is more then one field in the struct,
321/// you need to specify the `Wrapped` type using `#[transparent(T)]`. Due to
322/// technical limitations, the type in the `#[transparent(Type)]` needs to be
323/// the exact same token sequence as the corresponding type in the struct
324/// definition.
325///
326/// ## Examples
327///
328/// ```rust
329/// # use bytemuck_derive::TransparentWrapper;
330/// # use std::marker::PhantomData;
331/// #[derive(Copy, Clone, TransparentWrapper)]
332/// #[repr(transparent)]
333/// #[transparent(u16)]
334/// struct Test<T> {
335///   inner: u16,
336///   extra: PhantomData<T>,
337/// }
338/// ```
339///
340/// If the struct contains more than one field, the `Wrapped` type must be
341/// explicitly specified.
342///
343/// ```rust,compile_fail
344/// # use bytemuck_derive::TransparentWrapper;
345/// # use std::marker::PhantomData;
346/// #[derive(Copy, Clone, TransparentWrapper)]
347/// #[repr(transparent)]
348/// // missing `#[transparent(u16)]`
349/// struct Test<T> {
350///   inner: u16,
351///   extra: PhantomData<T>,
352/// }
353/// ```
354///
355/// Any ZST fields must be `Zeroable`.
356///
357/// ```rust,compile_fail
358/// # use bytemuck_derive::TransparentWrapper;
359/// # use std::marker::PhantomData;
360/// struct NonTransparentSafeZST;
361///
362/// #[derive(TransparentWrapper)]
363/// #[repr(transparent)]
364/// #[transparent(u16)]
365/// struct Test<T> {
366///   inner: u16,
367///   extra: PhantomData<T>,
368///   another_extra: NonTransparentSafeZST, // not `Zeroable`
369/// }
370/// ```
371#[proc_macro_derive(TransparentWrapper, attributes(bytemuck, transparent))]
372pub fn derive_transparent(
373  input: proc_macro::TokenStream,
374) -> proc_macro::TokenStream {
375  let expanded = derive_marker_trait::<TransparentWrapper>(parse_macro_input!(
376    input as DeriveInput
377  ));
378
379  proc_macro::TokenStream::from(expanded)
380}
381
382/// Derive the `Contiguous` trait for an enum
383///
384/// The macro ensures that the enum follows all the the safety requirements
385/// for the `Contiguous` trait.
386///
387/// The following constraints need to be satisfied for the macro to succeed
388///
389/// - The enum must be `#[repr(Int)]`
390/// - The enum must be fieldless
391/// - The enum discriminants must form a contiguous range
392///
393/// ## Example
394///
395/// ```rust
396/// # use bytemuck_derive::{Contiguous};
397///
398/// #[derive(Copy, Clone, Contiguous)]
399/// #[repr(u8)]
400/// enum Test {
401///   A = 0,
402///   B = 1,
403///   C = 2,
404/// }
405/// ```
406#[proc_macro_derive(Contiguous)]
407pub fn derive_contiguous(
408  input: proc_macro::TokenStream,
409) -> proc_macro::TokenStream {
410  let expanded =
411    derive_marker_trait::<Contiguous>(parse_macro_input!(input as DeriveInput));
412
413  proc_macro::TokenStream::from(expanded)
414}
415
416/// Derive the `PartialEq` and `Eq` trait for a type
417///
418/// The macro implements `PartialEq` and `Eq` by casting both sides of the
419/// comparison to a byte slice and then compares those.
420///
421/// ## Warning
422///
423/// Since this implements a byte wise comparison, the behavior of floating point
424/// numbers does not match their usual comparison behavior. Additionally other
425/// custom comparison behaviors of the individual fields are also ignored. This
426/// also does not implement `StructuralPartialEq` / `StructuralEq` like
427/// `PartialEq` / `Eq` would. This means you can't pattern match on the values.
428///
429/// ## Examples
430///
431/// ```rust
432/// # use bytemuck_derive::{ByteEq, NoUninit};
433/// #[derive(Copy, Clone, NoUninit, ByteEq)]
434/// #[repr(C)]
435/// struct Test {
436///   a: u32,
437///   b: char,
438///   c: f32,
439/// }
440/// ```
441///
442/// ```rust
443/// # use bytemuck_derive::ByteEq;
444/// # use bytemuck::NoUninit;
445/// #[derive(Copy, Clone, ByteEq)]
446/// #[repr(C)]
447/// struct Test<const N: usize> {
448///   a: [u32; N],
449/// }
450/// unsafe impl<const N: usize> NoUninit for Test<N> {}
451/// ```
452#[proc_macro_derive(ByteEq)]
453pub fn derive_byte_eq(
454  input: proc_macro::TokenStream,
455) -> proc_macro::TokenStream {
456  let input = parse_macro_input!(input as DeriveInput);
457  let crate_name = bytemuck_crate_name(&input);
458  let ident = input.ident;
459  let (impl_generics, ty_generics, where_clause) =
460    input.generics.split_for_impl();
461
462  proc_macro::TokenStream::from(quote! {
463    impl #impl_generics ::core::cmp::PartialEq for #ident #ty_generics #where_clause {
464      #[inline]
465      #[must_use]
466      fn eq(&self, other: &Self) -> bool {
467        #crate_name::bytes_of(self) == #crate_name::bytes_of(other)
468      }
469    }
470    impl #impl_generics ::core::cmp::Eq for #ident #ty_generics #where_clause { }
471  })
472}
473
474/// Derive the `Hash` trait for a type
475///
476/// The macro implements `Hash` by casting the value to a byte slice and hashing
477/// that.
478///
479/// ## Warning
480///
481/// The hash does not match the standard library's `Hash` derive.
482///
483/// ## Examples
484///
485/// ```rust
486/// # use bytemuck_derive::{ByteHash, NoUninit};
487/// #[derive(Copy, Clone, NoUninit, ByteHash)]
488/// #[repr(C)]
489/// struct Test {
490///   a: u32,
491///   b: char,
492///   c: f32,
493/// }
494/// ```
495///
496/// ```rust
497/// # use bytemuck_derive::ByteHash;
498/// # use bytemuck::NoUninit;
499/// #[derive(Copy, Clone, ByteHash)]
500/// #[repr(C)]
501/// struct Test<const N: usize> {
502///   a: [u32; N],
503/// }
504/// unsafe impl<const N: usize> NoUninit for Test<N> {}
505/// ```
506#[proc_macro_derive(ByteHash)]
507pub fn derive_byte_hash(
508  input: proc_macro::TokenStream,
509) -> proc_macro::TokenStream {
510  let input = parse_macro_input!(input as DeriveInput);
511  let crate_name = bytemuck_crate_name(&input);
512  let ident = input.ident;
513  let (impl_generics, ty_generics, where_clause) =
514    input.generics.split_for_impl();
515
516  proc_macro::TokenStream::from(quote! {
517    impl #impl_generics ::core::hash::Hash for #ident #ty_generics #where_clause {
518      #[inline]
519      fn hash<H: ::core::hash::Hasher>(&self, state: &mut H) {
520        ::core::hash::Hash::hash_slice(#crate_name::bytes_of(self), state)
521      }
522
523      #[inline]
524      fn hash_slice<H: ::core::hash::Hasher>(data: &[Self], state: &mut H) {
525        ::core::hash::Hash::hash_slice(#crate_name::cast_slice::<_, u8>(data), state)
526      }
527    }
528  })
529}
530
531/// Basic wrapper for error handling
532fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
533  derive_marker_trait_inner::<Trait>(input)
534    .unwrap_or_else(|err| err.into_compile_error())
535}
536
537/// Find `#[name(key = "value")]` helper attributes on the struct, and return
538/// their `"value"`s parsed with `parser`.
539///
540/// Returns an error if any attributes with the given `name` do not match the
541/// expected format. Returns `Ok([])` if no attributes with `name` are found.
542fn find_and_parse_helper_attributes<P: syn::parse::Parser + Copy>(
543  attributes: &[syn::Attribute], name: &str, key: &str, parser: P,
544  example_value: &str, invalid_value_msg: &str,
545) -> Result<Vec<P::Output>> {
546  let invalid_format_msg =
547    format!("{name} attribute must be `{name}({key} = \"{example_value}\")`",);
548  let values_to_check = attributes.iter().filter_map(|attr| match &attr.meta {
549    // If a `Path` matches our `name`, return an error, else ignore it.
550    // e.g. `#[zeroable]`
551    syn::Meta::Path(path) => path
552      .is_ident(name)
553      .then(|| Err(syn::Error::new_spanned(path, &invalid_format_msg))),
554    // If a `NameValue` matches our `name`, return an error, else ignore it.
555    // e.g. `#[zeroable = "hello"]`
556    syn::Meta::NameValue(namevalue) => {
557      namevalue.path.is_ident(name).then(|| {
558        Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
559      })
560    }
561    // If a `List` matches our `name`, match its contents to our format, else
562    // ignore it. If its contents match our format, return the value, else
563    // return an error.
564    syn::Meta::List(list) => list.path.is_ident(name).then(|| {
565      let namevalue: syn::MetaNameValue = syn::parse2(list.tokens.clone())
566        .map_err(|_| {
567          syn::Error::new_spanned(&list.tokens, &invalid_format_msg)
568        })?;
569      if namevalue.path.is_ident(key) {
570        match namevalue.value {
571          syn::Expr::Lit(syn::ExprLit {
572            lit: syn::Lit::Str(strlit), ..
573          }) => Ok(strlit),
574          _ => {
575            Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
576          }
577        }
578      } else {
579        Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
580      }
581    }),
582  });
583  // Parse each value found with the given parser, and return them if no errors
584  // occur.
585  values_to_check
586    .map(|lit| {
587      let lit = lit?;
588      lit.parse_with(parser).map_err(|err| {
589        syn::Error::new_spanned(&lit, format!("{invalid_value_msg}: {err}"))
590      })
591    })
592    .collect()
593}
594
595fn derive_marker_trait_inner<Trait: Derivable>(
596  mut input: DeriveInput,
597) -> Result<TokenStream> {
598  let crate_name = bytemuck_crate_name(&input);
599  let trait_ = Trait::ident(&input, &crate_name)?;
600  // If this trait allows explicit bounds, and any explicit bounds were given,
601  // then use those explicit bounds. Else, apply the default bounds (bound
602  // each generic type on this trait).
603  if let Some(name) = Trait::explicit_bounds_attribute_name() {
604    // See if any explicit bounds were given in attributes.
605    let explicit_bounds = find_and_parse_helper_attributes(
606      &input.attrs,
607      name,
608      "bound",
609      <syn::punctuated::Punctuated<syn::WherePredicate, syn::Token![,]>>::parse_terminated,
610      "Type: Trait",
611      "invalid where predicate",
612    )?;
613
614    if !explicit_bounds.is_empty() {
615      // Explicit bounds were given.
616      // Enforce explicitly given bounds, and emit "perfect derive" (i.e. add
617      // bounds for each field's type).
618      let explicit_bounds = explicit_bounds
619        .into_iter()
620        .flatten()
621        .collect::<Vec<syn::WherePredicate>>();
622
623      let fields = match (Trait::perfect_derive_fields(&input), &input.data) {
624        (Some(fields), _) => fields,
625        (None, syn::Data::Struct(syn::DataStruct { fields, .. })) => {
626          fields.clone()
627        }
628        (None, syn::Data::Union(_)) => {
629          return Err(syn::Error::new_spanned(
630            trait_,
631            &"perfect derive is not supported for unions",
632          ));
633        }
634        (None, syn::Data::Enum(_)) => {
635          return Err(syn::Error::new_spanned(
636            trait_,
637            &"perfect derive is not supported for enums",
638          ));
639        }
640      };
641
642      let predicates = &mut input.generics.make_where_clause().predicates;
643
644      predicates.extend(explicit_bounds);
645
646      for field in fields {
647        let ty = field.ty;
648        predicates.push(syn::parse_quote!(
649          #ty: #trait_
650        ));
651      }
652    } else {
653      // No explicit bounds were given.
654      // Enforce trait bound on all type generics.
655      add_trait_marker(&mut input.generics, &trait_);
656    }
657  } else {
658    // This trait does not allow explicit bounds.
659    // Enforce trait bound on all type generics.
660    add_trait_marker(&mut input.generics, &trait_);
661  }
662
663  let name = &input.ident;
664
665  let (impl_generics, ty_generics, where_clause) =
666    input.generics.split_for_impl();
667
668  Trait::check_attributes(&input.data, &input.attrs)?;
669  let asserts = Trait::asserts(&input, &crate_name)?;
670  let (trait_impl_extras, trait_impl) = Trait::trait_impl(&input, &crate_name)?;
671
672  let implies_trait = if let Some(implies_trait) =
673    Trait::implies_trait(&crate_name)
674  {
675    quote!(unsafe impl #impl_generics #implies_trait for #name #ty_generics #where_clause {})
676  } else {
677    quote!()
678  };
679
680  let where_clause =
681    if Trait::requires_where_clause() { where_clause } else { None };
682
683  Ok(quote! {
684    #asserts
685
686    #trait_impl_extras
687
688    unsafe impl #impl_generics #trait_ for #name #ty_generics #where_clause {
689      #trait_impl
690    }
691
692    #implies_trait
693  })
694}
695
696/// Add a trait marker to the generics if it is not already present
697fn add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path) {
698  // Get each generic type parameter.
699  let type_params = generics
700    .type_params()
701    .map(|param| &param.ident)
702    .map(|param| {
703      syn::parse_quote!(
704        #param: #trait_name
705      )
706    })
707    .collect::<Vec<syn::WherePredicate>>();
708
709  generics.make_where_clause().predicates.extend(type_params);
710}