Skip to main content

zerocopy_derive/derive/
from_bytes.rs

1use proc_macro2::{Span, TokenStream};
2use syn::{
3    parse_quote, Data, DataEnum, DataStruct, DataUnion, Error, Expr, ExprLit, ExprUnary, Lit, UnOp,
4    WherePredicate,
5};
6
7use crate::{
8    derive::try_from_bytes::derive_try_from_bytes,
9    repr::{CompoundRepr, EnumRepr, Repr, Spanned},
10    util::{enum_size_from_repr, Ctx, FieldBounds, ImplBlockBuilder, Trait, TraitBound},
11};
12/// Returns `Ok(index)` if variant `index` of the enum has a discriminant of
13/// zero. If `Err(bool)` is returned, the boolean is true if the enum has
14/// unknown discriminants (e.g. discriminants set to const expressions which we
15/// can't evaluate in a proc macro). If the enum has unknown discriminants, then
16/// it might have a zero variant that we just can't detect.
17pub(crate) fn find_zero_variant(enm: &DataEnum) -> Result<usize, bool> {
18    // Discriminants can be anywhere in the range [i128::MIN, u128::MAX] because
19    // the discriminant type may be signed or unsigned. Since we only care about
20    // tracking the discriminant when it's less than or equal to zero, we can
21    // avoid u128 -> i128 conversions and bounds checking by making the "next
22    // discriminant" value implicitly negative.
23    // Technically 64 bits is enough, but 128 is better for future compatibility
24    // with https://github.com/rust-lang/rust/issues/56071
25    let mut next_negative_discriminant = Some(0);
26
27    // Sometimes we encounter explicit discriminants that we can't know the
28    // value of (e.g. a constant expression that requires evaluation). These
29    // could evaluate to zero or a negative number, but we can't assume that
30    // they do (no false positives allowed!). So we treat them like strictly-
31    // positive values that can't result in any zero variants, and track whether
32    // we've encountered any unknown discriminants.
33    let mut has_unknown_discriminants = false;
34
35    for (i, v) in enm.variants.iter().enumerate() {
36        match v.discriminant.as_ref() {
37            // Implicit discriminant
38            None => {
39                match next_negative_discriminant.as_mut() {
40                    Some(0) => return Ok(i),
41                    // n is nonzero so subtraction is always safe
42                    Some(n) => *n -= 1,
43                    None => (),
44                }
45            }
46            // Explicit positive discriminant
47            Some((_, Expr::Lit(ExprLit { lit: Lit::Int(int), .. }))) => {
48                match int.base10_parse::<u128>().ok() {
49                    Some(0) => return Ok(i),
50                    Some(_) => next_negative_discriminant = None,
51                    None => {
52                        // Numbers should never fail to parse, but just in case:
53                        has_unknown_discriminants = true;
54                        next_negative_discriminant = None;
55                    }
56                }
57            }
58            // Explicit negative discriminant
59            Some((_, Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }))) => match &**expr {
60                Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => {
61                    match int.base10_parse::<u128>().ok() {
62                        Some(0) => return Ok(i),
63                        // x is nonzero so subtraction is always safe
64                        Some(x) => next_negative_discriminant = Some(x - 1),
65                        None => {
66                            // Numbers should never fail to parse, but just in
67                            // case:
68                            has_unknown_discriminants = true;
69                            next_negative_discriminant = None;
70                        }
71                    }
72                }
73                // Unknown negative discriminant (e.g. const repr)
74                _ => {
75                    has_unknown_discriminants = true;
76                    next_negative_discriminant = None;
77                }
78            },
79            // Unknown discriminant (e.g. const expr)
80            _ => {
81                has_unknown_discriminants = true;
82                next_negative_discriminant = None;
83            }
84        }
85    }
86
87    Err(has_unknown_discriminants)
88}
89pub(crate) fn derive_from_zeros(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> {
90    let try_from_bytes = derive_try_from_bytes(ctx, top_level)?;
91    let from_zeros = match &ctx.ast.data {
92        Data::Struct(strct) => derive_from_zeros_struct(ctx, strct),
93        Data::Enum(enm) => derive_from_zeros_enum(ctx, enm)?,
94        Data::Union(unn) => derive_from_zeros_union(ctx, unn),
95    };
96    Ok(IntoIterator::into_iter([try_from_bytes, from_zeros]).collect())
97}
98pub(crate) fn derive_from_bytes(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> {
99    let from_zeros = derive_from_zeros(ctx, top_level)?;
100    let from_bytes = match &ctx.ast.data {
101        Data::Struct(strct) => derive_from_bytes_struct(ctx, strct),
102        Data::Enum(enm) => derive_from_bytes_enum(ctx, enm)?,
103        Data::Union(unn) => derive_from_bytes_union(ctx, unn),
104    };
105
106    Ok(IntoIterator::into_iter([from_zeros, from_bytes]).collect())
107}
108fn derive_from_zeros_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream {
109    ImplBlockBuilder::new(ctx, strct, Trait::FromZeros, FieldBounds::ALL_SELF).build()
110}
111fn derive_from_zeros_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> {
112    let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?;
113
114    // We don't actually care what the repr is; we just care that it's one of
115    // the allowed ones.
116    match repr {
117        Repr::Compound(Spanned { t: CompoundRepr::C | CompoundRepr::Primitive(_), span: _ }, _) => {
118        }
119        Repr::Transparent(_) | Repr::Compound(Spanned { t: CompoundRepr::Rust, span: _ }, _) => {
120            return ctx.error_or_skip(
121                Error::new(
122                    Span::call_site(),
123                    "must have #[repr(C)] or #[repr(Int)] attribute in order to guarantee this type's memory layout",
124                ),
125            );
126        }
127    }
128
129    let zero_variant = match find_zero_variant(enm) {
130        Ok(index) => enm.variants.iter().nth(index).unwrap(),
131        // Has unknown variants
132        Err(true) => {
133            return ctx.error_or_skip(Error::new_spanned(
134                &ctx.ast,
135                "FromZeros only supported on enums with a variant that has a discriminant of `0`\n\
136                help: This enum has discriminants which are not literal integers. One of those may \
137                define or imply which variant has a discriminant of zero. Use a literal integer to \
138                define or imply the variant with a discriminant of zero.",
139            ));
140        }
141        // Does not have unknown variants
142        Err(false) => {
143            return ctx.error_or_skip(Error::new_spanned(
144                &ctx.ast,
145                "FromZeros only supported on enums with a variant that has a discriminant of `0`",
146            ));
147        }
148    };
149
150    let zerocopy_crate = &ctx.zerocopy_crate;
151    let explicit_bounds = zero_variant
152        .fields
153        .iter()
154        .map(|field| {
155            let ty = &field.ty;
156            parse_quote! { #ty: #zerocopy_crate::FromZeros }
157        })
158        .collect::<Vec<WherePredicate>>();
159
160    Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromZeros, FieldBounds::Explicit(explicit_bounds))
161        .build())
162}
163fn derive_from_zeros_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream {
164    let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]);
165    ImplBlockBuilder::new(ctx, unn, Trait::FromZeros, field_type_trait_bounds).build()
166}
167fn derive_from_bytes_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream {
168    ImplBlockBuilder::new(ctx, strct, Trait::FromBytes, FieldBounds::ALL_SELF).build()
169}
170fn derive_from_bytes_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> {
171    let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?;
172
173    let variants_required = 1usize << enum_size_from_repr(&repr)?;
174    if enm.variants.len() != variants_required {
175        return ctx.error_or_skip(Error::new_spanned(
176            &ctx.ast,
177            format!(
178                "FromBytes only supported on {} enum with {} variants",
179                repr.repr_type_name(),
180                variants_required
181            ),
182        ));
183    }
184
185    Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromBytes, FieldBounds::ALL_SELF).build())
186}
187fn derive_from_bytes_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream {
188    let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]);
189    ImplBlockBuilder::new(ctx, unn, Trait::FromBytes, field_type_trait_bounds).build()
190}