Skip to main content

zerocopy_derive/
util.rs

1// Copyright 2019 The Fuchsia Authors
2//
3// Licensed under a BSD-style license <LICENSE-BSD>, Apache License, Version 2.0
4// <LICENSE-APACHE or https://www.apache.org/licenses/LICENSE-2.0>, or the MIT
5// license <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your option.
6// This file may not be copied, modified, or distributed except according to
7// those terms.
8
9use std::num::NonZeroU32;
10
11use proc_macro2::{Span, TokenStream};
12use quote::{quote, quote_spanned, ToTokens};
13use syn::{
14    parse_quote, spanned::Spanned as _, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error,
15    Expr, ExprLit, Field, GenericParam, Ident, Index, Lit, LitStr, Meta, Path, Type, Variant,
16    Visibility, WherePredicate,
17};
18
19use crate::repr::{CompoundRepr, EnumRepr, PrimitiveRepr, Repr, Spanned};
20
21pub(crate) struct Ctx {
22    pub(crate) ast: DeriveInput,
23    pub(crate) zerocopy_crate: Path,
24
25    // The value of the last `#[zerocopy(on_error = ...)]` attribute, or `false`
26    // if none is provided.
27    pub(crate) skip_on_error: bool,
28
29    // The span of the last `#[zerocopy(on_error = ...)]` attribute, if any.
30    pub(crate) on_error_span: Option<proc_macro2::Span>,
31}
32
33impl Ctx {
34    /// Attempt to extract a crate path from the provided attributes. Defaults to
35    /// `::zerocopy` if not found.
36    pub(crate) fn try_from_derive_input(ast: DeriveInput) -> Result<Self, Error> {
37        let mut path = parse_quote!(::zerocopy);
38        let mut skip_on_error = false;
39        let mut on_error_span = None;
40
41        for attr in &ast.attrs {
42            if let Meta::List(ref meta_list) = attr.meta {
43                if meta_list.path.is_ident("zerocopy") {
44                    attr.parse_nested_meta(|meta| {
45                        if meta.path.is_ident("crate") {
46                            let expr = meta.value().and_then(|value| value.parse());
47                            if let Ok(Expr::Lit(ExprLit { lit: Lit::Str(lit), .. })) = expr {
48                                if let Ok(path_lit) = lit.parse::<Ident>() {
49                                    path = parse_quote!(::#path_lit);
50                                    return Ok(());
51                                }
52                            }
53
54                            return Err(Error::new(
55                                Span::call_site(),
56                                "`crate` attribute requires a path as the value",
57                            ));
58                        }
59
60                        if meta.path.is_ident("on_error") {
61                            on_error_span = Some(meta.path.span());
62                            let value = meta.value()?;
63                            let s: LitStr = value.parse()?;
64                            match s.value().as_str() {
65                                "skip" => skip_on_error = true,
66                                "fail" => skip_on_error = false,
67                                _ => return Err(Error::new(
68                                    s.span(),
69                                    "unrecognized value for `on_error` attribute from `zerocopy`; expected `skip` or `fail`",
70                                )),
71                            }
72                            return Ok(());
73                        }
74
75                        Err(Error::new(
76                            Span::call_site(),
77                            format!(
78                                "unknown attribute encountered: {}",
79                                meta.path.into_token_stream()
80                            ),
81                        ))
82                    })?;
83                }
84            }
85        }
86
87        Ok(Self { ast, zerocopy_crate: path, skip_on_error, on_error_span })
88    }
89
90    pub(crate) fn with_input(&self, input: &DeriveInput) -> Self {
91        Self {
92            ast: input.clone(),
93            zerocopy_crate: self.zerocopy_crate.clone(),
94            skip_on_error: self.skip_on_error,
95            on_error_span: self.on_error_span,
96        }
97    }
98
99    pub(crate) fn core_path(&self) -> TokenStream {
100        let zerocopy_crate = &self.zerocopy_crate;
101        quote!(#zerocopy_crate::util::macro_util::core_reexport)
102    }
103
104    pub(crate) fn cfg_compile_error(&self) -> TokenStream {
105        // By checking both during the compilation of the proc macro *and* in
106        // the generated code, we ensure that `--cfg
107        // zerocopy_unstable_derive_on_error` need only be passed *either* when
108        // compiling this crate *or* when compiling the user's crate. The former
109        // is preferable, but in some situations (such as when cross-compiling
110        // using `cargo build --target`), it doesn't get propagated to this
111        // crate's build by default.
112        if cfg!(zerocopy_unstable_derive_on_error) {
113            quote!()
114        } else if let Some(span) = self.on_error_span {
115            let core = self.core_path();
116            let error_message = "`on_error` is experimental; pass '--cfg zerocopy_unstable_derive_on_error' to enable";
117            quote::quote_spanned! {span=>
118                #[allow(unused_attributes, unexpected_cfgs)]
119                const _: () = {
120                    #[cfg(not(zerocopy_unstable_derive_on_error))]
121                    #core::compile_error!(#error_message);
122                };
123            }
124        } else {
125            quote!()
126        }
127    }
128
129    pub(crate) fn error_or_skip<E>(&self, error: E) -> Result<TokenStream, E> {
130        if self.skip_on_error {
131            Ok(self.cfg_compile_error())
132        } else {
133            Err(error)
134        }
135    }
136}
137
138pub(crate) trait DataExt {
139    /// Extracts the names and types of all fields. For enums, extracts the
140    /// names and types of fields from each variant. For tuple structs, the
141    /// names are the indices used to index into the struct (ie, `0`, `1`, etc).
142    ///
143    /// FIXME: Extracting field names for enums doesn't really make sense. Types
144    /// makes sense because we don't care about where they live - we just care
145    /// about transitive ownership. But for field names, we'd only use them when
146    /// generating is_bit_valid, which cares about where they live.
147    fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)>;
148
149    fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)>;
150
151    fn tag(&self) -> Option<Ident>;
152}
153
154impl DataExt for Data {
155    fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
156        match self {
157            Data::Struct(strc) => strc.fields(),
158            Data::Enum(enm) => enm.fields(),
159            Data::Union(un) => un.fields(),
160        }
161    }
162
163    fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
164        match self {
165            Data::Struct(strc) => strc.variants(),
166            Data::Enum(enm) => enm.variants(),
167            Data::Union(un) => un.variants(),
168        }
169    }
170
171    fn tag(&self) -> Option<Ident> {
172        match self {
173            Data::Struct(strc) => strc.tag(),
174            Data::Enum(enm) => enm.tag(),
175            Data::Union(un) => un.tag(),
176        }
177    }
178}
179
180impl DataExt for DataStruct {
181    fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
182        map_fields(&self.fields)
183    }
184
185    fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
186        vec![(None, self.fields())]
187    }
188
189    fn tag(&self) -> Option<Ident> {
190        None
191    }
192}
193
194impl DataExt for DataEnum {
195    fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
196        map_fields(self.variants.iter().flat_map(|var| &var.fields))
197    }
198
199    fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
200        self.variants.iter().map(|var| (Some(var), map_fields(&var.fields))).collect()
201    }
202
203    fn tag(&self) -> Option<Ident> {
204        Some(Ident::new("___ZerocopyTag", Span::call_site()))
205    }
206}
207
208impl DataExt for DataUnion {
209    fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
210        map_fields(&self.fields.named)
211    }
212
213    fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
214        vec![(None, self.fields())]
215    }
216
217    fn tag(&self) -> Option<Ident> {
218        None
219    }
220}
221
222fn map_fields<'a>(
223    fields: impl 'a + IntoIterator<Item = &'a Field>,
224) -> Vec<(&'a Visibility, TokenStream, &'a Type)> {
225    fields
226        .into_iter()
227        .enumerate()
228        .map(|(idx, f)| {
229            (
230                &f.vis,
231                f.ident
232                    .as_ref()
233                    .map(ToTokens::to_token_stream)
234                    .unwrap_or_else(|| Index::from(idx).to_token_stream()),
235                &f.ty,
236            )
237        })
238        .collect()
239}
240
241pub(crate) fn to_ident_str(t: &impl ToString) -> String {
242    let s = t.to_string();
243    if let Some(stripped) = s.strip_prefix("r#") {
244        stripped.to_string()
245    } else {
246        s
247    }
248}
249
250/// This enum describes what kind of padding check needs to be generated for the
251/// associated impl.
252pub(crate) enum PaddingCheck {
253    /// Check that the sum of the fields' sizes exactly equals the struct's
254    /// size.
255    Struct,
256    /// Check that a `repr(C)` struct has no padding.
257    ReprCStruct,
258    /// Check that the size of each field exactly equals the union's size.
259    Union,
260    /// Check that every variant of the enum contains no padding.
261    ///
262    /// Because doing so requires a tag enum, this padding check requires an
263    /// additional `TokenStream` which defines the tag enum as `___ZerocopyTag`.
264    Enum { tag_type_definition: TokenStream },
265}
266
267impl PaddingCheck {
268    /// Returns the idents of the trait to use and the macro to call in order to
269    /// validate that a type passes the relevant padding check.
270    pub(crate) fn validator_trait_and_macro_idents(&self) -> (Ident, Ident) {
271        let (trt, mcro) = match self {
272            PaddingCheck::Struct => ("PaddingFree", "struct_padding"),
273            PaddingCheck::ReprCStruct => ("DynamicPaddingFree", "repr_c_struct_has_padding"),
274            PaddingCheck::Union => ("PaddingFree", "union_padding"),
275            PaddingCheck::Enum { .. } => ("PaddingFree", "enum_padding"),
276        };
277
278        let trt = Ident::new(trt, Span::call_site());
279        let mcro = Ident::new(mcro, Span::call_site());
280        (trt, mcro)
281    }
282
283    /// Sometimes performing the padding check requires some additional
284    /// "context" code. For enums, this is the definition of the tag enum.
285    pub(crate) fn validator_macro_context(&self) -> Option<&TokenStream> {
286        match self {
287            PaddingCheck::Struct | PaddingCheck::ReprCStruct | PaddingCheck::Union => None,
288            PaddingCheck::Enum { tag_type_definition } => Some(tag_type_definition),
289        }
290    }
291}
292
293#[derive(Clone)]
294pub(crate) enum Trait {
295    KnownLayout,
296    HasTag,
297    HasField {
298        variant_id: Box<Expr>,
299        field: Box<Type>,
300        field_id: Box<Expr>,
301    },
302    ProjectField {
303        variant_id: Box<Expr>,
304        field: Box<Type>,
305        field_id: Box<Expr>,
306        invariants: Box<Type>,
307    },
308    Immutable,
309    TryFromBytes,
310    FromZeros,
311    FromBytes,
312    IntoBytes,
313    Unaligned,
314    Sized,
315    ByteHash,
316    ByteEq,
317    SplitAt,
318}
319
320impl ToTokens for Trait {
321    fn to_tokens(&self, tokens: &mut TokenStream) {
322        // According to [1], the format of the derived `Debug`` output is not
323        // stable and therefore not guaranteed to represent the variant names.
324        // Indeed with the (unstable) `fmt-debug` compiler flag [2], it can
325        // return only a minimalized output or empty string. To make sure this
326        // code will work in the future and independent of the compiler flag, we
327        // translate the variants to their names manually here.
328        //
329        // [1] https://doc.rust-lang.org/1.81.0/std/fmt/trait.Debug.html#stability
330        // [2] https://doc.rust-lang.org/beta/unstable-book/compiler-flags/fmt-debug.html
331        let s = match self {
332            Trait::HasField { .. } => "HasField",
333            Trait::ProjectField { .. } => "ProjectField",
334            Trait::KnownLayout => "KnownLayout",
335            Trait::HasTag => "HasTag",
336            Trait::Immutable => "Immutable",
337            Trait::TryFromBytes => "TryFromBytes",
338            Trait::FromZeros => "FromZeros",
339            Trait::FromBytes => "FromBytes",
340            Trait::IntoBytes => "IntoBytes",
341            Trait::Unaligned => "Unaligned",
342            Trait::Sized => "Sized",
343            Trait::ByteHash => "ByteHash",
344            Trait::ByteEq => "ByteEq",
345            Trait::SplitAt => "SplitAt",
346        };
347        let ident = Ident::new(s, Span::call_site());
348        let arguments: Option<syn::AngleBracketedGenericArguments> = match self {
349            Trait::HasField { variant_id, field, field_id } => {
350                Some(parse_quote!(<#field, #variant_id, #field_id>))
351            }
352            Trait::ProjectField { variant_id, field, field_id, invariants } => {
353                Some(parse_quote!(<#field, #invariants, #variant_id, #field_id>))
354            }
355            Trait::KnownLayout
356            | Trait::HasTag
357            | Trait::Immutable
358            | Trait::TryFromBytes
359            | Trait::FromZeros
360            | Trait::FromBytes
361            | Trait::IntoBytes
362            | Trait::Unaligned
363            | Trait::Sized
364            | Trait::ByteHash
365            | Trait::ByteEq
366            | Trait::SplitAt => None,
367        };
368        tokens.extend(quote!(#ident #arguments));
369    }
370}
371
372impl Trait {
373    pub(crate) fn crate_path(&self, ctx: &Ctx) -> Path {
374        let zerocopy_crate = &ctx.zerocopy_crate;
375        let core = ctx.core_path();
376        match self {
377            Self::Sized => parse_quote!(#core::marker::#self),
378            _ => parse_quote!(#zerocopy_crate::#self),
379        }
380    }
381}
382
383pub(crate) enum TraitBound {
384    Slf,
385    Other(Trait),
386}
387
388pub(crate) enum FieldBounds<'a> {
389    None,
390    All(&'a [TraitBound]),
391    Trailing(&'a [TraitBound]),
392    Explicit(Vec<WherePredicate>),
393}
394
395impl<'a> FieldBounds<'a> {
396    pub(crate) const ALL_SELF: FieldBounds<'a> = FieldBounds::All(&[TraitBound::Slf]);
397    pub(crate) const TRAILING_SELF: FieldBounds<'a> = FieldBounds::Trailing(&[TraitBound::Slf]);
398}
399
400pub(crate) enum SelfBounds<'a> {
401    None,
402    All(&'a [Trait]),
403}
404
405// FIXME(https://github.com/rust-lang/rust-clippy/issues/12908): This is a false
406// positive. Explicit lifetimes are actually necessary here.
407#[allow(clippy::needless_lifetimes)]
408impl<'a> SelfBounds<'a> {
409    pub(crate) const SIZED: Self = Self::All(&[Trait::Sized]);
410}
411
412/// Normalizes a slice of bounds by replacing [`TraitBound::Slf`] with `slf`.
413pub(crate) fn normalize_bounds<'a>(
414    slf: &'a Trait,
415    bounds: &'a [TraitBound],
416) -> impl 'a + Iterator<Item = Trait> {
417    bounds.iter().map(move |bound| match bound {
418        TraitBound::Slf => slf.clone(),
419        TraitBound::Other(trt) => trt.clone(),
420    })
421}
422
423pub(crate) struct ImplBlockBuilder<'a> {
424    ctx: &'a Ctx,
425    data: &'a dyn DataExt,
426    trt: Trait,
427    field_type_trait_bounds: FieldBounds<'a>,
428    self_type_trait_bounds: SelfBounds<'a>,
429    padding_check: Option<PaddingCheck>,
430    param_extras: Vec<GenericParam>,
431    inner_extras: Option<TokenStream>,
432    outer_extras: Option<TokenStream>,
433}
434
435impl<'a> ImplBlockBuilder<'a> {
436    pub(crate) fn new(
437        ctx: &'a Ctx,
438        data: &'a dyn DataExt,
439        trt: Trait,
440        field_type_trait_bounds: FieldBounds<'a>,
441    ) -> Self {
442        Self {
443            ctx,
444            data,
445            trt,
446            field_type_trait_bounds,
447            self_type_trait_bounds: SelfBounds::None,
448            padding_check: None,
449            param_extras: Vec::new(),
450            inner_extras: None,
451            outer_extras: None,
452        }
453    }
454
455    pub(crate) fn self_type_trait_bounds(mut self, self_type_trait_bounds: SelfBounds<'a>) -> Self {
456        self.self_type_trait_bounds = self_type_trait_bounds;
457        self
458    }
459
460    pub(crate) fn padding_check<P: Into<Option<PaddingCheck>>>(mut self, padding_check: P) -> Self {
461        self.padding_check = padding_check.into();
462        self
463    }
464
465    pub(crate) fn param_extras(mut self, param_extras: Vec<GenericParam>) -> Self {
466        self.param_extras.extend(param_extras);
467        self
468    }
469
470    pub(crate) fn inner_extras(mut self, inner_extras: TokenStream) -> Self {
471        self.inner_extras = Some(inner_extras);
472        self
473    }
474
475    pub(crate) fn outer_extras<T: Into<Option<TokenStream>>>(mut self, outer_extras: T) -> Self {
476        self.outer_extras = outer_extras.into();
477        self
478    }
479
480    pub(crate) fn build(self) -> TokenStream {
481        // In this documentation, we will refer to this hypothetical struct:
482        //
483        //   #[derive(FromBytes)]
484        //   struct Foo<T, I: Iterator>
485        //   where
486        //       T: Copy,
487        //       I: Clone,
488        //       I::Item: Clone,
489        //   {
490        //       a: u8,
491        //       b: T,
492        //       c: I::Item,
493        //   }
494        //
495        // We extract the field types, which in this case are `u8`, `T`, and
496        // `I::Item`. We re-use the existing parameters and where clauses. If
497        // `require_trait_bound == true` (as it is for `FromBytes), we add where
498        // bounds for each field's type:
499        //
500        //   impl<T, I: Iterator> FromBytes for Foo<T, I>
501        //   where
502        //       T: Copy,
503        //       I: Clone,
504        //       I::Item: Clone,
505        //       T: FromBytes,
506        //       I::Item: FromBytes,
507        //   {
508        //   }
509        //
510        // NOTE: It is standard practice to only emit bounds for the type
511        // parameters themselves, not for field types based on those parameters
512        // (e.g., `T` vs `T::Foo`). For a discussion of why this is standard
513        // practice, see https://github.com/rust-lang/rust/issues/26925.
514        //
515        // The reason we diverge from this standard is that doing it that way
516        // for us would be unsound. E.g., consider a type, `T` where `T:
517        // FromBytes` but `T::Foo: !FromBytes`. It would not be sound for us to
518        // accept a type with a `T::Foo` field as `FromBytes` simply because `T:
519        // FromBytes`.
520        //
521        // While there's no getting around this requirement for us, it does have
522        // the pretty serious downside that, when lifetimes are involved, the
523        // trait solver ties itself in knots:
524        //
525        //     #[derive(Unaligned)]
526        //     #[repr(C)]
527        //     struct Dup<'a, 'b> {
528        //         a: PhantomData<&'a u8>,
529        //         b: PhantomData<&'b u8>,
530        //     }
531        //
532        //     error[E0283]: type annotations required: cannot resolve `core::marker::PhantomData<&'a u8>: zerocopy::Unaligned`
533        //      --> src/main.rs:6:10
534        //       |
535        //     6 | #[derive(Unaligned)]
536        //       |          ^^^^^^^^^
537        //       |
538        //       = note: required by `zerocopy::Unaligned`
539
540        let type_ident = &self.ctx.ast.ident;
541        let trait_path = self.trt.crate_path(self.ctx);
542        let fields = self.data.fields();
543        let variants = self.data.variants();
544        let tag = self.data.tag();
545        let zerocopy_crate = &self.ctx.zerocopy_crate;
546
547        fn bound_tt(ty: &Type, traits: impl Iterator<Item = Trait>, ctx: &Ctx) -> WherePredicate {
548            let traits = traits.map(|t| t.crate_path(ctx));
549            parse_quote!(#ty: #(#traits)+*)
550        }
551        let field_type_bounds: Vec<_> = match (self.field_type_trait_bounds, &fields[..]) {
552            (FieldBounds::All(traits), _) => fields
553                .iter()
554                .map(|(_vis, _name, ty)| {
555                    bound_tt(ty, normalize_bounds(&self.trt, traits), self.ctx)
556                })
557                .collect(),
558            (FieldBounds::None, _) | (FieldBounds::Trailing(..), []) => vec![],
559            (FieldBounds::Trailing(traits), [.., last]) => {
560                vec![bound_tt(last.2, normalize_bounds(&self.trt, traits), self.ctx)]
561            }
562            (FieldBounds::Explicit(bounds), _) => bounds,
563        };
564
565        let padding_check_bound = self
566            .padding_check
567            .map(|check| {
568                // Parse the repr for `align` and `packed` modifiers. Note that
569                // `Repr::<PrimitiveRepr, NonZeroU32>` is more permissive than
570                // what Rust supports for structs, enums, or unions, and thus
571                // reliably extracts these modifiers for any kind of type.
572                let repr =
573                    Repr::<PrimitiveRepr, NonZeroU32>::from_attrs(&self.ctx.ast.attrs).unwrap();
574                let core = self.ctx.core_path();
575                let option = quote! { #core::option::Option };
576                let nonzero = quote! { #core::num::NonZeroUsize };
577                let none = quote! { #option::None::<#nonzero> };
578                let repr_align =
579                    repr.get_align().map(|spanned| {
580                        let n = spanned.t.get();
581                        quote_spanned! { spanned.span => (#nonzero::new(#n as usize)) }
582                    }).unwrap_or(quote! { (#none) });
583                let repr_packed =
584                    repr.get_packed().map(|packed| {
585                        let n = packed.get();
586                        quote! { (#nonzero::new(#n as usize)) }
587                    }).unwrap_or(quote! { (#none) });
588                let variant_types = variants.iter().map(|(_, fields)| {
589                    let types = fields.iter().map(|(_vis, _name, ty)| ty);
590                    quote!([#((#types)),*])
591                });
592                let validator_context = check.validator_macro_context();
593                let (trt, validator_macro) = check.validator_trait_and_macro_idents();
594                let t = tag.iter();
595                parse_quote! {
596                    (): #zerocopy_crate::util::macro_util::#trt<
597                        Self,
598                        {
599                            #validator_context
600                            #zerocopy_crate::#validator_macro!(Self, #repr_align, #repr_packed, #(#t,)* #(#variant_types),*)
601                        }
602                    >
603                }
604            });
605
606        let self_bounds: Option<WherePredicate> = match self.self_type_trait_bounds {
607            SelfBounds::None => None,
608            SelfBounds::All(traits) => {
609                Some(bound_tt(&parse_quote!(Self), traits.iter().cloned(), self.ctx))
610            }
611        };
612
613        let bounds = self
614            .ctx
615            .ast
616            .generics
617            .where_clause
618            .as_ref()
619            .map(|where_clause| where_clause.predicates.iter())
620            .into_iter()
621            .flatten()
622            .chain(field_type_bounds.iter())
623            .chain(padding_check_bound.iter())
624            .chain(self_bounds.iter());
625
626        // The parameters with trait bounds, but without type defaults.
627        let mut params: Vec<_> = self
628            .ctx
629            .ast
630            .generics
631            .params
632            .clone()
633            .into_iter()
634            .map(|mut param| {
635                match &mut param {
636                    GenericParam::Type(ty) => ty.default = None,
637                    GenericParam::Const(cnst) => cnst.default = None,
638                    GenericParam::Lifetime(_) => {}
639                }
640                parse_quote!(#param)
641            })
642            .chain(self.param_extras)
643            .collect();
644
645        // For MSRV purposes, ensure that lifetimes precede types precede const
646        // generics.
647        params.sort_by_cached_key(|param| match param {
648            GenericParam::Lifetime(_) => 0,
649            GenericParam::Type(_) => 1,
650            GenericParam::Const(_) => 2,
651        });
652
653        // The identifiers of the parameters without trait bounds or type
654        // defaults.
655        let param_idents = self.ctx.ast.generics.params.iter().map(|param| match param {
656            GenericParam::Type(ty) => {
657                let ident = &ty.ident;
658                quote!(#ident)
659            }
660            GenericParam::Lifetime(l) => {
661                let ident = &l.lifetime;
662                quote!(#ident)
663            }
664            GenericParam::Const(cnst) => {
665                let ident = &cnst.ident;
666                quote!({#ident})
667            }
668        });
669
670        let inner_extras = self.inner_extras;
671        let allow_trivial_bounds =
672            if self.ctx.skip_on_error { quote!(#[allow(trivial_bounds)]) } else { quote!() };
673        let impl_tokens = quote! {
674            #allow_trivial_bounds
675            unsafe impl < #(#params),* > #trait_path for #type_ident < #(#param_idents),* >
676            where
677                #(#bounds,)*
678            {
679                fn only_derive_is_allowed_to_implement_this_trait() {}
680
681                #inner_extras
682            }
683        };
684
685        let outer_extras = self.outer_extras.filter(|e| !e.is_empty());
686        let cfg_compile_error = self.ctx.cfg_compile_error();
687        const_block([Some(cfg_compile_error), Some(impl_tokens), outer_extras])
688    }
689}
690
691// A polyfill for `Option::then_some`, which was added after our MSRV.
692//
693// The `#[allow(unused)]` is necessary because, on sufficiently recent toolchain
694// versions, `b.then_some(...)` resolves to the inherent method rather than to
695// this trait, and so this trait is considered unused.
696//
697// FIXME(#67): Remove this once our MSRV is >= 1.62.
698#[allow(unused)]
699trait BoolExt {
700    fn then_some<T>(self, t: T) -> Option<T>;
701}
702
703impl BoolExt for bool {
704    fn then_some<T>(self, t: T) -> Option<T> {
705        if self {
706            Some(t)
707        } else {
708            None
709        }
710    }
711}
712
713pub(crate) fn const_block(items: impl IntoIterator<Item = Option<TokenStream>>) -> TokenStream {
714    let items = items.into_iter().flatten();
715    quote! {
716        #[allow(
717            // FIXME(#553): Add a test that generates a warning when
718            // `#[allow(deprecated)]` isn't present.
719            deprecated,
720            // Required on some rustc versions due to a lint that is only
721            // triggered when `derive(KnownLayout)` is applied to `repr(C)`
722            // structs that are generated by macros. See #2177 for details.
723            private_bounds,
724            non_local_definitions,
725            non_camel_case_types,
726            non_upper_case_globals,
727            non_snake_case,
728            non_ascii_idents,
729            clippy::missing_inline_in_public_items,
730        )]
731        #[deny(ambiguous_associated_items)]
732        // While there are not currently any warnings that this suppresses
733        // (that we're aware of), it's good future-proofing hygiene.
734        #[automatically_derived]
735        const _: () = {
736            #(#items)*
737        };
738    }
739}
740pub(crate) fn generate_tag_enum(ctx: &Ctx, repr: &EnumRepr, data: &DataEnum) -> TokenStream {
741    let zerocopy_crate = &ctx.zerocopy_crate;
742    let variants = data.variants.iter().map(|v| {
743        let ident = &v.ident;
744        if let Some((eq, discriminant)) = &v.discriminant {
745            quote! { #ident #eq #discriminant }
746        } else {
747            quote! { #ident }
748        }
749    });
750
751    // Don't include any `repr(align)` when generating the tag enum, as that
752    // could add padding after the tag but before any variants, which is not the
753    // correct behavior.
754    let repr = match repr {
755        EnumRepr::Transparent(span) => quote::quote_spanned! { *span => #[repr(transparent)] },
756        EnumRepr::Compound(c, _) => quote! { #c },
757    };
758
759    quote! {
760        #repr
761        #[allow(dead_code)]
762        pub enum ___ZerocopyTag {
763            #(#variants,)*
764        }
765
766        // SAFETY: `___ZerocopyTag` has no fields, and so it does not permit
767        // interior mutation.
768        unsafe impl #zerocopy_crate::Immutable for ___ZerocopyTag {
769            fn only_derive_is_allowed_to_implement_this_trait() {}
770        }
771    }
772}
773pub(crate) fn enum_size_from_repr(repr: &EnumRepr) -> Result<usize, Error> {
774    use CompoundRepr::*;
775    use PrimitiveRepr::*;
776    use Repr::*;
777    match repr {
778        Transparent(span)
779        | Compound(
780            Spanned {
781                t: C | Rust | Primitive(U32 | I32 | U64 | I64 | U128 | I128 | Usize | Isize),
782                span,
783            },
784            _,
785        ) => Err(Error::new(
786            *span,
787            "`FromBytes` only supported on enums with `#[repr(...)]` attributes `u8`, `i8`, `u16`, or `i16`",
788        )),
789        Compound(Spanned { t: Primitive(U8 | I8), span: _ }, _align) => Ok(8),
790        Compound(Spanned { t: Primitive(U16 | I16), span: _ }, _align) => Ok(16),
791    }
792}
793
794#[cfg(test)]
795pub(crate) mod testutil {
796    use proc_macro2::TokenStream;
797    use syn::visit::{self, Visit};
798
799    /// Checks for hygiene violations in the generated code.
800    ///
801    /// # Panics
802    ///
803    /// Panics if a hygiene violation is found.
804    pub(crate) fn check_hygiene(ts: TokenStream) {
805        struct AmbiguousItemVisitor;
806
807        impl<'ast> Visit<'ast> for AmbiguousItemVisitor {
808            fn visit_path(&mut self, i: &'ast syn::Path) {
809                if i.segments.len() > 1 && i.segments.first().unwrap().ident == "Self" {
810                    panic!(
811                    "Found ambiguous path `{}` in generated output. \
812                     All associated item access must be fully qualified (e.g., `<Self as Trait>::Item`) \
813                     to prevent hygiene issues.",
814                    quote::quote!(#i)
815                );
816                }
817                visit::visit_path(self, i);
818            }
819        }
820
821        let file = syn::parse2::<syn::File>(ts).expect("failed to parse generated output as File");
822        AmbiguousItemVisitor.visit_file(&file);
823    }
824
825    #[test]
826    fn test_check_hygiene_success() {
827        check_hygiene(quote::quote! {
828            fn foo() {
829                let _ = <Self as Trait>::Item;
830            }
831        });
832    }
833
834    #[test]
835    #[should_panic(expected = "Found ambiguous path `Self :: Ambiguous`")]
836    fn test_check_hygiene_failure() {
837        check_hygiene(quote::quote! {
838            fn foo() {
839                let _ = Self::Ambiguous;
840            }
841        });
842    }
843}