icu_segmenter/
word.rs

1// This file is part of ICU4X. For terms of use, please see the file
2// called LICENSE at the top level of the ICU4X source tree
3// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
5use crate::complex::*;
6use crate::indices::{Latin1Indices, Utf16Indices};
7use crate::iterator_helpers::derive_usize_iterator_with_type;
8use crate::provider::*;
9use crate::rule_segmenter::*;
10use crate::SegmenterError;
11use alloc::string::String;
12use alloc::vec;
13use alloc::vec::Vec;
14use core::str::CharIndices;
15use icu_provider::prelude::*;
16use utf8_iter::Utf8CharIndices;
17
18/// Implements the [`Iterator`] trait over the word boundaries of the given string.
19///
20/// Lifetimes:
21///
22/// - `'l` = lifetime of the segmenter object from which this iterator was created
23/// - `'s` = lifetime of the string being segmented
24///
25/// The [`Iterator::Item`] is an [`usize`] representing index of a code unit
26/// _after_ the boundary (for a boundary at the end of text, this index is the length
27/// of the [`str`] or array of code units).
28///
29/// For examples of use, see [`WordSegmenter`].
30#[derive(Debug)]
31pub struct WordBreakIterator<'l, 's, Y: RuleBreakType<'l, 's> + ?Sized>(
32    RuleBreakIterator<'l, 's, Y>,
33);
34
35derive_usize_iterator_with_type!(WordBreakIterator);
36
37/// The word type tag that is returned by [`WordBreakIterator::word_type()`].
38#[non_exhaustive]
39#[derive(Copy, Clone, PartialEq, Debug)]
40#[repr(u8)]
41#[zerovec::make_ule(WordTypeULE)]
42pub enum WordType {
43    /// No category tag.
44    None = 0,
45    /// Number category tag.
46    Number = 1,
47    /// Letter category tag, including CJK.
48    Letter = 2,
49}
50
51impl WordType {
52    /// Whether the segment is word-like; word-like segments include numbers, as
53    /// well as segments made up of letters (including CJKV ideographs).
54    pub fn is_word_like(&self) -> bool {
55        self != &WordType::None
56    }
57}
58
59impl<'l, 's, Y: RuleBreakType<'l, 's> + ?Sized> WordBreakIterator<'l, 's, Y> {
60    /// Returns the word type of the segment preceding the current boundary.
61    #[inline]
62    pub fn word_type(&self) -> WordType {
63        self.0.word_type()
64    }
65
66    /// Returns an iterator over pairs of boundary position and word type.
67    pub fn iter_with_word_type<'i: 'l + 's>(
68        &'i mut self,
69    ) -> impl Iterator<Item = (usize, WordType)> + '_ {
70        core::iter::from_fn(move || self.next().map(|i| (i, self.word_type())))
71    }
72
73    /// Returns `true` when the segment preceding the current boundary is word-like,
74    /// such as letters, numbers, or CJKV ideographs.
75    #[inline]
76    pub fn is_word_like(&self) -> bool {
77        self.word_type().is_word_like()
78    }
79}
80
81/// Word break iterator for an `str` (a UTF-8 string).
82///
83/// For examples of use, see [`WordSegmenter`].
84pub type WordBreakIteratorUtf8<'l, 's> = WordBreakIterator<'l, 's, WordBreakTypeUtf8>;
85
86/// Word break iterator for a potentially invalid UTF-8 string.
87///
88/// For examples of use, see [`WordSegmenter`].
89pub type WordBreakIteratorPotentiallyIllFormedUtf8<'l, 's> =
90    WordBreakIterator<'l, 's, WordBreakTypePotentiallyIllFormedUtf8>;
91
92/// Word break iterator for a Latin-1 (8-bit) string.
93///
94/// For examples of use, see [`WordSegmenter`].
95pub type WordBreakIteratorLatin1<'l, 's> = WordBreakIterator<'l, 's, RuleBreakTypeLatin1>;
96
97/// Word break iterator for a UTF-16 string.
98///
99/// For examples of use, see [`WordSegmenter`].
100pub type WordBreakIteratorUtf16<'l, 's> = WordBreakIterator<'l, 's, WordBreakTypeUtf16>;
101
102/// Supports loading word break data, and creating word break iterators for different string
103/// encodings.
104///
105/// # Examples
106///
107/// Segment a string:
108///
109/// ```rust
110/// use icu::segmenter::WordSegmenter;
111/// let segmenter = WordSegmenter::new_auto();
112///
113/// let breakpoints: Vec<usize> =
114///     segmenter.segment_str("Hello World").collect();
115/// assert_eq!(&breakpoints, &[0, 5, 6, 11]);
116/// ```
117///
118/// Segment a Latin1 byte string:
119///
120/// ```rust
121/// use icu::segmenter::WordSegmenter;
122/// let segmenter = WordSegmenter::new_auto();
123///
124/// let breakpoints: Vec<usize> =
125///     segmenter.segment_latin1(b"Hello World").collect();
126/// assert_eq!(&breakpoints, &[0, 5, 6, 11]);
127/// ```
128///
129/// Successive boundaries can be used to retrieve the segments.
130/// In particular, the first boundary is always 0, and the last one is the
131/// length of the segmented text in code units.
132///
133/// ```rust
134/// # use icu::segmenter::WordSegmenter;
135/// # let segmenter = WordSegmenter::new_auto();
136/// use itertools::Itertools;
137/// let text = "Mark’d ye his words?";
138/// let segments: Vec<&str> = segmenter
139///     .segment_str(text)
140///     .tuple_windows()
141///     .map(|(i, j)| &text[i..j])
142///     .collect();
143/// assert_eq!(
144///     &segments,
145///     &["Mark’d", " ", "ye", " ", "his", " ", "words", "?"]
146/// );
147/// ```
148///
149/// Not all segments delimited by word boundaries are words; some are interword
150/// segments such as spaces and punctuation.
151/// The [`WordBreakIterator::word_type()`] of a boundary can be used to
152/// classify the preceding segment; [`WordBreakIterator::iter_with_word_type()`]
153/// associates each boundary with its status.
154/// ```rust
155/// # use itertools::Itertools;
156/// # use icu::segmenter::{WordType, WordSegmenter};
157/// # let segmenter = WordSegmenter::new_auto();
158/// # let text = "Mark’d ye his words?";
159/// let words: Vec<&str> = segmenter
160///     .segment_str(text)
161///     .iter_with_word_type()
162///     .tuple_windows()
163///     .filter(|(_, (_, segment_type))| segment_type.is_word_like())
164///     .map(|((i, _), (j, _))| &text[i..j])
165///     .collect();
166/// assert_eq!(&words, &["Mark’d", "ye", "his", "words"]);
167/// ```
168#[derive(Debug)]
169pub struct WordSegmenter {
170    payload: DataPayload<WordBreakDataV1Marker>,
171    complex: ComplexPayloads,
172}
173
174impl WordSegmenter {
175    /// Constructs a [`WordSegmenter`] with an invariant locale and the best available compiled data for
176    /// complex scripts (Chinese, Japanese, Khmer, Lao, Myanmar, and Thai).
177    ///
178    /// The current behavior, which is subject to change, is to use the LSTM model when available
179    /// and the dictionary model for Chinese and Japanese.
180    ///
181    /// ✨ *Enabled with the `compiled_data` and `auto` Cargo features.*
182    ///
183    /// [📚 Help choosing a constructor](icu_provider::constructors)
184    ///
185    /// # Examples
186    ///
187    /// Behavior with complex scripts:
188    ///
189    /// ```
190    /// use icu::segmenter::WordSegmenter;
191    ///
192    /// let th_str = "ทุกสองสัปดาห์";
193    /// let ja_str = "こんにちは世界";
194    ///
195    /// let segmenter = WordSegmenter::new_auto();
196    ///
197    /// let th_bps = segmenter.segment_str(th_str).collect::<Vec<_>>();
198    /// let ja_bps = segmenter.segment_str(ja_str).collect::<Vec<_>>();
199    ///
200    /// assert_eq!(th_bps, [0, 9, 18, 39]);
201    /// assert_eq!(ja_bps, [0, 15, 21]);
202    /// ```
203    #[cfg(feature = "compiled_data")]
204    #[cfg(feature = "auto")]
205    pub fn new_auto() -> Self {
206        Self {
207            payload: DataPayload::from_static_ref(
208                crate::provider::Baked::SINGLETON_SEGMENTER_WORD_V1,
209            ),
210            complex: ComplexPayloads::new_auto(),
211        }
212    }
213
214    #[cfg(feature = "auto")]
215    icu_provider::gen_any_buffer_data_constructors!(
216        locale: skip,
217        options: skip,
218        error: SegmenterError,
219        #[cfg(skip)]
220        functions: [
221            try_new_auto,
222            try_new_auto_with_any_provider,
223            try_new_auto_with_buffer_provider,
224            try_new_auto_unstable,
225            Self
226        ]
227    );
228
229    #[cfg(feature = "auto")]
230    #[doc = icu_provider::gen_any_buffer_unstable_docs!(UNSTABLE, Self::new_auto)]
231    pub fn try_new_auto_unstable<D>(provider: &D) -> Result<Self, SegmenterError>
232    where
233        D: DataProvider<WordBreakDataV1Marker>
234            + DataProvider<DictionaryForWordOnlyAutoV1Marker>
235            + DataProvider<LstmForWordLineAutoV1Marker>
236            + DataProvider<GraphemeClusterBreakDataV1Marker>
237            + ?Sized,
238    {
239        Ok(Self {
240            payload: provider.load(Default::default())?.take_payload()?,
241            complex: ComplexPayloads::try_new_auto(provider)?,
242        })
243    }
244
245    /// Constructs a [`WordSegmenter`] with an invariant locale and compiled LSTM data for
246    /// complex scripts (Burmese, Khmer, Lao, and Thai).
247    ///
248    /// The LSTM, or Long Term Short Memory, is a machine learning model. It is smaller than
249    /// the full dictionary but more expensive during segmentation (inference).
250    ///
251    /// Warning: there is not currently an LSTM model for Chinese or Japanese, so the [`WordSegmenter`]
252    /// created by this function will have unexpected behavior in spans of those scripts.
253    ///
254    /// ✨ *Enabled with the `compiled_data` and `lstm` Cargo features.*
255    ///
256    /// [📚 Help choosing a constructor](icu_provider::constructors)
257    ///
258    /// # Examples
259    ///
260    /// Behavior with complex scripts:
261    ///
262    /// ```
263    /// use icu::segmenter::WordSegmenter;
264    ///
265    /// let th_str = "ทุกสองสัปดาห์";
266    /// let ja_str = "こんにちは世界";
267    ///
268    /// let segmenter = WordSegmenter::new_lstm();
269    ///
270    /// let th_bps = segmenter.segment_str(th_str).collect::<Vec<_>>();
271    /// let ja_bps = segmenter.segment_str(ja_str).collect::<Vec<_>>();
272    ///
273    /// assert_eq!(th_bps, [0, 9, 18, 39]);
274    ///
275    /// // Note: We aren't able to find a suitable breakpoint in Chinese/Japanese.
276    /// assert_eq!(ja_bps, [0, 21]);
277    /// ```
278    #[cfg(feature = "compiled_data")]
279    #[cfg(feature = "lstm")]
280    pub fn new_lstm() -> Self {
281        Self {
282            payload: DataPayload::from_static_ref(
283                crate::provider::Baked::SINGLETON_SEGMENTER_WORD_V1,
284            ),
285            complex: ComplexPayloads::new_lstm(),
286        }
287    }
288
289    #[cfg(feature = "lstm")]
290    icu_provider::gen_any_buffer_data_constructors!(
291        locale: skip,
292        options: skip,
293        error: SegmenterError,
294        #[cfg(skip)]
295        functions: [
296            new_lstm,
297            try_new_lstm_with_any_provider,
298            try_new_lstm_with_buffer_provider,
299            try_new_lstm_unstable,
300            Self
301        ]
302    );
303
304    #[cfg(feature = "lstm")]
305    #[doc = icu_provider::gen_any_buffer_unstable_docs!(UNSTABLE, Self::new_lstm)]
306    pub fn try_new_lstm_unstable<D>(provider: &D) -> Result<Self, SegmenterError>
307    where
308        D: DataProvider<WordBreakDataV1Marker>
309            + DataProvider<LstmForWordLineAutoV1Marker>
310            + DataProvider<GraphemeClusterBreakDataV1Marker>
311            + ?Sized,
312    {
313        Ok(Self {
314            payload: provider.load(Default::default())?.take_payload()?,
315            complex: ComplexPayloads::try_new_lstm(provider)?,
316        })
317    }
318
319    /// Construct a [`WordSegmenter`] with an invariant locale and compiled dictionary data for
320    /// complex scripts (Chinese, Japanese, Khmer, Lao, Myanmar, and Thai).
321    ///
322    /// The dictionary model uses a list of words to determine appropriate breakpoints. It is
323    /// faster than the LSTM model but requires more data.
324    ///
325    /// ✨ *Enabled with the `compiled_data` Cargo feature.*
326    ///
327    /// [📚 Help choosing a constructor](icu_provider::constructors)
328    ///
329    /// # Examples
330    ///
331    /// Behavior with complex scripts:
332    ///
333    /// ```
334    /// use icu::segmenter::WordSegmenter;
335    ///
336    /// let th_str = "ทุกสองสัปดาห์";
337    /// let ja_str = "こんにちは世界";
338    ///
339    /// let segmenter = WordSegmenter::new_dictionary();
340    ///
341    /// let th_bps = segmenter.segment_str(th_str).collect::<Vec<_>>();
342    /// let ja_bps = segmenter.segment_str(ja_str).collect::<Vec<_>>();
343    ///
344    /// assert_eq!(th_bps, [0, 9, 18, 39]);
345    /// assert_eq!(ja_bps, [0, 15, 21]);
346    /// ```
347    #[cfg(feature = "compiled_data")]
348    pub fn new_dictionary() -> Self {
349        Self {
350            payload: DataPayload::from_static_ref(
351                crate::provider::Baked::SINGLETON_SEGMENTER_WORD_V1,
352            ),
353            complex: ComplexPayloads::new_dict(),
354        }
355    }
356
357    icu_provider::gen_any_buffer_data_constructors!(
358        locale: skip,
359        options: skip,
360        error: SegmenterError,
361        #[cfg(skip)]
362        functions: [
363            new_dictionary,
364            try_new_dictionary_with_any_provider,
365            try_new_dictionary_with_buffer_provider,
366            try_new_dictionary_unstable,
367            Self
368        ]
369    );
370
371    #[doc = icu_provider::gen_any_buffer_unstable_docs!(UNSTABLE, Self::new_dictionary)]
372    pub fn try_new_dictionary_unstable<D>(provider: &D) -> Result<Self, SegmenterError>
373    where
374        D: DataProvider<WordBreakDataV1Marker>
375            + DataProvider<DictionaryForWordOnlyAutoV1Marker>
376            + DataProvider<DictionaryForWordLineExtendedV1Marker>
377            + DataProvider<GraphemeClusterBreakDataV1Marker>
378            + ?Sized,
379    {
380        Ok(Self {
381            payload: provider.load(Default::default())?.take_payload()?,
382            complex: ComplexPayloads::try_new_dict(provider)?,
383        })
384    }
385
386    /// Creates a word break iterator for an `str` (a UTF-8 string).
387    ///
388    /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string.
389    pub fn segment_str<'l, 's>(&'l self, input: &'s str) -> WordBreakIteratorUtf8<'l, 's> {
390        WordBreakIterator(RuleBreakIterator {
391            iter: input.char_indices(),
392            len: input.len(),
393            current_pos_data: None,
394            result_cache: Vec::new(),
395            data: self.payload.get(),
396            complex: Some(&self.complex),
397            boundary_property: 0,
398        })
399    }
400
401    /// Creates a word break iterator for a potentially ill-formed UTF8 string
402    ///
403    /// Invalid characters are treated as REPLACEMENT CHARACTER
404    ///
405    /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string.
406    pub fn segment_utf8<'l, 's>(
407        &'l self,
408        input: &'s [u8],
409    ) -> WordBreakIteratorPotentiallyIllFormedUtf8<'l, 's> {
410        WordBreakIterator(RuleBreakIterator {
411            iter: Utf8CharIndices::new(input),
412            len: input.len(),
413            current_pos_data: None,
414            result_cache: Vec::new(),
415            data: self.payload.get(),
416            complex: Some(&self.complex),
417            boundary_property: 0,
418        })
419    }
420
421    /// Creates a word break iterator for a Latin-1 (8-bit) string.
422    ///
423    /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string.
424    pub fn segment_latin1<'l, 's>(&'l self, input: &'s [u8]) -> WordBreakIteratorLatin1<'l, 's> {
425        WordBreakIterator(RuleBreakIterator {
426            iter: Latin1Indices::new(input),
427            len: input.len(),
428            current_pos_data: None,
429            result_cache: Vec::new(),
430            data: self.payload.get(),
431            complex: Some(&self.complex),
432            boundary_property: 0,
433        })
434    }
435
436    /// Creates a word break iterator for a UTF-16 string.
437    ///
438    /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string.
439    pub fn segment_utf16<'l, 's>(&'l self, input: &'s [u16]) -> WordBreakIteratorUtf16<'l, 's> {
440        WordBreakIterator(RuleBreakIterator {
441            iter: Utf16Indices::new(input),
442            len: input.len(),
443            current_pos_data: None,
444            result_cache: Vec::new(),
445            data: self.payload.get(),
446            complex: Some(&self.complex),
447            boundary_property: 0,
448        })
449    }
450}
451
452#[derive(Debug)]
453pub struct WordBreakTypeUtf8;
454
455impl<'l, 's> RuleBreakType<'l, 's> for WordBreakTypeUtf8 {
456    type IterAttr = CharIndices<'s>;
457    type CharType = char;
458
459    fn get_current_position_character_len(iter: &RuleBreakIterator<Self>) -> usize {
460        iter.get_current_codepoint().map_or(0, |c| c.len_utf8())
461    }
462
463    fn handle_complex_language(
464        iter: &mut RuleBreakIterator<'l, 's, Self>,
465        left_codepoint: Self::CharType,
466    ) -> Option<usize> {
467        handle_complex_language_utf8(iter, left_codepoint)
468    }
469}
470
471#[derive(Debug)]
472pub struct WordBreakTypePotentiallyIllFormedUtf8;
473
474impl<'l, 's> RuleBreakType<'l, 's> for WordBreakTypePotentiallyIllFormedUtf8 {
475    type IterAttr = Utf8CharIndices<'s>;
476    type CharType = char;
477
478    fn get_current_position_character_len(iter: &RuleBreakIterator<Self>) -> usize {
479        iter.get_current_codepoint().map_or(0, |c| c.len_utf8())
480    }
481
482    fn handle_complex_language(
483        iter: &mut RuleBreakIterator<'l, 's, Self>,
484        left_codepoint: Self::CharType,
485    ) -> Option<usize> {
486        handle_complex_language_utf8(iter, left_codepoint)
487    }
488}
489
490/// handle_complex_language impl for UTF8 iterators
491fn handle_complex_language_utf8<'l, 's, T>(
492    iter: &mut RuleBreakIterator<'l, 's, T>,
493    left_codepoint: T::CharType,
494) -> Option<usize>
495where
496    T: RuleBreakType<'l, 's, CharType = char>,
497{
498    // word segmenter doesn't define break rules for some languages such as Thai.
499    let start_iter = iter.iter.clone();
500    let start_point = iter.current_pos_data;
501    let mut s = String::new();
502    s.push(left_codepoint);
503    loop {
504        debug_assert!(!iter.is_eof());
505        s.push(iter.get_current_codepoint()?);
506        iter.advance_iter();
507        if let Some(current_break_property) = iter.get_current_break_property() {
508            if current_break_property != iter.data.complex_property {
509                break;
510            }
511        } else {
512            // EOF
513            break;
514        }
515    }
516
517    // Restore iterator to move to head of complex string
518    iter.iter = start_iter;
519    iter.current_pos_data = start_point;
520    #[allow(clippy::unwrap_used)] // iter.complex present for word segmenter
521    let breaks = complex_language_segment_str(iter.complex.unwrap(), &s);
522    iter.result_cache = breaks;
523    let first_pos = *iter.result_cache.first()?;
524    let mut i = left_codepoint.len_utf8();
525    loop {
526        if i == first_pos {
527            // Re-calculate breaking offset
528            iter.result_cache = iter.result_cache.iter().skip(1).map(|r| r - i).collect();
529            return iter.get_current_position();
530        }
531        debug_assert!(
532            i < first_pos,
533            "we should always arrive at first_pos: near index {:?}",
534            iter.get_current_position()
535        );
536        i += T::get_current_position_character_len(iter);
537        iter.advance_iter();
538        if iter.is_eof() {
539            iter.result_cache.clear();
540            return Some(iter.len);
541        }
542    }
543}
544
545#[derive(Debug)]
546pub struct WordBreakTypeUtf16;
547
548impl<'l, 's> RuleBreakType<'l, 's> for WordBreakTypeUtf16 {
549    type IterAttr = Utf16Indices<'s>;
550    type CharType = u32;
551
552    fn get_current_position_character_len(iter: &RuleBreakIterator<Self>) -> usize {
553        match iter.get_current_codepoint() {
554            None => 0,
555            Some(ch) if ch >= 0x10000 => 2,
556            _ => 1,
557        }
558    }
559
560    fn handle_complex_language(
561        iter: &mut RuleBreakIterator<Self>,
562        left_codepoint: Self::CharType,
563    ) -> Option<usize> {
564        // word segmenter doesn't define break rules for some languages such as Thai.
565        let start_iter = iter.iter.clone();
566        let start_point = iter.current_pos_data;
567        let mut s = vec![left_codepoint as u16];
568        loop {
569            debug_assert!(!iter.is_eof());
570            s.push(iter.get_current_codepoint()? as u16);
571            iter.advance_iter();
572            if let Some(current_break_property) = iter.get_current_break_property() {
573                if current_break_property != iter.data.complex_property {
574                    break;
575                }
576            } else {
577                // EOF
578                break;
579            }
580        }
581
582        // Restore iterator to move to head of complex string
583        iter.iter = start_iter;
584        iter.current_pos_data = start_point;
585        #[allow(clippy::unwrap_used)] // iter.complex present for word segmenter
586        let breaks = complex_language_segment_utf16(iter.complex.unwrap(), &s);
587        iter.result_cache = breaks;
588        // result_cache vector is utf-16 index that is in BMP.
589        let first_pos = *iter.result_cache.first()?;
590        let mut i = 1;
591        loop {
592            if i == first_pos {
593                // Re-calculate breaking offset
594                iter.result_cache = iter.result_cache.iter().skip(1).map(|r| r - i).collect();
595                return iter.get_current_position();
596            }
597            debug_assert!(
598                i < first_pos,
599                "we should always arrive at first_pos: near index {:?}",
600                iter.get_current_position()
601            );
602            i += 1;
603            iter.advance_iter();
604            if iter.is_eof() {
605                iter.result_cache.clear();
606                return Some(iter.len);
607            }
608        }
609    }
610}
611
612#[cfg(all(test, feature = "serde"))]
613#[test]
614fn empty_string() {
615    let segmenter = WordSegmenter::new_auto();
616    let breaks: Vec<usize> = segmenter.segment_str("").collect();
617    assert_eq!(breaks, [0]);
618}