diff --git a/client/conn.go b/client/conn.go index f82b3c465..00b72843a 100644 --- a/client/conn.go +++ b/client/conn.go @@ -382,6 +382,21 @@ func (c *Conn) Begin() error { return errors.Trace(err) } +func (c *Conn) BeginTx(readOnly bool, txIsolation string) error { + if txIsolation != "" { + if _, err := c.exec("SET TRANSACTION ISOLATION LEVEL " + txIsolation); err != nil { + return errors.Trace(err) + } + } + var err error + if readOnly { + _, err = c.exec("START TRANSACTION READ ONLY") + } else { + _, err = c.exec("START TRANSACTION") + } + return errors.Trace(err) +} + func (c *Conn) Commit() error { _, err := c.exec("COMMIT") return errors.Trace(err) diff --git a/driver/driver.go b/driver/driver.go index fc82c2572..71827df8c 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -3,6 +3,7 @@ package driver import ( + "context" "crypto/tls" "database/sql" sqldriver "database/sql/driver" @@ -21,6 +22,23 @@ import ( "github.com/pingcap/errors" ) +var ( + _ sqldriver.Driver = &driver{} + _ sqldriver.DriverContext = &driver{} + _ sqldriver.Connector = &connInfo{} + _ sqldriver.NamedValueChecker = &conn{} + _ sqldriver.Validator = &conn{} + _ sqldriver.Conn = &conn{} + _ sqldriver.Pinger = &conn{} + _ sqldriver.ConnBeginTx = &conn{} + _ sqldriver.ConnPrepareContext = &conn{} + _ sqldriver.ExecerContext = &conn{} + _ sqldriver.QueryerContext = &conn{} + _ sqldriver.Stmt = &stmt{} + _ sqldriver.StmtExecContext = &stmt{} + _ sqldriver.StmtQueryContext = &stmt{} +) + var customTLSMutex sync.Mutex // Map of dsn address (makes more sense than full dsn?) to tls Config @@ -101,16 +119,18 @@ func parseDSN(dsn string) (connInfo, error) { // Open takes a supplied DSN string and opens a connection // See ParseDSN for more information on the form of the DSN func (d driver) Open(dsn string) (sqldriver.Conn, error) { - var ( - c *client.Conn - // by default database/sql driver retries will be enabled - retries = true - ) - ci, err := parseDSN(dsn) if err != nil { return nil, err } + return ci.Connect(context.Background()) +} + +func (ci connInfo) Connect(ctx context.Context) (sqldriver.Conn, error) { + var c *client.Conn + var err error + // by default database/sql driver retries will be enabled + retries := true if ci.standardDSN { var timeout time.Duration @@ -159,45 +179,86 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { } } - if timeout > 0 { - c, err = client.ConnectWithTimeout(ci.addr, ci.user, ci.password, ci.db, timeout, configuredOptions...) - } else { - c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, configuredOptions...) + if timeout <= 0 { + timeout = 10 * time.Second } + c, err = client.ConnectWithContext(ctx, ci.addr, ci.user, ci.password, ci.db, timeout, configuredOptions...) } else { // No more processing here. Let's only support url parameters with the newer style DSN - c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db) + c, err = client.ConnectWithContext(ctx, ci.addr, ci.user, ci.password, ci.db, 10*time.Second) } if err != nil { return nil, err } + contexts := make(chan context.Context) + go func() { + ctx := context.Background() + for { + var ok bool + select { + case <-ctx.Done(): + ctx = context.Background() + _ = c.Conn.Close() + case ctx, ok = <-contexts: + if !ok { + return + } + } + } + }() + // if retries are 'on' then return sqldriver.ErrBadConn which will trigger up to 3 // retries by the database/sql package. If retries are 'off' then we'll return // the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry. // In this case the sqldriver.Validator interface is implemented and will return // false for IsValid() signaling the connection is bad and should be discarded. - return &conn{Conn: c, state: &state{valid: true, useStdLibErrors: retries}}, nil + return &conn{ + Conn: c, + state: &state{contexts: contexts, valid: true, useStdLibErrors: retries}, + }, nil } -type CheckNamedValueFunc func(*sqldriver.NamedValue) error +func (d driver) OpenConnector(name string) (sqldriver.Connector, error) { + return parseDSN(name) +} -var ( - _ sqldriver.NamedValueChecker = &conn{} - _ sqldriver.Validator = &conn{} -) +func (ci connInfo) Driver() sqldriver.Driver { + return driver{} +} + +type CheckNamedValueFunc func(*sqldriver.NamedValue) error type state struct { - valid bool + contexts chan context.Context + valid bool // when true, the driver connection will return ErrBadConn from the golang Standard Library useStdLibErrors bool } +func (s *state) watchCtx(ctx context.Context) func() { + s.contexts <- ctx + return func() { + s.contexts <- context.Background() + } +} + +func (s *state) Close() { + if s.contexts != nil { + close(s.contexts) + s.contexts = nil + } +} + type conn struct { *client.Conn state *state } +func (c *conn) watchCtx(ctx context.Context) func() { + return c.state.watchCtx(ctx) +} + func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error { for _, nvChecker := range namedValueCheckers { err := nvChecker(nv) @@ -220,6 +281,17 @@ func (c *conn) IsValid() bool { return c.state.valid } +func (c *conn) Ping(ctx context.Context) error { + defer c.watchCtx(ctx)() + if err := c.Conn.Ping(); err != nil { + if err == context.DeadlineExceeded || err == context.Canceled { + return err + } + return sqldriver.ErrBadConn + } + return nil +} + func (c *conn) Prepare(query string) (sqldriver.Stmt, error) { st, err := c.Conn.Prepare(query) if err != nil { @@ -229,7 +301,13 @@ func (c *conn) Prepare(query string) (sqldriver.Stmt, error) { return &stmt{Stmt: st, connectionState: c.state}, nil } +func (c *conn) PrepareContext(ctx context.Context, query string) (sqldriver.Stmt, error) { + defer c.watchCtx(ctx)() + return c.Prepare(query) +} + func (c *conn) Close() error { + c.state.Close() return c.Conn.Close() } @@ -242,6 +320,29 @@ func (c *conn) Begin() (sqldriver.Tx, error) { return &tx{c.Conn}, nil } +var isolationLevelTransactionIsolation = map[sql.IsolationLevel]string{ + sql.LevelDefault: "", + sql.LevelRepeatableRead: "REPEATABLE READ", + sql.LevelReadCommitted: "READ COMMITTED", + sql.LevelReadUncommitted: "READ UNCOMMITTED", + sql.LevelSerializable: "SERIALIZABLE", +} + +func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver.Tx, error) { + defer c.watchCtx(ctx)() + + isolation := sql.IsolationLevel(opts.Isolation) + txIsolation, ok := isolationLevelTransactionIsolation[isolation] + if !ok { + return nil, fmt.Errorf("invalid mysql transaction isolation level %s", isolation) + } + err := c.Conn.BeginTx(opts.ReadOnly, txIsolation) + if err != nil { + return nil, errors.Trace(err) + } + return &tx{c.Conn}, nil +} + func buildArgs(args []sqldriver.Value) []interface{} { a := make([]interface{}, len(args)) @@ -252,6 +353,16 @@ func buildArgs(args []sqldriver.Value) []interface{} { return a } +func buildNamedArgs(args []sqldriver.NamedValue) []interface{} { + a := make([]interface{}, len(args)) + + for i, arg := range args { + a[i] = arg.Value + } + + return a +} + func (st *state) replyError(err error) error { isBadConnection := mysql.ErrorEqual(err, mysql.ErrBadConn) @@ -275,6 +386,16 @@ func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, err return &result{r}, nil } +func (c *conn) ExecContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Result, error) { + defer c.watchCtx(ctx)() + a := buildNamedArgs(args) + r, err := c.Conn.Execute(query, a...) + if err != nil { + return nil, c.state.replyError(err) + } + return &result{r}, nil +} + func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, error) { a := buildArgs(args) r, err := c.Conn.Execute(query, a...) @@ -284,11 +405,25 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro return newRows(r.Resultset) } +func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) { + defer c.watchCtx(ctx)() + a := buildNamedArgs(args) + r, err := c.Conn.Execute(query, a...) + if err != nil { + return nil, c.state.replyError(err) + } + return newRows(r.Resultset) +} + type stmt struct { *client.Stmt connectionState *state } +func (s *stmt) watchCtx(ctx context.Context) func() { + return s.connectionState.watchCtx(ctx) +} + func (s *stmt) Close() error { return s.Stmt.Close() } @@ -306,6 +441,17 @@ func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) { return &result{r}, nil } +func (s *stmt) ExecContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Result, error) { + defer s.watchCtx(ctx)() + + a := buildNamedArgs(args) + r, err := s.Stmt.Execute(a...) + if err != nil { + return nil, s.connectionState.replyError(err) + } + return &result{r}, nil +} + func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) { a := buildArgs(args) r, err := s.Stmt.Execute(a...) @@ -315,6 +461,17 @@ func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) { return newRows(r.Resultset) } +func (s *stmt) QueryContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Rows, error) { + defer s.watchCtx(ctx)() + + a := buildNamedArgs(args) + r, err := s.Stmt.Execute(a...) + if err != nil { + return nil, s.connectionState.replyError(err) + } + return newRows(r.Resultset) +} + type tx struct { *client.Conn } diff --git a/packet/conn.go b/packet/conn.go index 27ead18cf..5b3930dae 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -360,7 +360,7 @@ func (c *Conn) writeCompressed(data []byte) (n int, err error) { var ( compressedLength, uncompressedLength int payload *bytes.Buffer - compressedHeader = make([]byte, 7) + compressedHeader [7]byte ) if len(data) > MinCompressionLength { @@ -406,7 +406,7 @@ func (c *Conn) writeCompressed(data []byte) (n int, err error) { compressedHeader[4] = byte(uncompressedLength) compressedHeader[5] = byte(uncompressedLength >> 8) compressedHeader[6] = byte(uncompressedLength >> 16) - if _, err = compressedPacket.Write(compressedHeader); err != nil { + if _, err = compressedPacket.Write(compressedHeader[:]); err != nil { return 0, err } c.CompressedSequence++