@@ -13,6 +13,7 @@ import (
1313 "os"
1414 "reflect"
1515 "runtime"
16+ "slices"
1617 "strconv"
1718 "strings"
1819 "sync"
@@ -1354,3 +1355,183 @@ func TestDisableFunction(t *testing.T) {
13541355 t .Fatal ("Attempting to use the LOWER function after disabling should have failed" )
13551356 }
13561357}
1358+
1359+ type connLogger struct {
1360+ ch chan []string
1361+ statements []string
1362+ panicOnUse bool
1363+ }
1364+
1365+ func (cl * connLogger ) Begin () {
1366+ if cl .panicOnUse {
1367+ panic ("unexpected connLogger.Begin()" )
1368+ }
1369+ cl .statements = nil
1370+ }
1371+
1372+ func (cl * connLogger ) Statement (s string ) {
1373+ if cl .panicOnUse {
1374+ panic ("unexpected connLogger.Statement: " + s )
1375+ }
1376+ cl .statements = append (cl .statements , s )
1377+ }
1378+
1379+ func (cl * connLogger ) Commit () {
1380+ if cl .panicOnUse {
1381+ panic ("unexpected connLogger.Commit()" )
1382+ }
1383+ cl .ch <- cl .statements
1384+ }
1385+
1386+ func (cl * connLogger ) Rollback () {
1387+ if cl .panicOnUse {
1388+ panic ("unexpected connLogger.Rollback()" )
1389+ }
1390+ cl .statements = nil
1391+ }
1392+
1393+ func TestConnLogger_writable (t * testing.T ) {
1394+ for _ , commit := range []bool {true , false } {
1395+ doneStatement := "ROLLBACK"
1396+ if commit {
1397+ doneStatement = "COMMIT"
1398+ }
1399+ t .Run (doneStatement , func (t * testing.T ) {
1400+ ctx := context .Background ()
1401+ ch := make (chan []string , 1 )
1402+ txl := connLogger {ch : ch }
1403+ makeLogger := func () ConnLogger { return & txl }
1404+ db := sql .OpenDB (ConnectorWithLogger ("file:" + t .TempDir ()+ "/test.db" , nil , nil , makeLogger ))
1405+ configDB (t , db )
1406+
1407+ tx , err := db .BeginTx (ctx , nil )
1408+ if err != nil {
1409+ t .Fatal (err )
1410+ }
1411+ if _ , err := tx .Exec ("CREATE TABLE T (x INTEGER)" ); err != nil {
1412+ t .Fatal (err )
1413+ }
1414+ if _ , err := tx .Exec ("INSERT INTO T VALUES (?)" , 1 ); err != nil {
1415+ t .Fatal (err )
1416+ }
1417+ if _ , err := tx .Query ("SELECT x FROM T" ); err != nil {
1418+ t .Fatal (err )
1419+ }
1420+ done := tx .Rollback
1421+ if commit {
1422+ done = tx .Commit
1423+ }
1424+ if err := done (); err != nil {
1425+ t .Fatal (err )
1426+ }
1427+ if ! commit {
1428+ select {
1429+ case got := <- ch :
1430+ t .Errorf ("unexpectedly logged statements for rollback:\n %s" , strings .Join (got , "\n " ))
1431+ default :
1432+ return
1433+ }
1434+ }
1435+
1436+ want := []string {
1437+ "BEGIN IMMEDIATE" ,
1438+ "CREATE TABLE T (x INTEGER)" ,
1439+ "INSERT INTO T VALUES (1)" ,
1440+ doneStatement ,
1441+ }
1442+ select {
1443+ case got := <- ch :
1444+ if ! slices .Equal (got , want ) {
1445+ t .Errorf ("unexpected log statements. got:\n %s\n \n want:\n %s" , strings .Join (got , "\n " ), strings .Join (want , "\n " ))
1446+ }
1447+ default :
1448+ t .Fatal ("no logged statements after commit" )
1449+ }
1450+ })
1451+ }
1452+ }
1453+
1454+ func TestConnLogger_commit_error_retry (t * testing.T ) {
1455+ ctx := context .Background ()
1456+ ch := make (chan []string , 1 )
1457+ txl := connLogger {ch : ch }
1458+ makeLogger := func () ConnLogger { return & txl }
1459+ db := sql .OpenDB (ConnectorWithLogger ("file:" + t .TempDir ()+ "/test.db" , nil , nil , makeLogger ))
1460+ configDB (t , db )
1461+
1462+ if _ , err := db .Exec ("PRAGMA foreign_keys = ON" ); err != nil {
1463+ t .Fatal (err )
1464+ }
1465+ if _ , err := db .Exec ("CREATE TABLE A (x INTEGER PRIMARY KEY)" ); err != nil {
1466+ t .Fatal (err )
1467+ }
1468+ if _ , err := db .Exec ("CREATE TABLE B (x INTEGER REFERENCES A(X) DEFERRABLE INITIALLY DEFERRED)" ); err != nil {
1469+ t .Fatal (err )
1470+ }
1471+
1472+ tx , err := db .BeginTx (ctx , nil )
1473+ if err != nil {
1474+ t .Fatal (err )
1475+ }
1476+ if _ , err := tx .Exec ("INSERT INTO B VALUES (?)" , 1 ); err != nil {
1477+ t .Fatal (err )
1478+ }
1479+ if err := tx .Commit (); err == nil {
1480+ t .Fatal ("expected Commit to error, but didn't" )
1481+ }
1482+ select {
1483+ case got := <- ch :
1484+ t .Errorf ("unexpectedly logged statements for errored commit:\n %s" , strings .Join (got , "\n " ))
1485+ default :
1486+ return
1487+ }
1488+ }
1489+
1490+ func TestConnLogger_read_tx (t * testing.T ) {
1491+ ctx := context .Background ()
1492+ ch := make (chan []string , 1 )
1493+ txl := connLogger {ch : ch }
1494+ makeLogger := func () ConnLogger { return & txl }
1495+ db := sql .OpenDB (ConnectorWithLogger ("file:" + t .TempDir ()+ "/test.db" , nil , nil , makeLogger ))
1496+ configDB (t , db )
1497+
1498+ tx , err := db .BeginTx (ctx , nil )
1499+ if err != nil {
1500+ t .Fatal (err )
1501+ }
1502+ if _ , err := tx .Exec ("CREATE TABLE T (x INTEGER)" ); err != nil {
1503+ t .Fatal (err )
1504+ }
1505+ if _ , err := tx .Exec ("INSERT INTO T VALUES (?)" , 1 ); err != nil {
1506+ t .Fatal (err )
1507+ }
1508+ if err := tx .Commit (); err != nil {
1509+ t .Fatal (err )
1510+ }
1511+ select {
1512+ case got := <- ch :
1513+ if len (got ) == 0 {
1514+ t .Errorf ("expected logged statements for write tx" )
1515+ }
1516+ default :
1517+ t .Errorf ("expected logged statements for write tx" )
1518+ }
1519+
1520+ txl .panicOnUse = true
1521+ for _ , commit := range []bool {true , false } {
1522+ rx , err := db .BeginTx (ctx , & sql.TxOptions {ReadOnly : true })
1523+ if err != nil {
1524+ t .Fatal (err )
1525+ }
1526+ if _ , err := rx .Query ("SELECT x FROM T" ); err != nil {
1527+ t .Fatal (err )
1528+ }
1529+ done := rx .Rollback
1530+ if commit {
1531+ done = rx .Commit
1532+ }
1533+ if err := done (); err != nil {
1534+ t .Fatal (err )
1535+ }
1536+ }
1537+ }
0 commit comments