Skip to content

Commit 1c3ba7d

Browse files
committed
move context management into driver
Close instead of SetTimeout hack implement more interfaces, with assertions
1 parent 3409818 commit 1c3ba7d

File tree

2 files changed

+137
-70
lines changed

2 files changed

+137
-70
lines changed

client/conn.go

+1-47
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ type Conn struct {
3535
// Connection read and write timeouts to set on the connection
3636
ReadTimeout time.Duration
3737
WriteTimeout time.Duration
38-
contexts chan context.Context
39-
closed bool
4038

4139
// The buffer size to use in the packet connection
4240
BufferSize int
@@ -138,7 +136,6 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam
138136
c.password = password
139137
c.db = dbName
140138
c.proto = network
141-
c.contexts = make(chan context.Context)
142139

143140
// use default charset here, utf-8
144141
c.charset = mysql.DEFAULT_CHARSET
@@ -187,21 +184,6 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam
187184
}
188185
}
189186

190-
go func() {
191-
ctx := context.Background()
192-
for {
193-
var ok bool
194-
select {
195-
case <-ctx.Done():
196-
_ = c.Conn.SetDeadline(time.Unix(0, 0))
197-
case ctx, ok = <-c.contexts:
198-
if !ok {
199-
return
200-
}
201-
}
202-
}
203-
}()
204-
205187
return c, nil
206188
}
207189

@@ -226,19 +208,8 @@ func (c *Conn) handshake() error {
226208
return nil
227209
}
228210

229-
func (c *Conn) watchCtx(ctx context.Context) func() {
230-
c.contexts <- ctx
231-
return func() {
232-
c.contexts <- context.Background()
233-
}
234-
}
235-
236211
// Close directly closes the connection. Use Quit() to first send COM_QUIT to the server and then close the connection.
237212
func (c *Conn) Close() error {
238-
if !c.closed {
239-
close(c.contexts)
240-
c.closed = true
241-
}
242213
return c.Conn.Close()
243214
}
244215

@@ -338,11 +309,6 @@ func (c *Conn) Execute(command string, args ...interface{}) (*mysql.Result, erro
338309
}
339310
}
340311

341-
func (c *Conn) ExecuteContext(ctx context.Context, command string, args ...interface{}) (*mysql.Result, error) {
342-
defer c.watchCtx(ctx)
343-
return c.Execute(command, args...)
344-
}
345-
346312
// ExecuteMultiple will call perResultCallback for every result of the multiple queries
347313
// that are executed.
348314
//
@@ -397,11 +363,6 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall
397363
return mysql.NewResult(rs), nil
398364
}
399365

400-
func (c *Conn) ExecuteMultipleContext(ctx context.Context, query string, perResultCallback ExecPerResultCallback) (*mysql.Result, error) {
401-
defer c.watchCtx(ctx)
402-
return c.ExecuteMultiple(query, perResultCallback)
403-
}
404-
405366
// ExecuteSelectStreaming will call perRowCallback for every row in resultset
406367
// WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields.
407368
// When given, perResultCallback will be called once per result
@@ -415,19 +376,12 @@ func (c *Conn) ExecuteSelectStreaming(command string, result *mysql.Result, perR
415376
return c.readResultStreaming(false, result, perRowCallback, perResultCallback)
416377
}
417378

