1use alloc::vec;
6use alloc::vec::Vec;
7use core::ops::Range;
8#[allow(unused_imports)]
9use core_maths::*;
10use zerovec::ule::AsULE;
11use zerovec::ZeroSlice;
12
13#[derive(Debug, Clone)]
19pub(super) struct MatrixOwned<const D: usize> {
20 data: Vec<f32>,
21 dims: [usize; D],
22}
23
24impl<const D: usize> MatrixOwned<D> {
25 pub(super) fn as_borrowed(&self) -> MatrixBorrowed<D> {
26 MatrixBorrowed {
27 data: &self.data,
28 dims: self.dims,
29 }
30 }
31
32 pub(super) fn new_zero(dims: [usize; D]) -> Self {
33 let total_len = dims.iter().product::<usize>();
34 MatrixOwned {
35 data: vec![0.0; total_len],
36 dims,
37 }
38 }
39
40 #[inline]
47 pub(super) fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixBorrowed<M>> {
48 assert_eq!(M, D - 1);
50 let (range, dims) = self.as_borrowed().submatrix_range(index);
51 let data = &self.data.get(range)?;
52 Some(MatrixBorrowed { data, dims })
53 }
54
55 pub(super) fn as_mut(&mut self) -> MatrixBorrowedMut<D> {
56 MatrixBorrowedMut {
57 data: &mut self.data,
58 dims: self.dims,
59 }
60 }
61
62 #[inline]
64 pub(super) fn submatrix_mut<const M: usize>(
65 &mut self,
66 index: usize,
67 ) -> Option<MatrixBorrowedMut<M>> {
68 assert_eq!(M, D - 1);
70 let (range, dims) = self.as_borrowed().submatrix_range(index);
71 let data = self.data.get_mut(range)?;
72 Some(MatrixBorrowedMut { data, dims })
73 }
74}
75
76#[derive(Debug, Clone, Copy)]
78pub(super) struct MatrixBorrowed<'a, const D: usize> {
79 data: &'a [f32],
80 dims: [usize; D],
81}
82
83impl<'a, const D: usize> MatrixBorrowed<'a, D> {
84 #[cfg(debug_assertions)]
85 pub(super) fn debug_assert_dims(&self, dims: [usize; D]) {
86 debug_assert_eq!(dims, self.dims);
87 let expected_len = dims.iter().product::<usize>();
88 debug_assert_eq!(expected_len, self.data.len());
89 }
90
91 pub(super) fn as_slice(&self) -> &'a [f32] {
92 self.data
93 }
94
95 #[inline]
97 pub(super) fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixBorrowed<'a, M>> {
98 assert_eq!(M, D - 1);
100 let (range, dims) = self.submatrix_range(index);
101 let data = &self.data.get(range)?;
102 Some(MatrixBorrowed { data, dims })
103 }
104
105 #[inline]
106 fn submatrix_range<const M: usize>(&self, index: usize) -> (Range<usize>, [usize; M]) {
107 assert_eq!(M, D - 1);
109 #[allow(clippy::indexing_slicing, clippy::unwrap_used)]
111 let sub_dims: [usize; M] = self.dims[1..].try_into().unwrap();
112 let n = sub_dims.iter().product::<usize>();
113 (n * index..n * (index + 1), sub_dims)
114 }
115}
116
117macro_rules! impl_basic_dim {
118 ($t1:path, $t2:path, $t3:path) => {
119 impl<'a> $t1 {
120 #[allow(dead_code)]
121 pub(super) fn dim(&self) -> usize {
122 let [dim] = self.dims;
123 dim
124 }
125 }
126 impl<'a> $t2 {
127 #[allow(dead_code)]
128 pub(super) fn dim(&self) -> (usize, usize) {
129 let [d0, d1] = self.dims;
130 (d0, d1)
131 }
132 }
133 impl<'a> $t3 {
134 #[allow(dead_code)]
135 pub(super) fn dim(&self) -> (usize, usize, usize) {
136 let [d0, d1, d2] = self.dims;
137 (d0, d1, d2)
138 }
139 }
140 };
141}
142
143impl_basic_dim!(MatrixOwned<1>, MatrixOwned<2>, MatrixOwned<3>);
144impl_basic_dim!(
145 MatrixBorrowed<'a, 1>,
146 MatrixBorrowed<'a, 2>,
147 MatrixBorrowed<'a, 3>
148);
149impl_basic_dim!(
150 MatrixBorrowedMut<'a, 1>,
151 MatrixBorrowedMut<'a, 2>,
152 MatrixBorrowedMut<'a, 3>
153);
154impl_basic_dim!(MatrixZero<'a, 1>, MatrixZero<'a, 2>, MatrixZero<'a, 3>);
155
156pub(super) struct MatrixBorrowedMut<'a, const D: usize> {
158 pub(super) data: &'a mut [f32],
159 pub(super) dims: [usize; D],
160}
161
162impl<'a, const D: usize> MatrixBorrowedMut<'a, D> {
163 pub(super) fn as_borrowed(&self) -> MatrixBorrowed<D> {
164 MatrixBorrowed {
165 data: self.data,
166 dims: self.dims,
167 }
168 }
169
170 pub(super) fn as_mut_slice(&mut self) -> &mut [f32] {
171 self.data
172 }
173
174 pub(super) fn copy_submatrix<const M: usize>(&mut self, from: usize, to: usize) {
175 let (range_from, _) = self.as_borrowed().submatrix_range::<M>(from);
176 let (range_to, _) = self.as_borrowed().submatrix_range::<M>(to);
177 if let (Some(_), Some(_)) = (
178 self.data.get(range_from.clone()),
179 self.data.get(range_to.clone()),
180 ) {
181 self.data.copy_within(range_from, range_to.start);
183 }
184 }
185
186 #[must_use]
187 pub(super) fn add(&mut self, other: MatrixZero<'_, D>) -> Option<()> {
188 debug_assert_eq!(self.dims, other.dims);
189 for i in 0..self.data.len() {
191 *self.data.get_mut(i)? += other.data.get(i)?;
192 }
193 Some(())
194 }
195
196 #[allow(dead_code)] pub(super) fn softmax_transform(&mut self) {
199 for v in self.data.iter_mut() {
200 *v = v.exp();
201 }
202 let sm = 1.0 / self.data.iter().sum::<f32>();
203 for v in self.data.iter_mut() {
204 *v *= sm;
205 }
206 }
207
208 pub(super) fn sigmoid_transform(&mut self) {
209 for x in &mut self.data.iter_mut() {
210 *x = 1.0 / (1.0 + (-*x).exp());
211 }
212 }
213
214 pub(super) fn tanh_transform(&mut self) {
215 for x in &mut self.data.iter_mut() {
216 *x = x.tanh();
217 }
218 }
219
220 pub(super) fn convolve(
221 &mut self,
222 i: MatrixBorrowed<'_, D>,
223 c: MatrixBorrowed<'_, D>,
224 f: MatrixBorrowed<'_, D>,
225 ) {
226 let i = i.as_slice();
227 let c = c.as_slice();
228 let f = f.as_slice();
229 let len = self.data.len();
230 if len != i.len() || len != c.len() || len != f.len() {
231 debug_assert!(false, "LSTM matrices not the correct dimensions");
232 return;
233 }
234 for idx in 0..len {
235 unsafe {
237 *self.data.get_unchecked_mut(idx) = i.get_unchecked(idx) * c.get_unchecked(idx)
238 + self.data.get_unchecked(idx) * f.get_unchecked(idx)
239 }
240 }
241 }
242
243 pub(super) fn mul_tanh(&mut self, o: MatrixBorrowed<'_, D>, c: MatrixBorrowed<'_, D>) {
244 let o = o.as_slice();
245 let c = c.as_slice();
246 let len = self.data.len();
247 if len != o.len() || len != c.len() {
248 debug_assert!(false, "LSTM matrices not the correct dimensions");
249 return;
250 }
251 for idx in 0..len {
252 unsafe {
254 *self.data.get_unchecked_mut(idx) =
255 o.get_unchecked(idx) * c.get_unchecked(idx).tanh();
256 }
257 }
258 }
259}
260
261impl<'a> MatrixBorrowed<'a, 1> {
262 #[allow(dead_code)] pub(super) fn dot_1d(&self, other: MatrixZero<1>) -> f32 {
264 debug_assert_eq!(self.dims, other.dims);
265 unrolled_dot_1(self.data, other.data)
266 }
267}
268
269impl<'a> MatrixBorrowedMut<'a, 1> {
270 pub(super) fn add_dot_2d(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<2>) {
275 let m = a.dim();
276 let n = self.as_borrowed().dim();
277 debug_assert_eq!(
278 m,
279 b.dim().1,
280 "dims: {:?}/{:?}/{:?}",
281 self.as_borrowed().dim(),
282 a.dim(),
283 b.dim()
284 );
285 debug_assert_eq!(
286 n,
287 b.dim().0,
288 "dims: {:?}/{:?}/{:?}",
289 self.as_borrowed().dim(),
290 a.dim(),
291 b.dim()
292 );
293 for i in 0..n {
294 if let (Some(dest), Some(b_sub)) = (self.as_mut_slice().get_mut(i), b.submatrix::<1>(i))
295 {
296 *dest += unrolled_dot_1(a.data, b_sub.data);
297 } else {
298 debug_assert!(false, "unreachable: dims checked above");
299 }
300 }
301 }
302}
303
304impl<'a> MatrixBorrowedMut<'a, 2> {
305 pub(super) fn add_dot_3d_1(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<3>) {
309 let m = a.dim();
310 let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1;
311 debug_assert_eq!(
312 m,
313 b.dim().2,
314 "dims: {:?}/{:?}/{:?}",
315 self.as_borrowed().dim(),
316 a.dim(),
317 b.dim()
318 );
319 debug_assert_eq!(
320 n,
321 b.dim().0 * b.dim().1,
322 "dims: {:?}/{:?}/{:?}",
323 self.as_borrowed().dim(),
324 a.dim(),
325 b.dim()
326 );
327 let lhs = a.as_slice();
333 for i in 0..n {
334 if let (Some(dest), Some(rhs)) = (
335 self.as_mut_slice().get_mut(i),
336 b.as_slice().get_subslice(i * m..(i + 1) * m),
337 ) {
338 *dest += unrolled_dot_1(lhs, rhs);
339 } else {
340 debug_assert!(false, "unreachable: dims checked above");
341 }
342 }
343 }
344
345 pub(super) fn add_dot_3d_2(&mut self, a: MatrixZero<1>, b: MatrixZero<3>) {
349 let m = a.dim();
350 let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1;
351 debug_assert_eq!(
352 m,
353 b.dim().2,
354 "dims: {:?}/{:?}/{:?}",
355 self.as_borrowed().dim(),
356 a.dim(),
357 b.dim()
358 );
359 debug_assert_eq!(
360 n,
361 b.dim().0 * b.dim().1,
362 "dims: {:?}/{:?}/{:?}",
363 self.as_borrowed().dim(),
364 a.dim(),
365 b.dim()
366 );
367 let lhs = a.as_slice();
373 for i in 0..n {
374 if let (Some(dest), Some(rhs)) = (
375 self.as_mut_slice().get_mut(i),
376 b.as_slice().get_subslice(i * m..(i + 1) * m),
377 ) {
378 *dest += unrolled_dot_2(lhs, rhs);
379 } else {
380 debug_assert!(false, "unreachable: dims checked above");
381 }
382 }
383 }
384}
385
386#[derive(Debug, Clone, Copy)]
388pub(super) struct MatrixZero<'a, const D: usize> {
389 data: &'a ZeroSlice<f32>,
390 dims: [usize; D],
391}
392
393impl<'a> From<&'a crate::provider::LstmMatrix1<'a>> for MatrixZero<'a, 1> {
394 fn from(other: &'a crate::provider::LstmMatrix1<'a>) -> Self {
395 Self {
396 data: &other.data,
397 dims: other.dims.map(|x| x as usize),
398 }
399 }
400}
401
402impl<'a> From<&'a crate::provider::LstmMatrix2<'a>> for MatrixZero<'a, 2> {
403 fn from(other: &'a crate::provider::LstmMatrix2<'a>) -> Self {
404 Self {
405 data: &other.data,
406 dims: other.dims.map(|x| x as usize),
407 }
408 }
409}
410
411impl<'a> From<&'a crate::provider::LstmMatrix3<'a>> for MatrixZero<'a, 3> {
412 fn from(other: &'a crate::provider::LstmMatrix3<'a>) -> Self {
413 Self {
414 data: &other.data,
415 dims: other.dims.map(|x| x as usize),
416 }
417 }
418}
419
420impl<'a, const D: usize> MatrixZero<'a, D> {
421 #[allow(clippy::wrong_self_convention)] pub(super) fn to_owned(&self) -> MatrixOwned<D> {
423 MatrixOwned {
424 data: self.data.iter().collect(),
425 dims: self.dims,
426 }
427 }
428
429 pub(super) fn as_slice(&self) -> &ZeroSlice<f32> {
430 self.data
431 }
432
433 #[cfg(debug_assertions)]
434 pub(super) fn debug_assert_dims(&self, dims: [usize; D]) {
435 debug_assert_eq!(dims, self.dims);
436 let expected_len = dims.iter().product::<usize>();
437 debug_assert_eq!(expected_len, self.data.len());
438 }
439
440 #[inline]
442 pub(super) fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixZero<'a, M>> {
443 assert_eq!(M, D - 1);
445 let (range, dims) = self.submatrix_range(index);
446 let data = &self.data.get_subslice(range)?;
447 Some(MatrixZero { data, dims })
448 }
449
450 #[inline]
451 fn submatrix_range<const M: usize>(&self, index: usize) -> (Range<usize>, [usize; M]) {
452 assert_eq!(M, D - 1);
454 #[allow(clippy::indexing_slicing, clippy::unwrap_used)]
456 let sub_dims: [usize; M] = self.dims[1..].try_into().unwrap();
457 let n = sub_dims.iter().product::<usize>();
458 (n * index..n * (index + 1), sub_dims)
459 }
460}
461
462macro_rules! f32c {
463 ($ule:expr) => {
464 f32::from_unaligned($ule)
465 };
466}
467
468fn unrolled_dot_1(xs: &[f32], ys: &ZeroSlice<f32>) -> f32 {
474 debug_assert_eq!(xs.len(), ys.len());
475 let mut p = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0);
478 let xit = xs.chunks_exact(8);
479 let yit = ys.as_ule_slice().chunks_exact(8);
480 let sum = xit
481 .remainder()
482 .iter()
483 .zip(yit.remainder().iter())
484 .map(|(x, y)| x * f32c!(*y))
485 .sum::<f32>();
486 for (xx, yy) in xit.zip(yit) {
487 #[allow(clippy::unwrap_used)]
490 let [x0, x1, x2, x3, x4, x5, x6, x7] = *<&[f32; 8]>::try_from(xx).unwrap();
491 #[allow(clippy::unwrap_used)]
492 let [y0, y1, y2, y3, y4, y5, y6, y7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(yy).unwrap();
493 p.0 += x0 * f32c!(y0);
494 p.1 += x1 * f32c!(y1);
495 p.2 += x2 * f32c!(y2);
496 p.3 += x3 * f32c!(y3);
497 p.4 += x4 * f32c!(y4);
498 p.5 += x5 * f32c!(y5);
499 p.6 += x6 * f32c!(y6);
500 p.7 += x7 * f32c!(y7);
501 }
502 sum + (p.0 + p.4) + (p.1 + p.5) + (p.2 + p.6) + (p.3 + p.7)
503}
504
505fn unrolled_dot_2(xs: &ZeroSlice<f32>, ys: &ZeroSlice<f32>) -> f32 {
511 debug_assert_eq!(xs.len(), ys.len());
512 let mut p = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0);
515 let xit = xs.as_ule_slice().chunks_exact(8);
516 let yit = ys.as_ule_slice().chunks_exact(8);
517 let sum = xit
518 .remainder()
519 .iter()
520 .zip(yit.remainder().iter())
521 .map(|(x, y)| f32c!(*x) * f32c!(*y))
522 .sum::<f32>();
523 for (xx, yy) in xit.zip(yit) {
524 #[allow(clippy::unwrap_used)]
527 let [x0, x1, x2, x3, x4, x5, x6, x7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(xx).unwrap();
528 #[allow(clippy::unwrap_used)]
529 let [y0, y1, y2, y3, y4, y5, y6, y7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(yy).unwrap();
530 p.0 += f32c!(x0) * f32c!(y0);
531 p.1 += f32c!(x1) * f32c!(y1);
532 p.2 += f32c!(x2) * f32c!(y2);
533 p.3 += f32c!(x3) * f32c!(y3);
534 p.4 += f32c!(x4) * f32c!(y4);
535 p.5 += f32c!(x5) * f32c!(y5);
536 p.6 += f32c!(x6) * f32c!(y6);
537 p.7 += f32c!(x7) * f32c!(y7);
538 }
539 sum + (p.0 + p.4) + (p.1 + p.5) + (p.2 + p.6) + (p.3 + p.7)
540}