use std::cell::Cell;
use std::cmp::Ordering;
use dom_struct::dom_struct;
use html5ever::{local_name, namespace_url, ns, LocalName, QualName};
use servo_atoms::Atom;
use style::str::split_html_space_chars;
use crate::dom::bindings::codegen::Bindings::HTMLCollectionBinding::HTMLCollectionMethods;
use crate::dom::bindings::inheritance::Castable;
use crate::dom::bindings::reflector::{reflect_dom_object, Reflector};
use crate::dom::bindings::root::{Dom, DomRoot, MutNullableDom};
use crate::dom::bindings::str::DOMString;
use crate::dom::bindings::trace::JSTraceable;
use crate::dom::bindings::xmlname::namespace_from_domstring;
use crate::dom::element::Element;
use crate::dom::node::{document_from_node, Node};
use crate::dom::window::Window;
use crate::script_runtime::CanGc;
pub trait CollectionFilter: JSTraceable {
fn filter<'a>(&self, elem: &'a Element, root: &'a Node) -> bool;
}
#[derive(Clone, Copy, JSTraceable, MallocSizeOf)]
struct OptionU32 {
bits: u32,
}
impl OptionU32 {
fn to_option(self) -> Option<u32> {
if self.bits == u32::MAX {
None
} else {
Some(self.bits)
}
}
fn some(bits: u32) -> OptionU32 {
assert_ne!(bits, u32::MAX);
OptionU32 { bits }
}
fn none() -> OptionU32 {
OptionU32 { bits: u32::MAX }
}
}
#[dom_struct]
pub struct HTMLCollection {
reflector_: Reflector,
root: Dom<Node>,
#[ignore_malloc_size_of = "Trait object (Box<dyn CollectionFilter>) cannot be sized"]
filter: Box<dyn CollectionFilter + 'static>,
cached_version: Cell<u64>,
cached_cursor_element: MutNullableDom<Element>,
cached_cursor_index: Cell<OptionU32>,
cached_length: Cell<OptionU32>,
}
impl HTMLCollection {
#[allow(crown::unrooted_must_root)]
pub fn new_inherited(
root: &Node,
filter: Box<dyn CollectionFilter + 'static>,
) -> HTMLCollection {
HTMLCollection {
reflector_: Reflector::new(),
root: Dom::from_ref(root),
filter,
cached_version: Cell::new(root.inclusive_descendants_version()),
cached_cursor_element: MutNullableDom::new(None),
cached_cursor_index: Cell::new(OptionU32::none()),
cached_length: Cell::new(OptionU32::none()),
}
}
pub fn always_empty(window: &Window, root: &Node) -> DomRoot<Self> {
#[derive(JSTraceable)]
struct NoFilter;
impl CollectionFilter for NoFilter {
fn filter<'a>(&self, _: &'a Element, _: &'a Node) -> bool {
false
}
}
Self::new(window, root, Box::new(NoFilter))
}
#[allow(crown::unrooted_must_root)]
pub fn new(
window: &Window,
root: &Node,
filter: Box<dyn CollectionFilter + 'static>,
) -> DomRoot<Self> {
reflect_dom_object(
Box::new(Self::new_inherited(root, filter)),
window,
CanGc::note(),
)
}
pub(crate) fn new_with_filter_fn(
window: &Window,
root: &Node,
filter_function: fn(&Element, &Node) -> bool,
) -> DomRoot<Self> {
#[derive(JSTraceable, MallocSizeOf)]
pub(crate) struct StaticFunctionFilter(
#[no_trace]
#[ignore_malloc_size_of = "Static function pointer"]
fn(&Element, &Node) -> bool,
);
impl CollectionFilter for StaticFunctionFilter {
fn filter(&self, element: &Element, root: &Node) -> bool {
(self.0)(element, root)
}
}
Self::new(
window,
root,
Box::new(StaticFunctionFilter(filter_function)),
)
}
pub(crate) fn create(
window: &Window,
root: &Node,
filter: Box<dyn CollectionFilter + 'static>,
) -> DomRoot<Self> {
Self::new(window, root, filter)
}
fn validate_cache(&self) {
let cached_version = self.cached_version.get();
let curr_version = self.root.inclusive_descendants_version();
if curr_version != cached_version {
self.cached_version.set(curr_version);
self.cached_cursor_element.set(None);
self.cached_length.set(OptionU32::none());
self.cached_cursor_index.set(OptionU32::none());
}
}
fn set_cached_cursor(
&self,
index: u32,
element: Option<DomRoot<Element>>,
) -> Option<DomRoot<Element>> {
if let Some(element) = element {
self.cached_cursor_index.set(OptionU32::some(index));
self.cached_cursor_element.set(Some(&element));
Some(element)
} else {
None
}
}
pub fn by_qualified_name(
window: &Window,
root: &Node,
qualified_name: LocalName,
) -> DomRoot<HTMLCollection> {
if qualified_name == local_name!("*") {
#[derive(JSTraceable, MallocSizeOf)]
struct AllFilter;
impl CollectionFilter for AllFilter {
fn filter(&self, _elem: &Element, _root: &Node) -> bool {
true
}
}
return HTMLCollection::create(window, root, Box::new(AllFilter));
}
#[derive(JSTraceable, MallocSizeOf)]
struct HtmlDocumentFilter {
#[no_trace]
qualified_name: LocalName,
#[no_trace]
ascii_lower_qualified_name: LocalName,
}
impl CollectionFilter for HtmlDocumentFilter {
fn filter(&self, elem: &Element, root: &Node) -> bool {
if root.is_in_html_doc() && elem.namespace() == &ns!(html) {
HTMLCollection::match_element(elem, &self.ascii_lower_qualified_name)
} else {
HTMLCollection::match_element(elem, &self.qualified_name)
}
}
}
let filter = HtmlDocumentFilter {
ascii_lower_qualified_name: qualified_name.to_ascii_lowercase(),
qualified_name,
};
HTMLCollection::create(window, root, Box::new(filter))
}
fn match_element(elem: &Element, qualified_name: &LocalName) -> bool {
match elem.prefix().as_ref() {
None => elem.local_name() == qualified_name,
Some(prefix) => {
qualified_name.starts_with(&**prefix) &&
qualified_name.find(':') == Some(prefix.len()) &&
qualified_name.ends_with(&**elem.local_name())
},
}
}
pub fn by_tag_name_ns(
window: &Window,
root: &Node,
tag: DOMString,
maybe_ns: Option<DOMString>,
) -> DomRoot<HTMLCollection> {
let local = LocalName::from(tag);
let ns = namespace_from_domstring(maybe_ns);
let qname = QualName::new(None, ns, local);
HTMLCollection::by_qual_tag_name(window, root, qname)
}
pub fn by_qual_tag_name(
window: &Window,
root: &Node,
qname: QualName,
) -> DomRoot<HTMLCollection> {
#[derive(JSTraceable, MallocSizeOf)]
struct TagNameNSFilter {
#[no_trace]
qname: QualName,
}
impl CollectionFilter for TagNameNSFilter {
fn filter(&self, elem: &Element, _root: &Node) -> bool {
((self.qname.ns == namespace_url!("*")) || (self.qname.ns == *elem.namespace())) &&
((self.qname.local == local_name!("*")) ||
(self.qname.local == *elem.local_name()))
}
}
let filter = TagNameNSFilter { qname };
HTMLCollection::create(window, root, Box::new(filter))
}
pub fn by_class_name(
window: &Window,
root: &Node,
classes: DOMString,
) -> DomRoot<HTMLCollection> {
let class_atoms = split_html_space_chars(&classes).map(Atom::from).collect();
HTMLCollection::by_atomic_class_name(window, root, class_atoms)
}
pub fn by_atomic_class_name(
window: &Window,
root: &Node,
classes: Vec<Atom>,
) -> DomRoot<HTMLCollection> {
#[derive(JSTraceable, MallocSizeOf)]
struct ClassNameFilter {
#[no_trace]
classes: Vec<Atom>,
}
impl CollectionFilter for ClassNameFilter {
fn filter(&self, elem: &Element, _root: &Node) -> bool {
let case_sensitivity = document_from_node(elem)
.quirks_mode()
.classes_and_ids_case_sensitivity();
self.classes
.iter()
.all(|class| elem.has_class(class, case_sensitivity))
}
}
if classes.is_empty() {
return HTMLCollection::always_empty(window, root);
}
let filter = ClassNameFilter { classes };
HTMLCollection::create(window, root, Box::new(filter))
}
pub fn children(window: &Window, root: &Node) -> DomRoot<HTMLCollection> {
HTMLCollection::new_with_filter_fn(window, root, |element, root| {
root.is_parent_of(element.upcast())
})
}
pub fn elements_iter_after<'a>(
&'a self,
after: &'a Node,
) -> impl Iterator<Item = DomRoot<Element>> + 'a {
after
.following_nodes(&self.root)
.filter_map(DomRoot::downcast)
.filter(move |element| self.filter.filter(element, &self.root))
}
pub fn elements_iter(&self) -> impl Iterator<Item = DomRoot<Element>> + '_ {
self.elements_iter_after(&self.root)
}
pub fn elements_iter_before<'a>(
&'a self,
before: &'a Node,
) -> impl Iterator<Item = DomRoot<Element>> + 'a {
before
.preceding_nodes(&self.root)
.filter_map(DomRoot::downcast)
.filter(move |element| self.filter.filter(element, &self.root))
}
pub fn root_node(&self) -> DomRoot<Node> {
DomRoot::from_ref(&self.root)
}
}
impl HTMLCollectionMethods<crate::DomTypeHolder> for HTMLCollection {
fn Length(&self) -> u32 {
self.validate_cache();
if let Some(cached_length) = self.cached_length.get().to_option() {
cached_length
} else {
let length = self.elements_iter().count() as u32;
self.cached_length.set(OptionU32::some(length));
length
}
}
fn Item(&self, index: u32) -> Option<DomRoot<Element>> {
self.validate_cache();
if let Some(element) = self.cached_cursor_element.get() {
if let Some(cached_index) = self.cached_cursor_index.get().to_option() {
match cached_index.cmp(&index) {
Ordering::Equal => {
Some(element)
},
Ordering::Less => {
let offset = index - (cached_index + 1);
let node: DomRoot<Node> = DomRoot::upcast(element);
let mut iter = self.elements_iter_after(&node);
self.set_cached_cursor(index, iter.nth(offset as usize))
},
Ordering::Greater => {
let offset = cached_index - (index + 1);
let node: DomRoot<Node> = DomRoot::upcast(element);
let mut iter = self.elements_iter_before(&node);
self.set_cached_cursor(index, iter.nth(offset as usize))
},
}
} else {
self.set_cached_cursor(index, self.elements_iter().nth(index as usize))
}
} else {
self.set_cached_cursor(index, self.elements_iter().nth(index as usize))
}
}
fn NamedItem(&self, key: DOMString) -> Option<DomRoot<Element>> {
if key.is_empty() {
return None;
}
let key = Atom::from(key);
self.elements_iter().find(|elem| {
elem.get_id().is_some_and(|id| id == key) ||
(elem.namespace() == &ns!(html) && elem.get_name().is_some_and(|id| id == key))
})
}
fn IndexedGetter(&self, index: u32) -> Option<DomRoot<Element>> {
self.Item(index)
}
fn NamedGetter(&self, name: DOMString) -> Option<DomRoot<Element>> {
self.NamedItem(name)
}
fn SupportedPropertyNames(&self) -> Vec<DOMString> {
let mut result = vec![];
for elem in self.elements_iter() {
if let Some(id_atom) = elem.get_id() {
let id_str = DOMString::from(&*id_atom);
if !result.contains(&id_str) {
result.push(id_str);
}
}
if *elem.namespace() == ns!(html) {
if let Some(name_atom) = elem.get_name() {
let name_str = DOMString::from(&*name_atom);
if !result.contains(&name_str) {
result.push(name_str)
}
}
}
}
result
}
}