diplomat_core/ast/
types.rs

1use proc_macro2::Span;
2use quote::ToTokens;
3use serde::{Deserialize, Serialize};
4use syn::{punctuated::Punctuated, *};
5
6use lazy_static::lazy_static;
7use std::collections::HashMap;
8use std::fmt;
9use std::ops::ControlFlow;
10
11use super::{
12    Attrs, Docs, Enum, Ident, Lifetime, LifetimeEnv, LifetimeTransitivity, Method, NamedLifetime,
13    OpaqueStruct, Path, RustLink, Struct,
14};
15use crate::Env;
16
17/// A type declared inside a Diplomat-annotated module.
18#[derive(Clone, Serialize, Debug, Hash, PartialEq, Eq)]
19#[non_exhaustive]
20pub enum CustomType {
21    /// A non-opaque struct whose fields will be visible across the FFI boundary.
22    Struct(Struct),
23    /// A struct annotated with [`diplomat::opaque`] whose fields are not visible.
24    Opaque(OpaqueStruct),
25    /// A fieldless enum.
26    Enum(Enum),
27}
28
29impl CustomType {
30    /// Get the name of the custom type, which is unique within a module.
31    pub fn name(&self) -> &Ident {
32        match self {
33            CustomType::Struct(strct) => &strct.name,
34            CustomType::Opaque(strct) => &strct.name,
35            CustomType::Enum(enm) => &enm.name,
36        }
37    }
38
39    /// Get the methods declared in impls of the custom type.
40    pub fn methods(&self) -> &Vec<Method> {
41        match self {
42            CustomType::Struct(strct) => &strct.methods,
43            CustomType::Opaque(strct) => &strct.methods,
44            CustomType::Enum(enm) => &enm.methods,
45        }
46    }
47
48    pub fn attrs(&self) -> &Attrs {
49        match self {
50            CustomType::Struct(strct) => &strct.attrs,
51            CustomType::Opaque(strct) => &strct.attrs,
52            CustomType::Enum(enm) => &enm.attrs,
53        }
54    }
55
56    /// The name of the destructor in C
57    pub fn dtor_name(&self) -> String {
58        let name = self.attrs().abi_rename.apply(self.name().as_str().into());
59        format!("{name}_destroy")
60    }
61
62    /// Get the doc lines of the custom type.
63    pub fn docs(&self) -> &Docs {
64        match self {
65            CustomType::Struct(strct) => &strct.docs,
66            CustomType::Opaque(strct) => &strct.docs,
67            CustomType::Enum(enm) => &enm.docs,
68        }
69    }
70
71    /// Get all rust links on this type and its methods
72    pub fn all_rust_links(&self) -> impl Iterator<Item = &RustLink> + '_ {
73        [self.docs()]
74            .into_iter()
75            .chain(self.methods().iter().map(|m| m.docs()))
76            .flat_map(|d| d.rust_links().iter())
77    }
78
79    pub fn self_path(&self, in_path: &Path) -> Path {
80        in_path.sub_path(self.name().clone())
81    }
82
83    /// Get the lifetimes of the custom type.
84    pub fn lifetimes(&self) -> Option<&LifetimeEnv> {
85        match self {
86            CustomType::Struct(strct) => Some(&strct.lifetimes),
87            CustomType::Opaque(strct) => Some(&strct.lifetimes),
88            CustomType::Enum(_) => None,
89        }
90    }
91}
92
93/// A symbol declared in a module, which can either be a pointer to another path,
94/// or a custom type defined directly inside that module
95#[derive(Clone, Serialize, Debug)]
96#[non_exhaustive]
97pub enum ModSymbol {
98    /// A symbol that is a pointer to another path.
99    Alias(Path),
100    /// A symbol that is a submodule.
101    SubModule(Ident),
102    /// A symbol that is a custom type.
103    CustomType(CustomType),
104}
105
106/// A named type that is just a path, e.g. `std::borrow::Cow<'a, T>`.
107#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)]
108#[non_exhaustive]
109pub struct PathType {
110    pub path: Path,
111    pub lifetimes: Vec<Lifetime>,
112}
113
114impl PathType {
115    pub fn to_syn(&self) -> syn::TypePath {
116        let mut path = self.path.to_syn();
117
118        if !self.lifetimes.is_empty() {
119            if let Some(seg) = path.segments.last_mut() {
120                let lifetimes = &self.lifetimes;
121                seg.arguments =
122                    syn::PathArguments::AngleBracketed(syn::parse_quote! { <#(#lifetimes),*> });
123            }
124        }
125
126        syn::TypePath { qself: None, path }
127    }
128
129    pub fn new(path: Path) -> Self {
130        Self {
131            path,
132            lifetimes: vec![],
133        }
134    }
135
136    /// Get the `Self` type from a struct declaration.
137    ///
138    /// Consider the following struct declaration:
139    /// ```
140    /// struct RefList<'a> {
141    ///     data: &'a i32,
142    ///     next: Option<Box<Self>>,
143    /// }
144    /// ```
145    /// When determining what type `Self` is in the `next` field, we would have to call
146    /// this method on the `syn::ItemStruct` that represents this struct declaration.
147    /// This method would then return a `PathType` representing `RefList<'a>`, so we
148    /// know that's what `Self` should refer to.
149    ///
150    /// The reason this function exists though is so when we convert the fields' types
151    /// to `PathType`s, we don't panic. We don't actually need to write the struct's
152    /// field types expanded in the macro, so this function is more for correctness,
153    pub fn extract_self_type(strct: &syn::ItemStruct) -> Self {
154        let self_name = (&strct.ident).into();
155
156        PathType {
157            path: Path {
158                elements: vec![self_name],
159            },
160            lifetimes: strct
161                .generics
162                .lifetimes()
163                .map(|lt_def| (&lt_def.lifetime).into())
164                .collect(),
165        }
166    }
167
168    /// If this is a [`TypeName::Named`], grab the [`CustomType`] it points to from
169    /// the `env`, which contains all [`CustomType`]s across all FFI modules.
170    ///
171    /// Also returns the path the CustomType is in (useful for resolving fields)
172    pub fn resolve_with_path<'a>(&self, in_path: &Path, env: &'a Env) -> (Path, &'a CustomType) {
173        let local_path = &self.path;
174        let mut cur_path = in_path.clone();
175        for (i, elem) in local_path.elements.iter().enumerate() {
176            match elem.as_str() {
177                "crate" => {
178                    // TODO(#34): get the name of enclosing crate from env when we support multiple crates
179                    cur_path = Path::empty()
180                }
181
182                "super" => cur_path = cur_path.get_super(),
183
184                o => match env.get(&cur_path, o) {
185                    Some(ModSymbol::Alias(p)) => {
186                        let mut remaining_elements: Vec<Ident> =
187                            local_path.elements.iter().skip(i + 1).cloned().collect();
188                        let mut new_path = p.elements.clone();
189                        new_path.append(&mut remaining_elements);
190                        return PathType::new(Path { elements: new_path })
191                            .resolve_with_path(&cur_path.clone(), env);
192                    }
193                    Some(ModSymbol::SubModule(name)) => {
194                        cur_path.elements.push(name.clone());
195                    }
196                    Some(ModSymbol::CustomType(t)) => {
197                        if i == local_path.elements.len() - 1 {
198                            return (cur_path, t);
199                        } else {
200                            panic!(
201                                "Unexpected custom type when resolving symbol {} in {}",
202                                o,
203                                cur_path.elements.join("::")
204                            )
205                        }
206                    }
207                    None => panic!(
208                        "Could not resolve symbol {} in {}",
209                        o,
210                        cur_path.elements.join("::")
211                    ),
212                },
213            }
214        }
215
216        panic!(
217            "Path {} does not point to a custom type",
218            in_path.elements.join("::")
219        )
220    }
221
222    /// If this is a [`TypeName::Named`], grab the [`CustomType`] it points to from
223    /// the `env`, which contains all [`CustomType`]s across all FFI modules.
224    ///
225    /// If you need to resolve struct fields later, call [`Self::resolve_with_path()`] instead
226    /// to get the path to resolve the fields in.
227    pub fn resolve<'a>(&self, in_path: &Path, env: &'a Env) -> &'a CustomType {
228        self.resolve_with_path(in_path, env).1
229    }
230}
231
232impl From<&syn::TypePath> for PathType {
233    fn from(other: &syn::TypePath) -> Self {
234        let lifetimes = other
235            .path
236            .segments
237            .last()
238            .and_then(|last| {
239                if let PathArguments::AngleBracketed(angle_generics) = &last.arguments {
240                    Some(
241                        angle_generics
242                            .args
243                            .iter()
244                            .map(|generic_arg| match generic_arg {
245                                GenericArgument::Lifetime(lifetime) => lifetime.into(),
246                                _ => panic!("generic type arguments are unsupported"),
247                            })
248                            .collect(),
249                    )
250                } else {
251                    None
252                }
253            })
254            .unwrap_or_default();
255
256        Self {
257            path: Path::from_syn(&other.path),
258            lifetimes,
259        }
260    }
261}
262
263impl From<Path> for PathType {
264    fn from(other: Path) -> Self {
265        PathType::new(other)
266    }
267}
268
269#[derive(Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)]
270#[allow(clippy::exhaustive_enums)] // there are only two kinds of mutability we care about
271pub enum Mutability {
272    Mutable,
273    Immutable,
274}
275
276impl Mutability {
277    pub fn to_syn(&self) -> Option<Token![mut]> {
278        match self {
279            Mutability::Mutable => Some(syn::token::Mut(Span::call_site())),
280            Mutability::Immutable => None,
281        }
282    }
283
284    pub fn from_syn(t: &Option<Token![mut]>) -> Self {
285        match t {
286            Some(_) => Mutability::Mutable,
287            None => Mutability::Immutable,
288        }
289    }
290
291    /// Returns `true` if `&self` is the mutable variant, otherwise `false`.
292    pub fn is_mutable(&self) -> bool {
293        matches!(self, Mutability::Mutable)
294    }
295
296    /// Returns `true` if `&self` is the immutable variant, otherwise `false`.
297    pub fn is_immutable(&self) -> bool {
298        matches!(self, Mutability::Immutable)
299    }
300
301    /// Shorthand ternary operator for choosing a value based on whether
302    /// a `Mutability` is mutable or immutable.
303    ///
304    /// The following pattern (with very slight variations) shows up often in code gen:
305    /// ```ignore
306    /// if mutability.is_mutable() {
307    ///     ""
308    /// } else {
309    ///     "const "
310    /// }
311    /// ```
312    /// This is particularly annoying in `write!(...)` statements, where `cargo fmt`
313    /// expands it to take up 5 lines.
314    ///
315    /// This method offers a 1-line alternative:
316    /// ```ignore
317    /// mutability.if_mut_else("", "const ")
318    /// ```
319    /// For cases where lazy evaluation is desired, consider using a conditional
320    /// or a `match` statement.
321    pub fn if_mut_else<T>(&self, if_mut: T, if_immut: T) -> T {
322        match self {
323            Mutability::Mutable => if_mut,
324            Mutability::Immutable => if_immut,
325        }
326    }
327}
328
329/// A local type reference, such as the type of a field, parameter, or return value.
330/// Unlike [`CustomType`], which represents a type declaration, [`TypeName`]s can compose
331/// types through references and boxing, and can also capture unresolved paths.
332#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)]
333#[non_exhaustive]
334pub enum TypeName {
335    /// A built-in Rust scalar primitive.
336    Primitive(PrimitiveType),
337    /// An unresolved path to a custom type, which can be resolved after all types
338    /// are collected with [`TypeName::resolve()`].
339    Named(PathType),
340    /// An optionally mutable reference to another type.
341    Reference(Lifetime, Mutability, Box<TypeName>),
342    /// A `Box<T>` type.
343    Box(Box<TypeName>),
344    /// A `Option<T>` type.
345    Option(Box<TypeName>),
346    /// A `Result<T, E>` or `diplomat_runtime::DiplomatWriteable` type. If the bool is true, it's `Result`
347    Result(Box<TypeName>, Box<TypeName>, bool),
348    Writeable,
349    /// A `&DiplomatStr` or `Box<DiplomatStr>` type.
350    /// Owned strings don't have a lifetime.
351    StrReference(Option<Lifetime>, StringEncoding),
352    /// A `&[T]` or `Box<[T]>` type, where `T` is a primitive.
353    /// Owned slices don't have a lifetime or mutability.
354    PrimitiveSlice(Option<(Lifetime, Mutability)>, PrimitiveType),
355    /// `&[&DiplomatStr]`
356    StrSlice(StringEncoding),
357    /// The `()` type.
358    Unit,
359    /// The `Self` type.
360    SelfType(PathType),
361    /// std::cmp::Ordering or core::cmp::Ordering
362    ///
363    /// The path must be present! Ordering will be parsed as an AST type!
364    Ordering,
365}
366
367#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Copy)]
368#[non_exhaustive]
369pub enum StringEncoding {
370    UnvalidatedUtf8,
371    UnvalidatedUtf16,
372    /// The caller guarantees that they're passing valid UTF-8, under penalty of UB
373    Utf8,
374}
375
376impl TypeName {
377    /// Converts the [`TypeName`] back into an AST node that can be spliced into a program.
378    pub fn to_syn(&self) -> syn::Type {
379        match self {
380            TypeName::Primitive(name) => {
381                syn::Type::Path(syn::parse_str(PRIMITIVE_TO_STRING.get(name).unwrap()).unwrap())
382            }
383            TypeName::Ordering => syn::Type::Path(syn::parse_str("i8").unwrap()),
384            TypeName::Named(name) | TypeName::SelfType(name) => {
385                // Self also gets expanded instead of turning into `Self` because
386                // this code is used to generate the `extern "C"` functions, which
387                // aren't in an impl block.
388                syn::Type::Path(name.to_syn())
389            }
390            TypeName::Reference(lifetime, mutability, underlying) => {
391                syn::Type::Reference(TypeReference {
392                    and_token: syn::token::And(Span::call_site()),
393                    lifetime: lifetime.to_syn(),
394                    mutability: mutability.to_syn(),
395                    elem: Box::new(underlying.to_syn()),
396                })
397            }
398            TypeName::Box(underlying) => syn::Type::Path(TypePath {
399                qself: None,
400                path: syn::Path {
401                    leading_colon: None,
402                    segments: Punctuated::from_iter(vec![PathSegment {
403                        ident: syn::Ident::new("Box", Span::call_site()),
404                        arguments: PathArguments::AngleBracketed(AngleBracketedGenericArguments {
405                            colon2_token: None,
406                            lt_token: syn::token::Lt(Span::call_site()),
407                            args: Punctuated::from_iter(vec![GenericArgument::Type(
408                                underlying.to_syn(),
409                            )]),
410                            gt_token: syn::token::Gt(Span::call_site()),
411                        }),
412                    }]),
413                },
414            }),
415            TypeName::Option(underlying) => syn::Type::Path(TypePath {
416                qself: None,
417                path: syn::Path {
418                    leading_colon: None,
419                    segments: Punctuated::from_iter(vec![PathSegment {
420                        ident: syn::Ident::new("Option", Span::call_site()),
421                        arguments: PathArguments::AngleBracketed(AngleBracketedGenericArguments {
422                            colon2_token: None,
423                            lt_token: syn::token::Lt(Span::call_site()),
424                            args: Punctuated::from_iter(vec![GenericArgument::Type(
425                                underlying.to_syn(),
426                            )]),
427                            gt_token: syn::token::Gt(Span::call_site()),
428                        }),
429                    }]),
430                },
431            }),
432            TypeName::Result(ok, err, true) => syn::Type::Path(TypePath {
433                qself: None,
434                path: syn::Path {
435                    leading_colon: None,
436                    segments: Punctuated::from_iter(vec![PathSegment {
437                        ident: syn::Ident::new("Result", Span::call_site()),
438                        arguments: PathArguments::AngleBracketed(AngleBracketedGenericArguments {
439                            colon2_token: None,
440                            lt_token: syn::token::Lt(Span::call_site()),
441                            args: Punctuated::from_iter(vec![
442                                GenericArgument::Type(ok.to_syn()),
443                                GenericArgument::Type(err.to_syn()),
444                            ]),
445                            gt_token: syn::token::Gt(Span::call_site()),
446                        }),
447                    }]),
448                },
449            }),
450            TypeName::Result(ok, err, false) => syn::Type::Path(TypePath {
451                qself: None,
452                path: syn::Path {
453                    leading_colon: None,
454                    segments: Punctuated::from_iter(vec![
455                        PathSegment {
456                            ident: syn::Ident::new("diplomat_runtime", Span::call_site()),
457                            arguments: PathArguments::None,
458                        },
459                        PathSegment {
460                            ident: syn::Ident::new("DiplomatResult", Span::call_site()),
461                            arguments: PathArguments::AngleBracketed(
462                                AngleBracketedGenericArguments {
463                                    colon2_token: None,
464                                    lt_token: syn::token::Lt(Span::call_site()),
465                                    args: Punctuated::from_iter(vec![
466                                        GenericArgument::Type(ok.to_syn()),
467                                        GenericArgument::Type(err.to_syn()),
468                                    ]),
469                                    gt_token: syn::token::Gt(Span::call_site()),
470                                },
471                            ),
472                        },
473                    ]),
474                },
475            }),
476            TypeName::Writeable => syn::parse_quote! {
477                diplomat_runtime::DiplomatWriteable
478            },
479            TypeName::StrReference(Some(lifetime), StringEncoding::UnvalidatedUtf8) => {
480                syn::parse_str(&format!(
481                    "{}DiplomatStr",
482                    ReferenceDisplay(lifetime, &Mutability::Immutable)
483                ))
484                .unwrap()
485            }
486            TypeName::StrReference(Some(lifetime), StringEncoding::UnvalidatedUtf16) => {
487                syn::parse_str(&format!(
488                    "{}DiplomatStr16",
489                    ReferenceDisplay(lifetime, &Mutability::Immutable)
490                ))
491                .unwrap()
492            }
493            TypeName::StrReference(Some(lifetime), StringEncoding::Utf8) => syn::parse_str(
494                &format!("{}str", ReferenceDisplay(lifetime, &Mutability::Immutable)),
495            )
496            .unwrap(),
497            TypeName::StrReference(None, StringEncoding::UnvalidatedUtf8) => {
498                syn::parse_str("Box<DiplomatStr>").unwrap()
499            }
500            TypeName::StrReference(None, StringEncoding::UnvalidatedUtf16) => {
501                syn::parse_str("Box<DiplomatStr16>").unwrap()
502            }
503            TypeName::StrReference(None, StringEncoding::Utf8) => {
504                syn::parse_str("Box<str>").unwrap()
505            }
506            TypeName::StrSlice(StringEncoding::UnvalidatedUtf8) => {
507                syn::parse_str("&[&DiplomatStr]").unwrap()
508            }
509            TypeName::StrSlice(StringEncoding::UnvalidatedUtf16) => {
510                syn::parse_str("&[&DiplomatStr16]").unwrap()
511            }
512            TypeName::StrSlice(StringEncoding::Utf8) => syn::parse_str("&[&str]").unwrap(),
513            TypeName::PrimitiveSlice(Some((lifetime, mutability)), name) => {
514                let primitive_name = PRIMITIVE_TO_STRING.get(name).unwrap();
515                let formatted_str = format!(
516                    "{}[{}]",
517                    ReferenceDisplay(lifetime, mutability),
518                    primitive_name
519                );
520                syn::parse_str(&formatted_str).unwrap()
521            }
522            TypeName::PrimitiveSlice(None, name) => syn::parse_str(&format!(
523                "Box<[{}]>",
524                PRIMITIVE_TO_STRING.get(name).unwrap()
525            ))
526            .unwrap(),
527            TypeName::Unit => syn::parse_quote! {
528                ()
529            },
530        }
531    }
532
533    /// Extract a [`TypeName`] from a [`syn::Type`] AST node.
534    /// The following rules are used to infer [`TypeName`] variants:
535    /// - If the type is a path with a single element that is the name of a Rust primitive, returns a [`TypeName::Primitive`]
536    /// - If the type is a path with a single element [`Box`], returns a [`TypeName::Box`] with the type parameter recursively converted
537    /// - If the type is a path with a single element [`Option`], returns a [`TypeName::Option`] with the type parameter recursively converted
538    /// - If the type is a path with a single element `Self` and `self_path_type` is provided, returns a [`TypeName::Named`]
539    /// - If the type is a path with a single element [`Result`], returns a [`TypeName::Result`] with the type parameters recursively converted
540    /// - If the type is a path equal to [`diplomat_runtime::DiplomatResult`], returns a [`TypeName::DiplomatResult`] with the type parameters recursively converted
541    /// - If the type is a path equal to [`diplomat_runtime::DiplomatWriteable`], returns a [`TypeName::Writeable`]
542    /// - If the type is a owned or borrowed string type, returns a [`TypeName::StrReference`]
543    /// - If the type is a owned or borrowed slice of a Rust primitive, returns a [`TypeName::PrimitiveSlice`]
544    /// - If the type is a reference (`&` or `&mut`), returns a [`TypeName::Reference`] with the referenced type recursively converted
545    /// - Otherwise, assume that the reference is to a [`CustomType`] in either the current module or another one, returns a [`TypeName::Named`]
546    pub fn from_syn(ty: &syn::Type, self_path_type: Option<PathType>) -> TypeName {
547        match ty {
548            syn::Type::Reference(r) => {
549                let lifetime = Lifetime::from(&r.lifetime);
550                let mutability = Mutability::from_syn(&r.mutability);
551
552                let name = r.elem.to_token_stream().to_string();
553                if name.starts_with("DiplomatStr") || name == "str" {
554                    if mutability.is_mutable() {
555                        panic!("mutable string references are disallowed");
556                    }
557                    if name == "DiplomatStr" {
558                        return TypeName::StrReference(
559                            Some(lifetime),
560                            StringEncoding::UnvalidatedUtf8,
561                        );
562                    } else if name == "DiplomatStr16" {
563                        return TypeName::StrReference(
564                            Some(lifetime),
565                            StringEncoding::UnvalidatedUtf16,
566                        );
567                    } else if name == "str" {
568                        return TypeName::StrReference(Some(lifetime), StringEncoding::Utf8);
569                    }
570                }
571                if let syn::Type::Slice(slice) = &*r.elem {
572                    if let syn::Type::Path(p) = &*slice.elem {
573                        if let Some(primitive) = p
574                            .path
575                            .get_ident()
576                            .and_then(|i| STRING_TO_PRIMITIVE.get(i.to_string().as_str()))
577                        {
578                            return TypeName::PrimitiveSlice(
579                                Some((lifetime, mutability)),
580                                *primitive,
581                            );
582                        }
583                    }
584                    if let TypeName::StrReference(Some(Lifetime::Anonymous), encoding) =
585                        TypeName::from_syn(&slice.elem, self_path_type.clone())
586                    {
587                        return TypeName::StrSlice(encoding);
588                    }
589                }
590                TypeName::Reference(
591                    lifetime,
592                    mutability,
593                    Box::new(TypeName::from_syn(r.elem.as_ref(), self_path_type)),
594                )
595            }
596            syn::Type::Path(p) => {
597                let p_len = p.path.segments.len();
598                if let Some(primitive) = p
599                    .path
600                    .get_ident()
601                    .and_then(|i| STRING_TO_PRIMITIVE.get(i.to_string().as_str()))
602                {
603                    TypeName::Primitive(*primitive)
604                } else if p_len >= 2
605                    && p.path.segments[p_len - 2].ident == "cmp"
606                    && p.path.segments[p_len - 1].ident == "Ordering"
607                {
608                    TypeName::Ordering
609                } else if p_len == 1 && p.path.segments[0].ident == "Box" {
610                    if let PathArguments::AngleBracketed(type_args) = &p.path.segments[0].arguments
611                    {
612                        if let GenericArgument::Type(syn::Type::Slice(slice)) = &type_args.args[0] {
613                            if let TypeName::Primitive(p) =
614                                TypeName::from_syn(&slice.elem, self_path_type)
615                            {
616                                TypeName::PrimitiveSlice(None, p)
617                            } else {
618                                panic!("Owned slices only support primitives.")
619                            }
620                        } else if let GenericArgument::Type(tpe) = &type_args.args[0] {
621                            if tpe.to_token_stream().to_string() == "DiplomatStr" {
622                                TypeName::StrReference(None, StringEncoding::UnvalidatedUtf8)
623                            } else if tpe.to_token_stream().to_string() == "DiplomatStr16" {
624                                TypeName::StrReference(None, StringEncoding::UnvalidatedUtf16)
625                            } else if tpe.to_token_stream().to_string() == "str" {
626                                TypeName::StrReference(None, StringEncoding::Utf8)
627                            } else {
628                                TypeName::Box(Box::new(TypeName::from_syn(tpe, self_path_type)))
629                            }
630                        } else {
631                            panic!("Expected first type argument for Box to be a type")
632                        }
633                    } else {
634                        panic!("Expected angle brackets for Box type")
635                    }
636                } else if p_len == 1 && p.path.segments[0].ident == "Option" {
637                    if let PathArguments::AngleBracketed(type_args) = &p.path.segments[0].arguments
638                    {
639                        if let GenericArgument::Type(tpe) = &type_args.args[0] {
640                            TypeName::Option(Box::new(TypeName::from_syn(tpe, self_path_type)))
641                        } else {
642                            panic!("Expected first type argument for Option to be a type")
643                        }
644                    } else {
645                        panic!("Expected angle brackets for Option type")
646                    }
647                } else if p_len == 1 && p.path.segments[0].ident == "Self" {
648                    if let Some(self_path_type) = self_path_type {
649                        TypeName::SelfType(self_path_type)
650                    } else {
651                        panic!("Cannot have `Self` type outside of a method");
652                    }
653                } else if p_len == 1 && p.path.segments[0].ident == "Result"
654                    || is_runtime_type(p, "DiplomatResult")
655                {
656                    if let PathArguments::AngleBracketed(type_args) =
657                        &p.path.segments.last().unwrap().arguments
658                    {
659                        if let (GenericArgument::Type(ok), GenericArgument::Type(err)) =
660                            (&type_args.args[0], &type_args.args[1])
661                        {
662                            let ok = TypeName::from_syn(ok, self_path_type.clone());
663                            let err = TypeName::from_syn(err, self_path_type);
664                            TypeName::Result(
665                                Box::new(ok),
666                                Box::new(err),
667                                !is_runtime_type(p, "DiplomatResult"),
668                            )
669                        } else {
670                            panic!("Expected both type arguments for Result to be a type")
671                        }
672                    } else {
673                        panic!("Expected angle brackets for Result type")
674                    }
675                } else if is_runtime_type(p, "DiplomatWriteable") {
676                    TypeName::Writeable
677                } else {
678                    TypeName::Named(PathType::from(p))
679                }
680            }
681            syn::Type::Tuple(tup) => {
682                if tup.elems.is_empty() {
683                    TypeName::Unit
684                } else {
685                    todo!("Tuples are not currently supported")
686                }
687            }
688            other => panic!("Unsupported type: {}", other.to_token_stream()),
689        }
690    }
691
692    /// Returns `true` if `self` is the `TypeName::SelfType` variant, otherwise
693    /// `false`.
694    pub fn is_self(&self) -> bool {
695        matches!(self, TypeName::SelfType(_))
696    }
697
698    /// Recurse down the type tree, visiting all lifetimes.
699    ///
700    /// Using this function, you can collect all the lifetimes into a collection,
701    /// or examine each one without having to make any additional allocations.
702    pub fn visit_lifetimes<'a, F, B>(&'a self, visit: &mut F) -> ControlFlow<B>
703    where
704        F: FnMut(&'a Lifetime, LifetimeOrigin) -> ControlFlow<B>,
705    {
706        match self {
707            TypeName::Named(path_type) | TypeName::SelfType(path_type) => path_type
708                .lifetimes
709                .iter()
710                .try_for_each(|lt| visit(lt, LifetimeOrigin::Named)),
711            TypeName::Reference(lt, _, ty) => {
712                ty.visit_lifetimes(visit)?;
713                visit(lt, LifetimeOrigin::Reference)
714            }
715            TypeName::Box(ty) | TypeName::Option(ty) => ty.visit_lifetimes(visit),
716            TypeName::Result(ok, err, _) => {
717                ok.visit_lifetimes(visit)?;
718                err.visit_lifetimes(visit)
719            }
720            TypeName::StrReference(Some(lt), ..) => visit(lt, LifetimeOrigin::StrReference),
721            TypeName::PrimitiveSlice(Some((lt, _)), ..) => {
722                visit(lt, LifetimeOrigin::PrimitiveSlice)
723            }
724            _ => ControlFlow::Continue(()),
725        }
726    }
727
728    /// Returns `true` if any lifetime satisfies a predicate, otherwise `false`.
729    ///
730    /// This method is short-circuiting, meaning that if the predicate ever succeeds,
731    /// it will return immediately.
732    pub fn any_lifetime<'a, F>(&'a self, mut f: F) -> bool
733    where
734        F: FnMut(&'a Lifetime, LifetimeOrigin) -> bool,
735    {
736        self.visit_lifetimes(&mut |lifetime, origin| {
737            if f(lifetime, origin) {
738                ControlFlow::Break(())
739            } else {
740                ControlFlow::Continue(())
741            }
742        })
743        .is_break()
744    }
745
746    /// Returns `true` if all lifetimes satisfy a predicate, otherwise `false`.
747    ///
748    /// This method is short-circuiting, meaning that if the predicate ever fails,
749    /// it will return immediately.
750    pub fn all_lifetimes<'a, F>(&'a self, mut f: F) -> bool
751    where
752        F: FnMut(&'a Lifetime, LifetimeOrigin) -> bool,
753    {
754        self.visit_lifetimes(&mut |lifetime, origin| {
755            if f(lifetime, origin) {
756                ControlFlow::Continue(())
757            } else {
758                ControlFlow::Break(())
759            }
760        })
761        .is_continue()
762    }
763
764    /// Returns all lifetimes in a [`LifetimeEnv`] that must live at least as
765    /// long as the type.
766    pub fn longer_lifetimes<'env>(
767        &self,
768        lifetime_env: &'env LifetimeEnv,
769    ) -> Vec<&'env NamedLifetime> {
770        self.transitive_lifetime_bounds(LifetimeTransitivity::longer(lifetime_env))
771    }
772
773    /// Returns all lifetimes in a [`LifetimeEnv`] that are outlived by the type.
774    pub fn shorter_lifetimes<'env>(
775        &self,
776        lifetime_env: &'env LifetimeEnv,
777    ) -> Vec<&'env NamedLifetime> {
778        self.transitive_lifetime_bounds(LifetimeTransitivity::shorter(lifetime_env))
779    }
780
781    /// Visits the provided [`LifetimeTransitivity`] value with all `NamedLifetime`s
782    /// in the type tree, and returns the transitively reachable lifetimes.
783    fn transitive_lifetime_bounds<'env>(
784        &self,
785        mut transitivity: LifetimeTransitivity<'env>,
786    ) -> Vec<&'env NamedLifetime> {
787        self.visit_lifetimes(&mut |lifetime, _| -> ControlFlow<()> {
788            if let Lifetime::Named(named) = lifetime {
789                transitivity.visit(named);
790            }
791            ControlFlow::Continue(())
792        });
793        transitivity.finish()
794    }
795
796    pub fn is_zst(&self) -> bool {
797        // check_zst() prevents non-unit types from being ZSTs
798        matches!(*self, TypeName::Unit)
799    }
800
801    pub fn is_pointer(&self) -> bool {
802        matches!(*self, TypeName::Reference(..) | TypeName::Box(_))
803    }
804}
805
806#[non_exhaustive]
807pub enum LifetimeOrigin {
808    Named,
809    Reference,
810    StrReference,
811    PrimitiveSlice,
812}
813
814fn is_runtime_type(p: &TypePath, name: &str) -> bool {
815    (p.path.segments.len() == 1 && p.path.segments[0].ident == name)
816        || (p.path.segments.len() == 2
817            && p.path.segments[0].ident == "diplomat_runtime"
818            && p.path.segments[1].ident == name)
819}
820
821impl fmt::Display for TypeName {
822    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
823        match self {
824            TypeName::Primitive(p) => p.fmt(f),
825            TypeName::Ordering => write!(f, "Ordering"),
826            TypeName::Named(p) | TypeName::SelfType(p) => p.fmt(f),
827            TypeName::Reference(lifetime, mutability, typ) => {
828                write!(f, "{}{typ}", ReferenceDisplay(lifetime, mutability))
829            }
830            TypeName::Box(typ) => write!(f, "Box<{typ}>"),
831            TypeName::Option(typ) => write!(f, "Option<{typ}>"),
832            TypeName::Result(ok, err, _) => {
833                write!(f, "Result<{ok}, {err}>")
834            }
835            TypeName::Writeable => "DiplomatWriteable".fmt(f),
836            TypeName::StrReference(Some(lifetime), StringEncoding::UnvalidatedUtf8) => {
837                write!(
838                    f,
839                    "{}DiplomatStr",
840                    ReferenceDisplay(lifetime, &Mutability::Immutable)
841                )
842            }
843            TypeName::StrReference(Some(lifetime), StringEncoding::UnvalidatedUtf16) => {
844                write!(
845                    f,
846                    "{}DiplomatStr16",
847                    ReferenceDisplay(lifetime, &Mutability::Immutable)
848                )
849            }
850            TypeName::StrReference(Some(lifetime), StringEncoding::Utf8) => {
851                write!(
852                    f,
853                    "{}str",
854                    ReferenceDisplay(lifetime, &Mutability::Immutable)
855                )
856            }
857            TypeName::StrReference(None, StringEncoding::UnvalidatedUtf8) => {
858                write!(f, "Box<DiplomatStr>")
859            }
860            TypeName::StrReference(None, StringEncoding::UnvalidatedUtf16) => {
861                write!(f, "Box<DiplomatStr16>")
862            }
863            TypeName::StrReference(None, StringEncoding::Utf8) => {
864                write!(f, "Box<str>")
865            }
866            TypeName::StrSlice(StringEncoding::UnvalidatedUtf8) => {
867                write!(f, "&[&DiplomatStr]")
868            }
869            TypeName::StrSlice(StringEncoding::UnvalidatedUtf16) => {
870                write!(f, "&[&DiplomatStr16]")
871            }
872            TypeName::StrSlice(StringEncoding::Utf8) => {
873                write!(f, "&[&str]")
874            }
875            TypeName::PrimitiveSlice(Some((lifetime, mutability)), typ) => {
876                write!(f, "{}[{typ}]", ReferenceDisplay(lifetime, mutability))
877            }
878            TypeName::PrimitiveSlice(None, typ) => write!(f, "Box<[{typ}]>"),
879            TypeName::Unit => "()".fmt(f),
880        }
881    }
882}
883
884/// An [`fmt::Display`] type for formatting Rust references.
885///
886/// # Examples
887///
888/// ```ignore
889/// let lifetime = Lifetime::from(&syn::parse_str::<syn::Lifetime>("'a"));
890/// let mutability = Mutability::Mutable;
891/// // ...
892/// let fmt = format!("{}[u8]", ReferenceDisplay(&lifetime, &mutability));
893///
894/// assert_eq!(fmt, "&'a mut [u8]");
895/// ```
896struct ReferenceDisplay<'a>(&'a Lifetime, &'a Mutability);
897
898impl<'a> fmt::Display for ReferenceDisplay<'a> {
899    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
900        match self.0 {
901            Lifetime::Static => "&'static ".fmt(f)?,
902            Lifetime::Named(lifetime) => write!(f, "&{lifetime} ")?,
903            Lifetime::Anonymous => '&'.fmt(f)?,
904        }
905
906        if self.1.is_mutable() {
907            "mut ".fmt(f)?;
908        }
909
910        Ok(())
911    }
912}
913
914impl fmt::Display for PathType {
915    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
916        self.path.fmt(f)?;
917
918        if let Some((first, rest)) = self.lifetimes.split_first() {
919            write!(f, "<{first}")?;
920            for lifetime in rest {
921                write!(f, ", {lifetime}")?;
922            }
923            '>'.fmt(f)?;
924        }
925        Ok(())
926    }
927}
928
929/// A built-in Rust primitive scalar type.
930#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)]
931#[allow(non_camel_case_types)]
932#[allow(clippy::exhaustive_enums)] // there are only these (scalar types)
933pub enum PrimitiveType {
934    i8,
935    u8,
936    i16,
937    u16,
938    i32,
939    u32,
940    i64,
941    u64,
942    i128,
943    u128,
944    isize,
945    usize,
946    f32,
947    f64,
948    bool,
949    char,
950    /// a primitive byte that is not meant to be interpreted numerically
951    /// in languages that don't have fine-grained integer types
952    byte,
953}
954
955impl fmt::Display for PrimitiveType {
956    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
957        match self {
958            PrimitiveType::i8 => "i8",
959            PrimitiveType::u8 | PrimitiveType::byte => "u8",
960            PrimitiveType::i16 => "i16",
961            PrimitiveType::u16 => "u16",
962            PrimitiveType::i32 => "i32",
963            PrimitiveType::u32 => "u32",
964            PrimitiveType::i64 => "i64",
965            PrimitiveType::u64 => "u64",
966            PrimitiveType::i128 => "i128",
967            PrimitiveType::u128 => "u128",
968            PrimitiveType::isize => "isize",
969            PrimitiveType::usize => "usize",
970            PrimitiveType::f32 => "f32",
971            PrimitiveType::f64 => "f64",
972            PrimitiveType::bool => "bool",
973            PrimitiveType::char => "char",
974        }
975        .fmt(f)
976    }
977}
978
979lazy_static! {
980    static ref PRIMITIVES_MAPPING: [(&'static str, PrimitiveType); 17] = [
981        ("i8", PrimitiveType::i8),
982        ("u8", PrimitiveType::u8),
983        ("i16", PrimitiveType::i16),
984        ("u16", PrimitiveType::u16),
985        ("i32", PrimitiveType::i32),
986        ("u32", PrimitiveType::u32),
987        ("i64", PrimitiveType::i64),
988        ("u64", PrimitiveType::u64),
989        ("i128", PrimitiveType::i128),
990        ("u128", PrimitiveType::u128),
991        ("isize", PrimitiveType::isize),
992        ("usize", PrimitiveType::usize),
993        ("f32", PrimitiveType::f32),
994        ("f64", PrimitiveType::f64),
995        ("bool", PrimitiveType::bool),
996        ("DiplomatChar", PrimitiveType::char),
997        ("DiplomatByte", PrimitiveType::byte),
998    ];
999    static ref STRING_TO_PRIMITIVE: HashMap<&'static str, PrimitiveType> =
1000        PRIMITIVES_MAPPING.iter().cloned().collect();
1001    static ref PRIMITIVE_TO_STRING: HashMap<PrimitiveType, &'static str> =
1002        PRIMITIVES_MAPPING.iter().map(|t| (t.1, t.0)).collect();
1003}
1004
1005#[cfg(test)]
1006mod tests {
1007    use insta;
1008
1009    use syn;
1010
1011    use super::TypeName;
1012
1013    #[test]
1014    fn typename_primitives() {
1015        insta::assert_yaml_snapshot!(TypeName::from_syn(
1016            &syn::parse_quote! {
1017                i32
1018            },
1019            None
1020        ));
1021
1022        insta::assert_yaml_snapshot!(TypeName::from_syn(
1023            &syn::parse_quote! {
1024                usize
1025            },
1026            None
1027        ));
1028
1029        insta::assert_yaml_snapshot!(TypeName::from_syn(
1030            &syn::parse_quote! {
1031                bool
1032            },
1033            None
1034        ));
1035    }
1036
1037    #[test]
1038    fn typename_named() {
1039        insta::assert_yaml_snapshot!(TypeName::from_syn(
1040            &syn::parse_quote! {
1041                MyLocalStruct
1042            },
1043            None
1044        ));
1045    }
1046
1047    #[test]
1048    fn typename_references() {
1049        insta::assert_yaml_snapshot!(TypeName::from_syn(
1050            &syn::parse_quote! {
1051                &i32
1052            },
1053            None
1054        ));
1055
1056        insta::assert_yaml_snapshot!(TypeName::from_syn(
1057            &syn::parse_quote! {
1058                &mut MyLocalStruct
1059            },
1060            None
1061        ));
1062    }
1063
1064    #[test]
1065    fn typename_boxes() {
1066        insta::assert_yaml_snapshot!(TypeName::from_syn(
1067            &syn::parse_quote! {
1068                Box<i32>
1069            },
1070            None
1071        ));
1072
1073        insta::assert_yaml_snapshot!(TypeName::from_syn(
1074            &syn::parse_quote! {
1075                Box<MyLocalStruct>
1076            },
1077            None
1078        ));
1079    }
1080
1081    #[test]
1082    fn typename_option() {
1083        insta::assert_yaml_snapshot!(TypeName::from_syn(
1084            &syn::parse_quote! {
1085                Option<i32>
1086            },
1087            None
1088        ));
1089
1090        insta::assert_yaml_snapshot!(TypeName::from_syn(
1091            &syn::parse_quote! {
1092                Option<MyLocalStruct>
1093            },
1094            None
1095        ));
1096    }
1097
1098    #[test]
1099    fn typename_result() {
1100        insta::assert_yaml_snapshot!(TypeName::from_syn(
1101            &syn::parse_quote! {
1102                DiplomatResult<MyLocalStruct, i32>
1103            },
1104            None
1105        ));
1106
1107        insta::assert_yaml_snapshot!(TypeName::from_syn(
1108            &syn::parse_quote! {
1109                DiplomatResult<(), MyLocalStruct>
1110            },
1111            None
1112        ));
1113
1114        insta::assert_yaml_snapshot!(TypeName::from_syn(
1115            &syn::parse_quote! {
1116                Result<MyLocalStruct, i32>
1117            },
1118            None
1119        ));
1120
1121        insta::assert_yaml_snapshot!(TypeName::from_syn(
1122            &syn::parse_quote! {
1123                Result<(), MyLocalStruct>
1124            },
1125            None
1126        ));
1127    }
1128
1129    #[test]
1130    fn lifetimes() {
1131        insta::assert_yaml_snapshot!(TypeName::from_syn(
1132            &syn::parse_quote! {
1133                Foo<'a, 'b>
1134            },
1135            None
1136        ));
1137
1138        insta::assert_yaml_snapshot!(TypeName::from_syn(
1139            &syn::parse_quote! {
1140                ::core::my_type::Foo
1141            },
1142            None
1143        ));
1144
1145        insta::assert_yaml_snapshot!(TypeName::from_syn(
1146            &syn::parse_quote! {
1147                ::core::my_type::Foo<'test>
1148            },
1149            None
1150        ));
1151
1152        insta::assert_yaml_snapshot!(TypeName::from_syn(
1153            &syn::parse_quote! {
1154                Option<Ref<'object>>
1155            },
1156            None
1157        ));
1158
1159        insta::assert_yaml_snapshot!(TypeName::from_syn(
1160            &syn::parse_quote! {
1161                Foo<'a, 'b, 'c, 'd>
1162            },
1163            None
1164        ));
1165
1166        insta::assert_yaml_snapshot!(TypeName::from_syn(
1167            &syn::parse_quote! {
1168                very::long::path::to::my::Type<'x, 'y, 'z>
1169            },
1170            None
1171        ));
1172
1173        insta::assert_yaml_snapshot!(TypeName::from_syn(
1174            &syn::parse_quote! {
1175                Result<OkRef<'a, 'b>, ErrRef<'c>>
1176            },
1177            None
1178        ));
1179    }
1180}