zerocopy_derive/
enum.rs

1// Copyright 2024 The Fuchsia Authors
2//
3// Licensed under a BSD-style license <LICENSE-BSD>, Apache License, Version 2.0
4// <LICENSE-APACHE or https://www.apache.org/licenses/LICENSE-2.0>, or the MIT
5// license <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your option.
6// This file may not be copied, modified, or distributed except according to
7// those terms.
8
9use proc_macro2::{Span, TokenStream};
10use quote::quote;
11use syn::{parse_quote, DataEnum, Error, Fields, Generics, Ident, Path};
12
13use crate::{derive_try_from_bytes_inner, repr::EnumRepr, Trait};
14
15/// Generates a tag enum for the given enum. This generates an enum with the
16/// same non-align `repr`s, variants, and corresponding discriminants, but none
17/// of the fields.
18pub(crate) fn generate_tag_enum(repr: &EnumRepr, data: &DataEnum) -> TokenStream {
19    let variants = data.variants.iter().map(|v| {
20        let ident = &v.ident;
21        if let Some((eq, discriminant)) = &v.discriminant {
22            quote! { #ident #eq #discriminant }
23        } else {
24            quote! { #ident }
25        }
26    });
27
28    // Don't include any `repr(align)` when generating the tag enum, as that
29    // could add padding after the tag but before any variants, which is not the
30    // correct behavior.
31    let repr = match repr {
32        EnumRepr::Transparent(span) => quote::quote_spanned! { *span => #[repr(transparent)] },
33        EnumRepr::Compound(c, _) => quote! { #c },
34    };
35
36    quote! {
37        #repr
38        #[allow(dead_code, non_camel_case_types)]
39        enum ___ZerocopyTag {
40            #(#variants,)*
41        }
42    }
43}
44
45fn tag_ident(variant_ident: &Ident) -> Ident {
46    let variant_ident_str = crate::ext::to_ident_str(variant_ident);
47    Ident::new(&format!("___ZEROCOPY_TAG_{}", variant_ident_str), variant_ident.span())
48}
49
50/// Generates a constant for the tag associated with each variant of the enum.
51/// When we match on the enum's tag, each arm matches one of these constants. We
52/// have to use constants here because:
53///
54/// - The type that we're matching on is not the type of the tag, it's an
55///   integer of the same size as the tag type and with the same bit patterns.
56/// - We can't read the enum tag as an enum because the bytes may not represent
57///   a valid variant.
58/// - Patterns do not currently support const expressions, so we have to assign
59///   these constants to names rather than use them inline in the `match`
60///   statement.
61fn generate_tag_consts(data: &DataEnum) -> TokenStream {
62    let tags = data.variants.iter().map(|v| {
63        let variant_ident = &v.ident;
64        let tag_ident = tag_ident(variant_ident);
65
66        quote! {
67            // This casts the enum variant to its discriminant, and then
68            // converts the discriminant to the target integral type via a
69            // numeric cast [1].
70            //
71            // Because these are the same size, this is defined to be a no-op
72            // and therefore is a lossless conversion [2].
73            //
74            // [1] Per https://doc.rust-lang.org/1.81.0/reference/expressions/operator-expr.html#enum-cast:
75            //
76            //   Casts an enum to its discriminant.
77            //
78            // [2] Per https://doc.rust-lang.org/1.81.0/reference/expressions/operator-expr.html#numeric-cast:
79            //
80            //   Casting between two integers of the same size (e.g. i32 -> u32)
81            //   is a no-op.
82            #[allow(non_upper_case_globals)]
83            const #tag_ident: ___ZerocopyTagPrimitive =
84                ___ZerocopyTag::#variant_ident as ___ZerocopyTagPrimitive;
85        }
86    });
87
88    quote! {
89        #(#tags)*
90    }
91}
92
93fn variant_struct_ident(variant_ident: &Ident) -> Ident {
94    let variant_ident_str = crate::ext::to_ident_str(variant_ident);
95    Ident::new(&format!("___ZerocopyVariantStruct_{}", variant_ident_str), variant_ident.span())
96}
97
98/// Generates variant structs for the given enum variant.
99///
100/// These are structs associated with each variant of an enum. They are
101/// `repr(C)` tuple structs with the same fields as the variant after a
102/// `MaybeUninit<___ZerocopyInnerTag>`.
103///
104/// In order to unify the generated types for `repr(C)` and `repr(int)` enums,
105/// we use a "fused" representation with fields for both an inner tag and an
106/// outer tag. Depending on the repr, we will set one of these tags to the tag
107/// type and the other to `()`. This lets us generate the same code but put the
108/// tags in different locations.
109fn generate_variant_structs(
110    enum_name: &Ident,
111    generics: &Generics,
112    data: &DataEnum,
113    zerocopy_crate: &Path,
114) -> TokenStream {
115    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
116
117    // All variant structs have a `PhantomData<MyEnum<...>>` field because we
118    // don't know which generic parameters each variant will use, and unused
119    // generic parameters are a compile error.
120    let phantom_ty = quote! {
121        core_reexport::marker::PhantomData<#enum_name #ty_generics>
122    };
123
124    let variant_structs = data.variants.iter().filter_map(|variant| {
125        // We don't generate variant structs for unit variants because we only
126        // need to check the tag. This helps cut down our generated code a bit.
127        if matches!(variant.fields, Fields::Unit) {
128            return None;
129        }
130
131        let variant_struct_ident = variant_struct_ident(&variant.ident);
132        let field_types = variant.fields.iter().map(|f| &f.ty);
133
134        let variant_struct = parse_quote! {
135            #[repr(C)]
136            #[allow(non_snake_case)]
137            struct #variant_struct_ident #impl_generics (
138                core_reexport::mem::MaybeUninit<___ZerocopyInnerTag>,
139                #(#field_types,)*
140                #phantom_ty,
141            ) #where_clause;
142        };
143
144        // We do this rather than emitting `#[derive(::zerocopy::TryFromBytes)]`
145        // because that is not hygienic, and this is also more performant.
146        let try_from_bytes_impl =
147            derive_try_from_bytes_inner(&variant_struct, Trait::TryFromBytes, zerocopy_crate)
148                .expect("derive_try_from_bytes_inner should not fail on synthesized type");
149
150        Some(quote! {
151            #variant_struct
152            #try_from_bytes_impl
153        })
154    });
155
156    quote! {
157        #(#variant_structs)*
158    }
159}
160
161fn generate_variants_union(generics: &Generics, data: &DataEnum) -> TokenStream {
162    let (_, ty_generics, _) = generics.split_for_impl();
163
164    let fields = data.variants.iter().filter_map(|variant| {
165        // We don't generate variant structs for unit variants because we only
166        // need to check the tag. This helps cut down our generated code a bit.
167        if matches!(variant.fields, Fields::Unit) {
168            return None;
169        }
170
171        // Field names are prefixed with `__field_` to prevent name collision
172        // with the `__nonempty` field.
173        let field_name_str = crate::ext::to_ident_str(&variant.ident);
174        let field_name = Ident::new(&format!("__field_{}", field_name_str), variant.ident.span());
175        let variant_struct_ident = variant_struct_ident(&variant.ident);
176
177        Some(quote! {
178            #field_name: core_reexport::mem::ManuallyDrop<
179                #variant_struct_ident #ty_generics
180            >,
181        })
182    });
183
184    quote! {
185        #[repr(C)]
186        #[allow(non_snake_case)]
187        union ___ZerocopyVariants #generics {
188            #(#fields)*
189            // Enums can have variants with no fields, but unions must
190            // have at least one field. So we just add a trailing unit
191            // to ensure that this union always has at least one field.
192            // Because this union is `repr(C)`, this unit type does not
193            // affect the layout.
194            __nonempty: (),
195        }
196    }
197}
198
199/// Generates an implementation of `is_bit_valid` for an arbitrary enum.
200///
201/// The general process is:
202///
203/// 1. Generate a tag enum. This is an enum with the same repr, variants, and
204///    corresponding discriminants as the original enum, but without any fields
205///    on the variants. This gives us access to an enum where the variants have
206///    the same discriminants as the one we're writing `is_bit_valid` for.
207/// 2. Make constants from the variants of the tag enum. We need these because
208///    we can't put const exprs in match arms.
209/// 3. Generate variant structs. These are structs which have the same fields as
210///    each variant of the enum, and are `#[repr(C)]` with an optional "inner
211///    tag".
212/// 4. Generate a variants union, with one field for each variant struct type.
213/// 5. And finally, our raw enum is a `#[repr(C)]` struct of an "outer tag" and
214///    the variants union.
215///
216/// See these reference links for fully-worked example decompositions.
217///
218/// - `repr(C)`: <https://doc.rust-lang.org/reference/type-layout.html#reprc-enums-with-fields>
219/// - `repr(int)`: <https://doc.rust-lang.org/reference/type-layout.html#primitive-representation-of-enums-with-fields>
220/// - `repr(C, int)`: <https://doc.rust-lang.org/reference/type-layout.html#combining-primitive-representations-of-enums-with-fields-and-reprc>
221pub(crate) fn derive_is_bit_valid(
222    enum_ident: &Ident,
223    repr: &EnumRepr,
224    generics: &Generics,
225    data: &DataEnum,
226    zerocopy_crate: &Path,
227) -> Result<TokenStream, Error> {
228    let trait_path = Trait::TryFromBytes.crate_path(zerocopy_crate);
229    let tag_enum = generate_tag_enum(repr, data);
230    let tag_consts = generate_tag_consts(data);
231
232    let (outer_tag_type, inner_tag_type) = if repr.is_c() {
233        (quote! { ___ZerocopyTag }, quote! { () })
234    } else if repr.is_primitive() {
235        (quote! { () }, quote! { ___ZerocopyTag })
236    } else {
237        return Err(Error::new(
238            Span::call_site(),
239            "must have #[repr(C)] or #[repr(Int)] attribute in order to guarantee this type's memory layout",
240        ));
241    };
242
243    let variant_structs = generate_variant_structs(enum_ident, generics, data, zerocopy_crate);
244    let variants_union = generate_variants_union(generics, data);
245
246    let (_, ty_generics, _) = generics.split_for_impl();
247
248    let match_arms = data.variants.iter().map(|variant| {
249        let tag_ident = tag_ident(&variant.ident);
250        let variant_struct_ident = variant_struct_ident(&variant.ident);
251
252        if matches!(variant.fields, Fields::Unit) {
253            // Unit variants don't need any further validation beyond checking
254            // the tag.
255            quote! {
256                #tag_ident => true
257            }
258        } else {
259            quote! {
260                #tag_ident => {
261                    // SAFETY:
262                    // - This cast is from a `repr(C)` union which has a field
263                    //   of type `variant_struct_ident` to that variant struct
264                    //   type itself. This addresses a subset of the bytes
265                    //   addressed by `variants`.
266                    // - The returned pointer is cast from `p`, and so has the
267                    //   same provenance as `p`.
268                    // - We checked that the tag of the enum matched the
269                    //   constant for this variant, so this cast preserves
270                    //   types and locations of all fields. Therefore, any
271                    //   `UnsafeCell`s will have the same location as in the
272                    //   original type.
273                    let variant = unsafe {
274                        variants.cast_unsized_unchecked(
275                            |p: #zerocopy_crate::pointer::PtrInner<'_, ___ZerocopyVariants #ty_generics>| {
276                                p.cast_sized::<#variant_struct_ident #ty_generics>()
277                            }
278                        )
279                    };
280                    // SAFETY: `cast_unsized_unchecked` removes the
281                    // initialization invariant from `p`, so we re-assert that
282                    // all of the bytes are initialized.
283                    let variant = unsafe { variant.assume_initialized() };
284                    <
285                        #variant_struct_ident #ty_generics as #trait_path
286                    >::is_bit_valid(variant)
287                }
288            }
289        }
290    });
291
292    Ok(quote! {
293        // SAFETY: We use `is_bit_valid` to validate that the bit pattern of the
294        // enum's tag corresponds to one of the enum's discriminants. Then, we
295        // check the bit validity of each field of the corresponding variant.
296        // Thus, this is a sound implementation of `is_bit_valid`.
297        fn is_bit_valid<___ZerocopyAliasing>(
298            mut candidate: #zerocopy_crate::Maybe<'_, Self, ___ZerocopyAliasing>,
299        ) -> #zerocopy_crate::util::macro_util::core_reexport::primitive::bool
300        where
301            ___ZerocopyAliasing: #zerocopy_crate::pointer::invariant::Reference,
302        {
303            use #zerocopy_crate::util::macro_util::core_reexport;
304
305            #tag_enum
306
307            type ___ZerocopyTagPrimitive = #zerocopy_crate::util::macro_util::SizeToTag<
308                { core_reexport::mem::size_of::<___ZerocopyTag>() },
309            >;
310
311            #tag_consts
312
313            type ___ZerocopyOuterTag = #outer_tag_type;
314            type ___ZerocopyInnerTag = #inner_tag_type;
315
316            #variant_structs
317
318            #variants_union
319
320            #[repr(C)]
321            struct ___ZerocopyRawEnum #generics {
322                tag: ___ZerocopyOuterTag,
323                variants: ___ZerocopyVariants #ty_generics,
324            }
325
326            let tag = {
327                // SAFETY:
328                // - The provided cast addresses a subset of the bytes addressed
329                //   by `candidate` because it addresses the starting tag of the
330                //   enum.
331                // - Because the pointer is cast from `candidate`, it has the
332                //   same provenance as it.
333                // - There are no `UnsafeCell`s in the tag because it is a
334                //   primitive integer.
335                let tag_ptr = unsafe {
336                    candidate.reborrow().cast_unsized_unchecked(|p: #zerocopy_crate::pointer::PtrInner<'_, Self>| {
337                        p.cast_sized::<___ZerocopyTagPrimitive>()
338                    })
339                };
340                // SAFETY: `tag_ptr` is casted from `candidate`, whose referent
341                // is `Initialized`. Since we have not written uninitialized
342                // bytes into the referent, `tag_ptr` is also `Initialized`.
343                let tag_ptr = unsafe { tag_ptr.assume_initialized() };
344                tag_ptr.recall_validity::<_, (_, (_, _))>().read_unaligned::<#zerocopy_crate::BecauseImmutable>()
345            };
346
347            // SAFETY:
348            // - The raw enum has the same fields in the same locations as the
349            //   input enum, and may have a lower alignment. This guarantees
350            //   that it addresses a subset of the bytes addressed by
351            //   `candidate`.
352            // - The returned pointer is cast from `p`, and so has the same
353            //   provenance as `p`.
354            // - The raw enum has the same types at the same locations as the
355            //   original enum, and so preserves the locations of any
356            //   `UnsafeCell`s.
357            let raw_enum = unsafe {
358                candidate.cast_unsized_unchecked(|p: #zerocopy_crate::pointer::PtrInner<'_, Self>| {
359                    p.cast_sized::<___ZerocopyRawEnum #ty_generics>()
360                })
361            };
362            // SAFETY: `cast_unsized_unchecked` removes the initialization
363            // invariant from `p`, so we re-assert that all of the bytes are
364            // initialized.
365            let raw_enum = unsafe { raw_enum.assume_initialized() };
366            // SAFETY:
367            // - This projection returns a subfield of `this` using
368            //   `addr_of_mut!`.
369            // - Because the subfield pointer is derived from `this`, it has the
370            //   same provenance.
371            // - The locations of `UnsafeCell`s in the subfield match the
372            //   locations of `UnsafeCell`s in `this`. This is because the
373            //   subfield pointer just points to a smaller portion of the
374            //   overall struct.
375            let variants = unsafe {
376                use #zerocopy_crate::pointer::PtrInner;
377                raw_enum.cast_unsized_unchecked(|p: PtrInner<'_, ___ZerocopyRawEnum #ty_generics>| {
378                    let p = p.as_non_null().as_ptr();
379                    let ptr = core_reexport::ptr::addr_of_mut!((*p).variants);
380                    // SAFETY: `ptr` is a projection into `p`, which is
381                    // `NonNull`, and guaranteed not to wrap around the address
382                    // space. Thus, `ptr` cannot be null.
383                    let ptr = unsafe { core_reexport::ptr::NonNull::new_unchecked(ptr) };
384                    unsafe { PtrInner::new(ptr) }
385                })
386            };
387
388            #[allow(non_upper_case_globals)]
389            match tag {
390                #(#match_arms,)*
391                _ => false,
392            }
393        }
394    })
395}