1#![allow(clippy::exhaustive_structs, clippy::exhaustive_enums)]
9
10use icu_provider::prelude::*;
11use zerovec::{ule::UnvalidatedStr, ZeroMap, ZeroVec};
12
13macro_rules! lstm_matrix {
16 ($name:ident, $generic:literal) => {
17 #[derive(PartialEq, Debug, Clone, zerofrom::ZeroFrom, yoke::Yokeable)]
25 #[cfg_attr(feature = "datagen", derive(serde::Serialize))]
26 pub struct $name<'data> {
27 #[allow(missing_docs)]
29 pub(crate) dims: [u16; $generic],
30 #[allow(missing_docs)]
31 pub(crate) data: ZeroVec<'data, f32>,
32 }
33
34 impl<'data> $name<'data> {
35 #[cfg(any(feature = "serde", feature = "datagen"))]
36 pub fn from_parts(
38 dims: [u16; $generic],
39 data: ZeroVec<'data, f32>,
40 ) -> Result<Self, DataError> {
41 if dims.iter().map(|&i| i as usize).product::<usize>() != data.len() {
42 Err(DataError::custom("Dimension mismatch"))
43 } else {
44 Ok(Self { dims, data })
45 }
46 }
47
48 #[doc(hidden)] pub const fn from_parts_unchecked(
50 dims: [u16; $generic],
51 data: ZeroVec<'data, f32>,
52 ) -> Self {
53 Self { dims, data }
54 }
55 }
56
57 #[cfg(feature = "serde")]
58 impl<'de: 'data, 'data> serde::Deserialize<'de> for $name<'data> {
59 fn deserialize<S>(deserializer: S) -> Result<Self, S::Error>
60 where
61 S: serde::de::Deserializer<'de>,
62 {
63 #[derive(serde::Deserialize)]
64 struct Raw<'data> {
65 dims: [u16; $generic],
66 #[serde(borrow)]
67 data: ZeroVec<'data, f32>,
68 }
69
70 let raw = Raw::deserialize(deserializer)?;
71
72 use serde::de::Error;
73 Self::from_parts(raw.dims, raw.data)
74 .map_err(|_| S::Error::custom("Dimension mismatch"))
75 }
76 }
77
78 #[cfg(feature = "datagen")]
79 impl databake::Bake for $name<'_> {
80 fn bake(&self, env: &databake::CrateEnv) -> databake::TokenStream {
81 let dims = self.dims.bake(env);
82 let data = self.data.bake(env);
83 databake::quote! {
84 icu_segmenter::provider::$name::from_parts_unchecked(#dims, #data)
85 }
86 }
87 }
88 };
89}
90
91lstm_matrix!(LstmMatrix1, 1);
92lstm_matrix!(LstmMatrix2, 2);
93lstm_matrix!(LstmMatrix3, 3);
94
95#[derive(PartialEq, Debug, Clone, Copy)]
96#[cfg_attr(
97 feature = "datagen",
98 derive(serde::Serialize,databake::Bake),
99 databake(path = icu_segmenter::provider),
100)]
101#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
102pub enum ModelType {
110 Codepoints,
112 GraphemeClusters,
114}
115
116#[derive(PartialEq, Debug, Clone, yoke::Yokeable, zerofrom::ZeroFrom)]
124#[cfg_attr(feature = "datagen", derive(serde::Serialize))]
125#[yoke(prove_covariance_manually)]
126pub struct LstmDataFloat32<'data> {
127 pub(crate) model: ModelType,
129 pub(crate) dic: ZeroMap<'data, UnvalidatedStr, u16>,
131 pub(crate) embedding: LstmMatrix2<'data>,
133 pub(crate) fw_w: LstmMatrix3<'data>,
135 pub(crate) fw_u: LstmMatrix3<'data>,
137 pub(crate) fw_b: LstmMatrix2<'data>,
139 pub(crate) bw_w: LstmMatrix3<'data>,
141 pub(crate) bw_u: LstmMatrix3<'data>,
143 pub(crate) bw_b: LstmMatrix2<'data>,
145 pub(crate) time_w: LstmMatrix3<'data>,
147 pub(crate) time_b: LstmMatrix1<'data>,
149}
150
151impl<'data> LstmDataFloat32<'data> {
152 #[doc(hidden)] #[allow(clippy::too_many_arguments)] pub const fn from_parts_unchecked(
155 model: ModelType,
156 dic: ZeroMap<'data, UnvalidatedStr, u16>,
157 embedding: LstmMatrix2<'data>,
158 fw_w: LstmMatrix3<'data>,
159 fw_u: LstmMatrix3<'data>,
160 fw_b: LstmMatrix2<'data>,
161 bw_w: LstmMatrix3<'data>,
162 bw_u: LstmMatrix3<'data>,
163 bw_b: LstmMatrix2<'data>,
164 time_w: LstmMatrix3<'data>,
165 time_b: LstmMatrix1<'data>,
166 ) -> Self {
167 Self {
168 model,
169 dic,
170 embedding,
171 fw_w,
172 fw_u,
173 fw_b,
174 bw_w,
175 bw_u,
176 bw_b,
177 time_w,
178 time_b,
179 }
180 }
181
182 #[cfg(any(feature = "serde", feature = "datagen"))]
183 #[allow(clippy::too_many_arguments)] pub fn try_from_parts(
186 model: ModelType,
187 dic: ZeroMap<'data, UnvalidatedStr, u16>,
188 embedding: LstmMatrix2<'data>,
189 fw_w: LstmMatrix3<'data>,
190 fw_u: LstmMatrix3<'data>,
191 fw_b: LstmMatrix2<'data>,
192 bw_w: LstmMatrix3<'data>,
193 bw_u: LstmMatrix3<'data>,
194 bw_b: LstmMatrix2<'data>,
195 time_w: LstmMatrix3<'data>,
196 time_b: LstmMatrix1<'data>,
197 ) -> Result<Self, DataError> {
198 let dic_len = u16::try_from(dic.len())
199 .map_err(|_| DataError::custom("Dictionary does not fit in u16"))?;
200
201 let num_classes = embedding.dims[0];
202 let embedd_dim = embedding.dims[1];
203 let hunits = fw_u.dims[2];
204 if num_classes - 1 != dic_len
205 || fw_w.dims != [4, hunits, embedd_dim]
206 || fw_u.dims != [4, hunits, hunits]
207 || fw_b.dims != [4, hunits]
208 || bw_w.dims != [4, hunits, embedd_dim]
209 || bw_u.dims != [4, hunits, hunits]
210 || bw_b.dims != [4, hunits]
211 || time_w.dims != [2, 4, hunits]
212 || time_b.dims != [4]
213 {
214 return Err(DataError::custom("LSTM dimension mismatch"));
215 }
216
217 #[cfg(debug_assertions)]
218 if !dic.iter_copied_values().all(|(_, g)| g < dic_len) {
219 return Err(DataError::custom("Invalid cluster id"));
220 }
221
222 Ok(Self {
223 model,
224 dic,
225 embedding,
226 fw_w,
227 fw_u,
228 fw_b,
229 bw_w,
230 bw_u,
231 bw_b,
232 time_w,
233 time_b,
234 })
235 }
236}
237
238#[cfg(feature = "serde")]
239impl<'de: 'data, 'data> serde::Deserialize<'de> for LstmDataFloat32<'data> {
240 fn deserialize<S>(deserializer: S) -> Result<Self, S::Error>
241 where
242 S: serde::de::Deserializer<'de>,
243 {
244 #[derive(serde::Deserialize)]
245 struct Raw<'data> {
246 model: ModelType,
247 #[cfg_attr(feature = "serde", serde(borrow))]
248 dic: ZeroMap<'data, UnvalidatedStr, u16>,
249 #[cfg_attr(feature = "serde", serde(borrow))]
250 embedding: LstmMatrix2<'data>,
251 #[cfg_attr(feature = "serde", serde(borrow))]
252 fw_w: LstmMatrix3<'data>,
253 #[cfg_attr(feature = "serde", serde(borrow))]
254 fw_u: LstmMatrix3<'data>,
255 #[cfg_attr(feature = "serde", serde(borrow))]
256 fw_b: LstmMatrix2<'data>,
257 #[cfg_attr(feature = "serde", serde(borrow))]
258 bw_w: LstmMatrix3<'data>,
259 #[cfg_attr(feature = "serde", serde(borrow))]
260 bw_u: LstmMatrix3<'data>,
261 #[cfg_attr(feature = "serde", serde(borrow))]
262 bw_b: LstmMatrix2<'data>,
263 #[cfg_attr(feature = "serde", serde(borrow))]
264 time_w: LstmMatrix3<'data>,
265 #[cfg_attr(feature = "serde", serde(borrow))]
266 time_b: LstmMatrix1<'data>,
267 }
268
269 let raw = Raw::deserialize(deserializer)?;
270
271 use serde::de::Error;
272 Self::try_from_parts(
273 raw.model,
274 raw.dic,
275 raw.embedding,
276 raw.fw_w,
277 raw.fw_u,
278 raw.fw_b,
279 raw.bw_w,
280 raw.bw_u,
281 raw.bw_b,
282 raw.time_w,
283 raw.time_b,
284 )
285 .map_err(|_| S::Error::custom("Invalid dimensions"))
286 }
287}
288
289#[cfg(feature = "datagen")]
290impl databake::Bake for LstmDataFloat32<'_> {
291 fn bake(&self, env: &databake::CrateEnv) -> databake::TokenStream {
292 let model = self.model.bake(env);
293 let dic = self.dic.bake(env);
294 let embedding = self.embedding.bake(env);
295 let fw_w = self.fw_w.bake(env);
296 let fw_u = self.fw_u.bake(env);
297 let fw_b = self.fw_b.bake(env);
298 let bw_w = self.bw_w.bake(env);
299 let bw_u = self.bw_u.bake(env);
300 let bw_b = self.bw_b.bake(env);
301 let time_w = self.time_w.bake(env);
302 let time_b = self.time_b.bake(env);
303 databake::quote! {
304 icu_segmenter::provider::LstmDataFloat32::from_parts_unchecked(
305 #model,
306 #dic,
307 #embedding,
308 #fw_w,
309 #fw_u,
310 #fw_b,
311 #bw_w,
312 #bw_u,
313 #bw_b,
314 #time_w,
315 #time_b,
316 )
317 }
318 }
319}
320
321#[icu_provider::data_struct(LstmForWordLineAutoV1Marker = "segmenter/lstm/wl_auto@1")]
337#[derive(Debug, PartialEq, Clone)]
338#[cfg_attr(
339 feature = "datagen",
340 derive(serde::Serialize, databake::Bake),
341 databake(path = icu_segmenter::provider),
342)]
343#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
344#[yoke(prove_covariance_manually)]
345#[non_exhaustive]
346pub enum LstmDataV1<'data> {
347 Float32(#[cfg_attr(feature = "serde", serde(borrow))] LstmDataFloat32<'data>),
349 }