1use proc_macro2::{Span, TokenStream};
2use syn::{
3 parse_quote, Data, DataEnum, DataStruct, DataUnion, Error, Expr, ExprLit, ExprUnary, Lit, UnOp,
4 WherePredicate,
5};
6
7use crate::{
8 derive::try_from_bytes::derive_try_from_bytes,
9 repr::{CompoundRepr, EnumRepr, Repr, Spanned},
10 util::{enum_size_from_repr, Ctx, FieldBounds, ImplBlockBuilder, Trait, TraitBound},
11};
12pub(crate) fn find_zero_variant(enm: &DataEnum) -> Result<usize, bool> {
18 let mut next_negative_discriminant = Some(0);
26
27 let mut has_unknown_discriminants = false;
34
35 for (i, v) in enm.variants.iter().enumerate() {
36 match v.discriminant.as_ref() {
37 None => {
39 match next_negative_discriminant.as_mut() {
40 Some(0) => return Ok(i),
41 Some(n) => *n -= 1,
43 None => (),
44 }
45 }
46 Some((_, Expr::Lit(ExprLit { lit: Lit::Int(int), .. }))) => {
48 match int.base10_parse::<u128>().ok() {
49 Some(0) => return Ok(i),
50 Some(_) => next_negative_discriminant = None,
51 None => {
52 has_unknown_discriminants = true;
54 next_negative_discriminant = None;
55 }
56 }
57 }
58 Some((_, Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }))) => match &**expr {
60 Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => {
61 match int.base10_parse::<u128>().ok() {
62 Some(0) => return Ok(i),
63 Some(x) => next_negative_discriminant = Some(x - 1),
65 None => {
66 has_unknown_discriminants = true;
69 next_negative_discriminant = None;
70 }
71 }
72 }
73 _ => {
75 has_unknown_discriminants = true;
76 next_negative_discriminant = None;
77 }
78 },
79 _ => {
81 has_unknown_discriminants = true;
82 next_negative_discriminant = None;
83 }
84 }
85 }
86
87 Err(has_unknown_discriminants)
88}
89pub(crate) fn derive_from_zeros(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> {
90 let try_from_bytes = derive_try_from_bytes(ctx, top_level)?;
91 let from_zeros = match &ctx.ast.data {
92 Data::Struct(strct) => derive_from_zeros_struct(ctx, strct),
93 Data::Enum(enm) => derive_from_zeros_enum(ctx, enm)?,
94 Data::Union(unn) => derive_from_zeros_union(ctx, unn),
95 };
96 Ok(IntoIterator::into_iter([try_from_bytes, from_zeros]).collect())
97}
98pub(crate) fn derive_from_bytes(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> {
99 let from_zeros = derive_from_zeros(ctx, top_level)?;
100 let from_bytes = match &ctx.ast.data {
101 Data::Struct(strct) => derive_from_bytes_struct(ctx, strct),
102 Data::Enum(enm) => derive_from_bytes_enum(ctx, enm)?,
103 Data::Union(unn) => derive_from_bytes_union(ctx, unn),
104 };
105
106 Ok(IntoIterator::into_iter([from_zeros, from_bytes]).collect())
107}
108fn derive_from_zeros_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream {
109 ImplBlockBuilder::new(ctx, strct, Trait::FromZeros, FieldBounds::ALL_SELF).build()
110}
111fn derive_from_zeros_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> {
112 let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?;
113
114 match repr {
117 Repr::Compound(Spanned { t: CompoundRepr::C | CompoundRepr::Primitive(_), span: _ }, _) => {
118 }
119 Repr::Transparent(_) | Repr::Compound(Spanned { t: CompoundRepr::Rust, span: _ }, _) => {
120 return ctx.error_or_skip(
121 Error::new(
122 Span::call_site(),
123 "must have #[repr(C)] or #[repr(Int)] attribute in order to guarantee this type's memory layout",
124 ),
125 );
126 }
127 }
128
129 let zero_variant = match find_zero_variant(enm) {
130 Ok(index) => enm.variants.iter().nth(index).unwrap(),
131 Err(true) => {
133 return ctx.error_or_skip(Error::new_spanned(
134 &ctx.ast,
135 "FromZeros only supported on enums with a variant that has a discriminant of `0`\n\
136 help: This enum has discriminants which are not literal integers. One of those may \
137 define or imply which variant has a discriminant of zero. Use a literal integer to \
138 define or imply the variant with a discriminant of zero.",
139 ));
140 }
141 Err(false) => {
143 return ctx.error_or_skip(Error::new_spanned(
144 &ctx.ast,
145 "FromZeros only supported on enums with a variant that has a discriminant of `0`",
146 ));
147 }
148 };
149
150 let zerocopy_crate = &ctx.zerocopy_crate;
151 let explicit_bounds = zero_variant
152 .fields
153 .iter()
154 .map(|field| {
155 let ty = &field.ty;
156 parse_quote! { #ty: #zerocopy_crate::FromZeros }
157 })
158 .collect::<Vec<WherePredicate>>();
159
160 Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromZeros, FieldBounds::Explicit(explicit_bounds))
161 .build())
162}
163fn derive_from_zeros_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream {
164 let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]);
165 ImplBlockBuilder::new(ctx, unn, Trait::FromZeros, field_type_trait_bounds).build()
166}
167fn derive_from_bytes_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream {
168 ImplBlockBuilder::new(ctx, strct, Trait::FromBytes, FieldBounds::ALL_SELF).build()
169}
170fn derive_from_bytes_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> {
171 let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?;
172
173 let variants_required = 1usize << enum_size_from_repr(&repr)?;
174 if enm.variants.len() != variants_required {
175 return ctx.error_or_skip(Error::new_spanned(
176 &ctx.ast,
177 format!(
178 "FromBytes only supported on {} enum with {} variants",
179 repr.repr_type_name(),
180 variants_required
181 ),
182 ));
183 }
184
185 Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromBytes, FieldBounds::ALL_SELF).build())
186}
187fn derive_from_bytes_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream {
188 let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]);
189 ImplBlockBuilder::new(ctx, unn, Trait::FromBytes, field_type_trait_bounds).build()
190}