418-
func (c *Conn) ExecuteSelectStreamingContext(ctx context.Context, command string, result *mysql.Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error {
419-
defer c.watchCtx(ctx)
420-
return c.ExecuteSelectStreaming(command, result, perRowCallback, perResultCallback)
421-
}
422-
423379
func (c *Conn) Begin() error {
424380
_, err := c.exec("BEGIN")
425381
return errors.Trace(err)
426382
}
427383

428-
func (c *Conn) BeginTx(ctx context.Context, readOnly bool, txIsolation string) error {
429-
defer c.watchCtx(ctx)()
430-
384+
func (c *Conn) BeginTx(readOnly bool, txIsolation string) error {
431385
if txIsolation != "" {
432386
if _, err := c.exec("SET TRANSACTION ISOLATION LEVEL " + txIsolation); err != nil {
433387
return errors.Trace(err)

driver/driver.go

+136-23
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,23 @@ import (
2222
"github.com/pingcap/errors"
2323
)
2424

25+
var (
26+
_ sqldriver.Driver = &driver{}
27+
_ sqldriver.DriverContext = &driver{}
28+
_ sqldriver.Connector = &connInfo{}
29+
_ sqldriver.NamedValueChecker = &conn{}
30+
_ sqldriver.Validator = &conn{}
31+
_ sqldriver.Conn = &conn{}
32+
_ sqldriver.Pinger = &conn{}
33+
_ sqldriver.ConnBeginTx = &conn{}
34+
_ sqldriver.ConnPrepareContext = &conn{}
35+
_ sqldriver.ExecerContext = &conn{}
36+
_ sqldriver.QueryerContext = &conn{}
37+
_ sqldriver.Stmt = &stmt{}
38+
_ sqldriver.StmtExecContext = &stmt{}
39+
_ sqldriver.StmtQueryContext = &stmt{}
40+
)
41+
2542
var customTLSMutex sync.Mutex
2643

2744
// Map of dsn address (makes more sense than full dsn?) to tls Config
@@ -102,16 +119,19 @@ func parseDSN(dsn string) (connInfo, error) {
102119
// Open takes a supplied DSN string and opens a connection
103120
// See ParseDSN for more information on the form of the DSN
104121
func (d driver) Open(dsn string) (sqldriver.Conn, error) {
105-
var (
106-
c *client.Conn
107-
// by default database/sql driver retries will be enabled
108-
retries = true
109-
)
110-
111122
ci, err := parseDSN(dsn)
112123
if err != nil {
113124
return nil, err
114125
}
126+
return ci.Connect(context.Background())
127+
128+
}
129+
130+
func (ci connInfo) Connect(ctx context.Context) (sqldriver.Conn, error) {
131+
var c *client.Conn
132+
var err error
133+
// by default database/sql driver retries will be enabled
134+
retries := true
115135

116136
if ci.standardDSN {
117137
var timeout time.Duration
@@ -160,48 +180,85 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
160180
}
161181
}
162182

163-
if timeout > 0 {
164-
c, err = client.ConnectWithTimeout(ci.addr, ci.user, ci.password, ci.db, timeout, configuredOptions...)
165-
} else {
166-
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, configuredOptions...)
183+
if timeout <= 0 {
184+
timeout = 10 * time.Second
167185
}
186+
c, err = client.ConnectWithContext(ctx, ci.addr, ci.user, ci.password, ci.db, timeout, configuredOptions...)
168187
} else {
169188
// No more processing here. Let's only support url parameters with the newer style DSN
170-
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db)
189+
c, err = client.ConnectWithContext(ctx, ci.addr, ci.user, ci.password, ci.db, 10*time.Second)
171190
}
172191
if err != nil {
173192
return nil, err
174193
}
175194

195+
contexts := make(chan context.Context)
196+
go func() {
197+
ctx := context.Background()
198+
for {
199+
var ok bool
200+
select {
201+
case <-ctx.Done():
202+
_ = c.Conn.Close()
203+
case ctx, ok = <-contexts:
204+
if !ok {
205+
return
206+
}
207+
}
208+
}
209+
}()
210+
176211
// if retries are 'on' then return sqldriver.ErrBadConn which will trigger up to 3
177212
// retries by the database/sql package. If retries are 'off' then we'll return
178213
// the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry.
179214
// In this case the sqldriver.Validator interface is implemented and will return
180215
// false for IsValid() signaling the connection is bad and should be discarded.
181-
return &conn{Conn: c, state: &state{valid: true, useStdLibErrors: retries}}, nil
216+
return &conn{
217+
Conn: c,
218+
state: &state{contexts: contexts, valid: true, useStdLibErrors: retries},
219+
}, nil
182220
}
183221

184-
type CheckNamedValueFunc func(*sqldriver.NamedValue) error
222+
func (d driver) OpenConnector(name string) (sqldriver.Connector, error) {
223+
return parseDSN(name)
224+
}
185225

186-
var (
187-
_ sqldriver.NamedValueChecker = &conn{}
188-
_ sqldriver.Validator = &conn{}
189-
_ sqldriver.Conn = &conn{}
190-
_ sqldriver.ConnBeginTx = &conn{}
191-
_ sqldriver.QueryerContext = &conn{}
192-
)
226+
func (ci connInfo) Driver() sqldriver.Driver {
227+
return driver{}
228+
}
229+
230+
type CheckNamedValueFunc func(*sqldriver.NamedValue) error
193231

194232
type state struct {
195-
valid bool
233+
contexts chan context.Context
234+
valid bool
196235
// when true, the driver connection will return ErrBadConn from the golang Standard Library
197236
useStdLibErrors bool
198237
}
199238

239+
func (s *state) watchCtx(ctx context.Context) func() {
240+
s.contexts <- ctx
241+
return func() {
242+
s.contexts <- context.Background()
243+
}
244+
}
245+
246+
func (s *state) Close() {
247+
if s.contexts != nil {
248+
close(s.contexts)
249+
s.contexts = nil
250+
}
251+
}
252+
200253
type conn struct {
201254
*client.Conn
202255
state *state
203256
}
204257

258+
func (c *conn) watchCtx(ctx context.Context) func() {
259+
return c.state.watchCtx(ctx)
260+
}
261+
205262
func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error {
206263
for _, nvChecker := range namedValueCheckers {
207264
err := nvChecker(nv)
@@ -224,6 +281,17 @@ func (c *conn) IsValid() bool {
224281
return c.state.valid
225282
}
226283

284+
func (c *conn) Ping(ctx context.Context) error {
285+
defer c.watchCtx(ctx)()
286+
if err := c.Conn.Ping(); err != nil {
287+
if err == context.DeadlineExceeded || err == context.Canceled {
288+
return err
289+
}
290+
return sqldriver.ErrBadConn
291+
}
292+
return nil
293+
}
294+
227295
func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
228296
st, err := c.Conn.Prepare(query)
229297
if err != nil {
@@ -233,7 +301,13 @@ func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
233301
return &stmt{Stmt: st, connectionState: c.state}, nil
234302
}
235303

304+
func (c *conn) PrepareContext(ctx context.Context, query string) (sqldriver.Stmt, error) {
305+
defer c.watchCtx(ctx)()
306+
return c.Prepare(query)
307+
}
308+
236309
func (c *conn) Close() error {
310+
c.state.Close()
237311
return c.Conn.Close()
238312
}
239313

@@ -255,12 +329,14 @@ var isolationLevelTransactionIsolation = map[sql.IsolationLevel]string{
255329
}
256330

257331
func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver.Tx, error) {
332+
defer c.watchCtx(ctx)()
333+
258334
isolation := sql.IsolationLevel(opts.Isolation)
259335
txIsolation, ok := isolationLevelTransactionIsolation[isolation]
260336
if !ok {
261337
return nil, fmt.Errorf("invalid mysql transaction isolation level %s", isolation)
262338
}
263-
err := c.Conn.BeginTx(ctx, opts.ReadOnly, txIsolation)
339+
err := c.Conn.BeginTx(opts.ReadOnly, txIsolation)
264340
if err != nil {
265341
return nil, errors.Trace(err)
266342
}
@@ -311,6 +387,16 @@ func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, err
311387
return &result{r}, nil
312388
}
313389

390+
func (c *conn) ExecContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Result, error) {
391+
defer c.watchCtx(ctx)()
392+
a := buildNamedArgs(args)
393+
r, err := c.Conn.Execute(query, a...)
394+
if err != nil {
395+
return nil, c.state.replyError(err)
396+
}
397+
return &result{r}, nil
398+
}
399+
314400
func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, error) {
315401
a := buildArgs(args)
316402
r, err := c.Conn.Execute(query, a...)
@@ -321,8 +407,9 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro
321407
}
322408

323409
func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
410+
defer c.watchCtx(ctx)()
324411
a := buildNamedArgs(args)
325-
r, err := c.Conn.ExecuteContext(ctx, query, a...)
412+
r, err := c.Conn.Execute(query, a...)
326413
if err != nil {
327414
return nil, c.state.replyError(err)
328415
}
@@ -334,6 +421,10 @@ type stmt struct {
334421
connectionState *state
335422
}
336423

424+
func (s *stmt) watchCtx(ctx context.Context) func() {
425+
return s.connectionState.watchCtx(ctx)
426+
}
427+
337428
func (s *stmt) Close() error {
338429
return s.Stmt.Close()
339430
}
@@ -351,6 +442,17 @@ func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) {
351442
return &result{r}, nil
352443
}
353444

445+
func (s *stmt) ExecContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Result, error) {
446+
defer s.watchCtx(ctx)()
447+
448+
a := buildNamedArgs(args)
449+
r, err := s.Stmt.Execute(a...)
450+
if err != nil {
451+
return nil, s.connectionState.replyError(err)
452+
}
453+
return &result{r}, nil
454+
}
455+
354456
func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
355457
a := buildArgs(args)
356458
r, err := s.Stmt.Execute(a...)
@@ -360,6 +462,17 @@ func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
360462
return newRows(r.Resultset)
361463
}
362464

465+
func (s *stmt) QueryContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
466+
defer s.watchCtx(ctx)()
467+
468+
a := buildNamedArgs(args)
469+
r, err := s.Stmt.Execute(a...)
470+
if err != nil {
471+
return nil, s.connectionState.replyError(err)
472+
}
473+
return newRows(r.Resultset)
474+
}
475+
363476
type tx struct {
364477
*client.Conn
365478
}

0 commit comments

Comments
 (0)