diplomat_core/ast/
lifetimes.rs

1use proc_macro2::Span;
2use quote::{quote, ToTokens};
3use serde::{Deserialize, Serialize};
4use std::fmt;
5
6use super::{Docs, Ident, Param, SelfParam, TypeName};
7
8/// A named lifetime, e.g. `'a`.
9///
10/// # Invariants
11///
12/// Cannot be `'static` or `'_`, use [`Lifetime`] to represent those instead.
13#[derive(Clone, Debug, Hash, Eq, PartialEq, Serialize, PartialOrd, Ord)]
14#[serde(transparent)]
15pub struct NamedLifetime(Ident);
16
17impl NamedLifetime {
18    pub fn name(&self) -> &Ident {
19        &self.0
20    }
21}
22
23impl<'de> Deserialize<'de> for NamedLifetime {
24    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
25    where
26        D: serde::Deserializer<'de>,
27    {
28        // Special `Deserialize` impl to ensure invariants.
29        let named = Ident::deserialize(deserializer)?;
30        if named.as_str() == "static" {
31            panic!("cannot be static");
32        }
33        Ok(NamedLifetime(named))
34    }
35}
36
37impl From<&syn::Lifetime> for NamedLifetime {
38    fn from(lt: &syn::Lifetime) -> Self {
39        Lifetime::from(lt).to_named().expect("cannot be static")
40    }
41}
42
43impl From<&NamedLifetime> for NamedLifetime {
44    fn from(this: &NamedLifetime) -> Self {
45        this.clone()
46    }
47}
48
49impl PartialEq<syn::Lifetime> for NamedLifetime {
50    fn eq(&self, other: &syn::Lifetime) -> bool {
51        other.ident == self.0.as_str()
52    }
53}
54
55impl fmt::Display for NamedLifetime {
56    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57        write!(f, "'{}", self.0)
58    }
59}
60
61impl ToTokens for NamedLifetime {
62    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
63        use proc_macro2::{Punct, Spacing};
64        Punct::new('\'', Spacing::Joint).to_tokens(tokens);
65        self.0.to_tokens(tokens);
66    }
67}
68
69/// A lifetime dependency graph used for tracking which lifetimes outlive,
70/// and are outlived by, other lifetimes.
71///
72/// It is similar to [`syn::LifetimeDef`], except it can also track lifetime
73/// bounds defined in the `where` clause.
74#[derive(Clone, PartialEq, Eq, Hash, Debug)]
75pub struct LifetimeEnv {
76    pub(crate) nodes: Vec<LifetimeNode>,
77}
78
79impl LifetimeEnv {
80    /// Construct an empty [`LifetimeEnv`].
81    ///
82    /// To create one outside of this module, use `LifetimeEnv::from_method_item`
83    /// or `LifetimeEnv::from` on `&syn::Generics`.
84    fn new() -> Self {
85        Self { nodes: vec![] }
86    }
87
88    /// Iterate through the names of the lifetimes in scope.
89    pub fn names(&self) -> impl Iterator<Item = &NamedLifetime> + Clone {
90        self.nodes.iter().map(|node| &node.lifetime)
91    }
92
93    /// Returns a [`LifetimeEnv`] for a method, accounting for lifetimes and bounds
94    /// defined in both the impl block and the method, as well as implicit lifetime
95    /// bounds in the optional `self` param, other param, and optional return type.
96    /// For example, the type `&'a Foo<'b>` implies `'b: 'a`.
97    pub fn from_method_item(
98        method: &syn::ImplItemFn,
99        impl_generics: Option<&syn::Generics>,
100        self_param: Option<&SelfParam>,
101        params: &[Param],
102        return_type: Option<&TypeName>,
103    ) -> Self {
104        let mut this = LifetimeEnv::new();
105        // The impl generics _must_ be loaded into the env first, since the method
106        // generics might use lifetimes defined in the impl, and `extend_generics`
107        // panics if `'a: 'b` where `'b` isn't declared by the time it finishes.
108        if let Some(generics) = impl_generics {
109            this.extend_generics(generics);
110        }
111        this.extend_generics(&method.sig.generics);
112
113        if let Some(self_param) = self_param {
114            this.extend_implicit_lifetime_bounds(&self_param.to_typename(), None);
115        }
116        for param in params {
117            this.extend_implicit_lifetime_bounds(&param.ty, None);
118        }
119        if let Some(return_type) = return_type {
120            this.extend_implicit_lifetime_bounds(return_type, None);
121        }
122
123        this
124    }
125
126    /// Returns a [`LifetimeEnv`] for a struct, accounding for lifetimes and bounds
127    /// defined in the struct generics, as well as implicit lifetime bounds in
128    /// the struct's fields. For example, the field `&'a Foo<'b>` implies `'b: 'a`.
129    pub fn from_struct_item(strct: &syn::ItemStruct, fields: &[(Ident, TypeName, Docs)]) -> Self {
130        let mut this = LifetimeEnv::new();
131        this.extend_generics(&strct.generics);
132        for (_, typ, _) in fields {
133            this.extend_implicit_lifetime_bounds(typ, None);
134        }
135        this
136    }
137
138    /// Traverse a type, adding any implicit lifetime bounds that arise from
139    /// having a reference to an opaque containing a lifetime.
140    /// For example, the type `&'a Foo<'b>` implies `'b: 'a`.
141    fn extend_implicit_lifetime_bounds(
142        &mut self,
143        typ: &TypeName,
144        behind_ref: Option<&NamedLifetime>,
145    ) {
146        match typ {
147            TypeName::Named(path_type) => {
148                if let Some(borrow_lifetime) = behind_ref {
149                    let explicit_longer_than_borrow =
150                        LifetimeTransitivity::longer_than(self, borrow_lifetime);
151                    let mut implicit_longer_than_borrow = vec![];
152
153                    for path_lifetime in path_type.lifetimes.iter() {
154                        if let Lifetime::Named(path_lifetime) = path_lifetime {
155                            if !explicit_longer_than_borrow.contains(&path_lifetime) {
156                                implicit_longer_than_borrow.push(path_lifetime);
157                            }
158                        }
159                    }
160
161                    self.extend_bounds(
162                        implicit_longer_than_borrow
163                            .into_iter()
164                            .map(|path_lifetime| (path_lifetime, Some(borrow_lifetime))),
165                    );
166                }
167            }
168            TypeName::Reference(lifetime, _, typ) => {
169                let behind_ref = if let Lifetime::Named(named) = lifetime {
170                    Some(named)
171                } else {
172                    None
173                };
174                self.extend_implicit_lifetime_bounds(typ, behind_ref);
175            }
176            TypeName::Option(typ) => self.extend_implicit_lifetime_bounds(typ, None),
177            TypeName::Result(ok, err, _) => {
178                self.extend_implicit_lifetime_bounds(ok, None);
179                self.extend_implicit_lifetime_bounds(err, None);
180            }
181            _ => {}
182        }
183    }
184
185    /// Add the lifetimes from generic parameters and where bounds.
186    fn extend_generics(&mut self, generics: &syn::Generics) {
187        let generic_bounds = generics.params.iter().map(|generic| match generic {
188            syn::GenericParam::Type(_) => panic!("generic types are unsupported"),
189            syn::GenericParam::Lifetime(def) => (&def.lifetime, &def.bounds),
190            syn::GenericParam::Const(_) => panic!("const generics are unsupported"),
191        });
192
193        let generic_defs = generic_bounds.clone().map(|(lifetime, _)| lifetime);
194
195        self.extend_lifetimes(generic_defs);
196        self.extend_bounds(generic_bounds);
197
198        if let Some(ref where_clause) = generics.where_clause {
199            self.extend_bounds(where_clause.predicates.iter().map(|pred| match pred {
200                syn::WherePredicate::Type(_) => panic!("trait bounds are unsupported"),
201                syn::WherePredicate::Lifetime(pred) => (&pred.lifetime, &pred.bounds),
202                _ => panic!("Found unknown kind of where predicate"),
203            }));
204        }
205    }
206
207    /// Returns the number of lifetimes in the graph.
208    pub fn len(&self) -> usize {
209        self.nodes.len()
210    }
211
212    /// Returns `true` if the graph contains no lifetimes.
213    pub fn is_empty(&self) -> bool {
214        self.nodes.is_empty()
215    }
216
217    /// `<'a, 'b, 'c>`
218    ///
219    /// Write the existing lifetimes, excluding bounds, as generic parameters.
220    ///
221    /// To include lifetime bounds, use [`LifetimeEnv::lifetime_defs_to_tokens`].
222    pub fn lifetimes_to_tokens(&self) -> proc_macro2::TokenStream {
223        if self.is_empty() {
224            return quote! {};
225        }
226
227        let lifetimes = self.nodes.iter().map(|node| &node.lifetime);
228        quote! { <#(#lifetimes),*> }
229    }
230
231    /// Returns the index of a lifetime in the graph, or `None` if the lifetime
232    /// isn't in the graph.
233    pub(crate) fn id<L>(&self, lifetime: &L) -> Option<usize>
234    where
235        NamedLifetime: PartialEq<L>,
236    {
237        self.nodes
238            .iter()
239            .position(|node| &node.lifetime == lifetime)
240    }
241
242    /// Add isolated lifetimes to the graph.
243    fn extend_lifetimes<'a, L, I>(&mut self, iter: I)
244    where
245        NamedLifetime: PartialEq<L> + From<&'a L>,
246        L: 'a,
247        I: IntoIterator<Item = &'a L>,
248    {
249        for lifetime in iter {
250            if self.id(lifetime).is_some() {
251                panic!(
252                    "lifetime name `{}` declared twice in the same scope",
253                    NamedLifetime::from(lifetime)
254                );
255            }
256
257            self.nodes.push(LifetimeNode {
258                lifetime: lifetime.into(),
259                shorter: vec![],
260                longer: vec![],
261            });
262        }
263    }
264
265    /// Add edges to the lifetime graph.
266    ///
267    /// This method is decoupled from [`LifetimeEnv::extend_lifetimes`] because
268    /// generics can define new lifetimes, while `where` clauses cannot.
269    ///
270    /// # Panics
271    ///
272    /// This method panics if any of the lifetime bounds aren't already defined
273    /// in the graph. This isn't allowed by rustc in the first place, so it should
274    /// only ever occur when deserializing an invalid [`LifetimeEnv`].
275    fn extend_bounds<'a, L, B, I>(&mut self, iter: I)
276    where
277        NamedLifetime: PartialEq<L> + From<&'a L>,
278        L: 'a,
279        B: IntoIterator<Item = &'a L>,
280        I: IntoIterator<Item = (&'a L, B)>,
281    {
282        for (lifetime, bounds) in iter {
283            let long = self.id(lifetime).expect("use of undeclared lifetime, this is a bug: try calling `LifetimeEnv::extend_lifetimes` first");
284            for bound in bounds {
285                let short = self
286                    .id(bound)
287                    .expect("cannot use undeclared lifetime as a bound");
288                self.nodes[short].longer.push(long);
289                self.nodes[long].shorter.push(short);
290            }
291        }
292    }
293}
294
295impl fmt::Display for LifetimeEnv {
296    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
297        self.to_token_stream().fmt(f)
298    }
299}
300
301impl ToTokens for LifetimeEnv {
302    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
303        for node in self.nodes.iter() {
304            let lifetime = &node.lifetime;
305            if node.shorter.is_empty() {
306                tokens.extend(quote! { #lifetime, });
307            } else {
308                let bounds = node.shorter.iter().map(|&id| &self.nodes[id].lifetime);
309                tokens.extend(quote! { #lifetime: #(#bounds)+*, });
310            }
311        }
312    }
313}
314
315/// Serialize a [`LifetimeEnv`] as a map from lifetimes to their bounds.
316impl Serialize for LifetimeEnv {
317    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
318    where
319        S: serde::Serializer,
320    {
321        use serde::ser::SerializeMap;
322        let mut seq = serializer.serialize_map(Some(self.len()))?;
323
324        for node in self.nodes.iter() {
325            /// Helper type for serializing bounds.
326            struct Bounds<'a> {
327                ids: &'a [usize],
328                nodes: &'a [LifetimeNode],
329            }
330
331            impl<'a> Serialize for Bounds<'a> {
332                fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
333                where
334                    S: serde::Serializer,
335                {
336                    use serde::ser::SerializeSeq;
337                    let mut seq = serializer.serialize_seq(Some(self.ids.len()))?;
338                    for &id in self.ids {
339                        seq.serialize_element(&self.nodes[id].lifetime)?;
340                    }
341                    seq.end()
342                }
343            }
344
345            seq.serialize_entry(
346                &node.lifetime,
347                &Bounds {
348                    ids: &node.shorter[..],
349                    nodes: &self.nodes,
350                },
351            )?;
352        }
353        seq.end()
354    }
355}
356
357impl<'de> Deserialize<'de> for LifetimeEnv {
358    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
359    where
360        D: serde::Deserializer<'de>,
361    {
362        use std::collections::BTreeMap;
363
364        let m: BTreeMap<NamedLifetime, Vec<NamedLifetime>> =
365            Deserialize::deserialize(deserializer)?;
366
367        let mut this = LifetimeEnv::new();
368        this.extend_lifetimes(m.keys());
369        this.extend_bounds(m.iter());
370        Ok(this)
371    }
372}
373
374/// A lifetime, along with ptrs to all lifetimes that are explicitly
375/// shorter/longer than it.
376///
377/// This type is internal to [`LifetimeGraph`]- the ptrs are stored as `usize`s,
378/// meaning that they may be invalid if a `LifetimeEdges` is created in one
379/// `LifetimeGraph` and then used in another.
380#[derive(Clone, PartialEq, Eq, Hash, Debug)]
381pub(crate) struct LifetimeNode {
382    /// The name of the lifetime.
383    pub(crate) lifetime: NamedLifetime,
384
385    /// Pointers to all lifetimes that this lives _at least_ as long as.
386    ///
387    /// Note: This doesn't account for transitivity.
388    pub(crate) shorter: Vec<usize>,
389
390    /// Pointers to all lifetimes that live _at least_ as long as this.
391    ///
392    /// Note: This doesn't account for transitivity.
393    pub(crate) longer: Vec<usize>,
394}
395
396/// A lifetime, analogous to [`syn::Lifetime`].
397#[derive(Clone, Debug, Hash, Eq, PartialEq, Serialize, Deserialize)]
398#[non_exhaustive]
399pub enum Lifetime {
400    /// The `'static` lifetime.
401    Static,
402
403    /// A named lifetime, like `'a`.
404    Named(NamedLifetime),
405
406    /// An elided lifetime.
407    Anonymous,
408}
409
410impl Lifetime {
411    /// Returns the inner `NamedLifetime` if the lifetime is the `Named` variant,
412    /// otherwise `None`.
413    pub fn to_named(self) -> Option<NamedLifetime> {
414        if let Lifetime::Named(named) = self {
415            return Some(named);
416        }
417        None
418    }
419
420    /// Returns a reference to the inner `NamedLifetime` if the lifetime is the
421    /// `Named` variant, otherwise `None`.
422    pub fn as_named(&self) -> Option<&NamedLifetime> {
423        if let Lifetime::Named(named) = self {
424            return Some(named);
425        }
426        None
427    }
428}
429
430impl fmt::Display for Lifetime {
431    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
432        match self {
433            Lifetime::Static => "'static".fmt(f),
434            Lifetime::Named(ref named) => named.fmt(f),
435            Lifetime::Anonymous => "'_".fmt(f),
436        }
437    }
438}
439
440impl ToTokens for Lifetime {
441    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
442        match self {
443            Lifetime::Static => syn::Lifetime::new("'static", Span::call_site()).to_tokens(tokens),
444            Lifetime::Named(ref s) => s.to_tokens(tokens),
445            Lifetime::Anonymous => syn::Lifetime::new("'_", Span::call_site()).to_tokens(tokens),
446        };
447    }
448}
449
450impl From<&syn::Lifetime> for Lifetime {
451    fn from(lt: &syn::Lifetime) -> Self {
452        if lt.ident == "static" {
453            Self::Static
454        } else {
455            Self::Named(NamedLifetime((&lt.ident).into()))
456        }
457    }
458}
459
460impl From<&Option<syn::Lifetime>> for Lifetime {
461    fn from(lt: &Option<syn::Lifetime>) -> Self {
462        lt.as_ref().map(Into::into).unwrap_or(Self::Anonymous)
463    }
464}
465
466impl Lifetime {
467    /// Converts the [`Lifetime`] back into an AST node that can be spliced into a program.
468    pub fn to_syn(&self) -> Option<syn::Lifetime> {
469        match *self {
470            Self::Static => Some(syn::Lifetime::new("'static", Span::call_site())),
471            Self::Anonymous => None,
472            Self::Named(ref s) => Some(syn::Lifetime::new(&s.to_string(), Span::call_site())),
473        }
474    }
475}
476
477/// Collect all lifetimes that are either longer_or_shorter
478pub struct LifetimeTransitivity<'env> {
479    env: &'env LifetimeEnv,
480    visited: Vec<bool>,
481    out: Vec<&'env NamedLifetime>,
482    longer_or_shorter: LongerOrShorter,
483}
484
485impl<'env> LifetimeTransitivity<'env> {
486    /// Returns a new [`LifetimeTransitivity`] that finds all longer lifetimes.
487    pub fn longer(env: &'env LifetimeEnv) -> Self {
488        Self::new(env, LongerOrShorter::Longer)
489    }
490
491    /// Returns a new [`LifetimeTransitivity`] that finds all shorter lifetimes.
492    pub fn shorter(env: &'env LifetimeEnv) -> Self {
493        Self::new(env, LongerOrShorter::Shorter)
494    }
495
496    /// Returns all the lifetimes longer than a provided `NamedLifetime`.
497    pub fn longer_than(env: &'env LifetimeEnv, named: &NamedLifetime) -> Vec<&'env NamedLifetime> {
498        let mut this = Self::new(env, LongerOrShorter::Longer);
499        this.visit(named);
500        this.finish()
501    }
502
503    /// Returns all the lifetimes shorter than the provided `NamedLifetime`.
504    pub fn shorter_than(env: &'env LifetimeEnv, named: &NamedLifetime) -> Vec<&'env NamedLifetime> {
505        let mut this = Self::new(env, LongerOrShorter::Shorter);
506        this.visit(named);
507        this.finish()
508    }
509
510    /// Returns a new [`LifetimeTransitivity`].
511    fn new(env: &'env LifetimeEnv, longer_or_shorter: LongerOrShorter) -> Self {
512        LifetimeTransitivity {
513            env,
514            visited: vec![false; env.len()],
515            out: vec![],
516            longer_or_shorter,
517        }
518    }
519
520    /// Visits a lifetime, as well as all the nodes it's transitively longer or
521    /// shorter than, depending on how the `LifetimeTransitivity` was constructed.
522    pub fn visit(&mut self, named: &NamedLifetime) {
523        if let Some(id) = self
524            .env
525            .nodes
526            .iter()
527            .position(|node| node.lifetime == *named)
528        {
529            self.dfs(id);
530        }
531    }
532
533    /// Performs depth-first search through the `LifetimeEnv` created at construction
534    /// for all nodes longer or shorter than the node at the provided index,
535    /// depending on how the `LifetimeTransitivity` was constructed.
536    fn dfs(&mut self, index: usize) {
537        // Note: all of these indexings SHOULD be valid because
538        // `visited.len() == nodes.len()`, and the ids come from
539        // calling `Iterator::position` on `nodes`, which never shrinks.
540        // So we should be able to change these to `get_unchecked`...
541        if !self.visited[index] {
542            self.visited[index] = true;
543
544            let node = &self.env.nodes[index];
545            self.out.push(&node.lifetime);
546            for &edge_index in self.longer_or_shorter.edges(node).iter() {
547                self.dfs(edge_index);
548            }
549        }
550    }
551
552    /// Returns the transitively reachable lifetimes.
553    pub fn finish(self) -> Vec<&'env NamedLifetime> {
554        self.out
555    }
556}
557
558/// A helper type for [`LifetimeTransitivity`] determining whether to find the
559/// transitively longer or transitively shorter lifetimes.
560enum LongerOrShorter {
561    Longer,
562    Shorter,
563}
564
565impl LongerOrShorter {
566    /// Returns either the indices of the longer or shorter lifetimes, depending
567    /// on `self`.
568    fn edges<'node>(&self, node: &'node LifetimeNode) -> &'node [usize] {
569        match self {
570            LongerOrShorter::Longer => &node.longer[..],
571            LongerOrShorter::Shorter => &node.shorter[..],
572        }
573    }
574}