From 740b4c6b98d67c69d31bd6cee9aff40ae4dd59ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Mon, 27 Jan 2025 08:08:07 +0100 Subject: [PATCH 01/21] client,mysql: Add support for Query Attributes --- client/auth.go | 1 + client/conn.go | 39 +++++++++++++- client/stmt.go | 111 +++++++++++++++++++++++++-------------- mysql/const.go | 9 ++++ mysql/queryattributes.go | 46 ++++++++++++++++ 5 files changed, 165 insertions(+), 41 deletions(-) create mode 100644 mysql/queryattributes.go diff --git a/client/auth.go b/client/auth.go index 200009609..6b6b9b866 100644 --- a/client/auth.go +++ b/client/auth.go @@ -203,6 +203,7 @@ func (c *Conn) writeAuthHandshake() error { CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_PLUGIN_AUTH // Adjust client capability flags based on server support capability |= c.capability & CLIENT_LONG_FLAG + capability |= c.capability & CLIENT_QUERY_ATTRIBUTES // Adjust client capability flags on specific client requests // Only flags that would make any sense setting and aren't handled elsewhere // in the library are supported here diff --git a/client/conn.go b/client/conn.go index 9cdf93afc..f7ab12c72 100644 --- a/client/conn.go +++ b/client/conn.go @@ -56,6 +56,8 @@ type Conn struct { authPluginName string connectionID uint32 + + queryAttributes []QueryAttribute } // This function will be called for every row in resultset from ExecuteSelectStreaming. @@ -481,10 +483,40 @@ func (c *Conn) ReadOKPacket() (*Result, error) { return c.readOK() } +// Sends COM_QUERY +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html func (c *Conn) exec(query string) (*Result, error) { - if err := c.writeCommandStr(COM_QUERY, query); err != nil { + var buf bytes.Buffer + + if c.capability&CLIENT_QUERY_ATTRIBUTES > 0 { + numParams := len(c.queryAttributes) + buf.Write(PutLengthEncodedInt(uint64(numParams))) + buf.WriteByte(0x1) // parameter_set_count, unused + if numParams > 0 { + // null_bitmap, length: (num_params+7)/8 + for i := 0; i < (numParams+7)/8; i++ { + buf.WriteByte(0x0) + } + buf.WriteByte(0x1) // new_params_bind_flag, unused + for _, qa := range c.queryAttributes { + buf.Write(qa.TypeAndFlag()) + buf.Write(PutLengthEncodedString([]byte(qa.Name))) + } + for _, qa := range c.queryAttributes { + buf.Write(qa.ValueBytes()) + } + } + } + + _, err := buf.Write(utils.StringToByteSlice(query)) + if err != nil { + return nil, err + } + + if err := c.writeCommandBuf(COM_QUERY, buf.Bytes()); err != nil { return nil, errors.Trace(err) } + c.queryAttributes = nil return c.readResult(false) } @@ -619,3 +651,8 @@ func (c *Conn) StatusString() string { return strings.Join(stats, "|") } + +func (c *Conn) SetQueryAttributes(attrs ...QueryAttribute) error { + c.queryAttributes = attrs + return nil +} diff --git a/client/stmt.go b/client/stmt.go index 82c760d72..ac3e3015d 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -56,6 +56,7 @@ func (s *Stmt) Close() error { return nil } +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html func (s *Stmt) write(args ...interface{}) error { paramsNum := s.params @@ -63,11 +64,14 @@ func (s *Stmt) write(args ...interface{}) error { return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args)) } - paramTypes := make([]byte, paramsNum<<1) - paramValues := make([][]byte, paramsNum) + qaLen := len(s.conn.queryAttributes) + paramTypes := make([][]byte, paramsNum+qaLen) + paramFlags := make([][]byte, paramsNum+qaLen) + paramValues := make([][]byte, paramsNum+qaLen) + paramNames := make([][]byte, paramsNum+qaLen) //NULL-bitmap, length: (num-params+7) - nullBitmap := make([]byte, (paramsNum+7)>>3) + nullBitmap := make([]byte, (paramsNum+qaLen+7)>>3) length := 1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1) @@ -76,7 +80,7 @@ func (s *Stmt) write(args ...interface{}) error { for i := range args { if args[i] == nil { nullBitmap[i/8] |= 1 << (uint(i) % 8) - paramTypes[i<<1] = MYSQL_TYPE_NULL + paramTypes[i] = []byte{MYSQL_TYPE_NULL} continue } @@ -84,68 +88,79 @@ func (s *Stmt) write(args ...interface{}) error { switch v := args[i].(type) { case int8: - paramTypes[i<<1] = MYSQL_TYPE_TINY + paramTypes[i] = []byte{MYSQL_TYPE_TINY} paramValues[i] = []byte{byte(v)} case int16: - paramTypes[i<<1] = MYSQL_TYPE_SHORT + paramTypes[i] = []byte{MYSQL_TYPE_SHORT} paramValues[i] = Uint16ToBytes(uint16(v)) case int32: - paramTypes[i<<1] = MYSQL_TYPE_LONG + paramTypes[i] = []byte{MYSQL_TYPE_LONG} paramValues[i] = Uint32ToBytes(uint32(v)) case int: - paramTypes[i<<1] = MYSQL_TYPE_LONGLONG + paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} paramValues[i] = Uint64ToBytes(uint64(v)) case int64: - paramTypes[i<<1] = MYSQL_TYPE_LONGLONG + paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} paramValues[i] = Uint64ToBytes(uint64(v)) case uint8: - paramTypes[i<<1] = MYSQL_TYPE_TINY - paramTypes[(i<<1)+1] = 0x80 + paramTypes[i] = []byte{MYSQL_TYPE_TINY} + paramFlags[i] = []byte{UNSIGNED_FLAG} paramValues[i] = []byte{v} case uint16: - paramTypes[i<<1] = MYSQL_TYPE_SHORT - paramTypes[(i<<1)+1] = 0x80 + paramTypes[i] = []byte{MYSQL_TYPE_SHORT} + paramFlags[i] = []byte{UNSIGNED_FLAG} paramValues[i] = Uint16ToBytes(v) case uint32: - paramTypes[i<<1] = MYSQL_TYPE_LONG - paramTypes[(i<<1)+1] = 0x80 + paramTypes[i] = []byte{MYSQL_TYPE_LONG} + paramFlags[i] = []byte{UNSIGNED_FLAG} paramValues[i] = Uint32ToBytes(v) case uint: - paramTypes[i<<1] = MYSQL_TYPE_LONGLONG - paramTypes[(i<<1)+1] = 0x80 + paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} + paramFlags[i] = []byte{UNSIGNED_FLAG} paramValues[i] = Uint64ToBytes(uint64(v)) case uint64: - paramTypes[i<<1] = MYSQL_TYPE_LONGLONG - paramTypes[(i<<1)+1] = 0x80 + paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} + paramFlags[i] = []byte{UNSIGNED_FLAG} paramValues[i] = Uint64ToBytes(v) case bool: - paramTypes[i<<1] = MYSQL_TYPE_TINY + paramTypes[i] = []byte{MYSQL_TYPE_TINY} if v { paramValues[i] = []byte{1} } else { paramValues[i] = []byte{0} } case float32: - paramTypes[i<<1] = MYSQL_TYPE_FLOAT + paramTypes[i] = []byte{MYSQL_TYPE_FLOAT} paramValues[i] = Uint32ToBytes(math.Float32bits(v)) case float64: - paramTypes[i<<1] = MYSQL_TYPE_DOUBLE + paramTypes[i] = []byte{MYSQL_TYPE_DOUBLE} paramValues[i] = Uint64ToBytes(math.Float64bits(v)) case string: - paramTypes[i<<1] = MYSQL_TYPE_STRING + paramTypes[i] = []byte{MYSQL_TYPE_STRING} paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) case []byte: - paramTypes[i<<1] = MYSQL_TYPE_STRING + paramTypes[i] = []byte{MYSQL_TYPE_STRING} paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) case json.RawMessage: - paramTypes[i<<1] = MYSQL_TYPE_STRING + paramTypes[i] = []byte{MYSQL_TYPE_STRING} paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) default: return fmt.Errorf("invalid argument type %T", args[i]) } + paramNames[i] = []byte{0} // lenght encoded, no name + if paramFlags[i] == nil { + paramFlags[i] = []byte{0} + } length += len(paramValues[i]) } + for i, qa := range s.conn.queryAttributes { + tf := qa.TypeAndFlag() + paramTypes[(i + paramsNum)] = []byte{tf[0]} + paramFlags[i+paramsNum] = []byte{tf[1]} + paramValues[i+paramsNum] = qa.ValueBytes() + paramNames[i+paramsNum] = PutLengthEncodedString([]byte(qa.Name)) + } data := utils.BytesBufferGet() defer func() { @@ -159,30 +174,46 @@ func (s *Stmt) write(args ...interface{}) error { data.WriteByte(COM_STMT_EXECUTE) data.Write([]byte{byte(s.id), byte(s.id >> 8), byte(s.id >> 16), byte(s.id >> 24)}) - //flag: CURSOR_TYPE_NO_CURSOR - data.WriteByte(0x00) + flags := CURSOR_TYPE_NO_CURSOR + if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 && len(s.conn.queryAttributes) > 0 { + flags |= PARAMETER_COUNT_AVAILABLE + } + data.WriteByte(flags) //iteration-count, always 1 data.Write([]byte{1, 0, 0, 0}) - if s.params > 0 { - data.Write(nullBitmap) - - //new-params-bound-flag - data.WriteByte(newParamBoundFlag) - - if newParamBoundFlag == 1 { - //type of each parameter, length: num-params * 2 - data.Write(paramTypes) - - //value of each parameter - for _, v := range paramValues { - data.Write(v) + if paramsNum > 0 || (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 && (flags&PARAMETER_COUNT_AVAILABLE > 0)) { + if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 { + paramsNum += len(s.conn.queryAttributes) + data.Write(PutLengthEncodedInt(uint64(paramsNum))) + } + if paramsNum > 0 { + data.Write(nullBitmap) + + //new-params-bound-flag + data.WriteByte(newParamBoundFlag) + + if newParamBoundFlag == 1 { + for i := 0; i < paramsNum; i++ { + data.Write(paramTypes[i]) + data.Write(paramFlags[i]) + + if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 { + data.Write(paramNames[i]) + } + } + + //value of each parameter + for _, v := range paramValues { + data.Write(v) + } } } } s.conn.ResetSequence() + s.conn.queryAttributes = nil return s.conn.WritePacket(data.Bytes()) } diff --git a/mysql/const.go b/mysql/const.go index 09c62cc3d..d6259e5f4 100644 --- a/mysql/const.go +++ b/mysql/const.go @@ -209,3 +209,12 @@ const ( MYSQL_COMPRESS_ZLIB MYSQL_COMPRESS_ZSTD ) + +// See enum_cursor_type in mysql.h +const ( + CURSOR_TYPE_NO_CURSOR byte = 0x0 + CURSOR_TYPE_READ_ONLY byte = 0x1 + CURSOR_TYPE_FOR_UPDATE byte = 0x2 + CURSOR_TYPE_SCROLLABLE byte = 0x4 + PARAMETER_COUNT_AVAILABLE byte = 0x8 +) diff --git a/mysql/queryattributes.go b/mysql/queryattributes.go new file mode 100644 index 000000000..ca5e8c99a --- /dev/null +++ b/mysql/queryattributes.go @@ -0,0 +1,46 @@ +package mysql + +import ( + "encoding/binary" + + "github.com/siddontang/go-log/log" +) + +// Query Attributes in MySQL are key/value pairs passed along with COM_QUERY or COM_STMT_EXECUTE +// +// Resources: +// - https://dev.mysql.com/doc/refman/8.4/en/query-attributes.html +// - https://github.com/mysql/mysql-server/blob/trunk/include/mysql/components/services/mysql_query_attributes.h +// - https://archive.fosdem.org/2021/schedule/event/mysql_protocl/ +type QueryAttribute struct { + Name string + Value interface{} +} + +// TypeAndFlag returns the type MySQL field type of the value and the field flag. +func (qa *QueryAttribute) TypeAndFlag() []byte { + switch v := qa.Value.(type) { + case string: + return []byte{MYSQL_TYPE_STRING, 0x0} + case uint64: + return []byte{MYSQL_TYPE_LONGLONG, UNSIGNED_FLAG} + default: + log.Warnf("query attribute with unsupported type %T", v) + } + return []byte{0x0, 0x0} // type 0x0, flag 0x0, to not break the protocol +} + +// ValueBytes returns the encoded value +func (qa *QueryAttribute) ValueBytes() []byte { + switch v := qa.Value.(type) { + case string: + return PutLengthEncodedString([]byte(v)) + case uint64: + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, v) + return b + default: + log.Warnf("query attribute with unsupported type %T", v) + } + return []byte{0x0} // 0 length value to not break the protocol +} From 1390b5891d4d4015d3ccd746a2e9db168c64ddbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Mon, 27 Jan 2025 08:24:18 +0100 Subject: [PATCH 02/21] fixup --- client/stmt.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/stmt.go b/client/stmt.go index ac3e3015d..bc1a063f2 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -147,7 +147,7 @@ func (s *Stmt) write(args ...interface{}) error { default: return fmt.Errorf("invalid argument type %T", args[i]) } - paramNames[i] = []byte{0} // lenght encoded, no name + paramNames[i] = []byte{0} // length encoded, no name if paramFlags[i] == nil { paramFlags[i] = []byte{0} } From f56205095888f705890d3046c6bd620d6177a470 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Mon, 3 Feb 2025 11:02:55 +0100 Subject: [PATCH 03/21] Add includeLine option --- client/conn.go | 21 +++++++++++++++++++++ client/stmt.go | 13 +++++++++++++ 2 files changed, 34 insertions(+) diff --git a/client/conn.go b/client/conn.go index f7ab12c72..319bd3ad1 100644 --- a/client/conn.go +++ b/client/conn.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "runtime" + "runtime/debug" "strings" "time" @@ -58,6 +59,8 @@ type Conn struct { connectionID uint32 queryAttributes []QueryAttribute + + includeLine bool } // This function will be called for every row in resultset from ExecuteSelectStreaming. @@ -488,6 +491,18 @@ func (c *Conn) ReadOKPacket() (*Result, error) { func (c *Conn) exec(query string) (*Result, error) { var buf bytes.Buffer + if c.includeLine { + _, file, line, ok := runtime.Caller(2) + if ok { + lineAttr := QueryAttribute{ + Name: "_line", + Value: fmt.Sprintf("%s:%d", file, line), + } + c.queryAttributes = append(c.queryAttributes, lineAttr) + } + + } + if c.capability&CLIENT_QUERY_ATTRIBUTES > 0 { numParams := len(c.queryAttributes) buf.Write(PutLengthEncodedInt(uint64(numParams))) @@ -656,3 +671,9 @@ func (c *Conn) SetQueryAttributes(attrs ...QueryAttribute) error { c.queryAttributes = attrs return nil } + +// IncludeLine can be passed as option when connecting to include the file name and line number +// of the caller as query attribute when sending queries. +func (c *Conn) IncludeLine() { + c.includeLine = true +} diff --git a/client/stmt.go b/client/stmt.go index bc1a063f2..c9d89d9d6 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "math" + "runtime" . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/utils" @@ -64,6 +65,18 @@ func (s *Stmt) write(args ...interface{}) error { return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args)) } + if s.conn.includeLine { + _, file, line, ok := runtime.Caller(2) + if ok { + lineAttr := QueryAttribute{ + Name: "_line", + Value: fmt.Sprintf("%s:%d", file, line), + } + s.conn.queryAttributes = append(s.conn.queryAttributes, lineAttr) + } + + } + qaLen := len(s.conn.queryAttributes) paramTypes := make([][]byte, paramsNum+qaLen) paramFlags := make([][]byte, paramsNum+qaLen) From 4cbdd13df703e4ae1642ba8d0cab2d15cdfed18e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Mon, 3 Feb 2025 11:03:18 +0100 Subject: [PATCH 04/21] Add _client_version connattr --- client/conn.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/client/conn.go b/client/conn.go index 319bd3ad1..73eeb414a 100644 --- a/client/conn.go +++ b/client/conn.go @@ -105,9 +105,13 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam c := new(Conn) c.BufferSize = defaultBufferSize + clientVersion := "unknown" + if buildInfo, ok := debug.ReadBuildInfo(); ok { + clientVersion = buildInfo.Main.Version + } c.attributes = map[string]string{ - "_client_name": "go-mysql", - // "_client_version": "0.1", + "_client_name": "go-mysql", + "_client_version": clientVersion, "_os": runtime.GOOS, "_platform": runtime.GOARCH, "_runtime_version": runtime.Version(), From deb44927671561d6ba7ceb0ad4681eed3ebcfb4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Mon, 3 Feb 2025 12:38:34 +0100 Subject: [PATCH 05/21] Fix nil args --- client/stmt.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/client/stmt.go b/client/stmt.go index c9d89d9d6..d4d460648 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -94,6 +94,8 @@ func (s *Stmt) write(args ...interface{}) error { if args[i] == nil { nullBitmap[i/8] |= 1 << (uint(i) % 8) paramTypes[i] = []byte{MYSQL_TYPE_NULL} + paramNames[i] = []byte{0} // length encoded, no name + paramFlags[i] = []byte{0} continue } From 2a29ba0cfe6fe8959510fe79a7409f7ff4f4b4e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Tue, 4 Feb 2025 09:27:11 +0100 Subject: [PATCH 06/21] update --- client/conn.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/client/conn.go b/client/conn.go index 73eeb414a..1cdf6c368 100644 --- a/client/conn.go +++ b/client/conn.go @@ -7,7 +7,6 @@ import ( "fmt" "net" "runtime" - "runtime/debug" "strings" "time" @@ -105,13 +104,9 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam c := new(Conn) c.BufferSize = defaultBufferSize - clientVersion := "unknown" - if buildInfo, ok := debug.ReadBuildInfo(); ok { - clientVersion = buildInfo.Main.Version - } c.attributes = map[string]string{ - "_client_name": "go-mysql", - "_client_version": clientVersion, + "_client_name": "go-mysql", + // "_client_version": "0.1", "_os": runtime.GOOS, "_platform": runtime.GOARCH, "_runtime_version": runtime.Version(), From 6f41c366d0962e3d4d1ab65b0e6f7c8bb03fcf1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Tue, 4 Feb 2025 10:35:22 +0100 Subject: [PATCH 07/21] update --- client/stmt.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/stmt.go b/client/stmt.go index d4d460648..1d612cb93 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -190,7 +190,7 @@ func (s *Stmt) write(args ...interface{}) error { data.Write([]byte{byte(s.id), byte(s.id >> 8), byte(s.id >> 16), byte(s.id >> 24)}) flags := CURSOR_TYPE_NO_CURSOR - if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 && len(s.conn.queryAttributes) > 0 { + if paramsNum > 0 { flags |= PARAMETER_COUNT_AVAILABLE } data.WriteByte(flags) From 5ad9e060360b62d28cc61ff647859f637add39ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Tue, 4 Feb 2025 10:41:02 +0100 Subject: [PATCH 08/21] fix flags --- client/stmt.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/client/stmt.go b/client/stmt.go index 1d612cb93..0bd546539 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -119,23 +119,23 @@ func (s *Stmt) write(args ...interface{}) error { paramValues[i] = Uint64ToBytes(uint64(v)) case uint8: paramTypes[i] = []byte{MYSQL_TYPE_TINY} - paramFlags[i] = []byte{UNSIGNED_FLAG} + paramFlags[i] = []byte{BINARY_FLAG} paramValues[i] = []byte{v} case uint16: paramTypes[i] = []byte{MYSQL_TYPE_SHORT} - paramFlags[i] = []byte{UNSIGNED_FLAG} + paramFlags[i] = []byte{BINARY_FLAG} paramValues[i] = Uint16ToBytes(v) case uint32: paramTypes[i] = []byte{MYSQL_TYPE_LONG} - paramFlags[i] = []byte{UNSIGNED_FLAG} + paramFlags[i] = []byte{BINARY_FLAG} paramValues[i] = Uint32ToBytes(v) case uint: paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} - paramFlags[i] = []byte{UNSIGNED_FLAG} + paramFlags[i] = []byte{BINARY_FLAG} paramValues[i] = Uint64ToBytes(uint64(v)) case uint64: paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} - paramFlags[i] = []byte{UNSIGNED_FLAG} + paramFlags[i] = []byte{BINARY_FLAG} paramValues[i] = Uint64ToBytes(v) case bool: paramTypes[i] = []byte{MYSQL_TYPE_TINY} From 3143ff87252c1712c6fa656ceb09e28f2fce1faf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Tue, 4 Feb 2025 13:05:31 +0100 Subject: [PATCH 09/21] update --- client/conn.go | 21 +++++++++++++++------ client/stmt.go | 10 +++++----- mysql/const.go | 4 ++++ 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/client/conn.go b/client/conn.go index 1cdf6c368..500bfb198 100644 --- a/client/conn.go +++ b/client/conn.go @@ -306,7 +306,7 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) { // flag set to signal the server multiple queries are executed. Handling the responses // is up to the implementation of perResultCallback. func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCallback) (*Result, error) { - if err := c.writeCommandStr(COM_QUERY, query); err != nil { + if err := c.exec_send(query); err != nil { return nil, errors.Trace(err) } @@ -359,7 +359,7 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall // // ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving. func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error { - if err := c.writeCommandStr(COM_QUERY, command); err != nil { + if err := c.exec_send(command); err != nil { return errors.Trace(err) } @@ -485,9 +485,18 @@ func (c *Conn) ReadOKPacket() (*Result, error) { return c.readOK() } +// Send COM_QUERY and read the result +func (c *Conn) exec(query string) (*Result, error) { + err := c.exec_send(query) + if err != nil { + return nil, err + } + return c.readResult(false) +} + // Sends COM_QUERY // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html -func (c *Conn) exec(query string) (*Result, error) { +func (c *Conn) exec_send(query string) error { var buf bytes.Buffer if c.includeLine { @@ -524,15 +533,15 @@ func (c *Conn) exec(query string) (*Result, error) { _, err := buf.Write(utils.StringToByteSlice(query)) if err != nil { - return nil, err + return err } if err := c.writeCommandBuf(COM_QUERY, buf.Bytes()); err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } c.queryAttributes = nil - return c.readResult(false) + return nil } // CapabilityString is returning a string with the names of capability flags diff --git a/client/stmt.go b/client/stmt.go index 0bd546539..1adf322d4 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -119,23 +119,23 @@ func (s *Stmt) write(args ...interface{}) error { paramValues[i] = Uint64ToBytes(uint64(v)) case uint8: paramTypes[i] = []byte{MYSQL_TYPE_TINY} - paramFlags[i] = []byte{BINARY_FLAG} + paramFlags[i] = []byte{PARAM_UNSIGNED} paramValues[i] = []byte{v} case uint16: paramTypes[i] = []byte{MYSQL_TYPE_SHORT} - paramFlags[i] = []byte{BINARY_FLAG} + paramFlags[i] = []byte{PARAM_UNSIGNED} paramValues[i] = Uint16ToBytes(v) case uint32: paramTypes[i] = []byte{MYSQL_TYPE_LONG} - paramFlags[i] = []byte{BINARY_FLAG} + paramFlags[i] = []byte{PARAM_UNSIGNED} paramValues[i] = Uint32ToBytes(v) case uint: paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} - paramFlags[i] = []byte{BINARY_FLAG} + paramFlags[i] = []byte{PARAM_UNSIGNED} paramValues[i] = Uint64ToBytes(uint64(v)) case uint64: paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} - paramFlags[i] = []byte{BINARY_FLAG} + paramFlags[i] = []byte{PARAM_UNSIGNED} paramValues[i] = Uint64ToBytes(v) case bool: paramTypes[i] = []byte{MYSQL_TYPE_TINY} diff --git a/mysql/const.go b/mysql/const.go index d6259e5f4..9d0a5391f 100644 --- a/mysql/const.go +++ b/mysql/const.go @@ -178,6 +178,10 @@ const ( UNIQUE_FLAG = 65536 ) +const ( + PARAM_UNSIGNED = 128 +) + const ( DEFAULT_ADDR = "127.0.0.1:3306" DEFAULT_IPV6_ADDR = "[::1]:3306" From a4f390b62b2f002a6831cbb444c5f2b3d9f6362d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Tue, 4 Feb 2025 13:09:21 +0100 Subject: [PATCH 10/21] updates --- client/conn.go | 2 +- client/stmt.go | 1 - mysql/queryattributes.go | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/client/conn.go b/client/conn.go index 500bfb198..e55c0248f 100644 --- a/client/conn.go +++ b/client/conn.go @@ -508,7 +508,6 @@ func (c *Conn) exec_send(query string) error { } c.queryAttributes = append(c.queryAttributes, lineAttr) } - } if c.capability&CLIENT_QUERY_ATTRIBUTES > 0 { @@ -675,6 +674,7 @@ func (c *Conn) StatusString() string { return strings.Join(stats, "|") } +// SetQueryAttributes sets the query attributes to be send along with the next query func (c *Conn) SetQueryAttributes(attrs ...QueryAttribute) error { c.queryAttributes = attrs return nil diff --git a/client/stmt.go b/client/stmt.go index 1adf322d4..5f8853988 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -74,7 +74,6 @@ func (s *Stmt) write(args ...interface{}) error { } s.conn.queryAttributes = append(s.conn.queryAttributes, lineAttr) } - } qaLen := len(s.conn.queryAttributes) diff --git a/mysql/queryattributes.go b/mysql/queryattributes.go index ca5e8c99a..d70bd73e7 100644 --- a/mysql/queryattributes.go +++ b/mysql/queryattributes.go @@ -23,7 +23,7 @@ func (qa *QueryAttribute) TypeAndFlag() []byte { case string: return []byte{MYSQL_TYPE_STRING, 0x0} case uint64: - return []byte{MYSQL_TYPE_LONGLONG, UNSIGNED_FLAG} + return []byte{MYSQL_TYPE_LONGLONG, PARAM_UNSIGNED} default: log.Warnf("query attribute with unsupported type %T", v) } From 2d288d4dd5b9bbc0f39c1bad221d8f190793ebc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Tue, 4 Feb 2025 13:41:55 +0100 Subject: [PATCH 11/21] server: accept PARAMETER_COUNT_AVAILABLE flag --- server/stmt.go | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/server/stmt.go b/server/stmt.go index bc7902e19..99d950fd2 100644 --- a/server/stmt.go +++ b/server/stmt.go @@ -107,9 +107,28 @@ func (c *Conn) handleStmtExecute(data []byte) (*Result, error) { flag := data[pos] pos++ - //now we only support CURSOR_TYPE_NO_CURSOR flag - if flag != 0 { - return nil, NewError(ER_UNKNOWN_ERROR, fmt.Sprintf("unsupported flag %d", flag)) + + // Supported types: + // - CURSOR_TYPE_NO_CURSOR + // - PARAMETER_COUNT_AVAILABLE + + // Make sure the first 4 bits are 0. + if flag>>4 != 0 { + return nil, NewError(ER_UNKNOWN_ERROR, fmt.Sprintf("unsupported flags 0x%x", flag)) + } + + // Test for unsupported flags in the remaining 4 bits. + if flag&CURSOR_TYPE_READ_ONLY > 0 { + return nil, NewError(ER_UNKNOWN_ERROR, + fmt.Sprintf("unsupported flag %s", "CURSOR_TYPE_READ_ONLY")) + } + if flag&CURSOR_TYPE_FOR_UPDATE > 0 { + return nil, NewError(ER_UNKNOWN_ERROR, + fmt.Sprintf("unsupported flag %s", "CURSOR_TYPE_FOR_UPDATE")) + } + if flag&CURSOR_TYPE_READ_ONLY > 0 { + return nil, NewError(ER_UNKNOWN_ERROR, + fmt.Sprintf("unsupported flag %s", "CURSOR_TYPE_SCROLLABLE")) } //skip iteration-count, always 1 From 5b4d37feb3982f290f2036d934c41f0ff607e332 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Tue, 4 Feb 2025 13:50:06 +0100 Subject: [PATCH 12/21] fixup --- server/stmt.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/server/stmt.go b/server/stmt.go index 99d950fd2..0d68e2479 100644 --- a/server/stmt.go +++ b/server/stmt.go @@ -107,7 +107,6 @@ func (c *Conn) handleStmtExecute(data []byte) (*Result, error) { flag := data[pos] pos++ - // Supported types: // - CURSOR_TYPE_NO_CURSOR // - PARAMETER_COUNT_AVAILABLE @@ -126,7 +125,7 @@ func (c *Conn) handleStmtExecute(data []byte) (*Result, error) { return nil, NewError(ER_UNKNOWN_ERROR, fmt.Sprintf("unsupported flag %s", "CURSOR_TYPE_FOR_UPDATE")) } - if flag&CURSOR_TYPE_READ_ONLY > 0 { + if flag&CURSOR_TYPE_SCROLLABLE > 0 { return nil, NewError(ER_UNKNOWN_ERROR, fmt.Sprintf("unsupported flag %s", "CURSOR_TYPE_SCROLLABLE")) } From 9e7520824bd56c0a2a39ad6f4950bf39d43d3fd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Wed, 5 Feb 2025 11:27:53 +0100 Subject: [PATCH 13/21] Apply suggestions from code review Co-authored-by: lance6716 --- client/conn.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/client/conn.go b/client/conn.go index 65e7db65a..d65d3bedd 100644 --- a/client/conn.go +++ b/client/conn.go @@ -497,9 +497,9 @@ func (c *Conn) ReadOKPacket() (*Result, error) { func (c *Conn) exec(query string) (*Result, error) { err := c.exec_send(query) if err != nil { - return nil, err + return nil, errors.Trace(err) } - return c.readResult(false) + return errors.Trace(c.readResult(false)) } // Sends COM_QUERY @@ -689,7 +689,7 @@ func (c *Conn) SetQueryAttributes(attrs ...QueryAttribute) error { } // IncludeLine can be passed as option when connecting to include the file name and line number -// of the caller as query attribute when sending queries. +// of the caller as query attribute `_line` when sending queries. func (c *Conn) IncludeLine() { c.includeLine = true } From 7cad496949f177a0a339bcafd214dccaebec44c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Wed, 5 Feb 2025 11:28:54 +0100 Subject: [PATCH 14/21] Allow setting of frame --- client/conn.go | 14 +++++++++----- client/stmt.go | 4 ++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/client/conn.go b/client/conn.go index d65d3bedd..ea6429189 100644 --- a/client/conn.go +++ b/client/conn.go @@ -60,7 +60,8 @@ type Conn struct { queryAttributes []QueryAttribute - includeLine bool + // Include the file + line as query attribute. The number set which frame in the stack should be used. + includeLine int } // This function will be called for every row in resultset from ExecuteSelectStreaming. @@ -104,6 +105,7 @@ type Dialer func(ctx context.Context, network, address string) (net.Conn, error) func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbName string, dialer Dialer, options ...Option) (*Conn, error) { c := new(Conn) + c.includeLine = -1 c.BufferSize = defaultBufferSize c.attributes = map[string]string{ "_client_name": "go-mysql", @@ -507,8 +509,8 @@ func (c *Conn) exec(query string) (*Result, error) { func (c *Conn) exec_send(query string) error { var buf bytes.Buffer - if c.includeLine { - _, file, line, ok := runtime.Caller(2) + if c.includeLine >= 0 { + _, file, line, ok := runtime.Caller(c.includeLine) if ok { lineAttr := QueryAttribute{ Name: "_line", @@ -690,6 +692,8 @@ func (c *Conn) SetQueryAttributes(attrs ...QueryAttribute) error { // IncludeLine can be passed as option when connecting to include the file name and line number // of the caller as query attribute `_line` when sending queries. -func (c *Conn) IncludeLine() { - c.includeLine = true +// The argument is used the dept in the stack. The top level is go-mysql and then there are the +// levels of the application. +func (c *Conn) IncludeLine(frame int) { + c.includeLine = frame } diff --git a/client/stmt.go b/client/stmt.go index 5f8853988..99105d0e1 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -65,8 +65,8 @@ func (s *Stmt) write(args ...interface{}) error { return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args)) } - if s.conn.includeLine { - _, file, line, ok := runtime.Caller(2) + if s.conn.includeLine >= 0 { + _, file, line, ok := runtime.Caller(s.conn.includeLine) if ok { lineAttr := QueryAttribute{ Name: "_line", From d330e39c88567eb7fccf96e3a2e45689a5321d5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Wed, 5 Feb 2025 11:30:40 +0100 Subject: [PATCH 15/21] Apply suggestion --- client/conn.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/client/conn.go b/client/conn.go index ea6429189..018fd493b 100644 --- a/client/conn.go +++ b/client/conn.go @@ -316,7 +316,7 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) { // flag set to signal the server multiple queries are executed. Handling the responses // is up to the implementation of perResultCallback. func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCallback) (*Result, error) { - if err := c.exec_send(query); err != nil { + if err := c.execSend(query); err != nil { return nil, errors.Trace(err) } @@ -369,7 +369,7 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall // // ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving. func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error { - if err := c.exec_send(command); err != nil { + if err := c.execSend(command); err != nil { return errors.Trace(err) } @@ -497,7 +497,7 @@ func (c *Conn) ReadOKPacket() (*Result, error) { // Send COM_QUERY and read the result func (c *Conn) exec(query string) (*Result, error) { - err := c.exec_send(query) + err := c.execSend(query) if err != nil { return nil, errors.Trace(err) } @@ -506,7 +506,7 @@ func (c *Conn) exec(query string) (*Result, error) { // Sends COM_QUERY // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html -func (c *Conn) exec_send(query string) error { +func (c *Conn) execSend(query string) error { var buf bytes.Buffer if c.includeLine >= 0 { From 25f3bae85de6b37c704b20c8c14b18e7ee9c290b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Wed, 5 Feb 2025 11:40:21 +0100 Subject: [PATCH 16/21] fixup --- client/conn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/conn.go b/client/conn.go index 018fd493b..99d8c2c72 100644 --- a/client/conn.go +++ b/client/conn.go @@ -501,7 +501,7 @@ func (c *Conn) exec(query string) (*Result, error) { if err != nil { return nil, errors.Trace(err) } - return errors.Trace(c.readResult(false)) + return c.readResult(false) } // Sends COM_QUERY From 9865fd5b07be69339616daaa0f96b3dfaba9599b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Wed, 5 Feb 2025 17:43:42 +0100 Subject: [PATCH 17/21] Update based on review --- client/conn.go | 20 ++++++++++---------- client/stmt.go | 4 ++-- mysql/queryattributes.go | 2 ++ server/stmt.go | 9 +++------ 4 files changed, 17 insertions(+), 18 deletions(-) diff --git a/client/conn.go b/client/conn.go index 99d8c2c72..523113b4a 100644 --- a/client/conn.go +++ b/client/conn.go @@ -508,19 +508,20 @@ func (c *Conn) exec(query string) (*Result, error) { // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html func (c *Conn) execSend(query string) error { var buf bytes.Buffer + defer clear(c.queryAttributes) - if c.includeLine >= 0 { - _, file, line, ok := runtime.Caller(c.includeLine) - if ok { - lineAttr := QueryAttribute{ - Name: "_line", - Value: fmt.Sprintf("%s:%d", file, line), + if c.capability&CLIENT_QUERY_ATTRIBUTES > 0 { + if c.includeLine >= 0 { + _, file, line, ok := runtime.Caller(c.includeLine) + if ok { + lineAttr := QueryAttribute{ + Name: "_line", + Value: fmt.Sprintf("%s:%d", file, line), + } + c.queryAttributes = append(c.queryAttributes, lineAttr) } - c.queryAttributes = append(c.queryAttributes, lineAttr) } - } - if c.capability&CLIENT_QUERY_ATTRIBUTES > 0 { numParams := len(c.queryAttributes) buf.Write(PutLengthEncodedInt(uint64(numParams))) buf.WriteByte(0x1) // parameter_set_count, unused @@ -548,7 +549,6 @@ func (c *Conn) execSend(query string) error { if err := c.writeCommandBuf(COM_QUERY, buf.Bytes()); err != nil { return errors.Trace(err) } - c.queryAttributes = nil return nil } diff --git a/client/stmt.go b/client/stmt.go index 99105d0e1..b9866704d 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -59,13 +59,14 @@ func (s *Stmt) Close() error { // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html func (s *Stmt) write(args ...interface{}) error { + defer clear(s.conn.queryAttributes) paramsNum := s.params if len(args) != paramsNum { return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args)) } - if s.conn.includeLine >= 0 { + if (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0) && (s.conn.includeLine >= 0) { _, file, line, ok := runtime.Caller(s.conn.includeLine) if ok { lineAttr := QueryAttribute{ @@ -227,7 +228,6 @@ func (s *Stmt) write(args ...interface{}) error { } s.conn.ResetSequence() - s.conn.queryAttributes = nil return s.conn.WritePacket(data.Bytes()) } diff --git a/mysql/queryattributes.go b/mysql/queryattributes.go index d70bd73e7..8e15c7dcb 100644 --- a/mysql/queryattributes.go +++ b/mysql/queryattributes.go @@ -8,6 +8,8 @@ import ( // Query Attributes in MySQL are key/value pairs passed along with COM_QUERY or COM_STMT_EXECUTE // +// Supported Value types: string, uint64 +// // Resources: // - https://dev.mysql.com/doc/refman/8.4/en/query-attributes.html // - https://github.com/mysql/mysql-server/blob/trunk/include/mysql/components/services/mysql_query_attributes.h diff --git a/server/stmt.go b/server/stmt.go index 0d68e2479..452219585 100644 --- a/server/stmt.go +++ b/server/stmt.go @@ -118,16 +118,13 @@ func (c *Conn) handleStmtExecute(data []byte) (*Result, error) { // Test for unsupported flags in the remaining 4 bits. if flag&CURSOR_TYPE_READ_ONLY > 0 { - return nil, NewError(ER_UNKNOWN_ERROR, - fmt.Sprintf("unsupported flag %s", "CURSOR_TYPE_READ_ONLY")) + return nil, NewError(ER_UNKNOWN_ERROR, "unsupported flag CURSOR_TYPE_READ_ONLY") } if flag&CURSOR_TYPE_FOR_UPDATE > 0 { - return nil, NewError(ER_UNKNOWN_ERROR, - fmt.Sprintf("unsupported flag %s", "CURSOR_TYPE_FOR_UPDATE")) + return nil, NewError(ER_UNKNOWN_ERROR, "unsupported flag CURSOR_TYPE_FOR_UPDATE") } if flag&CURSOR_TYPE_SCROLLABLE > 0 { - return nil, NewError(ER_UNKNOWN_ERROR, - fmt.Sprintf("unsupported flag %s", "CURSOR_TYPE_SCROLLABLE")) + return nil, NewError(ER_UNKNOWN_ERROR, "unsupported flag CURSOR_TYPE_SCROLLABLE") } //skip iteration-count, always 1 From e1b8e65f5f26e3735f3070bd6ef1e6ed7e63bb03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Thu, 6 Feb 2025 14:26:06 +0100 Subject: [PATCH 18/21] Add tests and fix error --- server/stmt.go | 4 ++-- server/stmt_test.go | 48 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 server/stmt_test.go diff --git a/server/stmt.go b/server/stmt.go index 452219585..be79ad3ad 100644 --- a/server/stmt.go +++ b/server/stmt.go @@ -101,7 +101,7 @@ func (c *Conn) handleStmtExecute(data []byte) (*Result, error) { s, ok := c.stmts[id] if !ok { - return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER, + return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER, 5, strconv.FormatUint(uint64(id), 10), "stmt_execute") } @@ -339,7 +339,7 @@ func (c *Conn) handleStmtReset(data []byte) (*Result, error) { s, ok := c.stmts[id] if !ok { - return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER, + return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER, 5, strconv.FormatUint(uint64(id), 10), "stmt_reset") } diff --git a/server/stmt_test.go b/server/stmt_test.go new file mode 100644 index 000000000..7141c2f15 --- /dev/null +++ b/server/stmt_test.go @@ -0,0 +1,48 @@ +package server + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHandleStmtExecute(t *testing.T) { + c := Conn{} + c.stmts = map[uint32]*Stmt{ + 1: &Stmt{}, + } + testcases := []struct { + data []byte + errtext string + }{ + { + []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + "ERROR 1243 (HY000): Unknown prepared statement handler (0) given to stmt_execute", + }, + { + []byte{0x1, 0x0, 0x0, 0x0, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0}, + "ERROR 1105 (HY000): unsupported flags 0xff", + }, + { + []byte{0x1, 0x0, 0x0, 0x0, 0x01, 0x0, 0x0, 0x0, 0x0, 0x0}, + "ERROR 1105 (HY000): unsupported flag CURSOR_TYPE_READ_ONLY", + }, + { + []byte{0x1, 0x0, 0x0, 0x0, 0x02, 0x0, 0x0, 0x0, 0x0, 0x0}, + "ERROR 1105 (HY000): unsupported flag CURSOR_TYPE_FOR_UPDATE", + }, + { + []byte{0x1, 0x0, 0x0, 0x0, 0x04, 0x0, 0x0, 0x0, 0x0, 0x0}, + "ERROR 1105 (HY000): unsupported flag CURSOR_TYPE_SCROLLABLE", + }, + } + + for _, tc := range testcases { + _, err := c.handleStmtExecute(tc.data) + if tc.errtext == "" { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, tc.errtext) + } + } +} From 16b678e0ecf50cbe2639656b6c5c9b72460728f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Thu, 6 Feb 2025 14:34:31 +0100 Subject: [PATCH 19/21] more tests --- mysql/queryattributes_test.go | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 mysql/queryattributes_test.go diff --git a/mysql/queryattributes_test.go b/mysql/queryattributes_test.go new file mode 100644 index 000000000..67b30969b --- /dev/null +++ b/mysql/queryattributes_test.go @@ -0,0 +1,33 @@ +package mysql + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTypeAndFlag_string(t *testing.T) { + qattr := QueryAttribute{ + Name: "attrname", + Value: "attrvalue", + } + + tf := qattr.TypeAndFlag() + require.Equal(t, []byte{0xfe, 0x0}, tf) + + vb := qattr.ValueBytes() + require.Equal(t, []byte{0x9, 0x61, 0x74, 0x74, 0x72, 0x76, 0x61, 0x6c, 0x75, 0x65}, vb) +} + +func TestTypeAndFlag_uint64(t *testing.T) { + qattr := QueryAttribute{ + Name: "attrname", + Value: uint64(12345), + } + + tf := qattr.TypeAndFlag() + require.Equal(t, []byte{0x08, 0x80}, tf) + + vb := qattr.ValueBytes() + require.Equal(t, []byte{0x39, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, vb) +} From e939d95353084d089253a0e6c14a613d044866f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Thu, 6 Feb 2025 14:39:33 +0100 Subject: [PATCH 20/21] Add another test --- client/conn_test.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/client/conn_test.go b/client/conn_test.go index 72776355f..ae8ba8d44 100644 --- a/client/conn_test.go +++ b/client/conn_test.go @@ -194,3 +194,13 @@ func (s *connTestSuite) TestAttributes() { require.Equal(s.T(), "go-mysql", s.c.attributes["_client_name"]) require.Equal(s.T(), "attrvalue", s.c.attributes["attrtest"]) } + +func (s *connTestSuite) TestSetQueryAttributes() { + qa := mysql.QueryAttribute{ + Name: "qattr1", + Value: "qattr1val", + } + s.c.SetQueryAttributes(qa) + expected := []mysql.QueryAttribute([]mysql.QueryAttribute{mysql.QueryAttribute{Name: "qattr1", Value: "qattr1val"}}) + require.Equal(s.T(), expected, s.c.queryAttributes) +} From 622b0ba80c6a9f8ce6d2d3ca286d79e300d71f33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Fri, 7 Feb 2025 07:49:40 +0100 Subject: [PATCH 21/21] Fix linter found issues --- client/conn_test.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/client/conn_test.go b/client/conn_test.go index ae8ba8d44..b34270cd0 100644 --- a/client/conn_test.go +++ b/client/conn_test.go @@ -200,7 +200,13 @@ func (s *connTestSuite) TestSetQueryAttributes() { Name: "qattr1", Value: "qattr1val", } - s.c.SetQueryAttributes(qa) - expected := []mysql.QueryAttribute([]mysql.QueryAttribute{mysql.QueryAttribute{Name: "qattr1", Value: "qattr1val"}}) + err := s.c.SetQueryAttributes(qa) + require.NoError(s.T(), err) + expected := []mysql.QueryAttribute{ + mysql.QueryAttribute{ + Name: "qattr1", + Value: "qattr1val", + }, + } require.Equal(s.T(), expected, s.c.queryAttributes) }