diff --git a/src/workerd/api/actor-state.c++ b/src/workerd/api/actor-state.c++ index 2d405a398dd..cb629363107 100644 --- a/src/workerd/api/actor-state.c++ +++ b/src/workerd/api/actor-state.c++ @@ -510,24 +510,9 @@ jsg::Promise DurableObjectStorage::deleteAll( jsg::Lock& js, jsg::Optional maybeOptions) { auto options = configureOptions(kj::mv(maybeOptions).orDefault(PutOptions{})); - auto& context = IoContext::current(); - { - // Log to get a sense of whether users are potentially depending on alarms being kept around - // after deleteAll is called. - auto getOptions = configureOptions(GetOptions{}); - context.addTask(context - .awaitJs(js, - transformCacheResult(js, cache->getAlarm(getOptions), getOptions, - [](jsg::Lock&, kj::Maybe alarmValue) { - if (alarmValue != kj::none) { - LOG_WARNING_PERIODICALLY("NOSENTRY deleteAll called with an alarm still set"); - } - return alarmValue; - })).ignoreResult()); - } - auto deleteAll = cache->deleteAll(options); + auto& context = IoContext::current(); context.addTask(updateStorageDeletes(context, currentActorMetrics(), kj::mv(deleteAll.count))); return transformMaybeBackpressure(js, options, kj::mv(deleteAll.backpressure)); diff --git a/src/workerd/api/sql-test.js b/src/workerd/api/sql-test.js index 5a75f147564..78cf942df06 100644 --- a/src/workerd/api/sql-test.js +++ b/src/workerd/api/sql-test.js @@ -1177,6 +1177,11 @@ export class DurableObjectExample { } else if (req.url.endsWith('/streaming-ingestion')) { await testStreamingIngestion(req, this.state.storage); return Response.json({ ok: true }); + } else if (req.url.endsWith('/deleteAll')) { + this.state.storage.put('counter', 888); // will be deleted + this.state.storage.deleteAll(); + assert.strictEqual(await this.state.storage.get('counter'), undefined); + return Response.json({ ok: true }); } throw new Error('unknown url: ' + req.url); @@ -1251,6 +1256,11 @@ export default { // Everything's still consistent. assert.equal(await doReq('increment'), 3); + + // Delete all: increments start over + await doReq('deleteAll'); + assert.equal(await doReq('increment'), 1); + assert.equal(await doReq('increment'), 2); }, }; diff --git a/src/workerd/io/actor-sqlite.c++ b/src/workerd/io/actor-sqlite.c++ index e1781c3da13..9984ae00053 100644 --- a/src/workerd/io/actor-sqlite.c++ +++ b/src/workerd/io/actor-sqlite.c++ @@ -50,6 +50,18 @@ void ActorSqlite::ImplicitTxn::commit() { } } +void ActorSqlite::ImplicitTxn::rollback() { + // Ignore redundant commit()s. + if (!committed) { + // As of this writing, rollback() is only called when the database is about to be reset. + // Preparing a statement for it would be a waste since that statement would never be executed + // more than once, since resetting requires repreparing all statements anyway. So we don't + // bother. + parent.db->run("ROLLBACK TRANSACTION"); + committed = true; + } +} + ActorSqlite::ExplicitTxn::ExplicitTxn(ActorSqlite& actorSqlite): actorSqlite(actorSqlite) { KJ_SWITCH_ONEOF(actorSqlite.currentTxn) { KJ_CASE_ONEOF(_, NoTxn) { @@ -283,6 +295,46 @@ kj::Own ActorSqlite::startTransaction() { ActorCacheInterface::DeleteAllResults ActorSqlite::deleteAll(WriteOptions options) { requireNotBroken(); + // deleteAll() cannot be part of a transaction because it deletes the database altogether. So, + // we have to close our transactions or fail. + KJ_SWITCH_ONEOF(currentTxn) { + KJ_CASE_ONEOF(_, NoTxn) { + // good + } + KJ_CASE_ONEOF(implicit, ImplicitTxn*) { + // Whatever the implicit transaction did, it's about to be blown away anyway. Roll it back + // so we don't waste time flushing these writes anywhere. + implicit->rollback(); + currentTxn = NoTxn(); + } + KJ_CASE_ONEOF(exp, ExplicitTxn*) { + // Keep in mind: + // + // ctx.storage.transaction(txn => { + // txn.deleteAll(); // calls `DurableObjectTransaction::deleteAll()` + // ctx.storage.deleteAll(); // calls this method, `ActorSqlite::deleteAll()` + // }); + // + // `DurableObjectTransaction::deleteAll()` throws this exception, since `deleteAll()` is not + // supported inside a transaction. Under the new SQLite-backed storage system, directly + // calling `cxt.storage` inside a transaction (as opposed to using the `txn` object) should + // still be treated as part of the transaction, and so should throw the same thing. + JSG_FAIL_REQUIRE(Error, "Cannot call deleteAll() within a transaction"); + } + } + + if (!deleteAllCommitScheduled) { + // We'll want to make sure the commit callback is called for the deleteAll(). + commitTasks.add(outputGate.lockWhile(kj::evalLater([this]() mutable -> kj::Promise { + // Don't commit if shutdown() has been called. + requireNotBroken(); + + deleteAllCommitScheduled = false; + return commitCallback(); + }))); + deleteAllCommitScheduled = true; + } + uint count = kv.deleteAll(); return { .backpressure = kj::none, diff --git a/src/workerd/io/actor-sqlite.h b/src/workerd/io/actor-sqlite.h index 56bf363163f..4c305a1f74d 100644 --- a/src/workerd/io/actor-sqlite.h +++ b/src/workerd/io/actor-sqlite.h @@ -46,7 +46,7 @@ class ActorSqlite final: public ActorCacheInterface, private kj::TaskSet::ErrorH Hooks& hooks = const_cast(Hooks::DEFAULT)); bool isCommitScheduled() { - return !currentTxn.is(); + return !currentTxn.is() || deleteAllCommitScheduled; } kj::Maybe getSqliteDatabase() override { @@ -101,6 +101,7 @@ class ActorSqlite final: public ActorCacheInterface, private kj::TaskSet::ErrorH KJ_DISALLOW_COPY_AND_MOVE(ImplicitTxn); void commit(); + void rollback(); private: ActorSqlite& parent; @@ -156,6 +157,9 @@ class ActorSqlite final: public ActorCacheInterface, private kj::TaskSet::ErrorH // transactions should be used in the meantime. kj::OneOf currentTxn = NoTxn(); + // If true, then a commit is scheduled as a result of deleteAll() having been called. + bool deleteAllCommitScheduled = false; + kj::TaskSet commitTasks; void onWrite(); diff --git a/src/workerd/server/alarm-scheduler.c++ b/src/workerd/server/alarm-scheduler.c++ index 1c954c94bd3..3ef9f52e143 100644 --- a/src/workerd/server/alarm-scheduler.c++ +++ b/src/workerd/server/alarm-scheduler.c++ @@ -25,12 +25,12 @@ std::default_random_engine makeSeededRandomEngine() { } // namespace AlarmScheduler::AlarmScheduler( - const kj::Clock& clock, kj::Timer& timer, const SqliteDatabase::Vfs& vfs, kj::PathPtr path) + const kj::Clock& clock, kj::Timer& timer, const SqliteDatabase::Vfs& vfs, kj::Path path) : clock(clock), timer(timer), random(makeSeededRandomEngine()), db([&] { - auto db = kj::heap(vfs, path, + auto db = kj::heap(vfs, kj::mv(path), kj::WriteMode::CREATE | kj::WriteMode::MODIFY | kj::WriteMode::CREATE_PARENT); ensureInitialized(*db); return kj::mv(db); diff --git a/src/workerd/server/alarm-scheduler.h b/src/workerd/server/alarm-scheduler.h index 2cf1086afd5..1a8aa2a1676 100644 --- a/src/workerd/server/alarm-scheduler.h +++ b/src/workerd/server/alarm-scheduler.h @@ -62,7 +62,7 @@ class AlarmScheduler final: kj::TaskSet::ErrorHandler { using GetActorFn = kj::Function(kj::String)>; AlarmScheduler( - const kj::Clock& clock, kj::Timer& timer, const SqliteDatabase::Vfs& vfs, kj::PathPtr path); + const kj::Clock& clock, kj::Timer& timer, const SqliteDatabase::Vfs& vfs, kj::Path path); kj::Maybe getAlarm(ActorKey actor); bool setAlarm(ActorKey actor, kj::Date scheduledTime); diff --git a/src/workerd/server/server.c++ b/src/workerd/server/server.c++ index 8a64902f6e2..d98427a751b 100644 --- a/src/workerd/server/server.c++ +++ b/src/workerd/server/server.c++ @@ -1941,8 +1941,11 @@ public: kj::Path({d.uniqueKey, kj::str(idPtr, ".sqlite")}), kj::WriteMode::CREATE | kj::WriteMode::MODIFY | kj::WriteMode::CREATE_PARENT); - // Before we do anything, make sure the database is in WAL mode. - db->run("PRAGMA journal_mode=WAL;"); + // Before we do anything, make sure the database is in WAL mode. We also need to + // do this after reset() is used, so register a callback for that. + auto setWalMode = [](SqliteDatabase& db) { db.run("PRAGMA journal_mode=WAL;"); }; + setWalMode(*db); + db->afterReset(kj::mv(setWalMode)); return kj::heap(kj::mv(db), outputGate, []() -> kj::Promise { return kj::READY_NOW; }, *sqliteHooks) diff --git a/src/workerd/util/sqlite-kv-test.c++ b/src/workerd/util/sqlite-kv-test.c++ index 28a3b4e439a..2897f570374 100644 --- a/src/workerd/util/sqlite-kv-test.c++ +++ b/src/workerd/util/sqlite-kv-test.c++ @@ -87,8 +87,27 @@ KJ_TEST("SQLite-KV") { kv.put("foo", "hello"_kj.asBytes()); KJ_EXPECT(list(nullptr, kj::none, kj::none, F) == "bar=def, foo=hello, qux=321"); - kv.deleteAll(); + // deleteAll() + KJ_EXPECT(kv.deleteAll() == 3); KJ_EXPECT(list(nullptr, kj::none, kj::none, F) == ""); + + KJ_EXPECT(!kv.get("bar", [&](kj::ArrayPtr value) { + KJ_FAIL_EXPECT("should not call callback when no match", value.asChars()); + })); + + kv.put("bar", "ghi"_kj.asBytes()); + kv.put("corge", "garply"_kj.asBytes()); + + KJ_EXPECT(list(nullptr, kj::none, kj::none, F) == "bar=ghi, corge=garply"); + + { + bool called = false; + KJ_EXPECT(kv.get("bar", [&](kj::ArrayPtr value) { + KJ_EXPECT(kj::str(value.asChars()) == "ghi"); + called = true; + })); + KJ_EXPECT(called); + } } } // namespace diff --git a/src/workerd/util/sqlite-kv.c++ b/src/workerd/util/sqlite-kv.c++ index 4aa25204484..2f78f863a40 100644 --- a/src/workerd/util/sqlite-kv.c++ +++ b/src/workerd/util/sqlite-kv.c++ @@ -6,30 +6,33 @@ namespace workerd { -SqliteKv::SqliteKv(SqliteDatabase& db) { +SqliteKv::SqliteKv(SqliteDatabase& db): ResetListener(db) { if (db.run("SELECT name FROM sqlite_master WHERE type='table' AND name='_cf_KV'").isDone()) { // The _cf_KV table doesn't exist. Defer initialization. - state = Uninitialized{db}; + state.init(Uninitialized{}); } else { // The KV table was initialized in the past. We can go ahead and prepare our statements. // (We don't call ensureInitialized() here because the `CREATE TABLE IF NOT EXISTS` query it // executes would be redundant.) - state = Initialized(db); + tableCreated = true; + state.init(db); } } SqliteKv::Initialized& SqliteKv::ensureInitialized() { - KJ_SWITCH_ONEOF(state) { - KJ_CASE_ONEOF(uninitialized, Uninitialized) { - auto& db = uninitialized.db; + if (!tableCreated) { + db.run(R"( + CREATE TABLE IF NOT EXISTS _cf_KV ( + key TEXT PRIMARY KEY, + value BLOB + ) WITHOUT ROWID; + )"); - db.run(R"( - CREATE TABLE IF NOT EXISTS _cf_KV ( - key TEXT PRIMARY KEY, - value BLOB - ) WITHOUT ROWID; - )"); + tableCreated = true; + } + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(uninitialized, Uninitialized) { return state.init(db); } KJ_CASE_ONEOF(initialized, Initialized) { @@ -49,8 +52,17 @@ bool SqliteKv::delete_(KeyPtr key) { } uint SqliteKv::deleteAll() { - auto query = ensureInitialized().stmtDeleteAll.run(); - return query.changeCount(); + // TODO(perf): Consider introducing a compatibility flag that causes deleteAll() to always return + // 1. Apps almost certainly don't care about the return value but historically we returned the + // count of keys deleted, so now we're stuck counting the table size for no good reason. + uint count = tableCreated ? ensureInitialized().stmtCountKeys.run().getInt(0) : 0; + db.reset(); + return count; +} + +void SqliteKv::beforeSqliteReset() { + // We'll need to recreate the table on the next operation. + tableCreated = false; } } // namespace workerd diff --git a/src/workerd/util/sqlite-kv.h b/src/workerd/util/sqlite-kv.h index 74cc1d74b28..638bdcb41c8 100644 --- a/src/workerd/util/sqlite-kv.h +++ b/src/workerd/util/sqlite-kv.h @@ -15,7 +15,7 @@ namespace workerd { // perform direct SQL queries, we can block it from accessing any table prefixed with `_cf_`. // (Ideally this class would allow configuring the table name, but this would require a somewhat // obnoxious amount of string allocation.) -class SqliteKv { +class SqliteKv: private SqliteDatabase::ResetListener { public: explicit SqliteKv(SqliteDatabase& db); @@ -51,11 +51,11 @@ class SqliteKv { // byte blobs or strings containing NUL bytes. private: - struct Uninitialized { - SqliteDatabase& db; - }; + struct Uninitialized {}; struct Initialized { + // This reference is redundant but storing it here makes the prepared statement code below + // easier to manage. SqliteDatabase& db; SqliteDatabase::Statement stmtGet = db.prepare(R"( @@ -112,8 +112,8 @@ class SqliteKv { ORDER BY key DESC LIMIT ? )"); - SqliteDatabase::Statement stmtDeleteAll = db.prepare(R"( - DELETE FROM _cf_KV + SqliteDatabase::Statement stmtCountKeys = db.prepare(R"( + SELECT count(*) FROM _cf_KV )"); Initialized(SqliteDatabase& db): db(db) {} @@ -121,11 +121,15 @@ class SqliteKv { kj::OneOf state; + // Has the _cf_KV table been created? This is separate from Uninitialized/Initialized since it + // has to be repeated after a reset, whereas the statements do not need to be recreated. + bool tableCreated = false; + Initialized& ensureInitialized(); // Make sure the KV table is created and prepared statements are ready. Not called until the // first write. - SqliteKv(SqliteDatabase& db, bool); + void beforeSqliteReset() override; }; // ======================================================================================= @@ -137,6 +141,7 @@ class SqliteKv { template bool SqliteKv::get(KeyPtr key, Func&& callback) { + if (!tableCreated) return 0; auto& stmts = KJ_UNWRAP_OR(state.tryGet(), return false); auto query = stmts.stmtGet.run(key); @@ -152,6 +157,7 @@ bool SqliteKv::get(KeyPtr key, Func&& callback) { template uint SqliteKv::list( KeyPtr begin, kj::Maybe end, kj::Maybe limit, Order order, Func&& callback) { + if (!tableCreated) return 0; auto& stmts = KJ_UNWRAP_OR(state.tryGet(), return 0); auto iterate = [&](SqliteDatabase::Query&& query) { diff --git a/src/workerd/util/sqlite-test.c++ b/src/workerd/util/sqlite-test.c++ index 2b5fc863674..843ffe2886e 100644 --- a/src/workerd/util/sqlite-test.c++ +++ b/src/workerd/util/sqlite-test.c++ @@ -740,5 +740,47 @@ KJ_TEST("DELETE with LIMIT") { KJ_EXPECT(q.getInt(0) == 3); } +KJ_TEST("reset database") { + auto dir = kj::newInMemoryDirectory(kj::nullClock()); + SqliteDatabase::Vfs vfs(*dir); + SqliteDatabase db(vfs, kj::Path({"foo"}), kj::WriteMode::CREATE | kj::WriteMode::MODIFY); + + db.run("PRAGMA journal_mode=WAL;"); + + db.run("CREATE TABLE things (id INTEGER PRIMARY KEY)"); + + db.run("INSERT INTO things VALUES (123)"); + db.run("INSERT INTO things VALUES (321)"); + + auto stmt = db.prepare("SELECT * FROM things"); + + auto query = stmt.run(); + KJ_ASSERT(!query.isDone()); + KJ_EXPECT(query.getInt(0) == 123); + + db.reset(); + db.run("PRAGMA journal_mode=WAL;"); + + // The query was canceled. + KJ_EXPECT_THROW_MESSAGE("query canceled because reset()", query.nextRow()); + KJ_EXPECT_THROW_MESSAGE("query canceled because reset()", query.getInt(0)); + + // The statement doesn't work because the table is gone. + KJ_EXPECT_THROW_MESSAGE("no such table: things: SQLITE_ERROR", stmt.run()); + + // But we can recreate it. + db.run("CREATE TABLE things (id INTEGER PRIMARY KEY)"); + db.run("INSERT INTO things VALUES (456)"); + + // Now the statement works. + { + auto q2 = stmt.run(); + KJ_ASSERT(!q2.isDone()); + KJ_EXPECT(q2.getInt(0) == 456); + q2.nextRow(); + KJ_EXPECT(q2.isDone()); + } +} + } // namespace } // namespace workerd diff --git a/src/workerd/util/sqlite.c++ b/src/workerd/util/sqlite.c++ index 0e9f139a06b..dd7b7745a4f 100644 --- a/src/workerd/util/sqlite.c++ +++ b/src/workerd/util/sqlite.c++ @@ -369,8 +369,17 @@ static constexpr PragmaInfo ALLOWED_PRAGMAS[] = {{"data_version"_kj, PragmaSigna constexpr SqliteDatabase::Regulator SqliteDatabase::TRUSTED; -SqliteDatabase::SqliteDatabase( - const Vfs& vfs, kj::PathPtr path, kj::Maybe maybeMode) { +SqliteDatabase::SqliteDatabase(const Vfs& vfs, kj::Path path, kj::Maybe maybeMode) + : vfs(vfs), + path(kj::mv(path)), + readOnly(maybeMode == kj::none) { + init(maybeMode); +} + +void SqliteDatabase::init(kj::Maybe maybeMode) { + KJ_ASSERT(maybeDb == kj::none); + sqlite3* db = nullptr; + KJ_IF_SOME(mode, maybeMode) { int flags = SQLITE_OPEN_READWRITE; if (kj::has(mode, kj::WriteMode::CREATE)) { @@ -406,10 +415,14 @@ SqliteDatabase::SqliteDatabase( KJ_ON_SCOPE_FAILURE(sqlite3_close_v2(db)); - setupSecurity(); + setupSecurity(db); + + maybeDb = *db; } SqliteDatabase::~SqliteDatabase() noexcept(false) { + sqlite3* db = &KJ_UNWRAP_OR(maybeDb, return); + auto err = sqlite3_close(db); if (err == SQLITE_BUSY) { KJ_LOG(ERROR, "sqlite database destroyed while dependent objects still exist"); @@ -423,6 +436,10 @@ SqliteDatabase::~SqliteDatabase() noexcept(false) { } } +SqliteDatabase::operator sqlite3*() { + return &KJ_ASSERT_NONNULL(maybeDb, "previous reset() failed"); +} + void SqliteDatabase::notifyWrite() { KJ_IF_SOME(cb, onWriteCallback) { cb(); @@ -441,6 +458,8 @@ kj::StringPtr SqliteDatabase::getCurrentQueryForDebug() { // statement. kj::Own SqliteDatabase::prepareSql( const Regulator& regulator, kj::StringPtr sqlCode, uint prepFlags, Multi multi) { + sqlite3* db = &KJ_ASSERT_NONNULL(maybeDb, "previous reset() failed"); + KJ_ASSERT(currentRegulator == kj::none, "recursive prepareSql()?"); KJ_DEFER(currentRegulator = kj::none); currentRegulator = regulator; @@ -545,6 +564,34 @@ void SqliteDatabase::executeWithRegulator( func(); } +void SqliteDatabase::reset() { + KJ_REQUIRE(!readOnly, "can't reset() read-only database"); + + // Temporarily disable the on-write callback while resetting. + auto writeCb = kj::mv(onWriteCallback); + KJ_DEFER(onWriteCallback = kj::mv(writeCb)); + + KJ_IF_SOME(db, maybeDb) { + for (auto& listener: resetListeners) { + listener.beforeSqliteReset(); + } + + auto err = sqlite3_close(&db); + KJ_REQUIRE(err == SQLITE_OK, "can't reset() database because dependent objects still exist", + sqlite3_errstr(err)); + + maybeDb = kj::none; + vfs.directory.remove(path); + } + + KJ_ON_SCOPE_FAILURE(maybeDb = kj::none); + init(kj::WriteMode::CREATE | kj::WriteMode::MODIFY); + + KJ_IF_SOME(resetCb, afterResetCallback) { + resetCb(*this); + } +} + bool SqliteDatabase::isAuthorized(int actionCode, kj::Maybe param1, kj::Maybe param2, @@ -810,7 +857,7 @@ bool SqliteDatabase::isAuthorizedTemp(int actionCode, // Set up security restrictions. // See: https://www.sqlite.org/security.html -void SqliteDatabase::setupSecurity() { +void SqliteDatabase::setupSecurity(sqlite3* db) { // 1. Set defensive mode. SQLITE_CALL_NODB(sqlite3_db_config(db, SQLITE_DBCONFIG_DEFENSIVE, 1, nullptr)); @@ -890,13 +937,30 @@ SqliteDatabase::Statement SqliteDatabase::prepare( *this, regulator, prepareSql(regulator, sqlCode, SQLITE_PREPARE_PERSISTENT, SINGLE)); } +SqliteDatabase::Statement::operator sqlite3_stmt*() { + KJ_IF_SOME(sqlCode, stmt.tryGet()) { + // Database was reset. Recompile the statement against the new database. (This could throw, + // of course, if the statement depends on tables that haven't been recreated yet.) + stmt = db.prepareSql(regulator, sqlCode, SQLITE_PREPARE_PERSISTENT, SINGLE); + } + + return KJ_ASSERT_NONNULL(stmt.tryGet>()).get(); +} + +void SqliteDatabase::Statement::beforeSqliteReset() { + KJ_IF_SOME(prepared, stmt.tryGet>()) { + // Pull the original SQL code out of the statement and store it. + stmt = kj::str(sqlite3_sql(prepared)); + } +} + SqliteDatabase::Query::Query(SqliteDatabase& db, const Regulator& regulator, Statement& statement, kj::ArrayPtr bindings) - : db(db), + : ResetListener(db), regulator(regulator), - statement(statement) { + maybeStatement(statement) { // If the statement was used for a previous query, then its row counters contain data from that // query's execution. Reset them to zero. resetRowCounters(); @@ -907,10 +971,10 @@ SqliteDatabase::Query::Query(SqliteDatabase& db, const Regulator& regulator, kj::StringPtr sqlCode, kj::ArrayPtr bindings) - : db(db), + : ResetListener(db), regulator(regulator), ownStatement(db.prepareSql(regulator, sqlCode, 0, MULTI)), - statement(ownStatement) { + maybeStatement(ownStatement) { init(bindings); } @@ -918,17 +982,21 @@ SqliteDatabase::Query::~Query() noexcept(false) { // We only need to reset the statement if we don't own it. If we own it, it's about to be // destroyed anyway. if (ownStatement.get() == nullptr) { - // The error code returned by sqlite3_reset() actually represents the last error encountered - // when stepping the statement. This doesn't mean that the reset failed. - sqlite3_reset(statement); - - // sqlite3_clear_bindings() returns int, but there is no documentation on how the return code - // should be interpreted, so we ignore it. - sqlite3_clear_bindings(statement); + KJ_IF_SOME(statement, maybeStatement) { + // The error code returned by sqlite3_reset() actually represents the last error encountered + // when stepping the statement. This doesn't mean that the reset failed. + sqlite3_reset(&statement); + + // sqlite3_clear_bindings() returns int, but there is no documentation on how the return code + // should be interpreted, so we ignore it. + sqlite3_clear_bindings(&statement); + } } } void SqliteDatabase::Query::checkRequirements(size_t size) { + sqlite3_stmt* statement = getStatement(); + SQLITE_REQUIRE(!sqlite3_stmt_busy(statement), "A SQL prepared statement can only be executed once at a time."); SQLITE_REQUIRE(size == sqlite3_bind_parameter_count(statement), @@ -952,6 +1020,8 @@ void SqliteDatabase::Query::init(kj::ArrayPtr bindings) { } void SqliteDatabase::Query::bind(uint i, ValuePtr value) { + sqlite3_stmt* statement = getStatement(); + KJ_SWITCH_ONEOF(value) { KJ_CASE_ONEOF(blob, kj::ArrayPtr) { SQLITE_CALL(sqlite3_bind_blob(statement, i + 1, blob.begin(), blob.size(), SQLITE_STATIC)); @@ -972,42 +1042,50 @@ void SqliteDatabase::Query::bind(uint i, ValuePtr value) { } uint64_t SqliteDatabase::Query::getRowsRead() { + sqlite3_stmt* statement = getStatement(); KJ_REQUIRE(statement != nullptr); return sqlite3_stmt_status(statement, LIBSQL_STMTSTATUS_ROWS_READ, 0); } uint64_t SqliteDatabase::Query::getRowsWritten() { - KJ_REQUIRE(statement != nullptr); + sqlite3_stmt* statement = getStatement(); return sqlite3_stmt_status(statement, LIBSQL_STMTSTATUS_ROWS_WRITTEN, 0); } void SqliteDatabase::Query::resetRowCounters() { - KJ_REQUIRE(statement != nullptr); + sqlite3_stmt* statement = getStatement(); sqlite3_stmt_status(statement, LIBSQL_STMTSTATUS_ROWS_READ, 1); sqlite3_stmt_status(statement, LIBSQL_STMTSTATUS_ROWS_WRITTEN, 1); } void SqliteDatabase::Query::bind(uint i, kj::ArrayPtr value) { + sqlite3_stmt* statement = getStatement(); SQLITE_CALL(sqlite3_bind_blob(statement, i + 1, value.begin(), value.size(), SQLITE_STATIC)); } void SqliteDatabase::Query::bind(uint i, kj::StringPtr value) { + sqlite3_stmt* statement = getStatement(); SQLITE_CALL(sqlite3_bind_text(statement, i + 1, value.begin(), value.size(), SQLITE_STATIC)); } void SqliteDatabase::Query::bind(uint i, long long value) { + sqlite3_stmt* statement = getStatement(); SQLITE_CALL(sqlite3_bind_int64(statement, i + 1, value)); } void SqliteDatabase::Query::bind(uint i, double value) { + sqlite3_stmt* statement = getStatement(); SQLITE_CALL(sqlite3_bind_double(statement, i + 1, value)); } void SqliteDatabase::Query::bind(uint i, decltype(nullptr)) { + sqlite3_stmt* statement = getStatement(); SQLITE_CALL(sqlite3_bind_null(statement, i + 1)); } void SqliteDatabase::Query::nextRow() { + sqlite3_stmt* statement = getStatement(); + KJ_ASSERT(db.currentStatement == kj::none, "recursive nextRow()?"); KJ_DEFER(db.currentStatement = kj::none); db.currentStatement = *statement; @@ -1033,10 +1111,12 @@ uint SqliteDatabase::Query::changeCount() { } uint SqliteDatabase::Query::columnCount() { + sqlite3_stmt* statement = getStatement(); return sqlite3_column_count(statement); } SqliteDatabase::Query::ValuePtr SqliteDatabase::Query::getValue(uint column) { + sqlite3_stmt* statement = getStatement(); switch (sqlite3_column_type(statement, column)) { case SQLITE_INTEGER: return getInt64(column); @@ -1053,35 +1133,57 @@ SqliteDatabase::Query::ValuePtr SqliteDatabase::Query::getValue(uint column) { } kj::StringPtr SqliteDatabase::Query::getColumnName(uint column) { + sqlite3_stmt* statement = getStatement(); return sqlite3_column_name(statement, column); } kj::ArrayPtr SqliteDatabase::Query::getBlob(uint column) { + sqlite3_stmt* statement = getStatement(); const byte* ptr = reinterpret_cast(sqlite3_column_blob(statement, column)); return kj::arrayPtr(ptr, sqlite3_column_bytes(statement, column)); } kj::StringPtr SqliteDatabase::Query::getText(uint column) { + sqlite3_stmt* statement = getStatement(); const char* ptr = reinterpret_cast(sqlite3_column_text(statement, column)); return kj::StringPtr(ptr, sqlite3_column_bytes(statement, column)); } int SqliteDatabase::Query::getInt(uint column) { + sqlite3_stmt* statement = getStatement(); return sqlite3_column_int(statement, column); } int64_t SqliteDatabase::Query::getInt64(uint column) { + sqlite3_stmt* statement = getStatement(); return sqlite3_column_int64(statement, column); } double SqliteDatabase::Query::getDouble(uint column) { + sqlite3_stmt* statement = getStatement(); return sqlite3_column_double(statement, column); } bool SqliteDatabase::Query::isNull(uint column) { + sqlite3_stmt* statement = getStatement(); return sqlite3_column_type(statement, column) == SQLITE_NULL; } +sqlite3_stmt* SqliteDatabase::Query::getStatement() { + return &KJ_UNWRAP_OR(maybeStatement, { + regulator.onError("SQLite query was canceled because the database was deleted."); + KJ_FAIL_REQUIRE("query canceled because reset() was called on the database"); + }); +} + +void SqliteDatabase::Query::beforeSqliteReset() { + // Note that if we don't own the statement, then `maybeStatement` is probably already dangling + // here. Luckily, we don't need to reset it or anything because the statement will be destroyed + // by Statement::beforeSqliteReset(). + maybeStatement = kj::none; + ownStatement = nullptr; +} + // ======================================================================================= // VFS diff --git a/src/workerd/util/sqlite.h b/src/workerd/util/sqlite.h index f4e125b2d47..e82dc65c821 100644 --- a/src/workerd/util/sqlite.h +++ b/src/workerd/util/sqlite.h @@ -6,6 +6,7 @@ #include #include +#include #include struct sqlite3; @@ -42,15 +43,13 @@ class SqliteDatabase { uint64_t statementCount; }; - SqliteDatabase(const Vfs& vfs, kj::PathPtr path, kj::Maybe maybeMode = kj::none); + SqliteDatabase(const Vfs& vfs, kj::Path path, kj::Maybe maybeMode = kj::none); ~SqliteDatabase() noexcept(false); KJ_DISALLOW_COPY_AND_MOVE(SqliteDatabase); // Allows a SqliteDatabase to be passed directly into SQLite API functions where `sqlite*` is // expected. - operator sqlite3*() { - return db; - } + operator sqlite3*(); // Class which regulates a SQL query, especially to control how queries created in JavaScript // application code are handled. @@ -130,6 +129,9 @@ class SqliteDatabase { // callback is called just before executing the query. // // Durable Objects uses this to automatically begin a transaction and close the output gate. + // + // Note that the write callback is NOT called before (or at any point during) a reset(). Use the + // `ResetListener` mechanism or `afterReset()` instead for that case. void onWrite(kj::Function callback) { onWriteCallback = kj::mv(callback); } @@ -157,8 +159,59 @@ class SqliteDatabase { // Execute a function with the given regulator. void executeWithRegulator(const Regulator& regulator, kj::FunctionParam func); + // Resets the database to an empty state by deleting the underlying database file and creating + // a new one in its place. This is the recommended way to "drop database" in SQLite, and is used + // to implement deleteAll() in Workers. + // + // reset() will cancel all outstanding queries (further attempts to use the cursors will throw). + // Prepared statements will be automatically reprepared the next time they are executed (which + // may throw if they depend on tables that haven't been recreated yet). + void reset(); + + // Objects that need to be notified when reset() is called may inherit `ResetListener`. + class ResetListener { + public: + ResetListener(SqliteDatabase& db): db(db) { + db.resetListeners.add(*this); + } + ~ResetListener() { + if (link.isLinked()) db.resetListeners.remove(*this); + } + ResetListener(ResetListener&& other): db(other.db) { + db.resetListeners.remove(other); + db.resetListeners.add(*this); + } + + // When the database's `reset()` method is called, all listeners' `beforeSqliteReset()` will be + // called before actually resetting the database. + virtual void beforeSqliteReset() = 0; + + protected: // so that subclasess don't have to store their own copy of the `db` reference + SqliteDatabase& db; + + private: + kj::ListLink link; + + friend class SqliteDatabase; + }; + + // Registers a callback to call after a reset completes. This can be used to do basic database + // initialization, e.g. set WAL mode. (To get notified *before* a reset, use `ResetListener`.) + // + // Note that the on-write callback is disabled during reset(), including while calling the + // after-reset callback. So, queries performed by the after-reset callback will not trigger the + // on-write callback. + void afterReset(kj::Function callback) { + afterResetCallback = kj::mv(callback); + } + private: - sqlite3* db; + const Vfs& vfs; + kj::Path path; + bool readOnly; + + // This pointer can be left null if a call to reset() failed to re-open the database. + kj::Maybe maybeDb; // Set while a query is compiling. kj::Maybe currentRegulator; @@ -167,8 +220,11 @@ class SqliteDatabase { kj::Maybe currentStatement; kj::Maybe> onWriteCallback; + kj::Maybe> afterResetCallback; - void close(); + kj::List resetListeners; + + void init(kj::Maybe maybeMode); enum Multi { SINGLE, MULTI }; @@ -194,11 +250,11 @@ class SqliteDatabase { const kj::Maybe& param2, const Regulator& regulator); - void setupSecurity(); + void setupSecurity(sqlite3* db); }; // Represents a prepared SQL statement, which can be executed many times. -class SqliteDatabase::Statement { +class SqliteDatabase::Statement final: private ResetListener { public: // Convenience method to start a query. This is equivalent to: // @@ -214,20 +270,20 @@ class SqliteDatabase::Statement { template Query run(Params&&... bindings); - operator sqlite3_stmt*() { - return stmt; - } + // Convert to sqlite3_stmt, creating it on-demand if needed. + operator sqlite3_stmt*(); private: - SqliteDatabase& db; const Regulator& regulator; - kj::Own stmt; + kj::OneOf> stmt; Statement(SqliteDatabase& db, const Regulator& regulator, kj::Own stmt) - : db(db), + : ResetListener(db), regulator(regulator), stmt(kj::mv(stmt)) {} + void beforeSqliteReset() override; + friend class SqliteDatabase; }; @@ -235,7 +291,7 @@ class SqliteDatabase::Statement { // // Only one Query can exist at a time, for a given database. It should probably be allocated on // the stack. -class SqliteDatabase::Query { +class SqliteDatabase::Query final: private ResetListener { public: using ValuePtr = kj::OneOf, kj::StringPtr, int64_t, double, decltype(nullptr)>; @@ -329,10 +385,9 @@ class SqliteDatabase::Query { } private: - SqliteDatabase& db; const Regulator& regulator; - kj::Own ownStatement; // for one-off queries - sqlite3_stmt* statement; + kj::Own ownStatement; // for one-off queries + kj::Maybe maybeStatement; // null if database was reset bool done = false; friend class SqliteDatabase; @@ -347,17 +402,17 @@ class SqliteDatabase::Query { kj::ArrayPtr bindings); template Query(SqliteDatabase& db, const Regulator& regulator, Statement& statement, Params&&... bindings) - : db(db), + : ResetListener(db), regulator(regulator), - statement(statement) { + maybeStatement(statement) { bindAll(std::index_sequence_for(), kj::fwd(bindings)...); } template Query(SqliteDatabase& db, const Regulator& regulator, kj::StringPtr sqlCode, Params&&... bindings) - : db(db), + : ResetListener(db), regulator(regulator), ownStatement(db.prepareSql(regulator, sqlCode, 0, MULTI)), - statement(ownStatement) { + maybeStatement(ownStatement) { bindAll(std::index_sequence_for(), kj::fwd(bindings)...); } @@ -394,6 +449,10 @@ class SqliteDatabase::Query { (bind(i, value), ...); nextRow(); } + + sqlite3_stmt* getStatement(); + + void beforeSqliteReset() override; }; // Options affecting SqliteDatabase::Vfs onstructor.