script/xpath/
eval_function.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5use style::Atom;
6
7use super::Value;
8use super::context::EvaluationCtx;
9use super::eval::{Error, Evaluatable, try_extract_nodeset};
10use super::parser::CoreFunction;
11use crate::dom::bindings::codegen::Bindings::NodeBinding::NodeMethods;
12use crate::dom::bindings::inheritance::{Castable, NodeTypeId};
13use crate::dom::bindings::root::DomRoot;
14use crate::dom::element::Element;
15use crate::dom::node::Node;
16
17/// Returns e.g. "rect" for `<svg:rect>`
18fn local_name(node: &Node) -> Option<String> {
19    if matches!(Node::type_id(node), NodeTypeId::Element(_)) {
20        let element = node.downcast::<Element>().unwrap();
21        Some(element.local_name().to_string())
22    } else {
23        None
24    }
25}
26
27/// Returns e.g. "svg:rect" for `<svg:rect>`
28fn name(node: &Node) -> Option<String> {
29    if matches!(Node::type_id(node), NodeTypeId::Element(_)) {
30        let element = node.downcast::<Element>().unwrap();
31        if let Some(prefix) = element.prefix().as_ref() {
32            Some(format!("{}:{}", prefix, element.local_name()))
33        } else {
34            Some(element.local_name().to_string())
35        }
36    } else {
37        None
38    }
39}
40
41/// Returns e.g. the SVG namespace URI for `<svg:rect>`
42fn namespace_uri(node: &Node) -> Option<String> {
43    if matches!(Node::type_id(node), NodeTypeId::Element(_)) {
44        let element = node.downcast::<Element>().unwrap();
45        Some(element.namespace().to_string())
46    } else {
47        None
48    }
49}
50
51/// Returns the text contents of the Node, or empty string if none.
52fn string_value(node: &Node) -> String {
53    node.GetTextContent().unwrap_or_default().to_string()
54}
55
56/// If s2 is found inside s1, return everything *before* s2. Return all of s1 otherwise.
57fn substring_before(s1: &str, s2: &str) -> String {
58    match s1.find(s2) {
59        Some(pos) => s1[..pos].to_string(),
60        None => String::new(),
61    }
62}
63
64/// If s2 is found inside s1, return everything *after* s2. Return all of s1 otherwise.
65fn substring_after(s1: &str, s2: &str) -> String {
66    match s1.find(s2) {
67        Some(pos) => s1[pos + s2.len()..].to_string(),
68        None => String::new(),
69    }
70}
71
72fn substring(s: &str, start_idx: isize, len: Option<isize>) -> String {
73    let s_len = s.len();
74    let len = len.unwrap_or(s_len as isize).max(0) as usize;
75    let start_idx = start_idx.max(0) as usize;
76    let end_idx = (start_idx + len.max(0)).min(s_len);
77    s[start_idx..end_idx].to_string()
78}
79
80/// <https://www.w3.org/TR/1999/REC-xpath-19991116/#function-normalize-space>
81pub(crate) fn normalize_space(s: &str) -> String {
82    let mut result = String::with_capacity(s.len());
83    let mut last_was_whitespace = true; // Handles leading whitespace
84
85    for c in s.chars() {
86        match c {
87            '\x20' | '\x09' | '\x0D' | '\x0A' => {
88                if !last_was_whitespace {
89                    result.push(' ');
90                    last_was_whitespace = true;
91                }
92            },
93            other => {
94                result.push(other);
95                last_was_whitespace = false;
96            },
97        }
98    }
99
100    if last_was_whitespace {
101        result.pop();
102    }
103
104    result
105}
106
107/// <https://www.w3.org/TR/1999/REC-xpath-19991116/#function-lang>
108fn lang_matches(context_lang: Option<&str>, target_lang: &str) -> bool {
109    let Some(context_lang) = context_lang else {
110        return false;
111    };
112
113    let context_lower = context_lang.to_ascii_lowercase();
114    let target_lower = target_lang.to_ascii_lowercase();
115
116    if context_lower == target_lower {
117        return true;
118    }
119
120    // Check if context is target with additional suffix
121    if context_lower.starts_with(&target_lower) {
122        // Make sure the next character is a hyphen to avoid matching
123        // e.g. "england" when target is "en"
124        if let Some(next_char) = context_lower.chars().nth(target_lower.len()) {
125            return next_char == '-';
126        }
127    }
128
129    false
130}
131
132impl Evaluatable for CoreFunction {
133    fn evaluate(&self, context: &EvaluationCtx) -> Result<Value, Error> {
134        match self {
135            CoreFunction::Last => {
136                let predicate_ctx = context.predicate_ctx.ok_or_else(|| Error::Internal {
137                    msg: "[CoreFunction] last() is only usable as a predicate".to_string(),
138                })?;
139                Ok(Value::Number(predicate_ctx.size as f64))
140            },
141            CoreFunction::Position => {
142                let predicate_ctx = context.predicate_ctx.ok_or_else(|| Error::Internal {
143                    msg: "[CoreFunction] position() is only usable as a predicate".to_string(),
144                })?;
145                Ok(Value::Number(predicate_ctx.index as f64))
146            },
147            CoreFunction::Count(expr) => {
148                let nodes = expr.evaluate(context).and_then(try_extract_nodeset)?;
149                Ok(Value::Number(nodes.len() as f64))
150            },
151            CoreFunction::String(expr_opt) => match expr_opt {
152                Some(expr) => Ok(Value::String(expr.evaluate(context)?.string())),
153                None => Ok(Value::String(string_value(&context.context_node))),
154            },
155            CoreFunction::Concat(exprs) => {
156                let strings: Result<Vec<_>, _> = exprs
157                    .iter()
158                    .map(|e| Ok(e.evaluate(context)?.string()))
159                    .collect();
160                Ok(Value::String(strings?.join("")))
161            },
162            CoreFunction::Id(expr) => {
163                let args_str = expr.evaluate(context)?.string();
164                let args_normalized = normalize_space(&args_str);
165                let args = args_normalized.split(' ');
166
167                let document = context.context_node.owner_doc();
168                let mut result = Vec::new();
169                for arg in args {
170                    for element in document.get_elements_with_id(&Atom::from(arg)).iter() {
171                        result.push(DomRoot::from_ref(element.upcast::<Node>()));
172                    }
173                }
174                Ok(Value::Nodeset(result))
175            },
176            CoreFunction::LocalName(expr_opt) => {
177                let node = match expr_opt {
178                    Some(expr) => expr
179                        .evaluate(context)
180                        .and_then(try_extract_nodeset)?
181                        .first()
182                        .cloned(),
183                    None => Some(context.context_node.clone()),
184                };
185                let name = node.and_then(|n| local_name(&n)).unwrap_or_default();
186                Ok(Value::String(name.to_string()))
187            },
188            CoreFunction::NamespaceUri(expr_opt) => {
189                let node = match expr_opt {
190                    Some(expr) => expr
191                        .evaluate(context)
192                        .and_then(try_extract_nodeset)?
193                        .first()
194                        .cloned(),
195                    None => Some(context.context_node.clone()),
196                };
197                let ns = node.and_then(|n| namespace_uri(&n)).unwrap_or_default();
198                Ok(Value::String(ns.to_string()))
199            },
200            CoreFunction::Name(expr_opt) => {
201                let node = match expr_opt {
202                    Some(expr) => expr
203                        .evaluate(context)
204                        .and_then(try_extract_nodeset)?
205                        .first()
206                        .cloned(),
207                    None => Some(context.context_node.clone()),
208                };
209                let name = node.and_then(|n| name(&n)).unwrap_or_default();
210                Ok(Value::String(name))
211            },
212            CoreFunction::StartsWith(str1, str2) => {
213                let s1 = str1.evaluate(context)?.string();
214                let s2 = str2.evaluate(context)?.string();
215                Ok(Value::Boolean(s1.starts_with(&s2)))
216            },
217            CoreFunction::Contains(str1, str2) => {
218                let s1 = str1.evaluate(context)?.string();
219                let s2 = str2.evaluate(context)?.string();
220                Ok(Value::Boolean(s1.contains(&s2)))
221            },
222            CoreFunction::SubstringBefore(str1, str2) => {
223                let s1 = str1.evaluate(context)?.string();
224                let s2 = str2.evaluate(context)?.string();
225                Ok(Value::String(substring_before(&s1, &s2)))
226            },
227            CoreFunction::SubstringAfter(str1, str2) => {
228                let s1 = str1.evaluate(context)?.string();
229                let s2 = str2.evaluate(context)?.string();
230                Ok(Value::String(substring_after(&s1, &s2)))
231            },
232            CoreFunction::Substring(str1, start, length_opt) => {
233                let s = str1.evaluate(context)?.string();
234                let start_idx = start.evaluate(context)?.number().round() as isize - 1;
235                let len = match length_opt {
236                    Some(len_expr) => Some(len_expr.evaluate(context)?.number().round() as isize),
237                    None => None,
238                };
239                Ok(Value::String(substring(&s, start_idx, len)))
240            },
241            CoreFunction::StringLength(expr_opt) => {
242                let s = match expr_opt {
243                    Some(expr) => expr.evaluate(context)?.string(),
244                    None => string_value(&context.context_node),
245                };
246                Ok(Value::Number(s.chars().count() as f64))
247            },
248            CoreFunction::NormalizeSpace(expr_opt) => {
249                let s = match expr_opt {
250                    Some(expr) => expr.evaluate(context)?.string(),
251                    None => string_value(&context.context_node),
252                };
253
254                Ok(Value::String(normalize_space(&s)))
255            },
256            CoreFunction::Translate(str1, str2, str3) => {
257                let s = str1.evaluate(context)?.string();
258                let from = str2.evaluate(context)?.string();
259                let to = str3.evaluate(context)?.string();
260                let result = s
261                    .chars()
262                    .map(|c| match from.find(c) {
263                        Some(i) if i < to.chars().count() => to.chars().nth(i).unwrap(),
264                        _ => c,
265                    })
266                    .collect();
267                Ok(Value::String(result))
268            },
269            CoreFunction::Number(expr_opt) => {
270                let val = match expr_opt {
271                    Some(expr) => expr.evaluate(context)?,
272                    None => Value::String(string_value(&context.context_node)),
273                };
274                Ok(Value::Number(val.number()))
275            },
276            CoreFunction::Sum(expr) => {
277                let nodes = expr.evaluate(context).and_then(try_extract_nodeset)?;
278                let sum = nodes
279                    .iter()
280                    .map(|n| Value::String(string_value(n)).number())
281                    .sum();
282                Ok(Value::Number(sum))
283            },
284            CoreFunction::Floor(expr) => {
285                let num = expr.evaluate(context)?.number();
286                Ok(Value::Number(num.floor()))
287            },
288            CoreFunction::Ceiling(expr) => {
289                let num = expr.evaluate(context)?.number();
290                Ok(Value::Number(num.ceil()))
291            },
292            CoreFunction::Round(expr) => {
293                let num = expr.evaluate(context)?.number();
294                Ok(Value::Number(num.round()))
295            },
296            CoreFunction::Boolean(expr) => Ok(Value::Boolean(expr.evaluate(context)?.boolean())),
297            CoreFunction::Not(expr) => Ok(Value::Boolean(!expr.evaluate(context)?.boolean())),
298            CoreFunction::True => Ok(Value::Boolean(true)),
299            CoreFunction::False => Ok(Value::Boolean(false)),
300            CoreFunction::Lang(expr) => {
301                let context_lang = context.context_node.get_lang();
302                let lang = expr.evaluate(context)?.string();
303                Ok(Value::Boolean(lang_matches(context_lang.as_deref(), &lang)))
304            },
305        }
306    }
307
308    fn is_primitive(&self) -> bool {
309        match self {
310            CoreFunction::Last => false,
311            CoreFunction::Position => false,
312            CoreFunction::Count(_) => false,
313            CoreFunction::Id(_) => false,
314            CoreFunction::LocalName(_) => false,
315            CoreFunction::NamespaceUri(_) => false,
316            CoreFunction::Name(_) => false,
317            CoreFunction::String(expr_opt) => expr_opt
318                .as_ref()
319                .map(|expr| expr.is_primitive())
320                .unwrap_or(false),
321            CoreFunction::Concat(vec) => vec.iter().all(|expr| expr.is_primitive()),
322            CoreFunction::StartsWith(expr, substr) => expr.is_primitive() && substr.is_primitive(),
323            CoreFunction::Contains(expr, substr) => expr.is_primitive() && substr.is_primitive(),
324            CoreFunction::SubstringBefore(expr, substr) => {
325                expr.is_primitive() && substr.is_primitive()
326            },
327            CoreFunction::SubstringAfter(expr, substr) => {
328                expr.is_primitive() && substr.is_primitive()
329            },
330            CoreFunction::Substring(expr, start_pos, length_opt) => {
331                expr.is_primitive() &&
332                    start_pos.is_primitive() &&
333                    length_opt
334                        .as_ref()
335                        .map(|length| length.is_primitive())
336                        .unwrap_or(false)
337            },
338            CoreFunction::StringLength(expr_opt) => expr_opt
339                .as_ref()
340                .map(|expr| expr.is_primitive())
341                .unwrap_or(false),
342            CoreFunction::NormalizeSpace(expr_opt) => expr_opt
343                .as_ref()
344                .map(|expr| expr.is_primitive())
345                .unwrap_or(false),
346            CoreFunction::Translate(expr, from_chars, to_chars) => {
347                expr.is_primitive() && from_chars.is_primitive() && to_chars.is_primitive()
348            },
349            CoreFunction::Number(expr_opt) => expr_opt
350                .as_ref()
351                .map(|expr| expr.is_primitive())
352                .unwrap_or(false),
353            CoreFunction::Sum(expr) => expr.is_primitive(),
354            CoreFunction::Floor(expr) => expr.is_primitive(),
355            CoreFunction::Ceiling(expr) => expr.is_primitive(),
356            CoreFunction::Round(expr) => expr.is_primitive(),
357            CoreFunction::Boolean(expr) => expr.is_primitive(),
358            CoreFunction::Not(expr) => expr.is_primitive(),
359            CoreFunction::True => true,
360            CoreFunction::False => true,
361            CoreFunction::Lang(_) => false,
362        }
363    }
364}
365#[cfg(test)]
366mod tests {
367    use super::{lang_matches, substring, substring_after, substring_before};
368
369    #[test]
370    fn test_substring_before() {
371        assert_eq!(substring_before("hello world", "world"), "hello ");
372        assert_eq!(substring_before("prefix:name", ":"), "prefix");
373        assert_eq!(substring_before("no-separator", "xyz"), "");
374        assert_eq!(substring_before("", "anything"), "");
375        assert_eq!(substring_before("multiple:colons:here", ":"), "multiple");
376        assert_eq!(substring_before("start-match-test", "start"), "");
377    }
378
379    #[test]
380    fn test_substring_after() {
381        assert_eq!(substring_after("hello world", "hello "), "world");
382        assert_eq!(substring_after("prefix:name", ":"), "name");
383        assert_eq!(substring_after("no-separator", "xyz"), "");
384        assert_eq!(substring_after("", "anything"), "");
385        assert_eq!(substring_after("multiple:colons:here", ":"), "colons:here");
386        assert_eq!(substring_after("test-end-match", "match"), "");
387    }
388
389    #[test]
390    fn test_substring() {
391        assert_eq!(substring("hello world", 0, Some(5)), "hello");
392        assert_eq!(substring("hello world", 6, Some(5)), "world");
393        assert_eq!(substring("hello", 1, Some(3)), "ell");
394        assert_eq!(substring("hello", -5, Some(2)), "he");
395        assert_eq!(substring("hello", 0, None), "hello");
396        assert_eq!(substring("hello", 2, Some(10)), "llo");
397        assert_eq!(substring("hello", 5, Some(1)), "");
398        assert_eq!(substring("", 0, Some(5)), "");
399        assert_eq!(substring("hello", 0, Some(0)), "");
400        assert_eq!(substring("hello", 0, Some(-5)), "");
401    }
402
403    #[test]
404    fn test_lang_matches() {
405        assert!(lang_matches(Some("en"), "en"));
406        assert!(lang_matches(Some("EN"), "en"));
407        assert!(lang_matches(Some("en"), "EN"));
408        assert!(lang_matches(Some("en-US"), "en"));
409        assert!(lang_matches(Some("en-GB"), "en"));
410
411        assert!(!lang_matches(Some("eng"), "en"));
412        assert!(!lang_matches(Some("fr"), "en"));
413        assert!(!lang_matches(Some("fr-en"), "en"));
414        assert!(!lang_matches(None, "en"));
415    }
416}