Skip to content

Commit 87147ed

Browse files
committed
poc: context support
1 parent 019125f commit 87147ed

File tree

3 files changed

+110
-3
lines changed

3 files changed

+110
-3
lines changed

client/conn.go

+61
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ type Conn struct {
3535
// Connection read and write timeouts to set on the connection
3636
ReadTimeout time.Duration
3737
WriteTimeout time.Duration
38+
contexts chan context.Context
39+
closed bool
3840

3941
// The buffer size to use in the packet connection
4042
BufferSize int
@@ -136,6 +138,7 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam
136138
c.password = password
137139
c.db = dbName
138140
c.proto = network
141+
c.contexts = make(chan context.Context)
139142

140143
// use default charset here, utf-8
141144
c.charset = mysql.DEFAULT_CHARSET
@@ -184,6 +187,21 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam
184187
}
185188
}
186189

190+
go func() {
191+
ctx := context.Background()
192+
for {
193+
var ok bool
194+
select {
195+
case <-ctx.Done():
196+
_ = c.Conn.SetDeadline(time.Unix(0, 0))
197+
case ctx, ok = <-c.contexts:
198+
if !ok {
199+
return
200+
}
201+
}
202+
}
203+
}()
204+
187205
return c, nil
188206
}
189207

@@ -208,8 +226,19 @@ func (c *Conn) handshake() error {
208226
return nil
209227
}
210228

229+
func (c *Conn) watchCtx(ctx context.Context) func() {
230+
c.contexts <- ctx
231+
return func() {
232+
c.contexts <- context.Background()
233+
}
234+
}
235+
211236
// Close directly closes the connection. Use Quit() to first send COM_QUIT to the server and then close the connection.
212237
func (c *Conn) Close() error {
238+
if !c.closed {
239+
close(c.contexts)
240+
c.closed = true
241+
}
213242
return c.Conn.Close()
214243
}
215244

@@ -309,6 +338,11 @@ func (c *Conn) Execute(command string, args ...interface{}) (*mysql.Result, erro
309338
}
310339
}
311340

341+
func (c *Conn) ExecuteContext(ctx context.Context, command string, args ...interface{}) (*mysql.Result, error) {
342+
defer c.watchCtx(ctx)
343+
return c.Execute(command, args)
344+
}
345+
312346
// ExecuteMultiple will call perResultCallback for every result of the multiple queries
313347
// that are executed.
314348
//
@@ -363,6 +397,11 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall
363397
return mysql.NewResult(rs), nil
364398
}
365399

400+
func (c *Conn) ExecuteMultipleContext(ctx context.Context, query string, perResultCallback ExecPerResultCallback) (*mysql.Result, error) {
401+
defer c.watchCtx(ctx)
402+
return c.ExecuteMultiple(query, perResultCallback)
403+
}
404+
366405
// ExecuteSelectStreaming will call perRowCallback for every row in resultset
367406
// WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields.
368407
// When given, perResultCallback will be called once per result
@@ -376,11 +415,33 @@ func (c *Conn) ExecuteSelectStreaming(command string, result *mysql.Result, perR
376415
return c.readResultStreaming(false, result, perRowCallback, perResultCallback)
377416
}
378417

