@@ -22,6 +22,23 @@ import (
22
22
"github.com/pingcap/errors"
23
23
)
24
24
25
+ var (
26
+ _ sqldriver.Driver = & driver {}
27
+ _ sqldriver.DriverContext = & driver {}
28
+ _ sqldriver.Connector = & connInfo {}
29
+ _ sqldriver.NamedValueChecker = & conn {}
30
+ _ sqldriver.Validator = & conn {}
31
+ _ sqldriver.Conn = & conn {}
32
+ _ sqldriver.Pinger = & conn {}
33
+ _ sqldriver.ConnBeginTx = & conn {}
34
+ _ sqldriver.ConnPrepareContext = & conn {}
35
+ _ sqldriver.ExecerContext = & conn {}
36
+ _ sqldriver.QueryerContext = & conn {}
37
+ _ sqldriver.Stmt = & stmt {}
38
+ _ sqldriver.StmtExecContext = & stmt {}
39
+ _ sqldriver.StmtQueryContext = & stmt {}
40
+ )
41
+
25
42
var customTLSMutex sync.Mutex
26
43
27
44
// Map of dsn address (makes more sense than full dsn?) to tls Config
@@ -102,16 +119,19 @@ func parseDSN(dsn string) (connInfo, error) {
102
119
// Open takes a supplied DSN string and opens a connection
103
120
// See ParseDSN for more information on the form of the DSN
104
121
func (d driver ) Open (dsn string ) (sqldriver.Conn , error ) {
105
- var (
106
- c * client.Conn
107
- // by default database/sql driver retries will be enabled
108
- retries = true
109
- )
110
-
111
122
ci , err := parseDSN (dsn )
112
123
if err != nil {
113
124
return nil , err
114
125
}
126
+ return ci .Connect (context .Background ())
127
+
128
+ }
129
+
130
+ func (ci connInfo ) Connect (ctx context.Context ) (sqldriver.Conn , error ) {
131
+ var c * client.Conn
132
+ var err error
133
+ // by default database/sql driver retries will be enabled
134
+ retries := true
115
135
116
136
if ci .standardDSN {
117
137
var timeout time.Duration
@@ -160,48 +180,85 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
160
180
}
161
181
}
162
182
163
- if timeout > 0 {
164
- c , err = client .ConnectWithTimeout (ci .addr , ci .user , ci .password , ci .db , timeout , configuredOptions ... )
165
- } else {
166
- c , err = client .Connect (ci .addr , ci .user , ci .password , ci .db , configuredOptions ... )
183
+ if timeout <= 0 {
184
+ timeout = 10 * time .Second
167
185
}
186
+ c , err = client .ConnectWithContext (ctx , ci .addr , ci .user , ci .password , ci .db , timeout , configuredOptions ... )
168
187
} else {
169
188
// No more processing here. Let's only support url parameters with the newer style DSN
170
- c , err = client .Connect ( ci .addr , ci .user , ci .password , ci .db )
189
+ c , err = client .ConnectWithContext ( ctx , ci .addr , ci .user , ci .password , ci .db , 10 * time . Second )
171
190
}
172
191
if err != nil {
173
192
return nil , err
174
193
}
175
194
195
+ contexts := make (chan context.Context )
196
+ go func () {
197
+ ctx := context .Background ()
198
+ for {
199
+ var ok bool
200
+ select {
201
+ case <- ctx .Done ():
202
+ _ = c .Conn .Close ()
203
+ case ctx , ok = <- contexts :
204
+ if ! ok {
205
+ return
206
+ }
207
+ }
208
+ }
209
+ }()
210
+
176
211
// if retries are 'on' then return sqldriver.ErrBadConn which will trigger up to 3
177
212
// retries by the database/sql package. If retries are 'off' then we'll return
178
213
// the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry.
179
214
// In this case the sqldriver.Validator interface is implemented and will return
180
215
// false for IsValid() signaling the connection is bad and should be discarded.
181
- return & conn {Conn : c , state : & state {valid : true , useStdLibErrors : retries }}, nil
216
+ return & conn {
217
+ Conn : c ,
218
+ state : & state {contexts : contexts , valid : true , useStdLibErrors : retries },
219
+ }, nil
182
220
}
183
221
184
- type CheckNamedValueFunc func (* sqldriver.NamedValue ) error
222
+ func (d driver ) OpenConnector (name string ) (sqldriver.Connector , error ) {
223
+ return parseDSN (name )
224
+ }
185
225
186
- var (
187
- _ sqldriver.NamedValueChecker = & conn {}
188
- _ sqldriver.Validator = & conn {}
189
- _ sqldriver.Conn = & conn {}
190
- _ sqldriver.ConnBeginTx = & conn {}
191
- _ sqldriver.QueryerContext = & conn {}
192
- )
226
+ func (ci connInfo ) Driver () sqldriver.Driver {
227
+ return driver {}
228
+ }
229
+
230
+ type CheckNamedValueFunc func (* sqldriver.NamedValue ) error
193
231
194
232
type state struct {
195
- valid bool
233
+ contexts chan context.Context
234
+ valid bool
196
235
// when true, the driver connection will return ErrBadConn from the golang Standard Library
197
236
useStdLibErrors bool
198
237
}
199
238
239
+ func (s * state ) watchCtx (ctx context.Context ) func () {
240
+ s .contexts <- ctx
241
+ return func () {
242
+ s .contexts <- context .Background ()
243
+ }
244
+ }
245
+
246
+ func (s * state ) Close () {
247
+ if s .contexts != nil {
248
+ close (s .contexts )
249
+ s .contexts = nil
250
+ }
251
+ }
252
+
200
253
type conn struct {
201
254
* client.Conn
202
255
state * state
203
256
}
204
257
258
+ func (c * conn ) watchCtx (ctx context.Context ) func () {
259
+ return c .state .watchCtx (ctx )
260
+ }
261
+
205
262
func (c * conn ) CheckNamedValue (nv * sqldriver.NamedValue ) error {
206
263
for _ , nvChecker := range namedValueCheckers {
207
264
err := nvChecker (nv )
@@ -224,6 +281,17 @@ func (c *conn) IsValid() bool {
224
281
return c .state .valid
225
282
}
226
283
284
+ func (c * conn ) Ping (ctx context.Context ) error {
285
+ defer c .watchCtx (ctx )()
286
+ if err := c .Conn .Ping (); err != nil {
287
+ if err == context .DeadlineExceeded || err == context .Canceled {
288
+ return err
289
+ }
290
+ return sqldriver .ErrBadConn
291
+ }
292
+ return nil
293
+ }
294
+
227
295
func (c * conn ) Prepare (query string ) (sqldriver.Stmt , error ) {
228
296
st , err := c .Conn .Prepare (query )
229
297
if err != nil {
@@ -233,7 +301,13 @@ func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
233
301
return & stmt {Stmt : st , connectionState : c .state }, nil
234
302
}
235
303
304
+ func (c * conn ) PrepareContext (ctx context.Context , query string ) (sqldriver.Stmt , error ) {
305
+ defer c .watchCtx (ctx )()
306
+ return c .Prepare (query )
307
+ }
308
+
236
309
func (c * conn ) Close () error {
310
+ c .state .Close ()
237
311
return c .Conn .Close ()
238
312
}
239
313
@@ -255,12 +329,14 @@ var isolationLevelTransactionIsolation = map[sql.IsolationLevel]string{
255
329
}
256
330
257
331
func (c * conn ) BeginTx (ctx context.Context , opts sqldriver.TxOptions ) (sqldriver.Tx , error ) {
332
+ defer c .watchCtx (ctx )()
333
+
258
334
isolation := sql .IsolationLevel (opts .Isolation )
259
335
txIsolation , ok := isolationLevelTransactionIsolation [isolation ]
260
336
if ! ok {
261
337
return nil , fmt .Errorf ("invalid mysql transaction isolation level %s" , isolation )
262
338
}
263
- err := c .Conn .BeginTx (ctx , opts .ReadOnly , txIsolation )
339
+ err := c .Conn .BeginTx (opts .ReadOnly , txIsolation )
264
340
if err != nil {
265
341
return nil , errors .Trace (err )
266
342
}
@@ -311,6 +387,16 @@ func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, err
311
387
return & result {r }, nil
312
388
}
313
389
390
+ func (c * conn ) ExecContext (ctx context.Context , query string , args []sqldriver.NamedValue ) (sqldriver.Result , error ) {
391
+ defer c .watchCtx (ctx )()
392
+ a := buildNamedArgs (args )
393
+ r , err := c .Conn .Execute (query , a ... )
394
+ if err != nil {
395
+ return nil , c .state .replyError (err )
396
+ }
397
+ return & result {r }, nil
398
+ }
399
+
314
400
func (c * conn ) Query (query string , args []sqldriver.Value ) (sqldriver.Rows , error ) {
315
401
a := buildArgs (args )
316
402
r , err := c .Conn .Execute (query , a ... )
@@ -321,8 +407,9 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro
321
407
}
322
408
323
409
func (c * conn ) QueryContext (ctx context.Context , query string , args []sqldriver.NamedValue ) (sqldriver.Rows , error ) {
410
+ defer c .watchCtx (ctx )()
324
411
a := buildNamedArgs (args )
325
- r , err := c .Conn .ExecuteContext ( ctx , query , a ... )
412
+ r , err := c .Conn .Execute ( query , a ... )
326
413
if err != nil {
327
414
return nil , c .state .replyError (err )
328
415
}
@@ -334,6 +421,10 @@ type stmt struct {
334
421
connectionState * state
335
422
}
336
423
424
+ func (s * stmt ) watchCtx (ctx context.Context ) func () {
425
+ return s .connectionState .watchCtx (ctx )
426
+ }
427
+
337
428
func (s * stmt ) Close () error {
338
429
return s .Stmt .Close ()
339
430
}
@@ -351,6 +442,17 @@ func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) {
351
442
return & result {r }, nil
352
443
}
353
444
445
+ func (s * stmt ) ExecContext (ctx context.Context , args []sqldriver.NamedValue ) (sqldriver.Result , error ) {
446
+ defer s .watchCtx (ctx )()
447
+
448
+ a := buildNamedArgs (args )
449
+ r , err := s .Stmt .Execute (a ... )
450
+ if err != nil {
451
+ return nil , s .connectionState .replyError (err )
452
+ }
453
+ return & result {r }, nil
454
+ }
455
+
354
456
func (s * stmt ) Query (args []sqldriver.Value ) (sqldriver.Rows , error ) {
355
457
a := buildArgs (args )
356
458
r , err := s .Stmt .Execute (a ... )
@@ -360,6 +462,17 @@ func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
360
462
return newRows (r .Resultset )
361
463
}
362
464
465
+ func (s * stmt ) QueryContext (ctx context.Context , args []sqldriver.NamedValue ) (sqldriver.Rows , error ) {
466
+ defer s .watchCtx (ctx )()
467
+
468
+ a := buildNamedArgs (args )
469
+ r , err := s .Stmt .Execute (a ... )
470
+ if err != nil {
471
+ return nil , s .connectionState .replyError (err )
472
+ }
473
+ return newRows (r .Resultset )
474
+ }
475
+
363
476
type tx struct {
364
477
* client.Conn
365
478
}
0 commit comments