1use proc_macro2::{Span, TokenStream};
12use quote::quote;
13use syn::{Ident, Index, TraitBound};
14use synstructure::{decl_derive, Structure, BindStyle, AddBounds};
15
16fn get_discriminant_size_type(len: usize) -> TokenStream {
18 if len <= u8::MAX as usize {
19 quote! { u8 }
20 } else if len <= u16::MAX as usize {
21 quote! { u16 }
22 } else {
23 quote! { u32 }
24 }
25}
26
27fn is_struct(s: &Structure) -> bool {
28 matches!(s.variants(), [v] if v.prefix.is_none())
30}
31
32fn derive_max_size(s: &Structure) -> TokenStream {
33 let max_size = s.variants().iter().fold(quote!(0), |acc, vi| {
34 let variant_size = vi.bindings().iter().fold(quote!(0), |acc, bi| {
35 let ty = &bi.ast().ty;
37 quote!(#acc + <#ty>::max_size())
38 });
39
40 quote! {
42 max(#acc, #variant_size)
43 }
44 });
45
46 let body = if is_struct(s) {
47 max_size
48 } else {
49 let discriminant_size_type = get_discriminant_size_type(s.variants().len());
50 quote! {
51 #discriminant_size_type ::max_size() + #max_size
52 }
53 };
54
55 quote! {
56 #[inline(always)]
57 fn max_size() -> usize {
58 use std::cmp::max;
59 #body
60 }
61 }
62}
63
64fn derive_peek_from_for_enum(s: &mut Structure) -> TokenStream {
65 assert!(!is_struct(s));
66 s.bind_with(|_| BindStyle::Move);
67
68 let num_variants = s.variants().len();
69 let discriminant_size_type = get_discriminant_size_type(num_variants);
70 let body = s
71 .variants()
72 .iter()
73 .enumerate()
74 .fold(quote!(), |acc, (i, vi)| {
75 let bindings = vi
76 .bindings()
77 .iter()
78 .map(|bi| quote!(#bi))
79 .collect::<Vec<_>>();
80
81 let variant_pat = Index::from(i);
82 let poke_exprs = bindings.iter().fold(quote!(), |acc, bi| {
83 quote! {
84 #acc
85 let (#bi, bytes) = peek_poke::peek_from_default(bytes);
86 }
87 });
88 let construct = vi.construct(|_, i| {
89 let bi = &bindings[i];
90 quote!(#bi)
91 });
92
93 quote! {
94 #acc
95 #variant_pat => {
96 #poke_exprs
97 *output = #construct;
98 bytes
99 }
100 }
101 });
102
103 let type_name = s.ast().ident.to_string();
104 let max_tag_value = num_variants - 1;
105
106 quote! {
107 #[inline(always)]
108 unsafe fn peek_from(bytes: *const u8, output: *mut Self) -> *const u8 {
109 let (variant, bytes) = peek_poke::peek_from_default::<#discriminant_size_type>(bytes);
110 match variant {
111 #body
112 out_of_range_tag => {
113 panic!("WRDL: memory corruption detected while parsing {} - enum tag should be <= {}, but was {}",
114 #type_name, #max_tag_value, out_of_range_tag);
115 }
116 }
117 }
118 }
119}
120
121fn derive_peek_from_for_struct(s: &mut Structure) -> TokenStream {
122 assert!(is_struct(s));
123
124 s.variants_mut()[0].bind_with(|_| BindStyle::RefMut);
125 let pat = s.variants()[0].pat();
126 let peek_exprs = s.variants()[0].bindings().iter().fold(quote!(), |acc, bi| {
127 let ty = &bi.ast().ty;
128 quote! {
129 #acc
130 let bytes = <#ty>::peek_from(bytes, #bi);
131 }
132 });
133
134 let body = quote! {
135 #pat => {
136 #peek_exprs
137 bytes
138 }
139 };
140
141 quote! {
142 #[inline(always)]
143 unsafe fn peek_from(bytes: *const u8, output: *mut Self) -> *const u8 {
144 match &mut (*output) {
145 #body
146 }
147 }
148 }
149}
150
151fn derive_poke_into(s: &Structure) -> TokenStream {
152 let is_struct = is_struct(s);
153 let discriminant_size_type = get_discriminant_size_type(s.variants().len());
154 let body = s
155 .variants()
156 .iter()
157 .enumerate()
158 .fold(quote!(), |acc, (i, vi)| {
159 let init = if !is_struct {
160 let index = Index::from(i);
161 quote! {
162 let bytes = #discriminant_size_type::poke_into(&#index, bytes);
163 }
164 } else {
165 quote!()
166 };
167 let variant_pat = vi.pat();
168 let poke_exprs = vi.bindings().iter().fold(init, |acc, bi| {
169 quote! {
170 #acc
171 let bytes = #bi.poke_into(bytes);
172 }
173 });
174
175 quote! {
176 #acc
177 #variant_pat => {
178 #poke_exprs
179 bytes
180 }
181 }
182 });
183
184 quote! {
185 #[inline(always)]
186 unsafe fn poke_into(&self, bytes: *mut u8) -> *mut u8 {
187 match &*self {
188 #body
189 }
190 }
191 }
192}
193
194fn peek_poke_derive(mut s: Structure) -> TokenStream {
195 s.binding_name(|_, i| Ident::new(&format!("__self_{}", i), Span::call_site()));
196
197 let max_size_fn = derive_max_size(&s);
198 let poke_into_fn = derive_poke_into(&s);
199 let peek_from_fn = if is_struct(&s) {
200 derive_peek_from_for_struct(&mut s)
201 } else {
202 derive_peek_from_for_enum(&mut s)
203 };
204
205 let poke_impl = s.gen_impl(quote! {
206 extern crate peek_poke;
207
208 gen unsafe impl peek_poke::Poke for @Self {
209 #max_size_fn
210 #poke_into_fn
211 }
212 });
213
214 let default_trait = syn::parse_str::<TraitBound>("::std::default::Default").unwrap();
218 let peek_trait = syn::parse_str::<TraitBound>("peek_poke::Peek").unwrap();
219
220 let ast = s.ast();
221 let name = &ast.ident;
222 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
223 let mut where_clause = where_clause.cloned();
224 s.add_trait_bounds(&default_trait, &mut where_clause, AddBounds::Generics);
225 s.add_trait_bounds(&peek_trait, &mut where_clause, AddBounds::Generics);
226
227 let peek_impl = quote! {
228 #[allow(non_upper_case_globals)]
229 const _: () = {
230 extern crate peek_poke;
231
232 impl #impl_generics peek_poke::Peek for #name #ty_generics #where_clause {
233 #peek_from_fn
234 }
235 };
236 };
237
238 quote! {
239 #poke_impl
240 #peek_impl
241 }
242}
243
244decl_derive!([PeekPoke] => peek_poke_derive);