diff --git a/src/workerd/api/sql-test.js b/src/workerd/api/sql-test.js index 94f7afd244d..16e47d3eab7 100644 --- a/src/workerd/api/sql-test.js +++ b/src/workerd/api/sql-test.js @@ -264,6 +264,7 @@ async function test(state) { () => sql.exec('CREATE TABLE _cf_invalid (name TEXT)'), /not authorized/ ) + storage.put("blah", 123); // force creation of _cf_KV table assert.throws( () => sql.exec('SELECT * FROM _cf_KV'), /access to _cf_KV.key is prohibited/ diff --git a/src/workerd/server/server.c++ b/src/workerd/server/server.c++ index 25c29136c57..f8b73b32961 100644 --- a/src/workerd/server/server.c++ +++ b/src/workerd/server/server.c++ @@ -1862,6 +1862,10 @@ public: auto db = kj::heap(*as, kj::Path({d.uniqueKey, kj::str(idPtr, ".sqlite")}), kj::WriteMode::CREATE | kj::WriteMode::MODIFY | kj::WriteMode::CREATE_PARENT); + + // Before we do anything, make sure the database is in WAL mode. + db->run("PRAGMA journal_mode=WAL;"); + return kj::heap(kj::mv(db), outputGate, []() -> kj::Promise { return kj::READY_NOW; }, *sqliteHooks).attach(kj::mv(sqliteHooks)); diff --git a/src/workerd/util/sqlite-kv.c++ b/src/workerd/util/sqlite-kv.c++ index a8caaefe853..63f8299604c 100644 --- a/src/workerd/util/sqlite-kv.c++ +++ b/src/workerd/util/sqlite-kv.c++ @@ -6,33 +6,50 @@ namespace workerd { -SqliteKv::SqliteKv(SqliteDatabase& db, bool): db(db) {} - -SqliteDatabase& SqliteKv::ensureInitialized(SqliteDatabase& db) { - // TODO(sqlite): Do this automatically at a lower layer? - db.run("PRAGMA journal_mode=WAL;"); - - db.run(R"( - CREATE TABLE IF NOT EXISTS _cf_KV ( - key TEXT PRIMARY KEY, - value BLOB - ) WITHOUT ROWID; - )"); +SqliteKv::SqliteKv(SqliteDatabase& db) { + if (db.run("SELECT name FROM sqlite_master WHERE type='table' AND name='_cf_KV'").isDone()) { + // The _cf_KV table doesn't exist. Defer initialization. + state = Uninitialized { db }; + } else { + // The KV table was initialized in the past. We can go ahead and prepare our statements. + // (We don't call ensureInitialized() here because the `CREATE TABLE IF NOT EXISTS` query it + // executes would be redundant.) + state = Initialized(db); + } +} - return db; +SqliteKv::Initialized& SqliteKv::ensureInitialized() { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(uninitialized, Uninitialized) { + auto& db = uninitialized.db; + + db.run(R"( + CREATE TABLE IF NOT EXISTS _cf_KV ( + key TEXT PRIMARY KEY, + value BLOB + ) WITHOUT ROWID; + )"); + + return state.init(db); + } + KJ_CASE_ONEOF(initialized, Initialized) { + return initialized; + } + } + KJ_UNREACHABLE; } void SqliteKv::put(KeyPtr key, ValuePtr value) { - stmtPut.run(key, value); + ensureInitialized().stmtPut.run(key, value); } bool SqliteKv::delete_(KeyPtr key) { - auto query = stmtDelete.run(key); + auto query = ensureInitialized().stmtDelete.run(key); return query.changeCount() > 0; } uint SqliteKv::deleteAll() { - auto query = stmtDeleteAll.run(); + auto query = ensureInitialized().stmtDeleteAll.run(); return query.changeCount(); } diff --git a/src/workerd/util/sqlite-kv.h b/src/workerd/util/sqlite-kv.h index 867fcd4377e..1aab04c340c 100644 --- a/src/workerd/util/sqlite-kv.h +++ b/src/workerd/util/sqlite-kv.h @@ -17,7 +17,7 @@ namespace workerd { // obnoxious amount of string allocation.) class SqliteKv { public: - explicit SqliteKv(SqliteDatabase& db): SqliteKv(ensureInitialized(db), true) {} + explicit SqliteKv(SqliteDatabase& db); typedef kj::StringPtr KeyPtr; typedef kj::ArrayPtr ValuePtr; @@ -54,68 +54,79 @@ class SqliteKv { // byte blobs or strings containing NUL bytes. private: - SqliteDatabase& db; - - SqliteDatabase::Statement stmtGet = db.prepare(R"( - SELECT value FROM _cf_KV WHERE key = ? - )"); - SqliteDatabase::Statement stmtPut = db.prepare(R"( - INSERT INTO _cf_KV VALUES(?, ?) - ON CONFLICT DO UPDATE SET value = excluded.value; - )"); - SqliteDatabase::Statement stmtDelete = db.prepare(R"( - DELETE FROM _cf_KV WHERE key = ? - )"); - SqliteDatabase::Statement stmtList = db.prepare(R"( - SELECT * FROM _cf_KV - WHERE key >= ? - ORDER BY key - )"); - SqliteDatabase::Statement stmtListEnd = db.prepare(R"( - SELECT * FROM _cf_KV - WHERE key >= ? AND key < ? - ORDER BY key - )"); - SqliteDatabase::Statement stmtListLimit = db.prepare(R"( - SELECT * FROM _cf_KV - WHERE key >= ? - ORDER BY key - LIMIT ? - )"); - SqliteDatabase::Statement stmtListEndLimit = db.prepare(R"( - SELECT * FROM _cf_KV - WHERE key >= ? AND key < ? - ORDER BY key - LIMIT ? - )"); - SqliteDatabase::Statement stmtListReverse = db.prepare(R"( - SELECT * FROM _cf_KV - WHERE key >= ? - ORDER BY key DESC - )"); - SqliteDatabase::Statement stmtListEndReverse = db.prepare(R"( - SELECT * FROM _cf_KV - WHERE key >= ? AND key < ? - ORDER BY key DESC - )"); - SqliteDatabase::Statement stmtListLimitReverse = db.prepare(R"( - SELECT * FROM _cf_KV - WHERE key >= ? - ORDER BY key DESC - LIMIT ? - )"); - SqliteDatabase::Statement stmtListEndLimitReverse = db.prepare(R"( - SELECT * FROM _cf_KV - WHERE key >= ? AND key < ? - ORDER BY key DESC - LIMIT ? - )"); - SqliteDatabase::Statement stmtDeleteAll = db.prepare(R"( - DELETE FROM _cf_KV - )"); - - SqliteDatabase& ensureInitialized(SqliteDatabase& db); - // Make sure the KV table is created, then return the same object. + struct Uninitialized { + SqliteDatabase& db; + }; + + struct Initialized { + SqliteDatabase& db; + + SqliteDatabase::Statement stmtGet = db.prepare(R"( + SELECT value FROM _cf_KV WHERE key = ? + )"); + SqliteDatabase::Statement stmtPut = db.prepare(R"( + INSERT INTO _cf_KV VALUES(?, ?) + ON CONFLICT DO UPDATE SET value = excluded.value; + )"); + SqliteDatabase::Statement stmtDelete = db.prepare(R"( + DELETE FROM _cf_KV WHERE key = ? + )"); + SqliteDatabase::Statement stmtList = db.prepare(R"( + SELECT * FROM _cf_KV + WHERE key >= ? + ORDER BY key + )"); + SqliteDatabase::Statement stmtListEnd = db.prepare(R"( + SELECT * FROM _cf_KV + WHERE key >= ? AND key < ? + ORDER BY key + )"); + SqliteDatabase::Statement stmtListLimit = db.prepare(R"( + SELECT * FROM _cf_KV + WHERE key >= ? + ORDER BY key + LIMIT ? + )"); + SqliteDatabase::Statement stmtListEndLimit = db.prepare(R"( + SELECT * FROM _cf_KV + WHERE key >= ? AND key < ? + ORDER BY key + LIMIT ? + )"); + SqliteDatabase::Statement stmtListReverse = db.prepare(R"( + SELECT * FROM _cf_KV + WHERE key >= ? + ORDER BY key DESC + )"); + SqliteDatabase::Statement stmtListEndReverse = db.prepare(R"( + SELECT * FROM _cf_KV + WHERE key >= ? AND key < ? + ORDER BY key DESC + )"); + SqliteDatabase::Statement stmtListLimitReverse = db.prepare(R"( + SELECT * FROM _cf_KV + WHERE key >= ? + ORDER BY key DESC + LIMIT ? + )"); + SqliteDatabase::Statement stmtListEndLimitReverse = db.prepare(R"( + SELECT * FROM _cf_KV + WHERE key >= ? AND key < ? + ORDER BY key DESC + LIMIT ? + )"); + SqliteDatabase::Statement stmtDeleteAll = db.prepare(R"( + DELETE FROM _cf_KV + )"); + + Initialized(SqliteDatabase& db): db(db) {} + }; + + kj::OneOf state; + + Initialized& ensureInitialized(); + // Make sure the KV table is created and prepared statements are ready. Not called until the + // first write. SqliteKv(SqliteDatabase& db, bool); }; @@ -129,7 +140,9 @@ class SqliteKv { template bool SqliteKv::get(KeyPtr key, Func&& callback) { - auto query = stmtGet.run(key); + auto& stmts = KJ_UNWRAP_OR(state.tryGet(), return false); + + auto query = stmts.stmtGet.run(key); if (query.isDone()) { return false; @@ -142,6 +155,8 @@ bool SqliteKv::get(KeyPtr key, Func&& callback) { template uint SqliteKv::list(KeyPtr begin, kj::Maybe end, kj::Maybe limit, Order order, Func&& callback) { + auto& stmts = KJ_UNWRAP_OR(state.tryGet(), return 0); + auto iterate = [&](SqliteDatabase::Query&& query) { size_t count = 0; while (!query.isDone()) { @@ -155,29 +170,29 @@ uint SqliteKv::list(KeyPtr begin, kj::Maybe end, kj::Maybe limit, if (order == Order::FORWARD) { KJ_IF_SOME(e, end) { KJ_IF_SOME(l, limit) { - return iterate(stmtListEndLimit.run(begin, e, (int64_t)l)); + return iterate(stmts.stmtListEndLimit.run(begin, e, (int64_t)l)); } else { - return iterate(stmtListEnd.run(begin, e)); + return iterate(stmts.stmtListEnd.run(begin, e)); } } else { KJ_IF_SOME(l, limit) { - return iterate(stmtListLimit.run(begin, (int64_t)l)); + return iterate(stmts.stmtListLimit.run(begin, (int64_t)l)); } else { - return iterate(stmtList.run(begin)); + return iterate(stmts.stmtList.run(begin)); } } } else { KJ_IF_SOME(e, end) { KJ_IF_SOME(l, limit) { - return iterate(stmtListEndLimitReverse.run(begin, e, (int64_t)l)); + return iterate(stmts.stmtListEndLimitReverse.run(begin, e, (int64_t)l)); } else { - return iterate(stmtListEndReverse.run(begin, e)); + return iterate(stmts.stmtListEndReverse.run(begin, e)); } } else { KJ_IF_SOME(l, limit) { - return iterate(stmtListLimitReverse.run(begin, (int64_t)l)); + return iterate(stmts.stmtListLimitReverse.run(begin, (int64_t)l)); } else { - return iterate(stmtListReverse.run(begin)); + return iterate(stmts.stmtListReverse.run(begin)); } } }