418+
func (c *Conn) ExecuteSelectStreamingContext(ctx context.Context, command string, result *mysql.Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error {
419+
defer c.watchCtx(ctx)
420+
return c.ExecuteSelectStreaming(command, result, perRowCallback, perResultCallback)
421+
}
422+
379423
func (c *Conn) Begin() error {
380424
_, err := c.exec("BEGIN")
381425
return errors.Trace(err)
382426
}
383427

428+
func (c *Conn) BeginTx(ctx context.Context, readOnly bool, txIsolation string) error {
429+
defer c.watchCtx(ctx)()
430+
431+
if txIsolation != "" {
432+
if _, err := c.exec("SET TRANSACTION ISOLATION LEVEL " + txIsolation); err != nil {
433+
return errors.Trace(err)
434+
}
435+
}
436+
var err error
437+
if readOnly {
438+
_, err = c.exec("START TRANSACTION READ ONLY")
439+
} else {
440+
_, err = c.exec("START TRANSACTION")
441+
}
442+
return errors.Trace(err)
443+
}
444+
384445
func (c *Conn) Commit() error {
385446
_, err := c.exec("COMMIT")
386447
return errors.Trace(err)

driver/driver.go

+49-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package driver
44

55
import (
6+
"context"
67
"crypto/tls"
78
"database/sql"
89
sqldriver "database/sql/driver"
@@ -184,8 +185,13 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
184185

185186
type CheckNamedValueFunc func(*sqldriver.NamedValue) error
186187

187-
var _ sqldriver.NamedValueChecker = &conn{}
188-
var _ sqldriver.Validator = &conn{}
188+
var (
189+
_ sqldriver.NamedValueChecker = &conn{}
190+
_ sqldriver.Validator = &conn{}
191+
_ sqldriver.Conn = &conn{}
192+
_ sqldriver.ConnBeginTx = &conn{}
193+
_ sqldriver.QueryerContext = &conn{}
194+
)
189195

190196
type state struct {
191197
valid bool
@@ -242,6 +248,27 @@ func (c *conn) Begin() (sqldriver.Tx, error) {
242248
return &tx{c.Conn}, nil
243249
}
244250

251+
var isolationLevelTransactionIsolation = map[sql.IsolationLevel]string{
252+
sql.LevelDefault: "",
253+
sql.LevelRepeatableRead: "REPEATABLE READ",
254+
sql.LevelReadCommitted: "READ COMMITTED",
255+
sql.LevelReadUncommitted: "READ UNCOMMITTED",
256+
sql.LevelSerializable: "SERIALIZABLE",
257+
}
258+
259+
func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver.Tx, error) {
260+
isolation := sql.IsolationLevel(opts.Isolation)
261+
txIsolation, ok := isolationLevelTransactionIsolation[isolation]
262+
if !ok {
263+
return nil, fmt.Errorf("invalid mysql transaction isolation level %s", isolation)
264+
}
265+
err := c.Conn.BeginTx(ctx, opts.ReadOnly, txIsolation)
266+
if err != nil {
267+
return nil, errors.Trace(err)
268+
}
269+
return &tx{c.Conn}, nil
270+
}
271+
245272
func buildArgs(args []sqldriver.Value) []interface{} {
246273
a := make([]interface{}, len(args))
247274

@@ -252,6 +279,17 @@ func buildArgs(args []sqldriver.Value) []interface{} {
252279
return a
253280
}
254281

282+
func buildNamedArgs(args []sqldriver.NamedValue) []interface{} {
283+
a := make([]interface{}, len(args))
284+
285+
for i, arg := range args {
286+
// TODO named parameter support
287+
a[i] = arg.Value
288+
}
289+
290+
return a
291+
}
292+
255293
func (st *state) replyError(err error) error {
256294
isBadConnection := mysql.ErrorEqual(err, mysql.ErrBadConn)
257295

@@ -284,6 +322,15 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro
284322
return newRows(r.Resultset)
285323
}
286324

325+
func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
326+
a := buildNamedArgs(args)
327+
r, err := c.Conn.ExecuteContext(ctx, query, a...)
328+
if err != nil {
329+
return nil, c.state.replyError(err)
330+
}
331+
return newRows(r.Resultset)
332+
}
333+
287334
type stmt struct {
288335
*client.Stmt
289336
connectionState *state

packet/conn.go

-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ type Conn struct {
2929

3030
readTimeout time.Duration
3131
writeTimeout time.Duration
32-
ctx context.Context
3332

3433
// Buffered reader for net.Conn in Non-TLS connection only to address replication performance issue.
3534
// See https://github.com/go-mysql-org/go-mysql/pull/422 for more details.

0 commit comments

Comments
 (0)