icu_segmenter/complex/lstm/
mod.rs

1// This file is part of ICU4X. For terms of use, please see the file
2// called LICENSE at the top level of the ICU4X source tree
3// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
5use crate::grapheme::GraphemeClusterSegmenter;
6use crate::provider::*;
7use alloc::vec::Vec;
8use core::char::{decode_utf16, REPLACEMENT_CHARACTER};
9use zerovec::{maps::ZeroMapBorrowed, ule::UnvalidatedStr};
10
11mod matrix;
12use matrix::*;
13
14// A word break iterator using LSTM model. Input string have to be same language.
15
16struct LstmSegmenterIterator<'s> {
17    input: &'s str,
18    pos_utf8: usize,
19    bies: BiesIterator<'s>,
20}
21
22impl Iterator for LstmSegmenterIterator<'_> {
23    type Item = usize;
24
25    fn next(&mut self) -> Option<Self::Item> {
26        #[allow(clippy::indexing_slicing)] // pos_utf8 in range
27        loop {
28            let is_e = self.bies.next()?;
29            self.pos_utf8 += self.input[self.pos_utf8..].chars().next()?.len_utf8();
30            if is_e || self.bies.len() == 0 {
31                return Some(self.pos_utf8);
32            }
33        }
34    }
35}
36
37struct LstmSegmenterIteratorUtf16<'s> {
38    bies: BiesIterator<'s>,
39    pos: usize,
40}
41
42impl Iterator for LstmSegmenterIteratorUtf16<'_> {
43    type Item = usize;
44
45    fn next(&mut self) -> Option<Self::Item> {
46        loop {
47            self.pos += 1;
48            if self.bies.next()? || self.bies.len() == 0 {
49                return Some(self.pos);
50            }
51        }
52    }
53}
54
55pub(super) struct LstmSegmenter<'l> {
56    dic: ZeroMapBorrowed<'l, UnvalidatedStr, u16>,
57    embedding: MatrixZero<'l, 2>,
58    fw_w: MatrixZero<'l, 3>,
59    fw_u: MatrixZero<'l, 3>,
60    fw_b: MatrixZero<'l, 2>,
61    bw_w: MatrixZero<'l, 3>,
62    bw_u: MatrixZero<'l, 3>,
63    bw_b: MatrixZero<'l, 2>,
64    timew_fw: MatrixZero<'l, 2>,
65    timew_bw: MatrixZero<'l, 2>,
66    time_b: MatrixZero<'l, 1>,
67    grapheme: Option<&'l RuleBreakDataV1<'l>>,
68}
69
70impl<'l> LstmSegmenter<'l> {
71    /// Returns `Err` if grapheme data is required but not present
72    pub(super) fn new(lstm: &'l LstmDataV1<'l>, grapheme: &'l RuleBreakDataV1<'l>) -> Self {
73        let LstmDataV1::Float32(lstm) = lstm;
74        let time_w = MatrixZero::from(&lstm.time_w);
75        #[allow(clippy::unwrap_used)] // shape (2, 4, hunits)
76        let timew_fw = time_w.submatrix(0).unwrap();
77        #[allow(clippy::unwrap_used)] // shape (2, 4, hunits)
78        let timew_bw = time_w.submatrix(1).unwrap();
79        Self {
80            dic: lstm.dic.as_borrowed(),
81            embedding: MatrixZero::from(&lstm.embedding),
82            fw_w: MatrixZero::from(&lstm.fw_w),
83            fw_u: MatrixZero::from(&lstm.fw_u),
84            fw_b: MatrixZero::from(&lstm.fw_b),
85            bw_w: MatrixZero::from(&lstm.bw_w),
86            bw_u: MatrixZero::from(&lstm.bw_u),
87            bw_b: MatrixZero::from(&lstm.bw_b),
88            timew_fw,
89            timew_bw,
90            time_b: MatrixZero::from(&lstm.time_b),
91            grapheme: (lstm.model == ModelType::GraphemeClusters).then_some(grapheme),
92        }
93    }
94
95    /// Create an LSTM based break iterator for an `str` (a UTF-8 string).
96    pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator<Item = usize> + 'l {
97        self.segment_str_p(input)
98    }
99
100    // For unit testing as we cannot inspect the opaque type's bies
101    fn segment_str_p(&'l self, input: &'l str) -> LstmSegmenterIterator<'l> {
102        let input_seq = if let Some(grapheme) = self.grapheme {
103            GraphemeClusterSegmenter::new_and_segment_str(input, grapheme)
104                .collect::<Vec<usize>>()
105                .windows(2)
106                .map(|chunk| {
107                    let range = if let [first, second, ..] = chunk {
108                        *first..*second
109                    } else {
110                        unreachable!()
111                    };
112                    let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) {
113                        grapheme_cluster
114                    } else {
115                        return self.dic.len() as u16;
116                    };
117
118                    self.dic
119                        .get_copied(UnvalidatedStr::from_str(grapheme_cluster))
120                        .unwrap_or_else(|| self.dic.len() as u16)
121                })
122                .collect()
123        } else {
124            input
125                .chars()
126                .map(|c| {
127                    self.dic
128                        .get_copied(UnvalidatedStr::from_str(c.encode_utf8(&mut [0; 4])))
129                        .unwrap_or_else(|| self.dic.len() as u16)
130                })
131                .collect()
132        };
133        LstmSegmenterIterator {
134            input,
135            pos_utf8: 0,
136            bies: BiesIterator::new(self, input_seq),
137        }
138    }
139
140    /// Create an LSTM based break iterator for a UTF-16 string.
141    pub(super) fn segment_utf16(&'l self, input: &[u16]) -> impl Iterator<Item = usize> + 'l {
142        let input_seq = if let Some(grapheme) = self.grapheme {
143            GraphemeClusterSegmenter::new_and_segment_utf16(input, grapheme)
144                .collect::<Vec<usize>>()
145                .windows(2)
146                .map(|chunk| {
147                    let range = if let [first, second, ..] = chunk {
148                        *first..*second
149                    } else {
150                        unreachable!()
151                    };
152                    let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) {
153                        grapheme_cluster
154                    } else {
155                        return self.dic.len() as u16;
156                    };
157
158                    self.dic
159                        .get_copied_by(|key| {
160                            key.as_bytes().iter().copied().cmp(
161                                decode_utf16(grapheme_cluster.iter().copied()).flat_map(|c| {
162                                    let mut buf = [0; 4];
163                                    let len = c
164                                        .unwrap_or(REPLACEMENT_CHARACTER)
165                                        .encode_utf8(&mut buf)
166                                        .len();
167                                    buf.into_iter().take(len)
168                                }),
169                            )
170                        })
171                        .unwrap_or_else(|| self.dic.len() as u16)
172                })
173                .collect()
174        } else {
175            decode_utf16(input.iter().copied())
176                .map(|c| c.unwrap_or(REPLACEMENT_CHARACTER))
177                .map(|c| {
178                    self.dic
179                        .get_copied(UnvalidatedStr::from_str(c.encode_utf8(&mut [0; 4])))
180                        .unwrap_or_else(|| self.dic.len() as u16)
181                })
182                .collect()
183        };
184        LstmSegmenterIteratorUtf16 {
185            bies: BiesIterator::new(self, input_seq),
186            pos: 0,
187        }
188    }
189}
190
191struct BiesIterator<'l> {
192    segmenter: &'l LstmSegmenter<'l>,
193    input_seq: core::iter::Enumerate<alloc::vec::IntoIter<u16>>,
194    h_bw: MatrixOwned<2>,
195    curr_fw: MatrixOwned<1>,
196    c_fw: MatrixOwned<1>,
197}
198
199impl<'l> BiesIterator<'l> {
200    // input_seq is a sequence of id numbers that represents grapheme clusters or code points in the input line. These ids are used later
201    // in the embedding layer of the model.
202    fn new(segmenter: &'l LstmSegmenter<'l>, input_seq: Vec<u16>) -> Self {
203        let hunits = segmenter.fw_u.dim().1;
204
205        // Backward LSTM
206        let mut c_bw = MatrixOwned::<1>::new_zero([hunits]);
207        let mut h_bw = MatrixOwned::<2>::new_zero([input_seq.len(), hunits]);
208        for (i, &g_id) in input_seq.iter().enumerate().rev() {
209            if i + 1 < input_seq.len() {
210                h_bw.as_mut().copy_submatrix::<1>(i + 1, i);
211            }
212            #[allow(clippy::unwrap_used)]
213            compute_hc(
214                segmenter.embedding.submatrix::<1>(g_id as usize).unwrap(), /* shape (dict.len() + 1, hunit), g_id is at most dict.len() */
215                h_bw.submatrix_mut(i).unwrap(), // shape (input_seq.len(), hunits)
216                c_bw.as_mut(),
217                segmenter.bw_w,
218                segmenter.bw_u,
219                segmenter.bw_b,
220            );
221        }
222
223        Self {
224            input_seq: input_seq.into_iter().enumerate(),
225            h_bw,
226            c_fw: MatrixOwned::<1>::new_zero([hunits]),
227            curr_fw: MatrixOwned::<1>::new_zero([hunits]),
228            segmenter,
229        }
230    }
231}
232
233impl ExactSizeIterator for BiesIterator<'_> {
234    fn len(&self) -> usize {
235        self.input_seq.len()
236    }
237}
238
239impl Iterator for BiesIterator<'_> {
240    type Item = bool;
241
242    fn next(&mut self) -> Option<Self::Item> {
243        let (i, g_id) = self.input_seq.next()?;
244
245        #[allow(clippy::unwrap_used)]
246        compute_hc(
247            self.segmenter
248                .embedding
249                .submatrix::<1>(g_id as usize)
250                .unwrap(), // shape (dict.len() + 1, hunit), g_id is at most dict.len()
251            self.curr_fw.as_mut(),
252            self.c_fw.as_mut(),
253            self.segmenter.fw_w,
254            self.segmenter.fw_u,
255            self.segmenter.fw_b,
256        );
257
258        #[allow(clippy::unwrap_used)] // shape (input_seq.len(), hunits)
259        let curr_bw = self.h_bw.submatrix::<1>(i).unwrap();
260        let mut weights = [0.0; 4];
261        let mut curr_est = MatrixBorrowedMut {
262            data: &mut weights,
263            dims: [4],
264        };
265        curr_est.add_dot_2d(self.curr_fw.as_borrowed(), self.segmenter.timew_fw);
266        curr_est.add_dot_2d(curr_bw, self.segmenter.timew_bw);
267        #[allow(clippy::unwrap_used)] // both shape (4)
268        curr_est.add(self.segmenter.time_b).unwrap();
269        // For correct BIES weight calculation we'd now have to apply softmax, however
270        // we're only doing a naive argmax, so a monotonic function doesn't make a difference.
271
272        Some(weights[2] > weights[0] && weights[2] > weights[1] && weights[2] > weights[3])
273    }
274}
275
276/// `compute_hc1` implemens the evaluation of one LSTM layer.
277fn compute_hc<'a>(
278    x_t: MatrixZero<'a, 1>,
279    mut h_tm1: MatrixBorrowedMut<'a, 1>,
280    mut c_tm1: MatrixBorrowedMut<'a, 1>,
281    w: MatrixZero<'a, 3>,
282    u: MatrixZero<'a, 3>,
283    b: MatrixZero<'a, 2>,
284) {
285    #[cfg(debug_assertions)]
286    {
287        let hunits = h_tm1.dim();
288        let embedd_dim = x_t.dim();
289        c_tm1.as_borrowed().debug_assert_dims([hunits]);
290        w.debug_assert_dims([4, hunits, embedd_dim]);
291        u.debug_assert_dims([4, hunits, hunits]);
292        b.debug_assert_dims([4, hunits]);
293    }
294
295    let mut s_t = b.to_owned();
296
297    s_t.as_mut().add_dot_3d_2(x_t, w);
298    s_t.as_mut().add_dot_3d_1(h_tm1.as_borrowed(), u);
299
300    #[allow(clippy::unwrap_used)] // first dimension is 4
301    s_t.submatrix_mut::<1>(0).unwrap().sigmoid_transform();
302    #[allow(clippy::unwrap_used)] // first dimension is 4
303    s_t.submatrix_mut::<1>(1).unwrap().sigmoid_transform();
304    #[allow(clippy::unwrap_used)] // first dimension is 4
305    s_t.submatrix_mut::<1>(2).unwrap().tanh_transform();
306    #[allow(clippy::unwrap_used)] // first dimension is 4
307    s_t.submatrix_mut::<1>(3).unwrap().sigmoid_transform();
308
309    #[allow(clippy::unwrap_used)] // first dimension is 4
310    c_tm1.convolve(
311        s_t.as_borrowed().submatrix(0).unwrap(),
312        s_t.as_borrowed().submatrix(2).unwrap(),
313        s_t.as_borrowed().submatrix(1).unwrap(),
314    );
315
316    #[allow(clippy::unwrap_used)] // first dimension is 4
317    h_tm1.mul_tanh(s_t.as_borrowed().submatrix(3).unwrap(), c_tm1.as_borrowed());
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use icu_locid::langid;
324    use icu_provider::prelude::*;
325    use serde::Deserialize;
326
327    /// `TestCase` is a struct used to store a single test case.
328    /// Each test case has two attributes: `unseg` which denotes the unsegmented line, and `true_bies` which indicates the Bies
329    /// sequence representing the true segmentation.
330    #[derive(PartialEq, Debug, Deserialize)]
331    struct TestCase {
332        unseg: String,
333        expected_bies: String,
334        true_bies: String,
335    }
336
337    /// `TestTextData` is a struct to store a vector of `TestCase` that represents a test text.
338    #[derive(PartialEq, Debug, Deserialize)]
339    struct TestTextData {
340        testcases: Vec<TestCase>,
341    }
342
343    #[derive(Debug)]
344    struct TestText {
345        data: TestTextData,
346    }
347
348    #[test]
349    fn segment_file_by_lstm() {
350        let lstm: DataPayload<LstmForWordLineAutoV1Marker> = crate::provider::Baked
351            .load(DataRequest {
352                locale: &langid!("th").into(),
353                metadata: Default::default(),
354            })
355            .unwrap()
356            .take_payload()
357            .unwrap();
358        let lstm = LstmSegmenter::new(
359            lstm.get(),
360            crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
361        );
362
363        // Importing the test data
364        let test_text_data = serde_json::from_str(if lstm.grapheme.is_some() {
365            include_str!("../../../tests/testdata/test_text_graphclust.json")
366        } else {
367            include_str!("../../../tests/testdata/test_text_codepoints.json")
368        })
369        .expect("JSON syntax error");
370        let test_text = TestText {
371            data: test_text_data,
372        };
373
374        // Testing
375        for test_case in &test_text.data.testcases {
376            let lstm_output = lstm
377                .segment_str_p(&test_case.unseg)
378                .bies
379                .map(|is_e| if is_e { 'e' } else { '?' })
380                .collect::<String>();
381            println!("Test case      : {}", test_case.unseg);
382            println!("Expected bies  : {}", test_case.expected_bies);
383            println!("Estimated bies : {lstm_output}");
384            println!("True bies      : {}", test_case.true_bies);
385            println!("****************************************************");
386            assert_eq!(
387                test_case.expected_bies.replace(['b', 'i', 's'], "?"),
388                lstm_output
389            );
390        }
391    }
392}