3
3
package driver
4
4
5
5
import (
6
+ "context"
6
7
"crypto/tls"
7
8
"database/sql"
8
9
sqldriver "database/sql/driver"
@@ -21,6 +22,23 @@ import (
21
22
"github.com/pingcap/errors"
22
23
)
23
24
25
+ var (
26
+ _ sqldriver.Driver = & driver {}
27
+ _ sqldriver.DriverContext = & driver {}
28
+ _ sqldriver.Connector = & connInfo {}
29
+ _ sqldriver.NamedValueChecker = & conn {}
30
+ _ sqldriver.Validator = & conn {}
31
+ _ sqldriver.Conn = & conn {}
32
+ _ sqldriver.Pinger = & conn {}
33
+ _ sqldriver.ConnBeginTx = & conn {}
34
+ _ sqldriver.ConnPrepareContext = & conn {}
35
+ _ sqldriver.ExecerContext = & conn {}
36
+ _ sqldriver.QueryerContext = & conn {}
37
+ _ sqldriver.Stmt = & stmt {}
38
+ _ sqldriver.StmtExecContext = & stmt {}
39
+ _ sqldriver.StmtQueryContext = & stmt {}
40
+ )
41
+
24
42
var customTLSMutex sync.Mutex
25
43
26
44
// Map of dsn address (makes more sense than full dsn?) to tls Config
@@ -101,16 +119,18 @@ func parseDSN(dsn string) (connInfo, error) {
101
119
// Open takes a supplied DSN string and opens a connection
102
120
// See ParseDSN for more information on the form of the DSN
103
121
func (d driver ) Open (dsn string ) (sqldriver.Conn , error ) {
104
- var (
105
- c * client.Conn
106
- // by default database/sql driver retries will be enabled
107
- retries = true
108
- )
109
-
110
122
ci , err := parseDSN (dsn )
111
123
if err != nil {
112
124
return nil , err
113
125
}
126
+ return ci .Connect (context .Background ())
127
+ }
128
+
129
+ func (ci connInfo ) Connect (ctx context.Context ) (sqldriver.Conn , error ) {
130
+ var c * client.Conn
131
+ var err error
132
+ // by default database/sql driver retries will be enabled
133
+ retries := true
114
134
115
135
if ci .standardDSN {
116
136
var timeout time.Duration
@@ -159,45 +179,86 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
159
179
}
160
180
}
161
181
162
- if timeout > 0 {
163
- c , err = client .ConnectWithTimeout (ci .addr , ci .user , ci .password , ci .db , timeout , configuredOptions ... )
164
- } else {
165
- c , err = client .Connect (ci .addr , ci .user , ci .password , ci .db , configuredOptions ... )
182
+ if timeout <= 0 {
183
+ timeout = 10 * time .Second
166
184
}
185
+ c , err = client .ConnectWithContext (ctx , ci .addr , ci .user , ci .password , ci .db , timeout , configuredOptions ... )
167
186
} else {
168
187
// No more processing here. Let's only support url parameters with the newer style DSN
169
- c , err = client .Connect ( ci .addr , ci .user , ci .password , ci .db )
188
+ c , err = client .ConnectWithContext ( ctx , ci .addr , ci .user , ci .password , ci .db , 10 * time . Second )
170
189
}
171
190
if err != nil {
172
191
return nil , err
173
192
}
174
193
194
+ contexts := make (chan context.Context )
195
+ go func () {
196
+ ctx := context .Background ()
197
+ for {
198
+ var ok bool
199
+ select {
200
+ case <- ctx .Done ():
201
+ ctx = context .Background ()
202
+ _ = c .Conn .Close ()
203
+ case ctx , ok = <- contexts :
204
+ if ! ok {
205
+ return
206
+ }
207
+ }
208
+ }
209
+ }()
210
+
175
211
// if retries are 'on' then return sqldriver.ErrBadConn which will trigger up to 3
176
212
// retries by the database/sql package. If retries are 'off' then we'll return
177
213
// the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry.
178
214
// In this case the sqldriver.Validator interface is implemented and will return
179
215
// false for IsValid() signaling the connection is bad and should be discarded.
180
- return & conn {Conn : c , state : & state {valid : true , useStdLibErrors : retries }}, nil
216
+ return & conn {
217
+ Conn : c ,
218
+ state : & state {contexts : contexts , valid : true , useStdLibErrors : retries },
219
+ }, nil
181
220
}
182
221
183
- type CheckNamedValueFunc func (* sqldriver.NamedValue ) error
222
+ func (d driver ) OpenConnector (name string ) (sqldriver.Connector , error ) {
223
+ return parseDSN (name )
224
+ }
184
225
185
- var (
186
- _ sqldriver.NamedValueChecker = & conn {}
187
- _ sqldriver.Validator = & conn {}
188
- )
226
+ func (ci connInfo ) Driver () sqldriver.Driver {
227
+ return driver {}
228
+ }
229
+
230
+ type CheckNamedValueFunc func (* sqldriver.NamedValue ) error
189
231
190
232
type state struct {
191
- valid bool
233
+ contexts chan context.Context
234
+ valid bool
192
235
// when true, the driver connection will return ErrBadConn from the golang Standard Library
193
236
useStdLibErrors bool
194
237
}
195
238
239
+ func (s * state ) watchCtx (ctx context.Context ) func () {
240
+ s .contexts <- ctx
241
+ return func () {
242
+ s .contexts <- context .Background ()
243
+ }
244
+ }
245
+
246
+ func (s * state ) Close () {
247
+ if s .contexts != nil {
248
+ close (s .contexts )
249
+ s .contexts = nil
250
+ }
251
+ }
252
+
196
253
type conn struct {
197
254
* client.Conn
198
255
state * state
199
256
}
200
257
258
+ func (c * conn ) watchCtx (ctx context.Context ) func () {
259
+ return c .state .watchCtx (ctx )
260
+ }
261
+
201
262
func (c * conn ) CheckNamedValue (nv * sqldriver.NamedValue ) error {
202
263
for _ , nvChecker := range namedValueCheckers {
203
264
err := nvChecker (nv )
@@ -220,6 +281,17 @@ func (c *conn) IsValid() bool {
220
281
return c .state .valid
221
282
}
222
283
284
+ func (c * conn ) Ping (ctx context.Context ) error {
285
+ defer c .watchCtx (ctx )()
286
+ if err := c .Conn .Ping (); err != nil {
287
+ if err == context .DeadlineExceeded || err == context .Canceled {
288
+ return err
289
+ }
290
+ return sqldriver .ErrBadConn
291
+ }
292
+ return nil
293
+ }
294
+
223
295
func (c * conn ) Prepare (query string ) (sqldriver.Stmt , error ) {
224
296
st , err := c .Conn .Prepare (query )
225
297
if err != nil {
@@ -229,7 +301,13 @@ func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
229
301
return & stmt {Stmt : st , connectionState : c .state }, nil
230
302
}
231
303
304
+ func (c * conn ) PrepareContext (ctx context.Context , query string ) (sqldriver.Stmt , error ) {
305
+ defer c .watchCtx (ctx )()
306
+ return c .Prepare (query )
307
+ }
308
+
232
309
func (c * conn ) Close () error {
310
+ c .state .Close ()
233
311
return c .Conn .Close ()
234
312
}
235
313
@@ -242,6 +320,29 @@ func (c *conn) Begin() (sqldriver.Tx, error) {
242
320
return & tx {c .Conn }, nil
243
321
}
244
322
323
+ var isolationLevelTransactionIsolation = map [sql.IsolationLevel ]string {
324
+ sql .LevelDefault : "" ,
325
+ sql .LevelRepeatableRead : "REPEATABLE READ" ,
326
+ sql .LevelReadCommitted : "READ COMMITTED" ,
327
+ sql .LevelReadUncommitted : "READ UNCOMMITTED" ,
328
+ sql .LevelSerializable : "SERIALIZABLE" ,
329
+ }
330
+
331
+ func (c * conn ) BeginTx (ctx context.Context , opts sqldriver.TxOptions ) (sqldriver.Tx , error ) {
332
+ defer c .watchCtx (ctx )()
333
+
334
+ isolation := sql .IsolationLevel (opts .Isolation )
335
+ txIsolation , ok := isolationLevelTransactionIsolation [isolation ]
336
+ if ! ok {
337
+ return nil , fmt .Errorf ("invalid mysql transaction isolation level %s" , isolation )
338
+ }
339
+ err := c .Conn .BeginTx (opts .ReadOnly , txIsolation )
340
+ if err != nil {
341
+ return nil , errors .Trace (err )
342
+ }
343
+ return & tx {c .Conn }, nil
344
+ }
345
+
245
346
func buildArgs (args []sqldriver.Value ) []interface {} {
246
347
a := make ([]interface {}, len (args ))
247
348
@@ -252,6 +353,16 @@ func buildArgs(args []sqldriver.Value) []interface{} {
252
353
return a
253
354
}
254
355
356
+ func buildNamedArgs (args []sqldriver.NamedValue ) []interface {} {
357
+ a := make ([]interface {}, len (args ))
358
+
359
+ for i , arg := range args {
360
+ a [i ] = arg .Value
361
+ }
362
+
363
+ return a
364
+ }
365
+
255
366
func (st * state ) replyError (err error ) error {
256
367
isBadConnection := mysql .ErrorEqual (err , mysql .ErrBadConn )
257
368
@@ -275,6 +386,16 @@ func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, err
275
386
return & result {r }, nil
276
387
}
277
388
389
+ func (c * conn ) ExecContext (ctx context.Context , query string , args []sqldriver.NamedValue ) (sqldriver.Result , error ) {
390
+ defer c .watchCtx (ctx )()
391
+ a := buildNamedArgs (args )
392
+ r , err := c .Conn .Execute (query , a ... )
393
+ if err != nil {
394
+ return nil , c .state .replyError (err )
395
+ }
396
+ return & result {r }, nil
397
+ }
398
+
278
399
func (c * conn ) Query (query string , args []sqldriver.Value ) (sqldriver.Rows , error ) {
279
400
a := buildArgs (args )
280
401
r , err := c .Conn .Execute (query , a ... )
@@ -284,11 +405,25 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro
284
405
return newRows (r .Resultset )
285
406
}
286
407
408
+ func (c * conn ) QueryContext (ctx context.Context , query string , args []sqldriver.NamedValue ) (sqldriver.Rows , error ) {
409
+ defer c .watchCtx (ctx )()
410
+ a := buildNamedArgs (args )
411
+ r , err := c .Conn .Execute (query , a ... )
412
+ if err != nil {
413
+ return nil , c .state .replyError (err )
414
+ }
415
+ return newRows (r .Resultset )
416
+ }
417
+
287
418
type stmt struct {
288
419
* client.Stmt
289
420
connectionState * state
290
421
}
291
422
423
+ func (s * stmt ) watchCtx (ctx context.Context ) func () {
424
+ return s .connectionState .watchCtx (ctx )
425
+ }
426
+
292
427
func (s * stmt ) Close () error {
293
428
return s .Stmt .Close ()
294
429
}
@@ -306,6 +441,17 @@ func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) {
306
441
return & result {r }, nil
307
442
}
308
443
444
+ func (s * stmt ) ExecContext (ctx context.Context , args []sqldriver.NamedValue ) (sqldriver.Result , error ) {
445
+ defer s .watchCtx (ctx )()
446
+
447
+ a := buildNamedArgs (args )
448
+ r , err := s .Stmt .Execute (a ... )
449
+ if err != nil {
450
+ return nil , s .connectionState .replyError (err )
451
+ }
452
+ return & result {r }, nil
453
+ }
454
+
309
455
func (s * stmt ) Query (args []sqldriver.Value ) (sqldriver.Rows , error ) {
310
456
a := buildArgs (args )
311
457
r , err := s .Stmt .Execute (a ... )
@@ -315,6 +461,17 @@ func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
315
461
return newRows (r .Resultset )
316
462
}
317
463
464
+ func (s * stmt ) QueryContext (ctx context.Context , args []sqldriver.NamedValue ) (sqldriver.Rows , error ) {
465
+ defer s .watchCtx (ctx )()
466
+
467
+ a := buildNamedArgs (args )
468
+ r , err := s .Stmt .Execute (a ... )
469
+ if err != nil {
470
+ return nil , s .connectionState .replyError (err )
471
+ }
472
+ return newRows (r .Resultset )
473
+ }
474
+
318
475
type tx struct {
319
476
* client.Conn
320
477
}
0 commit comments