sea_query_derive/
raw_sql.rs1mod token;
2use token::*;
3
4use proc_macro2::{Span, TokenStream};
5use quote::quote;
6use syn::{
7 Ident, Index, LitStr, Member, Token,
8 parse::{Parse, ParseStream},
9};
10
11struct CallArgs {
12 module: Ident,
13 backend: Ident,
14 method: Ident,
15 sql_holder: Option<Ident>,
16 sql_input: LitStr,
17}
18
19impl Parse for CallArgs {
20 fn parse(input: ParseStream) -> syn::Result<Self> {
21 let module: Ident = input.parse()?;
22 let _colon1: Token![::] = input.parse()?;
23 let backend: Ident = input.parse()?;
24 let _colon2: Token![::] = input.parse()?;
25 let method: Ident = input.parse()?;
26 let _comma: Token![,] = input.parse()?;
27 let sql_holder = if input.peek(Ident) {
28 let ident = input.parse()?;
29 let _assign: Token![=] = input.parse()?;
30 Some(ident)
31 } else {
32 None
33 };
34 let sql_input: LitStr = input.parse()?;
35
36 Ok(CallArgs {
37 module,
38 backend,
39 method,
40 sql_holder,
41 sql_input,
42 })
43 }
44}
45
46pub fn expand(input: proc_macro::TokenStream) -> syn::Result<TokenStream> {
47 let CallArgs {
48 module,
49 backend,
50 method,
51 sql_holder,
52 sql_input,
53 } = syn::parse(input)?;
54
55 let backend = match backend.to_string().as_str() {
56 "postgres" => quote!(sea_query::PostgresQueryBuilder),
57 "mysql" => quote!(sea_query::MysqlQueryBuilder),
58 "sqlite" => quote!(sea_query::SqliteQueryBuilder),
59 _ => quote!(#backend),
60 };
61
62 let mut fragments = Vec::new();
63 let mut params = Vec::new();
64
65 let sql_input = sql_input.value();
66 let tokens = Tokenizer::new(&sql_input);
67 let mut in_brace = false;
68 let mut in_paren = false;
69 let mut dot_count = 0;
70 let mut interpolate = false;
71 let mut nested_eval = false;
72 let mut has_ending_comma = false;
73 let mut fragment = String::new();
74 let mut vars: Vec<Var> = vec![Default::default()];
75
76 #[derive(Default)]
77 struct Var<'a> {
78 parts: Vec<&'a str>, members: Vec<u32>, }
81
82 for token in tokens {
83 match token {
84 Token::Punctuation("{") => {
85 in_brace = true;
86 }
87 Token::Punctuation("}") => {
88 assert!(in_brace, "unmatched closing brace }}");
89
90 for vars in vars.iter_mut() {
91 if interpolate {
92 assert!(vars.members.len() >= 2, "expect 2 numbers around :");
93 let a = vars.members[vars.members.len() - 2];
94 let b = vars.members[vars.members.len() - 1];
95 assert!(a < b, "expect a < b in a:b");
96 vars.members = (a..=b).collect();
97 } else {
98 vars.members.clear();
99 }
100 }
101
102 let top = {
103 let v = Ident::new(vars[0].parts[0], Span::call_site());
104 quote!(#v)
105 };
106 if nested_eval {
107 assert!(has_ending_comma, "..(), must end with comma ,");
108 let group_size: usize = vars
109 .iter()
110 .map(|var| {
111 if var.members.is_empty() {
112 1
113 } else {
114 var.members.len()
115 }
116 })
117 .sum();
118 fragments
119 .push(quote!(.push_tuple_parameter_groups((&#top).len(), #group_size)));
120 }
121 let mut group = Vec::new();
122
123 for vars in vars {
124 let mut var = TokenStream::new();
125 for (i, v) in vars.parts.iter().enumerate() {
126 if i > 0 {
127 var.extend(quote!(.));
128 }
129 if is_ascii_digits(v) {
130 if interpolate {
131 break;
133 }
134 let v = Member::Unnamed(Index {
135 index: v.parse().unwrap(),
136 span: Span::call_site(),
137 });
138 var.extend(quote!(#v));
139 } else {
140 let v = Ident::new(v, Span::call_site());
141 var.extend(quote!(#v));
142 }
143 }
144 if !vars.members.is_empty() {
145 if !nested_eval {
147 let len = vars.members.len();
148 fragments.push(quote!(.push_parameters(#len)));
149 }
150
151 for mul in vars.members.iter() {
152 let mut var = var.clone();
153 let mul = Member::Unnamed(Index {
154 index: *mul,
155 span: Span::call_site(),
156 });
157 var.extend(quote!(#mul));
158 group.push(quote! { query = query.bind(&#var); });
159 }
160 } else if dot_count == 2 && !nested_eval {
161 fragments.push(quote!(.push_parameters((&#var).len())));
163 group.push(quote! {
164 for v in (&#var).iter() {
165 query = query.bind(v);
166 }
167 });
168 } else {
169 if !nested_eval {
170 fragments.push(quote!(.push_parameters(1)));
171 }
172 group.push(quote! { query = query.bind(&#var); });
173 }
174 }
175
176 if nested_eval {
177 params.push(quote! {
178 for #top in (&#top).iter() {
179 #(#group)*
180 }
181 });
182 } else {
183 params.append(&mut group);
184 }
185
186 in_brace = false;
187 in_paren = false;
188 dot_count = 0;
189 interpolate = false;
190 nested_eval = false;
191 has_ending_comma = false;
192 vars = vec![Default::default()];
193 }
194 Token::Unquoted(var) if in_brace => {
195 if !fragment.is_empty() {
196 fragments.push(quote!(.push_fragment(#fragment)));
197 fragment.clear();
198 }
199 vars.last_mut().unwrap().parts.push(var);
200 if is_ascii_digits(var) {
201 vars.last_mut()
202 .unwrap()
203 .members
204 .push(var.parse().expect("index out of range"));
205 }
206 }
207 Token::Punctuation(".") if in_brace => {
208 if vars.last_mut().unwrap().parts.is_empty() {
209 dot_count += 1;
211 }
212 }
213 Token::Punctuation(":") if in_brace => {
214 if !vars.last_mut().unwrap().parts.is_empty() {
215 interpolate = true;
217 }
218 }
219 Token::Punctuation("(") if in_brace => {
220 nested_eval = true;
221 in_paren = true;
222 }
223 Token::Punctuation(")") if in_brace => {
224 assert!(in_paren, "unmatched closing parenthesis )");
225 in_paren = false
226 }
227 Token::Punctuation(",") if in_brace && in_paren && nested_eval => {
228 vars.push(Default::default());
230 }
231 Token::Punctuation(",") if in_brace && !in_paren && nested_eval => {
232 has_ending_comma = true;
233 }
234 Token::Punctuation(",") if in_brace && !in_paren && !nested_eval => {
235 panic!("unknown extra comma ,")
236 }
237 _ => {
238 if !in_brace {
239 fragment.push_str(token.as_str());
240 }
241 }
242 }
243 }
244 if !fragment.is_empty() {
245 fragments.push(quote!(.push_fragment(#fragment)));
246 fragment.clear();
247 }
248
249 let (maybe_let, sql_holder) = if let Some(sql_holder) = sql_holder {
250 (quote!(), sql_holder)
251 } else {
252 (quote!(let), Ident::new("sql", Span::call_site()))
253 };
254
255 let output = quote! {{
256 use sea_query::raw_sql::*;
257 let mut builder = RawSqlQueryBuilder::new(#backend);
258 builder
259 #(#fragments)*;
260
261 #maybe_let #sql_holder = builder.finish();
262 let mut query = #module::#method(&#sql_holder);
263 #(#params)*
264
265 query
266 }};
267
268 Ok(output)
269}
270
271fn is_ascii_digits(s: &str) -> bool {
272 s.chars().all(|c| c.is_ascii_digit())
273}