script/xpath/
parser.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5use nom::branch::alt;
6use nom::bytes::complete::{tag, take_while1};
7use nom::character::complete::{char, digit1, multispace0};
8use nom::combinator::{map, opt, recognize, value};
9use nom::error::{Error as NomError, ErrorKind as NomErrorKind, ParseError as NomParseError};
10use nom::multi::{many0, separated_list0};
11use nom::sequence::{delimited, pair, preceded};
12use nom::{AsChar, Finish, IResult, Input, Parser};
13
14use crate::dom::bindings::xmlname::{is_valid_continuation, is_valid_start};
15
16pub(crate) fn parse(input: &str) -> Result<Expr, OwnedParserError> {
17    let (_, ast) = expr(input).finish().map_err(OwnedParserError::from)?;
18    Ok(ast)
19}
20
21#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
22pub(crate) enum Expr {
23    Or(Box<Expr>, Box<Expr>),
24    And(Box<Expr>, Box<Expr>),
25    Equality(Box<Expr>, EqualityOp, Box<Expr>),
26    Relational(Box<Expr>, RelationalOp, Box<Expr>),
27    Additive(Box<Expr>, AdditiveOp, Box<Expr>),
28    Multiplicative(Box<Expr>, MultiplicativeOp, Box<Expr>),
29    Unary(UnaryOp, Box<Expr>),
30    Union(Box<Expr>, Box<Expr>),
31    Path(PathExpr),
32}
33
34#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
35pub(crate) enum EqualityOp {
36    Eq,
37    NotEq,
38}
39
40#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
41pub(crate) enum RelationalOp {
42    Lt,
43    Gt,
44    LtEq,
45    GtEq,
46}
47
48#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
49pub(crate) enum AdditiveOp {
50    Add,
51    Sub,
52}
53
54#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
55pub(crate) enum MultiplicativeOp {
56    Mul,
57    Div,
58    Mod,
59}
60
61#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
62pub(crate) enum UnaryOp {
63    Minus,
64}
65
66#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
67pub(crate) struct PathExpr {
68    pub(crate) is_absolute: bool,
69    pub(crate) is_descendant: bool,
70    pub(crate) steps: Vec<StepExpr>,
71}
72
73#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
74pub(crate) struct PredicateListExpr {
75    pub(crate) predicates: Vec<PredicateExpr>,
76}
77
78#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
79pub(crate) struct PredicateExpr {
80    pub(crate) expr: Expr,
81}
82
83#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
84pub(crate) struct FilterExpr {
85    pub(crate) primary: PrimaryExpr,
86    pub(crate) predicates: PredicateListExpr,
87}
88
89#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
90pub(crate) enum StepExpr {
91    Filter(FilterExpr),
92    Axis(AxisStep),
93}
94
95#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
96pub(crate) struct AxisStep {
97    pub(crate) axis: Axis,
98    pub(crate) node_test: NodeTest,
99    pub(crate) predicates: PredicateListExpr,
100}
101
102#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
103pub(crate) enum Axis {
104    Child,
105    Descendant,
106    Attribute,
107    Self_,
108    DescendantOrSelf,
109    FollowingSibling,
110    Following,
111    Namespace,
112    Parent,
113    Ancestor,
114    PrecedingSibling,
115    Preceding,
116    AncestorOrSelf,
117}
118
119#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
120pub(crate) enum NodeTest {
121    Name(QName),
122    Wildcard,
123    Kind(KindTest),
124}
125
126#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
127pub(crate) struct QName {
128    pub(crate) prefix: Option<String>,
129    pub(crate) local_part: String,
130}
131
132impl std::fmt::Display for QName {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        match &self.prefix {
135            Some(prefix) => write!(f, "{}:{}", prefix, self.local_part),
136            None => write!(f, "{}", self.local_part),
137        }
138    }
139}
140
141#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
142pub(crate) enum KindTest {
143    PI(Option<String>),
144    Comment,
145    Text,
146    Node,
147}
148
149#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
150pub(crate) enum PrimaryExpr {
151    Literal(Literal),
152    Variable(QName),
153    Parenthesized(Box<Expr>),
154    ContextItem,
155    /// We only support the built-in core functions
156    Function(CoreFunction),
157}
158
159#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
160pub(crate) enum Literal {
161    Numeric(NumericLiteral),
162    String(String),
163}
164
165#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
166pub(crate) enum NumericLiteral {
167    Integer(u64),
168    Decimal(f64),
169}
170
171/// In the DOM we do not support custom functions, so we can enumerate the usable ones
172#[derive(Clone, Debug, MallocSizeOf, PartialEq)]
173pub(crate) enum CoreFunction {
174    // Node Set Functions
175    /// last()
176    Last,
177    /// position()
178    Position,
179    /// count(node-set)
180    Count(Box<Expr>),
181    /// id(object)
182    Id(Box<Expr>),
183    /// local-name(node-set?)
184    LocalName(Option<Box<Expr>>),
185    /// namespace-uri(node-set?)
186    NamespaceUri(Option<Box<Expr>>),
187    /// name(node-set?)
188    Name(Option<Box<Expr>>),
189
190    // String Functions
191    /// string(object?)
192    String(Option<Box<Expr>>),
193    /// concat(string, string, ...)
194    Concat(Vec<Expr>),
195    /// starts-with(string, string)
196    StartsWith(Box<Expr>, Box<Expr>),
197    /// contains(string, string)
198    Contains(Box<Expr>, Box<Expr>),
199    /// substring-before(string, string)
200    SubstringBefore(Box<Expr>, Box<Expr>),
201    /// substring-after(string, string)
202    SubstringAfter(Box<Expr>, Box<Expr>),
203    /// substring(string, number, number?)
204    Substring(Box<Expr>, Box<Expr>, Option<Box<Expr>>),
205    /// string-length(string?)
206    StringLength(Option<Box<Expr>>),
207    /// normalize-space(string?)
208    NormalizeSpace(Option<Box<Expr>>),
209    /// translate(string, string, string)
210    Translate(Box<Expr>, Box<Expr>, Box<Expr>),
211
212    // Number Functions
213    /// number(object?)
214    Number(Option<Box<Expr>>),
215    /// sum(node-set)
216    Sum(Box<Expr>),
217    /// floor(number)
218    Floor(Box<Expr>),
219    /// ceiling(number)
220    Ceiling(Box<Expr>),
221    /// round(number)
222    Round(Box<Expr>),
223
224    // Boolean Functions
225    /// boolean(object)
226    Boolean(Box<Expr>),
227    /// not(boolean)
228    Not(Box<Expr>),
229    /// true()
230    True,
231    /// false()
232    False,
233    /// lang(string)
234    Lang(Box<Expr>),
235}
236
237impl CoreFunction {
238    pub(crate) fn name(&self) -> &'static str {
239        match self {
240            CoreFunction::Last => "last",
241            CoreFunction::Position => "position",
242            CoreFunction::Count(_) => "count",
243            CoreFunction::Id(_) => "id",
244            CoreFunction::LocalName(_) => "local-name",
245            CoreFunction::NamespaceUri(_) => "namespace-uri",
246            CoreFunction::Name(_) => "name",
247            CoreFunction::String(_) => "string",
248            CoreFunction::Concat(_) => "concat",
249            CoreFunction::StartsWith(_, _) => "starts-with",
250            CoreFunction::Contains(_, _) => "contains",
251            CoreFunction::SubstringBefore(_, _) => "substring-before",
252            CoreFunction::SubstringAfter(_, _) => "substring-after",
253            CoreFunction::Substring(_, _, _) => "substring",
254            CoreFunction::StringLength(_) => "string-length",
255            CoreFunction::NormalizeSpace(_) => "normalize-space",
256            CoreFunction::Translate(_, _, _) => "translate",
257            CoreFunction::Number(_) => "number",
258            CoreFunction::Sum(_) => "sum",
259            CoreFunction::Floor(_) => "floor",
260            CoreFunction::Ceiling(_) => "ceiling",
261            CoreFunction::Round(_) => "round",
262            CoreFunction::Boolean(_) => "boolean",
263            CoreFunction::Not(_) => "not",
264            CoreFunction::True => "true",
265            CoreFunction::False => "false",
266            CoreFunction::Lang(_) => "lang",
267        }
268    }
269
270    pub(crate) fn min_args(&self) -> usize {
271        match self {
272            // No args
273            CoreFunction::Last |
274            CoreFunction::Position |
275            CoreFunction::True |
276            CoreFunction::False => 0,
277
278            // Optional single arg
279            CoreFunction::LocalName(_) |
280            CoreFunction::NamespaceUri(_) |
281            CoreFunction::Name(_) |
282            CoreFunction::String(_) |
283            CoreFunction::StringLength(_) |
284            CoreFunction::NormalizeSpace(_) |
285            CoreFunction::Number(_) => 0,
286
287            // Required single arg
288            CoreFunction::Count(_) |
289            CoreFunction::Id(_) |
290            CoreFunction::Sum(_) |
291            CoreFunction::Floor(_) |
292            CoreFunction::Ceiling(_) |
293            CoreFunction::Round(_) |
294            CoreFunction::Boolean(_) |
295            CoreFunction::Not(_) |
296            CoreFunction::Lang(_) => 1,
297
298            // Required two args
299            CoreFunction::StartsWith(_, _) |
300            CoreFunction::Contains(_, _) |
301            CoreFunction::SubstringBefore(_, _) |
302            CoreFunction::SubstringAfter(_, _) => 2,
303
304            // Special cases
305            CoreFunction::Concat(_) => 2,          // Minimum 2 args
306            CoreFunction::Substring(_, _, _) => 2, // 2 or 3 args
307            CoreFunction::Translate(_, _, _) => 3, // Exactly 3 args
308        }
309    }
310
311    pub(crate) fn max_args(&self) -> Option<usize> {
312        match self {
313            // No args
314            CoreFunction::Last |
315            CoreFunction::Position |
316            CoreFunction::True |
317            CoreFunction::False => Some(0),
318
319            // Optional single arg (0 or 1)
320            CoreFunction::LocalName(_) |
321            CoreFunction::NamespaceUri(_) |
322            CoreFunction::Name(_) |
323            CoreFunction::String(_) |
324            CoreFunction::StringLength(_) |
325            CoreFunction::NormalizeSpace(_) |
326            CoreFunction::Number(_) => Some(1),
327
328            // Exactly one arg
329            CoreFunction::Count(_) |
330            CoreFunction::Id(_) |
331            CoreFunction::Sum(_) |
332            CoreFunction::Floor(_) |
333            CoreFunction::Ceiling(_) |
334            CoreFunction::Round(_) |
335            CoreFunction::Boolean(_) |
336            CoreFunction::Not(_) |
337            CoreFunction::Lang(_) => Some(1),
338
339            // Exactly two args
340            CoreFunction::StartsWith(_, _) |
341            CoreFunction::Contains(_, _) |
342            CoreFunction::SubstringBefore(_, _) |
343            CoreFunction::SubstringAfter(_, _) => Some(2),
344
345            // Special cases
346            CoreFunction::Concat(_) => None, // Unlimited args
347            CoreFunction::Substring(_, _, _) => Some(3), // 2 or 3 args
348            CoreFunction::Translate(_, _, _) => Some(3), // Exactly 3 args
349        }
350    }
351
352    /// Returns true if the number of arguments is valid for this function
353    pub(crate) fn is_valid_arity(&self, num_args: usize) -> bool {
354        let min = self.min_args();
355        let max = self.max_args();
356
357        num_args >= min && max.is_none_or(|max| num_args <= max)
358    }
359}
360
361#[derive(Clone, Debug, PartialEq)]
362pub(crate) struct OwnedParserError {
363    input: String,
364    kind: NomErrorKind,
365}
366
367impl<'a> From<NomError<&'a str>> for OwnedParserError {
368    fn from(err: NomError<&'a str>) -> Self {
369        OwnedParserError {
370            input: err.input.to_string(),
371            kind: err.code,
372        }
373    }
374}
375
376impl std::fmt::Display for OwnedParserError {
377    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
378        write!(f, "error {:?} at: {}", self.kind, self.input)
379    }
380}
381
382impl std::error::Error for OwnedParserError {}
383
384/// Top-level parser
385fn expr(input: &str) -> IResult<&str, Expr> {
386    expr_single(input)
387}
388
389fn expr_single(input: &str) -> IResult<&str, Expr> {
390    or_expr(input)
391}
392
393fn or_expr(input: &str) -> IResult<&str, Expr> {
394    let (input, first) = and_expr(input)?;
395    let (input, rest) = many0(preceded(ws(tag("or")), and_expr)).parse(input)?;
396
397    Ok((
398        input,
399        rest.into_iter()
400            .fold(first, |acc, expr| Expr::Or(Box::new(acc), Box::new(expr))),
401    ))
402}
403
404fn and_expr(input: &str) -> IResult<&str, Expr> {
405    let (input, first) = equality_expr(input)?;
406    let (input, rest) = many0(preceded(ws(tag("and")), equality_expr)).parse(input)?;
407
408    Ok((
409        input,
410        rest.into_iter()
411            .fold(first, |acc, expr| Expr::And(Box::new(acc), Box::new(expr))),
412    ))
413}
414
415fn equality_expr(input: &str) -> IResult<&str, Expr> {
416    let (input, first) = relational_expr(input)?;
417    let (input, rest) = many0((
418        ws(alt((
419            map(tag("="), |_| EqualityOp::Eq),
420            map(tag("!="), |_| EqualityOp::NotEq),
421        ))),
422        relational_expr,
423    ))
424    .parse(input)?;
425
426    Ok((
427        input,
428        rest.into_iter().fold(first, |acc, (op, expr)| {
429            Expr::Equality(Box::new(acc), op, Box::new(expr))
430        }),
431    ))
432}
433
434fn relational_expr(input: &str) -> IResult<&str, Expr> {
435    let (input, first) = additive_expr(input)?;
436    let (input, rest) = many0((
437        ws(alt((
438            map(tag("<="), |_| RelationalOp::LtEq),
439            map(tag(">="), |_| RelationalOp::GtEq),
440            map(tag("<"), |_| RelationalOp::Lt),
441            map(tag(">"), |_| RelationalOp::Gt),
442        ))),
443        additive_expr,
444    ))
445    .parse(input)?;
446
447    Ok((
448        input,
449        rest.into_iter().fold(first, |acc, (op, expr)| {
450            Expr::Relational(Box::new(acc), op, Box::new(expr))
451        }),
452    ))
453}
454
455fn additive_expr(input: &str) -> IResult<&str, Expr> {
456    let (input, first) = multiplicative_expr(input)?;
457    let (input, rest) = many0((
458        ws(alt((
459            map(tag("+"), |_| AdditiveOp::Add),
460            map(tag("-"), |_| AdditiveOp::Sub),
461        ))),
462        multiplicative_expr,
463    ))
464    .parse(input)?;
465
466    Ok((
467        input,
468        rest.into_iter().fold(first, |acc, (op, expr)| {
469            Expr::Additive(Box::new(acc), op, Box::new(expr))
470        }),
471    ))
472}
473
474fn multiplicative_expr(input: &str) -> IResult<&str, Expr> {
475    let (input, first) = unary_expr(input)?;
476    let (input, rest) = many0((
477        ws(alt((
478            map(tag("*"), |_| MultiplicativeOp::Mul),
479            map(tag("div"), |_| MultiplicativeOp::Div),
480            map(tag("mod"), |_| MultiplicativeOp::Mod),
481        ))),
482        unary_expr,
483    ))
484    .parse(input)?;
485
486    Ok((
487        input,
488        rest.into_iter().fold(first, |acc, (op, expr)| {
489            Expr::Multiplicative(Box::new(acc), op, Box::new(expr))
490        }),
491    ))
492}
493
494fn unary_expr(input: &str) -> IResult<&str, Expr> {
495    let (input, minus_count) = many0(ws(char('-'))).parse(input)?;
496    let (input, expr) = union_expr(input)?;
497
498    Ok((
499        input,
500        (0..minus_count.len()).fold(expr, |acc, _| Expr::Unary(UnaryOp::Minus, Box::new(acc))),
501    ))
502}
503
504fn union_expr(input: &str) -> IResult<&str, Expr> {
505    let (input, first) = path_expr(input)?;
506    let (input, rest) = many0(preceded(ws(char('|')), path_expr)).parse(input)?;
507
508    Ok((
509        input,
510        rest.into_iter().fold(first, |acc, expr| {
511            Expr::Union(Box::new(acc), Box::new(expr))
512        }),
513    ))
514}
515
516fn path_expr(input: &str) -> IResult<&str, Expr> {
517    ws(alt((
518        // "//" RelativePathExpr
519        map(
520            pair(tag("//"), move |i| relative_path_expr(true, i)),
521            |(_, rel_path)| {
522                Expr::Path(PathExpr {
523                    is_absolute: true,
524                    is_descendant: true,
525                    steps: match rel_path {
526                        Expr::Path(p) => p.steps,
527                        _ => unreachable!(),
528                    },
529                })
530            },
531        ),
532        // "/" RelativePathExpr?
533        map(
534            pair(char('/'), opt(move |i| relative_path_expr(false, i))),
535            |(_, rel_path)| {
536                Expr::Path(PathExpr {
537                    is_absolute: true,
538                    is_descendant: false,
539                    steps: rel_path
540                        .map(|p| match p {
541                            Expr::Path(p) => p.steps,
542                            _ => unreachable!(),
543                        })
544                        .unwrap_or_default(),
545                })
546            },
547        ),
548        // RelativePathExpr
549        move |i| relative_path_expr(false, i),
550    )))
551    .parse(input)
552}
553
554fn relative_path_expr(is_descendant: bool, input: &str) -> IResult<&str, Expr> {
555    let (input, first) = step_expr(is_descendant, input)?;
556    let (input, steps) = many0(pair(
557        ws(alt((value(true, tag("//")), value(false, char('/'))))),
558        move |i| step_expr(is_descendant, i),
559    ))
560    .parse(input)?;
561
562    let mut all_steps = vec![first];
563    for (is_descendant, step) in steps {
564        if is_descendant {
565            // Insert an implicit descendant-or-self::node() step
566            all_steps.push(StepExpr::Axis(AxisStep {
567                axis: Axis::DescendantOrSelf,
568                node_test: NodeTest::Kind(KindTest::Node),
569                predicates: PredicateListExpr { predicates: vec![] },
570            }));
571        }
572        all_steps.push(step);
573    }
574
575    Ok((
576        input,
577        Expr::Path(PathExpr {
578            is_absolute: false,
579            is_descendant: false,
580            steps: all_steps,
581        }),
582    ))
583}
584
585fn step_expr(is_descendant: bool, input: &str) -> IResult<&str, StepExpr> {
586    alt((
587        map(filter_expr, StepExpr::Filter),
588        map(|i| axis_step(is_descendant, i), StepExpr::Axis),
589    ))
590    .parse(input)
591}
592
593fn axis_step(is_descendant: bool, input: &str) -> IResult<&str, AxisStep> {
594    let (input, (step, predicates)) = pair(
595        alt((move |i| forward_step(is_descendant, i), reverse_step)),
596        predicate_list,
597    )
598    .parse(input)?;
599
600    let (axis, node_test) = step;
601    Ok((
602        input,
603        AxisStep {
604            axis,
605            node_test,
606            predicates,
607        },
608    ))
609}
610
611fn forward_step(is_descendant: bool, input: &str) -> IResult<&str, (Axis, NodeTest)> {
612    alt((pair(forward_axis, node_test), move |i| {
613        abbrev_forward_step(is_descendant, i)
614    }))
615    .parse(input)
616}
617
618fn forward_axis(input: &str) -> IResult<&str, Axis> {
619    let (input, axis) = alt((
620        value(Axis::Child, tag("child::")),
621        value(Axis::Descendant, tag("descendant::")),
622        value(Axis::Attribute, tag("attribute::")),
623        value(Axis::Self_, tag("self::")),
624        value(Axis::DescendantOrSelf, tag("descendant-or-self::")),
625        value(Axis::FollowingSibling, tag("following-sibling::")),
626        value(Axis::Following, tag("following::")),
627        value(Axis::Namespace, tag("namespace::")),
628    ))
629    .parse(input)?;
630
631    Ok((input, axis))
632}
633
634fn abbrev_forward_step(is_descendant: bool, input: &str) -> IResult<&str, (Axis, NodeTest)> {
635    let (input, attr) = opt(char('@')).parse(input)?;
636    let (input, test) = node_test(input)?;
637
638    Ok((
639        input,
640        (
641            if attr.is_some() {
642                Axis::Attribute
643            } else if is_descendant {
644                Axis::DescendantOrSelf
645            } else {
646                Axis::Child
647            },
648            test,
649        ),
650    ))
651}
652
653fn reverse_step(input: &str) -> IResult<&str, (Axis, NodeTest)> {
654    alt((
655        // ReverseAxis NodeTest
656        pair(reverse_axis, node_test),
657        // AbbrevReverseStep
658        abbrev_reverse_step,
659    ))
660    .parse(input)
661}
662
663fn reverse_axis(input: &str) -> IResult<&str, Axis> {
664    alt((
665        value(Axis::Parent, tag("parent::")),
666        value(Axis::Ancestor, tag("ancestor::")),
667        value(Axis::PrecedingSibling, tag("preceding-sibling::")),
668        value(Axis::Preceding, tag("preceding::")),
669        value(Axis::AncestorOrSelf, tag("ancestor-or-self::")),
670    ))
671    .parse(input)
672}
673
674fn abbrev_reverse_step(input: &str) -> IResult<&str, (Axis, NodeTest)> {
675    map(tag(".."), |_| {
676        (Axis::Parent, NodeTest::Kind(KindTest::Node))
677    })
678    .parse(input)
679}
680
681fn node_test(input: &str) -> IResult<&str, NodeTest> {
682    alt((
683        map(kind_test, NodeTest::Kind),
684        map(name_test, |name| match name {
685            NameTest::Wildcard => NodeTest::Wildcard,
686            NameTest::QName(qname) => NodeTest::Name(qname),
687        }),
688    ))
689    .parse(input)
690}
691
692#[derive(Clone, Debug, PartialEq)]
693enum NameTest {
694    QName(QName),
695    Wildcard,
696}
697
698fn name_test(input: &str) -> IResult<&str, NameTest> {
699    alt((
700        // NCName ":" "*"
701        map((ncname, char(':'), char('*')), |(prefix, _, _)| {
702            NameTest::QName(QName {
703                prefix: Some(prefix.to_string()),
704                local_part: "*".to_string(),
705            })
706        }),
707        // "*"
708        value(NameTest::Wildcard, char('*')),
709        // QName
710        map(qname, NameTest::QName),
711    ))
712    .parse(input)
713}
714
715fn filter_expr(input: &str) -> IResult<&str, FilterExpr> {
716    let (input, primary) = primary_expr(input)?;
717    let (input, predicates) = predicate_list(input)?;
718
719    Ok((
720        input,
721        FilterExpr {
722            primary,
723            predicates,
724        },
725    ))
726}
727
728fn predicate_list(input: &str) -> IResult<&str, PredicateListExpr> {
729    let (input, predicates) = many0(predicate).parse(input)?;
730
731    Ok((input, PredicateListExpr { predicates }))
732}
733
734fn predicate(input: &str) -> IResult<&str, PredicateExpr> {
735    let (input, expr) = delimited(ws(char('[')), expr, ws(char(']'))).parse(input)?;
736    Ok((input, PredicateExpr { expr }))
737}
738
739fn primary_expr(input: &str) -> IResult<&str, PrimaryExpr> {
740    alt((
741        literal,
742        var_ref,
743        map(parenthesized_expr, |e| {
744            PrimaryExpr::Parenthesized(Box::new(e))
745        }),
746        context_item_expr,
747        function_call,
748    ))
749    .parse(input)
750}
751
752fn literal(input: &str) -> IResult<&str, PrimaryExpr> {
753    map(alt((numeric_literal, string_literal)), |lit| {
754        PrimaryExpr::Literal(lit)
755    })
756    .parse(input)
757}
758
759fn numeric_literal(input: &str) -> IResult<&str, Literal> {
760    alt((decimal_literal, integer_literal)).parse(input)
761}
762
763fn var_ref(input: &str) -> IResult<&str, PrimaryExpr> {
764    let (input, _) = char('$').parse(input)?;
765    let (input, name) = qname(input)?;
766    Ok((input, PrimaryExpr::Variable(name)))
767}
768
769fn parenthesized_expr(input: &str) -> IResult<&str, Expr> {
770    delimited(ws(char('(')), expr, ws(char(')'))).parse(input)
771}
772
773fn context_item_expr(input: &str) -> IResult<&str, PrimaryExpr> {
774    map(char('.'), |_| PrimaryExpr::ContextItem).parse(input)
775}
776
777fn function_call(input: &str) -> IResult<&str, PrimaryExpr> {
778    let (input, name) = qname(input)?;
779    let (input, args) = delimited(
780        ws(char('(')),
781        separated_list0(ws(char(',')), expr_single),
782        ws(char(')')),
783    )
784    .parse(input)?;
785
786    // Helper to create error
787    let arity_error = || nom::Err::Error(NomError::new(input, NomErrorKind::Verify));
788
789    let core_fn = match name.local_part.as_str() {
790        // Node Set Functions
791        "last" => CoreFunction::Last,
792        "position" => CoreFunction::Position,
793        "count" => CoreFunction::Count(Box::new(args.into_iter().next().ok_or_else(arity_error)?)),
794        "id" => CoreFunction::Id(Box::new(args.into_iter().next().ok_or_else(arity_error)?)),
795        "local-name" => CoreFunction::LocalName(args.into_iter().next().map(Box::new)),
796        "namespace-uri" => CoreFunction::NamespaceUri(args.into_iter().next().map(Box::new)),
797        "name" => CoreFunction::Name(args.into_iter().next().map(Box::new)),
798
799        // String Functions
800        "string" => CoreFunction::String(args.into_iter().next().map(Box::new)),
801        "concat" => CoreFunction::Concat(args.into_iter().collect()),
802        "starts-with" => {
803            let mut args = args.into_iter();
804            CoreFunction::StartsWith(
805                Box::new(args.next().ok_or_else(arity_error)?),
806                Box::new(args.next().ok_or_else(arity_error)?),
807            )
808        },
809        "contains" => {
810            let mut args = args.into_iter();
811            CoreFunction::Contains(
812                Box::new(args.next().ok_or_else(arity_error)?),
813                Box::new(args.next().ok_or_else(arity_error)?),
814            )
815        },
816        "substring-before" => {
817            let mut args = args.into_iter();
818            CoreFunction::SubstringBefore(
819                Box::new(args.next().ok_or_else(arity_error)?),
820                Box::new(args.next().ok_or_else(arity_error)?),
821            )
822        },
823        "substring-after" => {
824            let mut args = args.into_iter();
825            CoreFunction::SubstringAfter(
826                Box::new(args.next().ok_or_else(arity_error)?),
827                Box::new(args.next().ok_or_else(arity_error)?),
828            )
829        },
830        "substring" => {
831            let mut args = args.into_iter();
832            CoreFunction::Substring(
833                Box::new(args.next().ok_or_else(arity_error)?),
834                Box::new(args.next().ok_or_else(arity_error)?),
835                args.next().map(Box::new),
836            )
837        },
838        "string-length" => CoreFunction::StringLength(args.into_iter().next().map(Box::new)),
839        "normalize-space" => CoreFunction::NormalizeSpace(args.into_iter().next().map(Box::new)),
840        "translate" => {
841            let mut args = args.into_iter();
842            CoreFunction::Translate(
843                Box::new(args.next().ok_or_else(arity_error)?),
844                Box::new(args.next().ok_or_else(arity_error)?),
845                Box::new(args.next().ok_or_else(arity_error)?),
846            )
847        },
848
849        // Number Functions
850        "number" => CoreFunction::Number(args.into_iter().next().map(Box::new)),
851        "sum" => CoreFunction::Sum(Box::new(args.into_iter().next().ok_or_else(arity_error)?)),
852        "floor" => CoreFunction::Floor(Box::new(args.into_iter().next().ok_or_else(arity_error)?)),
853        "ceiling" => {
854            CoreFunction::Ceiling(Box::new(args.into_iter().next().ok_or_else(arity_error)?))
855        },
856        "round" => CoreFunction::Round(Box::new(args.into_iter().next().ok_or_else(arity_error)?)),
857
858        // Boolean Functions
859        "boolean" => {
860            CoreFunction::Boolean(Box::new(args.into_iter().next().ok_or_else(arity_error)?))
861        },
862        "not" => CoreFunction::Not(Box::new(args.into_iter().next().ok_or_else(arity_error)?)),
863        "true" => CoreFunction::True,
864        "false" => CoreFunction::False,
865        "lang" => CoreFunction::Lang(Box::new(args.into_iter().next().ok_or_else(arity_error)?)),
866
867        // Unknown function
868        _ => return Err(nom::Err::Error(NomError::new(input, NomErrorKind::Verify))),
869    };
870
871    Ok((input, PrimaryExpr::Function(core_fn)))
872}
873
874fn kind_test(input: &str) -> IResult<&str, KindTest> {
875    alt((pi_test, comment_test, text_test, any_kind_test)).parse(input)
876}
877
878fn any_kind_test(input: &str) -> IResult<&str, KindTest> {
879    map((tag("node"), ws(char('(')), ws(char(')'))), |_| {
880        KindTest::Node
881    })
882    .parse(input)
883}
884
885fn text_test(input: &str) -> IResult<&str, KindTest> {
886    map((tag("text"), ws(char('(')), ws(char(')'))), |_| {
887        KindTest::Text
888    })
889    .parse(input)
890}
891
892fn comment_test(input: &str) -> IResult<&str, KindTest> {
893    map((tag("comment"), ws(char('(')), ws(char(')'))), |_| {
894        KindTest::Comment
895    })
896    .parse(input)
897}
898
899fn pi_test(input: &str) -> IResult<&str, KindTest> {
900    map(
901        (
902            tag("processing-instruction"),
903            ws(char('(')),
904            opt(ws(string_literal)),
905            ws(char(')')),
906        ),
907        |(_, _, literal, _)| {
908            KindTest::PI(literal.map(|l| match l {
909                Literal::String(s) => s,
910                _ => unreachable!(),
911            }))
912        },
913    )
914    .parse(input)
915}
916
917fn ws<'a, F, O, E>(inner: F) -> impl Parser<&'a str, Output = O, Error = E>
918where
919    E: NomParseError<&'a str>,
920    F: Parser<&'a str, Output = O, Error = E>,
921{
922    delimited(multispace0, inner, multispace0)
923}
924
925fn integer_literal(input: &str) -> IResult<&str, Literal> {
926    map(recognize((opt(char('-')), digit1)), |s: &str| {
927        Literal::Numeric(NumericLiteral::Integer(s.parse().unwrap()))
928    })
929    .parse(input)
930}
931
932fn decimal_literal(input: &str) -> IResult<&str, Literal> {
933    map(
934        recognize((opt(char('-')), opt(digit1), char('.'), digit1)),
935        |s: &str| Literal::Numeric(NumericLiteral::Decimal(s.parse().unwrap())),
936    )
937    .parse(input)
938}
939
940fn string_literal(input: &str) -> IResult<&str, Literal> {
941    alt((
942        delimited(
943            char('"'),
944            map(take_while1(|c| c != '"'), |s: &str| {
945                Literal::String(s.to_string())
946            }),
947            char('"'),
948        ),
949        delimited(
950            char('\''),
951            map(take_while1(|c| c != '\''), |s: &str| {
952                Literal::String(s.to_string())
953            }),
954            char('\''),
955        ),
956    ))
957    .parse(input)
958}
959
960/// <https://www.w3.org/TR/REC-xml-names/#NT-QName>
961fn qname(input: &str) -> IResult<&str, QName> {
962    let (input, prefix) = opt((ncname, char(':'))).parse(input)?;
963    let (input, local) = ncname(input)?;
964
965    Ok((
966        input,
967        QName {
968            prefix: prefix.map(|(p, _)| p.to_string()),
969            local_part: local.to_string(),
970        },
971    ))
972}
973
974/// <https://www.w3.org/TR/REC-xml-names/#NT-NCName>
975fn ncname(input: &str) -> IResult<&str, &str> {
976    fn name_start_character<T, E: NomParseError<T>>(input: T) -> IResult<T, T, E>
977    where
978        T: Input,
979        <T as Input>::Item: AsChar,
980    {
981        input.split_at_position1_complete(
982            |character| !is_valid_start(character.as_char()) || character.as_char() == ':',
983            NomErrorKind::OneOf,
984        )
985    }
986
987    fn name_character<T, E: NomParseError<T>>(input: T) -> IResult<T, T, E>
988    where
989        T: Input,
990        <T as Input>::Item: AsChar,
991    {
992        input.split_at_position1_complete(
993            |character| !is_valid_continuation(character.as_char()) || character.as_char() == ':',
994            NomErrorKind::OneOf,
995        )
996    }
997
998    recognize(pair(name_start_character, many0(name_character))).parse(input)
999}
1000
1001// Test functions to verify the parsers:
1002#[cfg(test)]
1003mod tests {
1004    use super::*;
1005
1006    #[test]
1007    fn test_node_tests() {
1008        let cases = vec![
1009            ("node()", NodeTest::Kind(KindTest::Node)),
1010            ("text()", NodeTest::Kind(KindTest::Text)),
1011            ("comment()", NodeTest::Kind(KindTest::Comment)),
1012            (
1013                "processing-instruction()",
1014                NodeTest::Kind(KindTest::PI(None)),
1015            ),
1016            (
1017                "processing-instruction('test')",
1018                NodeTest::Kind(KindTest::PI(Some("test".to_string()))),
1019            ),
1020            ("*", NodeTest::Wildcard),
1021            (
1022                "prefix:*",
1023                NodeTest::Name(QName {
1024                    prefix: Some("prefix".to_string()),
1025                    local_part: "*".to_string(),
1026                }),
1027            ),
1028            (
1029                "div",
1030                NodeTest::Name(QName {
1031                    prefix: None,
1032                    local_part: "div".to_string(),
1033                }),
1034            ),
1035            (
1036                "ns:div",
1037                NodeTest::Name(QName {
1038                    prefix: Some("ns".to_string()),
1039                    local_part: "div".to_string(),
1040                }),
1041            ),
1042        ];
1043
1044        for (input, expected) in cases {
1045            match node_test(input) {
1046                Ok((remaining, result)) => {
1047                    assert!(remaining.is_empty(), "Parser didn't consume all input");
1048                    assert_eq!(result, expected);
1049                },
1050                Err(e) => panic!("Failed to parse '{}': {:?}", input, e),
1051            }
1052        }
1053    }
1054
1055    #[test]
1056    fn test_filter_expr() {
1057        let cases = vec![
1058            (
1059                "processing-instruction('test')[2]",
1060                Expr::Path(PathExpr {
1061                    is_absolute: false,
1062                    is_descendant: false,
1063                    steps: vec![StepExpr::Axis(AxisStep {
1064                        axis: Axis::Child,
1065                        node_test: NodeTest::Kind(KindTest::PI(Some("test".to_string()))),
1066                        predicates: PredicateListExpr {
1067                            predicates: vec![PredicateExpr {
1068                                expr: Expr::Path(PathExpr {
1069                                    is_absolute: false,
1070                                    is_descendant: false,
1071                                    steps: vec![StepExpr::Filter(FilterExpr {
1072                                        primary: PrimaryExpr::Literal(Literal::Numeric(
1073                                            NumericLiteral::Integer(2),
1074                                        )),
1075                                        predicates: PredicateListExpr { predicates: vec![] },
1076                                    })],
1077                                }),
1078                            }],
1079                        },
1080                    })],
1081                }),
1082            ),
1083            (
1084                "concat('hello', ' ', 'world')",
1085                Expr::Path(PathExpr {
1086                    is_absolute: false,
1087                    is_descendant: false,
1088                    steps: vec![StepExpr::Filter(FilterExpr {
1089                        primary: PrimaryExpr::Function(CoreFunction::Concat(vec![
1090                            Expr::Path(PathExpr {
1091                                is_absolute: false,
1092                                is_descendant: false,
1093                                steps: vec![StepExpr::Filter(FilterExpr {
1094                                    primary: PrimaryExpr::Literal(Literal::String(
1095                                        "hello".to_string(),
1096                                    )),
1097                                    predicates: PredicateListExpr { predicates: vec![] },
1098                                })],
1099                            }),
1100                            Expr::Path(PathExpr {
1101                                is_absolute: false,
1102                                is_descendant: false,
1103                                steps: vec![StepExpr::Filter(FilterExpr {
1104                                    primary: PrimaryExpr::Literal(Literal::String(" ".to_string())),
1105                                    predicates: PredicateListExpr { predicates: vec![] },
1106                                })],
1107                            }),
1108                            Expr::Path(PathExpr {
1109                                is_absolute: false,
1110                                is_descendant: false,
1111                                steps: vec![StepExpr::Filter(FilterExpr {
1112                                    primary: PrimaryExpr::Literal(Literal::String(
1113                                        "world".to_string(),
1114                                    )),
1115                                    predicates: PredicateListExpr { predicates: vec![] },
1116                                })],
1117                            }),
1118                        ])),
1119                        predicates: PredicateListExpr { predicates: vec![] },
1120                    })],
1121                }),
1122            ),
1123        ];
1124
1125        for (input, expected) in cases {
1126            match parse(input) {
1127                Ok(result) => {
1128                    assert_eq!(result, expected);
1129                },
1130                Err(e) => panic!("Failed to parse '{}': {:?}", input, e),
1131            }
1132        }
1133    }
1134
1135    #[test]
1136    fn test_complex_paths() {
1137        let cases = vec![
1138            (
1139                "//*[contains(@class, 'test')]",
1140                Expr::Path(PathExpr {
1141                    is_absolute: true,
1142                    is_descendant: true,
1143                    steps: vec![StepExpr::Axis(AxisStep {
1144                        axis: Axis::Child,
1145                        node_test: NodeTest::Wildcard,
1146                        predicates: PredicateListExpr {
1147                            predicates: vec![PredicateExpr {
1148                                expr: Expr::Path(PathExpr {
1149                                    is_absolute: false,
1150                                    is_descendant: false,
1151                                    steps: vec![StepExpr::Filter(FilterExpr {
1152                                        primary: PrimaryExpr::Function(CoreFunction::Contains(
1153                                            Box::new(Expr::Path(PathExpr {
1154                                                is_absolute: false,
1155                                                is_descendant: false,
1156                                                steps: vec![StepExpr::Axis(AxisStep {
1157                                                    axis: Axis::Attribute,
1158                                                    node_test: NodeTest::Name(QName {
1159                                                        prefix: None,
1160                                                        local_part: "class".to_string(),
1161                                                    }),
1162                                                    predicates: PredicateListExpr {
1163                                                        predicates: vec![],
1164                                                    },
1165                                                })],
1166                                            })),
1167                                            Box::new(Expr::Path(PathExpr {
1168                                                is_absolute: false,
1169                                                is_descendant: false,
1170                                                steps: vec![StepExpr::Filter(FilterExpr {
1171                                                    primary: PrimaryExpr::Literal(Literal::String(
1172                                                        "test".to_string(),
1173                                                    )),
1174                                                    predicates: PredicateListExpr {
1175                                                        predicates: vec![],
1176                                                    },
1177                                                })],
1178                                            })),
1179                                        )),
1180                                        predicates: PredicateListExpr { predicates: vec![] },
1181                                    })],
1182                                }),
1183                            }],
1184                        },
1185                    })],
1186                }),
1187            ),
1188            (
1189                "//div[position() > 1]/*[last()]",
1190                Expr::Path(PathExpr {
1191                    is_absolute: true,
1192                    is_descendant: true,
1193                    steps: vec![
1194                        StepExpr::Axis(AxisStep {
1195                            axis: Axis::Child,
1196                            node_test: NodeTest::Name(QName {
1197                                prefix: None,
1198                                local_part: "div".to_string(),
1199                            }),
1200                            predicates: PredicateListExpr {
1201                                predicates: vec![PredicateExpr {
1202                                    expr: Expr::Relational(
1203                                        Box::new(Expr::Path(PathExpr {
1204                                            is_absolute: false,
1205                                            is_descendant: false,
1206                                            steps: vec![StepExpr::Filter(FilterExpr {
1207                                                primary: PrimaryExpr::Function(
1208                                                    CoreFunction::Position,
1209                                                ),
1210                                                predicates: PredicateListExpr {
1211                                                    predicates: vec![],
1212                                                },
1213                                            })],
1214                                        })),
1215                                        RelationalOp::Gt,
1216                                        Box::new(Expr::Path(PathExpr {
1217                                            is_absolute: false,
1218                                            is_descendant: false,
1219                                            steps: vec![StepExpr::Filter(FilterExpr {
1220                                                primary: PrimaryExpr::Literal(Literal::Numeric(
1221                                                    NumericLiteral::Integer(1),
1222                                                )),
1223                                                predicates: PredicateListExpr {
1224                                                    predicates: vec![],
1225                                                },
1226                                            })],
1227                                        })),
1228                                    ),
1229                                }],
1230                            },
1231                        }),
1232                        StepExpr::Axis(AxisStep {
1233                            axis: Axis::Child,
1234                            node_test: NodeTest::Wildcard,
1235                            predicates: PredicateListExpr {
1236                                predicates: vec![PredicateExpr {
1237                                    expr: Expr::Path(PathExpr {
1238                                        is_absolute: false,
1239                                        is_descendant: false,
1240                                        steps: vec![StepExpr::Filter(FilterExpr {
1241                                            primary: PrimaryExpr::Function(CoreFunction::Last),
1242                                            predicates: PredicateListExpr { predicates: vec![] },
1243                                        })],
1244                                    }),
1245                                }],
1246                            },
1247                        }),
1248                    ],
1249                }),
1250            ),
1251            (
1252                "//mu[@xml:id=\"id1\"]//rho[@title][@xml:lang=\"en-GB\"]",
1253                Expr::Path(PathExpr {
1254                    is_absolute: true,
1255                    is_descendant: true,
1256                    steps: vec![
1257                        StepExpr::Axis(AxisStep {
1258                            axis: Axis::Child,
1259                            node_test: NodeTest::Name(QName {
1260                                prefix: None,
1261                                local_part: "mu".to_string(),
1262                            }),
1263                            predicates: PredicateListExpr {
1264                                predicates: vec![PredicateExpr {
1265                                    expr: Expr::Equality(
1266                                        Box::new(Expr::Path(PathExpr {
1267                                            is_absolute: false,
1268                                            is_descendant: false,
1269                                            steps: vec![StepExpr::Axis(AxisStep {
1270                                                axis: Axis::Attribute,
1271                                                node_test: NodeTest::Name(QName {
1272                                                    prefix: Some("xml".to_string()),
1273                                                    local_part: "id".to_string(),
1274                                                }),
1275                                                predicates: PredicateListExpr {
1276                                                    predicates: vec![],
1277                                                },
1278                                            })],
1279                                        })),
1280                                        EqualityOp::Eq,
1281                                        Box::new(Expr::Path(PathExpr {
1282                                            is_absolute: false,
1283                                            is_descendant: false,
1284                                            steps: vec![StepExpr::Filter(FilterExpr {
1285                                                primary: PrimaryExpr::Literal(Literal::String(
1286                                                    "id1".to_string(),
1287                                                )),
1288                                                predicates: PredicateListExpr {
1289                                                    predicates: vec![],
1290                                                },
1291                                            })],
1292                                        })),
1293                                    ),
1294                                }],
1295                            },
1296                        }),
1297                        StepExpr::Axis(AxisStep {
1298                            axis: Axis::DescendantOrSelf, // Represents the second '//'
1299                            node_test: NodeTest::Kind(KindTest::Node),
1300                            predicates: PredicateListExpr { predicates: vec![] },
1301                        }),
1302                        StepExpr::Axis(AxisStep {
1303                            axis: Axis::Child,
1304                            node_test: NodeTest::Name(QName {
1305                                prefix: None,
1306                                local_part: "rho".to_string(),
1307                            }),
1308                            predicates: PredicateListExpr {
1309                                predicates: vec![
1310                                    PredicateExpr {
1311                                        expr: Expr::Path(PathExpr {
1312                                            is_absolute: false,
1313                                            is_descendant: false,
1314                                            steps: vec![StepExpr::Axis(AxisStep {
1315                                                axis: Axis::Attribute,
1316                                                node_test: NodeTest::Name(QName {
1317                                                    prefix: None,
1318                                                    local_part: "title".to_string(),
1319                                                }),
1320                                                predicates: PredicateListExpr {
1321                                                    predicates: vec![],
1322                                                },
1323                                            })],
1324                                        }),
1325                                    },
1326                                    PredicateExpr {
1327                                        expr: Expr::Equality(
1328                                            Box::new(Expr::Path(PathExpr {
1329                                                is_absolute: false,
1330                                                is_descendant: false,
1331                                                steps: vec![StepExpr::Axis(AxisStep {
1332                                                    axis: Axis::Attribute,
1333                                                    node_test: NodeTest::Name(QName {
1334                                                        prefix: Some("xml".to_string()),
1335                                                        local_part: "lang".to_string(),
1336                                                    }),
1337                                                    predicates: PredicateListExpr {
1338                                                        predicates: vec![],
1339                                                    },
1340                                                })],
1341                                            })),
1342                                            EqualityOp::Eq,
1343                                            Box::new(Expr::Path(PathExpr {
1344                                                is_absolute: false,
1345                                                is_descendant: false,
1346                                                steps: vec![StepExpr::Filter(FilterExpr {
1347                                                    primary: PrimaryExpr::Literal(Literal::String(
1348                                                        "en-GB".to_string(),
1349                                                    )),
1350                                                    predicates: PredicateListExpr {
1351                                                        predicates: vec![],
1352                                                    },
1353                                                })],
1354                                            })),
1355                                        ),
1356                                    },
1357                                ],
1358                            },
1359                        }),
1360                    ],
1361                }),
1362            ),
1363        ];
1364
1365        for (input, expected) in cases {
1366            match parse(input) {
1367                Ok(result) => {
1368                    assert_eq!(result, expected);
1369                },
1370                Err(e) => panic!("Failed to parse '{}': {:?}", input, e),
1371            }
1372        }
1373    }
1374}