Skip to content

Commit 5fc8c87

Browse files
serprexlance6716
andauthored
driver: context support (#997)
* wip * poc: context support * move context management into driver Close instead of SetTimeout hack implement more interfaces, with assertions * lint * avoid excessive spinning when context canceled & waiting on contexts channel to close * compressedHeader tweak * mysql doesn't have named parameters --------- Co-authored-by: lance6716 <[email protected]>
1 parent 08630ce commit 5fc8c87

File tree

3 files changed

+192
-20
lines changed

3 files changed

+192
-20
lines changed

client/conn.go

+15
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,21 @@ func (c *Conn) Begin() error {
382382
return errors.Trace(err)
383383
}
384384

385+
func (c *Conn) BeginTx(readOnly bool, txIsolation string) error {
386+
if txIsolation != "" {
387+
if _, err := c.exec("SET TRANSACTION ISOLATION LEVEL " + txIsolation); err != nil {
388+
return errors.Trace(err)
389+
}
390+
}
391+
var err error
392+
if readOnly {
393+
_, err = c.exec("START TRANSACTION READ ONLY")
394+
} else {
395+
_, err = c.exec("START TRANSACTION")
396+
}
397+
return errors.Trace(err)
398+
}
399+
385400
func (c *Conn) Commit() error {
386401
_, err := c.exec("COMMIT")
387402
return errors.Trace(err)

driver/driver.go

+175-18
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package driver
44

55
import (
6+
"context"
67
"crypto/tls"
78
"database/sql"
89
sqldriver "database/sql/driver"
@@ -21,6 +22,23 @@ import (
2122
"github.com/pingcap/errors"
2223
)
2324

25+
var (
26+
_ sqldriver.Driver = &driver{}
27+
_ sqldriver.DriverContext = &driver{}
28+
_ sqldriver.Connector = &connInfo{}
29+
_ sqldriver.NamedValueChecker = &conn{}
30+
_ sqldriver.Validator = &conn{}
31+
_ sqldriver.Conn = &conn{}
32+
_ sqldriver.Pinger = &conn{}
33+
_ sqldriver.ConnBeginTx = &conn{}
34+
_ sqldriver.ConnPrepareContext = &conn{}
35+
_ sqldriver.ExecerContext = &conn{}
36+
_ sqldriver.QueryerContext = &conn{}
37+
_ sqldriver.Stmt = &stmt{}
38+
_ sqldriver.StmtExecContext = &stmt{}
39+
_ sqldriver.StmtQueryContext = &stmt{}
40+
)
41+
2442
var customTLSMutex sync.Mutex
2543

2644
// Map of dsn address (makes more sense than full dsn?) to tls Config
@@ -101,16 +119,18 @@ func parseDSN(dsn string) (connInfo, error) {
101119
// Open takes a supplied DSN string and opens a connection
102120
// See ParseDSN for more information on the form of the DSN
103121
func (d driver) Open(dsn string) (sqldriver.Conn, error) {
104-
var (
105-
c *client.Conn
106-
// by default database/sql driver retries will be enabled
107-
retries = true
108-
)
109-
110122
ci, err := parseDSN(dsn)
111123
if err != nil {
112124
return nil, err
113125
}
126+
return ci.Connect(context.Background())
127+
}
128+
129+
func (ci connInfo) Connect(ctx context.Context) (sqldriver.Conn, error) {
130+
var c *client.Conn
131+
var err error
132+
// by default database/sql driver retries will be enabled
133+
retries := true
114134

115135
if ci.standardDSN {
116136
var timeout time.Duration
@@ -159,45 +179,86 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
159179
}
160180
}
161181

162-
if timeout > 0 {
163-
c, err = client.ConnectWithTimeout(ci.addr, ci.user, ci.password, ci.db, timeout, configuredOptions...)
164-
} else {
165-
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, configuredOptions...)
182+
if timeout <= 0 {
183+
timeout = 10 * time.Second
166184
}
185+
c, err = client.ConnectWithContext(ctx, ci.addr, ci.user, ci.password, ci.db, timeout, configuredOptions...)
167186
} else {
168187
// No more processing here. Let's only support url parameters with the newer style DSN
169-
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db)
188+
c, err = client.ConnectWithContext(ctx, ci.addr, ci.user, ci.password, ci.db, 10*time.Second)
170189
}
171190
if err != nil {
172191
return nil, err
173192
}
174193

194+
contexts := make(chan context.Context)
195+
go func() {
196+
ctx := context.Background()
197+
for {
198+
var ok bool
199+
select {
200+
case <-ctx.Done():
201+
ctx = context.Background()
202+
_ = c.Conn.Close()
203+
case ctx, ok = <-contexts:
204+
if !ok {
205+
return
206+
}
207+
}
208+
}
209+
}()
210+
175211
// if retries are 'on' then return sqldriver.ErrBadConn which will trigger up to 3
176212
// retries by the database/sql package. If retries are 'off' then we'll return
177213
// the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry.
178214
// In this case the sqldriver.Validator interface is implemented and will return
179215
// false for IsValid() signaling the connection is bad and should be discarded.
180-
return &conn{Conn: c, state: &state{valid: true, useStdLibErrors: retries}}, nil
216+
return &conn{
217+
Conn: c,
218+
state: &state{contexts: contexts, valid: true, useStdLibErrors: retries},
219+
}, nil
181220
}
182221

183-
type CheckNamedValueFunc func(*sqldriver.NamedValue) error
222+
func (d driver) OpenConnector(name string) (sqldriver.Connector, error) {
223+
return parseDSN(name)
224+
}
184225

185-
var (
186-
_ sqldriver.NamedValueChecker = &conn{}
187-
_ sqldriver.Validator = &conn{}
188-
)
226+
func (ci connInfo) Driver() sqldriver.Driver {
227+
return driver{}
228+
}
229+
230+
type CheckNamedValueFunc func(*sqldriver.NamedValue) error
189231

190232
type state struct {
191-
valid bool
233+
contexts chan context.Context
234+
valid bool
192235
// when true, the driver connection will return ErrBadConn from the golang Standard Library
193236
useStdLibErrors bool
194237
}
195238

239+
func (s *state) watchCtx(ctx context.Context) func() {
240+
s.contexts <- ctx
241+
return func() {
242+
s.contexts <- context.Background()
243+
}
244+
}
245+
246+
func (s *state) Close() {
247+
if s.contexts != nil {
248+
close(s.contexts)
249+
s.contexts = nil
250+
}
251+
}
252+
196253
type conn struct {
197254
*client.Conn
198255
state *state
199256
}
200257

258+
func (c *conn) watchCtx(ctx context.Context) func() {
259+
return c.state.watchCtx(ctx)
260+
}
261+
201262
func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error {
202263
for _, nvChecker := range namedValueCheckers {
203264
err := nvChecker(nv)
@@ -220,6 +281,17 @@ func (c *conn) IsValid() bool {
220281
return c.state.valid
221282
}
222283

284+
func (c *conn) Ping(ctx context.Context) error {
285+
defer c.watchCtx(ctx)()
286+
if err := c.Conn.Ping(); err != nil {
287+
if err == context.DeadlineExceeded || err == context.Canceled {
288+
return err
289+
}
290+
return sqldriver.ErrBadConn
291+
}
292+
return nil
293+
}
294+
223295
func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
224296
st, err := c.Conn.Prepare(query)
225297
if err != nil {
@@ -229,7 +301,13 @@ func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
229301
return &stmt{Stmt: st, connectionState: c.state}, nil
230302
}
231303

304+
func (c *conn) PrepareContext(ctx context.Context, query string) (sqldriver.Stmt, error) {
305+
defer c.watchCtx(ctx)()
306+
return c.Prepare(query)
307+
}
308+
232309
func (c *conn) Close() error {
310+
c.state.Close()
233311
return c.Conn.Close()
234312
}
235313

@@ -242,6 +320,29 @@ func (c *conn) Begin() (sqldriver.Tx, error) {
242320
return &tx{c.Conn}, nil
243321
}
244322

323+
var isolationLevelTransactionIsolation = map[sql.IsolationLevel]string{
324+
sql.LevelDefault: "",
325+
sql.LevelRepeatableRead: "REPEATABLE READ",
326+
sql.LevelReadCommitted: "READ COMMITTED",
327+
sql.LevelReadUncommitted: "READ UNCOMMITTED",
328+
sql.LevelSerializable: "SERIALIZABLE",
329+
}
330+
331+
func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver.Tx, error) {
332+
defer c.watchCtx(ctx)()
333+
334+
isolation := sql.IsolationLevel(opts.Isolation)
335+
txIsolation, ok := isolationLevelTransactionIsolation[isolation]
336+
if !ok {
337+
return nil, fmt.Errorf("invalid mysql transaction isolation level %s", isolation)
338+
}
339+
err := c.Conn.BeginTx(opts.ReadOnly, txIsolation)
340+
if err != nil {
341+
return nil, errors.Trace(err)
342+
}
343+
return &tx{c.Conn}, nil
344+
}
345+
245346
func buildArgs(args []sqldriver.Value) []interface{} {
246347
a := make([]interface{}, len(args))
247348

@@ -252,6 +353,16 @@ func buildArgs(args []sqldriver.Value) []interface{} {
252353
return a
253354
}
254355

356+
func buildNamedArgs(args []sqldriver.NamedValue) []interface{} {
357+
a := make([]interface{}, len(args))
358+
359+
for i, arg := range args {
360+
a[i] = arg.Value
361+
}
362+
363+
return a
364+
}
365+
255366
func (st *state) replyError(err error) error {
256367
isBadConnection := mysql.ErrorEqual(err, mysql.ErrBadConn)
257368

@@ -275,6 +386,16 @@ func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, err
275386
return &result{r}, nil
276387
}
277388

389+
func (c *conn) ExecContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Result, error) {
390+
defer c.watchCtx(ctx)()
391+
a := buildNamedArgs(args)
392+
r, err := c.Conn.Execute(query, a...)
393+
if err != nil {
394+
return nil, c.state.replyError(err)
395+
}
396+
return &result{r}, nil
397+
}
398+
278399
func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, error) {
279400
a := buildArgs(args)
280401
r, err := c.Conn.Execute(query, a...)
@@ -284,11 +405,25 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro
284405
return newRows(r.Resultset)
285406
}
286407

408+
func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
409+
defer c.watchCtx(ctx)()
410+
a := buildNamedArgs(args)
411+
r, err := c.Conn.Execute(query, a...)
412+
if err != nil {
413+
return nil, c.state.replyError(err)
414+
}
415+
return newRows(r.Resultset)
416+
}
417+
287418
type stmt struct {
288419
*client.Stmt
289420
connectionState *state
290421
}
291422

