1use alloc::format;
2
3use crate::front::wgsl::error::NumberError;
4use crate::front::wgsl::parse::directive::enable_extension::ImplementedEnableExtension;
5use crate::front::wgsl::parse::lexer::Token;
6use half::f16;
7
8#[derive(Copy, Clone, Debug, PartialEq)]
10pub enum Number {
11    AbstractInt(i64),
13    AbstractFloat(f64),
15    I32(i32),
17    U32(u32),
19    I64(i64),
21    U64(u64),
23    F16(f16),
25    F32(f32),
27    F64(f64),
29}
30
31impl Number {
32    pub(super) const fn requires_enable_extension(&self) -> Option<ImplementedEnableExtension> {
33        match *self {
34            Number::F16(_) => Some(ImplementedEnableExtension::F16),
35            _ => None,
36        }
37    }
38}
39
40pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) {
41    let (result, rest) = parse(input);
42    (Token::Number(result), rest)
43}
44
45enum Kind {
46    Int(IntKind),
47    Float(FloatKind),
48}
49
50enum IntKind {
51    I32,
52    U32,
53    I64,
54    U64,
55}
56
57#[derive(Debug)]
58enum FloatKind {
59    F16,
60    F32,
61    F64,
62}
63
64fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
89    macro_rules! consume {
92        ($bytes:ident, $($pattern:pat),*) => {
93            match $bytes {
94                &[$($pattern),*, ref rest @ ..] => { $bytes = rest; true },
95                _ => false,
96            }
97        };
98    }
99
100    macro_rules! consume_map {
104        ($bytes:ident, [$( $($pattern:pat_param),* => $to:expr),* $(,)?]) => {
105            match $bytes {
106                $( &[ $($pattern),*, ref rest @ ..] => { $bytes = rest; Some($to) }, )*
107                _ => None,
108            }
109        };
110    }
111
112    macro_rules! consume_dec_digits {
115        ($bytes:ident) => {{
116            let start_len = $bytes.len();
117            while let &[b'0'..=b'9', ref rest @ ..] = $bytes {
118                $bytes = rest;
119            }
120            start_len - $bytes.len()
121        }};
122    }
123
124    macro_rules! consume_hex_digits {
127        ($bytes:ident) => {{
128            let start_len = $bytes.len();
129            while let &[b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F', ref rest @ ..] = $bytes {
130                $bytes = rest;
131            }
132            start_len - $bytes.len()
133        }};
134    }
135
136    macro_rules! consume_float_suffix {
137        ($bytes:ident) => {
138            consume_map!($bytes, [
139                b'h' => FloatKind::F16,
140                b'f' => FloatKind::F32,
141                b'l', b'f' => FloatKind::F64,
142            ])
143        };
144    }
145
146    macro_rules! rest_to_str {
148        ($bytes:ident) => {
149            &input[input.len() - $bytes.len()..]
150        };
151    }
152
153    struct ExtractSubStr<'a>(&'a str);
154
155    impl<'a> ExtractSubStr<'a> {
156        fn start(input: &'a str, start: &'a [u8]) -> Self {
159            let start = input.len() - start.len();
160            Self(&input[start..])
161        }
162        fn end(&self, end: &'a [u8]) -> &'a str {
165            let end = self.0.len() - end.len();
166            &self.0[..end]
167        }
168    }
169
170    let mut bytes = input.as_bytes();
171
172    let general_extract = ExtractSubStr::start(input, bytes);
173
174    if consume!(bytes, b'0', b'x' | b'X') {
175        let digits_extract = ExtractSubStr::start(input, bytes);
176
177        let consumed = consume_hex_digits!(bytes);
178
179        if consume!(bytes, b'.') {
180            let consumed_after_period = consume_hex_digits!(bytes);
181
182            if consumed + consumed_after_period == 0 {
183                return (Err(NumberError::Invalid), rest_to_str!(bytes));
184            }
185
186            let significand = general_extract.end(bytes);
187
188            if consume!(bytes, b'p' | b'P') {
189                consume!(bytes, b'+' | b'-');
190                let consumed = consume_dec_digits!(bytes);
191
192                if consumed == 0 {
193                    return (Err(NumberError::Invalid), rest_to_str!(bytes));
194                }
195
196                let number = general_extract.end(bytes);
197
198                let kind = consume_float_suffix!(bytes);
199
200                (parse_hex_float(number, kind), rest_to_str!(bytes))
201            } else {
202                (
203                    parse_hex_float_missing_exponent(significand, None),
204                    rest_to_str!(bytes),
205                )
206            }
207        } else {
208            if consumed == 0 {
209                return (Err(NumberError::Invalid), rest_to_str!(bytes));
210            }
211
212            let significand = general_extract.end(bytes);
213            let digits = digits_extract.end(bytes);
214
215            let exp_extract = ExtractSubStr::start(input, bytes);
216
217            if consume!(bytes, b'p' | b'P') {
218                consume!(bytes, b'+' | b'-');
219                let consumed = consume_dec_digits!(bytes);
220
221                if consumed == 0 {
222                    return (Err(NumberError::Invalid), rest_to_str!(bytes));
223                }
224
225                let exponent = exp_extract.end(bytes);
226
227                let kind = consume_float_suffix!(bytes);
228
229                (
230                    parse_hex_float_missing_period(significand, exponent, kind),
231                    rest_to_str!(bytes),
232                )
233            } else {
234                let kind = consume_map!(bytes, [
235                    b'i' => IntKind::I32,
236                    b'u' => IntKind::U32,
237                    b'l', b'i' => IntKind::I64,
238                    b'l', b'u' => IntKind::U64,
239                ]);
240
241                (parse_hex_int(digits, kind), rest_to_str!(bytes))
242            }
243        }
244    } else {
245        let is_first_zero = bytes.first() == Some(&b'0');
246
247        let consumed = consume_dec_digits!(bytes);
248
249        if consume!(bytes, b'.') {
250            let consumed_after_period = consume_dec_digits!(bytes);
251
252            if consumed + consumed_after_period == 0 {
253                return (Err(NumberError::Invalid), rest_to_str!(bytes));
254            }
255
256            if consume!(bytes, b'e' | b'E') {
257                consume!(bytes, b'+' | b'-');
258                let consumed = consume_dec_digits!(bytes);
259
260                if consumed == 0 {
261                    return (Err(NumberError::Invalid), rest_to_str!(bytes));
262                }
263            }
264
265            let number = general_extract.end(bytes);
266
267            let kind = consume_float_suffix!(bytes);
268
269            (parse_dec_float(number, kind), rest_to_str!(bytes))
270        } else {
271            if consumed == 0 {
272                return (Err(NumberError::Invalid), rest_to_str!(bytes));
273            }
274
275            if consume!(bytes, b'e' | b'E') {
276                consume!(bytes, b'+' | b'-');
277                let consumed = consume_dec_digits!(bytes);
278
279                if consumed == 0 {
280                    return (Err(NumberError::Invalid), rest_to_str!(bytes));
281                }
282
283                let number = general_extract.end(bytes);
284
285                let kind = consume_float_suffix!(bytes);
286
287                (parse_dec_float(number, kind), rest_to_str!(bytes))
288            } else {
289                if consumed > 1 && is_first_zero {
291                    return (Err(NumberError::Invalid), rest_to_str!(bytes));
292                }
293
294                let digits = general_extract.end(bytes);
295
296                let kind = consume_map!(bytes, [
297                    b'i' => Kind::Int(IntKind::I32),
298                    b'u' => Kind::Int(IntKind::U32),
299                    b'l', b'i' => Kind::Int(IntKind::I64),
300                    b'l', b'u' => Kind::Int(IntKind::U64),
301                    b'h' => Kind::Float(FloatKind::F16),
302                    b'f' => Kind::Float(FloatKind::F32),
303                    b'l', b'f' => Kind::Float(FloatKind::F64),
304                ]);
305
306                (parse_dec(digits, kind), rest_to_str!(bytes))
307            }
308        }
309    }
310}
311
312fn parse_hex_float_missing_exponent(
313    significand: &str,
315    kind: Option<FloatKind>,
316) -> Result<Number, NumberError> {
317    let hexf_input = format!("{}{}", significand, "p0");
318    parse_hex_float(&hexf_input, kind)
319}
320
321fn parse_hex_float_missing_period(
322    significand: &str,
324    exponent: &str,
326    kind: Option<FloatKind>,
327) -> Result<Number, NumberError> {
328    let hexf_input = format!("{significand}.{exponent}");
329    parse_hex_float(&hexf_input, kind)
330}
331
332fn parse_hex_int(
333    digits: &str,
335    kind: Option<IntKind>,
336) -> Result<Number, NumberError> {
337    parse_int(digits, kind, 16)
338}
339
340fn parse_dec(
341    digits: &str,
343    kind: Option<Kind>,
344) -> Result<Number, NumberError> {
345    match kind {
346        None => parse_int(digits, None, 10),
347        Some(Kind::Int(kind)) => parse_int(digits, Some(kind), 10),
348        Some(Kind::Float(kind)) => parse_dec_float(digits, Some(kind)),
349    }
350}
351
352fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
379    match kind {
380        None => match hexf_parse::parse_hexf64(input, false) {
381            Ok(num) => Ok(Number::AbstractFloat(num)),
382            _ => Err(NumberError::NotRepresentable),
384        },
385        Some(FloatKind::F16) => Err(NumberError::NotRepresentable),
387        Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) {
388            Ok(num) => Ok(Number::F32(num)),
389            _ => Err(NumberError::NotRepresentable),
391        },
392        Some(FloatKind::F64) => match hexf_parse::parse_hexf64(input, false) {
393            Ok(num) => Ok(Number::F64(num)),
394            _ => Err(NumberError::NotRepresentable),
396        },
397    }
398}
399
400fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
403    match kind {
404        None => {
405            let num = input.parse::<f64>().unwrap(); num.is_finite()
407                .then_some(Number::AbstractFloat(num))
408                .ok_or(NumberError::NotRepresentable)
409        }
410        Some(FloatKind::F32) => {
411            let num = input.parse::<f32>().unwrap(); num.is_finite()
413                .then_some(Number::F32(num))
414                .ok_or(NumberError::NotRepresentable)
415        }
416        Some(FloatKind::F64) => {
417            let num = input.parse::<f64>().unwrap(); num.is_finite()
419                .then_some(Number::F64(num))
420                .ok_or(NumberError::NotRepresentable)
421        }
422        Some(FloatKind::F16) => {
423            let num = input.parse::<f16>().unwrap(); num.is_finite()
425                .then_some(Number::F16(num))
426                .ok_or(NumberError::NotRepresentable)
427        }
428    }
429}
430
431fn parse_int(input: &str, kind: Option<IntKind>, radix: u32) -> Result<Number, NumberError> {
432    fn map_err(e: core::num::ParseIntError) -> NumberError {
433        match *e.kind() {
434            core::num::IntErrorKind::PosOverflow | core::num::IntErrorKind::NegOverflow => {
435                NumberError::NotRepresentable
436            }
437            _ => unreachable!(),
438        }
439    }
440    match kind {
441        None => match i64::from_str_radix(input, radix) {
442            Ok(num) => Ok(Number::AbstractInt(num)),
443            Err(e) => Err(map_err(e)),
444        },
445        Some(IntKind::I32) => match i32::from_str_radix(input, radix) {
446            Ok(num) => Ok(Number::I32(num)),
447            Err(e) => Err(map_err(e)),
448        },
449        Some(IntKind::U32) => match u32::from_str_radix(input, radix) {
450            Ok(num) => Ok(Number::U32(num)),
451            Err(e) => Err(map_err(e)),
452        },
453        Some(IntKind::I64) => match i64::from_str_radix(input, radix) {
454            Ok(num) => Ok(Number::I64(num)),
455            Err(e) => Err(map_err(e)),
456        },
457        Some(IntKind::U64) => match u64::from_str_radix(input, radix) {
458            Ok(num) => Ok(Number::U64(num)),
459            Err(e) => Err(map_err(e)),
460        },
461    }
462}