Skip to content

Commit 2ebf665

Browse files
alisdairknyar
andcommitted
sqlite: add optional ConnLogger
ConnLogger can be used to log executed statements for a connection. The interface captures events for Begin, Exec, Commit, and Rollback calls. Updates tailscale/corp#33577 Co-authored-by: Anton Tolchanov <[email protected]>
1 parent 3a6395a commit 2ebf665

File tree

2 files changed

+233
-6
lines changed

2 files changed

+233
-6
lines changed

sqlite.go

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,19 @@ func Connector(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Trace
127127
}
128128
}
129129

130+
func ConnectorWithLogger(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Tracer, makeLogger func() ConnLogger) driver.Connector {
131+
return &connector{
132+
name: sqliteURI,
133+
tracer: tracer,
134+
makeLogger: makeLogger,
135+
connInitFunc: connInitFunc,
136+
}
137+
}
138+
130139
type connector struct {
131140
name string
132141
tracer sqliteh.Tracer
142+
makeLogger func() ConnLogger
133143
connInitFunc ConnInitFunc
134144
}
135145

@@ -152,12 +162,14 @@ func (p *connector) Connect(ctx context.Context) (driver.Conn, error) {
152162
}
153163
return nil, err
154164
}
155-
156165
c := &conn{
157166
db: db,
158167
tracer: p.tracer,
159168
id: sqliteh.TraceConnID(maxConnID.Add(1)),
160169
}
170+
if p.makeLogger != nil {
171+
c.logger = p.makeLogger()
172+
}
161173
if p.connInitFunc != nil {
162174
if err := p.connInitFunc(ctx, c); err != nil {
163175
db.Close()
@@ -179,6 +191,7 @@ type conn struct {
179191
db sqliteh.DB
180192
id sqliteh.TraceConnID
181193
tracer sqliteh.Tracer
194+
logger ConnLogger
182195
stmts map[string]*stmt // persisted statements
183196
txState txState
184197
readOnly bool
@@ -202,6 +215,7 @@ func (c *conn) Close() error {
202215
err := reserr(c.db, "Conn.Close", "", c.db.Close())
203216
return err
204217
}
218+
205219
func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
206220
persist := ctx.Value(persistQuery{}) != nil
207221
return c.prepare(ctx, query, persist)
@@ -341,6 +355,9 @@ func (c *conn) txInit(ctx context.Context) error {
341355
return err
342356
}
343357
} else {
358+
if c.logger != nil {
359+
c.logger.Begin()
360+
}
344361
// TODO(crawshaw): offer BEGIN DEFERRED (and BEGIN CONCURRENT?)
345362
// semantics via a context annotation function.
346363
if err := c.execInternal(ctx, "BEGIN IMMEDIATE"); err != nil {
@@ -351,15 +368,16 @@ func (c *conn) txInit(ctx context.Context) error {
351368
}
352369

353370
func (c *conn) txEnd(ctx context.Context, endStmt string) error {
354-
state, readOnly := c.txState, c.readOnly
355-
c.txState = txStateNone
356-
c.readOnly = false
357-
if state != txStateBegun {
371+
defer func() {
372+
c.txState = txStateNone
373+
c.readOnly = false
374+
}()
375+
if c.txState != txStateBegun {
358376
return nil
359377
}
360378

361379
err := c.execInternal(context.Background(), endStmt)
362-
if readOnly {
380+
if c.readOnly {
363381
if err2 := c.execInternal(ctx, "PRAGMA query_only=false"); err == nil {
364382
err = err2
365383
}
@@ -377,10 +395,14 @@ func (tx *connTx) Commit() error {
377395
return ErrClosed
378396
}
379397

398+
readonly := tx.conn.readOnly
380399
err := tx.conn.txEnd(context.Background(), "COMMIT")
381400
if tx.conn.tracer != nil {
382401
tx.conn.tracer.Commit(tx.conn.id, err)
383402
}
403+
if tx.conn.logger != nil && err == nil && !readonly {
404+
tx.conn.logger.Commit()
405+
}
384406
return err
385407
}
386408

@@ -390,10 +412,14 @@ func (tx *connTx) Rollback() error {
390412
return ErrClosed
391413
}
392414

415+
readonly := tx.conn.readOnly
393416
err := tx.conn.txEnd(context.Background(), "ROLLBACK")
394417
if tx.conn.tracer != nil {
395418
tx.conn.tracer.Rollback(tx.conn.id, err)
396419
}
420+
if tx.conn.logger != nil && !readonly {
421+
tx.conn.logger.Rollback()
422+
}
397423
return err
398424
}
399425

@@ -490,6 +516,9 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
490516
if err := s.bindAll(args); err != nil {
491517
return nil, s.reserr("Stmt.Exec(Bind)", err)
492518
}
519+
if s.conn.logger != nil && !s.conn.readOnly {
520+
s.conn.logger.Statement(s.stmt.ExpandedSQL())
521+
}
493522

494523
if ctx.Value(queryCancelKey{}) != nil {
495524
done := make(chan struct{})
@@ -1068,3 +1097,20 @@ func WithQueryCancel(ctx context.Context) context.Context {
10681097

10691098
// queryCancelKey is a context key for query context enforcement.
10701099
type queryCancelKey struct{}
1100+
1101+
// ConnLogger is implemented by the caller to support statement-level logging for
1102+
// write transactions. Only Exec calls are logged, not Query calls, as this is
1103+
// intended as a mechanism to replay failed transactions.
1104+
type ConnLogger interface {
1105+
// Begin is called when a writable transaction is opened.
1106+
Begin()
1107+
1108+
// Statement is called with evaluated SQL when a statement is executed.
1109+
Statement(sql string)
1110+
1111+
// Commit is called when a transaction successfully commits.
1112+
Commit()
1113+
1114+
// Rollback is called when a transaction is rolled back.
1115+
Rollback()
1116+
}

sqlite_test.go

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"os"
1414
"reflect"
1515
"runtime"
16+
"slices"
1617
"strconv"
1718
"strings"
1819
"sync"
@@ -1354,3 +1355,183 @@ func TestDisableFunction(t *testing.T) {
13541355
t.Fatal("Attempting to use the LOWER function after disabling should have failed")
13551356
}
13561357
}
1358+
1359+
type connLogger struct {
1360+
ch chan []string
1361+
statements []string
1362+
panicOnUse bool
1363+
}
1364+
1365+
func (cl *connLogger) Begin() {
1366+
if cl.panicOnUse {
1367+
panic("unexpected connLogger.Begin()")
1368+
}
1369+
cl.statements = nil
1370+
}
1371+
1372+
func (cl *connLogger) Statement(s string) {
1373+
if cl.panicOnUse {
1374+
panic("unexpected connLogger.Statement: " + s)
1375+
}
1376+
cl.statements = append(cl.statements, s)
1377+
}
1378+
1379+
func (cl *connLogger) Commit() {
1380+
if cl.panicOnUse {
1381+
panic("unexpected connLogger.Commit()")
1382+
}
1383+
cl.ch <- cl.statements
1384+
}
1385+
1386+
func (cl *connLogger) Rollback() {
1387+
if cl.panicOnUse {
1388+
panic("unexpected connLogger.Rollback()")
1389+
}
1390+
cl.statements = nil
1391+
}
1392+
1393+
func TestConnLogger_writable(t *testing.T) {
1394+
for _, commit := range []bool{true, false} {
1395+
doneStatement := "ROLLBACK"
1396+
if commit {
1397+
doneStatement = "COMMIT"
1398+
}
1399+
t.Run(doneStatement, func(t *testing.T) {
1400+
ctx := context.Background()
1401+
ch := make(chan []string, 1)
1402+
txl := connLogger{ch: ch}
1403+
makeLogger := func() ConnLogger { return &txl }
1404+
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
1405+
configDB(t, db)
1406+
1407+
tx, err := db.BeginTx(ctx, nil)
1408+
if err != nil {
1409+
t.Fatal(err)
1410+
}
1411+
if _, err := tx.Exec("CREATE TABLE T (x INTEGER)"); err != nil {
1412+
t.Fatal(err)
1413+
}
1414+
if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err != nil {
1415+
t.Fatal(err)
1416+
}
1417+
if _, err := tx.Query("SELECT x FROM T"); err != nil {
1418+
t.Fatal(err)
1419+
}
1420+
done := tx.Rollback
1421+
if commit {
1422+
done = tx.Commit
1423+
}
1424+
if err := done(); err != nil {
1425+
t.Fatal(err)
1426+
}
1427+
if !commit {
1428+
select {
1429+
case got := <-ch:
1430+
t.Errorf("unexpectedly logged statements for rollback:\n%s", strings.Join(got, "\n"))
1431+
default:
1432+
return
1433+
}
1434+
}
1435+
1436+
want := []string{
1437+
"BEGIN IMMEDIATE",
1438+
"CREATE TABLE T (x INTEGER)",
1439+
"INSERT INTO T VALUES (1)",
1440+
doneStatement,
1441+
}
1442+
select {
1443+
case got := <-ch:
1444+
if !slices.Equal(got, want) {
1445+
t.Errorf("unexpected log statements. got:\n%s\n\nwant:\n%s", strings.Join(got, "\n"), strings.Join(want, "\n"))
1446+
}
1447+
default:
1448+
t.Fatal("no logged statements after commit")
1449+
}
1450+
})
1451+
}
1452+
}
1453+
1454+
func TestConnLogger_commit_error_retry(t *testing.T) {
1455+
ctx := context.Background()
1456+
ch := make(chan []string, 1)
1457+
txl := connLogger{ch: ch}
1458+
makeLogger := func() ConnLogger { return &txl }
1459+
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
1460+
configDB(t, db)
1461+
1462+
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
1463+
t.Fatal(err)
1464+
}
1465+
if _, err := db.Exec("CREATE TABLE A (x INTEGER PRIMARY KEY)"); err != nil {
1466+
t.Fatal(err)
1467+
}
1468+
if _, err := db.Exec("CREATE TABLE B (x INTEGER REFERENCES A(X) DEFERRABLE INITIALLY DEFERRED)"); err != nil {
1469+
t.Fatal(err)
1470+
}
1471+
1472+
tx, err := db.BeginTx(ctx, nil)
1473+
if err != nil {
1474+
t.Fatal(err)
1475+
}
1476+
if _, err := tx.Exec("INSERT INTO B VALUES (?)", 1); err != nil {
1477+
t.Fatal(err)
1478+
}
1479+
if err := tx.Commit(); err == nil {
1480+
t.Fatal("expected Commit to error, but didn't")
1481+
}
1482+
select {
1483+
case got := <-ch:
1484+
t.Errorf("unexpectedly logged statements for errored commit:\n%s", strings.Join(got, "\n"))
1485+
default:
1486+
return
1487+
}
1488+
}
1489+
1490+
func TestConnLogger_read_tx(t *testing.T) {
1491+
ctx := context.Background()
1492+
ch := make(chan []string, 1)
1493+
txl := connLogger{ch: ch}
1494+
makeLogger := func() ConnLogger { return &txl }
1495+
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
1496+
configDB(t, db)
1497+
1498+
tx, err := db.BeginTx(ctx, nil)
1499+
if err != nil {
1500+
t.Fatal(err)
1501+
}
1502+
if _, err := tx.Exec("CREATE TABLE T (x INTEGER)"); err != nil {
1503+
t.Fatal(err)
1504+
}
1505+
if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err != nil {
1506+
t.Fatal(err)
1507+
}
1508+
if err := tx.Commit(); err != nil {
1509+
t.Fatal(err)
1510+
}
1511+
select {
1512+
case got := <-ch:
1513+
if len(got) == 0 {
1514+
t.Errorf("expected logged statements for write tx")
1515+
}
1516+
default:
1517+
t.Errorf("expected logged statements for write tx")
1518+
}
1519+
1520+
txl.panicOnUse = true
1521+
for _, commit := range []bool{true, false} {
1522+
rx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
1523+
if err != nil {
1524+
t.Fatal(err)
1525+
}
1526+
if _, err := rx.Query("SELECT x FROM T"); err != nil {
1527+
t.Fatal(err)
1528+
}
1529+
done := rx.Rollback
1530+
if commit {
1531+
done = rx.Commit
1532+
}
1533+
if err := done(); err != nil {
1534+
t.Fatal(err)
1535+
}
1536+
}
1537+
}

0 commit comments

Comments
 (0)