1use crate::provider::*;
6use alloc::vec::Vec;
7use icu_locid::{langid, LanguageIdentifier};
8use icu_provider::prelude::*;
9
10mod dictionary;
11use dictionary::*;
12mod language;
13use language::*;
14#[cfg(feature = "lstm")]
15mod lstm;
16#[cfg(feature = "lstm")]
17use lstm::*;
18
19#[cfg(not(feature = "lstm"))]
20type DictOrLstm = Result<DataPayload<UCharDictionaryBreakDataV1Marker>, core::convert::Infallible>;
21#[cfg(not(feature = "lstm"))]
22type DictOrLstmBorrowed<'a> =
23 Result<&'a DataPayload<UCharDictionaryBreakDataV1Marker>, &'a core::convert::Infallible>;
24
25#[cfg(feature = "lstm")]
26type DictOrLstm =
27 Result<DataPayload<UCharDictionaryBreakDataV1Marker>, DataPayload<LstmForWordLineAutoV1Marker>>;
28#[cfg(feature = "lstm")]
29type DictOrLstmBorrowed<'a> = Result<
30 &'a DataPayload<UCharDictionaryBreakDataV1Marker>,
31 &'a DataPayload<LstmForWordLineAutoV1Marker>,
32>;
33
34#[derive(Debug)]
35pub(crate) struct ComplexPayloads {
36 grapheme: DataPayload<GraphemeClusterBreakDataV1Marker>,
37 my: Option<DictOrLstm>,
38 km: Option<DictOrLstm>,
39 lo: Option<DictOrLstm>,
40 th: Option<DictOrLstm>,
41 ja: Option<DataPayload<UCharDictionaryBreakDataV1Marker>>,
42}
43
44impl ComplexPayloads {
45 fn select(&self, language: Language) -> Option<DictOrLstmBorrowed> {
46 const ERR: DataError = DataError::custom("No segmentation model for language");
47 match language {
48 Language::Burmese => self.my.as_ref().map(Result::as_ref).or_else(|| {
49 ERR.with_display_context("my");
50 None
51 }),
52 Language::Khmer => self.km.as_ref().map(Result::as_ref).or_else(|| {
53 ERR.with_display_context("km");
54 None
55 }),
56 Language::Lao => self.lo.as_ref().map(Result::as_ref).or_else(|| {
57 ERR.with_display_context("lo");
58 None
59 }),
60 Language::Thai => self.th.as_ref().map(Result::as_ref).or_else(|| {
61 ERR.with_display_context("th");
62 None
63 }),
64 Language::ChineseOrJapanese => self.ja.as_ref().map(Ok).or_else(|| {
65 ERR.with_display_context("ja");
66 None
67 }),
68 Language::Unknown => None,
69 }
70 }
71
72 #[cfg(feature = "lstm")]
73 #[cfg(feature = "compiled_data")]
74 pub(crate) fn new_lstm() -> Self {
75 #[allow(clippy::unwrap_used)]
76 Self {
78 grapheme: DataPayload::from_static_ref(
79 crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
80 ),
81 my: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("my"))
82 .unwrap()
83 .map(DataPayload::cast)
84 .map(Err),
85 km: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("km"))
86 .unwrap()
87 .map(DataPayload::cast)
88 .map(Err),
89 lo: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("lo"))
90 .unwrap()
91 .map(DataPayload::cast)
92 .map(Err),
93 th: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("th"))
94 .unwrap()
95 .map(DataPayload::cast)
96 .map(Err),
97 ja: None,
98 }
99 }
100
101 #[cfg(feature = "lstm")]
102 pub(crate) fn try_new_lstm<D>(provider: &D) -> Result<Self, DataError>
103 where
104 D: DataProvider<GraphemeClusterBreakDataV1Marker>
105 + DataProvider<LstmForWordLineAutoV1Marker>
106 + ?Sized,
107 {
108 Ok(Self {
109 grapheme: provider.load(Default::default())?.take_payload()?,
110 my: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("my"))?
111 .map(DataPayload::cast)
112 .map(Err),
113 km: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("km"))?
114 .map(DataPayload::cast)
115 .map(Err),
116 lo: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("lo"))?
117 .map(DataPayload::cast)
118 .map(Err),
119 th: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("th"))?
120 .map(DataPayload::cast)
121 .map(Err),
122 ja: None,
123 })
124 }
125
126 #[cfg(feature = "compiled_data")]
127 pub(crate) fn new_dict() -> Self {
128 #[allow(clippy::unwrap_used)]
129 Self {
131 grapheme: DataPayload::from_static_ref(
132 crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
133 ),
134 my: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
135 &crate::provider::Baked,
136 langid!("my"),
137 )
138 .unwrap()
139 .map(DataPayload::cast)
140 .map(Ok),
141 km: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
142 &crate::provider::Baked,
143 langid!("km"),
144 )
145 .unwrap()
146 .map(DataPayload::cast)
147 .map(Ok),
148 lo: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
149 &crate::provider::Baked,
150 langid!("lo"),
151 )
152 .unwrap()
153 .map(DataPayload::cast)
154 .map(Ok),
155 th: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
156 &crate::provider::Baked,
157 langid!("th"),
158 )
159 .unwrap()
160 .map(DataPayload::cast)
161 .map(Ok),
162 ja: try_load::<DictionaryForWordOnlyAutoV1Marker, _>(
163 &crate::provider::Baked,
164 langid!("ja"),
165 )
166 .unwrap()
167 .map(DataPayload::cast),
168 }
169 }
170
171 pub(crate) fn try_new_dict<D>(provider: &D) -> Result<Self, DataError>
172 where
173 D: DataProvider<GraphemeClusterBreakDataV1Marker>
174 + DataProvider<DictionaryForWordLineExtendedV1Marker>
175 + DataProvider<DictionaryForWordOnlyAutoV1Marker>
176 + ?Sized,
177 {
178 Ok(Self {
179 grapheme: provider.load(Default::default())?.take_payload()?,
180 my: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, langid!("my"))?
181 .map(DataPayload::cast)
182 .map(Ok),
183 km: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, langid!("km"))?
184 .map(DataPayload::cast)
185 .map(Ok),
186 lo: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, langid!("lo"))?
187 .map(DataPayload::cast)
188 .map(Ok),
189 th: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, langid!("th"))?
190 .map(DataPayload::cast)
191 .map(Ok),
192 ja: try_load::<DictionaryForWordOnlyAutoV1Marker, D>(provider, langid!("ja"))?
193 .map(DataPayload::cast),
194 })
195 }
196
197 #[cfg(feature = "auto")] #[cfg(feature = "compiled_data")]
199 pub(crate) fn new_auto() -> Self {
200 #[allow(clippy::unwrap_used)]
201 Self {
203 grapheme: DataPayload::from_static_ref(
204 crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
205 ),
206 my: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("my"))
207 .unwrap()
208 .map(DataPayload::cast)
209 .map(Err),
210 km: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("km"))
211 .unwrap()
212 .map(DataPayload::cast)
213 .map(Err),
214 lo: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("lo"))
215 .unwrap()
216 .map(DataPayload::cast)
217 .map(Err),
218 th: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, langid!("th"))
219 .unwrap()
220 .map(DataPayload::cast)
221 .map(Err),
222 ja: try_load::<DictionaryForWordOnlyAutoV1Marker, _>(
223 &crate::provider::Baked,
224 langid!("ja"),
225 )
226 .unwrap()
227 .map(DataPayload::cast),
228 }
229 }
230
231 #[cfg(feature = "auto")] pub(crate) fn try_new_auto<D>(provider: &D) -> Result<Self, DataError>
233 where
234 D: DataProvider<GraphemeClusterBreakDataV1Marker>
235 + DataProvider<LstmForWordLineAutoV1Marker>
236 + DataProvider<DictionaryForWordOnlyAutoV1Marker>
237 + ?Sized,
238 {
239 Ok(Self {
240 grapheme: provider.load(Default::default())?.take_payload()?,
241 my: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("my"))?
242 .map(DataPayload::cast)
243 .map(Err),
244 km: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("km"))?
245 .map(DataPayload::cast)
246 .map(Err),
247 lo: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("lo"))?
248 .map(DataPayload::cast)
249 .map(Err),
250 th: try_load::<LstmForWordLineAutoV1Marker, D>(provider, langid!("th"))?
251 .map(DataPayload::cast)
252 .map(Err),
253 ja: try_load::<DictionaryForWordOnlyAutoV1Marker, D>(provider, langid!("ja"))?
254 .map(DataPayload::cast),
255 })
256 }
257
258 #[cfg(feature = "compiled_data")]
259 pub(crate) fn new_southeast_asian() -> Self {
260 #[allow(clippy::unwrap_used)]
261 Self {
263 grapheme: DataPayload::from_static_ref(
264 crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
265 ),
266 my: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
267 &crate::provider::Baked,
268 langid!("my"),
269 )
270 .unwrap()
271 .map(DataPayload::cast)
272 .map(Ok),
273 km: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
274 &crate::provider::Baked,
275 langid!("km"),
276 )
277 .unwrap()
278 .map(DataPayload::cast)
279 .map(Ok),
280 lo: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
281 &crate::provider::Baked,
282 langid!("lo"),
283 )
284 .unwrap()
285 .map(DataPayload::cast)
286 .map(Ok),
287 th: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
288 &crate::provider::Baked,
289 langid!("th"),
290 )
291 .unwrap()
292 .map(DataPayload::cast)
293 .map(Ok),
294 ja: None,
295 }
296 }
297
298 pub(crate) fn try_new_southeast_asian<D>(provider: &D) -> Result<Self, DataError>
299 where
300 D: DataProvider<DictionaryForWordLineExtendedV1Marker>
301 + DataProvider<GraphemeClusterBreakDataV1Marker>
302 + ?Sized,
303 {
304 Ok(Self {
305 grapheme: provider.load(Default::default())?.take_payload()?,
306 my: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, langid!("my"))?
307 .map(DataPayload::cast)
308 .map(Ok),
309 km: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, langid!("km"))?
310 .map(DataPayload::cast)
311 .map(Ok),
312 lo: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, langid!("lo"))?
313 .map(DataPayload::cast)
314 .map(Ok),
315 th: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, langid!("th"))?
316 .map(DataPayload::cast)
317 .map(Ok),
318 ja: None,
319 })
320 }
321}
322
323fn try_load<M: KeyedDataMarker, P: DataProvider<M> + ?Sized>(
324 provider: &P,
325 locale: LanguageIdentifier,
326) -> Result<Option<DataPayload<M>>, DataError> {
327 match provider.load(DataRequest {
328 locale: &locale.into(),
329 metadata: {
330 let mut m = DataRequestMetadata::default();
331 m.silent = true;
332 m
333 },
334 }) {
335 Ok(response) => Ok(Some(response.take_payload()?)),
336 Err(DataError {
337 kind: DataErrorKind::MissingLocale,
338 ..
339 }) => Ok(None),
340 Err(e) => Err(e),
341 }
342}
343
344pub(crate) fn complex_language_segment_utf16(
346 payloads: &ComplexPayloads,
347 input: &[u16],
348) -> Vec<usize> {
349 let mut result = Vec::new();
350 let mut offset = 0;
351 for (slice, lang) in LanguageIteratorUtf16::new(input) {
352 match payloads.select(lang) {
353 Some(Ok(dict)) => {
354 result.extend(
355 DictionarySegmenter::new(dict.get(), payloads.grapheme.get())
356 .segment_utf16(slice)
357 .map(|n| offset + n),
358 );
359 }
360 #[cfg(feature = "lstm")]
361 Some(Err(lstm)) => {
362 result.extend(
363 LstmSegmenter::new(lstm.get(), payloads.grapheme.get())
364 .segment_utf16(slice)
365 .map(|n| offset + n),
366 );
367 }
368 #[cfg(not(feature = "lstm"))]
369 Some(Err(_infallible)) => {} None => {
371 result.push(offset + slice.len());
372 }
373 }
374 offset += slice.len();
375 }
376 result
377}
378
379pub(crate) fn complex_language_segment_str(payloads: &ComplexPayloads, input: &str) -> Vec<usize> {
381 let mut result = Vec::new();
382 let mut offset = 0;
383 for (slice, lang) in LanguageIterator::new(input) {
384 match payloads.select(lang) {
385 Some(Ok(dict)) => {
386 result.extend(
387 DictionarySegmenter::new(dict.get(), payloads.grapheme.get())
388 .segment_str(slice)
389 .map(|n| offset + n),
390 );
391 }
392 #[cfg(feature = "lstm")]
393 Some(Err(lstm)) => {
394 result.extend(
395 LstmSegmenter::new(lstm.get(), payloads.grapheme.get())
396 .segment_str(slice)
397 .map(|n| offset + n),
398 );
399 }
400 #[cfg(not(feature = "lstm"))]
401 Some(Err(_infallible)) => {} None => {
403 result.push(offset + slice.len());
404 }
405 }
406 offset += slice.len();
407 }
408 result
409}
410
411#[cfg(test)]
412#[cfg(feature = "serde")]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn thai_word_break() {
418 const TEST_STR: &str = "ภาษาไทยภาษาไทย";
419 let utf16: Vec<u16> = TEST_STR.encode_utf16().collect();
420
421 let lstm = ComplexPayloads::new_lstm();
422 let dict = ComplexPayloads::new_dict();
423
424 assert_eq!(
425 complex_language_segment_str(&lstm, TEST_STR),
426 [12, 21, 33, 42]
427 );
428 assert_eq!(
429 complex_language_segment_utf16(&lstm, &utf16),
430 [4, 7, 11, 14]
431 );
432
433 assert_eq!(
434 complex_language_segment_str(&dict, TEST_STR),
435 [12, 21, 33, 42]
436 );
437 assert_eq!(
438 complex_language_segment_utf16(&dict, &utf16),
439 [4, 7, 11, 14]
440 );
441 }
442}