1use std::num::NonZeroU32;
10
11use proc_macro2::{Span, TokenStream};
12use quote::{quote, quote_spanned, ToTokens};
13use syn::{
14 parse_quote, spanned::Spanned as _, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error,
15 Expr, ExprLit, Field, GenericParam, Ident, Index, Lit, LitStr, Meta, Path, Type, Variant,
16 Visibility, WherePredicate,
17};
18
19use crate::repr::{CompoundRepr, EnumRepr, PrimitiveRepr, Repr, Spanned};
20
21pub(crate) struct Ctx {
22 pub(crate) ast: DeriveInput,
23 pub(crate) zerocopy_crate: Path,
24
25 pub(crate) skip_on_error: bool,
28
29 pub(crate) on_error_span: Option<proc_macro2::Span>,
31}
32
33impl Ctx {
34 pub(crate) fn try_from_derive_input(ast: DeriveInput) -> Result<Self, Error> {
37 let mut path = parse_quote!(::zerocopy);
38 let mut skip_on_error = false;
39 let mut on_error_span = None;
40
41 for attr in &ast.attrs {
42 if let Meta::List(ref meta_list) = attr.meta {
43 if meta_list.path.is_ident("zerocopy") {
44 attr.parse_nested_meta(|meta| {
45 if meta.path.is_ident("crate") {
46 let expr = meta.value().and_then(|value| value.parse());
47 if let Ok(Expr::Lit(ExprLit { lit: Lit::Str(lit), .. })) = expr {
48 if let Ok(path_lit) = lit.parse::<Ident>() {
49 path = parse_quote!(::#path_lit);
50 return Ok(());
51 }
52 }
53
54 return Err(Error::new(
55 Span::call_site(),
56 "`crate` attribute requires a path as the value",
57 ));
58 }
59
60 if meta.path.is_ident("on_error") {
61 on_error_span = Some(meta.path.span());
62 let value = meta.value()?;
63 let s: LitStr = value.parse()?;
64 match s.value().as_str() {
65 "skip" => skip_on_error = true,
66 "fail" => skip_on_error = false,
67 _ => return Err(Error::new(
68 s.span(),
69 "unrecognized value for `on_error` attribute from `zerocopy`; expected `skip` or `fail`",
70 )),
71 }
72 return Ok(());
73 }
74
75 Err(Error::new(
76 Span::call_site(),
77 format!(
78 "unknown attribute encountered: {}",
79 meta.path.into_token_stream()
80 ),
81 ))
82 })?;
83 }
84 }
85 }
86
87 Ok(Self { ast, zerocopy_crate: path, skip_on_error, on_error_span })
88 }
89
90 pub(crate) fn with_input(&self, input: &DeriveInput) -> Self {
91 Self {
92 ast: input.clone(),
93 zerocopy_crate: self.zerocopy_crate.clone(),
94 skip_on_error: self.skip_on_error,
95 on_error_span: self.on_error_span,
96 }
97 }
98
99 pub(crate) fn core_path(&self) -> TokenStream {
100 let zerocopy_crate = &self.zerocopy_crate;
101 quote!(#zerocopy_crate::util::macro_util::core_reexport)
102 }
103
104 pub(crate) fn cfg_compile_error(&self) -> TokenStream {
105 if cfg!(zerocopy_unstable_derive_on_error) {
113 quote!()
114 } else if let Some(span) = self.on_error_span {
115 let core = self.core_path();
116 let error_message = "`on_error` is experimental; pass '--cfg zerocopy_unstable_derive_on_error' to enable";
117 quote::quote_spanned! {span=>
118 #[allow(unused_attributes, unexpected_cfgs)]
119 const _: () = {
120 #[cfg(not(zerocopy_unstable_derive_on_error))]
121 #core::compile_error!(#error_message);
122 };
123 }
124 } else {
125 quote!()
126 }
127 }
128
129 pub(crate) fn error_or_skip<E>(&self, error: E) -> Result<TokenStream, E> {
130 if self.skip_on_error {
131 Ok(self.cfg_compile_error())
132 } else {
133 Err(error)
134 }
135 }
136}
137
138pub(crate) trait DataExt {
139 fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)>;
148
149 fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)>;
150
151 fn tag(&self) -> Option<Ident>;
152}
153
154impl DataExt for Data {
155 fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
156 match self {
157 Data::Struct(strc) => strc.fields(),
158 Data::Enum(enm) => enm.fields(),
159 Data::Union(un) => un.fields(),
160 }
161 }
162
163 fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
164 match self {
165 Data::Struct(strc) => strc.variants(),
166 Data::Enum(enm) => enm.variants(),
167 Data::Union(un) => un.variants(),
168 }
169 }
170
171 fn tag(&self) -> Option<Ident> {
172 match self {
173 Data::Struct(strc) => strc.tag(),
174 Data::Enum(enm) => enm.tag(),
175 Data::Union(un) => un.tag(),
176 }
177 }
178}
179
180impl DataExt for DataStruct {
181 fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
182 map_fields(&self.fields)
183 }
184
185 fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
186 vec![(None, self.fields())]
187 }
188
189 fn tag(&self) -> Option<Ident> {
190 None
191 }
192}
193
194impl DataExt for DataEnum {
195 fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
196 map_fields(self.variants.iter().flat_map(|var| &var.fields))
197 }
198
199 fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
200 self.variants.iter().map(|var| (Some(var), map_fields(&var.fields))).collect()
201 }
202
203 fn tag(&self) -> Option<Ident> {
204 Some(Ident::new("___ZerocopyTag", Span::call_site()))
205 }
206}
207
208impl DataExt for DataUnion {
209 fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
210 map_fields(&self.fields.named)
211 }
212
213 fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
214 vec![(None, self.fields())]
215 }
216
217 fn tag(&self) -> Option<Ident> {
218 None
219 }
220}
221
222fn map_fields<'a>(
223 fields: impl 'a + IntoIterator<Item = &'a Field>,
224) -> Vec<(&'a Visibility, TokenStream, &'a Type)> {
225 fields
226 .into_iter()
227 .enumerate()
228 .map(|(idx, f)| {
229 (
230 &f.vis,
231 f.ident
232 .as_ref()
233 .map(ToTokens::to_token_stream)
234 .unwrap_or_else(|| Index::from(idx).to_token_stream()),
235 &f.ty,
236 )
237 })
238 .collect()
239}
240
241pub(crate) fn to_ident_str(t: &impl ToString) -> String {
242 let s = t.to_string();
243 if let Some(stripped) = s.strip_prefix("r#") {
244 stripped.to_string()
245 } else {
246 s
247 }
248}
249
250pub(crate) enum PaddingCheck {
253 Struct,
256 ReprCStruct,
258 Union,
260 Enum { tag_type_definition: TokenStream },
265}
266
267impl PaddingCheck {
268 pub(crate) fn validator_trait_and_macro_idents(&self) -> (Ident, Ident) {
271 let (trt, mcro) = match self {
272 PaddingCheck::Struct => ("PaddingFree", "struct_padding"),
273 PaddingCheck::ReprCStruct => ("DynamicPaddingFree", "repr_c_struct_has_padding"),
274 PaddingCheck::Union => ("PaddingFree", "union_padding"),
275 PaddingCheck::Enum { .. } => ("PaddingFree", "enum_padding"),
276 };
277
278 let trt = Ident::new(trt, Span::call_site());
279 let mcro = Ident::new(mcro, Span::call_site());
280 (trt, mcro)
281 }
282
283 pub(crate) fn validator_macro_context(&self) -> Option<&TokenStream> {
286 match self {
287 PaddingCheck::Struct | PaddingCheck::ReprCStruct | PaddingCheck::Union => None,
288 PaddingCheck::Enum { tag_type_definition } => Some(tag_type_definition),
289 }
290 }
291}
292
293#[derive(Clone)]
294pub(crate) enum Trait {
295 KnownLayout,
296 HasTag,
297 HasField {
298 variant_id: Box<Expr>,
299 field: Box<Type>,
300 field_id: Box<Expr>,
301 },
302 ProjectField {
303 variant_id: Box<Expr>,
304 field: Box<Type>,
305 field_id: Box<Expr>,
306 invariants: Box<Type>,
307 },
308 Immutable,
309 TryFromBytes,
310 FromZeros,
311 FromBytes,
312 IntoBytes,
313 Unaligned,
314 Sized,
315 ByteHash,
316 ByteEq,
317 SplitAt,
318}
319
320impl ToTokens for Trait {
321 fn to_tokens(&self, tokens: &mut TokenStream) {
322 let s = match self {
332 Trait::HasField { .. } => "HasField",
333 Trait::ProjectField { .. } => "ProjectField",
334 Trait::KnownLayout => "KnownLayout",
335 Trait::HasTag => "HasTag",
336 Trait::Immutable => "Immutable",
337 Trait::TryFromBytes => "TryFromBytes",
338 Trait::FromZeros => "FromZeros",
339 Trait::FromBytes => "FromBytes",
340 Trait::IntoBytes => "IntoBytes",
341 Trait::Unaligned => "Unaligned",
342 Trait::Sized => "Sized",
343 Trait::ByteHash => "ByteHash",
344 Trait::ByteEq => "ByteEq",
345 Trait::SplitAt => "SplitAt",
346 };
347 let ident = Ident::new(s, Span::call_site());
348 let arguments: Option<syn::AngleBracketedGenericArguments> = match self {
349 Trait::HasField { variant_id, field, field_id } => {
350 Some(parse_quote!(<#field, #variant_id, #field_id>))
351 }
352 Trait::ProjectField { variant_id, field, field_id, invariants } => {
353 Some(parse_quote!(<#field, #invariants, #variant_id, #field_id>))
354 }
355 Trait::KnownLayout
356 | Trait::HasTag
357 | Trait::Immutable
358 | Trait::TryFromBytes
359 | Trait::FromZeros
360 | Trait::FromBytes
361 | Trait::IntoBytes
362 | Trait::Unaligned
363 | Trait::Sized
364 | Trait::ByteHash
365 | Trait::ByteEq
366 | Trait::SplitAt => None,
367 };
368 tokens.extend(quote!(#ident #arguments));
369 }
370}
371
372impl Trait {
373 pub(crate) fn crate_path(&self, ctx: &Ctx) -> Path {
374 let zerocopy_crate = &ctx.zerocopy_crate;
375 let core = ctx.core_path();
376 match self {
377 Self::Sized => parse_quote!(#core::marker::#self),
378 _ => parse_quote!(#zerocopy_crate::#self),
379 }
380 }
381}
382
383pub(crate) enum TraitBound {
384 Slf,
385 Other(Trait),
386}
387
388pub(crate) enum FieldBounds<'a> {
389 None,
390 All(&'a [TraitBound]),
391 Trailing(&'a [TraitBound]),
392 Explicit(Vec<WherePredicate>),
393}
394
395impl<'a> FieldBounds<'a> {
396 pub(crate) const ALL_SELF: FieldBounds<'a> = FieldBounds::All(&[TraitBound::Slf]);
397 pub(crate) const TRAILING_SELF: FieldBounds<'a> = FieldBounds::Trailing(&[TraitBound::Slf]);
398}
399
400pub(crate) enum SelfBounds<'a> {
401 None,
402 All(&'a [Trait]),
403}
404
405#[allow(clippy::needless_lifetimes)]
408impl<'a> SelfBounds<'a> {
409 pub(crate) const SIZED: Self = Self::All(&[Trait::Sized]);
410}
411
412pub(crate) fn normalize_bounds<'a>(
414 slf: &'a Trait,
415 bounds: &'a [TraitBound],
416) -> impl 'a + Iterator<Item = Trait> {
417 bounds.iter().map(move |bound| match bound {
418 TraitBound::Slf => slf.clone(),
419 TraitBound::Other(trt) => trt.clone(),
420 })
421}
422
423pub(crate) struct ImplBlockBuilder<'a> {
424 ctx: &'a Ctx,
425 data: &'a dyn DataExt,
426 trt: Trait,
427 field_type_trait_bounds: FieldBounds<'a>,
428 self_type_trait_bounds: SelfBounds<'a>,
429 padding_check: Option<PaddingCheck>,
430 param_extras: Vec<GenericParam>,
431 inner_extras: Option<TokenStream>,
432 outer_extras: Option<TokenStream>,
433}
434
435impl<'a> ImplBlockBuilder<'a> {
436 pub(crate) fn new(
437 ctx: &'a Ctx,
438 data: &'a dyn DataExt,
439 trt: Trait,
440 field_type_trait_bounds: FieldBounds<'a>,
441 ) -> Self {
442 Self {
443 ctx,
444 data,
445 trt,
446 field_type_trait_bounds,
447 self_type_trait_bounds: SelfBounds::None,
448 padding_check: None,
449 param_extras: Vec::new(),
450 inner_extras: None,
451 outer_extras: None,
452 }
453 }
454
455 pub(crate) fn self_type_trait_bounds(mut self, self_type_trait_bounds: SelfBounds<'a>) -> Self {
456 self.self_type_trait_bounds = self_type_trait_bounds;
457 self
458 }
459
460 pub(crate) fn padding_check<P: Into<Option<PaddingCheck>>>(mut self, padding_check: P) -> Self {
461 self.padding_check = padding_check.into();
462 self
463 }
464
465 pub(crate) fn param_extras(mut self, param_extras: Vec<GenericParam>) -> Self {
466 self.param_extras.extend(param_extras);
467 self
468 }
469
470 pub(crate) fn inner_extras(mut self, inner_extras: TokenStream) -> Self {
471 self.inner_extras = Some(inner_extras);
472 self
473 }
474
475 pub(crate) fn outer_extras<T: Into<Option<TokenStream>>>(mut self, outer_extras: T) -> Self {
476 self.outer_extras = outer_extras.into();
477 self
478 }
479
480 pub(crate) fn build(self) -> TokenStream {
481 let type_ident = &self.ctx.ast.ident;
541 let trait_path = self.trt.crate_path(self.ctx);
542 let fields = self.data.fields();
543 let variants = self.data.variants();
544 let tag = self.data.tag();
545 let zerocopy_crate = &self.ctx.zerocopy_crate;
546
547 fn bound_tt(ty: &Type, traits: impl Iterator<Item = Trait>, ctx: &Ctx) -> WherePredicate {
548 let traits = traits.map(|t| t.crate_path(ctx));
549 parse_quote!(#ty: #(#traits)+*)
550 }
551 let field_type_bounds: Vec<_> = match (self.field_type_trait_bounds, &fields[..]) {
552 (FieldBounds::All(traits), _) => fields
553 .iter()
554 .map(|(_vis, _name, ty)| {
555 bound_tt(ty, normalize_bounds(&self.trt, traits), self.ctx)
556 })
557 .collect(),
558 (FieldBounds::None, _) | (FieldBounds::Trailing(..), []) => vec![],
559 (FieldBounds::Trailing(traits), [.., last]) => {
560 vec![bound_tt(last.2, normalize_bounds(&self.trt, traits), self.ctx)]
561 }
562 (FieldBounds::Explicit(bounds), _) => bounds,
563 };
564
565 let padding_check_bound = self
566 .padding_check
567 .map(|check| {
568 let repr =
573 Repr::<PrimitiveRepr, NonZeroU32>::from_attrs(&self.ctx.ast.attrs).unwrap();
574 let core = self.ctx.core_path();
575 let option = quote! { #core::option::Option };
576 let nonzero = quote! { #core::num::NonZeroUsize };
577 let none = quote! { #option::None::<#nonzero> };
578 let repr_align =
579 repr.get_align().map(|spanned| {
580 let n = spanned.t.get();
581 quote_spanned! { spanned.span => (#nonzero::new(#n as usize)) }
582 }).unwrap_or(quote! { (#none) });
583 let repr_packed =
584 repr.get_packed().map(|packed| {
585 let n = packed.get();
586 quote! { (#nonzero::new(#n as usize)) }
587 }).unwrap_or(quote! { (#none) });
588 let variant_types = variants.iter().map(|(_, fields)| {
589 let types = fields.iter().map(|(_vis, _name, ty)| ty);
590 quote!([#((#types)),*])
591 });
592 let validator_context = check.validator_macro_context();
593 let (trt, validator_macro) = check.validator_trait_and_macro_idents();
594 let t = tag.iter();
595 parse_quote! {
596 (): #zerocopy_crate::util::macro_util::#trt<
597 Self,
598 {
599 #validator_context
600 #zerocopy_crate::#validator_macro!(Self, #repr_align, #repr_packed, #(#t,)* #(#variant_types),*)
601 }
602 >
603 }
604 });
605
606 let self_bounds: Option<WherePredicate> = match self.self_type_trait_bounds {
607 SelfBounds::None => None,
608 SelfBounds::All(traits) => {
609 Some(bound_tt(&parse_quote!(Self), traits.iter().cloned(), self.ctx))
610 }
611 };
612
613 let bounds = self
614 .ctx
615 .ast
616 .generics
617 .where_clause
618 .as_ref()
619 .map(|where_clause| where_clause.predicates.iter())
620 .into_iter()
621 .flatten()
622 .chain(field_type_bounds.iter())
623 .chain(padding_check_bound.iter())
624 .chain(self_bounds.iter());
625
626 let mut params: Vec<_> = self
628 .ctx
629 .ast
630 .generics
631 .params
632 .clone()
633 .into_iter()
634 .map(|mut param| {
635 match &mut param {
636 GenericParam::Type(ty) => ty.default = None,
637 GenericParam::Const(cnst) => cnst.default = None,
638 GenericParam::Lifetime(_) => {}
639 }
640 parse_quote!(#param)
641 })
642 .chain(self.param_extras)
643 .collect();
644
645 params.sort_by_cached_key(|param| match param {
648 GenericParam::Lifetime(_) => 0,
649 GenericParam::Type(_) => 1,
650 GenericParam::Const(_) => 2,
651 });
652
653 let param_idents = self.ctx.ast.generics.params.iter().map(|param| match param {
656 GenericParam::Type(ty) => {
657 let ident = &ty.ident;
658 quote!(#ident)
659 }
660 GenericParam::Lifetime(l) => {
661 let ident = &l.lifetime;
662 quote!(#ident)
663 }
664 GenericParam::Const(cnst) => {
665 let ident = &cnst.ident;
666 quote!({#ident})
667 }
668 });
669
670 let inner_extras = self.inner_extras;
671 let allow_trivial_bounds =
672 if self.ctx.skip_on_error { quote!(#[allow(trivial_bounds)]) } else { quote!() };
673 let impl_tokens = quote! {
674 #allow_trivial_bounds
675 unsafe impl < #(#params),* > #trait_path for #type_ident < #(#param_idents),* >
676 where
677 #(#bounds,)*
678 {
679 fn only_derive_is_allowed_to_implement_this_trait() {}
680
681 #inner_extras
682 }
683 };
684
685 let outer_extras = self.outer_extras.filter(|e| !e.is_empty());
686 let cfg_compile_error = self.ctx.cfg_compile_error();
687 const_block([Some(cfg_compile_error), Some(impl_tokens), outer_extras])
688 }
689}
690
691#[allow(unused)]
699trait BoolExt {
700 fn then_some<T>(self, t: T) -> Option<T>;
701}
702
703impl BoolExt for bool {
704 fn then_some<T>(self, t: T) -> Option<T> {
705 if self {
706 Some(t)
707 } else {
708 None
709 }
710 }
711}
712
713pub(crate) fn const_block(items: impl IntoIterator<Item = Option<TokenStream>>) -> TokenStream {
714 let items = items.into_iter().flatten();
715 quote! {
716 #[allow(
717 deprecated,
720 private_bounds,
724 non_local_definitions,
725 non_camel_case_types,
726 non_upper_case_globals,
727 non_snake_case,
728 non_ascii_idents,
729 clippy::missing_inline_in_public_items,
730 )]
731 #[deny(ambiguous_associated_items)]
732 #[automatically_derived]
735 const _: () = {
736 #(#items)*
737 };
738 }
739}
740pub(crate) fn generate_tag_enum(ctx: &Ctx, repr: &EnumRepr, data: &DataEnum) -> TokenStream {
741 let zerocopy_crate = &ctx.zerocopy_crate;
742 let variants = data.variants.iter().map(|v| {
743 let ident = &v.ident;
744 if let Some((eq, discriminant)) = &v.discriminant {
745 quote! { #ident #eq #discriminant }
746 } else {
747 quote! { #ident }
748 }
749 });
750
751 let repr = match repr {
755 EnumRepr::Transparent(span) => quote::quote_spanned! { *span => #[repr(transparent)] },
756 EnumRepr::Compound(c, _) => quote! { #c },
757 };
758
759 quote! {
760 #repr
761 #[allow(dead_code)]
762 pub enum ___ZerocopyTag {
763 #(#variants,)*
764 }
765
766 unsafe impl #zerocopy_crate::Immutable for ___ZerocopyTag {
769 fn only_derive_is_allowed_to_implement_this_trait() {}
770 }
771 }
772}
773pub(crate) fn enum_size_from_repr(repr: &EnumRepr) -> Result<usize, Error> {
774 use CompoundRepr::*;
775 use PrimitiveRepr::*;
776 use Repr::*;
777 match repr {
778 Transparent(span)
779 | Compound(
780 Spanned {
781 t: C | Rust | Primitive(U32 | I32 | U64 | I64 | U128 | I128 | Usize | Isize),
782 span,
783 },
784 _,
785 ) => Err(Error::new(
786 *span,
787 "`FromBytes` only supported on enums with `#[repr(...)]` attributes `u8`, `i8`, `u16`, or `i16`",
788 )),
789 Compound(Spanned { t: Primitive(U8 | I8), span: _ }, _align) => Ok(8),
790 Compound(Spanned { t: Primitive(U16 | I16), span: _ }, _align) => Ok(16),
791 }
792}
793
794#[cfg(test)]
795pub(crate) mod testutil {
796 use proc_macro2::TokenStream;
797 use syn::visit::{self, Visit};
798
799 pub(crate) fn check_hygiene(ts: TokenStream) {
805 struct AmbiguousItemVisitor;
806
807 impl<'ast> Visit<'ast> for AmbiguousItemVisitor {
808 fn visit_path(&mut self, i: &'ast syn::Path) {
809 if i.segments.len() > 1 && i.segments.first().unwrap().ident == "Self" {
810 panic!(
811 "Found ambiguous path `{}` in generated output. \
812 All associated item access must be fully qualified (e.g., `<Self as Trait>::Item`) \
813 to prevent hygiene issues.",
814 quote::quote!(#i)
815 );
816 }
817 visit::visit_path(self, i);
818 }
819 }
820
821 let file = syn::parse2::<syn::File>(ts).expect("failed to parse generated output as File");
822 AmbiguousItemVisitor.visit_file(&file);
823 }
824
825 #[test]
826 fn test_check_hygiene_success() {
827 check_hygiene(quote::quote! {
828 fn foo() {
829 let _ = <Self as Trait>::Item;
830 }
831 });
832 }
833
834 #[test]
835 #[should_panic(expected = "Found ambiguous path `Self :: Ambiguous`")]
836 fn test_check_hygiene_failure() {
837 check_hygiene(quote::quote! {
838 fn foo() {
839 let _ = Self::Ambiguous;
840 }
841 });
842 }
843}