Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

driver: context support #997

Merged
merged 8 commits into from
Feb 24, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
move context management into driver
Close instead of SetTimeout hack
implement more interfaces, with assertions
serprex committed Feb 21, 2025
commit 1c3ba7d8d70344b14505fbea19cbd3db4e1bce1d
48 changes: 1 addition & 47 deletions client/conn.go
Original file line number Diff line number Diff line change
@@ -35,8 +35,6 @@ type Conn struct {
// Connection read and write timeouts to set on the connection
ReadTimeout time.Duration
WriteTimeout time.Duration
contexts chan context.Context
closed bool

// The buffer size to use in the packet connection
BufferSize int
@@ -138,7 +136,6 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam
c.password = password
c.db = dbName
c.proto = network
c.contexts = make(chan context.Context)

// use default charset here, utf-8
c.charset = mysql.DEFAULT_CHARSET
@@ -187,21 +184,6 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam
}
}

go func() {
ctx := context.Background()
for {
var ok bool
select {
case <-ctx.Done():
_ = c.Conn.SetDeadline(time.Unix(0, 0))
case ctx, ok = <-c.contexts:
if !ok {
return
}
}
}
}()

return c, nil
}

@@ -226,19 +208,8 @@ func (c *Conn) handshake() error {
return nil
}

func (c *Conn) watchCtx(ctx context.Context) func() {
c.contexts <- ctx
return func() {
c.contexts <- context.Background()
}
}

// Close directly closes the connection. Use Quit() to first send COM_QUIT to the server and then close the connection.
func (c *Conn) Close() error {
if !c.closed {
close(c.contexts)
c.closed = true
}
return c.Conn.Close()
}

@@ -338,11 +309,6 @@ func (c *Conn) Execute(command string, args ...interface{}) (*mysql.Result, erro
}
}

func (c *Conn) ExecuteContext(ctx context.Context, command string, args ...interface{}) (*mysql.Result, error) {
defer c.watchCtx(ctx)
return c.Execute(command, args...)
}

// ExecuteMultiple will call perResultCallback for every result of the multiple queries
// that are executed.
//
@@ -397,11 +363,6 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall
return mysql.NewResult(rs), nil
}

func (c *Conn) ExecuteMultipleContext(ctx context.Context, query string, perResultCallback ExecPerResultCallback) (*mysql.Result, error) {
defer c.watchCtx(ctx)
return c.ExecuteMultiple(query, perResultCallback)
}

// ExecuteSelectStreaming will call perRowCallback for every row in resultset
// WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields.
// When given, perResultCallback will be called once per result
@@ -415,19 +376,12 @@ func (c *Conn) ExecuteSelectStreaming(command string, result *mysql.Result, perR
return c.readResultStreaming(false, result, perRowCallback, perResultCallback)
}

