storage/indexeddb/engines/
sqlite.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6
7use base::threadpool::ThreadPool;
8use log::{error, info};
9use profile_traits::generic_callback::GenericCallback;
10use rusqlite::{Connection, Error, OptionalExtension, params};
11use sea_query::{Condition, Expr, ExprTrait, IntoCondition, SqliteQueryBuilder};
12use sea_query_rusqlite::RusqliteBinder;
13use serde::{Deserialize, Serialize};
14use storage_traits::indexeddb::{
15    AsyncOperation, AsyncReadOnlyOperation, AsyncReadWriteOperation, BackendError, BackendResult,
16    CreateObjectResult, IndexedDBKeyRange, IndexedDBKeyType, IndexedDBRecord, IndexedDBTxnMode,
17    KeyPath, PutItemResult,
18};
19use tokio::sync::oneshot;
20
21use crate::indexeddb::IndexedDBDescription;
22use crate::indexeddb::engines::{KvsEngine, KvsTransaction};
23use crate::shared::{DB_INIT_PRAGMAS, DB_PRAGMAS};
24
25mod create;
26mod database_model;
27mod object_data_model;
28mod object_store_index_model;
29mod object_store_model;
30
31fn range_to_query(range: IndexedDBKeyRange) -> Condition {
32    // Special case for optimization
33    if let Some(singleton) = range.as_singleton() {
34        let encoded = bincode::serialize(singleton).unwrap();
35        return Expr::column(object_data_model::Column::Key)
36            .eq(encoded)
37            .into_condition();
38    }
39    let mut parts = vec![];
40    if let Some(upper) = range.upper.as_ref() {
41        let upper_bytes = bincode::serialize(upper).unwrap();
42        let query = if range.upper_open {
43            Expr::column(object_data_model::Column::Key).lt(upper_bytes)
44        } else {
45            Expr::column(object_data_model::Column::Key).lte(upper_bytes)
46        };
47        parts.push(query);
48    }
49    if let Some(lower) = range.lower.as_ref() {
50        let lower_bytes = bincode::serialize(lower).unwrap();
51        let query = if range.lower_open {
52            Expr::column(object_data_model::Column::Key).gt(lower_bytes)
53        } else {
54            Expr::column(object_data_model::Column::Key).gte(lower_bytes)
55        };
56        parts.push(query);
57    }
58    let mut condition = Condition::all();
59    for part in parts {
60        condition = condition.add(part);
61    }
62    condition
63}
64
65pub struct SqliteEngine {
66    db_path: PathBuf,
67    connection: Connection,
68    read_pool: Arc<ThreadPool>,
69    write_pool: Arc<ThreadPool>,
70    created_db_path: bool,
71}
72
73impl SqliteEngine {
74    // TODO: intake dual pools
75    pub fn new(
76        base_dir: &Path,
77        db_info: &IndexedDBDescription,
78        pool: Arc<ThreadPool>,
79    ) -> Result<Self, Error> {
80        let mut db_path = PathBuf::new();
81        db_path.push(base_dir);
82        db_path.push(db_info.as_path());
83        let db_parent = db_path.clone();
84        db_path.push("db.sqlite");
85
86        let created_db_path = if !db_path.exists() {
87            std::fs::create_dir_all(db_parent).unwrap();
88            std::fs::File::create(&db_path).unwrap();
89            true
90        } else {
91            false
92        };
93
94        let connection = Self::init_db(&db_path, db_info)?;
95
96        for stmt in DB_PRAGMAS {
97            // TODO: Handle errors properly
98            let _ = connection.execute(stmt, ());
99        }
100
101        Ok(Self {
102            connection,
103            db_path,
104            read_pool: pool.clone(),
105            write_pool: pool,
106            created_db_path,
107        })
108    }
109
110    /// Returns whether the physical db was created as part of `new`.
111    pub(crate) fn created_db_path(&self) -> bool {
112        self.created_db_path
113    }
114
115    fn init_db(path: &Path, db_info: &IndexedDBDescription) -> Result<Connection, Error> {
116        let connection = Connection::open(path)?;
117        if connection.table_exists(None, "database")? {
118            // Database already exists, no need to initialize
119            return Ok(connection);
120        }
121        info!("Initializing indexeddb database at {:?}", path);
122        for stmt in DB_INIT_PRAGMAS {
123            // FIXME(arihant2math): this fails occasionally
124            let _ = connection.execute(stmt, ());
125        }
126        create::create_tables(&connection)?;
127        // From https://w3c.github.io/IndexedDB/#database-version:
128        // "When a database is first created, its version is 0 (zero)."
129        connection.execute(
130            "INSERT INTO database (name, origin, version) VALUES (?, ?, ?)",
131            params![
132                db_info.name.to_owned(),
133                db_info.origin.to_owned().ascii_serialization(),
134                i64::from_ne_bytes(0_u64.to_ne_bytes())
135            ],
136        )?;
137        Ok(connection)
138    }
139
140    fn get(
141        connection: &Connection,
142        store: object_store_model::Model,
143        key_range: IndexedDBKeyRange,
144    ) -> Result<Option<object_data_model::Model>, Error> {
145        let query = range_to_query(key_range);
146        let (sql, values) = sea_query::Query::select()
147            .from(object_data_model::Column::Table)
148            .columns(vec![
149                object_data_model::Column::ObjectStoreId,
150                object_data_model::Column::Key,
151                object_data_model::Column::Data,
152            ])
153            .and_where(query.and(Expr::col(object_data_model::Column::ObjectStoreId).is(store.id)))
154            .limit(1)
155            .build_rusqlite(SqliteQueryBuilder);
156        connection
157            .prepare(&sql)?
158            .query_one(&*values.as_params(), |row| {
159                object_data_model::Model::try_from(row)
160            })
161            .optional()
162    }
163
164    fn get_key(
165        connection: &Connection,
166        store: object_store_model::Model,
167        key_range: IndexedDBKeyRange,
168    ) -> Result<Option<Vec<u8>>, Error> {
169        Self::get(connection, store, key_range).map(|opt| opt.map(|model| model.key))
170    }
171
172    fn get_item(
173        connection: &Connection,
174        store: object_store_model::Model,
175        key_range: IndexedDBKeyRange,
176    ) -> Result<Option<Vec<u8>>, Error> {
177        Self::get(connection, store, key_range).map(|opt| opt.map(|model| model.data))
178    }
179
180    fn get_all(
181        connection: &Connection,
182        store: object_store_model::Model,
183        key_range: IndexedDBKeyRange,
184        count: Option<u32>,
185    ) -> Result<Vec<object_data_model::Model>, Error> {
186        let query = range_to_query(key_range);
187        let mut sql_query = sea_query::Query::select();
188        sql_query
189            .from(object_data_model::Column::Table)
190            .columns(vec![
191                object_data_model::Column::ObjectStoreId,
192                object_data_model::Column::Key,
193                object_data_model::Column::Data,
194            ])
195            .and_where(query.and(Expr::col(object_data_model::Column::ObjectStoreId).is(store.id)));
196        if let Some(count) = count {
197            sql_query.limit(count as u64);
198        }
199        let (sql, values) = sql_query.build_rusqlite(SqliteQueryBuilder);
200        let mut stmt = connection.prepare(&sql)?;
201        let models = stmt
202            .query_and_then(&*values.as_params(), |row| {
203                object_data_model::Model::try_from(row)
204            })?
205            .collect::<Result<Vec<_>, _>>()?;
206        Ok(models)
207    }
208
209    fn get_all_keys(
210        connection: &Connection,
211        store: object_store_model::Model,
212        key_range: IndexedDBKeyRange,
213        count: Option<u32>,
214    ) -> Result<Vec<Vec<u8>>, Error> {
215        Self::get_all(connection, store, key_range, count)
216            .map(|models| models.into_iter().map(|m| m.key).collect())
217    }
218
219    fn get_all_items(
220        connection: &Connection,
221        store: object_store_model::Model,
222        key_range: IndexedDBKeyRange,
223        count: Option<u32>,
224    ) -> Result<Vec<Vec<u8>>, Error> {
225        Self::get_all(connection, store, key_range, count)
226            .map(|models| models.into_iter().map(|m| m.data).collect())
227    }
228
229    #[expect(clippy::type_complexity)]
230    fn get_all_records(
231        connection: &Connection,
232        store: object_store_model::Model,
233        key_range: IndexedDBKeyRange,
234    ) -> Result<Vec<(Vec<u8>, Vec<u8>)>, Error> {
235        Self::get_all(connection, store, key_range, None)
236            .map(|models| models.into_iter().map(|m| (m.key, m.data)).collect())
237    }
238
239    fn put_item(
240        connection: &Connection,
241        store: object_store_model::Model,
242        serialized_key: Vec<u8>,
243        value: Vec<u8>,
244        should_overwrite: bool,
245    ) -> Result<PutItemResult, Error> {
246        let existing_item = connection
247            .prepare("SELECT * FROM object_data WHERE key = ? AND object_store_id = ?")
248            .and_then(|mut stmt| {
249                stmt.query_row(params![serialized_key, store.id], |row| {
250                    object_data_model::Model::try_from(row)
251                })
252                .optional()
253            })?;
254        if should_overwrite || existing_item.is_none() {
255            connection.execute(
256                "INSERT INTO object_data (object_store_id, key, data) VALUES (?, ?, ?)",
257                params![store.id, serialized_key, value],
258            )?;
259            Ok(PutItemResult::Success)
260        } else {
261            Ok(PutItemResult::CannotOverwrite)
262        }
263    }
264
265    fn delete_item(
266        connection: &Connection,
267        store: object_store_model::Model,
268        key_range: IndexedDBKeyRange,
269    ) -> Result<(), Error> {
270        let query = range_to_query(key_range);
271        let (sql, values) = sea_query::Query::delete()
272            .from_table(object_data_model::Column::Table)
273            .and_where(query.and(Expr::col(object_data_model::Column::ObjectStoreId).is(store.id)))
274            .build_rusqlite(SqliteQueryBuilder);
275        connection.prepare(&sql)?.execute(&*values.as_params())?;
276        Ok(())
277    }
278
279    fn clear(connection: &Connection, store: object_store_model::Model) -> Result<(), Error> {
280        connection.execute(
281            "DELETE FROM object_data WHERE object_store_id = ?",
282            params![store.id],
283        )?;
284        Ok(())
285    }
286
287    fn count(
288        connection: &Connection,
289        store: object_store_model::Model,
290        key_range: IndexedDBKeyRange,
291    ) -> Result<usize, Error> {
292        let query = range_to_query(key_range);
293        let (sql, values) = sea_query::Query::select()
294            .expr(Expr::col(object_data_model::Column::Key).count())
295            .from(object_data_model::Column::Table)
296            .and_where(query.and(Expr::col(object_data_model::Column::ObjectStoreId).is(store.id)))
297            .build_rusqlite(SqliteQueryBuilder);
298        connection
299            .prepare(&sql)?
300            .query_row(&*values.as_params(), |row| row.get(0))
301            .map(|count: i64| count as usize)
302    }
303
304    fn generate_key(
305        connection: &Connection,
306        store: &object_store_model::Model,
307    ) -> Result<IndexedDBKeyType, Error> {
308        if store.auto_increment == 0 {
309            unreachable!("Should be caught in the script thread");
310        }
311        // TODO: handle overflows, this also needs to be able to handle 2^53 as per spec
312        let new_key = store.auto_increment + 1;
313        connection.execute(
314            "UPDATE object_store SET auto_increment = ? WHERE id = ?",
315            params![new_key, store.id],
316        )?;
317        Ok(IndexedDBKeyType::Number(new_key as f64))
318    }
319}
320
321impl KvsEngine for SqliteEngine {
322    type Error = Error;
323
324    fn create_store(
325        &self,
326        store_name: &str,
327        key_path: Option<KeyPath>,
328        auto_increment: bool,
329    ) -> Result<CreateObjectResult, Self::Error> {
330        let mut stmt = self
331            .connection
332            .prepare("SELECT * FROM object_store WHERE name = ?")?;
333        if stmt.exists(params![store_name.to_string()])? {
334            // Store already exists
335            return Ok(CreateObjectResult::AlreadyExists);
336        }
337        self.connection.execute(
338            "INSERT INTO object_store (name, key_path, auto_increment) VALUES (?, ?, ?)",
339            params![
340                store_name.to_string(),
341                key_path.map(|v| bincode::serialize(&v).unwrap()),
342                auto_increment as i32
343            ],
344        )?;
345
346        Ok(CreateObjectResult::Created)
347    }
348
349    fn delete_store(&self, store_name: &str) -> Result<(), Self::Error> {
350        let result = self.connection.execute(
351            "DELETE FROM object_store WHERE name = ?",
352            params![store_name.to_string()],
353        )?;
354        if result == 0 {
355            Err(Error::QueryReturnedNoRows)
356        } else if result > 1 {
357            Err(Error::QueryReturnedMoreThanOneRow)
358        } else {
359            Ok(())
360        }
361    }
362
363    fn close_store(&self, _store_name: &str) -> Result<(), Self::Error> {
364        // TODO: do something
365        Ok(())
366    }
367
368    fn delete_database(self) -> Result<(), Self::Error> {
369        // attempt to close the connection first
370        let _ = self.connection.close();
371        if self.db_path.exists() {
372            if let Err(e) = std::fs::remove_dir_all(self.db_path.parent().unwrap()) {
373                error!("Failed to delete database: {:?}", e);
374            }
375        }
376        Ok(())
377    }
378
379    fn process_transaction(
380        &self,
381        transaction: KvsTransaction,
382    ) -> oneshot::Receiver<Option<Vec<u8>>> {
383        let (tx, rx) = oneshot::channel();
384
385        let spawning_pool = if transaction.mode == IndexedDBTxnMode::Readonly {
386            self.read_pool.clone()
387        } else {
388            self.write_pool.clone()
389        };
390        let path = self.db_path.clone();
391        spawning_pool.spawn(move || {
392            let connection = Connection::open(path).unwrap();
393            for request in transaction.requests {
394                let object_store = connection
395                    .prepare("SELECT * FROM object_store WHERE name = ?")
396                    .and_then(|mut stmt| {
397                        stmt.query_row(params![request.store_name.to_string()], |row| {
398                            object_store_model::Model::try_from(row)
399                        })
400                        .optional()
401                    });
402                fn process_object_store<T: Send + Serialize + for<'de> Deserialize<'de>>(
403                    object_store: Result<Option<object_store_model::Model>, Error>,
404                    callback: &GenericCallback<BackendResult<T>>,
405                ) -> Result<object_store_model::Model, ()> {
406                    match object_store {
407                        Ok(Some(store)) => Ok(store),
408                        Ok(None) => {
409                            let _ = callback.send(Err(BackendError::StoreNotFound));
410                            Err(())
411                        },
412                        Err(e) => {
413                            let _ = callback.send(Err(BackendError::DbErr(format!("{:?}", e))));
414                            Err(())
415                        },
416                    }
417                }
418
419                match request.operation {
420                    AsyncOperation::ReadWrite(AsyncReadWriteOperation::PutItem {
421                        callback,
422                        key,
423                        value,
424                        should_overwrite,
425                    }) => {
426                        let Ok(object_store) = process_object_store(object_store, &callback) else {
427                            continue;
428                        };
429                        let key = match key
430                            .map(Ok)
431                            .unwrap_or_else(|| Self::generate_key(&connection, &object_store))
432                        {
433                            Ok(key) => key,
434                            Err(e) => {
435                                let _ = callback.send(Err(BackendError::DbErr(format!("{:?}", e))));
436                                continue;
437                            },
438                        };
439                        let serialized_key: Vec<u8> = bincode::serialize(&key).unwrap();
440                        let _ = callback.send(
441                            Self::put_item(
442                                &connection,
443                                object_store,
444                                serialized_key,
445                                value,
446                                should_overwrite,
447                            )
448                            .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
449                        );
450                    },
451                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetItem {
452                        callback,
453                        key_range,
454                    }) => {
455                        let Ok(object_store) = process_object_store(object_store, &callback) else {
456                            continue;
457                        };
458                        let _ = callback.send(
459                            Self::get_item(&connection, object_store, key_range)
460                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
461                        );
462                    },
463                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetAllKeys {
464                        callback,
465                        key_range,
466                        count,
467                    }) => {
468                        let Ok(object_store) = process_object_store(object_store, &callback) else {
469                            continue;
470                        };
471                        let _ = callback.send(
472                            Self::get_all_keys(&connection, object_store, key_range, count)
473                                .map(|keys| {
474                                    keys.into_iter()
475                                        .map(|k| bincode::deserialize(&k).unwrap())
476                                        .collect()
477                                })
478                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
479                        );
480                    },
481                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetAllItems {
482                        callback,
483                        key_range,
484                        count,
485                    }) => {
486                        let Ok(object_store) = process_object_store(object_store, &callback) else {
487                            continue;
488                        };
489                        let _ = callback.send(
490                            Self::get_all_items(&connection, object_store, key_range, count)
491                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
492                        );
493                    },
494                    AsyncOperation::ReadWrite(AsyncReadWriteOperation::RemoveItem {
495                        callback,
496                        key_range,
497                    }) => {
498                        let Ok(object_store) = process_object_store(object_store, &callback) else {
499                            continue;
500                        };
501                        let _ = callback.send(
502                            Self::delete_item(&connection, object_store, key_range)
503                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
504                        );
505                    },
506                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::Count {
507                        callback,
508                        key_range,
509                    }) => {
510                        let Ok(object_store) = process_object_store(object_store, &callback) else {
511                            continue;
512                        };
513                        let _ = callback.send(
514                            Self::count(&connection, object_store, key_range)
515                                .map(|r| r as u64)
516                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
517                        );
518                    },
519                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::Iterate {
520                        callback,
521                        key_range,
522                    }) => {
523                        let Ok(object_store) = process_object_store(object_store, &callback) else {
524                            continue;
525                        };
526                        let _ = callback.send(
527                            Self::get_all_records(&connection, object_store, key_range)
528                                .map(|records| {
529                                    records
530                                        .into_iter()
531                                        .map(|(key, data)| IndexedDBRecord {
532                                            key: bincode::deserialize(&key).unwrap(),
533                                            primary_key: bincode::deserialize(&key).unwrap(),
534                                            value: data,
535                                        })
536                                        .collect()
537                                })
538                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
539                        );
540                    },
541                    AsyncOperation::ReadWrite(AsyncReadWriteOperation::Clear(sender)) => {
542                        let Ok(object_store) = process_object_store(object_store, &sender) else {
543                            continue;
544                        };
545                        let _ = sender.send(
546                            Self::clear(&connection, object_store)
547                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
548                        );
549                    },
550                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetKey {
551                        callback,
552                        key_range,
553                    }) => {
554                        let Ok(object_store) = process_object_store(object_store, &callback) else {
555                            continue;
556                        };
557                        let _ = callback.send(
558                            Self::get_key(&connection, object_store, key_range)
559                                .map(|key| key.map(|k| bincode::deserialize(&k).unwrap()))
560                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
561                        );
562                    },
563                }
564            }
565            let _ = tx.send(None);
566        });
567        rx
568    }
569
570    // TODO: we should be able to error out here, maybe change the trait definition?
571    fn has_key_generator(&self, store_name: &str) -> bool {
572        self.connection
573            .prepare("SELECT * FROM object_store WHERE name = ?")
574            .and_then(|mut stmt| {
575                stmt.query_row(params![store_name.to_string()], |r| {
576                    let object_store = object_store_model::Model::try_from(r).unwrap();
577                    Ok(object_store.auto_increment)
578                })
579            })
580            .optional()
581            .unwrap()
582            // TODO: Wrong (change trait definition for this function)
583            .unwrap_or_default() !=
584            0
585    }
586
587    fn key_path(&self, store_name: &str) -> Option<KeyPath> {
588        self.connection
589            .prepare("SELECT * FROM object_store WHERE name = ?")
590            .and_then(|mut stmt| {
591                stmt.query_row(params![store_name.to_string()], |r| {
592                    let object_store = object_store_model::Model::try_from(r).unwrap();
593                    Ok(object_store
594                        .key_path
595                        .map(|key_path| bincode::deserialize(&key_path).unwrap()))
596                })
597            })
598            .optional()
599            .unwrap()
600            // TODO: Wrong, same issues as has_key_generator
601            .unwrap_or_default()
602    }
603
604    fn create_index(
605        &self,
606        store_name: &str,
607        index_name: String,
608        key_path: KeyPath,
609        unique: bool,
610        multi_entry: bool,
611    ) -> Result<CreateObjectResult, Self::Error> {
612        let object_store = self.connection.query_row(
613            "SELECT * FROM object_store WHERE name = ?",
614            params![store_name.to_string()],
615            |row| object_store_model::Model::try_from(row),
616        )?;
617
618        let index_exists: bool = self.connection.query_row(
619            "SELECT EXISTS(SELECT * FROM object_store_index WHERE name = ? AND object_store_id = ?)",
620            params![index_name.to_string(), object_store.id],
621            |row| row.get(0),
622        )?;
623        if index_exists {
624            return Ok(CreateObjectResult::AlreadyExists);
625        }
626
627        self.connection.execute(
628            "INSERT INTO object_store_index (object_store_id, name, key_path, unique_index, multi_entry_index)\
629            VALUES (?, ?, ?, ?, ?)",
630            params![
631                object_store.id,
632                index_name.to_string(),
633                bincode::serialize(&key_path).unwrap(),
634                unique,
635                multi_entry,
636            ],
637        )?;
638        Ok(CreateObjectResult::Created)
639    }
640
641    fn delete_index(&self, store_name: &str, index_name: String) -> Result<(), Self::Error> {
642        let object_store = self.connection.query_row(
643            "SELECT * FROM object_store WHERE name = ?",
644            params![store_name.to_string()],
645            |r| Ok(object_store_model::Model::try_from(r).unwrap()),
646        )?;
647
648        // Delete the index if it exists
649        let _ = self.connection.execute(
650            "DELETE FROM object_store_index WHERE name = ? AND object_store_id = ?",
651            params![index_name.to_string(), object_store.id],
652        )?;
653        Ok(())
654    }
655
656    fn version(&self) -> Result<u64, Self::Error> {
657        let version: i64 =
658            self.connection
659                .query_row("SELECT version FROM database LIMIT 1", [], |row| row.get(0))?;
660        Ok(u64::from_ne_bytes(version.to_ne_bytes()))
661    }
662
663    fn set_version(&self, version: u64) -> Result<(), Self::Error> {
664        let rows_affected = self.connection.execute(
665            "UPDATE database SET version = ?",
666            params![i64::from_ne_bytes(version.to_ne_bytes())],
667        )?;
668        if rows_affected == 0 {
669            return Err(Error::QueryReturnedNoRows);
670        }
671        Ok(())
672    }
673}
674
675#[cfg(test)]
676mod tests {
677    use std::collections::VecDeque;
678    use std::sync::Arc;
679
680    use base::generic_channel::{self, GenericReceiver, GenericSender};
681    use base::threadpool::ThreadPool;
682    use profile_traits::generic_callback::GenericCallback;
683    use profile_traits::time::ProfilerChan;
684    use serde::{Deserialize, Serialize};
685    use servo_url::ImmutableOrigin;
686    use storage_traits::indexeddb::{
687        AsyncOperation, AsyncReadOnlyOperation, AsyncReadWriteOperation, CreateObjectResult,
688        IndexedDBKeyRange, IndexedDBKeyType, IndexedDBTxnMode, KeyPath, PutItemResult,
689    };
690    use url::Host;
691
692    use crate::indexeddb::IndexedDBDescription;
693    use crate::indexeddb::engines::{KvsEngine, KvsOperation, KvsTransaction, SqliteEngine};
694
695    fn test_origin() -> ImmutableOrigin {
696        ImmutableOrigin::Tuple(
697            "test_origin".to_string(),
698            Host::Domain("localhost".to_string()),
699            80,
700        )
701    }
702
703    fn get_pool() -> Arc<ThreadPool> {
704        Arc::new(ThreadPool::new(1, "test".to_string()))
705    }
706
707    #[test]
708    fn test_cycle() {
709        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
710        let thread_pool = get_pool();
711        // Test create
712        let _ = SqliteEngine::new(
713            base_dir.path(),
714            &IndexedDBDescription {
715                name: "test_db".to_string(),
716                origin: test_origin(),
717            },
718            thread_pool.clone(),
719        )
720        .unwrap();
721        // Test open
722        let db = SqliteEngine::new(
723            base_dir.path(),
724            &IndexedDBDescription {
725                name: "test_db".to_string(),
726                origin: test_origin(),
727            },
728            thread_pool.clone(),
729        )
730        .unwrap();
731        let version = db.version().expect("Failed to get version");
732        assert_eq!(version, 0);
733        db.set_version(5).unwrap();
734        let new_version = db.version().expect("Failed to get new version");
735        assert_eq!(new_version, 5);
736        db.delete_database().expect("Failed to delete database");
737    }
738
739    #[test]
740    fn test_create_store() {
741        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
742        let thread_pool = get_pool();
743        let db = SqliteEngine::new(
744            base_dir.path(),
745            &IndexedDBDescription {
746                name: "test_db".to_string(),
747                origin: test_origin(),
748            },
749            thread_pool,
750        )
751        .unwrap();
752        let store_name = "test_store";
753        let result = db.create_store(store_name, None, true);
754        assert!(result.is_ok());
755        let create_result = result.unwrap();
756        assert_eq!(create_result, CreateObjectResult::Created);
757        // Try to create the same store again
758        let result = db.create_store(store_name, None, false);
759        assert!(result.is_ok());
760        let create_result = result.unwrap();
761        assert_eq!(create_result, CreateObjectResult::AlreadyExists);
762        // Ensure store was not overwritten
763        assert!(db.has_key_generator(store_name));
764    }
765
766    #[test]
767    fn test_create_store_empty_name() {
768        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
769        let thread_pool = get_pool();
770        let db = SqliteEngine::new(
771            base_dir.path(),
772            &IndexedDBDescription {
773                name: "test_db".to_string(),
774                origin: test_origin(),
775            },
776            thread_pool,
777        )
778        .unwrap();
779        let store_name = "";
780        let result = db.create_store(store_name, None, true);
781        assert!(result.is_ok());
782        let create_result = result.unwrap();
783        assert_eq!(create_result, CreateObjectResult::Created);
784    }
785
786    #[test]
787    fn test_injection() {
788        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
789        let thread_pool = get_pool();
790        let db = SqliteEngine::new(
791            base_dir.path(),
792            &IndexedDBDescription {
793                name: "test_db".to_string(),
794                origin: test_origin(),
795            },
796            thread_pool,
797        )
798        .unwrap();
799        // Create a normal store
800        let store_name1 = "test_store";
801        let result = db.create_store(store_name1, None, true);
802        assert!(result.is_ok());
803        let create_result = result.unwrap();
804        assert_eq!(create_result, CreateObjectResult::Created);
805        // Injection
806        let store_name2 = "' OR 1=1 -- -";
807        let result = db.create_store(store_name2, None, false);
808        assert!(result.is_ok());
809        let create_result = result.unwrap();
810        assert_eq!(create_result, CreateObjectResult::Created);
811    }
812
813    #[test]
814    fn test_key_path() {
815        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
816        let thread_pool = get_pool();
817        let db = SqliteEngine::new(
818            base_dir.path(),
819            &IndexedDBDescription {
820                name: "test_db".to_string(),
821                origin: test_origin(),
822            },
823            thread_pool,
824        )
825        .unwrap();
826        let store_name = "test_store";
827        let result = db.create_store(store_name, Some(KeyPath::String("test".to_string())), true);
828        assert!(result.is_ok());
829        assert_eq!(
830            db.key_path(store_name),
831            Some(KeyPath::String("test".to_string()))
832        );
833    }
834
835    #[test]
836    fn test_delete_store() {
837        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
838        let thread_pool = get_pool();
839        let db = SqliteEngine::new(
840            base_dir.path(),
841            &IndexedDBDescription {
842                name: "test_db".to_string(),
843                origin: test_origin(),
844            },
845            thread_pool,
846        )
847        .unwrap();
848        db.create_store("test_store", None, false)
849            .expect("Failed to create store");
850        // Delete the store
851        db.delete_store("test_store")
852            .expect("Failed to delete store");
853        // Try to delete the same store again
854        let result = db.delete_store("test_store");
855        assert!(result.is_err());
856        // Try to delete a non-existing store
857        let result = db.delete_store("test_store");
858        // Should work as per spec
859        assert!(result.is_err());
860    }
861
862    #[test]
863    fn test_async_operations() {
864        fn get_channel<T>() -> (GenericSender<T>, GenericReceiver<T>)
865        where
866            T: for<'de> Deserialize<'de> + Serialize,
867        {
868            generic_channel::channel().unwrap()
869        }
870
871        fn get_callback<T>(chan: GenericSender<T>) -> GenericCallback<T>
872        where
873            T: for<'de> Deserialize<'de> + Serialize + Send + Sync,
874        {
875            GenericCallback::new(ProfilerChan(None), move |r| {
876                assert!(chan.send(r.unwrap()).is_ok());
877            })
878            .expect("Could not construct callback")
879        }
880
881        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
882        let thread_pool = get_pool();
883        let db = SqliteEngine::new(
884            base_dir.path(),
885            &IndexedDBDescription {
886                name: "test_db".to_string(),
887                origin: test_origin(),
888            },
889            thread_pool,
890        )
891        .unwrap();
892        let store_name = "test_store";
893        db.create_store(store_name, None, false)
894            .expect("Failed to create store");
895        let put = get_channel();
896        let put2 = get_channel();
897        let put3 = get_channel();
898        let put_dup = get_channel();
899        let get_item_some = get_channel();
900        let get_item_none = get_channel();
901        let get_all_items = get_channel();
902        let count = get_channel();
903        let remove = get_channel();
904        let clear = get_channel();
905        let rx = db.process_transaction(KvsTransaction {
906            mode: IndexedDBTxnMode::Readwrite,
907            requests: VecDeque::from(vec![
908                KvsOperation {
909                    store_name: store_name.to_owned(),
910                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::PutItem {
911                        callback: get_callback(put.0),
912                        key: Some(IndexedDBKeyType::Number(1.0)),
913                        value: vec![1, 2, 3],
914                        should_overwrite: false,
915                    }),
916                },
917                KvsOperation {
918                    store_name: store_name.to_owned(),
919                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::PutItem {
920                        callback: get_callback(put2.0),
921                        key: Some(IndexedDBKeyType::String("2.0".to_string())),
922                        value: vec![4, 5, 6],
923                        should_overwrite: false,
924                    }),
925                },
926                KvsOperation {
927                    store_name: store_name.to_owned(),
928                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::PutItem {
929                        callback: get_callback(put3.0),
930                        key: Some(IndexedDBKeyType::Array(vec![
931                            IndexedDBKeyType::String("3".to_string()),
932                            IndexedDBKeyType::Number(0.0),
933                        ])),
934                        value: vec![7, 8, 9],
935                        should_overwrite: false,
936                    }),
937                },
938                // Try to put a duplicate key without overwrite
939                KvsOperation {
940                    store_name: store_name.to_owned(),
941                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::PutItem {
942                        callback: get_callback(put_dup.0),
943                        key: Some(IndexedDBKeyType::Number(1.0)),
944                        value: vec![10, 11, 12],
945                        should_overwrite: false,
946                    }),
947                },
948                KvsOperation {
949                    store_name: store_name.to_owned(),
950                    operation: AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetItem {
951                        callback: get_callback(get_item_some.0),
952                        key_range: IndexedDBKeyRange::only(IndexedDBKeyType::Number(1.0)),
953                    }),
954                },
955                KvsOperation {
956                    store_name: store_name.to_owned(),
957                    operation: AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetItem {
958                        callback: get_callback(get_item_none.0),
959                        key_range: IndexedDBKeyRange::only(IndexedDBKeyType::Number(5.0)),
960                    }),
961                },
962                KvsOperation {
963                    store_name: store_name.to_owned(),
964                    operation: AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetAllItems {
965                        callback: get_callback(get_all_items.0),
966                        key_range: IndexedDBKeyRange::lower_bound(
967                            IndexedDBKeyType::Number(0.0),
968                            false,
969                        ),
970                        count: None,
971                    }),
972                },
973                KvsOperation {
974                    store_name: store_name.to_owned(),
975                    operation: AsyncOperation::ReadOnly(AsyncReadOnlyOperation::Count {
976                        callback: get_callback(count.0),
977                        key_range: IndexedDBKeyRange::only(IndexedDBKeyType::Number(1.0)),
978                    }),
979                },
980                KvsOperation {
981                    store_name: store_name.to_owned(),
982                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::RemoveItem {
983                        callback: get_callback(remove.0),
984                        key_range: IndexedDBKeyRange::only(IndexedDBKeyType::Number(1.0)),
985                    }),
986                },
987                KvsOperation {
988                    store_name: store_name.to_owned(),
989                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::Clear(
990                        get_callback(clear.0),
991                    )),
992                },
993            ]),
994        });
995        let _ = rx.blocking_recv().unwrap();
996        put.1.recv().unwrap().unwrap();
997        put2.1.recv().unwrap().unwrap();
998        put3.1.recv().unwrap().unwrap();
999        let err = put_dup.1.recv().unwrap().unwrap();
1000        assert_eq!(err, PutItemResult::CannotOverwrite);
1001        let get_result = get_item_some.1.recv().unwrap();
1002        let value = get_result.unwrap();
1003        assert_eq!(value, Some(vec![1, 2, 3]));
1004        let get_result = get_item_none.1.recv().unwrap();
1005        let value = get_result.unwrap();
1006        assert_eq!(value, None);
1007        let all_items = get_all_items.1.recv().unwrap().unwrap();
1008        assert_eq!(all_items.len(), 3);
1009        // Check that all three items are present
1010        assert!(all_items.contains(&vec![1, 2, 3]));
1011        assert!(all_items.contains(&vec![4, 5, 6]));
1012        assert!(all_items.contains(&vec![7, 8, 9]));
1013        let amount = count.1.recv().unwrap().unwrap();
1014        assert_eq!(amount, 1);
1015        remove.1.recv().unwrap().unwrap();
1016        clear.1.recv().unwrap().unwrap();
1017    }
1018}