sea_query_derive/
raw_sql.rs

1mod token;
2use token::*;
3
4use proc_macro2::{Span, TokenStream};
5use quote::quote;
6use syn::{
7    Ident, Index, LitStr, Member, Token,
8    parse::{Parse, ParseStream},
9};
10
11struct CallArgs {
12    module: Ident,
13    backend: Ident,
14    method: Ident,
15    sql_holder: Option<Ident>,
16    sql_input: LitStr,
17}
18
19impl Parse for CallArgs {
20    fn parse(input: ParseStream) -> syn::Result<Self> {
21        let module: Ident = input.parse()?;
22        let _colon1: Token![::] = input.parse()?;
23        let backend: Ident = input.parse()?;
24        let _colon2: Token![::] = input.parse()?;
25        let method: Ident = input.parse()?;
26        let _comma: Token![,] = input.parse()?;
27        let sql_holder = if input.peek(Ident) {
28            let ident = input.parse()?;
29            let _assign: Token![=] = input.parse()?;
30            Some(ident)
31        } else {
32            None
33        };
34        let sql_input: LitStr = input.parse()?;
35
36        Ok(CallArgs {
37            module,
38            backend,
39            method,
40            sql_holder,
41            sql_input,
42        })
43    }
44}
45
46pub fn expand(input: proc_macro::TokenStream) -> syn::Result<TokenStream> {
47    let CallArgs {
48        module,
49        backend,
50        method,
51        sql_holder,
52        sql_input,
53    } = syn::parse(input)?;
54
55    let backend = match backend.to_string().as_str() {
56        "postgres" => quote!(sea_query::PostgresQueryBuilder),
57        "mysql" => quote!(sea_query::MysqlQueryBuilder),
58        "sqlite" => quote!(sea_query::SqliteQueryBuilder),
59        _ => quote!(#backend),
60    };
61
62    let mut fragments = Vec::new();
63    let mut params = Vec::new();
64
65    let sql_input = sql_input.value();
66    let tokens = Tokenizer::new(&sql_input);
67    let mut in_brace = false;
68    let mut in_paren = false;
69    let mut dot_count = 0;
70    let mut interpolate = false;
71    let mut nested_eval = false;
72    let mut has_ending_comma = false;
73    let mut fragment = String::new();
74    let mut vars: Vec<Var> = vec![Default::default()];
75
76    #[derive(Default)]
77    struct Var<'a> {
78        parts: Vec<&'a str>, // named parts, i.e. a.b.c
79        members: Vec<u32>,   // tuple members
80    }
81
82    for token in tokens {
83        match token {
84            Token::Punctuation("{") => {
85                in_brace = true;
86            }
87            Token::Punctuation("}") => {
88                assert!(in_brace, "unmatched closing brace }}");
89
90                for vars in vars.iter_mut() {
91                    if interpolate {
92                        assert!(vars.members.len() >= 2, "expect 2 numbers around :");
93                        let a = vars.members[vars.members.len() - 2];
94                        let b = vars.members[vars.members.len() - 1];
95                        assert!(a < b, "expect a < b in a:b");
96                        vars.members = (a..=b).collect();
97                    } else {
98                        vars.members.clear();
99                    }
100                }
101
102                let top = {
103                    let v = Ident::new(vars[0].parts[0], Span::call_site());
104                    quote!(#v)
105                };
106                if nested_eval {
107                    assert!(has_ending_comma, "..(), must end with comma ,");
108                    let group_size: usize = vars
109                        .iter()
110                        .map(|var| {
111                            if var.members.is_empty() {
112                                1
113                            } else {
114                                var.members.len()
115                            }
116                        })
117                        .sum();
118                    fragments
119                        .push(quote!(.push_tuple_parameter_groups((&#top).len(), #group_size)));
120                }
121                let mut group = Vec::new();
122
123                for vars in vars {
124                    let mut var = TokenStream::new();
125                    for (i, v) in vars.parts.iter().enumerate() {
126                        if i > 0 {
127                            var.extend(quote!(.));
128                        }
129                        if is_ascii_digits(v) {
130                            if interpolate {
131                                // skip .x
132                                break;
133                            }
134                            let v = Member::Unnamed(Index {
135                                index: v.parse().unwrap(),
136                                span: Span::call_site(),
137                            });
138                            var.extend(quote!(#v));
139                        } else {
140                            let v = Ident::new(v, Span::call_site());
141                            var.extend(quote!(#v));
142                        }
143                    }
144                    if !vars.members.is_empty() {
145                        // there is a range operator `a:b`
146                        if !nested_eval {
147                            let len = vars.members.len();
148                            fragments.push(quote!(.push_parameters(#len)));
149                        }
150
151                        for mul in vars.members.iter() {
152                            let mut var = var.clone();
153                            let mul = Member::Unnamed(Index {
154                                index: *mul,
155                                span: Span::call_site(),
156                            });
157                            var.extend(quote!(#mul));
158                            group.push(quote! { query = query.bind(&#var); });
159                        }
160                    } else if dot_count == 2 && !nested_eval {
161                        // non nested spread `..a`
162                        fragments.push(quote!(.push_parameters((&#var).len())));
163                        group.push(quote! {
164                            for v in (&#var).iter() {
165                                query = query.bind(v);
166                            }
167                        });
168                    } else {
169                        if !nested_eval {
170                            fragments.push(quote!(.push_parameters(1)));
171                        }
172                        group.push(quote! { query = query.bind(&#var); });
173                    }
174                }
175
176                if nested_eval {
177                    params.push(quote! {
178                        for #top in (&#top).iter() {
179                            #(#group)*
180                        }
181                    });
182                } else {
183                    params.append(&mut group);
184                }
185
186                in_brace = false;
187                in_paren = false;
188                dot_count = 0;
189                interpolate = false;
190                nested_eval = false;
191                has_ending_comma = false;
192                vars = vec![Default::default()];
193            }
194            Token::Unquoted(var) if in_brace => {
195                if !fragment.is_empty() {
196                    fragments.push(quote!(.push_fragment(#fragment)));
197                    fragment.clear();
198                }
199                vars.last_mut().unwrap().parts.push(var);
200                if is_ascii_digits(var) {
201                    vars.last_mut()
202                        .unwrap()
203                        .members
204                        .push(var.parse().expect("index out of range"));
205                }
206            }
207            Token::Punctuation(".") if in_brace => {
208                if vars.last_mut().unwrap().parts.is_empty() {
209                    // prefix ..
210                    dot_count += 1;
211                }
212            }
213            Token::Punctuation(":") if in_brace => {
214                if !vars.last_mut().unwrap().parts.is_empty() {
215                    // postfix :
216                    interpolate = true;
217                }
218            }
219            Token::Punctuation("(") if in_brace => {
220                nested_eval = true;
221                in_paren = true;
222            }
223            Token::Punctuation(")") if in_brace => {
224                assert!(in_paren, "unmatched closing parenthesis )");
225                in_paren = false
226            }
227            Token::Punctuation(",") if in_brace && in_paren && nested_eval => {
228                // push a new variable
229                vars.push(Default::default());
230            }
231            Token::Punctuation(",") if in_brace && !in_paren && nested_eval => {
232                has_ending_comma = true;
233            }
234            Token::Punctuation(",") if in_brace && !in_paren && !nested_eval => {
235                panic!("unknown extra comma ,")
236            }
237            _ => {
238                if !in_brace {
239                    fragment.push_str(token.as_str());
240                }
241            }
242        }
243    }
244    if !fragment.is_empty() {
245        fragments.push(quote!(.push_fragment(#fragment)));
246        fragment.clear();
247    }
248
249    let (maybe_let, sql_holder) = if let Some(sql_holder) = sql_holder {
250        (quote!(), sql_holder)
251    } else {
252        (quote!(let), Ident::new("sql", Span::call_site()))
253    };
254
255    let output = quote! {{
256        use sea_query::raw_sql::*;
257        let mut builder = RawSqlQueryBuilder::new(#backend);
258        builder
259            #(#fragments)*;
260
261        #maybe_let #sql_holder = builder.finish();
262        let mut query = #module::#method(&#sql_holder);
263        #(#params)*
264
265        query
266    }};
267
268    Ok(output)
269}
270
271fn is_ascii_digits(s: &str) -> bool {
272    s.chars().all(|c| c.is_ascii_digit())
273}