Skip to content

Commit c5f7361

Browse files
committed
sqlite: add optional ConnLogger
ConnLogger can be used to log executed statements for a connection. The interface captures events for Begin, Exec, Commit, and Rollback calls. Updates tailscale/corp#33577
1 parent 3a6395a commit c5f7361

File tree

2 files changed

+157
-1
lines changed

2 files changed

+157
-1
lines changed

sqlite.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,19 @@ func Connector(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Trace
127127
}
128128
}
129129

130+
func ConnectorWithLogger(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Tracer, makeLogger func() ConnLogger) driver.Connector {
131+
return &connector{
132+
name: sqliteURI,
133+
tracer: tracer,
134+
makeLogger: makeLogger,
135+
connInitFunc: connInitFunc,
136+
}
137+
}
138+
130139
type connector struct {
131140
name string
132141
tracer sqliteh.Tracer
142+
makeLogger func() ConnLogger
133143
connInitFunc ConnInitFunc
134144
}
135145

@@ -152,10 +162,10 @@ func (p *connector) Connect(ctx context.Context) (driver.Conn, error) {
152162
}
153163
return nil, err
154164
}
155-
156165
c := &conn{
157166
db: db,
158167
tracer: p.tracer,
168+
logger: p.makeLogger(),
159169
id: sqliteh.TraceConnID(maxConnID.Add(1)),
160170
}
161171
if p.connInitFunc != nil {
@@ -179,6 +189,7 @@ type conn struct {
179189
db sqliteh.DB
180190
id sqliteh.TraceConnID
181191
tracer sqliteh.Tracer
192+
logger ConnLogger
182193
stmts map[string]*stmt // persisted statements
183194
txState txState
184195
readOnly bool
@@ -341,6 +352,9 @@ func (c *conn) txInit(ctx context.Context) error {
341352
return err
342353
}
343354
} else {
355+
if c.logger != nil {
356+
c.logger.Begin()
357+
}
344358
// TODO(crawshaw): offer BEGIN DEFERRED (and BEGIN CONCURRENT?)
345359
// semantics via a context annotation function.
346360
if err := c.execInternal(ctx, "BEGIN IMMEDIATE"); err != nil {
@@ -381,6 +395,9 @@ func (tx *connTx) Commit() error {
381395
if tx.conn.tracer != nil {
382396
tx.conn.tracer.Commit(tx.conn.id, err)
383397
}
398+
if tx.conn.logger != nil && err == nil {
399+
tx.conn.logger.Commit()
400+
}
384401
return err
385402
}
386403

@@ -394,6 +411,9 @@ func (tx *connTx) Rollback() error {
394411
if tx.conn.tracer != nil {
395412
tx.conn.tracer.Rollback(tx.conn.id, err)
396413
}
414+
if tx.conn.logger != nil {
415+
tx.conn.logger.Rollback()
416+
}
397417
return err
398418
}
399419

@@ -490,6 +510,9 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
490510
if err := s.bindAll(args); err != nil {
491511
return nil, s.reserr("Stmt.Exec(Bind)", err)
492512
}
513+
if s.conn.logger != nil {
514+
s.conn.logger.Statement(s.stmt.ExpandedSQL())
515+
}
493516

494517
if ctx.Value(queryCancelKey{}) != nil {
495518
done := make(chan struct{})
@@ -1068,3 +1091,20 @@ func WithQueryCancel(ctx context.Context) context.Context {
10681091

10691092
// queryCancelKey is a context key for query context enforcement.
10701093
type queryCancelKey struct{}
1094+
1095+
// ConnLogger is implemented by the caller to support statement-level logging for
1096+
// write transactions. Only Exec calls are logged, not Query calls, as this is
1097+
// intended as a mechanism to replay failed transactions.
1098+
type ConnLogger interface {
1099+
// Begin is called when a writable transaction is opened.
1100+
Begin()
1101+
1102+
// Statement is called with evaluated SQL when a statement is executed.
1103+
Statement(sql string)
1104+
1105+
// Commit is called when a transaction successfully commits.
1106+
Commit()
1107+
1108+
// Rollback is called when a transaction is rolled back.
1109+
Rollback()
1110+
}

sqlite_test.go

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"os"
1414
"reflect"
1515
"runtime"
16+
"slices"
1617
"strconv"
1718
"strings"
1819
"sync"
@@ -1354,3 +1355,118 @@ func TestDisableFunction(t *testing.T) {
13541355
t.Fatal("Attempting to use the LOWER function after disabling should have failed")
13551356
}
13561357
}
1358+
1359+
type connLogger struct {
1360+
ch chan []string
1361+
statements []string
1362+
}
1363+
1364+
func (cl *connLogger) Begin() {
1365+
cl.statements = nil
1366+
}
1367+
1368+
func (cl *connLogger) Statement(s string) {
1369+
cl.statements = append(cl.statements, s)
1370+
}
1371+
1372+
func (cl *connLogger) Commit() {
1373+
cl.ch <- cl.statements
1374+
}
1375+
1376+
func (cl *connLogger) Rollback() {
1377+
cl.statements = nil
1378+
}
1379+
1380+
func TestConnLogger_writable(t *testing.T) {
1381+
for _, commit := range []bool{true, false} {
1382+
doneStatement := "ROLLBACK"
1383+
if commit {
1384+
doneStatement = "COMMIT"
1385+
}
1386+
t.Run(doneStatement, func(t *testing.T) {
1387+
ctx := context.Background()
1388+
ch := make(chan []string, 1)
1389+
txl := connLogger{ch: ch}
1390+
makeLogger := func() ConnLogger { return &txl }
1391+
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
1392+
configDB(t, db)
1393+
1394+
tx, err := db.BeginTx(ctx, nil)
1395+
if err != nil {
1396+
t.Fatal(err)
1397+
}
1398+
if _, err := tx.Exec("CREATE TABLE T (x INTEGER)"); err != nil {
1399+
t.Fatal(err)
1400+
}
1401+
if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err != nil {
1402+
t.Fatal(err)
1403+
}
1404+
done := tx.Rollback
1405+
if commit {
1406+
done = tx.Commit
1407+
}
1408+
if err := done(); err != nil {
1409+
t.Fatal(err)
1410+
}
1411+
if !commit {
1412+
select {
1413+
case got := <-ch:
1414+
t.Errorf("unexpectedly logged statements for rollback:\n%s", strings.Join(got, "\n"))
1415+
default:
1416+
return
1417+
}
1418+
}
1419+
1420+
want := []string{
1421+
"BEGIN IMMEDIATE",
1422+
"CREATE TABLE T (x INTEGER)",
1423+
"INSERT INTO T VALUES (1)",
1424+
doneStatement,
1425+
}
1426+
select {
1427+
case got := <-ch:
1428+
if !slices.Equal(got, want) {
1429+
t.Errorf("unexpected log statements. got:\n%s\n\nwant:\n%s", strings.Join(got, "\n"), strings.Join(want, "\n"))
1430+
}
1431+
default:
1432+
t.Fatal("no logged statements after commit")
1433+
}
1434+
})
1435+
}
1436+
}
1437+
1438+
func TestConnLogger_commit_error_retry(t *testing.T) {
1439+
ctx := context.Background()
1440+
ch := make(chan []string, 1)
1441+
txl := connLogger{ch: ch}
1442+
makeLogger := func() ConnLogger { return &txl }
1443+
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
1444+
configDB(t, db)
1445+
1446+
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
1447+
t.Fatal(err)
1448+
}
1449+
if _, err := db.Exec("CREATE TABLE A (x INTEGER PRIMARY KEY)"); err != nil {
1450+
t.Fatal(err)
1451+
}
1452+
if _, err := db.Exec("CREATE TABLE B (x INTEGER REFERENCES A(X) DEFERRABLE INITIALLY DEFERRED)"); err != nil {
1453+
t.Fatal(err)
1454+
}
1455+
1456+
tx, err := db.BeginTx(ctx, nil)
1457+
if err != nil {
1458+
t.Fatal(err)
1459+
}
1460+
if _, err := tx.Exec("INSERT INTO B VALUES (?)", 1); err != nil {
1461+
t.Fatal(err)
1462+
}
1463+
if err := tx.Commit(); err == nil {
1464+
t.Fatal("expected Commit to error, but didn't")
1465+
}
1466+
select {
1467+
case got := <-ch:
1468+
t.Errorf("unexpectedly logged statements for errored commit:\n%s", strings.Join(got, "\n"))
1469+
default:
1470+
return
1471+
}
1472+
}

0 commit comments

Comments
 (0)