Skip to content

Commit 354e841

Browse files
committed
Add DisableFunction API
This allows built-in functions to be disabled. Updates tailscale/corp#31396 Signed-off-by: Percy Wegmann <[email protected]>
1 parent 9328d04 commit 354e841

File tree

5 files changed

+49
-0
lines changed

5 files changed

+49
-0
lines changed

cgosqlite/cgosqlite.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ func (db *DB) Prepare(query string, prepFlags sqliteh.PrepareFlags) (stmt sqlite
203203
return &Stmt{db: db, stmt: cStmtFromPtr(cstmt)}, remainingQuery, nil
204204
}
205205

206+
func (db *DB) DisableFunction(name string, numArgs int) error {
207+
return errCode(C.ts_sqlite3_disable_function(db.db, C.CString(name), C.int(numArgs)))
208+
}
209+
206210
func (stmt *Stmt) DBHandle() sqliteh.DB {
207211
cdb := C.sqlite3_db_handle(stmt.stmt.ptr())
208212
if cdb != nil {

cgosqlite/cgosqlite.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,7 @@ static double ts_sqlite3_column_double(handle_sqlite3_stmt stmt, int iCol) {
146146
static sqlite3_int64 ts_sqlite3_column_int64(handle_sqlite3_stmt stmt, int iCol) {
147147
return sqlite3_column_int64((sqlite3_stmt*)(stmt), iCol);
148148
}
149+
150+
static int ts_sqlite3_disable_function(sqlite3 *db, const char *zFunctionName, int nArg) {
151+
return sqlite3_create_function(db, zFunctionName, nArg, SQLITE_ANY, NULL, NULL, NULL, NULL);
152+
}

sqlite.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,18 @@ func Checkpoint(sqlconn SQLConn, dbName string, mode sqliteh.Checkpoint) (numFra
10361036
return numFrames, numFramesCheckpointed, err
10371037
}
10381038

1039+
// DisableFunction disables the named function on the given connection.
1040+
// numArgs must match the number of args of the function to be disabled.
1041+
func DisableFunction(sqlconn SQLConn, name string, numArgs int) error {
1042+
return sqlconn.Raw(func(driverConn any) error {
1043+
c, ok := driverConn.(*conn)
1044+
if !ok {
1045+
return fmt.Errorf("sqlite.DisableFunction: sql.Conn is not the sqlite driver: %T", driverConn)
1046+
}
1047+
return c.db.DisableFunction(name, numArgs)
1048+
})
1049+
}
1050+
10391051
// WithPersist makes a ctx instruct the sqlite driver to persist a prepared query.
10401052
//
10411053
// This should be used with recurring queries to avoid constant parsing and

sqlite_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,3 +1330,28 @@ func TestRegression(t *testing.T) {
13301330
t.Log("OK") // Reaching here at all means we didn't panic.
13311331
})
13321332
}
1333+
1334+
func TestDisableFunction(t *testing.T) {
1335+
db := openTestDB(t)
1336+
1337+
conn, err := db.Conn(context.Background())
1338+
if err != nil {
1339+
t.Fatal(err)
1340+
}
1341+
defer conn.Close()
1342+
1343+
exec(t, conn, "CREATE TABLE t (c)")
1344+
ctx := context.Background()
1345+
1346+
if _, err := conn.ExecContext(ctx, "SELECT LOWER('Hi') FROM t"); err != nil {
1347+
t.Fatal("Attempting to use the LOWER function before disabling should have been allowed")
1348+
}
1349+
1350+
if err := DisableFunction(conn, "lower", 1); err != nil {
1351+
t.Fatal(err)
1352+
}
1353+
1354+
if _, err := conn.ExecContext(ctx, "SELECT LOWER('Hi') FROM t"); err == nil {
1355+
t.Fatal("Attempting to use the LOWER function after disabling should have failed")
1356+
}
1357+
}

sqliteh/sqliteh.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ type DB interface {
6060
//
6161
// If hook is nil, the hook is removed.
6262
SetWALHook(hook func(dbName string, pages int))
63+
// DisableFunction allows disabling an existing function using
64+
// sqlite3_create_function.
65+
//
66+
DisableFunction(name string, numArgs int) error
6367
}
6468

6569
// Stmt is an sqlite3_stmt* database connection object.

0 commit comments

Comments
 (0)