storage/indexeddb/engines/
sqlite.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6
7use base::threadpool::ThreadPool;
8use log::{error, info};
9use rusqlite::{Connection, Error, OptionalExtension, params};
10use sea_query::{Condition, Expr, ExprTrait, IntoCondition, SqliteQueryBuilder};
11use sea_query_rusqlite::RusqliteBinder;
12use storage_traits::indexeddb::{
13    AsyncOperation, AsyncReadOnlyOperation, AsyncReadWriteOperation, BackendError,
14    CreateObjectResult, IndexedDBIndex, IndexedDBKeyRange, IndexedDBKeyType, IndexedDBRecord,
15    IndexedDBTxnMode, KeyPath, PutItemResult,
16};
17use tokio::sync::oneshot;
18
19use crate::indexeddb::IndexedDBDescription;
20use crate::indexeddb::engines::{KvsEngine, KvsTransaction};
21use crate::shared::{DB_INIT_PRAGMAS, DB_PRAGMAS};
22
23mod create;
24mod database_model;
25mod object_data_model;
26mod object_store_index_model;
27mod object_store_model;
28
29fn range_to_query(range: IndexedDBKeyRange) -> Condition {
30    // Special case for optimization
31    if let Some(singleton) = range.as_singleton() {
32        let encoded = postcard::to_stdvec(singleton).unwrap();
33        return Expr::column(object_data_model::Column::Key)
34            .eq(encoded)
35            .into_condition();
36    }
37    let mut parts = vec![];
38    if let Some(upper) = range.upper.as_ref() {
39        let upper_bytes = postcard::to_stdvec(upper).unwrap();
40        let query = if range.upper_open {
41            Expr::column(object_data_model::Column::Key).lt(upper_bytes)
42        } else {
43            Expr::column(object_data_model::Column::Key).lte(upper_bytes)
44        };
45        parts.push(query);
46    }
47    if let Some(lower) = range.lower.as_ref() {
48        let lower_bytes = postcard::to_stdvec(lower).unwrap();
49        let query = if range.lower_open {
50            Expr::column(object_data_model::Column::Key).gt(lower_bytes)
51        } else {
52            Expr::column(object_data_model::Column::Key).gte(lower_bytes)
53        };
54        parts.push(query);
55    }
56    let mut condition = Condition::all();
57    for part in parts {
58        condition = condition.add(part);
59    }
60    condition
61}
62
63pub struct SqliteEngine {
64    db_path: PathBuf,
65    connection: Connection,
66    read_pool: Arc<ThreadPool>,
67    write_pool: Arc<ThreadPool>,
68    created_db_path: bool,
69}
70
71impl SqliteEngine {
72    // TODO: intake dual pools
73    pub fn new(
74        base_dir: &Path,
75        db_info: &IndexedDBDescription,
76        pool: Arc<ThreadPool>,
77    ) -> Result<Self, Error> {
78        let mut db_path = PathBuf::new();
79        db_path.push(base_dir);
80        db_path.push(db_info.as_path());
81        let db_parent = db_path.clone();
82        db_path.push("db.sqlite");
83
84        let created_db_path = if !db_path.exists() {
85            std::fs::create_dir_all(db_parent).unwrap();
86            std::fs::File::create(&db_path).unwrap();
87            true
88        } else {
89            false
90        };
91
92        let connection = Self::init_db(&db_path, db_info)?;
93
94        for stmt in DB_PRAGMAS {
95            // TODO: Handle errors properly
96            let _ = connection.execute(stmt, ());
97        }
98
99        Ok(Self {
100            connection,
101            db_path,
102            read_pool: pool.clone(),
103            write_pool: pool,
104            created_db_path,
105        })
106    }
107
108    /// Returns whether the physical db was created as part of `new`.
109    pub(crate) fn created_db_path(&self) -> bool {
110        self.created_db_path
111    }
112
113    fn init_db(path: &Path, db_info: &IndexedDBDescription) -> Result<Connection, Error> {
114        let connection = Connection::open(path)?;
115        if connection.table_exists(None, "database")? {
116            // Database already exists, no need to initialize
117            return Ok(connection);
118        }
119        info!("Initializing indexeddb database at {:?}", path);
120        for stmt in DB_INIT_PRAGMAS {
121            // FIXME(arihant2math): this fails occasionally
122            let _ = connection.execute(stmt, ());
123        }
124        create::create_tables(&connection)?;
125        // From https://w3c.github.io/IndexedDB/#database-version:
126        // "When a database is first created, its version is 0 (zero)."
127        connection.execute(
128            "INSERT INTO database (name, origin, version) VALUES (?, ?, ?)",
129            params![
130                db_info.name.to_owned(),
131                db_info.origin.to_owned().ascii_serialization(),
132                i64::from_ne_bytes(0_u64.to_ne_bytes())
133            ],
134        )?;
135        Ok(connection)
136    }
137
138    fn get(
139        connection: &Connection,
140        store: object_store_model::Model,
141        key_range: IndexedDBKeyRange,
142    ) -> Result<Option<object_data_model::Model>, Error> {
143        let query = range_to_query(key_range);
144        let (sql, values) = sea_query::Query::select()
145            .from(object_data_model::Column::Table)
146            .columns(vec![
147                object_data_model::Column::ObjectStoreId,
148                object_data_model::Column::Key,
149                object_data_model::Column::Data,
150            ])
151            .and_where(query.and(Expr::col(object_data_model::Column::ObjectStoreId).is(store.id)))
152            .limit(1)
153            .build_rusqlite(SqliteQueryBuilder);
154        connection
155            .prepare(&sql)?
156            .query_one(&*values.as_params(), |row| {
157                object_data_model::Model::try_from(row)
158            })
159            .optional()
160    }
161
162    fn get_key(
163        connection: &Connection,
164        store: object_store_model::Model,
165        key_range: IndexedDBKeyRange,
166    ) -> Result<Option<Vec<u8>>, Error> {
167        Self::get(connection, store, key_range).map(|opt| opt.map(|model| model.key))
168    }
169
170    fn get_item(
171        connection: &Connection,
172        store: object_store_model::Model,
173        key_range: IndexedDBKeyRange,
174    ) -> Result<Option<Vec<u8>>, Error> {
175        Self::get(connection, store, key_range).map(|opt| opt.map(|model| model.data))
176    }
177
178    fn get_all(
179        connection: &Connection,
180        store: object_store_model::Model,
181        key_range: IndexedDBKeyRange,
182        count: Option<u32>,
183    ) -> Result<Vec<object_data_model::Model>, Error> {
184        let query = range_to_query(key_range);
185        let mut sql_query = sea_query::Query::select();
186        sql_query
187            .from(object_data_model::Column::Table)
188            .columns(vec![
189                object_data_model::Column::ObjectStoreId,
190                object_data_model::Column::Key,
191                object_data_model::Column::Data,
192            ])
193            .and_where(query.and(Expr::col(object_data_model::Column::ObjectStoreId).is(store.id)));
194        if let Some(count) = count {
195            sql_query.limit(count as u64);
196        }
197        let (sql, values) = sql_query.build_rusqlite(SqliteQueryBuilder);
198        let mut stmt = connection.prepare(&sql)?;
199        let models = stmt
200            .query_and_then(&*values.as_params(), |row| {
201                object_data_model::Model::try_from(row)
202            })?
203            .collect::<Result<Vec<_>, _>>()?;
204        Ok(models)
205    }
206
207    fn get_all_keys(
208        connection: &Connection,
209        store: object_store_model::Model,
210        key_range: IndexedDBKeyRange,
211        count: Option<u32>,
212    ) -> Result<Vec<Vec<u8>>, Error> {
213        Self::get_all(connection, store, key_range, count)
214            .map(|models| models.into_iter().map(|m| m.key).collect())
215    }
216
217    fn get_all_items(
218        connection: &Connection,
219        store: object_store_model::Model,
220        key_range: IndexedDBKeyRange,
221        count: Option<u32>,
222    ) -> Result<Vec<Vec<u8>>, Error> {
223        Self::get_all(connection, store, key_range, count)
224            .map(|models| models.into_iter().map(|m| m.data).collect())
225    }
226
227    #[expect(clippy::type_complexity)]
228    fn get_all_records(
229        connection: &Connection,
230        store: object_store_model::Model,
231        key_range: IndexedDBKeyRange,
232    ) -> Result<Vec<(Vec<u8>, Vec<u8>)>, Error> {
233        Self::get_all(connection, store, key_range, None)
234            .map(|models| models.into_iter().map(|m| (m.key, m.data)).collect())
235    }
236
237    fn put_item(
238        connection: &Connection,
239        store: object_store_model::Model,
240        serialized_key: Vec<u8>,
241        value: Vec<u8>,
242        should_overwrite: bool,
243    ) -> Result<PutItemResult, Error> {
244        let existing_item = connection
245            .prepare("SELECT * FROM object_data WHERE key = ? AND object_store_id = ?")
246            .and_then(|mut stmt| {
247                stmt.query_row(params![serialized_key, store.id], |row| {
248                    object_data_model::Model::try_from(row)
249                })
250                .optional()
251            })?;
252        if should_overwrite || existing_item.is_none() {
253            connection.execute(
254                "INSERT INTO object_data (object_store_id, key, data) VALUES (?, ?, ?)",
255                params![store.id, serialized_key, value],
256            )?;
257            Ok(PutItemResult::Success)
258        } else {
259            Ok(PutItemResult::CannotOverwrite)
260        }
261    }
262
263    fn delete_item(
264        connection: &Connection,
265        store: object_store_model::Model,
266        key_range: IndexedDBKeyRange,
267    ) -> Result<(), Error> {
268        let query = range_to_query(key_range);
269        let (sql, values) = sea_query::Query::delete()
270            .from_table(object_data_model::Column::Table)
271            .and_where(query.and(Expr::col(object_data_model::Column::ObjectStoreId).is(store.id)))
272            .build_rusqlite(SqliteQueryBuilder);
273        connection.prepare(&sql)?.execute(&*values.as_params())?;
274        Ok(())
275    }
276
277    fn clear(connection: &Connection, store: object_store_model::Model) -> Result<(), Error> {
278        connection.execute(
279            "DELETE FROM object_data WHERE object_store_id = ?",
280            params![store.id],
281        )?;
282        Ok(())
283    }
284
285    fn count(
286        connection: &Connection,
287        store: object_store_model::Model,
288        key_range: IndexedDBKeyRange,
289    ) -> Result<usize, Error> {
290        let query = range_to_query(key_range);
291        let (sql, values) = sea_query::Query::select()
292            .expr(Expr::col(object_data_model::Column::Key).count())
293            .from(object_data_model::Column::Table)
294            .and_where(query.and(Expr::col(object_data_model::Column::ObjectStoreId).is(store.id)))
295            .build_rusqlite(SqliteQueryBuilder);
296        connection
297            .prepare(&sql)?
298            .query_row(&*values.as_params(), |row| row.get(0))
299            .map(|count: i64| count as usize)
300    }
301
302    fn generate_key(
303        connection: &Connection,
304        store: &object_store_model::Model,
305    ) -> Result<IndexedDBKeyType, Error> {
306        if store.auto_increment == 0 {
307            unreachable!("Should be caught in the script thread");
308        }
309        // TODO: handle overflows, this also needs to be able to handle 2^53 as per spec
310        let new_key = store.auto_increment + 1;
311        connection.execute(
312            "UPDATE object_store SET auto_increment = ? WHERE id = ?",
313            params![new_key, store.id],
314        )?;
315        Ok(IndexedDBKeyType::Number(new_key as f64))
316    }
317}
318
319impl KvsEngine for SqliteEngine {
320    type Error = Error;
321
322    fn create_store(
323        &self,
324        store_name: &str,
325        key_path: Option<KeyPath>,
326        auto_increment: bool,
327    ) -> Result<CreateObjectResult, Self::Error> {
328        let mut stmt = self
329            .connection
330            .prepare("SELECT * FROM object_store WHERE name = ?")?;
331        if stmt.exists(params![store_name.to_string()])? {
332            // Store already exists
333            return Ok(CreateObjectResult::AlreadyExists);
334        }
335        self.connection.execute(
336            "INSERT INTO object_store (name, key_path, auto_increment) VALUES (?, ?, ?)",
337            params![
338                store_name.to_string(),
339                key_path.map(|v| postcard::to_stdvec(&v).unwrap()),
340                auto_increment as i32
341            ],
342        )?;
343
344        Ok(CreateObjectResult::Created)
345    }
346
347    fn delete_store(&self, store_name: &str) -> Result<(), Self::Error> {
348        let result = self.connection.execute(
349            "DELETE FROM object_store WHERE name = ?",
350            params![store_name.to_string()],
351        )?;
352        if result == 0 {
353            Err(Error::QueryReturnedNoRows)
354        } else if result > 1 {
355            Err(Error::QueryReturnedMoreThanOneRow)
356        } else {
357            Ok(())
358        }
359    }
360
361    fn close_store(&self, _store_name: &str) -> Result<(), Self::Error> {
362        // TODO: do something
363        Ok(())
364    }
365
366    fn delete_database(self) -> Result<(), Self::Error> {
367        // attempt to close the connection first
368        let _ = self.connection.close();
369        if self.db_path.exists() {
370            if let Err(e) = std::fs::remove_dir_all(self.db_path.parent().unwrap()) {
371                error!("Failed to delete database: {:?}", e);
372            }
373        }
374        Ok(())
375    }
376
377    fn process_transaction(
378        &self,
379        transaction: KvsTransaction,
380    ) -> oneshot::Receiver<Option<Vec<u8>>> {
381        let (tx, rx) = oneshot::channel();
382
383        let spawning_pool = if transaction.mode == IndexedDBTxnMode::Readonly {
384            self.read_pool.clone()
385        } else {
386            self.write_pool.clone()
387        };
388        let path = self.db_path.clone();
389        spawning_pool.spawn(move || {
390            let connection = match Connection::open(path) {
391                Ok(connection) => connection,
392                Err(error) => {
393                    for request in transaction.requests {
394                        request
395                            .operation
396                            .notify_error(BackendError::DbErr(format!("{error:?}")));
397                    }
398                    let _ = tx.send(None);
399                    return;
400                },
401            };
402            for request in transaction.requests {
403                let object_store = connection
404                    .prepare("SELECT * FROM object_store WHERE name = ?")
405                    .and_then(|mut stmt| {
406                        stmt.query_row(params![request.store_name.to_string()], |row| {
407                            object_store_model::Model::try_from(row)
408                        })
409                        .optional()
410                    });
411                let object_store = match object_store {
412                    Ok(Some(store)) => store,
413                    Ok(None) => {
414                        request.operation.notify_error(BackendError::StoreNotFound);
415                        continue;
416                    },
417                    Err(error) => {
418                        request
419                            .operation
420                            .notify_error(BackendError::DbErr(format!("{error:?}")));
421                        continue;
422                    },
423                };
424
425                match request.operation {
426                    AsyncOperation::ReadWrite(AsyncReadWriteOperation::PutItem {
427                        callback,
428                        key,
429                        value,
430                        should_overwrite,
431                    }) => {
432                        let key = match key
433                            .map(Ok)
434                            .unwrap_or_else(|| Self::generate_key(&connection, &object_store))
435                        {
436                            Ok(key) => key,
437                            Err(e) => {
438                                let _ = callback.send(Err(BackendError::DbErr(format!("{:?}", e))));
439                                continue;
440                            },
441                        };
442                        let serialized_key: Vec<u8> = postcard::to_stdvec(&key).unwrap();
443                        let _ = callback.send(
444                            Self::put_item(
445                                &connection,
446                                object_store,
447                                serialized_key,
448                                value,
449                                should_overwrite,
450                            )
451                            .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
452                        );
453                    },
454                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetItem {
455                        callback,
456                        key_range,
457                    }) => {
458                        let _ = callback.send(
459                            Self::get_item(&connection, object_store, key_range)
460                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
461                        );
462                    },
463                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetAllKeys {
464                        callback,
465                        key_range,
466                        count,
467                    }) => {
468                        let _ = callback.send(
469                            Self::get_all_keys(&connection, object_store, key_range, count)
470                                .map(|keys| {
471                                    keys.into_iter()
472                                        .map(|k| postcard::from_bytes(&k).unwrap())
473                                        .collect()
474                                })
475                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
476                        );
477                    },
478                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetAllItems {
479                        callback,
480                        key_range,
481                        count,
482                    }) => {
483                        let _ = callback.send(
484                            Self::get_all_items(&connection, object_store, key_range, count)
485                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
486                        );
487                    },
488                    AsyncOperation::ReadWrite(AsyncReadWriteOperation::RemoveItem {
489                        callback,
490                        key_range,
491                    }) => {
492                        let _ = callback.send(
493                            Self::delete_item(&connection, object_store, key_range)
494                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
495                        );
496                    },
497                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::Count {
498                        callback,
499                        key_range,
500                    }) => {
501                        let _ = callback.send(
502                            Self::count(&connection, object_store, key_range)
503                                .map(|r| r as u64)
504                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
505                        );
506                    },
507                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::Iterate {
508                        callback,
509                        key_range,
510                    }) => {
511                        let _ = callback.send(
512                            Self::get_all_records(&connection, object_store, key_range)
513                                .map(|records| {
514                                    records
515                                        .into_iter()
516                                        .map(|(key, data)| IndexedDBRecord {
517                                            key: postcard::from_bytes(&key).unwrap(),
518                                            primary_key: postcard::from_bytes(&key).unwrap(),
519                                            value: data,
520                                        })
521                                        .collect()
522                                })
523                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
524                        );
525                    },
526                    AsyncOperation::ReadWrite(AsyncReadWriteOperation::Clear(sender)) => {
527                        let _ = sender.send(
528                            Self::clear(&connection, object_store)
529                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
530                        );
531                    },
532                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetKey {
533                        callback,
534                        key_range,
535                    }) => {
536                        let _ = callback.send(
537                            Self::get_key(&connection, object_store, key_range)
538                                .map(|key| key.map(|k| postcard::from_bytes(&k).unwrap()))
539                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
540                        );
541                    },
542                }
543            }
544            let _ = tx.send(None);
545        });
546        rx
547    }
548
549    // TODO: we should be able to error out here, maybe change the trait definition?
550    fn has_key_generator(&self, store_name: &str) -> bool {
551        self.connection
552            .prepare("SELECT * FROM object_store WHERE name = ?")
553            .and_then(|mut stmt| {
554                stmt.query_row(params![store_name.to_string()], |r| {
555                    let object_store = object_store_model::Model::try_from(r).unwrap();
556                    Ok(object_store.auto_increment)
557                })
558            })
559            .optional()
560            .unwrap()
561            // TODO: Wrong (change trait definition for this function)
562            .unwrap_or_default() !=
563            0
564    }
565
566    fn key_path(&self, store_name: &str) -> Option<KeyPath> {
567        self.connection
568            .prepare("SELECT * FROM object_store WHERE name = ?")
569            .and_then(|mut stmt| {
570                stmt.query_row(params![store_name.to_string()], |r| {
571                    let object_store = object_store_model::Model::try_from(r).unwrap();
572                    Ok(object_store
573                        .key_path
574                        .map(|key_path| postcard::from_bytes(&key_path).unwrap()))
575                })
576            })
577            .optional()
578            .unwrap()
579            // TODO: Wrong, same issues as has_key_generator
580            .unwrap_or_default()
581    }
582
583    fn indexes(&self, store_name: &str) -> Result<Vec<IndexedDBIndex>, Self::Error> {
584        let object_store = self.connection.query_row(
585            "SELECT * FROM object_store WHERE name = ?",
586            params![store_name.to_string()],
587            |row| object_store_model::Model::try_from(row),
588        )?;
589
590        let mut stmt = self
591            .connection
592            .prepare("SELECT * FROM object_store_index WHERE object_store_id = ?")?;
593        let indexes = stmt
594            .query_map(params![object_store.id], |row| {
595                let model = object_store_index_model::Model::try_from(row)?;
596                Ok(IndexedDBIndex {
597                    name: model.name,
598                    key_path: postcard::from_bytes(&model.key_path).unwrap(),
599                    unique: model.unique_index,
600                    multi_entry: model.multi_entry_index,
601                })
602            })?
603            .collect::<Result<Vec<_>, _>>()?;
604        Ok(indexes)
605    }
606
607    fn create_index(
608        &self,
609        store_name: &str,
610        index_name: String,
611        key_path: KeyPath,
612        unique: bool,
613        multi_entry: bool,
614    ) -> Result<CreateObjectResult, Self::Error> {
615        let object_store = self.connection.query_row(
616            "SELECT * FROM object_store WHERE name = ?",
617            params![store_name.to_string()],
618            |row| object_store_model::Model::try_from(row),
619        )?;
620
621        let index_exists: bool = self.connection.query_row(
622            "SELECT EXISTS(SELECT * FROM object_store_index WHERE name = ? AND object_store_id = ?)",
623            params![index_name.to_string(), object_store.id],
624            |row| row.get(0),
625        )?;
626        if index_exists {
627            return Ok(CreateObjectResult::AlreadyExists);
628        }
629
630        self.connection.execute(
631            "INSERT INTO object_store_index (object_store_id, name, key_path, unique_index, multi_entry_index)\
632            VALUES (?, ?, ?, ?, ?)",
633            params![
634                object_store.id,
635                index_name.to_string(),
636                postcard::to_stdvec(&key_path).unwrap(),
637                unique,
638                multi_entry,
639            ],
640        )?;
641        Ok(CreateObjectResult::Created)
642    }
643
644    fn delete_index(&self, store_name: &str, index_name: String) -> Result<(), Self::Error> {
645        let object_store = self.connection.query_row(
646            "SELECT * FROM object_store WHERE name = ?",
647            params![store_name.to_string()],
648            |r| Ok(object_store_model::Model::try_from(r).unwrap()),
649        )?;
650
651        // Delete the index if it exists
652        let _ = self.connection.execute(
653            "DELETE FROM object_store_index WHERE name = ? AND object_store_id = ?",
654            params![index_name.to_string(), object_store.id],
655        )?;
656        Ok(())
657    }
658
659    fn version(&self) -> Result<u64, Self::Error> {
660        let version: i64 =
661            self.connection
662                .query_row("SELECT version FROM database LIMIT 1", [], |row| row.get(0))?;
663        Ok(u64::from_ne_bytes(version.to_ne_bytes()))
664    }
665
666    fn set_version(&self, version: u64) -> Result<(), Self::Error> {
667        let rows_affected = self.connection.execute(
668            "UPDATE database SET version = ?",
669            params![i64::from_ne_bytes(version.to_ne_bytes())],
670        )?;
671        if rows_affected == 0 {
672            return Err(Error::QueryReturnedNoRows);
673        }
674        Ok(())
675    }
676}
677
678#[cfg(test)]
679mod tests {
680    use std::collections::VecDeque;
681    use std::sync::Arc;
682
683    use base::generic_channel::{self, GenericReceiver, GenericSender};
684    use base::threadpool::ThreadPool;
685    use profile_traits::generic_callback::GenericCallback;
686    use profile_traits::time::ProfilerChan;
687    use serde::{Deserialize, Serialize};
688    use servo_url::ImmutableOrigin;
689    use storage_traits::indexeddb::{
690        AsyncOperation, AsyncReadOnlyOperation, AsyncReadWriteOperation, CreateObjectResult,
691        IndexedDBKeyRange, IndexedDBKeyType, IndexedDBTxnMode, KeyPath, PutItemResult,
692    };
693    use url::Host;
694
695    use crate::indexeddb::IndexedDBDescription;
696    use crate::indexeddb::engines::{KvsEngine, KvsOperation, KvsTransaction, SqliteEngine};
697
698    fn test_origin() -> ImmutableOrigin {
699        ImmutableOrigin::Tuple(
700            "test_origin".to_string(),
701            Host::Domain("localhost".to_string()),
702            80,
703        )
704    }
705
706    fn get_pool() -> Arc<ThreadPool> {
707        Arc::new(ThreadPool::new(1, "test".to_string()))
708    }
709
710    #[test]
711    fn test_cycle() {
712        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
713        let thread_pool = get_pool();
714        // Test create
715        let _ = SqliteEngine::new(
716            base_dir.path(),
717            &IndexedDBDescription {
718                name: "test_db".to_string(),
719                origin: test_origin(),
720            },
721            thread_pool.clone(),
722        )
723        .unwrap();
724        // Test open
725        let db = SqliteEngine::new(
726            base_dir.path(),
727            &IndexedDBDescription {
728                name: "test_db".to_string(),
729                origin: test_origin(),
730            },
731            thread_pool.clone(),
732        )
733        .unwrap();
734        let version = db.version().expect("Failed to get version");
735        assert_eq!(version, 0);
736        db.set_version(5).unwrap();
737        let new_version = db.version().expect("Failed to get new version");
738        assert_eq!(new_version, 5);
739        db.delete_database().expect("Failed to delete database");
740    }
741
742    #[test]
743    fn test_create_store() {
744        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
745        let thread_pool = get_pool();
746        let db = SqliteEngine::new(
747            base_dir.path(),
748            &IndexedDBDescription {
749                name: "test_db".to_string(),
750                origin: test_origin(),
751            },
752            thread_pool,
753        )
754        .unwrap();
755        let store_name = "test_store";
756        let result = db.create_store(store_name, None, true);
757        assert!(result.is_ok());
758        let create_result = result.unwrap();
759        assert_eq!(create_result, CreateObjectResult::Created);
760        // Try to create the same store again
761        let result = db.create_store(store_name, None, false);
762        assert!(result.is_ok());
763        let create_result = result.unwrap();
764        assert_eq!(create_result, CreateObjectResult::AlreadyExists);
765        // Ensure store was not overwritten
766        assert!(db.has_key_generator(store_name));
767    }
768
769    #[test]
770    fn test_create_store_empty_name() {
771        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
772        let thread_pool = get_pool();
773        let db = SqliteEngine::new(
774            base_dir.path(),
775            &IndexedDBDescription {
776                name: "test_db".to_string(),
777                origin: test_origin(),
778            },
779            thread_pool,
780        )
781        .unwrap();
782        let store_name = "";
783        let result = db.create_store(store_name, None, true);
784        assert!(result.is_ok());
785        let create_result = result.unwrap();
786        assert_eq!(create_result, CreateObjectResult::Created);
787    }
788
789    #[test]
790    fn test_injection() {
791        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
792        let thread_pool = get_pool();
793        let db = SqliteEngine::new(
794            base_dir.path(),
795            &IndexedDBDescription {
796                name: "test_db".to_string(),
797                origin: test_origin(),
798            },
799            thread_pool,
800        )
801        .unwrap();
802        // Create a normal store
803        let store_name1 = "test_store";
804        let result = db.create_store(store_name1, None, true);
805        assert!(result.is_ok());
806        let create_result = result.unwrap();
807        assert_eq!(create_result, CreateObjectResult::Created);
808        // Injection
809        let store_name2 = "' OR 1=1 -- -";
810        let result = db.create_store(store_name2, None, false);
811        assert!(result.is_ok());
812        let create_result = result.unwrap();
813        assert_eq!(create_result, CreateObjectResult::Created);
814    }
815
816    #[test]
817    fn test_key_path() {
818        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
819        let thread_pool = get_pool();
820        let db = SqliteEngine::new(
821            base_dir.path(),
822            &IndexedDBDescription {
823                name: "test_db".to_string(),
824                origin: test_origin(),
825            },
826            thread_pool,
827        )
828        .unwrap();
829        let store_name = "test_store";
830        let result = db.create_store(store_name, Some(KeyPath::String("test".to_string())), true);
831        assert!(result.is_ok());
832        assert_eq!(
833            db.key_path(store_name),
834            Some(KeyPath::String("test".to_string()))
835        );
836    }
837
838    #[test]
839    fn test_delete_store() {
840        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
841        let thread_pool = get_pool();
842        let db = SqliteEngine::new(
843            base_dir.path(),
844            &IndexedDBDescription {
845                name: "test_db".to_string(),
846                origin: test_origin(),
847            },
848            thread_pool,
849        )
850        .unwrap();
851        db.create_store("test_store", None, false)
852            .expect("Failed to create store");
853        // Delete the store
854        db.delete_store("test_store")
855            .expect("Failed to delete store");
856        // Try to delete the same store again
857        let result = db.delete_store("test_store");
858        assert!(result.is_err());
859        // Try to delete a non-existing store
860        let result = db.delete_store("test_store");
861        // Should work as per spec
862        assert!(result.is_err());
863    }
864
865    #[test]
866    fn test_async_operations() {
867        fn get_channel<T>() -> (GenericSender<T>, GenericReceiver<T>)
868        where
869            T: for<'de> Deserialize<'de> + Serialize,
870        {
871            generic_channel::channel().unwrap()
872        }
873
874        fn get_callback<T>(chan: GenericSender<T>) -> GenericCallback<T>
875        where
876            T: for<'de> Deserialize<'de> + Serialize + Send + Sync,
877        {
878            GenericCallback::new(ProfilerChan(None), move |r| {
879                assert!(chan.send(r.unwrap()).is_ok());
880            })
881            .expect("Could not construct callback")
882        }
883
884        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
885        let thread_pool = get_pool();
886        let db = SqliteEngine::new(
887            base_dir.path(),
888            &IndexedDBDescription {
889                name: "test_db".to_string(),
890                origin: test_origin(),
891            },
892            thread_pool,
893        )
894        .unwrap();
895        let store_name = "test_store";
896        db.create_store(store_name, None, false)
897            .expect("Failed to create store");
898        let put = get_channel();
899        let put2 = get_channel();
900        let put3 = get_channel();
901        let put_dup = get_channel();
902        let get_item_some = get_channel();
903        let get_item_none = get_channel();
904        let get_all_items = get_channel();
905        let count = get_channel();
906        let remove = get_channel();
907        let clear = get_channel();
908        let rx = db.process_transaction(KvsTransaction {
909            mode: IndexedDBTxnMode::Readwrite,
910            requests: VecDeque::from(vec![
911                KvsOperation {
912                    store_name: store_name.to_owned(),
913                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::PutItem {
914                        callback: get_callback(put.0),
915                        key: Some(IndexedDBKeyType::Number(1.0)),
916                        value: vec![1, 2, 3],
917                        should_overwrite: false,
918                    }),
919                },
920                KvsOperation {
921                    store_name: store_name.to_owned(),
922                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::PutItem {
923                        callback: get_callback(put2.0),
924                        key: Some(IndexedDBKeyType::String("2.0".to_string())),
925                        value: vec![4, 5, 6],
926                        should_overwrite: false,
927                    }),
928                },
929                KvsOperation {
930                    store_name: store_name.to_owned(),
931                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::PutItem {
932                        callback: get_callback(put3.0),
933                        key: Some(IndexedDBKeyType::Array(vec![
934                            IndexedDBKeyType::String("3".to_string()),
935                            IndexedDBKeyType::Number(0.0),
936                        ])),
937                        value: vec![7, 8, 9],
938                        should_overwrite: false,
939                    }),
940                },
941                // Try to put a duplicate key without overwrite
942                KvsOperation {
943                    store_name: store_name.to_owned(),
944                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::PutItem {
945                        callback: get_callback(put_dup.0),
946                        key: Some(IndexedDBKeyType::Number(1.0)),
947                        value: vec![10, 11, 12],
948                        should_overwrite: false,
949                    }),
950                },
951                KvsOperation {
952                    store_name: store_name.to_owned(),
953                    operation: AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetItem {
954                        callback: get_callback(get_item_some.0),
955                        key_range: IndexedDBKeyRange::only(IndexedDBKeyType::Number(1.0)),
956                    }),
957                },
958                KvsOperation {
959                    store_name: store_name.to_owned(),
960                    operation: AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetItem {
961                        callback: get_callback(get_item_none.0),
962                        key_range: IndexedDBKeyRange::only(IndexedDBKeyType::Number(5.0)),
963                    }),
964                },
965                KvsOperation {
966                    store_name: store_name.to_owned(),
967                    operation: AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetAllItems {
968                        callback: get_callback(get_all_items.0),
969                        key_range: IndexedDBKeyRange::lower_bound(
970                            IndexedDBKeyType::Number(0.0),
971                            false,
972                        ),
973                        count: None,
974                    }),
975                },
976                KvsOperation {
977                    store_name: store_name.to_owned(),
978                    operation: AsyncOperation::ReadOnly(AsyncReadOnlyOperation::Count {
979                        callback: get_callback(count.0),
980                        key_range: IndexedDBKeyRange::only(IndexedDBKeyType::Number(1.0)),
981                    }),
982                },
983                KvsOperation {
984                    store_name: store_name.to_owned(),
985                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::RemoveItem {
986                        callback: get_callback(remove.0),
987                        key_range: IndexedDBKeyRange::only(IndexedDBKeyType::Number(1.0)),
988                    }),
989                },
990                KvsOperation {
991                    store_name: store_name.to_owned(),
992                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::Clear(
993                        get_callback(clear.0),
994                    )),
995                },
996            ]),
997        });
998        let _ = rx.blocking_recv().unwrap();
999        put.1.recv().unwrap().unwrap();
1000        put2.1.recv().unwrap().unwrap();
1001        put3.1.recv().unwrap().unwrap();
1002        let err = put_dup.1.recv().unwrap().unwrap();
1003        assert_eq!(err, PutItemResult::CannotOverwrite);
1004        let get_result = get_item_some.1.recv().unwrap();
1005        let value = get_result.unwrap();
1006        assert_eq!(value, Some(vec![1, 2, 3]));
1007        let get_result = get_item_none.1.recv().unwrap();
1008        let value = get_result.unwrap();
1009        assert_eq!(value, None);
1010        let all_items = get_all_items.1.recv().unwrap().unwrap();
1011        assert_eq!(all_items.len(), 3);
1012        // Check that all three items are present
1013        assert!(all_items.contains(&vec![1, 2, 3]));
1014        assert!(all_items.contains(&vec![4, 5, 6]));
1015        assert!(all_items.contains(&vec![7, 8, 9]));
1016        let amount = count.1.recv().unwrap().unwrap();
1017        assert_eq!(amount, 1);
1018        remove.1.recv().unwrap().unwrap();
1019        clear.1.recv().unwrap().unwrap();
1020    }
1021}