func (c *Conn) ExecuteSelectStreamingContext(ctx context.Context, command string, result *mysql.Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error {
defer c.watchCtx(ctx)
return c.ExecuteSelectStreaming(command, result, perRowCallback, perResultCallback)
}

func (c *Conn) Begin() error {
_, err := c.exec("BEGIN")
return errors.Trace(err)
}

func (c *Conn) BeginTx(ctx context.Context, readOnly bool, txIsolation string) error {
defer c.watchCtx(ctx)()

func (c *Conn) BeginTx(readOnly bool, txIsolation string) error {
if txIsolation != "" {
if _, err := c.exec("SET TRANSACTION ISOLATION LEVEL " + txIsolation); err != nil {
return errors.Trace(err)
159 changes: 136 additions & 23 deletions driver/driver.go
Original file line number Diff line number Diff line change
@@ -22,6 +22,23 @@
"github.com/pingcap/errors"
)

var (
_ sqldriver.Driver = &driver{}
_ sqldriver.DriverContext = &driver{}
_ sqldriver.Connector = &connInfo{}
_ sqldriver.NamedValueChecker = &conn{}
_ sqldriver.Validator = &conn{}
_ sqldriver.Conn = &conn{}
_ sqldriver.Pinger = &conn{}
_ sqldriver.ConnBeginTx = &conn{}
_ sqldriver.ConnPrepareContext = &conn{}
_ sqldriver.ExecerContext = &conn{}
_ sqldriver.QueryerContext = &conn{}
_ sqldriver.Stmt = &stmt{}
_ sqldriver.StmtExecContext = &stmt{}
_ sqldriver.StmtQueryContext = &stmt{}
)

var customTLSMutex sync.Mutex

// Map of dsn address (makes more sense than full dsn?) to tls Config
@@ -102,16 +119,19 @@
// Open takes a supplied DSN string and opens a connection
// See ParseDSN for more information on the form of the DSN
func (d driver) Open(dsn string) (sqldriver.Conn, error) {
var (
c *client.Conn
// by default database/sql driver retries will be enabled
retries = true
)

ci, err := parseDSN(dsn)
if err != nil {
return nil, err
}
return ci.Connect(context.Background())

Check failure on line 127 in driver/driver.go

GitHub Actions / golangci

File is not properly formatted (gofumpt)
}

Check failure on line 128 in driver/driver.go

GitHub Actions / golangci

unnecessary trailing newline (whitespace)

func (ci connInfo) Connect(ctx context.Context) (sqldriver.Conn, error) {
var c *client.Conn
var err error
// by default database/sql driver retries will be enabled
retries := true

if ci.standardDSN {
var timeout time.Duration
@@ -160,48 +180,85 @@
}
}

if timeout > 0 {
c, err = client.ConnectWithTimeout(ci.addr, ci.user, ci.password, ci.db, timeout, configuredOptions...)
} else {
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, configuredOptions...)
if timeout <= 0 {
timeout = 10 * time.Second
}
c, err = client.ConnectWithContext(ctx, ci.addr, ci.user, ci.password, ci.db, timeout, configuredOptions...)
} else {
// No more processing here. Let's only support url parameters with the newer style DSN
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db)
c, err = client.ConnectWithContext(ctx, ci.addr, ci.user, ci.password, ci.db, 10*time.Second)
}
if err != nil {
return nil, err
}

contexts := make(chan context.Context)
go func() {
ctx := context.Background()
for {
var ok bool
select {
case <-ctx.Done():
_ = c.Conn.Close()
case ctx, ok = <-contexts:
if !ok {
return
}
}
}
}()

// if retries are 'on' then return sqldriver.ErrBadConn which will trigger up to 3
// retries by the database/sql package. If retries are 'off' then we'll return
// the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry.
// In this case the sqldriver.Validator interface is implemented and will return
// false for IsValid() signaling the connection is bad and should be discarded.
return &conn{Conn: c, state: &state{valid: true, useStdLibErrors: retries}}, nil
return &conn{
Conn: c,
state: &state{contexts: contexts, valid: true, useStdLibErrors: retries},
}, nil
}

type CheckNamedValueFunc func(*sqldriver.NamedValue) error
func (d driver) OpenConnector(name string) (sqldriver.Connector, error) {
return parseDSN(name)
}

var (
_ sqldriver.NamedValueChecker = &conn{}
_ sqldriver.Validator = &conn{}
_ sqldriver.Conn = &conn{}
_ sqldriver.ConnBeginTx = &conn{}
_ sqldriver.QueryerContext = &conn{}
)
func (ci connInfo) Driver() sqldriver.Driver {
return driver{}
}

type CheckNamedValueFunc func(*sqldriver.NamedValue) error

type state struct {
valid bool
contexts chan context.Context
valid bool
// when true, the driver connection will return ErrBadConn from the golang Standard Library
useStdLibErrors bool
}

func (s *state) watchCtx(ctx context.Context) func() {
s.contexts <- ctx
return func() {
s.contexts <- context.Background()
}
}

func (s *state) Close() {
if s.contexts != nil {
close(s.contexts)
s.contexts = nil
}
}

type conn struct {
*client.Conn
state *state
}

func (c *conn) watchCtx(ctx context.Context) func() {
return c.state.watchCtx(ctx)
}

func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error {
for _, nvChecker := range namedValueCheckers {
err := nvChecker(nv)
@@ -224,6 +281,17 @@
return c.state.valid
}

func (c *conn) Ping(ctx context.Context) error {
defer c.watchCtx(ctx)()
if err := c.Conn.Ping(); err != nil {
if err == context.DeadlineExceeded || err == context.Canceled {
return err
}
return sqldriver.ErrBadConn
}
return nil
}

func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
st, err := c.Conn.Prepare(query)
if err != nil {
@@ -233,7 +301,13 @@
return &stmt{Stmt: st, connectionState: c.state}, nil
}

func (c *conn) PrepareContext(ctx context.Context, query string) (sqldriver.Stmt, error) {
defer c.watchCtx(ctx)()
return c.Prepare(query)
}

func (c *conn) Close() error {
c.state.Close()
return c.Conn.Close()
}

@@ -255,12 +329,14 @@
}

func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver.Tx, error) {
defer c.watchCtx(ctx)()

isolation := sql.IsolationLevel(opts.Isolation)
txIsolation, ok := isolationLevelTransactionIsolation[isolation]
if !ok {
return nil, fmt.Errorf("invalid mysql transaction isolation level %s", isolation)
}
err := c.Conn.BeginTx(ctx, opts.ReadOnly, txIsolation)
err := c.Conn.BeginTx(opts.ReadOnly, txIsolation)
if err != nil {
return nil, errors.Trace(err)
}
@@ -311,6 +387,16 @@
return &result{r}, nil
}

func (c *conn) ExecContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Result, error) {
defer c.watchCtx(ctx)()
a := buildNamedArgs(args)
r, err := c.Conn.Execute(query, a...)
if err != nil {
return nil, c.state.replyError(err)
}
return &result{r}, nil
}

func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, error) {
a := buildArgs(args)
r, err := c.Conn.Execute(query, a...)
@@ -321,8 +407,9 @@
}

func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
defer c.watchCtx(ctx)()
a := buildNamedArgs(args)
r, err := c.Conn.ExecuteContext(ctx, query, a...)
r, err := c.Conn.Execute(query, a...)
if err != nil {
return nil, c.state.replyError(err)
}
@@ -334,6 +421,10 @@
connectionState *state
}

func (s *stmt) watchCtx(ctx context.Context) func() {
return s.connectionState.watchCtx(ctx)
}

func (s *stmt) Close() error {
return s.Stmt.Close()
}
@@ -351,6 +442,17 @@
return &result{r}, nil
}

func (s *stmt) ExecContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Result, error) {
defer s.watchCtx(ctx)()

a := buildNamedArgs(args)
r, err := s.Stmt.Execute(a...)
if err != nil {
return nil, s.connectionState.replyError(err)
}
return &result{r}, nil
}

func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
a := buildArgs(args)
r, err := s.Stmt.Execute(a...)
@@ -360,6 +462,17 @@
return newRows(r.Resultset)
}

func (s *stmt) QueryContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
defer s.watchCtx(ctx)()

a := buildNamedArgs(args)
r, err := s.Stmt.Execute(a...)
if err != nil {
return nil, s.connectionState.replyError(err)
}
return newRows(r.Resultset)
}

type tx struct {
*client.Conn
}