use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens};
use syn::{
parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Block, Lifetime, Receiver,
ReturnType, Signature, TypeReference, WhereClause,
};
use crate::parse::{AsyncItem, RecursionArgs};
impl ToTokens for AsyncItem {
fn to_tokens(&self, tokens: &mut TokenStream) {
self.0.to_tokens(tokens);
}
}
pub fn expand(item: &mut AsyncItem, args: &RecursionArgs) {
item.0.attrs.push(parse_quote!(#[must_use]));
transform_sig(&mut item.0.sig, args);
transform_block(&mut item.0.block);
}
fn transform_block(block: &mut Block) {
let brace = block.brace_token;
*block = parse_quote!({
Box::pin(async move #block)
});
block.brace_token = brace;
}
enum ArgLifetime {
New(Lifetime),
Existing(Lifetime),
}
impl ArgLifetime {
pub fn lifetime(self) -> Lifetime {
match self {
ArgLifetime::New(lt) | ArgLifetime::Existing(lt) => lt,
}
}
}
#[derive(Default)]
struct ReferenceVisitor {
counter: usize,
lifetimes: Vec<ArgLifetime>,
self_receiver: bool,
self_receiver_new_lifetime: bool,
self_lifetime: Option<Lifetime>,
}
impl VisitMut for ReferenceVisitor {
fn visit_receiver_mut(&mut self, receiver: &mut Receiver) {
self.self_lifetime = Some(if let Some((_, lt)) = &mut receiver.reference {
self.self_receiver = true;
if let Some(lt) = lt {
lt.clone()
} else {
let new_lifetime: Lifetime = parse_quote!('life_self);
lt.replace(new_lifetime.clone());
self.self_receiver_new_lifetime = true;
new_lifetime
}
} else {
return;
});
}
fn visit_type_reference_mut(&mut self, argument: &mut TypeReference) {
if argument.lifetime.is_none() {
let lt = Lifetime::new(&format!("'life{}", self.counter), Span::call_site());
self.lifetimes.push(ArgLifetime::New(parse_quote!(#lt)));
argument.lifetime = Some(lt);
self.counter += 1;
} else {
let lt = argument.lifetime.as_ref().cloned().unwrap();
let ident_matches = |x: &ArgLifetime| {
if let ArgLifetime::Existing(elt) = x {
elt.ident == lt.ident
} else {
false
}
};
if !self.lifetimes.iter().any(ident_matches) {
self.lifetimes.push(ArgLifetime::Existing(lt));
}
}
}
}
fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
let ret = match &sig.output {
ReturnType::Default => quote!(()),
ReturnType::Type(_, ret) => quote!(#ret),
};
sig.asyncness = None;
let mut v = ReferenceVisitor::default();
for input in &mut sig.inputs {
v.visit_fn_arg_mut(input);
}
let mut requires_lifetime = false;
let mut where_clause_lifetimes = vec![];
let mut where_clause_generics = vec![];
let asr: Lifetime = parse_quote!('async_recursion);
for param in sig.generics.type_params() {
let ident = param.ident.clone();
where_clause_generics.push(ident);
requires_lifetime = true;
}
if !v.lifetimes.is_empty() {
requires_lifetime = true;
for alt in v.lifetimes {
if let ArgLifetime::New(lt) = &alt {
sig.generics.params.push(parse_quote!(#lt));
}
let lt = alt.lifetime();
where_clause_lifetimes.push(lt);
}
}
if v.self_receiver {
if v.self_receiver_new_lifetime {
sig.generics.params.push(parse_quote!('life_self));
}
where_clause_lifetimes.extend(v.self_lifetime);
requires_lifetime = true;
}
let box_lifetime: TokenStream = if requires_lifetime {
sig.generics.params.push(parse_quote!('async_recursion));
quote!(+ #asr)
} else {
quote!()
};
let send_bound: TokenStream = if args.send_bound {
quote!(+ ::core::marker::Send)
} else {
quote!()
};
let sync_bound: TokenStream = if args.sync_bound {
quote!(+ ::core::marker::Sync)
} else {
quote!()
};
let where_clause = sig
.generics
.where_clause
.get_or_insert_with(|| WhereClause {
where_token: Default::default(),
predicates: Punctuated::new(),
});
for generic_ident in where_clause_generics {
where_clause
.predicates
.push(parse_quote!(#generic_ident : #asr));
}
for lifetime in where_clause_lifetimes {
where_clause.predicates.push(parse_quote!(#lifetime : #asr));
}
sig.output = parse_quote! {
-> ::core::pin::Pin<Box<
dyn ::core::future::Future<Output = #ret> #box_lifetime #send_bound #sync_bound>>
};
}