423+
func (s *stmt) watchCtx(ctx context.Context) func() {
424+
return s.connectionState.watchCtx(ctx)
425+
}
426+
292427
func (s *stmt) Close() error {
293428
return s.Stmt.Close()
294429
}
@@ -306,6 +441,17 @@ func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) {
306441
return &result{r}, nil
307442
}
308443

444+
func (s *stmt) ExecContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Result, error) {
445+
defer s.watchCtx(ctx)()
446+
447+
a := buildNamedArgs(args)
448+
r, err := s.Stmt.Execute(a...)
449+
if err != nil {
450+
return nil, s.connectionState.replyError(err)
451+
}
452+
return &result{r}, nil
453+
}
454+
309455
func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
310456
a := buildArgs(args)
311457
r, err := s.Stmt.Execute(a...)
@@ -315,6 +461,17 @@ func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
315461
return newRows(r.Resultset)
316462
}
317463

464+
func (s *stmt) QueryContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
465+
defer s.watchCtx(ctx)()
466+
467+
a := buildNamedArgs(args)
468+
r, err := s.Stmt.Execute(a...)
469+
if err != nil {
470+
return nil, s.connectionState.replyError(err)
471+
}
472+
return newRows(r.Resultset)
473+
}
474+
318475
type tx struct {
319476
*client.Conn
320477
}

packet/conn.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ func (c *Conn) writeCompressed(data []byte) (n int, err error) {
360360
var (
361361
compressedLength, uncompressedLength int
362362
payload *bytes.Buffer
363-
compressedHeader = make([]byte, 7)
363+
compressedHeader [7]byte
364364
)
365365

366366
if len(data) > MinCompressionLength {
@@ -406,7 +406,7 @@ func (c *Conn) writeCompressed(data []byte) (n int, err error) {
406406
compressedHeader[4] = byte(uncompressedLength)
407407
compressedHeader[5] = byte(uncompressedLength >> 8)
408408
compressedHeader[6] = byte(uncompressedLength >> 16)
409-
if _, err = compressedPacket.Write(compressedHeader); err != nil {
409+
if _, err = compressedPacket.Write(compressedHeader[:]); err != nil {
410410
return 0, err
411411
}
412412
c.CompressedSequence++

0 commit comments

Comments
 (0)