peek_poke_derive/
lib.rs

1// Copyright 2019 The Servo Project Developers. See the COPYRIGHT
2// file at the top-level directory of this distribution and at
3// http://rust-lang.org/COPYRIGHT.
4//
5// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8// option. This file may not be copied, modified, or distributed
9// except according to those terms.
10
11use proc_macro2::{Span, TokenStream};
12use quote::quote;
13use syn::{Ident, Index, TraitBound};
14use synstructure::{decl_derive, Structure, BindStyle, AddBounds};
15
16/// Calculates size type for number of variants (used for enums)
17fn get_discriminant_size_type(len: usize) -> TokenStream {
18    if len <= u8::MAX as usize {
19        quote! { u8 }
20    } else if len <= u16::MAX as usize {
21        quote! { u16 }
22    } else {
23        quote! { u32 }
24    }
25}
26
27fn is_struct(s: &Structure) -> bool {
28    // a single variant with no prefix is 'struct'
29    matches!(s.variants(), [v] if v.prefix.is_none())
30}
31
32fn derive_max_size(s: &Structure) -> TokenStream {
33    let max_size = s.variants().iter().fold(quote!(0), |acc, vi| {
34        let variant_size = vi.bindings().iter().fold(quote!(0), |acc, bi| {
35            // compute size of each variant by summing the sizes of its bindings
36            let ty = &bi.ast().ty;
37            quote!(#acc + <#ty>::max_size())
38        });
39
40        // find the maximum of each variant
41        quote! {
42            max(#acc, #variant_size)
43        }
44    });
45
46    let body = if is_struct(s) {
47        max_size
48    } else {
49        let discriminant_size_type = get_discriminant_size_type(s.variants().len());
50        quote! {
51            #discriminant_size_type ::max_size() + #max_size
52        }
53    };
54
55    quote! {
56        #[inline(always)]
57        fn max_size() -> usize {
58            use std::cmp::max;
59            #body
60        }
61    }
62}
63
64fn derive_peek_from_for_enum(s: &mut Structure) -> TokenStream {
65    assert!(!is_struct(s));
66    s.bind_with(|_| BindStyle::Move);
67
68    let num_variants = s.variants().len();
69    let discriminant_size_type = get_discriminant_size_type(num_variants);
70    let body = s
71        .variants()
72        .iter()
73        .enumerate()
74        .fold(quote!(), |acc, (i, vi)| {
75            let bindings = vi
76                .bindings()
77                .iter()
78                .map(|bi| quote!(#bi))
79                .collect::<Vec<_>>();
80
81            let variant_pat = Index::from(i);
82            let poke_exprs = bindings.iter().fold(quote!(), |acc, bi| {
83                quote! {
84                    #acc
85                    let (#bi, bytes) = peek_poke::peek_from_default(bytes);
86                }
87            });
88            let construct = vi.construct(|_, i| {
89                let bi = &bindings[i];
90                quote!(#bi)
91            });
92
93            quote! {
94                #acc
95                #variant_pat => {
96                    #poke_exprs
97                    *output = #construct;
98                    bytes
99                }
100            }
101        });
102
103    let type_name = s.ast().ident.to_string();
104    let max_tag_value = num_variants - 1;
105
106    quote! {
107        #[inline(always)]
108        unsafe fn peek_from(bytes: *const u8, output: *mut Self) -> *const u8 {
109            let (variant, bytes) = peek_poke::peek_from_default::<#discriminant_size_type>(bytes);
110            match variant {
111                #body
112                out_of_range_tag => {
113                    panic!("WRDL: memory corruption detected while parsing {} - enum tag should be <= {}, but was {}",
114                        #type_name, #max_tag_value, out_of_range_tag);
115                }
116            }
117        }
118    }
119}
120
121fn derive_peek_from_for_struct(s: &mut Structure) -> TokenStream {
122    assert!(is_struct(s));
123
124    s.variants_mut()[0].bind_with(|_| BindStyle::RefMut);
125    let pat = s.variants()[0].pat();
126    let peek_exprs = s.variants()[0].bindings().iter().fold(quote!(), |acc, bi| {
127        let ty = &bi.ast().ty;
128        quote! {
129            #acc
130            let bytes = <#ty>::peek_from(bytes, #bi);
131        }
132    });
133
134    let body = quote! {
135        #pat => {
136            #peek_exprs
137            bytes
138        }
139    };
140
141    quote! {
142        #[inline(always)]
143        unsafe fn peek_from(bytes: *const u8, output: *mut Self) -> *const u8 {
144            match &mut (*output) {
145                #body
146            }
147        }
148    }
149}
150
151fn derive_poke_into(s: &Structure) -> TokenStream {
152    let is_struct = is_struct(s);
153    let discriminant_size_type = get_discriminant_size_type(s.variants().len());
154    let body = s
155        .variants()
156        .iter()
157        .enumerate()
158        .fold(quote!(), |acc, (i, vi)| {
159            let init = if !is_struct {
160                let index = Index::from(i);
161                quote! {
162                    let bytes = #discriminant_size_type::poke_into(&#index, bytes);
163                }
164            } else {
165                quote!()
166            };
167            let variant_pat = vi.pat();
168            let poke_exprs = vi.bindings().iter().fold(init, |acc, bi| {
169                quote! {
170                    #acc
171                    let bytes = #bi.poke_into(bytes);
172                }
173            });
174
175            quote! {
176                #acc
177                #variant_pat => {
178                    #poke_exprs
179                    bytes
180                }
181            }
182        });
183
184    quote! {
185        #[inline(always)]
186        unsafe fn poke_into(&self, bytes: *mut u8) -> *mut u8 {
187            match &*self {
188                #body
189            }
190        }
191    }
192}
193
194fn peek_poke_derive(mut s: Structure) -> TokenStream {
195    s.binding_name(|_, i| Ident::new(&format!("__self_{}", i), Span::call_site()));
196
197    let max_size_fn = derive_max_size(&s);
198    let poke_into_fn = derive_poke_into(&s);
199    let peek_from_fn = if is_struct(&s) {
200        derive_peek_from_for_struct(&mut s)
201    } else {
202        derive_peek_from_for_enum(&mut s)
203    };
204
205    let poke_impl = s.gen_impl(quote! {
206        extern crate peek_poke;
207
208        gen unsafe impl peek_poke::Poke for @Self {
209            #max_size_fn
210            #poke_into_fn
211        }
212    });
213
214    // To implement `fn peek_from` we require that types implement `Default`
215    // trait to create temporary values. This code does the addition all
216    // manually until https://github.com/mystor/synstructure/issues/24 is fixed.
217    let default_trait = syn::parse_str::<TraitBound>("::std::default::Default").unwrap();
218    let peek_trait = syn::parse_str::<TraitBound>("peek_poke::Peek").unwrap();
219
220    let ast = s.ast();
221    let name = &ast.ident;
222    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
223    let mut where_clause = where_clause.cloned();
224    s.add_trait_bounds(&default_trait, &mut where_clause, AddBounds::Generics);
225    s.add_trait_bounds(&peek_trait, &mut where_clause, AddBounds::Generics);
226
227    let peek_impl = quote! {
228        #[allow(non_upper_case_globals)]
229        const _: () = {
230            extern crate peek_poke;
231
232            impl #impl_generics peek_poke::Peek for #name #ty_generics #where_clause {
233                #peek_from_fn
234            }
235        };
236    };
237
238    quote! {
239        #poke_impl
240        #peek_impl
241    }
242}
243
244decl_derive!([PeekPoke] => peek_poke_derive);