icu_segmenter/provider/
lstm.rs

1// This file is part of ICU4X. For terms of use, please see the file
2// called LICENSE at the top level of the ICU4X source tree
3// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
5//! Data provider struct definitions for the lstm
6
7// Provider structs must be stable
8#![allow(clippy::exhaustive_structs, clippy::exhaustive_enums)]
9
10use icu_provider::prelude::*;
11use zerovec::{ule::UnvalidatedStr, ZeroMap, ZeroVec};
12
13// We do this instead of const generics because ZeroFrom and Yokeable derives, as well as serde
14// don't support them
15macro_rules! lstm_matrix {
16    ($name:ident, $generic:literal) => {
17        /// The struct that stores a LSTM's matrix.
18        ///
19        /// <div class="stab unstable">
20        /// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways,
21        /// including in SemVer minor releases. While the serde representation of data structs is guaranteed
22        /// to be stable, their Rust representation might not be. Use with caution.
23        /// </div>
24        #[derive(PartialEq, Debug, Clone, zerofrom::ZeroFrom, yoke::Yokeable)]
25        #[cfg_attr(feature = "datagen", derive(serde::Serialize))]
26        pub struct $name<'data> {
27            // Invariant: dims.product() == data.len()
28            #[allow(missing_docs)]
29            pub(crate) dims: [u16; $generic],
30            #[allow(missing_docs)]
31            pub(crate) data: ZeroVec<'data, f32>,
32        }
33
34        impl<'data> $name<'data> {
35            #[cfg(any(feature = "serde", feature = "datagen"))]
36            /// Creates a LstmMatrix with the given dimensions. Fails if the dimensions don't match the data.
37            pub fn from_parts(
38                dims: [u16; $generic],
39                data: ZeroVec<'data, f32>,
40            ) -> Result<Self, DataError> {
41                if dims.iter().map(|&i| i as usize).product::<usize>() != data.len() {
42                    Err(DataError::custom("Dimension mismatch"))
43                } else {
44                    Ok(Self { dims, data })
45                }
46            }
47
48            #[doc(hidden)] // databake
49            pub const fn from_parts_unchecked(
50                dims: [u16; $generic],
51                data: ZeroVec<'data, f32>,
52            ) -> Self {
53                Self { dims, data }
54            }
55        }
56
57        #[cfg(feature = "serde")]
58        impl<'de: 'data, 'data> serde::Deserialize<'de> for $name<'data> {
59            fn deserialize<S>(deserializer: S) -> Result<Self, S::Error>
60            where
61                S: serde::de::Deserializer<'de>,
62            {
63                #[derive(serde::Deserialize)]
64                struct Raw<'data> {
65                    dims: [u16; $generic],
66                    #[serde(borrow)]
67                    data: ZeroVec<'data, f32>,
68                }
69
70                let raw = Raw::deserialize(deserializer)?;
71
72                use serde::de::Error;
73                Self::from_parts(raw.dims, raw.data)
74                    .map_err(|_| S::Error::custom("Dimension mismatch"))
75            }
76        }
77
78        #[cfg(feature = "datagen")]
79        impl databake::Bake for $name<'_> {
80            fn bake(&self, env: &databake::CrateEnv) -> databake::TokenStream {
81                let dims = self.dims.bake(env);
82                let data = self.data.bake(env);
83                databake::quote! {
84                    icu_segmenter::provider::$name::from_parts_unchecked(#dims, #data)
85                }
86            }
87        }
88    };
89}
90
91lstm_matrix!(LstmMatrix1, 1);
92lstm_matrix!(LstmMatrix2, 2);
93lstm_matrix!(LstmMatrix3, 3);
94
95#[derive(PartialEq, Debug, Clone, Copy)]
96#[cfg_attr(
97    feature = "datagen",
98    derive(serde::Serialize,databake::Bake),
99    databake(path = icu_segmenter::provider),
100)]
101#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
102/// The type of LSTM model
103///
104/// <div class="stab unstable">
105/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways,
106/// including in SemVer minor releases. While the serde representation of data structs is guaranteed
107/// to be stable, their Rust representation might not be. Use with caution.
108/// </div>
109pub enum ModelType {
110    /// A model working on code points
111    Codepoints,
112    /// A model working on grapheme clusters
113    GraphemeClusters,
114}
115
116/// The struct that stores a LSTM model.
117///
118/// <div class="stab unstable">
119/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways,
120/// including in SemVer minor releases. While the serde representation of data structs is guaranteed
121/// to be stable, their Rust representation might not be. Use with caution.
122/// </div>
123#[derive(PartialEq, Debug, Clone, yoke::Yokeable, zerofrom::ZeroFrom)]
124#[cfg_attr(feature = "datagen", derive(serde::Serialize))]
125#[yoke(prove_covariance_manually)]
126pub struct LstmDataFloat32<'data> {
127    /// Type of the model
128    pub(crate) model: ModelType,
129    /// The grapheme cluster dictionary used to train the model
130    pub(crate) dic: ZeroMap<'data, UnvalidatedStr, u16>,
131    /// The embedding layer. Shape (dic.len + 1, e)
132    pub(crate) embedding: LstmMatrix2<'data>,
133    /// The forward layer's first matrix. Shape (h, 4, e)
134    pub(crate) fw_w: LstmMatrix3<'data>,
135    /// The forward layer's second matrix. Shape (h, 4, h)
136    pub(crate) fw_u: LstmMatrix3<'data>,
137    /// The forward layer's bias. Shape (h, 4)
138    pub(crate) fw_b: LstmMatrix2<'data>,
139    /// The backward layer's first matrix. Shape (h, 4, e)
140    pub(crate) bw_w: LstmMatrix3<'data>,
141    /// The backward layer's second matrix. Shape (h, 4, h)
142    pub(crate) bw_u: LstmMatrix3<'data>,
143    /// The backward layer's bias. Shape (h, 4)
144    pub(crate) bw_b: LstmMatrix2<'data>,
145    /// The output layer's weights. Shape (2, 4, h)
146    pub(crate) time_w: LstmMatrix3<'data>,
147    /// The output layer's bias. Shape (4)
148    pub(crate) time_b: LstmMatrix1<'data>,
149}
150
151impl<'data> LstmDataFloat32<'data> {
152    #[doc(hidden)] // databake
153    #[allow(clippy::too_many_arguments)] // constructor
154    pub const fn from_parts_unchecked(
155        model: ModelType,
156        dic: ZeroMap<'data, UnvalidatedStr, u16>,
157        embedding: LstmMatrix2<'data>,
158        fw_w: LstmMatrix3<'data>,
159        fw_u: LstmMatrix3<'data>,
160        fw_b: LstmMatrix2<'data>,
161        bw_w: LstmMatrix3<'data>,
162        bw_u: LstmMatrix3<'data>,
163        bw_b: LstmMatrix2<'data>,
164        time_w: LstmMatrix3<'data>,
165        time_b: LstmMatrix1<'data>,
166    ) -> Self {
167        Self {
168            model,
169            dic,
170            embedding,
171            fw_w,
172            fw_u,
173            fw_b,
174            bw_w,
175            bw_u,
176            bw_b,
177            time_w,
178            time_b,
179        }
180    }
181
182    #[cfg(any(feature = "serde", feature = "datagen"))]
183    /// Creates a LstmDataFloat32 with the given data. Fails if the matrix dimensions are inconsistent.
184    #[allow(clippy::too_many_arguments)] // constructor
185    pub fn try_from_parts(
186        model: ModelType,
187        dic: ZeroMap<'data, UnvalidatedStr, u16>,
188        embedding: LstmMatrix2<'data>,
189        fw_w: LstmMatrix3<'data>,
190        fw_u: LstmMatrix3<'data>,
191        fw_b: LstmMatrix2<'data>,
192        bw_w: LstmMatrix3<'data>,
193        bw_u: LstmMatrix3<'data>,
194        bw_b: LstmMatrix2<'data>,
195        time_w: LstmMatrix3<'data>,
196        time_b: LstmMatrix1<'data>,
197    ) -> Result<Self, DataError> {
198        let dic_len = u16::try_from(dic.len())
199            .map_err(|_| DataError::custom("Dictionary does not fit in u16"))?;
200
201        let num_classes = embedding.dims[0];
202        let embedd_dim = embedding.dims[1];
203        let hunits = fw_u.dims[2];
204        if num_classes - 1 != dic_len
205            || fw_w.dims != [4, hunits, embedd_dim]
206            || fw_u.dims != [4, hunits, hunits]
207            || fw_b.dims != [4, hunits]
208            || bw_w.dims != [4, hunits, embedd_dim]
209            || bw_u.dims != [4, hunits, hunits]
210            || bw_b.dims != [4, hunits]
211            || time_w.dims != [2, 4, hunits]
212            || time_b.dims != [4]
213        {
214            return Err(DataError::custom("LSTM dimension mismatch"));
215        }
216
217        #[cfg(debug_assertions)]
218        if !dic.iter_copied_values().all(|(_, g)| g < dic_len) {
219            return Err(DataError::custom("Invalid cluster id"));
220        }
221
222        Ok(Self {
223            model,
224            dic,
225            embedding,
226            fw_w,
227            fw_u,
228            fw_b,
229            bw_w,
230            bw_u,
231            bw_b,
232            time_w,
233            time_b,
234        })
235    }
236}
237
238#[cfg(feature = "serde")]
239impl<'de: 'data, 'data> serde::Deserialize<'de> for LstmDataFloat32<'data> {
240    fn deserialize<S>(deserializer: S) -> Result<Self, S::Error>
241    where
242        S: serde::de::Deserializer<'de>,
243    {
244        #[derive(serde::Deserialize)]
245        struct Raw<'data> {
246            model: ModelType,
247            #[cfg_attr(feature = "serde", serde(borrow))]
248            dic: ZeroMap<'data, UnvalidatedStr, u16>,
249            #[cfg_attr(feature = "serde", serde(borrow))]
250            embedding: LstmMatrix2<'data>,
251            #[cfg_attr(feature = "serde", serde(borrow))]
252            fw_w: LstmMatrix3<'data>,
253            #[cfg_attr(feature = "serde", serde(borrow))]
254            fw_u: LstmMatrix3<'data>,
255            #[cfg_attr(feature = "serde", serde(borrow))]
256            fw_b: LstmMatrix2<'data>,
257            #[cfg_attr(feature = "serde", serde(borrow))]
258            bw_w: LstmMatrix3<'data>,
259            #[cfg_attr(feature = "serde", serde(borrow))]
260            bw_u: LstmMatrix3<'data>,
261            #[cfg_attr(feature = "serde", serde(borrow))]
262            bw_b: LstmMatrix2<'data>,
263            #[cfg_attr(feature = "serde", serde(borrow))]
264            time_w: LstmMatrix3<'data>,
265            #[cfg_attr(feature = "serde", serde(borrow))]
266            time_b: LstmMatrix1<'data>,
267        }
268
269        let raw = Raw::deserialize(deserializer)?;
270
271        use serde::de::Error;
272        Self::try_from_parts(
273            raw.model,
274            raw.dic,
275            raw.embedding,
276            raw.fw_w,
277            raw.fw_u,
278            raw.fw_b,
279            raw.bw_w,
280            raw.bw_u,
281            raw.bw_b,
282            raw.time_w,
283            raw.time_b,
284        )
285        .map_err(|_| S::Error::custom("Invalid dimensions"))
286    }
287}
288
289#[cfg(feature = "datagen")]
290impl databake::Bake for LstmDataFloat32<'_> {
291    fn bake(&self, env: &databake::CrateEnv) -> databake::TokenStream {
292        let model = self.model.bake(env);
293        let dic = self.dic.bake(env);
294        let embedding = self.embedding.bake(env);
295        let fw_w = self.fw_w.bake(env);
296        let fw_u = self.fw_u.bake(env);
297        let fw_b = self.fw_b.bake(env);
298        let bw_w = self.bw_w.bake(env);
299        let bw_u = self.bw_u.bake(env);
300        let bw_b = self.bw_b.bake(env);
301        let time_w = self.time_w.bake(env);
302        let time_b = self.time_b.bake(env);
303        databake::quote! {
304            icu_segmenter::provider::LstmDataFloat32::from_parts_unchecked(
305                #model,
306                #dic,
307                #embedding,
308                #fw_w,
309                #fw_u,
310                #fw_b,
311                #bw_w,
312                #bw_u,
313                #bw_b,
314                #time_w,
315                #time_b,
316            )
317        }
318    }
319}
320
321/// The data to power the LSTM segmentation model.
322///
323/// This data enum is extensible: more backends may be added in the future.
324/// Old data can be used with newer code but not vice versa.
325///
326/// Examples of possible future extensions:
327///
328/// 1. Variant to store data in 16 instead of 32 bits
329/// 2. Minor changes to the LSTM model, such as different forward/backward matrix sizes
330///
331/// <div class="stab unstable">
332/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways,
333/// including in SemVer minor releases. While the serde representation of data structs is guaranteed
334/// to be stable, their Rust representation might not be. Use with caution.
335/// </div>
336#[icu_provider::data_struct(LstmForWordLineAutoV1Marker = "segmenter/lstm/wl_auto@1")]
337#[derive(Debug, PartialEq, Clone)]
338#[cfg_attr(
339    feature = "datagen", 
340    derive(serde::Serialize, databake::Bake),
341    databake(path = icu_segmenter::provider),
342)]
343#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
344#[yoke(prove_covariance_manually)]
345#[non_exhaustive]
346pub enum LstmDataV1<'data> {
347    /// The data as matrices of zerovec f32 values.
348    Float32(#[cfg_attr(feature = "serde", serde(borrow))] LstmDataFloat32<'data>),
349    // new variants should go BELOW existing ones
350    // Serde serializes based on variant name and index in the enum
351    // https://docs.rs/serde/latest/serde/trait.Serializer.html#tymethod.serialize_unit_variant
352}