icu_segmenter/complex/lstm/
mod.rs1use crate::grapheme::GraphemeClusterSegmenter;
6use crate::provider::*;
7use alloc::vec::Vec;
8use core::char::{decode_utf16, REPLACEMENT_CHARACTER};
9use zerovec::{maps::ZeroMapBorrowed, ule::UnvalidatedStr};
10
11mod matrix;
12use matrix::*;
13
14struct LstmSegmenterIterator<'s> {
17 input: &'s str,
18 pos_utf8: usize,
19 bies: BiesIterator<'s>,
20}
21
22impl Iterator for LstmSegmenterIterator<'_> {
23 type Item = usize;
24
25 fn next(&mut self) -> Option<Self::Item> {
26 #[allow(clippy::indexing_slicing)] loop {
28 let is_e = self.bies.next()?;
29 self.pos_utf8 += self.input[self.pos_utf8..].chars().next()?.len_utf8();
30 if is_e || self.bies.len() == 0 {
31 return Some(self.pos_utf8);
32 }
33 }
34 }
35}
36
37struct LstmSegmenterIteratorUtf16<'s> {
38 bies: BiesIterator<'s>,
39 pos: usize,
40}
41
42impl Iterator for LstmSegmenterIteratorUtf16<'_> {
43 type Item = usize;
44
45 fn next(&mut self) -> Option<Self::Item> {
46 loop {
47 self.pos += 1;
48 if self.bies.next()? || self.bies.len() == 0 {
49 return Some(self.pos);
50 }
51 }
52 }
53}
54
55pub(super) struct LstmSegmenter<'l> {
56 dic: ZeroMapBorrowed<'l, UnvalidatedStr, u16>,
57 embedding: MatrixZero<'l, 2>,
58 fw_w: MatrixZero<'l, 3>,
59 fw_u: MatrixZero<'l, 3>,
60 fw_b: MatrixZero<'l, 2>,
61 bw_w: MatrixZero<'l, 3>,
62 bw_u: MatrixZero<'l, 3>,
63 bw_b: MatrixZero<'l, 2>,
64 timew_fw: MatrixZero<'l, 2>,
65 timew_bw: MatrixZero<'l, 2>,
66 time_b: MatrixZero<'l, 1>,
67 grapheme: Option<&'l RuleBreakDataV1<'l>>,
68}
69
70impl<'l> LstmSegmenter<'l> {
71 pub(super) fn new(lstm: &'l LstmDataV1<'l>, grapheme: &'l RuleBreakDataV1<'l>) -> Self {
73 let LstmDataV1::Float32(lstm) = lstm;
74 let time_w = MatrixZero::from(&lstm.time_w);
75 #[allow(clippy::unwrap_used)] let timew_fw = time_w.submatrix(0).unwrap();
77 #[allow(clippy::unwrap_used)] let timew_bw = time_w.submatrix(1).unwrap();
79 Self {
80 dic: lstm.dic.as_borrowed(),
81 embedding: MatrixZero::from(&lstm.embedding),
82 fw_w: MatrixZero::from(&lstm.fw_w),
83 fw_u: MatrixZero::from(&lstm.fw_u),
84 fw_b: MatrixZero::from(&lstm.fw_b),
85 bw_w: MatrixZero::from(&lstm.bw_w),
86 bw_u: MatrixZero::from(&lstm.bw_u),
87 bw_b: MatrixZero::from(&lstm.bw_b),
88 timew_fw,
89 timew_bw,
90 time_b: MatrixZero::from(&lstm.time_b),
91 grapheme: (lstm.model == ModelType::GraphemeClusters).then_some(grapheme),
92 }
93 }
94
95 pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator<Item = usize> + 'l {
97 self.segment_str_p(input)
98 }
99
100 fn segment_str_p(&'l self, input: &'l str) -> LstmSegmenterIterator<'l> {
102 let input_seq = if let Some(grapheme) = self.grapheme {
103 GraphemeClusterSegmenter::new_and_segment_str(input, grapheme)
104 .collect::<Vec<usize>>()
105 .windows(2)
106 .map(|chunk| {
107 let range = if let [first, second, ..] = chunk {
108 *first..*second
109 } else {
110 unreachable!()
111 };
112 let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) {
113 grapheme_cluster
114 } else {
115 return self.dic.len() as u16;
116 };
117
118 self.dic
119 .get_copied(UnvalidatedStr::from_str(grapheme_cluster))
120 .unwrap_or_else(|| self.dic.len() as u16)
121 })
122 .collect()
123 } else {
124 input
125 .chars()
126 .map(|c| {
127 self.dic
128 .get_copied(UnvalidatedStr::from_str(c.encode_utf8(&mut [0; 4])))
129 .unwrap_or_else(|| self.dic.len() as u16)
130 })
131 .collect()
132 };
133 LstmSegmenterIterator {
134 input,
135 pos_utf8: 0,
136 bies: BiesIterator::new(self, input_seq),
137 }
138 }
139
140 pub(super) fn segment_utf16(&'l self, input: &[u16]) -> impl Iterator<Item = usize> + 'l {
142 let input_seq = if let Some(grapheme) = self.grapheme {
143 GraphemeClusterSegmenter::new_and_segment_utf16(input, grapheme)
144 .collect::<Vec<usize>>()
145 .windows(2)
146 .map(|chunk| {
147 let range = if let [first, second, ..] = chunk {
148 *first..*second
149 } else {
150 unreachable!()
151 };
152 let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) {
153 grapheme_cluster
154 } else {
155 return self.dic.len() as u16;
156 };
157
158 self.dic
159 .get_copied_by(|key| {
160 key.as_bytes().iter().copied().cmp(
161 decode_utf16(grapheme_cluster.iter().copied()).flat_map(|c| {
162 let mut buf = [0; 4];
163 let len = c
164 .unwrap_or(REPLACEMENT_CHARACTER)
165 .encode_utf8(&mut buf)
166 .len();
167 buf.into_iter().take(len)
168 }),
169 )
170 })
171 .unwrap_or_else(|| self.dic.len() as u16)
172 })
173 .collect()
174 } else {
175 decode_utf16(input.iter().copied())
176 .map(|c| c.unwrap_or(REPLACEMENT_CHARACTER))
177 .map(|c| {
178 self.dic
179 .get_copied(UnvalidatedStr::from_str(c.encode_utf8(&mut [0; 4])))
180 .unwrap_or_else(|| self.dic.len() as u16)
181 })
182 .collect()
183 };
184 LstmSegmenterIteratorUtf16 {
185 bies: BiesIterator::new(self, input_seq),
186 pos: 0,
187 }
188 }
189}
190
191struct BiesIterator<'l> {
192 segmenter: &'l LstmSegmenter<'l>,
193 input_seq: core::iter::Enumerate<alloc::vec::IntoIter<u16>>,
194 h_bw: MatrixOwned<2>,
195 curr_fw: MatrixOwned<1>,
196 c_fw: MatrixOwned<1>,
197}
198
199impl<'l> BiesIterator<'l> {
200 fn new(segmenter: &'l LstmSegmenter<'l>, input_seq: Vec<u16>) -> Self {
203 let hunits = segmenter.fw_u.dim().1;
204
205 let mut c_bw = MatrixOwned::<1>::new_zero([hunits]);
207 let mut h_bw = MatrixOwned::<2>::new_zero([input_seq.len(), hunits]);
208 for (i, &g_id) in input_seq.iter().enumerate().rev() {
209 if i + 1 < input_seq.len() {
210 h_bw.as_mut().copy_submatrix::<1>(i + 1, i);
211 }
212 #[allow(clippy::unwrap_used)]
213 compute_hc(
214 segmenter.embedding.submatrix::<1>(g_id as usize).unwrap(), h_bw.submatrix_mut(i).unwrap(), c_bw.as_mut(),
217 segmenter.bw_w,
218 segmenter.bw_u,
219 segmenter.bw_b,
220 );
221 }
222
223 Self {
224 input_seq: input_seq.into_iter().enumerate(),
225 h_bw,
226 c_fw: MatrixOwned::<1>::new_zero([hunits]),
227 curr_fw: MatrixOwned::<1>::new_zero([hunits]),
228 segmenter,
229 }
230 }
231}
232
233impl ExactSizeIterator for BiesIterator<'_> {
234 fn len(&self) -> usize {
235 self.input_seq.len()
236 }
237}
238
239impl Iterator for BiesIterator<'_> {
240 type Item = bool;
241
242 fn next(&mut self) -> Option<Self::Item> {
243 let (i, g_id) = self.input_seq.next()?;
244
245 #[allow(clippy::unwrap_used)]
246 compute_hc(
247 self.segmenter
248 .embedding
249 .submatrix::<1>(g_id as usize)
250 .unwrap(), self.curr_fw.as_mut(),
252 self.c_fw.as_mut(),
253 self.segmenter.fw_w,
254 self.segmenter.fw_u,
255 self.segmenter.fw_b,
256 );
257
258 #[allow(clippy::unwrap_used)] let curr_bw = self.h_bw.submatrix::<1>(i).unwrap();
260 let mut weights = [0.0; 4];
261 let mut curr_est = MatrixBorrowedMut {
262 data: &mut weights,
263 dims: [4],
264 };
265 curr_est.add_dot_2d(self.curr_fw.as_borrowed(), self.segmenter.timew_fw);
266 curr_est.add_dot_2d(curr_bw, self.segmenter.timew_bw);
267 #[allow(clippy::unwrap_used)] curr_est.add(self.segmenter.time_b).unwrap();
269 Some(weights[2] > weights[0] && weights[2] > weights[1] && weights[2] > weights[3])
273 }
274}
275
276fn compute_hc<'a>(
278 x_t: MatrixZero<'a, 1>,
279 mut h_tm1: MatrixBorrowedMut<'a, 1>,
280 mut c_tm1: MatrixBorrowedMut<'a, 1>,
281 w: MatrixZero<'a, 3>,
282 u: MatrixZero<'a, 3>,
283 b: MatrixZero<'a, 2>,
284) {
285 #[cfg(debug_assertions)]
286 {
287 let hunits = h_tm1.dim();
288 let embedd_dim = x_t.dim();
289 c_tm1.as_borrowed().debug_assert_dims([hunits]);
290 w.debug_assert_dims([4, hunits, embedd_dim]);
291 u.debug_assert_dims([4, hunits, hunits]);
292 b.debug_assert_dims([4, hunits]);
293 }
294
295 let mut s_t = b.to_owned();
296
297 s_t.as_mut().add_dot_3d_2(x_t, w);
298 s_t.as_mut().add_dot_3d_1(h_tm1.as_borrowed(), u);
299
300 #[allow(clippy::unwrap_used)] s_t.submatrix_mut::<1>(0).unwrap().sigmoid_transform();
302 #[allow(clippy::unwrap_used)] s_t.submatrix_mut::<1>(1).unwrap().sigmoid_transform();
304 #[allow(clippy::unwrap_used)] s_t.submatrix_mut::<1>(2).unwrap().tanh_transform();
306 #[allow(clippy::unwrap_used)] s_t.submatrix_mut::<1>(3).unwrap().sigmoid_transform();
308
309 #[allow(clippy::unwrap_used)] c_tm1.convolve(
311 s_t.as_borrowed().submatrix(0).unwrap(),
312 s_t.as_borrowed().submatrix(2).unwrap(),
313 s_t.as_borrowed().submatrix(1).unwrap(),
314 );
315
316 #[allow(clippy::unwrap_used)] h_tm1.mul_tanh(s_t.as_borrowed().submatrix(3).unwrap(), c_tm1.as_borrowed());
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use icu_locid::langid;
324 use icu_provider::prelude::*;
325 use serde::Deserialize;
326
327 #[derive(PartialEq, Debug, Deserialize)]
331 struct TestCase {
332 unseg: String,
333 expected_bies: String,
334 true_bies: String,
335 }
336
337 #[derive(PartialEq, Debug, Deserialize)]
339 struct TestTextData {
340 testcases: Vec<TestCase>,
341 }
342
343 #[derive(Debug)]
344 struct TestText {
345 data: TestTextData,
346 }
347
348 #[test]
349 fn segment_file_by_lstm() {
350 let lstm: DataPayload<LstmForWordLineAutoV1Marker> = crate::provider::Baked
351 .load(DataRequest {
352 locale: &langid!("th").into(),
353 metadata: Default::default(),
354 })
355 .unwrap()
356 .take_payload()
357 .unwrap();
358 let lstm = LstmSegmenter::new(
359 lstm.get(),
360 crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
361 );
362
363 let test_text_data = serde_json::from_str(if lstm.grapheme.is_some() {
365 include_str!("../../../tests/testdata/test_text_graphclust.json")
366 } else {
367 include_str!("../../../tests/testdata/test_text_codepoints.json")
368 })
369 .expect("JSON syntax error");
370 let test_text = TestText {
371 data: test_text_data,
372 };
373
374 for test_case in &test_text.data.testcases {
376 let lstm_output = lstm
377 .segment_str_p(&test_case.unseg)
378 .bies
379 .map(|is_e| if is_e { 'e' } else { '?' })
380 .collect::<String>();
381 println!("Test case : {}", test_case.unseg);
382 println!("Expected bies : {}", test_case.expected_bies);
383 println!("Estimated bies : {lstm_output}");
384 println!("True bies : {}", test_case.true_bies);
385 println!("****************************************************");
386 assert_eq!(
387 test_case.expected_bies.replace(['b', 'i', 's'], "?"),
388 lstm_output
389 );
390 }
391 }
392}