diff --git a/client/auth.go b/client/auth.go index a491dd82f..5e1e4c45f 100644 --- a/client/auth.go +++ b/client/auth.go @@ -203,6 +203,7 @@ func (c *Conn) writeAuthHandshake() error { CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_PLUGIN_AUTH // Adjust client capability flags based on server support capability |= c.capability & CLIENT_LONG_FLAG + capability |= c.capability & CLIENT_QUERY_ATTRIBUTES // Adjust client capability flags on specific client requests // Only flags that would make any sense setting and aren't handled elsewhere // in the library are supported here diff --git a/client/conn.go b/client/conn.go index 310146415..523113b4a 100644 --- a/client/conn.go +++ b/client/conn.go @@ -57,6 +57,11 @@ type Conn struct { authPluginName string connectionID uint32 + + queryAttributes []QueryAttribute + + // Include the file + line as query attribute. The number set which frame in the stack should be used. + includeLine int } // This function will be called for every row in resultset from ExecuteSelectStreaming. @@ -100,6 +105,7 @@ type Dialer func(ctx context.Context, network, address string) (net.Conn, error) func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbName string, dialer Dialer, options ...Option) (*Conn, error) { c := new(Conn) + c.includeLine = -1 c.BufferSize = defaultBufferSize c.attributes = map[string]string{ "_client_name": "go-mysql", @@ -310,7 +316,7 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) { // flag set to signal the server multiple queries are executed. Handling the responses // is up to the implementation of perResultCallback. func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCallback) (*Result, error) { - if err := c.writeCommandStr(COM_QUERY, query); err != nil { + if err := c.execSend(query); err != nil { return nil, errors.Trace(err) } @@ -363,7 +369,7 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall // // ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving. func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error { - if err := c.writeCommandStr(COM_QUERY, command); err != nil { + if err := c.execSend(command); err != nil { return errors.Trace(err) } @@ -489,14 +495,64 @@ func (c *Conn) ReadOKPacket() (*Result, error) { return c.readOK() } +// Send COM_QUERY and read the result func (c *Conn) exec(query string) (*Result, error) { - if err := c.writeCommandStr(COM_QUERY, query); err != nil { + err := c.execSend(query) + if err != nil { return nil, errors.Trace(err) } - return c.readResult(false) } +// Sends COM_QUERY +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html +func (c *Conn) execSend(query string) error { + var buf bytes.Buffer + defer clear(c.queryAttributes) + + if c.capability&CLIENT_QUERY_ATTRIBUTES > 0 { + if c.includeLine >= 0 { + _, file, line, ok := runtime.Caller(c.includeLine) + if ok { + lineAttr := QueryAttribute{ + Name: "_line", + Value: fmt.Sprintf("%s:%d", file, line), + } + c.queryAttributes = append(c.queryAttributes, lineAttr) + } + } + + numParams := len(c.queryAttributes) + buf.Write(PutLengthEncodedInt(uint64(numParams))) + buf.WriteByte(0x1) // parameter_set_count, unused + if numParams > 0 { + // null_bitmap, length: (num_params+7)/8 + for i := 0; i < (numParams+7)/8; i++ { + buf.WriteByte(0x0) + } + buf.WriteByte(0x1) // new_params_bind_flag, unused + for _, qa := range c.queryAttributes { + buf.Write(qa.TypeAndFlag()) + buf.Write(PutLengthEncodedString([]byte(qa.Name))) + } + for _, qa := range c.queryAttributes { + buf.Write(qa.ValueBytes()) + } + } + } + + _, err := buf.Write(utils.StringToByteSlice(query)) + if err != nil { + return err + } + + if err := c.writeCommandBuf(COM_QUERY, buf.Bytes()); err != nil { + return errors.Trace(err) + } + + return nil +} + // CapabilityString is returning a string with the names of capability flags // separated by "|". Examples of capability names are CLIENT_DEPRECATE_EOF and CLIENT_PROTOCOL_41. // These are defined as constants in the mysql package. @@ -627,3 +683,17 @@ func (c *Conn) StatusString() string { return strings.Join(stats, "|") } + +// SetQueryAttributes sets the query attributes to be send along with the next query +func (c *Conn) SetQueryAttributes(attrs ...QueryAttribute) error { + c.queryAttributes = attrs + return nil +} + +// IncludeLine can be passed as option when connecting to include the file name and line number +// of the caller as query attribute `_line` when sending queries. +// The argument is used the dept in the stack. The top level is go-mysql and then there are the +// levels of the application. +func (c *Conn) IncludeLine(frame int) { + c.includeLine = frame +} diff --git a/client/conn_test.go b/client/conn_test.go index 72776355f..b34270cd0 100644 --- a/client/conn_test.go +++ b/client/conn_test.go @@ -194,3 +194,19 @@ func (s *connTestSuite) TestAttributes() { require.Equal(s.T(), "go-mysql", s.c.attributes["_client_name"]) require.Equal(s.T(), "attrvalue", s.c.attributes["attrtest"]) } + +func (s *connTestSuite) TestSetQueryAttributes() { + qa := mysql.QueryAttribute{ + Name: "qattr1", + Value: "qattr1val", + } + err := s.c.SetQueryAttributes(qa) + require.NoError(s.T(), err) + expected := []mysql.QueryAttribute{ + mysql.QueryAttribute{ + Name: "qattr1", + Value: "qattr1val", + }, + } + require.Equal(s.T(), expected, s.c.queryAttributes) +} diff --git a/client/stmt.go b/client/stmt.go index 82c760d72..b9866704d 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "math" + "runtime" . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/utils" @@ -56,18 +57,34 @@ func (s *Stmt) Close() error { return nil } +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html func (s *Stmt) write(args ...interface{}) error { + defer clear(s.conn.queryAttributes) paramsNum := s.params if len(args) != paramsNum { return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args)) } - paramTypes := make([]byte, paramsNum<<1) - paramValues := make([][]byte, paramsNum) + if (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0) && (s.conn.includeLine >= 0) { + _, file, line, ok := runtime.Caller(s.conn.includeLine) + if ok { + lineAttr := QueryAttribute{ + Name: "_line", + Value: fmt.Sprintf("%s:%d", file, line), + } + s.conn.queryAttributes = append(s.conn.queryAttributes, lineAttr) + } + } + + qaLen := len(s.conn.queryAttributes) + paramTypes := make([][]byte, paramsNum+qaLen) + paramFlags := make([][]byte, paramsNum+qaLen) + paramValues := make([][]byte, paramsNum+qaLen) + paramNames := make([][]byte, paramsNum+qaLen) //NULL-bitmap, length: (num-params+7) - nullBitmap := make([]byte, (paramsNum+7)>>3) + nullBitmap := make([]byte, (paramsNum+qaLen+7)>>3) length := 1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1) @@ -76,7 +93,9 @@ func (s *Stmt) write(args ...interface{}) error { for i := range args { if args[i] == nil { nullBitmap[i/8] |= 1 << (uint(i) % 8) - paramTypes[i<<1] = MYSQL_TYPE_NULL + paramTypes[i] = []byte{MYSQL_TYPE_NULL} + paramNames[i] = []byte{0} // length encoded, no name + paramFlags[i] = []byte{0} continue } @@ -84,68 +103,79 @@ func (s *Stmt) write(args ...interface{}) error { switch v := args[i].(type) { case int8: - paramTypes[i<<1] = MYSQL_TYPE_TINY + paramTypes[i] = []byte{MYSQL_TYPE_TINY} paramValues[i] = []byte{byte(v)} case int16: - paramTypes[i<<1] = MYSQL_TYPE_SHORT + paramTypes[i] = []byte{MYSQL_TYPE_SHORT} paramValues[i] = Uint16ToBytes(uint16(v)) case int32: - paramTypes[i<<1] = MYSQL_TYPE_LONG + paramTypes[i] = []byte{MYSQL_TYPE_LONG} paramValues[i] = Uint32ToBytes(uint32(v)) case int: - paramTypes[i<<1] = MYSQL_TYPE_LONGLONG + paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} paramValues[i] = Uint64ToBytes(uint64(v)) case int64: - paramTypes[i<<1] = MYSQL_TYPE_LONGLONG + paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} paramValues[i] = Uint64ToBytes(uint64(v)) case uint8: - paramTypes[i<<1] = MYSQL_TYPE_TINY - paramTypes[(i<<1)+1] = 0x80 + paramTypes[i] = []byte{MYSQL_TYPE_TINY} + paramFlags[i] = []byte{PARAM_UNSIGNED} paramValues[i] = []byte{v} case uint16: - paramTypes[i<<1] = MYSQL_TYPE_SHORT - paramTypes[(i<<1)+1] = 0x80 + paramTypes[i] = []byte{MYSQL_TYPE_SHORT} + paramFlags[i] = []byte{PARAM_UNSIGNED} paramValues[i] = Uint16ToBytes(v) case uint32: - paramTypes[i<<1] = MYSQL_TYPE_LONG - paramTypes[(i<<1)+1] = 0x80 + paramTypes[i] = []byte{MYSQL_TYPE_LONG} + paramFlags[i] = []byte{PARAM_UNSIGNED} paramValues[i] = Uint32ToBytes(v) case uint: - paramTypes[i<<1] = MYSQL_TYPE_LONGLONG - paramTypes[(i<<1)+1] = 0x80 + paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} + paramFlags[i] = []byte{PARAM_UNSIGNED} paramValues[i] = Uint64ToBytes(uint64(v)) case uint64: - paramTypes[i<<1] = MYSQL_TYPE_LONGLONG - paramTypes[(i<<1)+1] = 0x80 + paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} + paramFlags[i] = []byte{PARAM_UNSIGNED} paramValues[i] = Uint64ToBytes(v) case bool: - paramTypes[i<<1] = MYSQL_TYPE_TINY + paramTypes[i] = []byte{MYSQL_TYPE_TINY} if v { paramValues[i] = []byte{1} } else { paramValues[i] = []byte{0} } case float32: - paramTypes[i<<1] = MYSQL_TYPE_FLOAT + paramTypes[i] = []byte{MYSQL_TYPE_FLOAT} paramValues[i] = Uint32ToBytes(math.Float32bits(v)) case float64: - paramTypes[i<<1] = MYSQL_TYPE_DOUBLE + paramTypes[i] = []byte{MYSQL_TYPE_DOUBLE} paramValues[i] = Uint64ToBytes(math.Float64bits(v)) case string: - paramTypes[i<<1] = MYSQL_TYPE_STRING + paramTypes[i] = []byte{MYSQL_TYPE_STRING} paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) case []byte: - paramTypes[i<<1] = MYSQL_TYPE_STRING + paramTypes[i] = []byte{MYSQL_TYPE_STRING} paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) case json.RawMessage: - paramTypes[i<<1] = MYSQL_TYPE_STRING + paramTypes[i] = []byte{MYSQL_TYPE_STRING} paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) default: return fmt.Errorf("invalid argument type %T", args[i]) } + paramNames[i] = []byte{0} // length encoded, no name + if paramFlags[i] == nil { + paramFlags[i] = []byte{0} + } length += len(paramValues[i]) } + for i, qa := range s.conn.queryAttributes { + tf := qa.TypeAndFlag() + paramTypes[(i + paramsNum)] = []byte{tf[0]} + paramFlags[i+paramsNum] = []byte{tf[1]} + paramValues[i+paramsNum] = qa.ValueBytes() + paramNames[i+paramsNum] = PutLengthEncodedString([]byte(qa.Name)) + } data := utils.BytesBufferGet() defer func() { @@ -159,25 +189,40 @@ func (s *Stmt) write(args ...interface{}) error { data.WriteByte(COM_STMT_EXECUTE) data.Write([]byte{byte(s.id), byte(s.id >> 8), byte(s.id >> 16), byte(s.id >> 24)}) - //flag: CURSOR_TYPE_NO_CURSOR - data.WriteByte(0x00) + flags := CURSOR_TYPE_NO_CURSOR + if paramsNum > 0 { + flags |= PARAMETER_COUNT_AVAILABLE + } + data.WriteByte(flags) //iteration-count, always 1 data.Write([]byte{1, 0, 0, 0}) - if s.params > 0 { - data.Write(nullBitmap) - - //new-params-bound-flag - data.WriteByte(newParamBoundFlag) - - if newParamBoundFlag == 1 { - //type of each parameter, length: num-params * 2 - data.Write(paramTypes) - - //value of each parameter - for _, v := range paramValues { - data.Write(v) + if paramsNum > 0 || (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 && (flags&PARAMETER_COUNT_AVAILABLE > 0)) { + if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 { + paramsNum += len(s.conn.queryAttributes) + data.Write(PutLengthEncodedInt(uint64(paramsNum))) + } + if paramsNum > 0 { + data.Write(nullBitmap) + + //new-params-bound-flag + data.WriteByte(newParamBoundFlag) + + if newParamBoundFlag == 1 { + for i := 0; i < paramsNum; i++ { + data.Write(paramTypes[i]) + data.Write(paramFlags[i]) + + if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 { + data.Write(paramNames[i]) + } + } + + //value of each parameter + for _, v := range paramValues { + data.Write(v) + } } } } diff --git a/mysql/const.go b/mysql/const.go index 293382a38..d361e8f8f 100644 --- a/mysql/const.go +++ b/mysql/const.go @@ -178,6 +178,10 @@ const ( UNIQUE_FLAG = 65536 ) +const ( + PARAM_UNSIGNED = 128 +) + const ( DEFAULT_ADDR = "127.0.0.1:3306" DEFAULT_IPV6_ADDR = "[::1]:3306" @@ -209,3 +213,12 @@ const ( MYSQL_COMPRESS_ZLIB MYSQL_COMPRESS_ZSTD ) + +// See enum_cursor_type in mysql.h +const ( + CURSOR_TYPE_NO_CURSOR byte = 0x0 + CURSOR_TYPE_READ_ONLY byte = 0x1 + CURSOR_TYPE_FOR_UPDATE byte = 0x2 + CURSOR_TYPE_SCROLLABLE byte = 0x4 + PARAMETER_COUNT_AVAILABLE byte = 0x8 +) diff --git a/mysql/queryattributes.go b/mysql/queryattributes.go new file mode 100644 index 000000000..8e15c7dcb --- /dev/null +++ b/mysql/queryattributes.go @@ -0,0 +1,48 @@ +package mysql + +import ( + "encoding/binary" + + "github.com/siddontang/go-log/log" +) + +// Query Attributes in MySQL are key/value pairs passed along with COM_QUERY or COM_STMT_EXECUTE +// +// Supported Value types: string, uint64 +// +// Resources: +// - https://dev.mysql.com/doc/refman/8.4/en/query-attributes.html +// - https://github.com/mysql/mysql-server/blob/trunk/include/mysql/components/services/mysql_query_attributes.h +// - https://archive.fosdem.org/2021/schedule/event/mysql_protocl/ +type QueryAttribute struct { + Name string + Value interface{} +} + +// TypeAndFlag returns the type MySQL field type of the value and the field flag. +func (qa *QueryAttribute) TypeAndFlag() []byte { + switch v := qa.Value.(type) { + case string: + return []byte{MYSQL_TYPE_STRING, 0x0} + case uint64: + return []byte{MYSQL_TYPE_LONGLONG, PARAM_UNSIGNED} + default: + log.Warnf("query attribute with unsupported type %T", v) + } + return []byte{0x0, 0x0} // type 0x0, flag 0x0, to not break the protocol +} + +// ValueBytes returns the encoded value +func (qa *QueryAttribute) ValueBytes() []byte { + switch v := qa.Value.(type) { + case string: + return PutLengthEncodedString([]byte(v)) + case uint64: + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, v) + return b + default: + log.Warnf("query attribute with unsupported type %T", v) + } + return []byte{0x0} // 0 length value to not break the protocol +} diff --git a/mysql/queryattributes_test.go b/mysql/queryattributes_test.go new file mode 100644 index 000000000..67b30969b --- /dev/null +++ b/mysql/queryattributes_test.go @@ -0,0 +1,33 @@ +package mysql + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTypeAndFlag_string(t *testing.T) { + qattr := QueryAttribute{ + Name: "attrname", + Value: "attrvalue", + } + + tf := qattr.TypeAndFlag() + require.Equal(t, []byte{0xfe, 0x0}, tf) + + vb := qattr.ValueBytes() + require.Equal(t, []byte{0x9, 0x61, 0x74, 0x74, 0x72, 0x76, 0x61, 0x6c, 0x75, 0x65}, vb) +} + +func TestTypeAndFlag_uint64(t *testing.T) { + qattr := QueryAttribute{ + Name: "attrname", + Value: uint64(12345), + } + + tf := qattr.TypeAndFlag() + require.Equal(t, []byte{0x08, 0x80}, tf) + + vb := qattr.ValueBytes() + require.Equal(t, []byte{0x39, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, vb) +} diff --git a/server/stmt.go b/server/stmt.go index bc7902e19..be79ad3ad 100644 --- a/server/stmt.go +++ b/server/stmt.go @@ -101,15 +101,30 @@ func (c *Conn) handleStmtExecute(data []byte) (*Result, error) { s, ok := c.stmts[id] if !ok { - return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER, + return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER, 5, strconv.FormatUint(uint64(id), 10), "stmt_execute") } flag := data[pos] pos++ - //now we only support CURSOR_TYPE_NO_CURSOR flag - if flag != 0 { - return nil, NewError(ER_UNKNOWN_ERROR, fmt.Sprintf("unsupported flag %d", flag)) + // Supported types: + // - CURSOR_TYPE_NO_CURSOR + // - PARAMETER_COUNT_AVAILABLE + + // Make sure the first 4 bits are 0. + if flag>>4 != 0 { + return nil, NewError(ER_UNKNOWN_ERROR, fmt.Sprintf("unsupported flags 0x%x", flag)) + } + + // Test for unsupported flags in the remaining 4 bits. + if flag&CURSOR_TYPE_READ_ONLY > 0 { + return nil, NewError(ER_UNKNOWN_ERROR, "unsupported flag CURSOR_TYPE_READ_ONLY") + } + if flag&CURSOR_TYPE_FOR_UPDATE > 0 { + return nil, NewError(ER_UNKNOWN_ERROR, "unsupported flag CURSOR_TYPE_FOR_UPDATE") + } + if flag&CURSOR_TYPE_SCROLLABLE > 0 { + return nil, NewError(ER_UNKNOWN_ERROR, "unsupported flag CURSOR_TYPE_SCROLLABLE") } //skip iteration-count, always 1 @@ -324,7 +339,7 @@ func (c *Conn) handleStmtReset(data []byte) (*Result, error) { s, ok := c.stmts[id] if !ok { - return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER, + return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER, 5, strconv.FormatUint(uint64(id), 10), "stmt_reset") } diff --git a/server/stmt_test.go b/server/stmt_test.go new file mode 100644 index 000000000..7141c2f15 --- /dev/null +++ b/server/stmt_test.go @@ -0,0 +1,48 @@ +package server + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHandleStmtExecute(t *testing.T) { + c := Conn{} + c.stmts = map[uint32]*Stmt{ + 1: &Stmt{}, + } + testcases := []struct { + data []byte + errtext string + }{ + { + []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + "ERROR 1243 (HY000): Unknown prepared statement handler (0) given to stmt_execute", + }, + { + []byte{0x1, 0x0, 0x0, 0x0, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0}, + "ERROR 1105 (HY000): unsupported flags 0xff", + }, + { + []byte{0x1, 0x0, 0x0, 0x0, 0x01, 0x0, 0x0, 0x0, 0x0, 0x0}, + "ERROR 1105 (HY000): unsupported flag CURSOR_TYPE_READ_ONLY", + }, + { + []byte{0x1, 0x0, 0x0, 0x0, 0x02, 0x0, 0x0, 0x0, 0x0, 0x0}, + "ERROR 1105 (HY000): unsupported flag CURSOR_TYPE_FOR_UPDATE", + }, + { + []byte{0x1, 0x0, 0x0, 0x0, 0x04, 0x0, 0x0, 0x0, 0x0, 0x0}, + "ERROR 1105 (HY000): unsupported flag CURSOR_TYPE_SCROLLABLE", + }, + } + + for _, tc := range testcases { + _, err := c.handleStmtExecute(tc.data) + if tc.errtext == "" { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, tc.errtext) + } + } +}