@@ -35,6 +35,8 @@ type Conn struct {
35
35
// Connection read and write timeouts to set on the connection
36
36
ReadTimeout time.Duration
37
37
WriteTimeout time.Duration
38
+ contexts chan context.Context
39
+ closed bool
38
40
39
41
// The buffer size to use in the packet connection
40
42
BufferSize int
@@ -136,6 +138,7 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam
136
138
c .password = password
137
139
c .db = dbName
138
140
c .proto = network
141
+ c .contexts = make (chan context.Context )
139
142
140
143
// use default charset here, utf-8
141
144
c .charset = mysql .DEFAULT_CHARSET
@@ -184,6 +187,21 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam
184
187
}
185
188
}
186
189
190
+ go func () {
191
+ ctx := context .Background ()
192
+ for {
193
+ var ok bool
194
+ select {
195
+ case <- ctx .Done ():
196
+ _ = c .Conn .SetDeadline (time .Unix (0 , 0 ))
197
+ case ctx , ok = <- c .contexts :
198
+ if ! ok {
199
+ return
200
+ }
201
+ }
202
+ }
203
+ }()
204
+
187
205
return c , nil
188
206
}
189
207
@@ -208,8 +226,19 @@ func (c *Conn) handshake() error {
208
226
return nil
209
227
}
210
228
229
+ func (c * Conn ) watchCtx (ctx context.Context ) func () {
230
+ c .contexts <- ctx
231
+ return func () {
232
+ c .contexts <- context .Background ()
233
+ }
234
+ }
235
+
211
236
// Close directly closes the connection. Use Quit() to first send COM_QUIT to the server and then close the connection.
212
237
func (c * Conn ) Close () error {
238
+ if ! c .closed {
239
+ close (c .contexts )
240
+ c .closed = true
241
+ }
213
242
return c .Conn .Close ()
214
243
}
215
244
@@ -309,6 +338,11 @@ func (c *Conn) Execute(command string, args ...interface{}) (*mysql.Result, erro
309
338
}
310
339
}
311
340
341
+ func (c * Conn ) ExecuteContext (ctx context.Context , command string , args ... interface {}) (* mysql.Result , error ) {
342
+ defer c .watchCtx (ctx )
343
+ return c .Execute (command , args ... )
344
+ }
345
+
312
346
// ExecuteMultiple will call perResultCallback for every result of the multiple queries
313
347
// that are executed.
314
348
//
@@ -363,6 +397,11 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall
363
397
return mysql .NewResult (rs ), nil
364
398
}
365
399
400
+ func (c * Conn ) ExecuteMultipleContext (ctx context.Context , query string , perResultCallback ExecPerResultCallback ) (* mysql.Result , error ) {
401
+ defer c .watchCtx (ctx )
402
+ return c .ExecuteMultiple (query , perResultCallback )
403
+ }
404
+
366
405
// ExecuteSelectStreaming will call perRowCallback for every row in resultset
367
406
// WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields.
368
407
// When given, perResultCallback will be called once per result
@@ -376,11 +415,33 @@ func (c *Conn) ExecuteSelectStreaming(command string, result *mysql.Result, perR
376
415
return c .readResultStreaming (false , result , perRowCallback , perResultCallback )
377
416
}
378
417
418
+ func (c * Conn ) ExecuteSelectStreamingContext (ctx context.Context , command string , result * mysql.Result , perRowCallback SelectPerRowCallback , perResultCallback SelectPerResultCallback ) error {
419
+ defer c .watchCtx (ctx )
420
+ return c .ExecuteSelectStreaming (command , result , perRowCallback , perResultCallback )
421
+ }
422
+
379
423
func (c * Conn ) Begin () error {
380
424
_ , err := c .exec ("BEGIN" )
381
425
return errors .Trace (err )
382
426
}
383
427
428
+ func (c * Conn ) BeginTx (ctx context.Context , readOnly bool , txIsolation string ) error {
429
+ defer c .watchCtx (ctx )()
430
+
431
+ if txIsolation != "" {
432
+ if _ , err := c .exec ("SET TRANSACTION ISOLATION LEVEL " + txIsolation ); err != nil {
433
+ return errors .Trace (err )
434
+ }
435
+ }
436
+ var err error
437
+ if readOnly {
438
+ _ , err = c .exec ("START TRANSACTION READ ONLY" )
439
+ } else {
440
+ _ , err = c .exec ("START TRANSACTION" )
441
+ }
442
+ return errors .Trace (err )
443
+ }
444
+
384
445
func (c * Conn ) Commit () error {
385
446
_ , err := c .exec ("COMMIT" )
386
447
return errors .Trace (err )
0 commit comments