storage/indexeddb/engines/
sqlite.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6
7use base::threadpool::ThreadPool;
8use ipc_channel::ipc::IpcSender;
9use log::{error, info};
10use rusqlite::{Connection, Error, OptionalExtension, params};
11use sea_query::{Condition, Expr, ExprTrait, IntoCondition, SqliteQueryBuilder};
12use sea_query_rusqlite::RusqliteBinder;
13use serde::Serialize;
14use storage_traits::indexeddb_thread::{
15    AsyncOperation, AsyncReadOnlyOperation, AsyncReadWriteOperation, BackendError, BackendResult,
16    CreateObjectResult, IndexedDBKeyRange, IndexedDBKeyType, IndexedDBRecord, IndexedDBTxnMode,
17    KeyPath, PutItemResult,
18};
19use tokio::sync::oneshot;
20
21use crate::indexeddb::IndexedDBDescription;
22use crate::indexeddb::engines::{KvsEngine, KvsTransaction};
23use crate::shared::{DB_INIT_PRAGMAS, DB_PRAGMAS};
24
25mod create;
26mod database_model;
27mod object_data_model;
28mod object_store_index_model;
29mod object_store_model;
30
31fn range_to_query(range: IndexedDBKeyRange) -> Condition {
32    // Special case for optimization
33    if let Some(singleton) = range.as_singleton() {
34        let encoded = bincode::serialize(singleton).unwrap();
35        return Expr::column(object_data_model::Column::Key)
36            .eq(encoded)
37            .into_condition();
38    }
39    let mut parts = vec![];
40    if let Some(upper) = range.upper.as_ref() {
41        let upper_bytes = bincode::serialize(upper).unwrap();
42        let query = if range.upper_open {
43            Expr::column(object_data_model::Column::Key).lt(upper_bytes)
44        } else {
45            Expr::column(object_data_model::Column::Key).lte(upper_bytes)
46        };
47        parts.push(query);
48    }
49    if let Some(lower) = range.lower.as_ref() {
50        let lower_bytes = bincode::serialize(lower).unwrap();
51        let query = if range.upper_open {
52            Expr::column(object_data_model::Column::Key).gt(lower_bytes)
53        } else {
54            Expr::column(object_data_model::Column::Key).gte(lower_bytes)
55        };
56        parts.push(query);
57    }
58    let mut condition = Condition::all();
59    for part in parts {
60        condition = condition.add(part);
61    }
62    condition
63}
64
65pub struct SqliteEngine {
66    db_path: PathBuf,
67    connection: Connection,
68    read_pool: Arc<ThreadPool>,
69    write_pool: Arc<ThreadPool>,
70}
71
72impl SqliteEngine {
73    // TODO: intake dual pools
74    pub fn new(
75        base_dir: &Path,
76        db_info: &IndexedDBDescription,
77        pool: Arc<ThreadPool>,
78    ) -> Result<Self, Error> {
79        let mut db_path = PathBuf::new();
80        db_path.push(base_dir);
81        db_path.push(db_info.as_path());
82        let db_parent = db_path.clone();
83        db_path.push("db.sqlite");
84
85        if !db_path.exists() {
86            std::fs::create_dir_all(db_parent).unwrap();
87            std::fs::File::create(&db_path).unwrap();
88        }
89        let connection = Self::init_db(&db_path, db_info)?;
90
91        for stmt in DB_PRAGMAS {
92            // TODO: Handle errors properly
93            let _ = connection.execute(stmt, ());
94        }
95
96        Ok(Self {
97            connection,
98            db_path,
99            read_pool: pool.clone(),
100            write_pool: pool,
101        })
102    }
103
104    fn init_db(path: &Path, db_info: &IndexedDBDescription) -> Result<Connection, Error> {
105        let connection = Connection::open(path)?;
106        if connection.table_exists(None, "database")? {
107            // Database already exists, no need to initialize
108            return Ok(connection);
109        }
110        info!("Initializing indexeddb database at {:?}", path);
111        for stmt in DB_INIT_PRAGMAS {
112            // FIXME(arihant2math): this fails occasionally
113            let _ = connection.execute(stmt, ());
114        }
115        create::create_tables(&connection)?;
116        // From https://w3c.github.io/IndexedDB/#database-version:
117        // "When a database is first created, its version is 0 (zero)."
118        connection.execute(
119            "INSERT INTO database (name, origin, version) VALUES (?, ?, ?)",
120            params![
121                db_info.name.to_owned(),
122                db_info.origin.to_owned().ascii_serialization(),
123                i64::from_ne_bytes(0_u64.to_ne_bytes())
124            ],
125        )?;
126        Ok(connection)
127    }
128
129    fn get(
130        connection: &Connection,
131        store: object_store_model::Model,
132        key_range: IndexedDBKeyRange,
133    ) -> Result<Option<object_data_model::Model>, Error> {
134        let query = range_to_query(key_range);
135        let (sql, values) = sea_query::Query::select()
136            .from(object_data_model::Column::Table)
137            .columns(vec![
138                object_data_model::Column::ObjectStoreId,
139                object_data_model::Column::Key,
140                object_data_model::Column::Data,
141            ])
142            .and_where(query.and(Expr::col(object_data_model::Column::ObjectStoreId).is(store.id)))
143            .limit(1)
144            .build_rusqlite(SqliteQueryBuilder);
145        connection
146            .prepare(&sql)?
147            .query_one(&*values.as_params(), |row| {
148                object_data_model::Model::try_from(row)
149            })
150            .optional()
151    }
152
153    fn get_key(
154        connection: &Connection,
155        store: object_store_model::Model,
156        key_range: IndexedDBKeyRange,
157    ) -> Result<Option<Vec<u8>>, Error> {
158        Self::get(connection, store, key_range).map(|opt| opt.map(|model| model.key))
159    }
160
161    fn get_item(
162        connection: &Connection,
163        store: object_store_model::Model,
164        key_range: IndexedDBKeyRange,
165    ) -> Result<Option<Vec<u8>>, Error> {
166        Self::get(connection, store, key_range).map(|opt| opt.map(|model| model.data))
167    }
168
169    fn get_all(
170        connection: &Connection,
171        store: object_store_model::Model,
172        key_range: IndexedDBKeyRange,
173        count: Option<u32>,
174    ) -> Result<Vec<object_data_model::Model>, Error> {
175        let query = range_to_query(key_range);
176        let mut sql_query = sea_query::Query::select();
177        sql_query
178            .from(object_data_model::Column::Table)
179            .columns(vec![
180                object_data_model::Column::ObjectStoreId,
181                object_data_model::Column::Key,
182                object_data_model::Column::Data,
183            ])
184            .and_where(query.and(Expr::col(object_data_model::Column::ObjectStoreId).is(store.id)));
185        if let Some(count) = count {
186            sql_query.limit(count as u64);
187        }
188        let (sql, values) = sql_query.build_rusqlite(SqliteQueryBuilder);
189        let mut stmt = connection.prepare(&sql)?;
190        let models = stmt
191            .query_and_then(&*values.as_params(), |row| {
192                object_data_model::Model::try_from(row)
193            })?
194            .collect::<Result<Vec<_>, _>>()?;
195        Ok(models)
196    }
197
198    fn get_all_keys(
199        connection: &Connection,
200        store: object_store_model::Model,
201        key_range: IndexedDBKeyRange,
202        count: Option<u32>,
203    ) -> Result<Vec<Vec<u8>>, Error> {
204        Self::get_all(connection, store, key_range, count)
205            .map(|models| models.into_iter().map(|m| m.key).collect())
206    }
207
208    fn get_all_items(
209        connection: &Connection,
210        store: object_store_model::Model,
211        key_range: IndexedDBKeyRange,
212        count: Option<u32>,
213    ) -> Result<Vec<Vec<u8>>, Error> {
214        Self::get_all(connection, store, key_range, count)
215            .map(|models| models.into_iter().map(|m| m.data).collect())
216    }
217
218    #[allow(clippy::type_complexity)]
219    fn get_all_records(
220        connection: &Connection,
221        store: object_store_model::Model,
222        key_range: IndexedDBKeyRange,
223    ) -> Result<Vec<(Vec<u8>, Vec<u8>)>, Error> {
224        Self::get_all(connection, store, key_range, None)
225            .map(|models| models.into_iter().map(|m| (m.key, m.data)).collect())
226    }
227
228    fn put_item(
229        connection: &Connection,
230        store: object_store_model::Model,
231        serialized_key: Vec<u8>,
232        value: Vec<u8>,
233        should_overwrite: bool,
234    ) -> Result<PutItemResult, Error> {
235        let existing_item = connection
236            .prepare("SELECT * FROM object_data WHERE key = ? AND object_store_id = ?")
237            .and_then(|mut stmt| {
238                stmt.query_row(params![serialized_key, store.id], |row| {
239                    object_data_model::Model::try_from(row)
240                })
241                .optional()
242            })?;
243        if should_overwrite || existing_item.is_none() {
244            connection.execute(
245                "INSERT INTO object_data (object_store_id, key, data) VALUES (?, ?, ?)",
246                params![store.id, serialized_key, value],
247            )?;
248            Ok(PutItemResult::Success)
249        } else {
250            Ok(PutItemResult::CannotOverwrite)
251        }
252    }
253
254    fn delete_item(
255        connection: &Connection,
256        store: object_store_model::Model,
257        serialized_key: Vec<u8>,
258    ) -> Result<(), Error> {
259        connection.execute(
260            "DELETE FROM object_data WHERE key = ? AND object_store_id = ?",
261            params![serialized_key, store.id],
262        )?;
263        Ok(())
264    }
265
266    fn clear(connection: &Connection, store: object_store_model::Model) -> Result<(), Error> {
267        connection.execute(
268            "DELETE FROM object_data WHERE object_store_id = ?",
269            params![store.id],
270        )?;
271        Ok(())
272    }
273
274    fn count(
275        connection: &Connection,
276        store: object_store_model::Model,
277        key_range: IndexedDBKeyRange,
278    ) -> Result<usize, Error> {
279        let query = range_to_query(key_range);
280        let (sql, values) = sea_query::Query::select()
281            .expr(Expr::col(object_data_model::Column::Key).count())
282            .from(object_data_model::Column::Table)
283            .and_where(query.and(Expr::col(object_data_model::Column::ObjectStoreId).is(store.id)))
284            .build_rusqlite(SqliteQueryBuilder);
285        connection
286            .prepare(&sql)?
287            .query_row(&*values.as_params(), |row| row.get(0))
288            .map(|count: i64| count as usize)
289    }
290
291    fn generate_key(
292        connection: &Connection,
293        store: &object_store_model::Model,
294    ) -> Result<IndexedDBKeyType, Error> {
295        if store.auto_increment == 0 {
296            unreachable!("Should be caught in the script thread");
297        }
298        // TODO: handle overflows, this also needs to be able to handle 2^53 as per spec
299        let new_key = store.auto_increment + 1;
300        connection.execute(
301            "UPDATE object_store SET auto_increment = ? WHERE id = ?",
302            params![new_key, store.id],
303        )?;
304        Ok(IndexedDBKeyType::Number(new_key as f64))
305    }
306}
307
308impl KvsEngine for SqliteEngine {
309    type Error = Error;
310
311    fn create_store(
312        &self,
313        store_name: &str,
314        key_path: Option<KeyPath>,
315        auto_increment: bool,
316    ) -> Result<CreateObjectResult, Self::Error> {
317        let mut stmt = self
318            .connection
319            .prepare("SELECT * FROM object_store WHERE name = ?")?;
320        if stmt.exists(params![store_name.to_string()])? {
321            // Store already exists
322            return Ok(CreateObjectResult::AlreadyExists);
323        }
324        self.connection.execute(
325            "INSERT INTO object_store (name, key_path, auto_increment) VALUES (?, ?, ?)",
326            params![
327                store_name.to_string(),
328                key_path.map(|v| bincode::serialize(&v).unwrap()),
329                auto_increment as i32
330            ],
331        )?;
332
333        Ok(CreateObjectResult::Created)
334    }
335
336    fn delete_store(&self, store_name: &str) -> Result<(), Self::Error> {
337        let result = self.connection.execute(
338            "DELETE FROM object_store WHERE name = ?",
339            params![store_name.to_string()],
340        )?;
341        if result == 0 {
342            Err(Error::QueryReturnedNoRows)
343        } else if result > 1 {
344            Err(Error::QueryReturnedMoreThanOneRow)
345        } else {
346            Ok(())
347        }
348    }
349
350    fn close_store(&self, _store_name: &str) -> Result<(), Self::Error> {
351        // TODO: do something
352        Ok(())
353    }
354
355    fn delete_database(self) -> Result<(), Self::Error> {
356        // attempt to close the connection first
357        let _ = self.connection.close();
358        if self.db_path.exists() {
359            if let Err(e) = std::fs::remove_dir_all(self.db_path.parent().unwrap()) {
360                error!("Failed to delete database: {:?}", e);
361            }
362        }
363        Ok(())
364    }
365
366    fn process_transaction(
367        &self,
368        transaction: KvsTransaction,
369    ) -> oneshot::Receiver<Option<Vec<u8>>> {
370        let (tx, rx) = oneshot::channel();
371
372        let spawning_pool = if transaction.mode == IndexedDBTxnMode::Readonly {
373            self.read_pool.clone()
374        } else {
375            self.write_pool.clone()
376        };
377        let path = self.db_path.clone();
378        spawning_pool.spawn(move || {
379            let connection = Connection::open(path).unwrap();
380            for request in transaction.requests {
381                let object_store = connection
382                    .prepare("SELECT * FROM object_store WHERE name = ?")
383                    .and_then(|mut stmt| {
384                        stmt.query_row(params![request.store_name.to_string()], |row| {
385                            object_store_model::Model::try_from(row)
386                        })
387                        .optional()
388                    });
389                fn process_object_store<T>(
390                    object_store: Result<Option<object_store_model::Model>, Error>,
391                    sender: &IpcSender<BackendResult<T>>,
392                ) -> Result<object_store_model::Model, ()>
393                where
394                    T: Serialize,
395                {
396                    match object_store {
397                        Ok(Some(store)) => Ok(store),
398                        Ok(None) => {
399                            let _ = sender.send(Err(BackendError::StoreNotFound));
400                            Err(())
401                        },
402                        Err(e) => {
403                            let _ = sender.send(Err(BackendError::DbErr(format!("{:?}", e))));
404                            Err(())
405                        },
406                    }
407                }
408
409                match request.operation {
410                    AsyncOperation::ReadWrite(AsyncReadWriteOperation::PutItem {
411                        sender,
412                        key,
413                        value,
414                        should_overwrite,
415                    }) => {
416                        let Ok(object_store) = process_object_store(object_store, &sender) else {
417                            continue;
418                        };
419                        let key = match key
420                            .map(Ok)
421                            .unwrap_or_else(|| Self::generate_key(&connection, &object_store))
422                        {
423                            Ok(key) => key,
424                            Err(e) => {
425                                let _ = sender.send(Err(BackendError::DbErr(format!("{:?}", e))));
426                                continue;
427                            },
428                        };
429                        let serialized_key: Vec<u8> = bincode::serialize(&key).unwrap();
430                        let _ = sender.send(
431                            Self::put_item(
432                                &connection,
433                                object_store,
434                                serialized_key,
435                                value,
436                                should_overwrite,
437                            )
438                            .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
439                        );
440                    },
441                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetItem {
442                        sender,
443                        key_range,
444                    }) => {
445                        let Ok(object_store) = process_object_store(object_store, &sender) else {
446                            continue;
447                        };
448                        let _ = sender.send(
449                            Self::get_item(&connection, object_store, key_range)
450                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
451                        );
452                    },
453                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetAllKeys {
454                        sender,
455                        key_range,
456                        count,
457                    }) => {
458                        let Ok(object_store) = process_object_store(object_store, &sender) else {
459                            continue;
460                        };
461                        let _ = sender.send(
462                            Self::get_all_keys(&connection, object_store, key_range, count)
463                                .map(|keys| {
464                                    keys.into_iter()
465                                        .map(|k| bincode::deserialize(&k).unwrap())
466                                        .collect()
467                                })
468                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
469                        );
470                    },
471                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetAllItems {
472                        sender,
473                        key_range,
474                        count,
475                    }) => {
476                        let Ok(object_store) = process_object_store(object_store, &sender) else {
477                            continue;
478                        };
479                        let _ = sender.send(
480                            Self::get_all_items(&connection, object_store, key_range, count)
481                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
482                        );
483                    },
484                    AsyncOperation::ReadWrite(AsyncReadWriteOperation::RemoveItem {
485                        sender,
486                        key,
487                    }) => {
488                        let Ok(object_store) = process_object_store(object_store, &sender) else {
489                            continue;
490                        };
491                        let serialized_key: Vec<u8> = bincode::serialize(&key).unwrap();
492                        let _ = sender.send(
493                            Self::delete_item(&connection, object_store, serialized_key)
494                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
495                        );
496                    },
497                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::Count {
498                        sender,
499                        key_range,
500                    }) => {
501                        let Ok(object_store) = process_object_store(object_store, &sender) else {
502                            continue;
503                        };
504                        let _ = sender.send(
505                            Self::count(&connection, object_store, key_range)
506                                .map(|r| r as u64)
507                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
508                        );
509                    },
510                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::Iterate {
511                        sender,
512                        key_range,
513                    }) => {
514                        let Ok(object_store) = process_object_store(object_store, &sender) else {
515                            continue;
516                        };
517                        let _ = sender.send(
518                            Self::get_all_records(&connection, object_store, key_range)
519                                .map(|records| {
520                                    records
521                                        .into_iter()
522                                        .map(|(key, data)| IndexedDBRecord {
523                                            key: bincode::deserialize(&key).unwrap(),
524                                            primary_key: bincode::deserialize(&key).unwrap(),
525                                            value: data,
526                                        })
527                                        .collect()
528                                })
529                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
530                        );
531                    },
532                    AsyncOperation::ReadWrite(AsyncReadWriteOperation::Clear(sender)) => {
533                        let Ok(object_store) = process_object_store(object_store, &sender) else {
534                            continue;
535                        };
536                        let _ = sender.send(
537                            Self::clear(&connection, object_store)
538                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
539                        );
540                    },
541                    AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetKey {
542                        sender,
543                        key_range,
544                    }) => {
545                        let Ok(object_store) = process_object_store(object_store, &sender) else {
546                            continue;
547                        };
548                        let _ = sender.send(
549                            Self::get_key(&connection, object_store, key_range)
550                                .map(|key| key.map(|k| bincode::deserialize(&k).unwrap()))
551                                .map_err(|e| BackendError::DbErr(format!("{:?}", e))),
552                        );
553                    },
554                }
555            }
556            let _ = tx.send(None);
557        });
558        rx
559    }
560
561    // TODO: we should be able to error out here, maybe change the trait definition?
562    fn has_key_generator(&self, store_name: &str) -> bool {
563        self.connection
564            .prepare("SELECT * FROM object_store WHERE name = ?")
565            .and_then(|mut stmt| {
566                stmt.query_row(params![store_name.to_string()], |r| {
567                    let object_store = object_store_model::Model::try_from(r).unwrap();
568                    Ok(object_store.auto_increment)
569                })
570            })
571            .optional()
572            .unwrap()
573            // TODO: Wrong (change trait definition for this function)
574            .unwrap_or_default() !=
575            0
576    }
577
578    fn key_path(&self, store_name: &str) -> Option<KeyPath> {
579        self.connection
580            .prepare("SELECT * FROM object_store WHERE name = ?")
581            .and_then(|mut stmt| {
582                stmt.query_row(params![store_name.to_string()], |r| {
583                    let object_store = object_store_model::Model::try_from(r).unwrap();
584                    Ok(object_store
585                        .key_path
586                        .map(|key_path| bincode::deserialize(&key_path).unwrap()))
587                })
588            })
589            .optional()
590            .unwrap()
591            // TODO: Wrong, same issues as has_key_generator
592            .unwrap_or_default()
593    }
594
595    fn create_index(
596        &self,
597        store_name: &str,
598        index_name: String,
599        key_path: KeyPath,
600        unique: bool,
601        multi_entry: bool,
602    ) -> Result<CreateObjectResult, Self::Error> {
603        let object_store = self.connection.query_row(
604            "SELECT * FROM object_store WHERE name = ?",
605            params![store_name.to_string()],
606            |row| object_store_model::Model::try_from(row),
607        )?;
608
609        let index_exists: bool = self.connection.query_row(
610            "SELECT EXISTS(SELECT * FROM object_store_index WHERE name = ? AND object_store_id = ?)",
611            params![index_name.to_string(), object_store.id],
612            |row| row.get(0),
613        )?;
614        if index_exists {
615            return Ok(CreateObjectResult::AlreadyExists);
616        }
617
618        self.connection.execute(
619            "INSERT INTO object_store_index (object_store_id, name, key_path, unique_index, multi_entry_index)\
620            VALUES (?, ?, ?, ?, ?)",
621            params![
622                object_store.id,
623                index_name.to_string(),
624                bincode::serialize(&key_path).unwrap(),
625                unique,
626                multi_entry,
627            ],
628        )?;
629        Ok(CreateObjectResult::Created)
630    }
631
632    fn delete_index(&self, store_name: &str, index_name: String) -> Result<(), Self::Error> {
633        let object_store = self.connection.query_row(
634            "SELECT * FROM object_store WHERE name = ?",
635            params![store_name.to_string()],
636            |r| Ok(object_store_model::Model::try_from(r).unwrap()),
637        )?;
638
639        // Delete the index if it exists
640        let _ = self.connection.execute(
641            "DELETE FROM object_store_index WHERE name = ? AND object_store_id = ?",
642            params![index_name.to_string(), object_store.id],
643        )?;
644        Ok(())
645    }
646
647    fn version(&self) -> Result<u64, Self::Error> {
648        let version: i64 =
649            self.connection
650                .query_row("SELECT version FROM database LIMIT 1", [], |row| row.get(0))?;
651        Ok(u64::from_ne_bytes(version.to_ne_bytes()))
652    }
653
654    fn set_version(&self, version: u64) -> Result<(), Self::Error> {
655        let rows_affected = self.connection.execute(
656            "UPDATE database SET version = ?",
657            params![i64::from_ne_bytes(version.to_ne_bytes())],
658        )?;
659        if rows_affected == 0 {
660            return Err(Error::QueryReturnedNoRows);
661        }
662        Ok(())
663    }
664}
665
666#[cfg(test)]
667mod tests {
668    use std::collections::VecDeque;
669    use std::sync::Arc;
670
671    use base::threadpool::ThreadPool;
672    use serde::{Deserialize, Serialize};
673    use servo_url::ImmutableOrigin;
674    use storage_traits::indexeddb_thread::{
675        AsyncOperation, AsyncReadOnlyOperation, AsyncReadWriteOperation, CreateObjectResult,
676        IndexedDBKeyRange, IndexedDBKeyType, IndexedDBTxnMode, KeyPath, PutItemResult,
677    };
678    use url::Host;
679
680    use crate::indexeddb::IndexedDBDescription;
681    use crate::indexeddb::engines::{KvsEngine, KvsOperation, KvsTransaction, SqliteEngine};
682
683    fn test_origin() -> ImmutableOrigin {
684        ImmutableOrigin::Tuple(
685            "test_origin".to_string(),
686            Host::Domain("localhost".to_string()),
687            80,
688        )
689    }
690
691    fn get_pool() -> Arc<ThreadPool> {
692        Arc::new(ThreadPool::new(1, "test".to_string()))
693    }
694
695    #[test]
696    fn test_cycle() {
697        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
698        let thread_pool = get_pool();
699        // Test create
700        let _ = SqliteEngine::new(
701            base_dir.path(),
702            &IndexedDBDescription {
703                name: "test_db".to_string(),
704                origin: test_origin(),
705            },
706            thread_pool.clone(),
707        )
708        .unwrap();
709        // Test open
710        let db = SqliteEngine::new(
711            base_dir.path(),
712            &IndexedDBDescription {
713                name: "test_db".to_string(),
714                origin: test_origin(),
715            },
716            thread_pool.clone(),
717        )
718        .unwrap();
719        let version = db.version().expect("Failed to get version");
720        assert_eq!(version, 0);
721        db.set_version(5).unwrap();
722        let new_version = db.version().expect("Failed to get new version");
723        assert_eq!(new_version, 5);
724        db.delete_database().expect("Failed to delete database");
725    }
726
727    #[test]
728    fn test_create_store() {
729        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
730        let thread_pool = get_pool();
731        let db = SqliteEngine::new(
732            base_dir.path(),
733            &IndexedDBDescription {
734                name: "test_db".to_string(),
735                origin: test_origin(),
736            },
737            thread_pool,
738        )
739        .unwrap();
740        let store_name = "test_store";
741        let result = db.create_store(store_name, None, true);
742        assert!(result.is_ok());
743        let create_result = result.unwrap();
744        assert_eq!(create_result, CreateObjectResult::Created);
745        // Try to create the same store again
746        let result = db.create_store(store_name, None, false);
747        assert!(result.is_ok());
748        let create_result = result.unwrap();
749        assert_eq!(create_result, CreateObjectResult::AlreadyExists);
750        // Ensure store was not overwritten
751        assert!(db.has_key_generator(store_name));
752    }
753
754    #[test]
755    fn test_create_store_empty_name() {
756        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
757        let thread_pool = get_pool();
758        let db = SqliteEngine::new(
759            base_dir.path(),
760            &IndexedDBDescription {
761                name: "test_db".to_string(),
762                origin: test_origin(),
763            },
764            thread_pool,
765        )
766        .unwrap();
767        let store_name = "";
768        let result = db.create_store(store_name, None, true);
769        assert!(result.is_ok());
770        let create_result = result.unwrap();
771        assert_eq!(create_result, CreateObjectResult::Created);
772    }
773
774    #[test]
775    fn test_injection() {
776        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
777        let thread_pool = get_pool();
778        let db = SqliteEngine::new(
779            base_dir.path(),
780            &IndexedDBDescription {
781                name: "test_db".to_string(),
782                origin: test_origin(),
783            },
784            thread_pool,
785        )
786        .unwrap();
787        // Create a normal store
788        let store_name1 = "test_store";
789        let result = db.create_store(store_name1, None, true);
790        assert!(result.is_ok());
791        let create_result = result.unwrap();
792        assert_eq!(create_result, CreateObjectResult::Created);
793        // Injection
794        let store_name2 = "' OR 1=1 -- -";
795        let result = db.create_store(store_name2, None, false);
796        assert!(result.is_ok());
797        let create_result = result.unwrap();
798        assert_eq!(create_result, CreateObjectResult::Created);
799    }
800
801    #[test]
802    fn test_key_path() {
803        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
804        let thread_pool = get_pool();
805        let db = SqliteEngine::new(
806            base_dir.path(),
807            &IndexedDBDescription {
808                name: "test_db".to_string(),
809                origin: test_origin(),
810            },
811            thread_pool,
812        )
813        .unwrap();
814        let store_name = "test_store";
815        let result = db.create_store(store_name, Some(KeyPath::String("test".to_string())), true);
816        assert!(result.is_ok());
817        assert_eq!(
818            db.key_path(store_name),
819            Some(KeyPath::String("test".to_string()))
820        );
821    }
822
823    #[test]
824    fn test_delete_store() {
825        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
826        let thread_pool = get_pool();
827        let db = SqliteEngine::new(
828            base_dir.path(),
829            &IndexedDBDescription {
830                name: "test_db".to_string(),
831                origin: test_origin(),
832            },
833            thread_pool,
834        )
835        .unwrap();
836        db.create_store("test_store", None, false)
837            .expect("Failed to create store");
838        // Delete the store
839        db.delete_store("test_store")
840            .expect("Failed to delete store");
841        // Try to delete the same store again
842        let result = db.delete_store("test_store");
843        assert!(result.is_err());
844        // Try to delete a non-existing store
845        let result = db.delete_store("test_store");
846        // Should work as per spec
847        assert!(result.is_err());
848    }
849
850    #[test]
851    fn test_async_operations() {
852        fn get_channel<T>() -> (
853            ipc_channel::ipc::IpcSender<T>,
854            ipc_channel::ipc::IpcReceiver<T>,
855        )
856        where
857            T: for<'de> Deserialize<'de> + Serialize,
858        {
859            ipc_channel::ipc::channel().unwrap()
860        }
861
862        let base_dir = tempfile::tempdir().expect("Failed to create temp dir");
863        let thread_pool = get_pool();
864        let db = SqliteEngine::new(
865            base_dir.path(),
866            &IndexedDBDescription {
867                name: "test_db".to_string(),
868                origin: test_origin(),
869            },
870            thread_pool,
871        )
872        .unwrap();
873        let store_name = "test_store";
874        db.create_store(store_name, None, false)
875            .expect("Failed to create store");
876        let put = get_channel();
877        let put2 = get_channel();
878        let put3 = get_channel();
879        let put_dup = get_channel();
880        let get_item_some = get_channel();
881        let get_item_none = get_channel();
882        let get_all_items = get_channel();
883        let count = get_channel();
884        let remove = get_channel();
885        let clear = get_channel();
886        let rx = db.process_transaction(KvsTransaction {
887            mode: IndexedDBTxnMode::Readwrite,
888            requests: VecDeque::from(vec![
889                KvsOperation {
890                    store_name: store_name.to_owned(),
891                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::PutItem {
892                        sender: put.0,
893                        key: Some(IndexedDBKeyType::Number(1.0)),
894                        value: vec![1, 2, 3],
895                        should_overwrite: false,
896                    }),
897                },
898                KvsOperation {
899                    store_name: store_name.to_owned(),
900                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::PutItem {
901                        sender: put2.0,
902                        key: Some(IndexedDBKeyType::String("2.0".to_string())),
903                        value: vec![4, 5, 6],
904                        should_overwrite: false,
905                    }),
906                },
907                KvsOperation {
908                    store_name: store_name.to_owned(),
909                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::PutItem {
910                        sender: put3.0,
911                        key: Some(IndexedDBKeyType::Array(vec![
912                            IndexedDBKeyType::String("3".to_string()),
913                            IndexedDBKeyType::Number(0.0),
914                        ])),
915                        value: vec![7, 8, 9],
916                        should_overwrite: false,
917                    }),
918                },
919                // Try to put a duplicate key without overwrite
920                KvsOperation {
921                    store_name: store_name.to_owned(),
922                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::PutItem {
923                        sender: put_dup.0,
924                        key: Some(IndexedDBKeyType::Number(1.0)),
925                        value: vec![10, 11, 12],
926                        should_overwrite: false,
927                    }),
928                },
929                KvsOperation {
930                    store_name: store_name.to_owned(),
931                    operation: AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetItem {
932                        sender: get_item_some.0,
933                        key_range: IndexedDBKeyRange::only(IndexedDBKeyType::Number(1.0)),
934                    }),
935                },
936                KvsOperation {
937                    store_name: store_name.to_owned(),
938                    operation: AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetItem {
939                        sender: get_item_none.0,
940                        key_range: IndexedDBKeyRange::only(IndexedDBKeyType::Number(5.0)),
941                    }),
942                },
943                KvsOperation {
944                    store_name: store_name.to_owned(),
945                    operation: AsyncOperation::ReadOnly(AsyncReadOnlyOperation::GetAllItems {
946                        sender: get_all_items.0,
947                        key_range: IndexedDBKeyRange::lower_bound(
948                            IndexedDBKeyType::Number(0.0),
949                            false,
950                        ),
951                        count: None,
952                    }),
953                },
954                KvsOperation {
955                    store_name: store_name.to_owned(),
956                    operation: AsyncOperation::ReadOnly(AsyncReadOnlyOperation::Count {
957                        sender: count.0,
958                        key_range: IndexedDBKeyRange::only(IndexedDBKeyType::Number(1.0)),
959                    }),
960                },
961                KvsOperation {
962                    store_name: store_name.to_owned(),
963                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::RemoveItem {
964                        sender: remove.0,
965                        key: IndexedDBKeyType::Number(1.0),
966                    }),
967                },
968                KvsOperation {
969                    store_name: store_name.to_owned(),
970                    operation: AsyncOperation::ReadWrite(AsyncReadWriteOperation::Clear(clear.0)),
971                },
972            ]),
973        });
974        let _ = rx.blocking_recv().unwrap();
975        put.1.recv().unwrap().unwrap();
976        put2.1.recv().unwrap().unwrap();
977        put3.1.recv().unwrap().unwrap();
978        let err = put_dup.1.recv().unwrap().unwrap();
979        assert_eq!(err, PutItemResult::CannotOverwrite);
980        let get_result = get_item_some.1.recv().unwrap();
981        let value = get_result.unwrap();
982        assert_eq!(value, Some(vec![1, 2, 3]));
983        let get_result = get_item_none.1.recv().unwrap();
984        let value = get_result.unwrap();
985        assert_eq!(value, None);
986        let all_items = get_all_items.1.recv().unwrap().unwrap();
987        assert_eq!(all_items.len(), 3);
988        // Check that all three items are present
989        assert!(all_items.contains(&vec![1, 2, 3]));
990        assert!(all_items.contains(&vec![4, 5, 6]));
991        assert!(all_items.contains(&vec![7, 8, 9]));
992        let amount = count.1.recv().unwrap().unwrap();
993        assert_eq!(amount, 1);
994        remove.1.recv().unwrap().unwrap();
995        clear.1.recv().unwrap().unwrap();
996    }
997}