icu_segmenter/complex/
mod.rs

1// This file is part of ICU4X. For terms of use, please see the file
2// called LICENSE at the top level of the ICU4X source tree
3// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
5use crate::provider::*;
6use alloc::vec::Vec;
7use icu_locid::{langid, LanguageIdentifier};
8use icu_provider::prelude::*;
9
10mod dictionary;
11use dictionary::*;
12mod language;
13use language::*;
14#[cfg(feature = "lstm")]
15mod lstm;
16#[cfg(feature = "lstm")]
17use lstm::*;
18
19#[cfg(not(feature = "lstm"))]
20type DictOrLstm = Result<DataPayload<UCharDictionaryBreakDataV1Marker>, core::convert::Infallible>;
21#[cfg(not(feature = "lstm"))]
22type DictOrLstmBorrowed<'a> =
23    Result<&'a DataPayload<UCharDictionaryBreakDataV1Marker>, &'a core::convert::Infallible>;
24
25#[cfg(feature = "lstm")]
26type DictOrLstm =
27    Result<DataPayload<UCharDictionaryBreakDataV1Marker>, DataPayload<LstmForWordLineAutoV1Marker>>;
28#[cfg(feature = "lstm")]
29type DictOrLstmBorrowed<'a> = Result<
30    &'a DataPayload<UCharDictionaryBreakDataV1Marker>,
31    &'a DataPayload<LstmForWordLineAutoV1Marker>,
32>;
33
34#[derive(Debug)]
35pub(crate) struct ComplexPayloads {
36    grapheme: DataPayload<GraphemeClusterBreakDataV1Marker>,
37    my: Option<DictOrLstm>,
38    km: Option<DictOrLstm>,
39    lo: Option<DictOrLstm>,
40    th: Option<DictOrLstm>,
41    ja: Option<DataPayload<UCharDictionaryBreakDataV1Marker>>,
42}
43
44impl ComplexPayloads {
45    fn select(&self, language: Language) -> Option<DictOrLstmBorrowed> {
46        const ERR: DataError = DataError::custom("No segmentation model for language");
47        match language {
48            Language::Burmese => self.my.as_ref().map(Result::as_ref).or_else(|| {
49                ERR.with_display_context("my");
50                None
51            }),
52            Language::Khmer => self.km.as_ref().map(Result::as_ref).or_else(|| {
53                ERR.with_display_context("km");
54                None
55            }),
56            Language::Lao => self.lo.as_ref().map(Result::as_ref).or_else(|| {
57                ERR.with_display_context("lo");
58                None
59            }),
60            Language::Thai => self.th.as_ref().map(Result::as_ref).or_else(|| {
61                ERR.with_display_context("th");
62                None
63            }),
64            Language::ChineseOrJapanese => self.ja.as_ref().map(Ok).or_else(|| {
65                ERR.with_display_context("ja");
66                None
67            }),
68            Language::Unknown => None,
69        }
70    }
71
72    #[cfg(feature = "lstm")]
73    #[cfg(feature = "compiled_data")]
74    pub(crate) fn new_lstm() -> Self {
75        #[allow(clippy::unwrap_used)]
76        // try_load is infallible if the provider only returns `MissingLocale`.
77        Self {
78            grapheme: DataPayload::from_static_ref(
79                crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
80            ),
81            my: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("my"))
82                .unwrap()
83                .map(DataPayload::cast)
84                .map(Err),
85            km: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("km"))
86                .unwrap()
87                .map(DataPayload::cast)
88                .map(Err),
89            lo: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("lo"))
90                .unwrap()
91                .map(DataPayload::cast)
92                .map(Err),
93            th: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("th"))
94                .unwrap()
95                .map(DataPayload::cast)
96                .map(Err),
97            ja: None,
98        }
99    }
100
101    #[cfg(feature = "lstm")]
102    pub(crate) fn try_new_lstm<D>(provider: &D) -> Result<Self, DataError>
103    where
104        D: DataProvider<GraphemeClusterBreakDataV1Marker>
105            + DataProvider<LstmForWordLineAutoV1Marker>
106            + ?Sized,
107    {
108        Ok(Self {
109            grapheme: provider.load(Default::default())?.take_payload()?,
110            my: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("my"))?
111                .map(DataPayload::cast)
112                .map(Err),
113            km: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("km"))?
114                .map(DataPayload::cast)
115                .map(Err),
116            lo: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("lo"))?
117                .map(DataPayload::cast)
118                .map(Err),
119            th: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("th"))?
120                .map(DataPayload::cast)
121                .map(Err),
122            ja: None,
123        })
124    }
125
126    #[cfg(feature = "compiled_data")]
127    pub(crate) fn new_dict() -> Self {
128        #[allow(clippy::unwrap_used)]
129        // try_load is infallible if the provider only returns `MissingLocale`.
130        Self {
131            grapheme: DataPayload::from_static_ref(
132                crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
133            ),
134            my: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
135                &crate::provider::Baked,
136                langid!("my"),
137            )
138            .unwrap()
139            .map(DataPayload::cast)
140            .map(Ok),
141            km: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
142                &crate::provider::Baked,
143                langid!("km"),
144            )
145            .unwrap()
146            .map(DataPayload::cast)
147            .map(Ok),
148            lo: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
149                &crate::provider::Baked,
150                langid!("lo"),
151            )
152            .unwrap()
153            .map(DataPayload::cast)
154            .map(Ok),
155            th: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
156                &crate::provider::Baked,
157                langid!("th"),
158            )
159            .unwrap()
160            .map(DataPayload::cast)
161            .map(Ok),
162            ja: try_load::<DictionaryForWordOnlyAutoV1Marker, _>(
163                &crate::provider::Baked,
164                langid!("ja"),
165            )
166            .unwrap()
167            .map(DataPayload::cast),
168        }
169    }
170
171    pub(crate) fn try_new_dict<D>(provider: &D) -> Result<Self, DataError>
172    where
173        D: DataProvider<GraphemeClusterBreakDataV1Marker>
174            + DataProvider<DictionaryForWordLineExtendedV1Marker>
175            + DataProvider<DictionaryForWordOnlyAutoV1Marker>
176            + ?Sized,
177    {
178        Ok(Self {
179            grapheme: provider.load(Default::default())?.take_payload()?,
180            my: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, langid!("my"))?
181                .map(DataPayload::cast)
182                .map(Ok),
183            km: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, langid!("km"))?
184                .map(DataPayload::cast)
185                .map(Ok),
186            lo: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, langid!("lo"))?
187                .map(DataPayload::cast)
188                .map(Ok),
189            th: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, langid!("th"))?
190                .map(DataPayload::cast)
191                .map(Ok),
192            ja: try_load::<DictionaryForWordOnlyAutoV1Marker, D>(provider, langid!("ja"))?
193                .map(DataPayload::cast),
194        })
195    }
196
197    #[cfg(feature = "auto")] // Use by WordSegmenter with "auto" enabled.
198    #[cfg(feature = "compiled_data")]
199    pub(crate) fn new_auto() -> Self {
200        #[allow(clippy::unwrap_used)]
201        // try_load is infallible if the provider only returns `MissingLocale`.
202        Self {
203            grapheme: DataPayload::from_static_ref(
204                crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
205            ),
206            my: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("my"))
207                .unwrap()
208                .map(DataPayload::cast)
209                .map(Err),
210            km: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("km"))
211                .unwrap()
212                .map(DataPayload::cast)
213                .map(Err),
214            lo: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("lo"))
215                .unwrap()
216                .map(DataPayload::cast)
217                .map(Err),
218            th: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("th"))
219                .unwrap()
220                .map(DataPayload::cast)
221                .map(Err),
222            ja: try_load::<DictionaryForWordOnlyAutoV1Marker, _>(
223                &crate::provider::Baked,
224                langid!("ja"),
225            )
226            .unwrap()
227            .map(DataPayload::cast),
228        }
229    }
230
231    #[cfg(feature = "auto")] // Use by WordSegmenter with "auto" enabled.
232    pub(crate) fn try_new_auto<D>(provider: &D) -> Result<Self, DataError>
233    where
234        D: DataProvider<GraphemeClusterBreakDataV1Marker>
235            + DataProvider<LstmForWordLineAutoV1Marker>
236            + DataProvider<DictionaryForWordOnlyAutoV1Marker>
237            + ?Sized,
238    {
239        Ok(Self {
240            grapheme: provider.load(Default::default())?.take_payload()?,
241            my: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("my"))?
242                .map(DataPayload::cast)
243                .map(Err),
244            km: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("km"))?
245                .map(DataPayload::cast)
246                .map(Err),
247            lo: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("lo"))?
248                .map(DataPayload::cast)
249                .map(Err),
250            th: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("th"))?
251                .map(DataPayload::cast)
252                .map(Err),
253            ja: try_load::<DictionaryForWordOnlyAutoV1Marker, D>(provider, langid!("ja"))?
254                .map(DataPayload::cast),
255        })
256    }
257
258    #[cfg(feature = "compiled_data")]
259    pub(crate) fn new_southeast_asian() -> Self {
260        #[allow(clippy::unwrap_used)]
261        // try_load is infallible if the provider only returns `MissingLocale`.
262        Self {
263            grapheme: DataPayload::from_static_ref(
264                crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
265            ),
266            my: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
267                &crate::provider::Baked,
268                langid!("my"),
269            )
270            .unwrap()
271            .map(DataPayload::cast)
272            .map(Ok),
273            km: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
274                &crate::provider::Baked,
275                langid!("km"),
276            )
277            .unwrap()
278            .map(DataPayload::cast)
279            .map(Ok),
280            lo: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
281                &crate::provider::Baked,
282                langid!("lo"),
283            )
284            .unwrap()
285            .map(DataPayload::cast)
286            .map(Ok),
287            th: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
288                &crate::provider::Baked,
289                langid!("th"),
290            )
291            .unwrap()
292            .map(DataPayload::cast)
293            .map(Ok),
294            ja: None,
295        }
296    }
297
298    pub(crate) fn try_new_southeast_asian<D>(provider: &D) -> Result<Self, DataError>
299    where
300        D: DataProvider<DictionaryForWordLineExtendedV1Marker>
301            + DataProvider<GraphemeClusterBreakDataV1Marker>
302            + ?Sized,
303    {
304        Ok(Self {
305            grapheme: provider.load(Default::default())?.take_payload()?,
306            my: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, langid!("my"))?
307                .map(DataPayload::cast)
308                .map(Ok),
309            km: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, langid!("km"))?
310                .map(DataPayload::cast)
311                .map(Ok),
312            lo: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, langid!("lo"))?
313                .map(DataPayload::cast)
314                .map(Ok),
315            th: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, langid!("th"))?
316                .map(DataPayload::cast)
317                .map(Ok),
318            ja: None,
319        })
320    }
321}
322
323fn try_load<M: KeyedDataMarker, P: DataProvider<M> + ?Sized>(
324    provider: &P,
325    locale: LanguageIdentifier,
326) -> Result<Option<DataPayload<M>>, DataError> {
327    match provider.load(DataRequest {
328        locale: &locale.into(),
329        metadata: {
330            let mut m = DataRequestMetadata::default();
331            m.silent = true;
332            m
333        },
334    }) {
335        Ok(response) => Ok(Some(response.take_payload()?)),
336        Err(DataError {
337            kind: DataErrorKind::MissingLocale,
338            ..
339        }) => Ok(None),
340        Err(e) => Err(e),
341    }
342}
343
344/// Return UTF-16 segment offset array using dictionary or lstm segmenter.
345pub(crate) fn complex_language_segment_utf16(
346    payloads: &ComplexPayloads,
347    input: &[u16],
348) -> Vec<usize> {
349    let mut result = Vec::new();
350    let mut offset = 0;
351    for (slice, lang) in LanguageIteratorUtf16::new(input) {
352        match payloads.select(lang) {
353            Some(Ok(dict)) => {
354                result.extend(
355                    DictionarySegmenter::new(dict.get(), payloads.grapheme.get())
356                        .segment_utf16(slice)
357                        .map(|n| offset + n),
358                );
359            }
360            #[cfg(feature = "lstm")]
361            Some(Err(lstm)) => {
362                result.extend(
363                    LstmSegmenter::new(lstm.get(), payloads.grapheme.get())
364                        .segment_utf16(slice)
365                        .map(|n| offset + n),
366                );
367            }
368            #[cfg(not(feature = "lstm"))]
369            Some(Err(_infallible)) => {} // should be refutable
370            None => {
371                result.push(offset + slice.len());
372            }
373        }
374        offset += slice.len();
375    }
376    result
377}
378
379/// Return UTF-8 segment offset array using dictionary or lstm segmenter.
380pub(crate) fn complex_language_segment_str(payloads: &ComplexPayloads, input: &str) -> Vec<usize> {
381    let mut result = Vec::new();
382    let mut offset = 0;
383    for (slice, lang) in LanguageIterator::new(input) {
384        match payloads.select(lang) {
385            Some(Ok(dict)) => {
386                result.extend(
387                    DictionarySegmenter::new(dict.get(), payloads.grapheme.get())
388                        .segment_str(slice)
389                        .map(|n| offset + n),
390                );
391            }
392            #[cfg(feature = "lstm")]
393            Some(Err(lstm)) => {
394                result.extend(
395                    LstmSegmenter::new(lstm.get(), payloads.grapheme.get())
396                        .segment_str(slice)
397                        .map(|n| offset + n),
398                );
399            }
400            #[cfg(not(feature = "lstm"))]
401            Some(Err(_infallible)) => {} // should be refutable
402            None => {
403                result.push(offset + slice.len());
404            }
405        }
406        offset += slice.len();
407    }
408    result
409}
410
411#[cfg(test)]
412#[cfg(feature = "serde")]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn thai_word_break() {
418        const TEST_STR: &str = "ภาษาไทยภาษาไทย";
419        let utf16: Vec<u16> = TEST_STR.encode_utf16().collect();
420
421        let lstm = ComplexPayloads::new_lstm();
422        let dict = ComplexPayloads::new_dict();
423
424        assert_eq!(
425            complex_language_segment_str(&lstm, TEST_STR),
426            [12, 21, 33, 42]
427        );
428        assert_eq!(
429            complex_language_segment_utf16(&lstm, &utf16),
430            [4, 7, 11, 14]
431        );
432
433        assert_eq!(
434            complex_language_segment_str(&dict, TEST_STR),
435            [12, 21, 33, 42]
436        );
437        assert_eq!(
438            complex_language_segment_utf16(&dict, &utf16),
439            [4, 7, 11, 14]
440        );
441    }
442}