From f7c4036e84af00cc4f39949fef38db1f25c8ce28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Thu, 20 Feb 2025 17:47:30 +0000 Subject: [PATCH 1/7] wip --- packet/conn.go | 1 + 1 file changed, 1 insertion(+) diff --git a/packet/conn.go b/packet/conn.go index 27ead18cf..172714e0b 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -31,6 +31,7 @@ type Conn struct { readTimeout time.Duration writeTimeout time.Duration + ctx context.Context // Buffered reader for net.Conn in Non-TLS connection only to address replication performance issue. // See https://github.com/go-mysql-org/go-mysql/pull/422 for more details. From 34098189f7e440d1dfeab1f81cc94dbc29d01124 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Fri, 21 Feb 2025 04:01:44 +0000 Subject: [PATCH 2/7] poc: context support --- client/conn.go | 61 ++++++++++++++++++++++++++++++++++++++++++++++++ driver/driver.go | 45 +++++++++++++++++++++++++++++++++++ packet/conn.go | 1 - 3 files changed, 106 insertions(+), 1 deletion(-) diff --git a/client/conn.go b/client/conn.go index 928e1dfd1..b3ba07ea2 100644 --- a/client/conn.go +++ b/client/conn.go @@ -35,6 +35,8 @@ type Conn struct { // Connection read and write timeouts to set on the connection ReadTimeout time.Duration WriteTimeout time.Duration + contexts chan context.Context + closed bool // The buffer size to use in the packet connection BufferSize int @@ -136,6 +138,7 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam c.password = password c.db = dbName c.proto = network + c.contexts = make(chan context.Context) // use default charset here, utf-8 c.charset = mysql.DEFAULT_CHARSET @@ -184,6 +187,21 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam } } + go func() { + ctx := context.Background() + for { + var ok bool + select { + case <-ctx.Done(): + _ = c.Conn.SetDeadline(time.Unix(0, 0)) + case ctx, ok = <-c.contexts: + if !ok { + return + } + } + } + }() + return c, nil } @@ -208,8 +226,19 @@ func (c *Conn) handshake() error { return nil } +func (c *Conn) watchCtx(ctx context.Context) func() { + c.contexts <- ctx + return func() { + c.contexts <- context.Background() + } +} + // Close directly closes the connection. Use Quit() to first send COM_QUIT to the server and then close the connection. func (c *Conn) Close() error { + if !c.closed { + close(c.contexts) + c.closed = true + } return c.Conn.Close() } @@ -309,6 +338,11 @@ func (c *Conn) Execute(command string, args ...interface{}) (*mysql.Result, erro } } +func (c *Conn) ExecuteContext(ctx context.Context, command string, args ...interface{}) (*mysql.Result, error) { + defer c.watchCtx(ctx) + return c.Execute(command, args...) +} + // ExecuteMultiple will call perResultCallback for every result of the multiple queries // that are executed. // @@ -363,6 +397,11 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall return mysql.NewResult(rs), nil } +func (c *Conn) ExecuteMultipleContext(ctx context.Context, query string, perResultCallback ExecPerResultCallback) (*mysql.Result, error) { + defer c.watchCtx(ctx) + return c.ExecuteMultiple(query, perResultCallback) +} + // ExecuteSelectStreaming will call perRowCallback for every row in resultset // WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields. // When given, perResultCallback will be called once per result @@ -376,11 +415,33 @@ func (c *Conn) ExecuteSelectStreaming(command string, result *mysql.Result, perR return c.readResultStreaming(false, result, perRowCallback, perResultCallback) } +func (c *Conn) ExecuteSelectStreamingContext(ctx context.Context, command string, result *mysql.Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error { + defer c.watchCtx(ctx) + return c.ExecuteSelectStreaming(command, result, perRowCallback, perResultCallback) +} + func (c *Conn) Begin() error { _, err := c.exec("BEGIN") return errors.Trace(err) } +func (c *Conn) BeginTx(ctx context.Context, readOnly bool, txIsolation string) error { + defer c.watchCtx(ctx)() + + if txIsolation != "" { + if _, err := c.exec("SET TRANSACTION ISOLATION LEVEL " + txIsolation); err != nil { + return errors.Trace(err) + } + } + var err error + if readOnly { + _, err = c.exec("START TRANSACTION READ ONLY") + } else { + _, err = c.exec("START TRANSACTION") + } + return errors.Trace(err) +} + func (c *Conn) Commit() error { _, err := c.exec("COMMIT") return errors.Trace(err) diff --git a/driver/driver.go b/driver/driver.go index fc82c2572..bc6215269 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -3,6 +3,7 @@ package driver import ( + "context" "crypto/tls" "database/sql" sqldriver "database/sql/driver" @@ -185,6 +186,9 @@ type CheckNamedValueFunc func(*sqldriver.NamedValue) error var ( _ sqldriver.NamedValueChecker = &conn{} _ sqldriver.Validator = &conn{} + _ sqldriver.Conn = &conn{} + _ sqldriver.ConnBeginTx = &conn{} + _ sqldriver.QueryerContext = &conn{} ) type state struct { @@ -242,6 +246,27 @@ func (c *conn) Begin() (sqldriver.Tx, error) { return &tx{c.Conn}, nil } +var isolationLevelTransactionIsolation = map[sql.IsolationLevel]string{ + sql.LevelDefault: "", + sql.LevelRepeatableRead: "REPEATABLE READ", + sql.LevelReadCommitted: "READ COMMITTED", + sql.LevelReadUncommitted: "READ UNCOMMITTED", + sql.LevelSerializable: "SERIALIZABLE", +} + +func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver.Tx, error) { + isolation := sql.IsolationLevel(opts.Isolation) + txIsolation, ok := isolationLevelTransactionIsolation[isolation] + if !ok { + return nil, fmt.Errorf("invalid mysql transaction isolation level %s", isolation) + } + err := c.Conn.BeginTx(ctx, opts.ReadOnly, txIsolation) + if err != nil { + return nil, errors.Trace(err) + } + return &tx{c.Conn}, nil +} + func buildArgs(args []sqldriver.Value) []interface{} { a := make([]interface{}, len(args)) @@ -252,6 +277,17 @@ func buildArgs(args []sqldriver.Value) []interface{} { return a } +func buildNamedArgs(args []sqldriver.NamedValue) []interface{} { + a := make([]interface{}, len(args)) + + for i, arg := range args { + // TODO named parameter support + a[i] = arg.Value + } + + return a +} + func (st *state) replyError(err error) error { isBadConnection := mysql.ErrorEqual(err, mysql.ErrBadConn) @@ -284,6 +320,15 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro return newRows(r.Resultset) } +func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) { + a := buildNamedArgs(args) + r, err := c.Conn.ExecuteContext(ctx, query, a...) + if err != nil { + return nil, c.state.replyError(err) + } + return newRows(r.Resultset) +} + type stmt struct { *client.Stmt connectionState *state diff --git a/packet/conn.go b/packet/conn.go index 172714e0b..27ead18cf 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -31,7 +31,6 @@ type Conn struct { readTimeout time.Duration writeTimeout time.Duration - ctx context.Context // Buffered reader for net.Conn in Non-TLS connection only to address replication performance issue. // See https://github.com/go-mysql-org/go-mysql/pull/422 for more details. From 1c3ba7d8d70344b14505fbea19cbd3db4e1bce1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Fri, 21 Feb 2025 21:14:06 +0000 Subject: [PATCH 3/7] move context management into driver Close instead of SetTimeout hack implement more interfaces, with assertions --- client/conn.go | 48 +------------- driver/driver.go | 159 ++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 137 insertions(+), 70 deletions(-) diff --git a/client/conn.go b/client/conn.go index b3ba07ea2..c1b65a6a2 100644 --- a/client/conn.go +++ b/client/conn.go @@ -35,8 +35,6 @@ type Conn struct { // Connection read and write timeouts to set on the connection ReadTimeout time.Duration WriteTimeout time.Duration - contexts chan context.Context - closed bool // The buffer size to use in the packet connection BufferSize int @@ -138,7 +136,6 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam c.password = password c.db = dbName c.proto = network - c.contexts = make(chan context.Context) // use default charset here, utf-8 c.charset = mysql.DEFAULT_CHARSET @@ -187,21 +184,6 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam } } - go func() { - ctx := context.Background() - for { - var ok bool - select { - case <-ctx.Done(): - _ = c.Conn.SetDeadline(time.Unix(0, 0)) - case ctx, ok = <-c.contexts: - if !ok { - return - } - } - } - }() - return c, nil } @@ -226,19 +208,8 @@ func (c *Conn) handshake() error { return nil } -func (c *Conn) watchCtx(ctx context.Context) func() { - c.contexts <- ctx - return func() { - c.contexts <- context.Background() - } -} - // Close directly closes the connection. Use Quit() to first send COM_QUIT to the server and then close the connection. func (c *Conn) Close() error { - if !c.closed { - close(c.contexts) - c.closed = true - } return c.Conn.Close() } @@ -338,11 +309,6 @@ func (c *Conn) Execute(command string, args ...interface{}) (*mysql.Result, erro } } -func (c *Conn) ExecuteContext(ctx context.Context, command string, args ...interface{}) (*mysql.Result, error) { - defer c.watchCtx(ctx) - return c.Execute(command, args...) -} - // ExecuteMultiple will call perResultCallback for every result of the multiple queries // that are executed. // @@ -397,11 +363,6 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall return mysql.NewResult(rs), nil } -func (c *Conn) ExecuteMultipleContext(ctx context.Context, query string, perResultCallback ExecPerResultCallback) (*mysql.Result, error) { - defer c.watchCtx(ctx) - return c.ExecuteMultiple(query, perResultCallback) -} - // ExecuteSelectStreaming will call perRowCallback for every row in resultset // WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields. // When given, perResultCallback will be called once per result @@ -415,19 +376,12 @@ func (c *Conn) ExecuteSelectStreaming(command string, result *mysql.Result, perR return c.readResultStreaming(false, result, perRowCallback, perResultCallback) } -func (c *Conn) ExecuteSelectStreamingContext(ctx context.Context, command string, result *mysql.Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error { - defer c.watchCtx(ctx) - return c.ExecuteSelectStreaming(command, result, perRowCallback, perResultCallback) -} - func (c *Conn) Begin() error { _, err := c.exec("BEGIN") return errors.Trace(err) } -func (c *Conn) BeginTx(ctx context.Context, readOnly bool, txIsolation string) error { - defer c.watchCtx(ctx)() - +func (c *Conn) BeginTx(readOnly bool, txIsolation string) error { if txIsolation != "" { if _, err := c.exec("SET TRANSACTION ISOLATION LEVEL " + txIsolation); err != nil { return errors.Trace(err) diff --git a/driver/driver.go b/driver/driver.go index bc6215269..8d5183623 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -22,6 +22,23 @@ import ( "github.com/pingcap/errors" ) +var ( + _ sqldriver.Driver = &driver{} + _ sqldriver.DriverContext = &driver{} + _ sqldriver.Connector = &connInfo{} + _ sqldriver.NamedValueChecker = &conn{} + _ sqldriver.Validator = &conn{} + _ sqldriver.Conn = &conn{} + _ sqldriver.Pinger = &conn{} + _ sqldriver.ConnBeginTx = &conn{} + _ sqldriver.ConnPrepareContext = &conn{} + _ sqldriver.ExecerContext = &conn{} + _ sqldriver.QueryerContext = &conn{} + _ sqldriver.Stmt = &stmt{} + _ sqldriver.StmtExecContext = &stmt{} + _ sqldriver.StmtQueryContext = &stmt{} +) + var customTLSMutex sync.Mutex // Map of dsn address (makes more sense than full dsn?) to tls Config @@ -102,16 +119,19 @@ func parseDSN(dsn string) (connInfo, error) { // Open takes a supplied DSN string and opens a connection // See ParseDSN for more information on the form of the DSN func (d driver) Open(dsn string) (sqldriver.Conn, error) { - var ( - c *client.Conn - // by default database/sql driver retries will be enabled - retries = true - ) - ci, err := parseDSN(dsn) if err != nil { return nil, err } + return ci.Connect(context.Background()) + +} + +func (ci connInfo) Connect(ctx context.Context) (sqldriver.Conn, error) { + var c *client.Conn + var err error + // by default database/sql driver retries will be enabled + retries := true if ci.standardDSN { var timeout time.Duration @@ -160,48 +180,85 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { } } - if timeout > 0 { - c, err = client.ConnectWithTimeout(ci.addr, ci.user, ci.password, ci.db, timeout, configuredOptions...) - } else { - c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, configuredOptions...) + if timeout <= 0 { + timeout = 10 * time.Second } + c, err = client.ConnectWithContext(ctx, ci.addr, ci.user, ci.password, ci.db, timeout, configuredOptions...) } else { // No more processing here. Let's only support url parameters with the newer style DSN - c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db) + c, err = client.ConnectWithContext(ctx, ci.addr, ci.user, ci.password, ci.db, 10*time.Second) } if err != nil { return nil, err } + contexts := make(chan context.Context) + go func() { + ctx := context.Background() + for { + var ok bool + select { + case <-ctx.Done(): + _ = c.Conn.Close() + case ctx, ok = <-contexts: + if !ok { + return + } + } + } + }() + // if retries are 'on' then return sqldriver.ErrBadConn which will trigger up to 3 // retries by the database/sql package. If retries are 'off' then we'll return // the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry. // In this case the sqldriver.Validator interface is implemented and will return // false for IsValid() signaling the connection is bad and should be discarded. - return &conn{Conn: c, state: &state{valid: true, useStdLibErrors: retries}}, nil + return &conn{ + Conn: c, + state: &state{contexts: contexts, valid: true, useStdLibErrors: retries}, + }, nil } -type CheckNamedValueFunc func(*sqldriver.NamedValue) error +func (d driver) OpenConnector(name string) (sqldriver.Connector, error) { + return parseDSN(name) +} -var ( - _ sqldriver.NamedValueChecker = &conn{} - _ sqldriver.Validator = &conn{} - _ sqldriver.Conn = &conn{} - _ sqldriver.ConnBeginTx = &conn{} - _ sqldriver.QueryerContext = &conn{} -) +func (ci connInfo) Driver() sqldriver.Driver { + return driver{} +} + +type CheckNamedValueFunc func(*sqldriver.NamedValue) error type state struct { - valid bool + contexts chan context.Context + valid bool // when true, the driver connection will return ErrBadConn from the golang Standard Library useStdLibErrors bool } +func (s *state) watchCtx(ctx context.Context) func() { + s.contexts <- ctx + return func() { + s.contexts <- context.Background() + } +} + +func (s *state) Close() { + if s.contexts != nil { + close(s.contexts) + s.contexts = nil + } +} + type conn struct { *client.Conn state *state } +func (c *conn) watchCtx(ctx context.Context) func() { + return c.state.watchCtx(ctx) +} + func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error { for _, nvChecker := range namedValueCheckers { err := nvChecker(nv) @@ -224,6 +281,17 @@ func (c *conn) IsValid() bool { return c.state.valid } +func (c *conn) Ping(ctx context.Context) error { + defer c.watchCtx(ctx)() + if err := c.Conn.Ping(); err != nil { + if err == context.DeadlineExceeded || err == context.Canceled { + return err + } + return sqldriver.ErrBadConn + } + return nil +} + func (c *conn) Prepare(query string) (sqldriver.Stmt, error) { st, err := c.Conn.Prepare(query) if err != nil { @@ -233,7 +301,13 @@ func (c *conn) Prepare(query string) (sqldriver.Stmt, error) { return &stmt{Stmt: st, connectionState: c.state}, nil } +func (c *conn) PrepareContext(ctx context.Context, query string) (sqldriver.Stmt, error) { + defer c.watchCtx(ctx)() + return c.Prepare(query) +} + func (c *conn) Close() error { + c.state.Close() return c.Conn.Close() } @@ -255,12 +329,14 @@ var isolationLevelTransactionIsolation = map[sql.IsolationLevel]string{ } func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver.Tx, error) { + defer c.watchCtx(ctx)() + isolation := sql.IsolationLevel(opts.Isolation) txIsolation, ok := isolationLevelTransactionIsolation[isolation] if !ok { return nil, fmt.Errorf("invalid mysql transaction isolation level %s", isolation) } - err := c.Conn.BeginTx(ctx, opts.ReadOnly, txIsolation) + err := c.Conn.BeginTx(opts.ReadOnly, txIsolation) if err != nil { return nil, errors.Trace(err) } @@ -311,6 +387,16 @@ func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, err return &result{r}, nil } +func (c *conn) ExecContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Result, error) { + defer c.watchCtx(ctx)() + a := buildNamedArgs(args) + r, err := c.Conn.Execute(query, a...) + if err != nil { + return nil, c.state.replyError(err) + } + return &result{r}, nil +} + func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, error) { a := buildArgs(args) r, err := c.Conn.Execute(query, a...) @@ -321,8 +407,9 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro } func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) { + defer c.watchCtx(ctx)() a := buildNamedArgs(args) - r, err := c.Conn.ExecuteContext(ctx, query, a...) + r, err := c.Conn.Execute(query, a...) if err != nil { return nil, c.state.replyError(err) } @@ -334,6 +421,10 @@ type stmt struct { connectionState *state } +func (s *stmt) watchCtx(ctx context.Context) func() { + return s.connectionState.watchCtx(ctx) +} + func (s *stmt) Close() error { return s.Stmt.Close() } @@ -351,6 +442,17 @@ func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) { return &result{r}, nil } +func (s *stmt) ExecContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Result, error) { + defer s.watchCtx(ctx)() + + a := buildNamedArgs(args) + r, err := s.Stmt.Execute(a...) + if err != nil { + return nil, s.connectionState.replyError(err) + } + return &result{r}, nil +} + func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) { a := buildArgs(args) r, err := s.Stmt.Execute(a...) @@ -360,6 +462,17 @@ func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) { return newRows(r.Resultset) } +func (s *stmt) QueryContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Rows, error) { + defer s.watchCtx(ctx)() + + a := buildNamedArgs(args) + r, err := s.Stmt.Execute(a...) + if err != nil { + return nil, s.connectionState.replyError(err) + } + return newRows(r.Resultset) +} + type tx struct { *client.Conn } From f830d5000b56a752f556c8e3d6717a678026cd53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Fri, 21 Feb 2025 21:17:32 +0000 Subject: [PATCH 4/7] lint --- driver/driver.go | 1 - 1 file changed, 1 deletion(-) diff --git a/driver/driver.go b/driver/driver.go index 8d5183623..d6629f1c3 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -124,7 +124,6 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { return nil, err } return ci.Connect(context.Background()) - } func (ci connInfo) Connect(ctx context.Context) (sqldriver.Conn, error) { From 52d1f177eb7ce02d3ebefa4fa3d416f01288e5a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Fri, 21 Feb 2025 21:19:55 +0000 Subject: [PATCH 5/7] avoid excessive spinning when context canceled & waiting on contexts channel to close --- driver/driver.go | 1 + 1 file changed, 1 insertion(+) diff --git a/driver/driver.go b/driver/driver.go index d6629f1c3..65193c7f4 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -198,6 +198,7 @@ func (ci connInfo) Connect(ctx context.Context) (sqldriver.Conn, error) { var ok bool select { case <-ctx.Done(): + ctx = context.Background() _ = c.Conn.Close() case ctx, ok = <-contexts: if !ok { From aaf9d0e5b67e9734ca47fb876bb35de62859a476 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Fri, 21 Feb 2025 22:40:02 +0000 Subject: [PATCH 6/7] compressedHeader tweak --- packet/conn.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packet/conn.go b/packet/conn.go index 27ead18cf..5b3930dae 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -360,7 +360,7 @@ func (c *Conn) writeCompressed(data []byte) (n int, err error) { var ( compressedLength, uncompressedLength int payload *bytes.Buffer - compressedHeader = make([]byte, 7) + compressedHeader [7]byte ) if len(data) > MinCompressionLength { @@ -406,7 +406,7 @@ func (c *Conn) writeCompressed(data []byte) (n int, err error) { compressedHeader[4] = byte(uncompressedLength) compressedHeader[5] = byte(uncompressedLength >> 8) compressedHeader[6] = byte(uncompressedLength >> 16) - if _, err = compressedPacket.Write(compressedHeader); err != nil { + if _, err = compressedPacket.Write(compressedHeader[:]); err != nil { return 0, err } c.CompressedSequence++ From 33ce3af54da247cc2229309dfbd8bf44df32715e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Sat, 22 Feb 2025 02:26:13 +0000 Subject: [PATCH 7/7] mysql doesn't have named parameters --- driver/driver.go | 1 - 1 file changed, 1 deletion(-) diff --git a/driver/driver.go b/driver/driver.go index 65193c7f4..71827df8c 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -357,7 +357,6 @@ func buildNamedArgs(args []sqldriver.NamedValue) []interface{} { a := make([]interface{}, len(args)) for i, arg := range args { - // TODO named parameter support a[i] = arg.Value }