zerocopy_derive/enum.rs
1// Copyright 2024 The Fuchsia Authors
2//
3// Licensed under a BSD-style license <LICENSE-BSD>, Apache License, Version 2.0
4// <LICENSE-APACHE or https://www.apache.org/licenses/LICENSE-2.0>, or the MIT
5// license <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your option.
6// This file may not be copied, modified, or distributed except according to
7// those terms.
8
9use proc_macro2::{Span, TokenStream};
10use quote::quote;
11use syn::{parse_quote, DataEnum, Error, Fields, Generics, Ident, Path};
12
13use crate::{derive_try_from_bytes_inner, repr::EnumRepr, Trait};
14
15/// Generates a tag enum for the given enum. This generates an enum with the
16/// same non-align `repr`s, variants, and corresponding discriminants, but none
17/// of the fields.
18pub(crate) fn generate_tag_enum(repr: &EnumRepr, data: &DataEnum) -> TokenStream {
19 let variants = data.variants.iter().map(|v| {
20 let ident = &v.ident;
21 if let Some((eq, discriminant)) = &v.discriminant {
22 quote! { #ident #eq #discriminant }
23 } else {
24 quote! { #ident }
25 }
26 });
27
28 // Don't include any `repr(align)` when generating the tag enum, as that
29 // could add padding after the tag but before any variants, which is not the
30 // correct behavior.
31 let repr = match repr {
32 EnumRepr::Transparent(span) => quote::quote_spanned! { *span => #[repr(transparent)] },
33 EnumRepr::Compound(c, _) => quote! { #c },
34 };
35
36 quote! {
37 #repr
38 #[allow(dead_code, non_camel_case_types)]
39 enum ___ZerocopyTag {
40 #(#variants,)*
41 }
42 }
43}
44
45fn tag_ident(variant_ident: &Ident) -> Ident {
46 Ident::new(&format!("___ZEROCOPY_TAG_{}", variant_ident), variant_ident.span())
47}
48
49/// Generates a constant for the tag associated with each variant of the enum.
50/// When we match on the enum's tag, each arm matches one of these constants. We
51/// have to use constants here because:
52///
53/// - The type that we're matching on is not the type of the tag, it's an
54/// integer of the same size as the tag type and with the same bit patterns.
55/// - We can't read the enum tag as an enum because the bytes may not represent
56/// a valid variant.
57/// - Patterns do not currently support const expressions, so we have to assign
58/// these constants to names rather than use them inline in the `match`
59/// statement.
60fn generate_tag_consts(data: &DataEnum) -> TokenStream {
61 let tags = data.variants.iter().map(|v| {
62 let variant_ident = &v.ident;
63 let tag_ident = tag_ident(variant_ident);
64
65 quote! {
66 // This casts the enum variant to its discriminant, and then
67 // converts the discriminant to the target integral type via a
68 // numeric cast [1].
69 //
70 // Because these are the same size, this is defined to be a no-op
71 // and therefore is a lossless conversion [2].
72 //
73 // [1]: https://doc.rust-lang.org/stable/reference/expressions/operator-expr.html#enum-cast
74 // [2]: https://doc.rust-lang.org/stable/reference/expressions/operator-expr.html#numeric-cast
75 #[allow(non_upper_case_globals)]
76 const #tag_ident: ___ZerocopyTagPrimitive =
77 ___ZerocopyTag::#variant_ident as ___ZerocopyTagPrimitive;
78 }
79 });
80
81 quote! {
82 #(#tags)*
83 }
84}
85
86fn variant_struct_ident(variant_ident: &Ident) -> Ident {
87 Ident::new(&format!("___ZerocopyVariantStruct_{}", variant_ident), variant_ident.span())
88}
89
90/// Generates variant structs for the given enum variant.
91///
92/// These are structs associated with each variant of an enum. They are
93/// `repr(C)` tuple structs with the same fields as the variant after a
94/// `MaybeUninit<___ZerocopyInnerTag>`.
95///
96/// In order to unify the generated types for `repr(C)` and `repr(int)` enums,
97/// we use a "fused" representation with fields for both an inner tag and an
98/// outer tag. Depending on the repr, we will set one of these tags to the tag
99/// type and the other to `()`. This lets us generate the same code but put the
100/// tags in different locations.
101fn generate_variant_structs(
102 enum_name: &Ident,
103 generics: &Generics,
104 data: &DataEnum,
105 zerocopy_crate: &Path,
106) -> TokenStream {
107 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
108
109 // All variant structs have a `PhantomData<MyEnum<...>>` field because we
110 // don't know which generic parameters each variant will use, and unused
111 // generic parameters are a compile error.
112 let phantom_ty = quote! {
113 core_reexport::marker::PhantomData<#enum_name #ty_generics>
114 };
115
116 let variant_structs = data.variants.iter().filter_map(|variant| {
117 // We don't generate variant structs for unit variants because we only
118 // need to check the tag. This helps cut down our generated code a bit.
119 if matches!(variant.fields, Fields::Unit) {
120 return None;
121 }
122
123 let variant_struct_ident = variant_struct_ident(&variant.ident);
124 let field_types = variant.fields.iter().map(|f| &f.ty);
125
126 let variant_struct = parse_quote! {
127 #[repr(C)]
128 #[allow(non_snake_case)]
129 struct #variant_struct_ident #impl_generics (
130 core_reexport::mem::MaybeUninit<___ZerocopyInnerTag>,
131 #(#field_types,)*
132 #phantom_ty,
133 ) #where_clause;
134 };
135
136 // We do this rather than emitting `#[derive(::zerocopy::TryFromBytes)]`
137 // because that is not hygienic, and this is also more performant.
138 let try_from_bytes_impl =
139 derive_try_from_bytes_inner(&variant_struct, Trait::TryFromBytes, zerocopy_crate)
140 .expect("derive_try_from_bytes_inner should not fail on synthesized type");
141
142 Some(quote! {
143 #variant_struct
144 #try_from_bytes_impl
145 })
146 });
147
148 quote! {
149 #(#variant_structs)*
150 }
151}
152
153fn generate_variants_union(generics: &Generics, data: &DataEnum) -> TokenStream {
154 let (_, ty_generics, _) = generics.split_for_impl();
155
156 let fields = data.variants.iter().filter_map(|variant| {
157 // We don't generate variant structs for unit variants because we only
158 // need to check the tag. This helps cut down our generated code a bit.
159 if matches!(variant.fields, Fields::Unit) {
160 return None;
161 }
162
163 // Field names are prefixed with `__field_` to prevent name collision with
164 // the `__nonempty` field.
165 let field_name = Ident::new(&format!("__field_{}", &variant.ident), variant.ident.span());
166 let variant_struct_ident = variant_struct_ident(&variant.ident);
167
168 Some(quote! {
169 #field_name: core_reexport::mem::ManuallyDrop<
170 #variant_struct_ident #ty_generics
171 >,
172 })
173 });
174
175 quote! {
176 #[repr(C)]
177 #[allow(non_snake_case)]
178 union ___ZerocopyVariants #generics {
179 #(#fields)*
180 // Enums can have variants with no fields, but unions must
181 // have at least one field. So we just add a trailing unit
182 // to ensure that this union always has at least one field.
183 // Because this union is `repr(C)`, this unit type does not
184 // affect the layout.
185 __nonempty: (),
186 }
187 }
188}
189
190/// Generates an implementation of `is_bit_valid` for an arbitrary enum.
191///
192/// The general process is:
193///
194/// 1. Generate a tag enum. This is an enum with the same repr, variants, and
195/// corresponding discriminants as the original enum, but without any fields
196/// on the variants. This gives us access to an enum where the variants have
197/// the same discriminants as the one we're writing `is_bit_valid` for.
198/// 2. Make constants from the variants of the tag enum. We need these because
199/// we can't put const exprs in match arms.
200/// 3. Generate variant structs. These are structs which have the same fields as
201/// each variant of the enum, and are `#[repr(C)]` with an optional "inner
202/// tag".
203/// 4. Generate a variants union, with one field for each variant struct type.
204/// 5. And finally, our raw enum is a `#[repr(C)]` struct of an "outer tag" and
205/// the variants union.
206///
207/// See these reference links for fully-worked example decompositions.
208///
209/// - `repr(C)`: <https://doc.rust-lang.org/reference/type-layout.html#reprc-enums-with-fields>
210/// - `repr(int)`: <https://doc.rust-lang.org/reference/type-layout.html#primitive-representation-of-enums-with-fields>
211/// - `repr(C, int)`: <https://doc.rust-lang.org/reference/type-layout.html#combining-primitive-representations-of-enums-with-fields-and-reprc>
212pub(crate) fn derive_is_bit_valid(
213 enum_ident: &Ident,
214 repr: &EnumRepr,
215 generics: &Generics,
216 data: &DataEnum,
217 zerocopy_crate: &Path,
218) -> Result<TokenStream, Error> {
219 let trait_path = Trait::TryFromBytes.crate_path(zerocopy_crate);
220 let tag_enum = generate_tag_enum(repr, data);
221 let tag_consts = generate_tag_consts(data);
222
223 let (outer_tag_type, inner_tag_type) = if repr.is_c() {
224 (quote! { ___ZerocopyTag }, quote! { () })
225 } else if repr.is_primitive() {
226 (quote! { () }, quote! { ___ZerocopyTag })
227 } else {
228 return Err(Error::new(
229 Span::call_site(),
230 "must have #[repr(C)] or #[repr(Int)] attribute in order to guarantee this type's memory layout",
231 ));
232 };
233
234 let variant_structs = generate_variant_structs(enum_ident, generics, data, zerocopy_crate);
235 let variants_union = generate_variants_union(generics, data);
236
237 let (_, ty_generics, _) = generics.split_for_impl();
238
239 let match_arms = data.variants.iter().map(|variant| {
240 let tag_ident = tag_ident(&variant.ident);
241 let variant_struct_ident = variant_struct_ident(&variant.ident);
242
243 if matches!(variant.fields, Fields::Unit) {
244 // Unit variants don't need any further validation beyond checking
245 // the tag.
246 quote! {
247 #tag_ident => true
248 }
249 } else {
250 quote! {
251 #tag_ident => {
252 // SAFETY:
253 // - This cast is from a `repr(C)` union which has a field
254 // of type `variant_struct_ident` to that variant struct
255 // type itself. This addresses a subset of the bytes
256 // addressed by `variants`.
257 // - The returned pointer is cast from `p`, and so has the
258 // same provenance as `p`.
259 // - We checked that the tag of the enum matched the
260 // constant for this variant, so this cast preserves
261 // types and locations of all fields. Therefore, any
262 // `UnsafeCell`s will have the same location as in the
263 // original type.
264 let variant = unsafe {
265 variants.cast_unsized_unchecked(
266 |p: #zerocopy_crate::pointer::PtrInner<'_, ___ZerocopyVariants #ty_generics>| {
267 p.cast_sized::<#variant_struct_ident #ty_generics>()
268 }
269 )
270 };
271 // SAFETY: `cast_unsized_unchecked` removes the
272 // initialization invariant from `p`, so we re-assert that
273 // all of the bytes are initialized.
274 let variant = unsafe { variant.assume_initialized() };
275 <
276 #variant_struct_ident #ty_generics as #trait_path
277 >::is_bit_valid(variant)
278 }
279 }
280 }
281 });
282
283 Ok(quote! {
284 // SAFETY: We use `is_bit_valid` to validate that the bit pattern of the
285 // enum's tag corresponds to one of the enum's discriminants. Then, we
286 // check the bit validity of each field of the corresponding variant.
287 // Thus, this is a sound implementation of `is_bit_valid`.
288 fn is_bit_valid<___ZerocopyAliasing>(
289 mut candidate: #zerocopy_crate::Maybe<'_, Self, ___ZerocopyAliasing>,
290 ) -> #zerocopy_crate::util::macro_util::core_reexport::primitive::bool
291 where
292 ___ZerocopyAliasing: #zerocopy_crate::pointer::invariant::Reference,
293 {
294 use #zerocopy_crate::util::macro_util::core_reexport;
295
296 #tag_enum
297
298 type ___ZerocopyTagPrimitive = #zerocopy_crate::util::macro_util::SizeToTag<
299 { core_reexport::mem::size_of::<___ZerocopyTag>() },
300 >;
301
302 #tag_consts
303
304 type ___ZerocopyOuterTag = #outer_tag_type;
305 type ___ZerocopyInnerTag = #inner_tag_type;
306
307 #variant_structs
308
309 #variants_union
310
311 #[repr(C)]
312 struct ___ZerocopyRawEnum #generics {
313 tag: ___ZerocopyOuterTag,
314 variants: ___ZerocopyVariants #ty_generics,
315 }
316
317 let tag = {
318 // SAFETY:
319 // - The provided cast addresses a subset of the bytes addressed
320 // by `candidate` because it addresses the starting tag of the
321 // enum.
322 // - Because the pointer is cast from `candidate`, it has the
323 // same provenance as it.
324 // - There are no `UnsafeCell`s in the tag because it is a
325 // primitive integer.
326 let tag_ptr = unsafe {
327 candidate.reborrow().cast_unsized_unchecked(|p: #zerocopy_crate::pointer::PtrInner<'_, Self>| {
328 p.cast_sized::<___ZerocopyTagPrimitive>()
329 })
330 };
331 // SAFETY: `tag_ptr` is casted from `candidate`, whose referent
332 // is `Initialized`. Since we have not written uninitialized
333 // bytes into the referent, `tag_ptr` is also `Initialized`.
334 let tag_ptr = unsafe { tag_ptr.assume_initialized() };
335 tag_ptr.recall_validity::<_, (_, (_, _))>().read_unaligned::<#zerocopy_crate::BecauseImmutable>()
336 };
337
338 // SAFETY:
339 // - The raw enum has the same fields in the same locations as the
340 // input enum, and may have a lower alignment. This guarantees
341 // that it addresses a subset of the bytes addressed by
342 // `candidate`.
343 // - The returned pointer is cast from `p`, and so has the same
344 // provenance as `p`.
345 // - The raw enum has the same types at the same locations as the
346 // original enum, and so preserves the locations of any
347 // `UnsafeCell`s.
348 let raw_enum = unsafe {
349 candidate.cast_unsized_unchecked(|p: #zerocopy_crate::pointer::PtrInner<'_, Self>| {
350 p.cast_sized::<___ZerocopyRawEnum #ty_generics>()
351 })
352 };
353 // SAFETY: `cast_unsized_unchecked` removes the initialization
354 // invariant from `p`, so we re-assert that all of the bytes are
355 // initialized.
356 let raw_enum = unsafe { raw_enum.assume_initialized() };
357 // SAFETY:
358 // - This projection returns a subfield of `this` using
359 // `addr_of_mut!`.
360 // - Because the subfield pointer is derived from `this`, it has the
361 // same provenance.
362 // - The locations of `UnsafeCell`s in the subfield match the
363 // locations of `UnsafeCell`s in `this`. This is because the
364 // subfield pointer just points to a smaller portion of the
365 // overall struct.
366 let variants = unsafe {
367 use #zerocopy_crate::pointer::PtrInner;
368 raw_enum.cast_unsized_unchecked(|p: PtrInner<'_, ___ZerocopyRawEnum #ty_generics>| {
369 let p = p.as_non_null().as_ptr();
370 let ptr = core_reexport::ptr::addr_of_mut!((*p).variants);
371 // SAFETY: `ptr` is a projection into `p`, which is
372 // `NonNull`, and guaranteed not to wrap around the address
373 // space. Thus, `ptr` cannot be null.
374 let ptr = unsafe { core_reexport::ptr::NonNull::new_unchecked(ptr) };
375 unsafe { PtrInner::new(ptr) }
376 })
377 };
378
379 #[allow(non_upper_case_globals)]
380 match tag {
381 #(#match_arms,)*
382 _ => false,
383 }
384 }
385 })
386}