Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client,mysql: Add support for Query Attributes #976

Merged
merged 24 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ func (c *Conn) writeAuthHandshake() error {
CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_PLUGIN_AUTH
// Adjust client capability flags based on server support
capability |= c.capability & CLIENT_LONG_FLAG
capability |= c.capability & CLIENT_QUERY_ATTRIBUTES
// Adjust client capability flags on specific client requests
// Only flags that would make any sense setting and aren't handled elsewhere
// in the library are supported here
Expand Down
78 changes: 74 additions & 4 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ type Conn struct {
authPluginName string

connectionID uint32

queryAttributes []QueryAttribute

// Include the file + line as query attribute. The number set which frame in the stack should be used.
includeLine int
}

// This function will be called for every row in resultset from ExecuteSelectStreaming.
Expand Down Expand Up @@ -100,6 +105,7 @@ type Dialer func(ctx context.Context, network, address string) (net.Conn, error)
func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbName string, dialer Dialer, options ...Option) (*Conn, error) {
c := new(Conn)

c.includeLine = -1
c.BufferSize = defaultBufferSize
c.attributes = map[string]string{
"_client_name": "go-mysql",
Expand Down Expand Up @@ -310,7 +316,7 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) {
// flag set to signal the server multiple queries are executed. Handling the responses
// is up to the implementation of perResultCallback.
func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCallback) (*Result, error) {
if err := c.writeCommandStr(COM_QUERY, query); err != nil {
if err := c.execSend(query); err != nil {
return nil, errors.Trace(err)
}

Expand Down Expand Up @@ -363,7 +369,7 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall
//
// ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving.
func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error {
if err := c.writeCommandStr(COM_QUERY, command); err != nil {
if err := c.execSend(command); err != nil {
return errors.Trace(err)
}

Expand Down Expand Up @@ -489,14 +495,64 @@ func (c *Conn) ReadOKPacket() (*Result, error) {
return c.readOK()
}

// Send COM_QUERY and read the result
func (c *Conn) exec(query string) (*Result, error) {
if err := c.writeCommandStr(COM_QUERY, query); err != nil {
err := c.execSend(query)
if err != nil {
return nil, errors.Trace(err)
}

return c.readResult(false)
}

// Sends COM_QUERY
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html
func (c *Conn) execSend(query string) error {
var buf bytes.Buffer
defer clear(c.queryAttributes)

if c.capability&CLIENT_QUERY_ATTRIBUTES > 0 {
if c.includeLine >= 0 {
_, file, line, ok := runtime.Caller(c.includeLine)
if ok {
lineAttr := QueryAttribute{
Name: "_line",
Value: fmt.Sprintf("%s:%d", file, line),
}
c.queryAttributes = append(c.queryAttributes, lineAttr)
}
}

numParams := len(c.queryAttributes)
buf.Write(PutLengthEncodedInt(uint64(numParams)))
buf.WriteByte(0x1) // parameter_set_count, unused
if numParams > 0 {
// null_bitmap, length: (num_params+7)/8
for i := 0; i < (numParams+7)/8; i++ {
buf.WriteByte(0x0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we allow NULL query attributes so the bitmap will have 1? 🤔 we can add it in next PR.

}
buf.WriteByte(0x1) // new_params_bind_flag, unused
for _, qa := range c.queryAttributes {
buf.Write(qa.TypeAndFlag())
buf.Write(PutLengthEncodedString([]byte(qa.Name)))
}
for _, qa := range c.queryAttributes {
buf.Write(qa.ValueBytes())
}
}
}

_, err := buf.Write(utils.StringToByteSlice(query))
if err != nil {
return err
}

if err := c.writeCommandBuf(COM_QUERY, buf.Bytes()); err != nil {
return errors.Trace(err)
}

return nil
}

// CapabilityString is returning a string with the names of capability flags
// separated by "|". Examples of capability names are CLIENT_DEPRECATE_EOF and CLIENT_PROTOCOL_41.
// These are defined as constants in the mysql package.
Expand Down Expand Up @@ -627,3 +683,17 @@ func (c *Conn) StatusString() string {

return strings.Join(stats, "|")
}

// SetQueryAttributes sets the query attributes to be send along with the next query
func (c *Conn) SetQueryAttributes(attrs ...QueryAttribute) error {
c.queryAttributes = attrs
return nil
}

// IncludeLine can be passed as option when connecting to include the file name and line number
// of the caller as query attribute `_line` when sending queries.
// The argument is used the dept in the stack. The top level is go-mysql and then there are the
// levels of the application.
func (c *Conn) IncludeLine(frame int) {
c.includeLine = frame
}
16 changes: 16 additions & 0 deletions client/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,19 @@ func (s *connTestSuite) TestAttributes() {
require.Equal(s.T(), "go-mysql", s.c.attributes["_client_name"])
require.Equal(s.T(), "attrvalue", s.c.attributes["attrtest"])
}

func (s *connTestSuite) TestSetQueryAttributes() {
qa := mysql.QueryAttribute{
Name: "qattr1",
Value: "qattr1val",
}
err := s.c.SetQueryAttributes(qa)
require.NoError(s.T(), err)
expected := []mysql.QueryAttribute{
mysql.QueryAttribute{
Name: "qattr1",
Value: "qattr1val",
},
}
require.Equal(s.T(), expected, s.c.queryAttributes)
}
125 changes: 85 additions & 40 deletions client/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"math"
"runtime"

. "github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/utils"
Expand Down Expand Up @@ -56,18 +57,34 @@ func (s *Stmt) Close() error {
return nil
}

// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
func (s *Stmt) write(args ...interface{}) error {
defer clear(s.conn.queryAttributes)
paramsNum := s.params

if len(args) != paramsNum {
return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args))
}

paramTypes := make([]byte, paramsNum<<1)
paramValues := make([][]byte, paramsNum)
if (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0) && (s.conn.includeLine >= 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0) && (s.conn.includeLine >= 0) {
if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 && s.conn.includeLine >= 0 {

the old code has clear priority. It's OK to keep it.

_, file, line, ok := runtime.Caller(s.conn.includeLine)
if ok {
lineAttr := QueryAttribute{
Name: "_line",
Value: fmt.Sprintf("%s:%d", file, line),
}
s.conn.queryAttributes = append(s.conn.queryAttributes, lineAttr)
}
}

qaLen := len(s.conn.queryAttributes)
paramTypes := make([][]byte, paramsNum+qaLen)
paramFlags := make([][]byte, paramsNum+qaLen)
paramValues := make([][]byte, paramsNum+qaLen)
paramNames := make([][]byte, paramsNum+qaLen)

//NULL-bitmap, length: (num-params+7)
nullBitmap := make([]byte, (paramsNum+7)>>3)
nullBitmap := make([]byte, (paramsNum+qaLen+7)>>3)

length := 1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1)

Expand All @@ -76,76 +93,89 @@ func (s *Stmt) write(args ...interface{}) error {
for i := range args {
if args[i] == nil {
nullBitmap[i/8] |= 1 << (uint(i) % 8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part is a bit complicated, I'll review later 😂

paramTypes[i<<1] = MYSQL_TYPE_NULL
paramTypes[i] = []byte{MYSQL_TYPE_NULL}
paramNames[i] = []byte{0} // length encoded, no name
paramFlags[i] = []byte{0}
continue
}

newParamBoundFlag = 1

switch v := args[i].(type) {
case int8:
paramTypes[i<<1] = MYSQL_TYPE_TINY
paramTypes[i] = []byte{MYSQL_TYPE_TINY}
paramValues[i] = []byte{byte(v)}
case int16:
paramTypes[i<<1] = MYSQL_TYPE_SHORT
paramTypes[i] = []byte{MYSQL_TYPE_SHORT}
paramValues[i] = Uint16ToBytes(uint16(v))
case int32:
paramTypes[i<<1] = MYSQL_TYPE_LONG
paramTypes[i] = []byte{MYSQL_TYPE_LONG}
paramValues[i] = Uint32ToBytes(uint32(v))
case int:
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG}
paramValues[i] = Uint64ToBytes(uint64(v))
case int64:
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG}
paramValues[i] = Uint64ToBytes(uint64(v))
case uint8:
paramTypes[i<<1] = MYSQL_TYPE_TINY
paramTypes[(i<<1)+1] = 0x80
paramTypes[i] = []byte{MYSQL_TYPE_TINY}
paramFlags[i] = []byte{PARAM_UNSIGNED}
paramValues[i] = []byte{v}
case uint16:
paramTypes[i<<1] = MYSQL_TYPE_SHORT
paramTypes[(i<<1)+1] = 0x80
paramTypes[i] = []byte{MYSQL_TYPE_SHORT}
paramFlags[i] = []byte{PARAM_UNSIGNED}
paramValues[i] = Uint16ToBytes(v)
case uint32:
paramTypes[i<<1] = MYSQL_TYPE_LONG
paramTypes[(i<<1)+1] = 0x80
paramTypes[i] = []byte{MYSQL_TYPE_LONG}
paramFlags[i] = []byte{PARAM_UNSIGNED}
paramValues[i] = Uint32ToBytes(v)
case uint:
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
paramTypes[(i<<1)+1] = 0x80
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG}
paramFlags[i] = []byte{PARAM_UNSIGNED}
paramValues[i] = Uint64ToBytes(uint64(v))
case uint64:
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
paramTypes[(i<<1)+1] = 0x80
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG}
paramFlags[i] = []byte{PARAM_UNSIGNED}
paramValues[i] = Uint64ToBytes(v)
case bool:
paramTypes[i<<1] = MYSQL_TYPE_TINY
paramTypes[i] = []byte{MYSQL_TYPE_TINY}
if v {
paramValues[i] = []byte{1}
} else {
paramValues[i] = []byte{0}
}
case float32:
paramTypes[i<<1] = MYSQL_TYPE_FLOAT
paramTypes[i] = []byte{MYSQL_TYPE_FLOAT}
paramValues[i] = Uint32ToBytes(math.Float32bits(v))
case float64:
paramTypes[i<<1] = MYSQL_TYPE_DOUBLE
paramTypes[i] = []byte{MYSQL_TYPE_DOUBLE}
paramValues[i] = Uint64ToBytes(math.Float64bits(v))
case string:
paramTypes[i<<1] = MYSQL_TYPE_STRING
paramTypes[i] = []byte{MYSQL_TYPE_STRING}
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
case []byte:
paramTypes[i<<1] = MYSQL_TYPE_STRING
paramTypes[i] = []byte{MYSQL_TYPE_STRING}
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
case json.RawMessage:
paramTypes[i<<1] = MYSQL_TYPE_STRING
paramTypes[i] = []byte{MYSQL_TYPE_STRING}
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
default:
return fmt.Errorf("invalid argument type %T", args[i])
}
paramNames[i] = []byte{0} // length encoded, no name
if paramFlags[i] == nil {
paramFlags[i] = []byte{0}
}

length += len(paramValues[i])
}
for i, qa := range s.conn.queryAttributes {
tf := qa.TypeAndFlag()
paramTypes[(i + paramsNum)] = []byte{tf[0]}
paramFlags[i+paramsNum] = []byte{tf[1]}
paramValues[i+paramsNum] = qa.ValueBytes()
paramNames[i+paramsNum] = PutLengthEncodedString([]byte(qa.Name))
}

data := utils.BytesBufferGet()
defer func() {
Expand All @@ -159,25 +189,40 @@ func (s *Stmt) write(args ...interface{}) error {
data.WriteByte(COM_STMT_EXECUTE)
data.Write([]byte{byte(s.id), byte(s.id >> 8), byte(s.id >> 16), byte(s.id >> 24)})

//flag: CURSOR_TYPE_NO_CURSOR
data.WriteByte(0x00)
flags := CURSOR_TYPE_NO_CURSOR
if paramsNum > 0 {
flags |= PARAMETER_COUNT_AVAILABLE
}
data.WriteByte(flags)

//iteration-count, always 1
data.Write([]byte{1, 0, 0, 0})

if s.params > 0 {
data.Write(nullBitmap)

//new-params-bound-flag
data.WriteByte(newParamBoundFlag)

if newParamBoundFlag == 1 {
//type of each parameter, length: num-params * 2
data.Write(paramTypes)

//value of each parameter
for _, v := range paramValues {
data.Write(v)
if paramsNum > 0 || (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 && (flags&PARAMETER_COUNT_AVAILABLE > 0)) {
if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 {
paramsNum += len(s.conn.queryAttributes)
data.Write(PutLengthEncodedInt(uint64(paramsNum)))
}
if paramsNum > 0 {
data.Write(nullBitmap)

//new-params-bound-flag
data.WriteByte(newParamBoundFlag)

if newParamBoundFlag == 1 {
for i := 0; i < paramsNum; i++ {
data.Write(paramTypes[i])
data.Write(paramFlags[i])

if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 {
data.Write(paramNames[i])
}
}

//value of each parameter
for _, v := range paramValues {
data.Write(v)
}
}
}
}
Expand Down
13 changes: 13 additions & 0 deletions mysql/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ const (
UNIQUE_FLAG = 65536
)

const (
PARAM_UNSIGNED = 128
)

const (
DEFAULT_ADDR = "127.0.0.1:3306"
DEFAULT_IPV6_ADDR = "[::1]:3306"
Expand Down Expand Up @@ -209,3 +213,12 @@ const (
MYSQL_COMPRESS_ZLIB
MYSQL_COMPRESS_ZSTD
)

// See enum_cursor_type in mysql.h
const (
CURSOR_TYPE_NO_CURSOR byte = 0x0
CURSOR_TYPE_READ_ONLY byte = 0x1
CURSOR_TYPE_FOR_UPDATE byte = 0x2
CURSOR_TYPE_SCROLLABLE byte = 0x4
PARAMETER_COUNT_AVAILABLE byte = 0x8
)
Loading
Loading