rustls/msgs/
macros.rs

1/// A macro which defines an enum type.
2macro_rules! enum_builder {
3    (
4        $(#[doc = $comment:literal])*
5        #[repr($uint:ty)]
6        $enum_vis:vis enum $enum_name:ident
7        {
8          $( $enum_var:ident => $enum_val:literal),* $(,)?
9          $( !Debug:
10            $( $enum_var_nd:ident => $enum_val_nd:literal),* $(,)?
11          )?
12        }
13    ) => {
14        $(#[doc = $comment])*
15        #[non_exhaustive]
16        #[derive(PartialEq, Eq, Clone, Copy)]
17        $enum_vis enum $enum_name {
18            $( $enum_var),*
19            $(, $($enum_var_nd),* )?
20            ,Unknown($uint)
21        }
22
23        impl $enum_name {
24            // NOTE(allow) generated irrespective if there are callers
25            #[allow(dead_code)]
26            $enum_vis fn to_array(self) -> [u8; core::mem::size_of::<$uint>()] {
27                <$uint>::from(self).to_be_bytes()
28            }
29
30            // NOTE(allow) generated irrespective if there are callers
31            #[allow(dead_code)]
32            $enum_vis fn as_str(&self) -> Option<&'static str> {
33                match self {
34                    $( $enum_name::$enum_var => Some(stringify!($enum_var))),*
35                    $(, $( $enum_name::$enum_var_nd => Some(stringify!($enum_var_nd))),* )?
36                    ,$enum_name::Unknown(_) => None,
37                }
38            }
39        }
40
41        impl Codec<'_> for $enum_name {
42            // NOTE(allow) fully qualified Vec is only needed in no-std mode
43            #[allow(unused_qualifications)]
44            fn encode(&self, bytes: &mut alloc::vec::Vec<u8>) {
45                <$uint>::from(*self).encode(bytes);
46            }
47
48            fn read(r: &mut Reader<'_>) -> Result<Self, crate::error::InvalidMessage> {
49                match <$uint>::read(r) {
50                    Ok(x) => Ok($enum_name::from(x)),
51                    Err(_) => Err(crate::error::InvalidMessage::MissingData(stringify!($enum_name))),
52                }
53            }
54        }
55
56        impl From<$uint> for $enum_name {
57            fn from(x: $uint) -> Self {
58                match x {
59                    $($enum_val => $enum_name::$enum_var),*
60                    $(, $($enum_val_nd => $enum_name::$enum_var_nd),* )?
61                    , x => $enum_name::Unknown(x),
62                }
63            }
64        }
65
66        impl From<$enum_name> for $uint {
67            fn from(value: $enum_name) -> Self {
68                match value {
69                    $( $enum_name::$enum_var => $enum_val),*
70                    $(, $( $enum_name::$enum_var_nd => $enum_val_nd),* )?
71                    ,$enum_name::Unknown(x) => x
72                }
73            }
74        }
75
76        impl core::fmt::Debug for $enum_name {
77            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
78                match self {
79                    $( $enum_name::$enum_var => f.write_str(stringify!($enum_var)), )*
80                    _ => write!(f, "{}(0x{:x?})", stringify!($enum_name), <$uint>::from(*self)),
81                }
82            }
83        }
84    };
85}
86
87/// A macro which defines a structure containing TLS extensions
88///
89/// The contents are defined by two blocks, which are merged to
90/// give the struct's items.  The second block is optional.
91///
92/// The first block defines the items read-into by decoding,
93/// and used for encoding.
94///
95/// The type of each item in the first block _must_ be an `Option`.
96/// This records the presence of that extension.
97///
98/// Each item in the first block is prefixed with a match arm,
99/// which must match an `ExtensionType` variant.  This maps
100/// the item to its extension type.
101///
102/// Items in the second block are not encoded or decoded-to.
103/// They therefore must have a reasonable `Default` value.
104///
105/// All items must have a `Default`, `Debug` and `Clone`.
106macro_rules! extension_struct {
107    (
108        $(#[doc = $comment:literal])*
109        $struct_vis:vis struct $struct_name:ident$(<$struct_lt:lifetime>)*
110        {
111          $(
112            $(#[$item_attr:meta])*
113            $item_id:path => $item_vis:vis $item_slot:ident : Option<$item_ty:ty>,
114          )+
115        } $( + {
116          $(
117            $(#[$meta_attr:meta])*
118            $meta_vis:vis $meta_slot:ident : $meta_ty:ty,
119          )+
120        })*
121    ) => {
122        $(#[doc = $comment])*
123        #[non_exhaustive]
124        #[derive(Clone, Default)]
125        $struct_vis struct $struct_name$(<$struct_lt>)* {
126            $(
127              $(#[$item_attr])*
128              $item_vis $item_slot: Option<$item_ty>,
129            )+
130            $($(
131              $(#[$meta_attr])*
132              $meta_vis $meta_slot: $meta_ty,
133            )+)*
134        }
135
136        impl<'a> $struct_name$(<$struct_lt>)* {
137            /// Reads one extension typ, length and body from `r`.
138            ///
139            /// Unhandled extensions (according to `read_extension_body()` are inserted into `unknown_extensions`)
140            fn read_one(
141                &mut self,
142                r: &mut Reader<'a>,
143                mut unknown: impl FnMut(ExtensionType) -> Result<(), InvalidMessage>,
144            ) -> Result<ExtensionType, InvalidMessage> {
145                let typ = ExtensionType::read(r)?;
146                let len = usize::from(u16::read(r)?);
147                let mut ext_body = r.sub(len)?;
148                match self.read_extension_body(typ, &mut ext_body)? {
149                    true => ext_body.expect_empty(stringify!($struct_name))?,
150                    false => unknown(typ)?,
151
152                };
153                Ok(typ)
154            }
155
156            /// Reads one extension body for an extension named by `typ`.
157            ///
158            /// Returns `true` if handled, `false` otherwise.
159            ///
160            /// `r` is fully consumed if `typ` is unhandled.
161            fn read_extension_body(
162                &mut self,
163                typ: ExtensionType,
164                r: &mut Reader<'a>,
165            ) -> Result<bool, InvalidMessage> {
166                match typ {
167                   $(
168                      $item_id => Self::read_once(r, $item_id, &mut self.$item_slot)?,
169                   )*
170
171                   // read and ignore unhandled extensions
172                   _ => {
173                       r.rest();
174                       return Ok(false);
175                   }
176                }
177
178                Ok(true)
179            }
180
181            /// Decode `r` as `T` into `out`, only if `out` is `None`.
182            fn read_once<T>(r: &mut Reader<'a>, id: ExtensionType, out: &mut Option<T>) -> Result<(), InvalidMessage>
183            where T: Codec<'a>,
184            {
185                if let Some(_) = out {
186                    return Err(InvalidMessage::DuplicateExtension(u16::from(id)));
187                }
188
189                *out = Some(T::read(r)?);
190                Ok(())
191            }
192
193            /// Encode one extension body for `typ` into `output`.
194            ///
195            /// Adds nothing to `output` if `typ` is absent from this
196            /// struct, either because it is `None` or unhandled by
197            /// this struct.
198            fn encode_one(
199                &self,
200                typ: ExtensionType,
201                output: &mut Vec<u8>,
202            ) {
203                match typ {
204                    $(
205                        $item_id => if let Some(item) = &self.$item_slot {
206                            typ.encode(output);
207                            item.encode(LengthPrefixedBuffer::new(ListLength::U16, output).buf);
208                        },
209
210                    )*
211                    _ => {},
212                }
213            }
214
215            /// Return a list of extensions whose items are `Some`
216            #[allow(dead_code)]
217            pub(crate) fn collect_used(&self) -> Vec<ExtensionType> {
218                let mut r = Vec::with_capacity(Self::ALL_EXTENSIONS.len());
219
220                $(
221                    if let Some(_) = &self.$item_slot {
222                        r.push($item_id);
223                    }
224                )*
225
226                r
227            }
228
229            /// Clone the value of the extension identified by `typ` from `source` to `self`.
230            ///
231            /// Does nothing if `typ` is not an extension handled by this object.
232            #[allow(dead_code)]
233            pub(crate) fn clone_one(
234                &mut self,
235                source: &Self,
236                typ: ExtensionType,
237            )  {
238                match typ {
239                    $(
240                        $item_id => self.$item_slot = source.$item_slot.clone(),
241                    )*
242                    _ => {},
243                }
244            }
245
246            /// Remove the extension identified by `typ` from `self`.
247            #[allow(dead_code)]
248            pub(crate) fn clear(&mut self, typ: ExtensionType) {
249                match typ {
250                    $(
251                        $item_id => self.$item_slot = None,
252                    )*
253                    _ => {},
254                }
255            }
256
257            /// Return true if all present extensions are named in `allowed`
258            #[allow(dead_code)]
259            pub(crate) fn only_contains(&self, allowed: &[ExtensionType]) -> bool {
260                $(
261                    if let Some(_) = &self.$item_slot {
262                        if !allowed.contains(&$item_id) {
263                            return false;
264                        }
265                    }
266                )*
267
268                true
269            }
270
271            /// Return true if any extension named in `exts` is present.
272            #[allow(dead_code)]
273            pub(crate) fn contains_any(&self, exts: &[ExtensionType]) -> bool {
274                for e in exts {
275                    if self.contains(*e) {
276                        return true;
277                    }
278                }
279                false
280            }
281
282            fn contains(&self, e: ExtensionType) -> bool {
283                match e {
284                    $(
285
286                        $item_id => self.$item_slot.is_some(),
287                    )*
288                    _ => false,
289                }
290            }
291
292            /// Every `ExtensionType` this structure may encode/decode.
293            const ALL_EXTENSIONS: &'static [ExtensionType] = &[
294                $($item_id,)*
295            ];
296        }
297
298        impl<'a> core::fmt::Debug for $struct_name$(<$struct_lt>)*  {
299            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
300                let mut ds = f.debug_struct(stringify!($struct_name));
301                $(
302                    if let Some(ext) = &self.$item_slot {
303                        ds.field(stringify!($item_slot), ext);
304                    }
305                )*
306                $($(
307                    ds.field(stringify!($meta_slot), &self.$meta_slot);
308                )+)*
309                ds.finish_non_exhaustive()
310            }
311        }
312    }
313}