Skip to content

Commit b07a594

Browse files
alisdairknyar
andcommitted
sqlite: add optional ConnLogger
ConnLogger can be used to log executed statements for a connection. The interface captures events for Begin, Exec, Commit, and Rollback calls. Updates tailscale/corp#33577 Co-authored-by: Anton Tolchanov <[email protected]>
1 parent 3a6395a commit b07a594

File tree

2 files changed

+237
-6
lines changed

2 files changed

+237
-6
lines changed

sqlite.go

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,19 @@ func Connector(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Trace
127127
}
128128
}
129129

130+
func ConnectorWithLogger(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Tracer, makeLogger func() ConnLogger) driver.Connector {
131+
return &connector{
132+
name: sqliteURI,
133+
tracer: tracer,
134+
makeLogger: makeLogger,
135+
connInitFunc: connInitFunc,
136+
}
137+
}
138+
130139
type connector struct {
131140
name string
132141
tracer sqliteh.Tracer
142+
makeLogger func() ConnLogger
133143
connInitFunc ConnInitFunc
134144
}
135145

@@ -152,12 +162,14 @@ func (p *connector) Connect(ctx context.Context) (driver.Conn, error) {
152162
}
153163
return nil, err
154164
}
155-
156165
c := &conn{
157166
db: db,
158167
tracer: p.tracer,
159168
id: sqliteh.TraceConnID(maxConnID.Add(1)),
160169
}
170+
if p.makeLogger != nil {
171+
c.logger = p.makeLogger()
172+
}
161173
if p.connInitFunc != nil {
162174
if err := p.connInitFunc(ctx, c); err != nil {
163175
db.Close()
@@ -179,6 +191,7 @@ type conn struct {
179191
db sqliteh.DB
180192
id sqliteh.TraceConnID
181193
tracer sqliteh.Tracer
194+
logger ConnLogger
182195
stmts map[string]*stmt // persisted statements
183196
txState txState
184197
readOnly bool
@@ -202,6 +215,7 @@ func (c *conn) Close() error {
202215
err := reserr(c.db, "Conn.Close", "", c.db.Close())
203216
return err
204217
}
218+
205219
func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
206220
persist := ctx.Value(persistQuery{}) != nil
207221
return c.prepare(ctx, query, persist)
@@ -341,6 +355,9 @@ func (c *conn) txInit(ctx context.Context) error {
341355
return err
342356
}
343357
} else {
358+
if c.logger != nil {
359+
c.logger.Begin()
360+
}
344361
// TODO(crawshaw): offer BEGIN DEFERRED (and BEGIN CONCURRENT?)
345362
// semantics via a context annotation function.
346363
if err := c.execInternal(ctx, "BEGIN IMMEDIATE"); err != nil {
@@ -351,15 +368,16 @@ func (c *conn) txInit(ctx context.Context) error {
351368
}
352369

353370
func (c *conn) txEnd(ctx context.Context, endStmt string) error {
354-
state, readOnly := c.txState, c.readOnly
355-
c.txState = txStateNone
356-
c.readOnly = false
357-
if state != txStateBegun {
371+
defer func() {
372+
c.txState = txStateNone
373+
c.readOnly = false
374+
}()
375+
if c.txState != txStateBegun {
358376
return nil
359377
}
360378

361379
err := c.execInternal(context.Background(), endStmt)
362-
if readOnly {
380+
if c.readOnly {
363381
if err2 := c.execInternal(ctx, "PRAGMA query_only=false"); err == nil {
364382
err = err2
365383
}
@@ -377,10 +395,14 @@ func (tx *connTx) Commit() error {
377395
return ErrClosed
378396
}
379397

398+
readonly := tx.conn.readOnly
380399
err := tx.conn.txEnd(context.Background(), "COMMIT")
381400
if tx.conn.tracer != nil {
382401
tx.conn.tracer.Commit(tx.conn.id, err)
383402
}
403+
if tx.conn.logger != nil && !readonly {
404+
tx.conn.logger.Commit(err)
405+
}
384406
return err
385407
}
386408

@@ -390,10 +412,14 @@ func (tx *connTx) Rollback() error {
390412
return ErrClosed
391413
}
392414

415+
readonly := tx.conn.readOnly
393416
err := tx.conn.txEnd(context.Background(), "ROLLBACK")
394417
if tx.conn.tracer != nil {
395418
tx.conn.tracer.Rollback(tx.conn.id, err)
396419
}
420+
if tx.conn.logger != nil && !readonly {
421+
tx.conn.logger.Rollback()
422+
}
397423
return err
398424
}
399425

@@ -490,6 +516,9 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
490516
if err := s.bindAll(args); err != nil {
491517
return nil, s.reserr("Stmt.Exec(Bind)", err)
492518
}
519+
if s.conn.logger != nil && !s.conn.readOnly {
520+
s.conn.logger.Statement(s.stmt.ExpandedSQL())
521+
}
493522

494523
if ctx.Value(queryCancelKey{}) != nil {
495524
done := make(chan struct{})
@@ -1068,3 +1097,21 @@ func WithQueryCancel(ctx context.Context) context.Context {
10681097

10691098
// queryCancelKey is a context key for query context enforcement.
10701099
type queryCancelKey struct{}
1100+
1101+
// ConnLogger is implemented by the caller to support statement-level logging for
1102+
// write transactions. Only Exec calls are logged, not Query calls, as this is
1103+
// intended as a mechanism to replay failed transactions.
1104+
type ConnLogger interface {
1105+
// Begin is called when a writable transaction is opened.
1106+
Begin()
1107+
1108+
// Statement is called with evaluated SQL when a statement is executed.
1109+
Statement(sql string)
1110+
1111+
// Commit is called after a commit statement, with the error resulting
1112+
// from the attempted commit.
1113+
Commit(error)
1114+
1115+
// Rollback is called after a rollback statement.
1116+
Rollback()
1117+
}

sqlite_test.go

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"os"
1414
"reflect"
1515
"runtime"
16+
"slices"
1617
"strconv"
1718
"strings"
1819
"sync"
@@ -1354,3 +1355,186 @@ func TestDisableFunction(t *testing.T) {
13541355
t.Fatal("Attempting to use the LOWER function after disabling should have failed")
13551356
}
13561357
}
1358+
1359+
type connLogger struct {
1360+
ch chan []string
1361+
statements []string
1362+
panicOnUse bool
1363+
}
1364+
1365+
func (cl *connLogger) Begin() {
1366+
if cl.panicOnUse {
1367+
panic("unexpected connLogger.Begin()")
1368+
}
1369+
cl.statements = nil
1370+
}
1371+
1372+
func (cl *connLogger) Statement(s string) {
1373+
if cl.panicOnUse {
1374+
panic("unexpected connLogger.Statement: " + s)
1375+
}
1376+
cl.statements = append(cl.statements, s)
1377+
}
1378+
1379+
func (cl *connLogger) Commit(err error) {
1380+
if cl.panicOnUse {
1381+
panic("unexpected connLogger.Commit()")
1382+
}
1383+
if err != nil {
1384+
return
1385+
}
1386+
cl.ch <- cl.statements
1387+
}
1388+
1389+
func (cl *connLogger) Rollback() {
1390+
if cl.panicOnUse {
1391+
panic("unexpected connLogger.Rollback()")
1392+
}
1393+
cl.statements = nil
1394+
}
1395+
1396+
func TestConnLogger_writable(t *testing.T) {
1397+
for _, commit := range []bool{true, false} {
1398+
doneStatement := "ROLLBACK"
1399+
if commit {
1400+
doneStatement = "COMMIT"
1401+
}
1402+
t.Run(doneStatement, func(t *testing.T) {
1403+
ctx := context.Background()
1404+
ch := make(chan []string, 1)
1405+
txl := connLogger{ch: ch}
1406+
makeLogger := func() ConnLogger { return &txl }
1407+
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
1408+
configDB(t, db)
1409+
1410+
tx, err := db.BeginTx(ctx, nil)
1411+
if err != nil {
1412+
t.Fatal(err)
1413+
}
1414+
if _, err := tx.Exec("CREATE TABLE T (x INTEGER)"); err != nil {
1415+
t.Fatal(err)
1416+
}
1417+
if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err != nil {
1418+
t.Fatal(err)
1419+
}
1420+
if _, err := tx.Query("SELECT x FROM T"); err != nil {
1421+
t.Fatal(err)
1422+
}
1423+
done := tx.Rollback
1424+
if commit {
1425+
done = tx.Commit
1426+
}
1427+
if err := done(); err != nil {
1428+
t.Fatal(err)
1429+
}
1430+
if !commit {
1431+
select {
1432+
case got := <-ch:
1433+
t.Errorf("unexpectedly logged statements for rollback:\n%s", strings.Join(got, "\n"))
1434+
default:
1435+
return
1436+
}
1437+
}
1438+
1439+
want := []string{
1440+
"BEGIN IMMEDIATE",
1441+
"CREATE TABLE T (x INTEGER)",
1442+
"INSERT INTO T VALUES (1)",
1443+
doneStatement,
1444+
}
1445+
select {
1446+
case got := <-ch:
1447+
if !slices.Equal(got, want) {
1448+
t.Errorf("unexpected log statements. got:\n%s\n\nwant:\n%s", strings.Join(got, "\n"), strings.Join(want, "\n"))
1449+
}
1450+
default:
1451+
t.Fatal("no logged statements after commit")
1452+
}
1453+
})
1454+
}
1455+
}
1456+
1457+
func TestConnLogger_commit_error(t *testing.T) {
1458+
ctx := context.Background()
1459+
ch := make(chan []string, 1)
1460+
txl := connLogger{ch: ch}
1461+
makeLogger := func() ConnLogger { return &txl }
1462+
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
1463+
configDB(t, db)
1464+
1465+
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
1466+
t.Fatal(err)
1467+
}
1468+
if _, err := db.Exec("CREATE TABLE A (x INTEGER PRIMARY KEY)"); err != nil {
1469+
t.Fatal(err)
1470+
}
1471+
if _, err := db.Exec("CREATE TABLE B (x INTEGER REFERENCES A(X) DEFERRABLE INITIALLY DEFERRED)"); err != nil {
1472+
t.Fatal(err)
1473+
}
1474+
1475+
tx, err := db.BeginTx(ctx, nil)
1476+
if err != nil {
1477+
t.Fatal(err)
1478+
}
1479+
if _, err := tx.Exec("INSERT INTO B VALUES (?)", 1); err != nil {
1480+
t.Fatal(err)
1481+
}
1482+
if err := tx.Commit(); err == nil {
1483+
t.Fatal("expected Commit to error, but didn't")
1484+
}
1485+
select {
1486+
case got := <-ch:
1487+
t.Errorf("unexpectedly logged statements for errored commit:\n%s", strings.Join(got, "\n"))
1488+
default:
1489+
return
1490+
}
1491+
}
1492+
1493+
func TestConnLogger_read_tx(t *testing.T) {
1494+
ctx := context.Background()
1495+
ch := make(chan []string, 1)
1496+
txl := connLogger{ch: ch}
1497+
makeLogger := func() ConnLogger { return &txl }
1498+
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
1499+
configDB(t, db)
1500+
1501+
tx, err := db.BeginTx(ctx, nil)
1502+
if err != nil {
1503+
t.Fatal(err)
1504+
}
1505+
if _, err := tx.Exec("CREATE TABLE T (x INTEGER)"); err != nil {
1506+
t.Fatal(err)
1507+
}
1508+
if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err != nil {
1509+
t.Fatal(err)
1510+
}
1511+
if err := tx.Commit(); err != nil {
1512+
t.Fatal(err)
1513+
}
1514+
select {
1515+
case got := <-ch:
1516+
if len(got) == 0 {
1517+
t.Errorf("expected logged statements for write tx")
1518+
}
1519+
default:
1520+
t.Errorf("expected logged statements for write tx")
1521+
}
1522+
1523+
txl.panicOnUse = true
1524+
for _, commit := range []bool{true, false} {
1525+
rx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
1526+
if err != nil {
1527+
t.Fatal(err)
1528+
}
1529+
if _, err := rx.Query("SELECT x FROM T"); err != nil {
1530+
t.Fatal(err)
1531+
}
1532+
done := rx.Rollback
1533+
if commit {
1534+
done = rx.Commit
1535+
}
1536+
if err := done(); err != nil {
1537+
t.Fatal(err)
1538+
}
1539+
}
1540+
}

0 commit comments

Comments
 (0)