icu_segmenter/complex/
dictionary.rs

1// This file is part of ICU4X. For terms of use, please see the file
2// called LICENSE at the top level of the ICU4X source tree
3// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
5use crate::grapheme::*;
6use crate::indices::Utf16Indices;
7use crate::provider::*;
8use core::str::CharIndices;
9use icu_collections::char16trie::{Char16Trie, TrieResult};
10
11/// A trait for dictionary based iterator
12trait DictionaryType<'l, 's> {
13    /// The iterator over characters.
14    type IterAttr: Iterator<Item = (usize, Self::CharType)> + Clone;
15
16    /// The character type.
17    type CharType: Copy + Into<u32>;
18
19    fn to_char(c: Self::CharType) -> char;
20    fn char_len(c: Self::CharType) -> usize;
21}
22
23struct DictionaryBreakIterator<
24    'l,
25    's,
26    Y: DictionaryType<'l, 's> + ?Sized,
27    X: Iterator<Item = usize> + ?Sized,
28> {
29    trie: Char16Trie<'l>,
30    iter: Y::IterAttr,
31    len: usize,
32    grapheme_iter: X,
33    // TODO transform value for byte trie
34}
35
36/// Implement the [`Iterator`] trait over the segmenter break opportunities of the given string.
37/// Please see the [module-level documentation](crate) for its usages.
38///
39/// Lifetimes:
40/// - `'l` = lifetime of the segmenter object from which this iterator was created
41/// - `'s` = lifetime of the string being segmented
42///
43/// [`Iterator`]: core::iter::Iterator
44impl<'l, 's, Y: DictionaryType<'l, 's> + ?Sized, X: Iterator<Item = usize> + ?Sized> Iterator
45    for DictionaryBreakIterator<'l, 's, Y, X>
46{
47    type Item = usize;
48
49    fn next(&mut self) -> Option<Self::Item> {
50        let mut trie_iter = self.trie.iter();
51        let mut intermediate_length = 0;
52        let mut not_match = false;
53        let mut previous_match = None;
54        let mut last_grapheme_offset = 0;
55
56        while let Some(next) = self.iter.next() {
57            let ch = Y::to_char(next.1);
58            match trie_iter.next(ch) {
59                TrieResult::FinalValue(_) => {
60                    return Some(next.0 + Y::char_len(next.1));
61                }
62                TrieResult::Intermediate(_) => {
63                    // Dictionary has to match with grapheme cluster segment.
64                    // If not, we ignore it.
65                    while last_grapheme_offset < next.0 + Y::char_len(next.1) {
66                        if let Some(offset) = self.grapheme_iter.next() {
67                            last_grapheme_offset = offset;
68                            continue;
69                        }
70                        last_grapheme_offset = self.len;
71                        break;
72                    }
73                    if last_grapheme_offset != next.0 + Y::char_len(next.1) {
74                        continue;
75                    }
76
77                    intermediate_length = next.0 + Y::char_len(next.1);
78                    previous_match = Some(self.iter.clone());
79                }
80                TrieResult::NoMatch => {
81                    if intermediate_length > 0 {
82                        if let Some(previous_match) = previous_match {
83                            // Rewind previous match point
84                            self.iter = previous_match;
85                        }
86                        return Some(intermediate_length);
87                    }
88                    // Not found
89                    return Some(next.0 + Y::char_len(next.1));
90                }
91                TrieResult::NoValue => {
92                    // Prefix string is matched
93                    not_match = true;
94                }
95            }
96        }
97
98        if intermediate_length > 0 {
99            Some(intermediate_length)
100        } else if not_match {
101            // no match by scanning text
102            Some(self.len)
103        } else {
104            None
105        }
106    }
107}
108
109impl<'l, 's> DictionaryType<'l, 's> for u32 {
110    type IterAttr = Utf16Indices<'s>;
111    type CharType = u32;
112
113    fn to_char(c: u32) -> char {
114        char::from_u32(c).unwrap_or(char::REPLACEMENT_CHARACTER)
115    }
116
117    fn char_len(c: u32) -> usize {
118        if c >= 0x10000 {
119            2
120        } else {
121            1
122        }
123    }
124}
125
126impl<'l, 's> DictionaryType<'l, 's> for char {
127    type IterAttr = CharIndices<'s>;
128    type CharType = char;
129
130    fn to_char(c: char) -> char {
131        c
132    }
133
134    fn char_len(c: char) -> usize {
135        c.len_utf8()
136    }
137}
138
139pub(super) struct DictionarySegmenter<'l> {
140    dict: &'l UCharDictionaryBreakDataV1<'l>,
141    grapheme: &'l RuleBreakDataV1<'l>,
142}
143
144impl<'l> DictionarySegmenter<'l> {
145    pub(super) fn new(
146        dict: &'l UCharDictionaryBreakDataV1<'l>,
147        grapheme: &'l RuleBreakDataV1<'l>,
148    ) -> Self {
149        // TODO: no way to verify trie data
150        Self { dict, grapheme }
151    }
152
153    /// Create a dictionary based break iterator for an `str` (a UTF-8 string).
154    pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator<Item = usize> + 'l {
155        let grapheme_iter = GraphemeClusterSegmenter::new_and_segment_str(input, self.grapheme);
156        DictionaryBreakIterator::<char, GraphemeClusterBreakIteratorUtf8> {
157            trie: Char16Trie::new(self.dict.trie_data.clone()),
158            iter: input.char_indices(),
159            len: input.len(),
160            grapheme_iter,
161        }
162    }
163
164    /// Create a dictionary based break iterator for a UTF-16 string.
165    pub(super) fn segment_utf16(&'l self, input: &'l [u16]) -> impl Iterator<Item = usize> + 'l {
166        let grapheme_iter = GraphemeClusterSegmenter::new_and_segment_utf16(input, self.grapheme);
167        DictionaryBreakIterator::<u32, GraphemeClusterBreakIteratorUtf16> {
168            trie: Char16Trie::new(self.dict.trie_data.clone()),
169            iter: Utf16Indices::new(input),
170            len: input.len(),
171            grapheme_iter,
172        }
173    }
174}
175
176#[cfg(test)]
177#[cfg(feature = "serde")]
178mod tests {
179    use super::*;
180    use crate::{LineSegmenter, WordSegmenter};
181    use icu_provider::prelude::*;
182
183    #[test]
184    fn burmese_dictionary_test() {
185        let segmenter = LineSegmenter::new_dictionary();
186        // From css/css-text/word-break/word-break-normal-my-000.html
187        let s = "မြန်မာစာမြန်မာစာမြန်မာစာ";
188        let result: Vec<usize> = segmenter.segment_str(s).collect();
189        assert_eq!(result, vec![0, 18, 24, 42, 48, 66, 72]);
190
191        let s_utf16: Vec<u16> = s.encode_utf16().collect();
192        let result: Vec<usize> = segmenter.segment_utf16(&s_utf16).collect();
193        assert_eq!(result, vec![0, 6, 8, 14, 16, 22, 24]);
194    }
195
196    #[test]
197    fn cj_dictionary_test() {
198        let dict_payload: DataPayload<DictionaryForWordOnlyAutoV1Marker> = crate::provider::Baked
199            .load(DataRequest {
200                locale: &icu_locid::langid!("ja").into(),
201                metadata: Default::default(),
202            })
203            .unwrap()
204            .take_payload()
205            .unwrap();
206        let word_segmenter = WordSegmenter::new_dictionary();
207        let dict_segmenter = DictionarySegmenter::new(
208            dict_payload.get(),
209            crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
210        );
211
212        // Match case
213        let s = "龟山岛龟山岛";
214        let result: Vec<usize> = dict_segmenter.segment_str(s).collect();
215        assert_eq!(result, vec![9, 18]);
216
217        let result: Vec<usize> = word_segmenter.segment_str(s).collect();
218        assert_eq!(result, vec![0, 9, 18]);
219
220        let s_utf16: Vec<u16> = s.encode_utf16().collect();
221        let result: Vec<usize> = dict_segmenter.segment_utf16(&s_utf16).collect();
222        assert_eq!(result, vec![3, 6]);
223
224        let result: Vec<usize> = word_segmenter.segment_utf16(&s_utf16).collect();
225        assert_eq!(result, vec![0, 3, 6]);
226
227        // Match case, then no match case
228        let s = "エディターエディ";
229        let result: Vec<usize> = dict_segmenter.segment_str(s).collect();
230        assert_eq!(result, vec![15, 24]);
231
232        // TODO(#3236): Why is WordSegmenter not returning the middle segment?
233        let result: Vec<usize> = word_segmenter.segment_str(s).collect();
234        assert_eq!(result, vec![0, 24]);
235
236        let s_utf16: Vec<u16> = s.encode_utf16().collect();
237        let result: Vec<usize> = dict_segmenter.segment_utf16(&s_utf16).collect();
238        assert_eq!(result, vec![5, 8]);
239
240        // TODO(#3236): Why is WordSegmenter not returning the middle segment?
241        let result: Vec<usize> = word_segmenter.segment_utf16(&s_utf16).collect();
242        assert_eq!(result, vec![0, 8]);
243    }
244
245    #[test]
246    fn khmer_dictionary_test() {
247        let segmenter = LineSegmenter::new_dictionary();
248        let s = "ភាសាខ្មែរភាសាខ្មែរភាសាខ្មែរ";
249        let result: Vec<usize> = segmenter.segment_str(s).collect();
250        assert_eq!(result, vec![0, 27, 54, 81]);
251
252        let s_utf16: Vec<u16> = s.encode_utf16().collect();
253        let result: Vec<usize> = segmenter.segment_utf16(&s_utf16).collect();
254        assert_eq!(result, vec![0, 9, 18, 27]);
255    }
256
257    #[test]
258    fn lao_dictionary_test() {
259        let segmenter = LineSegmenter::new_dictionary();
260        let s = "ພາສາລາວພາສາລາວພາສາລາວ";
261        let r: Vec<usize> = segmenter.segment_str(s).collect();
262        assert_eq!(r, vec![0, 12, 21, 33, 42, 54, 63]);
263
264        let s_utf16: Vec<u16> = s.encode_utf16().collect();
265        let r: Vec<usize> = segmenter.segment_utf16(&s_utf16).collect();
266        assert_eq!(r, vec![0, 4, 7, 11, 14, 18, 21]);
267    }
268}