icu_segmenter/complex/lstm/
matrix.rs

1// This file is part of ICU4X. For terms of use, please see the file
2// called LICENSE at the top level of the ICU4X source tree
3// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
5use alloc::vec;
6use alloc::vec::Vec;
7use core::ops::Range;
8#[allow(unused_imports)]
9use core_maths::*;
10use zerovec::ule::AsULE;
11use zerovec::ZeroSlice;
12
13/// A `D`-dimensional, heap-allocated matrix.
14///
15/// This matrix implementation supports slicing matrices into tightly-packed
16/// submatrices. For example, indexing into a matrix of size 5x4x3 returns a
17/// matrix of size 4x3. For more information, see [`MatrixOwned::submatrix`].
18#[derive(Debug, Clone)]
19pub(super) struct MatrixOwned<const D: usize> {
20    data: Vec<f32>,
21    dims: [usize; D],
22}
23
24impl<const D: usize> MatrixOwned<D> {
25    pub(super) fn as_borrowed(&self) -> MatrixBorrowed<D> {
26        MatrixBorrowed {
27            data: &self.data,
28            dims: self.dims,
29        }
30    }
31
32    pub(super) fn new_zero(dims: [usize; D]) -> Self {
33        let total_len = dims.iter().product::<usize>();
34        MatrixOwned {
35            data: vec![0.0; total_len],
36            dims,
37        }
38    }
39
40    /// Returns the tightly packed submatrix at _index_, or `None` if _index_ is out of range.
41    ///
42    /// For example, if the matrix is 5x4x3, this function returns a matrix sized 4x3. If the
43    /// matrix is 4x3, then this function returns a linear matrix of length 3.
44    ///
45    /// The type parameter `M` should be `D - 1`.
46    #[inline]
47    pub(super) fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixBorrowed<M>> {
48        // This assertion is based on const generics; it should always succeed and be elided.
49        assert_eq!(M, D - 1);
50        let (range, dims) = self.as_borrowed().submatrix_range(index);
51        let data = &self.data.get(range)?;
52        Some(MatrixBorrowed { data, dims })
53    }
54
55    pub(super) fn as_mut(&mut self) -> MatrixBorrowedMut<D> {
56        MatrixBorrowedMut {
57            data: &mut self.data,
58            dims: self.dims,
59        }
60    }
61
62    /// A mutable version of [`Self::submatrix`].
63    #[inline]
64    pub(super) fn submatrix_mut<const M: usize>(
65        &mut self,
66        index: usize,
67    ) -> Option<MatrixBorrowedMut<M>> {
68        // This assertion is based on const generics; it should always succeed and be elided.
69        assert_eq!(M, D - 1);
70        let (range, dims) = self.as_borrowed().submatrix_range(index);
71        let data = self.data.get_mut(range)?;
72        Some(MatrixBorrowedMut { data, dims })
73    }
74}
75
76/// A `D`-dimensional, borrowed matrix.
77#[derive(Debug, Clone, Copy)]
78pub(super) struct MatrixBorrowed<'a, const D: usize> {
79    data: &'a [f32],
80    dims: [usize; D],
81}
82
83impl<'a, const D: usize> MatrixBorrowed<'a, D> {
84    #[cfg(debug_assertions)]
85    pub(super) fn debug_assert_dims(&self, dims: [usize; D]) {
86        debug_assert_eq!(dims, self.dims);
87        let expected_len = dims.iter().product::<usize>();
88        debug_assert_eq!(expected_len, self.data.len());
89    }
90
91    pub(super) fn as_slice(&self) -> &'a [f32] {
92        self.data
93    }
94
95    /// See [`MatrixOwned::submatrix`].
96    #[inline]
97    pub(super) fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixBorrowed<'a, M>> {
98        // This assertion is based on const generics; it should always succeed and be elided.
99        assert_eq!(M, D - 1);
100        let (range, dims) = self.submatrix_range(index);
101        let data = &self.data.get(range)?;
102        Some(MatrixBorrowed { data, dims })
103    }
104
105    #[inline]
106    fn submatrix_range<const M: usize>(&self, index: usize) -> (Range<usize>, [usize; M]) {
107        // This assertion is based on const generics; it should always succeed and be elided.
108        assert_eq!(M, D - 1);
109        // The above assertion guarantees that the following line will succeed
110        #[allow(clippy::indexing_slicing, clippy::unwrap_used)]
111        let sub_dims: [usize; M] = self.dims[1..].try_into().unwrap();
112        let n = sub_dims.iter().product::<usize>();
113        (n * index..n * (index + 1), sub_dims)
114    }
115}
116
117macro_rules! impl_basic_dim {
118    ($t1:path, $t2:path, $t3:path) => {
119        impl<'a> $t1 {
120            #[allow(dead_code)]
121            pub(super) fn dim(&self) -> usize {
122                let [dim] = self.dims;
123                dim
124            }
125        }
126        impl<'a> $t2 {
127            #[allow(dead_code)]
128            pub(super) fn dim(&self) -> (usize, usize) {
129                let [d0, d1] = self.dims;
130                (d0, d1)
131            }
132        }
133        impl<'a> $t3 {
134            #[allow(dead_code)]
135            pub(super) fn dim(&self) -> (usize, usize, usize) {
136                let [d0, d1, d2] = self.dims;
137                (d0, d1, d2)
138            }
139        }
140    };
141}
142
143impl_basic_dim!(MatrixOwned<1>, MatrixOwned<2>, MatrixOwned<3>);
144impl_basic_dim!(
145    MatrixBorrowed<'a, 1>,
146    MatrixBorrowed<'a, 2>,
147    MatrixBorrowed<'a, 3>
148);
149impl_basic_dim!(
150    MatrixBorrowedMut<'a, 1>,
151    MatrixBorrowedMut<'a, 2>,
152    MatrixBorrowedMut<'a, 3>
153);
154impl_basic_dim!(MatrixZero<'a, 1>, MatrixZero<'a, 2>, MatrixZero<'a, 3>);
155
156/// A `D`-dimensional, mutably borrowed matrix.
157pub(super) struct MatrixBorrowedMut<'a, const D: usize> {
158    pub(super) data: &'a mut [f32],
159    pub(super) dims: [usize; D],
160}
161
162impl<'a, const D: usize> MatrixBorrowedMut<'a, D> {
163    pub(super) fn as_borrowed(&self) -> MatrixBorrowed<D> {
164        MatrixBorrowed {
165            data: self.data,
166            dims: self.dims,
167        }
168    }
169
170    pub(super) fn as_mut_slice(&mut self) -> &mut [f32] {
171        self.data
172    }
173
174    pub(super) fn copy_submatrix<const M: usize>(&mut self, from: usize, to: usize) {
175        let (range_from, _) = self.as_borrowed().submatrix_range::<M>(from);
176        let (range_to, _) = self.as_borrowed().submatrix_range::<M>(to);
177        if let (Some(_), Some(_)) = (
178            self.data.get(range_from.clone()),
179            self.data.get(range_to.clone()),
180        ) {
181            // This function is panicky, but we just validated the ranges
182            self.data.copy_within(range_from, range_to.start);
183        }
184    }
185
186    #[must_use]
187    pub(super) fn add(&mut self, other: MatrixZero<'_, D>) -> Option<()> {
188        debug_assert_eq!(self.dims, other.dims);
189        // TODO: Vectorize?
190        for i in 0..self.data.len() {
191            *self.data.get_mut(i)? += other.data.get(i)?;
192        }
193        Some(())
194    }
195
196    #[allow(dead_code)] // maybe needed for more complicated bies calculations
197    /// Mutates this matrix by applying a softmax transformation.
198    pub(super) fn softmax_transform(&mut self) {
199        for v in self.data.iter_mut() {
200            *v = v.exp();
201        }
202        let sm = 1.0 / self.data.iter().sum::<f32>();
203        for v in self.data.iter_mut() {
204            *v *= sm;
205        }
206    }
207
208    pub(super) fn sigmoid_transform(&mut self) {
209        for x in &mut self.data.iter_mut() {
210            *x = 1.0 / (1.0 + (-*x).exp());
211        }
212    }
213
214    pub(super) fn tanh_transform(&mut self) {
215        for x in &mut self.data.iter_mut() {
216            *x = x.tanh();
217        }
218    }
219
220    pub(super) fn convolve(
221        &mut self,
222        i: MatrixBorrowed<'_, D>,
223        c: MatrixBorrowed<'_, D>,
224        f: MatrixBorrowed<'_, D>,
225    ) {
226        let i = i.as_slice();
227        let c = c.as_slice();
228        let f = f.as_slice();
229        let len = self.data.len();
230        if len != i.len() || len != c.len() || len != f.len() {
231            debug_assert!(false, "LSTM matrices not the correct dimensions");
232            return;
233        }
234        for idx in 0..len {
235            // Safety: The lengths are all the same (checked above)
236            unsafe {
237                *self.data.get_unchecked_mut(idx) = i.get_unchecked(idx) * c.get_unchecked(idx)
238                    + self.data.get_unchecked(idx) * f.get_unchecked(idx)
239            }
240        }
241    }
242
243    pub(super) fn mul_tanh(&mut self, o: MatrixBorrowed<'_, D>, c: MatrixBorrowed<'_, D>) {
244        let o = o.as_slice();
245        let c = c.as_slice();
246        let len = self.data.len();
247        if len != o.len() || len != c.len() {
248            debug_assert!(false, "LSTM matrices not the correct dimensions");
249            return;
250        }
251        for idx in 0..len {
252            // Safety: The lengths are all the same (checked above)
253            unsafe {
254                *self.data.get_unchecked_mut(idx) =
255                    o.get_unchecked(idx) * c.get_unchecked(idx).tanh();
256            }
257        }
258    }
259}
260
261impl<'a> MatrixBorrowed<'a, 1> {
262    #[allow(dead_code)] // could be useful
263    pub(super) fn dot_1d(&self, other: MatrixZero<1>) -> f32 {
264        debug_assert_eq!(self.dims, other.dims);
265        unrolled_dot_1(self.data, other.data)
266    }
267}
268
269impl<'a> MatrixBorrowedMut<'a, 1> {
270    /// Calculate the dot product of a and b, adding the result to self.
271    ///
272    /// Note: For better dot product efficiency, if `b` is MxN, then `a` should be N;
273    /// this is the opposite of standard practice.
274    pub(super) fn add_dot_2d(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<2>) {
275        let m = a.dim();
276        let n = self.as_borrowed().dim();
277        debug_assert_eq!(
278            m,
279            b.dim().1,
280            "dims: {:?}/{:?}/{:?}",
281            self.as_borrowed().dim(),
282            a.dim(),
283            b.dim()
284        );
285        debug_assert_eq!(
286            n,
287            b.dim().0,
288            "dims: {:?}/{:?}/{:?}",
289            self.as_borrowed().dim(),
290            a.dim(),
291            b.dim()
292        );
293        for i in 0..n {
294            if let (Some(dest), Some(b_sub)) = (self.as_mut_slice().get_mut(i), b.submatrix::<1>(i))
295            {
296                *dest += unrolled_dot_1(a.data, b_sub.data);
297            } else {
298                debug_assert!(false, "unreachable: dims checked above");
299            }
300        }
301    }
302}
303
304impl<'a> MatrixBorrowedMut<'a, 2> {
305    /// Calculate the dot product of a and b, adding the result to self.
306    ///
307    /// Self should be _MxN_; `a`, _O_; and `b`, _MxNxO_.
308    pub(super) fn add_dot_3d_1(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<3>) {
309        let m = a.dim();
310        let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1;
311        debug_assert_eq!(
312            m,
313            b.dim().2,
314            "dims: {:?}/{:?}/{:?}",
315            self.as_borrowed().dim(),
316            a.dim(),
317            b.dim()
318        );
319        debug_assert_eq!(
320            n,
321            b.dim().0 * b.dim().1,
322            "dims: {:?}/{:?}/{:?}",
323            self.as_borrowed().dim(),
324            a.dim(),
325            b.dim()
326        );
327        // Note: The following two loops are equivalent, but the second has more opportunity for
328        // vectorization since it allows the vectorization to span submatrices.
329        // for i in 0..b.dim().0 {
330        //     self.submatrix_mut::<1>(i).add_dot_2d(a, b.submatrix(i));
331        // }
332        let lhs = a.as_slice();
333        for i in 0..n {
334            if let (Some(dest), Some(rhs)) = (
335                self.as_mut_slice().get_mut(i),
336                b.as_slice().get_subslice(i * m..(i + 1) * m),
337            ) {
338                *dest += unrolled_dot_1(lhs, rhs);
339            } else {
340                debug_assert!(false, "unreachable: dims checked above");
341            }
342        }
343    }
344
345    /// Calculate the dot product of a and b, adding the result to self.
346    ///
347    /// Self should be _MxN_; `a`, _O_; and `b`, _MxNxO_.
348    pub(super) fn add_dot_3d_2(&mut self, a: MatrixZero<1>, b: MatrixZero<3>) {
349        let m = a.dim();
350        let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1;
351        debug_assert_eq!(
352            m,
353            b.dim().2,
354            "dims: {:?}/{:?}/{:?}",
355            self.as_borrowed().dim(),
356            a.dim(),
357            b.dim()
358        );
359        debug_assert_eq!(
360            n,
361            b.dim().0 * b.dim().1,
362            "dims: {:?}/{:?}/{:?}",
363            self.as_borrowed().dim(),
364            a.dim(),
365            b.dim()
366        );
367        // Note: The following two loops are equivalent, but the second has more opportunity for
368        // vectorization since it allows the vectorization to span submatrices.
369        // for i in 0..b.dim().0 {
370        //     self.submatrix_mut::<1>(i).add_dot_2d(a, b.submatrix(i));
371        // }
372        let lhs = a.as_slice();
373        for i in 0..n {
374            if let (Some(dest), Some(rhs)) = (
375                self.as_mut_slice().get_mut(i),
376                b.as_slice().get_subslice(i * m..(i + 1) * m),
377            ) {
378                *dest += unrolled_dot_2(lhs, rhs);
379            } else {
380                debug_assert!(false, "unreachable: dims checked above");
381            }
382        }
383    }
384}
385
386/// A `D`-dimensional matrix borrowed from a [`ZeroSlice`].
387#[derive(Debug, Clone, Copy)]
388pub(super) struct MatrixZero<'a, const D: usize> {
389    data: &'a ZeroSlice<f32>,
390    dims: [usize; D],
391}
392
393impl<'a> From<&'a crate::provider::LstmMatrix1<'a>> for MatrixZero<'a, 1> {
394    fn from(other: &'a crate::provider::LstmMatrix1<'a>) -> Self {
395        Self {
396            data: &other.data,
397            dims: other.dims.map(|x| x as usize),
398        }
399    }
400}
401
402impl<'a> From<&'a crate::provider::LstmMatrix2<'a>> for MatrixZero<'a, 2> {
403    fn from(other: &'a crate::provider::LstmMatrix2<'a>) -> Self {
404        Self {
405            data: &other.data,
406            dims: other.dims.map(|x| x as usize),
407        }
408    }
409}
410
411impl<'a> From<&'a crate::provider::LstmMatrix3<'a>> for MatrixZero<'a, 3> {
412    fn from(other: &'a crate::provider::LstmMatrix3<'a>) -> Self {
413        Self {
414            data: &other.data,
415            dims: other.dims.map(|x| x as usize),
416        }
417    }
418}
419
420impl<'a, const D: usize> MatrixZero<'a, D> {
421    #[allow(clippy::wrong_self_convention)] // same convention as slice::to_vec
422    pub(super) fn to_owned(&self) -> MatrixOwned<D> {
423        MatrixOwned {
424            data: self.data.iter().collect(),
425            dims: self.dims,
426        }
427    }
428
429    pub(super) fn as_slice(&self) -> &ZeroSlice<f32> {
430        self.data
431    }
432
433    #[cfg(debug_assertions)]
434    pub(super) fn debug_assert_dims(&self, dims: [usize; D]) {
435        debug_assert_eq!(dims, self.dims);
436        let expected_len = dims.iter().product::<usize>();
437        debug_assert_eq!(expected_len, self.data.len());
438    }
439
440    /// See [`MatrixOwned::submatrix`].
441    #[inline]
442    pub(super) fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixZero<'a, M>> {
443        // This assertion is based on const generics; it should always succeed and be elided.
444        assert_eq!(M, D - 1);
445        let (range, dims) = self.submatrix_range(index);
446        let data = &self.data.get_subslice(range)?;
447        Some(MatrixZero { data, dims })
448    }
449
450    #[inline]
451    fn submatrix_range<const M: usize>(&self, index: usize) -> (Range<usize>, [usize; M]) {
452        // This assertion is based on const generics; it should always succeed and be elided.
453        assert_eq!(M, D - 1);
454        // The above assertion guarantees that the following line will succeed
455        #[allow(clippy::indexing_slicing, clippy::unwrap_used)]
456        let sub_dims: [usize; M] = self.dims[1..].try_into().unwrap();
457        let n = sub_dims.iter().product::<usize>();
458        (n * index..n * (index + 1), sub_dims)
459    }
460}
461
462macro_rules! f32c {
463    ($ule:expr) => {
464        f32::from_unaligned($ule)
465    };
466}
467
468/// Compute the dot product of an aligned and an unaligned f32 slice.
469///
470/// `xs` and `ys` must be the same length
471///
472/// (Based on ndarray 0.15.6)
473fn unrolled_dot_1(xs: &[f32], ys: &ZeroSlice<f32>) -> f32 {
474    debug_assert_eq!(xs.len(), ys.len());
475    // eightfold unrolled so that floating point can be vectorized
476    // (even with strict floating point accuracy semantics)
477    let mut p = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0);
478    let xit = xs.chunks_exact(8);
479    let yit = ys.as_ule_slice().chunks_exact(8);
480    let sum = xit
481        .remainder()
482        .iter()
483        .zip(yit.remainder().iter())
484        .map(|(x, y)| x * f32c!(*y))
485        .sum::<f32>();
486    for (xx, yy) in xit.zip(yit) {
487        // TODO: Use array_chunks once stable to avoid the unwrap.
488        // <https://github.com/rust-lang/rust/issues/74985>
489        #[allow(clippy::unwrap_used)]
490        let [x0, x1, x2, x3, x4, x5, x6, x7] = *<&[f32; 8]>::try_from(xx).unwrap();
491        #[allow(clippy::unwrap_used)]
492        let [y0, y1, y2, y3, y4, y5, y6, y7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(yy).unwrap();
493        p.0 += x0 * f32c!(y0);
494        p.1 += x1 * f32c!(y1);
495        p.2 += x2 * f32c!(y2);
496        p.3 += x3 * f32c!(y3);
497        p.4 += x4 * f32c!(y4);
498        p.5 += x5 * f32c!(y5);
499        p.6 += x6 * f32c!(y6);
500        p.7 += x7 * f32c!(y7);
501    }
502    sum + (p.0 + p.4) + (p.1 + p.5) + (p.2 + p.6) + (p.3 + p.7)
503}
504
505/// Compute the dot product of two unaligned f32 slices.
506///
507/// `xs` and `ys` must be the same length
508///
509/// (Based on ndarray 0.15.6)
510fn unrolled_dot_2(xs: &ZeroSlice<f32>, ys: &ZeroSlice<f32>) -> f32 {
511    debug_assert_eq!(xs.len(), ys.len());
512    // eightfold unrolled so that floating point can be vectorized
513    // (even with strict floating point accuracy semantics)
514    let mut p = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0);
515    let xit = xs.as_ule_slice().chunks_exact(8);
516    let yit = ys.as_ule_slice().chunks_exact(8);
517    let sum = xit
518        .remainder()
519        .iter()
520        .zip(yit.remainder().iter())
521        .map(|(x, y)| f32c!(*x) * f32c!(*y))
522        .sum::<f32>();
523    for (xx, yy) in xit.zip(yit) {
524        // TODO: Use array_chunks once stable to avoid the unwrap.
525        // <https://github.com/rust-lang/rust/issues/74985>
526        #[allow(clippy::unwrap_used)]
527        let [x0, x1, x2, x3, x4, x5, x6, x7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(xx).unwrap();
528        #[allow(clippy::unwrap_used)]
529        let [y0, y1, y2, y3, y4, y5, y6, y7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(yy).unwrap();
530        p.0 += f32c!(x0) * f32c!(y0);
531        p.1 += f32c!(x1) * f32c!(y1);
532        p.2 += f32c!(x2) * f32c!(y2);
533        p.3 += f32c!(x3) * f32c!(y3);
534        p.4 += f32c!(x4) * f32c!(y4);
535        p.5 += f32c!(x5) * f32c!(y5);
536        p.6 += f32c!(x6) * f32c!(y6);
537        p.7 += f32c!(x7) * f32c!(y7);
538    }
539    sum + (p.0 + p.4) + (p.1 + p.5) + (p.2 + p.6) + (p.3 + p.7)
540}