script/xpath/
parser.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5use nom::branch::alt;
6use nom::bytes::complete::{tag, take_while1};
7use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0};
8use nom::combinator::{map, opt, recognize, value};
9use nom::error::{Error as NomError, ErrorKind as NomErrorKind, ParseError as NomParseError};
10use nom::multi::{many0, separated_list0};
11use nom::sequence::{delimited, pair, preceded};
12use nom::{Finish, IResult, Parser};
13
14pub(crate) fn parse(input: &str) -> Result<Expr, OwnedParserError> {
15    let (_, ast) = expr(input).finish().map_err(OwnedParserError::from)?;
16    Ok(ast)
17}
18
19#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
20pub(crate) enum Expr {
21    Or(Box<Expr>, Box<Expr>),
22    And(Box<Expr>, Box<Expr>),
23    Equality(Box<Expr>, EqualityOp, Box<Expr>),
24    Relational(Box<Expr>, RelationalOp, Box<Expr>),
25    Additive(Box<Expr>, AdditiveOp, Box<Expr>),
26    Multiplicative(Box<Expr>, MultiplicativeOp, Box<Expr>),
27    Unary(UnaryOp, Box<Expr>),
28    Union(Box<Expr>, Box<Expr>),
29    Path(PathExpr),
30}
31
32#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
33pub(crate) enum EqualityOp {
34    Eq,
35    NotEq,
36}
37
38#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
39pub(crate) enum RelationalOp {
40    Lt,
41    Gt,
42    LtEq,
43    GtEq,
44}
45
46#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
47pub(crate) enum AdditiveOp {
48    Add,
49    Sub,
50}
51
52#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
53pub(crate) enum MultiplicativeOp {
54    Mul,
55    Div,
56    Mod,
57}
58
59#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
60pub(crate) enum UnaryOp {
61    Minus,
62}
63
64#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
65pub(crate) struct PathExpr {
66    pub(crate) is_absolute: bool,
67    pub(crate) is_descendant: bool,
68    pub(crate) steps: Vec<StepExpr>,
69}
70
71#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
72pub(crate) struct PredicateListExpr {
73    pub(crate) predicates: Vec<PredicateExpr>,
74}
75
76#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
77pub(crate) struct PredicateExpr {
78    pub(crate) expr: Expr,
79}
80
81#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
82pub(crate) struct FilterExpr {
83    pub(crate) primary: PrimaryExpr,
84    pub(crate) predicates: PredicateListExpr,
85}
86
87#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
88pub(crate) enum StepExpr {
89    Filter(FilterExpr),
90    Axis(AxisStep),
91}
92
93#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
94pub(crate) struct AxisStep {
95    pub(crate) axis: Axis,
96    pub(crate) node_test: NodeTest,
97    pub(crate) predicates: PredicateListExpr,
98}
99
100#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
101pub(crate) enum Axis {
102    Child,
103    Descendant,
104    Attribute,
105    Self_,
106    DescendantOrSelf,
107    FollowingSibling,
108    Following,
109    Namespace,
110    Parent,
111    Ancestor,
112    PrecedingSibling,
113    Preceding,
114    AncestorOrSelf,
115}
116
117#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
118pub(crate) enum NodeTest {
119    Name(QName),
120    Wildcard,
121    Kind(KindTest),
122}
123
124#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
125pub(crate) struct QName {
126    pub(crate) prefix: Option<String>,
127    pub(crate) local_part: String,
128}
129
130impl std::fmt::Display for QName {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        match &self.prefix {
133            Some(prefix) => write!(f, "{}:{}", prefix, self.local_part),
134            None => write!(f, "{}", self.local_part),
135        }
136    }
137}
138
139#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
140pub(crate) enum KindTest {
141    PI(Option<String>),
142    Comment,
143    Text,
144    Node,
145}
146
147#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
148pub(crate) enum PrimaryExpr {
149    Literal(Literal),
150    Variable(QName),
151    Parenthesized(Box<Expr>),
152    ContextItem,
153    /// We only support the built-in core functions
154    Function(CoreFunction),
155}
156
157#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
158pub(crate) enum Literal {
159    Numeric(NumericLiteral),
160    String(String),
161}
162
163#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
164pub(crate) enum NumericLiteral {
165    Integer(u64),
166    Decimal(f64),
167}
168
169/// In the DOM we do not support custom functions, so we can enumerate the usable ones
170#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
171pub(crate) enum CoreFunction {
172    // Node Set Functions
173    /// last()
174    Last,
175    /// position()
176    Position,
177    /// count(node-set)
178    Count(Box<Expr>),
179    /// id(object)
180    Id(Box<Expr>),
181    /// local-name(node-set?)
182    LocalName(Option<Box<Expr>>),
183    /// namespace-uri(node-set?)
184    NamespaceUri(Option<Box<Expr>>),
185    /// name(node-set?)
186    Name(Option<Box<Expr>>),
187
188    // String Functions
189    /// string(object?)
190    String(Option<Box<Expr>>),
191    /// concat(string, string, ...)
192    Concat(Vec<Expr>),
193    /// starts-with(string, string)
194    StartsWith(Box<Expr>, Box<Expr>),
195    /// contains(string, string)
196    Contains(Box<Expr>, Box<Expr>),
197    /// substring-before(string, string)
198    SubstringBefore(Box<Expr>, Box<Expr>),
199    /// substring-after(string, string)
200    SubstringAfter(Box<Expr>, Box<Expr>),
201    /// substring(string, number, number?)
202    Substring(Box<Expr>, Box<Expr>, Option<Box<Expr>>),
203    /// string-length(string?)
204    StringLength(Option<Box<Expr>>),
205    /// normalize-space(string?)
206    NormalizeSpace(Option<Box<Expr>>),
207    /// translate(string, string, string)
208    Translate(Box<Expr>, Box<Expr>, Box<Expr>),
209
210    // Number Functions
211    /// number(object?)
212    Number(Option<Box<Expr>>),
213    /// sum(node-set)
214    Sum(Box<Expr>),
215    /// floor(number)
216    Floor(Box<Expr>),
217    /// ceiling(number)
218    Ceiling(Box<Expr>),
219    /// round(number)
220    Round(Box<Expr>),
221
222    // Boolean Functions
223    /// boolean(object)
224    Boolean(Box<Expr>),
225    /// not(boolean)
226    Not(Box<Expr>),
227    /// true()
228    True,
229    /// false()
230    False,
231    /// lang(string)
232    Lang(Box<Expr>),
233}
234
235impl CoreFunction {
236    pub(crate) fn name(&self) -> &'static str {
237        match self {
238            CoreFunction::Last => "last",
239            CoreFunction::Position => "position",
240            CoreFunction::Count(_) => "count",
241            CoreFunction::Id(_) => "id",
242            CoreFunction::LocalName(_) => "local-name",
243            CoreFunction::NamespaceUri(_) => "namespace-uri",
244            CoreFunction::Name(_) => "name",
245            CoreFunction::String(_) => "string",
246            CoreFunction::Concat(_) => "concat",
247            CoreFunction::StartsWith(_, _) => "starts-with",
248            CoreFunction::Contains(_, _) => "contains",
249            CoreFunction::SubstringBefore(_, _) => "substring-before",
250            CoreFunction::SubstringAfter(_, _) => "substring-after",
251            CoreFunction::Substring(_, _, _) => "substring",
252            CoreFunction::StringLength(_) => "string-length",
253            CoreFunction::NormalizeSpace(_) => "normalize-space",
254            CoreFunction::Translate(_, _, _) => "translate",
255            CoreFunction::Number(_) => "number",
256            CoreFunction::Sum(_) => "sum",
257            CoreFunction::Floor(_) => "floor",
258            CoreFunction::Ceiling(_) => "ceiling",
259            CoreFunction::Round(_) => "round",
260            CoreFunction::Boolean(_) => "boolean",
261            CoreFunction::Not(_) => "not",
262            CoreFunction::True => "true",
263            CoreFunction::False => "false",
264            CoreFunction::Lang(_) => "lang",
265        }
266    }
267
268    pub(crate) fn min_args(&self) -> usize {
269        match self {
270            // No args
271            CoreFunction::Last |
272            CoreFunction::Position |
273            CoreFunction::True |
274            CoreFunction::False => 0,
275
276            // Optional single arg
277            CoreFunction::LocalName(_) |
278            CoreFunction::NamespaceUri(_) |
279            CoreFunction::Name(_) |
280            CoreFunction::String(_) |
281            CoreFunction::StringLength(_) |
282            CoreFunction::NormalizeSpace(_) |
283            CoreFunction::Number(_) => 0,
284
285            // Required single arg
286            CoreFunction::Count(_) |
287            CoreFunction::Id(_) |
288            CoreFunction::Sum(_) |
289            CoreFunction::Floor(_) |
290            CoreFunction::Ceiling(_) |
291            CoreFunction::Round(_) |
292            CoreFunction::Boolean(_) |
293            CoreFunction::Not(_) |
294            CoreFunction::Lang(_) => 1,
295
296            // Required two args
297            CoreFunction::StartsWith(_, _) |
298            CoreFunction::Contains(_, _) |
299            CoreFunction::SubstringBefore(_, _) |
300            CoreFunction::SubstringAfter(_, _) => 2,
301
302            // Special cases
303            CoreFunction::Concat(_) => 2,          // Minimum 2 args
304            CoreFunction::Substring(_, _, _) => 2, // 2 or 3 args
305            CoreFunction::Translate(_, _, _) => 3, // Exactly 3 args
306        }
307    }
308
309    pub(crate) fn max_args(&self) -> Option<usize> {
310        match self {
311            // No args
312            CoreFunction::Last |
313            CoreFunction::Position |
314            CoreFunction::True |
315            CoreFunction::False => Some(0),
316
317            // Optional single arg (0 or 1)
318            CoreFunction::LocalName(_) |
319            CoreFunction::NamespaceUri(_) |
320            CoreFunction::Name(_) |
321            CoreFunction::String(_) |
322            CoreFunction::StringLength(_) |
323            CoreFunction::NormalizeSpace(_) |
324            CoreFunction::Number(_) => Some(1),
325
326            // Exactly one arg
327            CoreFunction::Count(_) |
328            CoreFunction::Id(_) |
329            CoreFunction::Sum(_) |
330            CoreFunction::Floor(_) |
331            CoreFunction::Ceiling(_) |
332            CoreFunction::Round(_) |
333            CoreFunction::Boolean(_) |
334            CoreFunction::Not(_) |
335            CoreFunction::Lang(_) => Some(1),
336
337            // Exactly two args
338            CoreFunction::StartsWith(_, _) |
339            CoreFunction::Contains(_, _) |
340            CoreFunction::SubstringBefore(_, _) |
341            CoreFunction::SubstringAfter(_, _) => Some(2),
342
343            // Special cases
344            CoreFunction::Concat(_) => None, // Unlimited args
345            CoreFunction::Substring(_, _, _) => Some(3), // 2 or 3 args
346            CoreFunction::Translate(_, _, _) => Some(3), // Exactly 3 args
347        }
348    }
349
350    /// Returns true if the number of arguments is valid for this function
351    pub(crate) fn is_valid_arity(&self, num_args: usize) -> bool {
352        let min = self.min_args();
353        let max = self.max_args();
354
355        num_args >= min && max.is_none_or(|max| num_args <= max)
356    }
357}
358
359#[derive(Clone, Debug, PartialEq)]
360pub(crate) struct OwnedParserError {
361    input: String,
362    kind: NomErrorKind,
363}
364
365impl<'a> From<NomError<&'a str>> for OwnedParserError {
366    fn from(err: NomError<&'a str>) -> Self {
367        OwnedParserError {
368            input: err.input.to_string(),
369            kind: err.code,
370        }
371    }
372}
373
374impl std::fmt::Display for OwnedParserError {
375    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
376        write!(f, "error {:?} at: {}", self.kind, self.input)
377    }
378}
379
380impl std::error::Error for OwnedParserError {}
381
382/// Top-level parser
383fn expr(input: &str) -> IResult<&str, Expr> {
384    expr_single(input)
385}
386
387fn expr_single(input: &str) -> IResult<&str, Expr> {
388    or_expr(input)
389}
390
391fn or_expr(input: &str) -> IResult<&str, Expr> {
392    let (input, first) = and_expr(input)?;
393    let (input, rest) = many0(preceded(ws(tag("or")), and_expr)).parse(input)?;
394
395    Ok((
396        input,
397        rest.into_iter()
398            .fold(first, |acc, expr| Expr::Or(Box::new(acc), Box::new(expr))),
399    ))
400}
401
402fn and_expr(input: &str) -> IResult<&str, Expr> {
403    let (input, first) = equality_expr(input)?;
404    let (input, rest) = many0(preceded(ws(tag("and")), equality_expr)).parse(input)?;
405
406    Ok((
407        input,
408        rest.into_iter()
409            .fold(first, |acc, expr| Expr::And(Box::new(acc), Box::new(expr))),
410    ))
411}
412
413fn equality_expr(input: &str) -> IResult<&str, Expr> {
414    let (input, first) = relational_expr(input)?;
415    let (input, rest) = many0((
416        ws(alt((
417            map(tag("="), |_| EqualityOp::Eq),
418            map(tag("!="), |_| EqualityOp::NotEq),
419        ))),
420        relational_expr,
421    ))
422    .parse(input)?;
423
424    Ok((
425        input,
426        rest.into_iter().fold(first, |acc, (op, expr)| {
427            Expr::Equality(Box::new(acc), op, Box::new(expr))
428        }),
429    ))
430}
431
432fn relational_expr(input: &str) -> IResult<&str, Expr> {
433    let (input, first) = additive_expr(input)?;
434    let (input, rest) = many0((
435        ws(alt((
436            map(tag("<="), |_| RelationalOp::LtEq),
437            map(tag(">="), |_| RelationalOp::GtEq),
438            map(tag("<"), |_| RelationalOp::Lt),
439            map(tag(">"), |_| RelationalOp::Gt),
440        ))),
441        additive_expr,
442    ))
443    .parse(input)?;
444
445    Ok((
446        input,
447        rest.into_iter().fold(first, |acc, (op, expr)| {
448            Expr::Relational(Box::new(acc), op, Box::new(expr))
449        }),
450    ))
451}
452
453fn additive_expr(input: &str) -> IResult<&str, Expr> {
454    let (input, first) = multiplicative_expr(input)?;
455    let (input, rest) = many0((
456        ws(alt((
457            map(tag("+"), |_| AdditiveOp::Add),
458            map(tag("-"), |_| AdditiveOp::Sub),
459        ))),
460        multiplicative_expr,
461    ))
462    .parse(input)?;
463
464    Ok((
465        input,
466        rest.into_iter().fold(first, |acc, (op, expr)| {
467            Expr::Additive(Box::new(acc), op, Box::new(expr))
468        }),
469    ))
470}
471
472fn multiplicative_expr(input: &str) -> IResult<&str, Expr> {
473    let (input, first) = unary_expr(input)?;
474    let (input, rest) = many0((
475        ws(alt((
476            map(tag("*"), |_| MultiplicativeOp::Mul),
477            map(tag("div"), |_| MultiplicativeOp::Div),
478            map(tag("mod"), |_| MultiplicativeOp::Mod),
479        ))),
480        unary_expr,
481    ))
482    .parse(input)?;
483
484    Ok((
485        input,
486        rest.into_iter().fold(first, |acc, (op, expr)| {
487            Expr::Multiplicative(Box::new(acc), op, Box::new(expr))
488        }),
489    ))
490}
491
492fn unary_expr(input: &str) -> IResult<&str, Expr> {
493    let (input, minus_count) = many0(ws(char('-'))).parse(input)?;
494    let (input, expr) = union_expr(input)?;
495
496    Ok((
497        input,
498        (0..minus_count.len()).fold(expr, |acc, _| Expr::Unary(UnaryOp::Minus, Box::new(acc))),
499    ))
500}
501
502fn union_expr(input: &str) -> IResult<&str, Expr> {
503    let (input, first) = path_expr(input)?;
504    let (input, rest) = many0(preceded(ws(char('|')), path_expr)).parse(input)?;
505
506    Ok((
507        input,
508        rest.into_iter().fold(first, |acc, expr| {
509            Expr::Union(Box::new(acc), Box::new(expr))
510        }),
511    ))
512}
513
514fn path_expr(input: &str) -> IResult<&str, Expr> {
515    alt((
516        // "//" RelativePathExpr
517        map(
518            pair(tag("//"), move |i| relative_path_expr(true, i)),
519            |(_, rel_path)| {
520                Expr::Path(PathExpr {
521                    is_absolute: true,
522                    is_descendant: true,
523                    steps: match rel_path {
524                        Expr::Path(p) => p.steps,
525                        _ => unreachable!(),
526                    },
527                })
528            },
529        ),
530        // "/" RelativePathExpr?
531        map(
532            pair(char('/'), opt(move |i| relative_path_expr(false, i))),
533            |(_, rel_path)| {
534                Expr::Path(PathExpr {
535                    is_absolute: true,
536                    is_descendant: false,
537                    steps: rel_path
538                        .map(|p| match p {
539                            Expr::Path(p) => p.steps,
540                            _ => unreachable!(),
541                        })
542                        .unwrap_or_default(),
543                })
544            },
545        ),
546        // RelativePathExpr
547        move |i| relative_path_expr(false, i),
548    ))
549    .parse(input)
550}
551
552fn relative_path_expr(is_descendant: bool, input: &str) -> IResult<&str, Expr> {
553    let (input, first) = step_expr(is_descendant, input)?;
554    let (input, steps) = many0(pair(
555        ws(alt((value(true, tag("//")), value(false, char('/'))))),
556        move |i| step_expr(is_descendant, i),
557    ))
558    .parse(input)?;
559
560    let mut all_steps = vec![first];
561    for (is_descendant, step) in steps {
562        if is_descendant {
563            // Insert an implicit descendant-or-self::node() step
564            all_steps.push(StepExpr::Axis(AxisStep {
565                axis: Axis::DescendantOrSelf,
566                node_test: NodeTest::Kind(KindTest::Node),
567                predicates: PredicateListExpr { predicates: vec![] },
568            }));
569        }
570        all_steps.push(step);
571    }
572
573    Ok((
574        input,
575        Expr::Path(PathExpr {
576            is_absolute: false,
577            is_descendant: false,
578            steps: all_steps,
579        }),
580    ))
581}
582
583fn step_expr(is_descendant: bool, input: &str) -> IResult<&str, StepExpr> {
584    alt((
585        map(filter_expr, StepExpr::Filter),
586        map(|i| axis_step(is_descendant, i), StepExpr::Axis),
587    ))
588    .parse(input)
589}
590
591fn axis_step(is_descendant: bool, input: &str) -> IResult<&str, AxisStep> {
592    let (input, (step, predicates)) = pair(
593        alt((move |i| forward_step(is_descendant, i), reverse_step)),
594        predicate_list,
595    )
596    .parse(input)?;
597
598    let (axis, node_test) = step;
599    Ok((
600        input,
601        AxisStep {
602            axis,
603            node_test,
604            predicates,
605        },
606    ))
607}
608
609fn forward_step(is_descendant: bool, input: &str) -> IResult<&str, (Axis, NodeTest)> {
610    alt((pair(forward_axis, node_test), move |i| {
611        abbrev_forward_step(is_descendant, i)
612    }))
613    .parse(input)
614}
615
616fn forward_axis(input: &str) -> IResult<&str, Axis> {
617    let (input, axis) = alt((
618        value(Axis::Child, tag("child::")),
619        value(Axis::Descendant, tag("descendant::")),
620        value(Axis::Attribute, tag("attribute::")),
621        value(Axis::Self_, tag("self::")),
622        value(Axis::DescendantOrSelf, tag("descendant-or-self::")),
623        value(Axis::FollowingSibling, tag("following-sibling::")),
624        value(Axis::Following, tag("following::")),
625        value(Axis::Namespace, tag("namespace::")),
626    ))
627    .parse(input)?;
628
629    Ok((input, axis))
630}
631
632fn abbrev_forward_step(is_descendant: bool, input: &str) -> IResult<&str, (Axis, NodeTest)> {
633    let (input, attr) = opt(char('@')).parse(input)?;
634    let (input, test) = node_test(input)?;
635
636    Ok((
637        input,
638        (
639            if attr.is_some() {
640                Axis::Attribute
641            } else if is_descendant {
642                Axis::DescendantOrSelf
643            } else {
644                Axis::Child
645            },
646            test,
647        ),
648    ))
649}
650
651fn reverse_step(input: &str) -> IResult<&str, (Axis, NodeTest)> {
652    alt((
653        // ReverseAxis NodeTest
654        pair(reverse_axis, node_test),
655        // AbbrevReverseStep
656        abbrev_reverse_step,
657    ))
658    .parse(input)
659}
660
661fn reverse_axis(input: &str) -> IResult<&str, Axis> {
662    alt((
663        value(Axis::Parent, tag("parent::")),
664        value(Axis::Ancestor, tag("ancestor::")),
665        value(Axis::PrecedingSibling, tag("preceding-sibling::")),
666        value(Axis::Preceding, tag("preceding::")),
667        value(Axis::AncestorOrSelf, tag("ancestor-or-self::")),
668    ))
669    .parse(input)
670}
671
672fn abbrev_reverse_step(input: &str) -> IResult<&str, (Axis, NodeTest)> {
673    map(tag(".."), |_| {
674        (Axis::Parent, NodeTest::Kind(KindTest::Node))
675    })
676    .parse(input)
677}
678
679fn node_test(input: &str) -> IResult<&str, NodeTest> {
680    alt((
681        map(kind_test, NodeTest::Kind),
682        map(name_test, |name| match name {
683            NameTest::Wildcard => NodeTest::Wildcard,
684            NameTest::QName(qname) => NodeTest::Name(qname),
685        }),
686    ))
687    .parse(input)
688}
689
690#[derive(Clone, Debug, PartialEq)]
691enum NameTest {
692    QName(QName),
693    Wildcard,
694}
695
696fn name_test(input: &str) -> IResult<&str, NameTest> {
697    alt((
698        // NCName ":" "*"
699        map((ncname, char(':'), char('*')), |(prefix, _, _)| {
700            NameTest::QName(QName {
701                prefix: Some(prefix.to_string()),
702                local_part: "*".to_string(),
703            })
704        }),
705        // "*"
706        value(NameTest::Wildcard, char('*')),
707        // QName
708        map(qname, NameTest::QName),
709    ))
710    .parse(input)
711}
712
713fn filter_expr(input: &str) -> IResult<&str, FilterExpr> {
714    let (input, primary) = primary_expr(input)?;
715    let (input, predicates) = predicate_list(input)?;
716
717    Ok((
718        input,
719        FilterExpr {
720            primary,
721            predicates,
722        },
723    ))
724}
725
726fn predicate_list(input: &str) -> IResult<&str, PredicateListExpr> {
727    let (input, predicates) = many0(predicate).parse(input)?;
728
729    Ok((input, PredicateListExpr { predicates }))
730}
731
732fn predicate(input: &str) -> IResult<&str, PredicateExpr> {
733    let (input, expr) = delimited(ws(char('[')), expr, ws(char(']'))).parse(input)?;
734    Ok((input, PredicateExpr { expr }))
735}
736
737fn primary_expr(input: &str) -> IResult<&str, PrimaryExpr> {
738    alt((
739        literal,
740        var_ref,
741        map(parenthesized_expr, |e| {
742            PrimaryExpr::Parenthesized(Box::new(e))
743        }),
744        context_item_expr,
745        function_call,
746    ))
747    .parse(input)
748}
749
750fn literal(input: &str) -> IResult<&str, PrimaryExpr> {
751    map(alt((numeric_literal, string_literal)), |lit| {
752        PrimaryExpr::Literal(lit)
753    })
754    .parse(input)
755}
756
757fn numeric_literal(input: &str) -> IResult<&str, Literal> {
758    alt((decimal_literal, integer_literal)).parse(input)
759}
760
761fn var_ref(input: &str) -> IResult<&str, PrimaryExpr> {
762    let (input, _) = char('$').parse(input)?;
763    let (input, name) = qname(input)?;
764    Ok((input, PrimaryExpr::Variable(name)))
765}
766
767fn parenthesized_expr(input: &str) -> IResult<&str, Expr> {
768    delimited(ws(char('(')), expr, ws(char(')'))).parse(input)
769}
770
771fn context_item_expr(input: &str) -> IResult<&str, PrimaryExpr> {
772    map(char('.'), |_| PrimaryExpr::ContextItem).parse(input)
773}
774
775fn function_call(input: &str) -> IResult<&str, PrimaryExpr> {
776    let (input, name) = qname(input)?;
777    let (input, args) = delimited(
778        ws(char('(')),
779        separated_list0(ws(char(',')), expr_single),
780        ws(char(')')),
781    )
782    .parse(input)?;
783
784    // Helper to create error
785    let arity_error = || nom::Err::Error(NomError::new(input, NomErrorKind::Verify));
786
787    let core_fn = match name.local_part.as_str() {
788        // Node Set Functions
789        "last" => CoreFunction::Last,
790        "position" => CoreFunction::Position,
791        "count" => CoreFunction::Count(Box::new(args.into_iter().next().ok_or_else(arity_error)?)),
792        "id" => CoreFunction::Id(Box::new(args.into_iter().next().ok_or_else(arity_error)?)),
793        "local-name" => CoreFunction::LocalName(args.into_iter().next().map(Box::new)),
794        "namespace-uri" => CoreFunction::NamespaceUri(args.into_iter().next().map(Box::new)),
795        "name" => CoreFunction::Name(args.into_iter().next().map(Box::new)),
796
797        // String Functions
798        "string" => CoreFunction::String(args.into_iter().next().map(Box::new)),
799        "concat" => CoreFunction::Concat(args.into_iter().collect()),
800        "starts-with" => {
801            let mut args = args.into_iter();
802            CoreFunction::StartsWith(
803                Box::new(args.next().ok_or_else(arity_error)?),
804                Box::new(args.next().ok_or_else(arity_error)?),
805            )
806        },
807        "contains" => {
808            let mut args = args.into_iter();
809            CoreFunction::Contains(
810                Box::new(args.next().ok_or_else(arity_error)?),
811                Box::new(args.next().ok_or_else(arity_error)?),
812            )
813        },
814        "substring-before" => {
815            let mut args = args.into_iter();
816            CoreFunction::SubstringBefore(
817                Box::new(args.next().ok_or_else(arity_error)?),
818                Box::new(args.next().ok_or_else(arity_error)?),
819            )
820        },
821        "substring-after" => {
822            let mut args = args.into_iter();
823            CoreFunction::SubstringAfter(
824                Box::new(args.next().ok_or_else(arity_error)?),
825                Box::new(args.next().ok_or_else(arity_error)?),
826            )
827        },
828        "substring" => {
829            let mut args = args.into_iter();
830            CoreFunction::Substring(
831                Box::new(args.next().ok_or_else(arity_error)?),
832                Box::new(args.next().ok_or_else(arity_error)?),
833                args.next().map(Box::new),
834            )
835        },
836        "string-length" => CoreFunction::StringLength(args.into_iter().next().map(Box::new)),
837        "normalize-space" => CoreFunction::NormalizeSpace(args.into_iter().next().map(Box::new)),
838        "translate" => {
839            let mut args = args.into_iter();
840            CoreFunction::Translate(
841                Box::new(args.next().ok_or_else(arity_error)?),
842                Box::new(args.next().ok_or_else(arity_error)?),
843                Box::new(args.next().ok_or_else(arity_error)?),
844            )
845        },
846
847        // Number Functions
848        "number" => CoreFunction::Number(args.into_iter().next().map(Box::new)),
849        "sum" => CoreFunction::Sum(Box::new(args.into_iter().next().ok_or_else(arity_error)?)),
850        "floor" => CoreFunction::Floor(Box::new(args.into_iter().next().ok_or_else(arity_error)?)),
851        "ceiling" => {
852            CoreFunction::Ceiling(Box::new(args.into_iter().next().ok_or_else(arity_error)?))
853        },
854        "round" => CoreFunction::Round(Box::new(args.into_iter().next().ok_or_else(arity_error)?)),
855
856        // Boolean Functions
857        "boolean" => {
858            CoreFunction::Boolean(Box::new(args.into_iter().next().ok_or_else(arity_error)?))
859        },
860        "not" => CoreFunction::Not(Box::new(args.into_iter().next().ok_or_else(arity_error)?)),
861        "true" => CoreFunction::True,
862        "false" => CoreFunction::False,
863        "lang" => CoreFunction::Lang(Box::new(args.into_iter().next().ok_or_else(arity_error)?)),
864
865        // Unknown function
866        _ => return Err(nom::Err::Error(NomError::new(input, NomErrorKind::Verify))),
867    };
868
869    Ok((input, PrimaryExpr::Function(core_fn)))
870}
871
872fn kind_test(input: &str) -> IResult<&str, KindTest> {
873    alt((pi_test, comment_test, text_test, any_kind_test)).parse(input)
874}
875
876fn any_kind_test(input: &str) -> IResult<&str, KindTest> {
877    map((tag("node"), ws(char('(')), ws(char(')'))), |_| {
878        KindTest::Node
879    })
880    .parse(input)
881}
882
883fn text_test(input: &str) -> IResult<&str, KindTest> {
884    map((tag("text"), ws(char('(')), ws(char(')'))), |_| {
885        KindTest::Text
886    })
887    .parse(input)
888}
889
890fn comment_test(input: &str) -> IResult<&str, KindTest> {
891    map((tag("comment"), ws(char('(')), ws(char(')'))), |_| {
892        KindTest::Comment
893    })
894    .parse(input)
895}
896
897fn pi_test(input: &str) -> IResult<&str, KindTest> {
898    map(
899        (
900            tag("processing-instruction"),
901            ws(char('(')),
902            opt(ws(string_literal)),
903            ws(char(')')),
904        ),
905        |(_, _, literal, _)| {
906            KindTest::PI(literal.map(|l| match l {
907                Literal::String(s) => s,
908                _ => unreachable!(),
909            }))
910        },
911    )
912    .parse(input)
913}
914
915fn ws<'a, F, O, E>(inner: F) -> impl Parser<&'a str, Output = O, Error = E>
916where
917    E: NomParseError<&'a str>,
918    F: Parser<&'a str, Output = O, Error = E>,
919{
920    delimited(multispace0, inner, multispace0)
921}
922
923fn integer_literal(input: &str) -> IResult<&str, Literal> {
924    map(recognize((opt(char('-')), digit1)), |s: &str| {
925        Literal::Numeric(NumericLiteral::Integer(s.parse().unwrap()))
926    })
927    .parse(input)
928}
929
930fn decimal_literal(input: &str) -> IResult<&str, Literal> {
931    map(
932        recognize((opt(char('-')), opt(digit1), char('.'), digit1)),
933        |s: &str| Literal::Numeric(NumericLiteral::Decimal(s.parse().unwrap())),
934    )
935    .parse(input)
936}
937
938fn string_literal(input: &str) -> IResult<&str, Literal> {
939    alt((
940        delimited(
941            char('"'),
942            map(take_while1(|c| c != '"'), |s: &str| {
943                Literal::String(s.to_string())
944            }),
945            char('"'),
946        ),
947        delimited(
948            char('\''),
949            map(take_while1(|c| c != '\''), |s: &str| {
950                Literal::String(s.to_string())
951            }),
952            char('\''),
953        ),
954    ))
955    .parse(input)
956}
957
958// QName parser
959fn qname(input: &str) -> IResult<&str, QName> {
960    let (input, prefix) = opt((ncname, char(':'))).parse(input)?;
961    let (input, local) = ncname(input)?;
962
963    Ok((
964        input,
965        QName {
966            prefix: prefix.map(|(p, _)| p.to_string()),
967            local_part: local.to_string(),
968        },
969    ))
970}
971
972// NCName parser
973fn ncname(input: &str) -> IResult<&str, &str> {
974    recognize(pair(
975        alpha1,
976        many0(alt((alphanumeric1, tag("-"), tag("_")))),
977    ))
978    .parse(input)
979}
980
981// Test functions to verify the parsers:
982#[cfg(test)]
983mod tests {
984    use super::*;
985
986    #[test]
987    fn test_node_tests() {
988        let cases = vec![
989            ("node()", NodeTest::Kind(KindTest::Node)),
990            ("text()", NodeTest::Kind(KindTest::Text)),
991            ("comment()", NodeTest::Kind(KindTest::Comment)),
992            (
993                "processing-instruction()",
994                NodeTest::Kind(KindTest::PI(None)),
995            ),
996            (
997                "processing-instruction('test')",
998                NodeTest::Kind(KindTest::PI(Some("test".to_string()))),
999            ),
1000            ("*", NodeTest::Wildcard),
1001            (
1002                "prefix:*",
1003                NodeTest::Name(QName {
1004                    prefix: Some("prefix".to_string()),
1005                    local_part: "*".to_string(),
1006                }),
1007            ),
1008            (
1009                "div",
1010                NodeTest::Name(QName {
1011                    prefix: None,
1012                    local_part: "div".to_string(),
1013                }),
1014            ),
1015            (
1016                "ns:div",
1017                NodeTest::Name(QName {
1018                    prefix: Some("ns".to_string()),
1019                    local_part: "div".to_string(),
1020                }),
1021            ),
1022        ];
1023
1024        for (input, expected) in cases {
1025            match node_test(input) {
1026                Ok((remaining, result)) => {
1027                    assert!(remaining.is_empty(), "Parser didn't consume all input");
1028                    assert_eq!(result, expected);
1029                },
1030                Err(e) => panic!("Failed to parse '{}': {:?}", input, e),
1031            }
1032        }
1033    }
1034
1035    #[test]
1036    fn test_filter_expr() {
1037        let cases = vec![
1038            (
1039                "processing-instruction('test')[2]",
1040                Expr::Path(PathExpr {
1041                    is_absolute: false,
1042                    is_descendant: false,
1043                    steps: vec![StepExpr::Axis(AxisStep {
1044                        axis: Axis::Child,
1045                        node_test: NodeTest::Kind(KindTest::PI(Some("test".to_string()))),
1046                        predicates: PredicateListExpr {
1047                            predicates: vec![PredicateExpr {
1048                                expr: Expr::Path(PathExpr {
1049                                    is_absolute: false,
1050                                    is_descendant: false,
1051                                    steps: vec![StepExpr::Filter(FilterExpr {
1052                                        primary: PrimaryExpr::Literal(Literal::Numeric(
1053                                            NumericLiteral::Integer(2),
1054                                        )),
1055                                        predicates: PredicateListExpr { predicates: vec![] },
1056                                    })],
1057                                }),
1058                            }],
1059                        },
1060                    })],
1061                }),
1062            ),
1063            (
1064                "concat('hello', ' ', 'world')",
1065                Expr::Path(PathExpr {
1066                    is_absolute: false,
1067                    is_descendant: false,
1068                    steps: vec![StepExpr::Filter(FilterExpr {
1069                        primary: PrimaryExpr::Function(CoreFunction::Concat(vec![
1070                            Expr::Path(PathExpr {
1071                                is_absolute: false,
1072                                is_descendant: false,
1073                                steps: vec![StepExpr::Filter(FilterExpr {
1074                                    primary: PrimaryExpr::Literal(Literal::String(
1075                                        "hello".to_string(),
1076                                    )),
1077                                    predicates: PredicateListExpr { predicates: vec![] },
1078                                })],
1079                            }),
1080                            Expr::Path(PathExpr {
1081                                is_absolute: false,
1082                                is_descendant: false,
1083                                steps: vec![StepExpr::Filter(FilterExpr {
1084                                    primary: PrimaryExpr::Literal(Literal::String(" ".to_string())),
1085                                    predicates: PredicateListExpr { predicates: vec![] },
1086                                })],
1087                            }),
1088                            Expr::Path(PathExpr {
1089                                is_absolute: false,
1090                                is_descendant: false,
1091                                steps: vec![StepExpr::Filter(FilterExpr {
1092                                    primary: PrimaryExpr::Literal(Literal::String(
1093                                        "world".to_string(),
1094                                    )),
1095                                    predicates: PredicateListExpr { predicates: vec![] },
1096                                })],
1097                            }),
1098                        ])),
1099                        predicates: PredicateListExpr { predicates: vec![] },
1100                    })],
1101                }),
1102            ),
1103        ];
1104
1105        for (input, expected) in cases {
1106            match parse(input) {
1107                Ok(result) => {
1108                    assert_eq!(result, expected);
1109                },
1110                Err(e) => panic!("Failed to parse '{}': {:?}", input, e),
1111            }
1112        }
1113    }
1114
1115    #[test]
1116    fn test_complex_paths() {
1117        let cases = vec![
1118            (
1119                "//*[contains(@class, 'test')]",
1120                Expr::Path(PathExpr {
1121                    is_absolute: true,
1122                    is_descendant: true,
1123                    steps: vec![StepExpr::Axis(AxisStep {
1124                        axis: Axis::Child,
1125                        node_test: NodeTest::Wildcard,
1126                        predicates: PredicateListExpr {
1127                            predicates: vec![PredicateExpr {
1128                                expr: Expr::Path(PathExpr {
1129                                    is_absolute: false,
1130                                    is_descendant: false,
1131                                    steps: vec![StepExpr::Filter(FilterExpr {
1132                                        primary: PrimaryExpr::Function(CoreFunction::Contains(
1133                                            Box::new(Expr::Path(PathExpr {
1134                                                is_absolute: false,
1135                                                is_descendant: false,
1136                                                steps: vec![StepExpr::Axis(AxisStep {
1137                                                    axis: Axis::Attribute,
1138                                                    node_test: NodeTest::Name(QName {
1139                                                        prefix: None,
1140                                                        local_part: "class".to_string(),
1141                                                    }),
1142                                                    predicates: PredicateListExpr {
1143                                                        predicates: vec![],
1144                                                    },
1145                                                })],
1146                                            })),
1147                                            Box::new(Expr::Path(PathExpr {
1148                                                is_absolute: false,
1149                                                is_descendant: false,
1150                                                steps: vec![StepExpr::Filter(FilterExpr {
1151                                                    primary: PrimaryExpr::Literal(Literal::String(
1152                                                        "test".to_string(),
1153                                                    )),
1154                                                    predicates: PredicateListExpr {
1155                                                        predicates: vec![],
1156                                                    },
1157                                                })],
1158                                            })),
1159                                        )),
1160                                        predicates: PredicateListExpr { predicates: vec![] },
1161                                    })],
1162                                }),
1163                            }],
1164                        },
1165                    })],
1166                }),
1167            ),
1168            (
1169                "//div[position() > 1]/*[last()]",
1170                Expr::Path(PathExpr {
1171                    is_absolute: true,
1172                    is_descendant: true,
1173                    steps: vec![
1174                        StepExpr::Axis(AxisStep {
1175                            axis: Axis::Child,
1176                            node_test: NodeTest::Name(QName {
1177                                prefix: None,
1178                                local_part: "div".to_string(),
1179                            }),
1180                            predicates: PredicateListExpr {
1181                                predicates: vec![PredicateExpr {
1182                                    expr: Expr::Relational(
1183                                        Box::new(Expr::Path(PathExpr {
1184                                            is_absolute: false,
1185                                            is_descendant: false,
1186                                            steps: vec![StepExpr::Filter(FilterExpr {
1187                                                primary: PrimaryExpr::Function(
1188                                                    CoreFunction::Position,
1189                                                ),
1190                                                predicates: PredicateListExpr {
1191                                                    predicates: vec![],
1192                                                },
1193                                            })],
1194                                        })),
1195                                        RelationalOp::Gt,
1196                                        Box::new(Expr::Path(PathExpr {
1197                                            is_absolute: false,
1198                                            is_descendant: false,
1199                                            steps: vec![StepExpr::Filter(FilterExpr {
1200                                                primary: PrimaryExpr::Literal(Literal::Numeric(
1201                                                    NumericLiteral::Integer(1),
1202                                                )),
1203                                                predicates: PredicateListExpr {
1204                                                    predicates: vec![],
1205                                                },
1206                                            })],
1207                                        })),
1208                                    ),
1209                                }],
1210                            },
1211                        }),
1212                        StepExpr::Axis(AxisStep {
1213                            axis: Axis::Child,
1214                            node_test: NodeTest::Wildcard,
1215                            predicates: PredicateListExpr {
1216                                predicates: vec![PredicateExpr {
1217                                    expr: Expr::Path(PathExpr {
1218                                        is_absolute: false,
1219                                        is_descendant: false,
1220                                        steps: vec![StepExpr::Filter(FilterExpr {
1221                                            primary: PrimaryExpr::Function(CoreFunction::Last),
1222                                            predicates: PredicateListExpr { predicates: vec![] },
1223                                        })],
1224                                    }),
1225                                }],
1226                            },
1227                        }),
1228                    ],
1229                }),
1230            ),
1231            (
1232                "//mu[@xml:id=\"id1\"]//rho[@title][@xml:lang=\"en-GB\"]",
1233                Expr::Path(PathExpr {
1234                    is_absolute: true,
1235                    is_descendant: true,
1236                    steps: vec![
1237                        StepExpr::Axis(AxisStep {
1238                            axis: Axis::Child,
1239                            node_test: NodeTest::Name(QName {
1240                                prefix: None,
1241                                local_part: "mu".to_string(),
1242                            }),
1243                            predicates: PredicateListExpr {
1244                                predicates: vec![PredicateExpr {
1245                                    expr: Expr::Equality(
1246                                        Box::new(Expr::Path(PathExpr {
1247                                            is_absolute: false,
1248                                            is_descendant: false,
1249                                            steps: vec![StepExpr::Axis(AxisStep {
1250                                                axis: Axis::Attribute,
1251                                                node_test: NodeTest::Name(QName {
1252                                                    prefix: Some("xml".to_string()),
1253                                                    local_part: "id".to_string(),
1254                                                }),
1255                                                predicates: PredicateListExpr {
1256                                                    predicates: vec![],
1257                                                },
1258                                            })],
1259                                        })),
1260                                        EqualityOp::Eq,
1261                                        Box::new(Expr::Path(PathExpr {
1262                                            is_absolute: false,
1263                                            is_descendant: false,
1264                                            steps: vec![StepExpr::Filter(FilterExpr {
1265                                                primary: PrimaryExpr::Literal(Literal::String(
1266                                                    "id1".to_string(),
1267                                                )),
1268                                                predicates: PredicateListExpr {
1269                                                    predicates: vec![],
1270                                                },
1271                                            })],
1272                                        })),
1273                                    ),
1274                                }],
1275                            },
1276                        }),
1277                        StepExpr::Axis(AxisStep {
1278                            axis: Axis::DescendantOrSelf, // Represents the second '//'
1279                            node_test: NodeTest::Kind(KindTest::Node),
1280                            predicates: PredicateListExpr { predicates: vec![] },
1281                        }),
1282                        StepExpr::Axis(AxisStep {
1283                            axis: Axis::Child,
1284                            node_test: NodeTest::Name(QName {
1285                                prefix: None,
1286                                local_part: "rho".to_string(),
1287                            }),
1288                            predicates: PredicateListExpr {
1289                                predicates: vec![
1290                                    PredicateExpr {
1291                                        expr: Expr::Path(PathExpr {
1292                                            is_absolute: false,
1293                                            is_descendant: false,
1294                                            steps: vec![StepExpr::Axis(AxisStep {
1295                                                axis: Axis::Attribute,
1296                                                node_test: NodeTest::Name(QName {
1297                                                    prefix: None,
1298                                                    local_part: "title".to_string(),
1299                                                }),
1300                                                predicates: PredicateListExpr {
1301                                                    predicates: vec![],
1302                                                },
1303                                            })],
1304                                        }),
1305                                    },
1306                                    PredicateExpr {
1307                                        expr: Expr::Equality(
1308                                            Box::new(Expr::Path(PathExpr {
1309                                                is_absolute: false,
1310                                                is_descendant: false,
1311                                                steps: vec![StepExpr::Axis(AxisStep {
1312                                                    axis: Axis::Attribute,
1313                                                    node_test: NodeTest::Name(QName {
1314                                                        prefix: Some("xml".to_string()),
1315                                                        local_part: "lang".to_string(),
1316                                                    }),
1317                                                    predicates: PredicateListExpr {
1318                                                        predicates: vec![],
1319                                                    },
1320                                                })],
1321                                            })),
1322                                            EqualityOp::Eq,
1323                                            Box::new(Expr::Path(PathExpr {
1324                                                is_absolute: false,
1325                                                is_descendant: false,
1326                                                steps: vec![StepExpr::Filter(FilterExpr {
1327                                                    primary: PrimaryExpr::Literal(Literal::String(
1328                                                        "en-GB".to_string(),
1329                                                    )),
1330                                                    predicates: PredicateListExpr {
1331                                                        predicates: vec![],
1332                                                    },
1333                                                })],
1334                                            })),
1335                                        ),
1336                                    },
1337                                ],
1338                            },
1339                        }),
1340                    ],
1341                }),
1342            ),
1343        ];
1344
1345        for (input, expected) in cases {
1346            match parse(input) {
1347                Ok(result) => {
1348                    assert_eq!(result, expected);
1349                },
1350                Err(e) => panic!("Failed to parse '{}': {:?}", input, e),
1351            }
1352        }
1353    }
1354}