Skip to content

Commit ffd15ad

Browse files
dveedenlance6716
andauthored
client,mysql: Add support for Query Attributes (#976)
* client,mysql: Add support for Query Attributes * fixup * Add includeLine option * Add _client_version connattr * Fix nil args * update * update * fix flags * update * updates * server: accept PARAMETER_COUNT_AVAILABLE flag * fixup * Apply suggestions from code review Co-authored-by: lance6716 <[email protected]> * Allow setting of frame * Apply suggestion * fixup * Update based on review * Add tests and fix error * more tests * Add another test * Fix linter found issues --------- Co-authored-by: lance6716 <[email protected]>
1 parent a5c10ba commit ffd15ad

9 files changed

+338
-49
lines changed

client/auth.go

+1
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ func (c *Conn) writeAuthHandshake() error {
203203
CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_PLUGIN_AUTH
204204
// Adjust client capability flags based on server support
205205
capability |= c.capability & CLIENT_LONG_FLAG
206+
capability |= c.capability & CLIENT_QUERY_ATTRIBUTES
206207
// Adjust client capability flags on specific client requests
207208
// Only flags that would make any sense setting and aren't handled elsewhere
208209
// in the library are supported here

client/conn.go

+74-4
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ type Conn struct {
5757
authPluginName string
5858

5959
connectionID uint32
60+
61+
queryAttributes []QueryAttribute
62+
63+
// Include the file + line as query attribute. The number set which frame in the stack should be used.
64+
includeLine int
6065
}
6166

6267
// This function will be called for every row in resultset from ExecuteSelectStreaming.
@@ -100,6 +105,7 @@ type Dialer func(ctx context.Context, network, address string) (net.Conn, error)
100105
func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbName string, dialer Dialer, options ...Option) (*Conn, error) {
101106
c := new(Conn)
102107

108+
c.includeLine = -1
103109
c.BufferSize = defaultBufferSize
104110
c.attributes = map[string]string{
105111
"_client_name": "go-mysql",
@@ -310,7 +316,7 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) {
310316
// flag set to signal the server multiple queries are executed. Handling the responses
311317
// is up to the implementation of perResultCallback.
312318
func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCallback) (*Result, error) {
313-
if err := c.writeCommandStr(COM_QUERY, query); err != nil {
319+
if err := c.execSend(query); err != nil {
314320
return nil, errors.Trace(err)
315321
}
316322

@@ -363,7 +369,7 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall
363369
//
364370
// ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving.
365371
func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error {
366-
if err := c.writeCommandStr(COM_QUERY, command); err != nil {
372+
if err := c.execSend(command); err != nil {
367373
return errors.Trace(err)
368374
}
369375

@@ -489,14 +495,64 @@ func (c *Conn) ReadOKPacket() (*Result, error) {
489495
return c.readOK()
490496
}
491497

498+
// Send COM_QUERY and read the result
492499
func (c *Conn) exec(query string) (*Result, error) {
493-
if err := c.writeCommandStr(COM_QUERY, query); err != nil {
500+
err := c.execSend(query)
501+
if err != nil {
494502
return nil, errors.Trace(err)
495503
}
496-
497504
return c.readResult(false)
498505
}
499506

507+
// Sends COM_QUERY
508+
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html
509+
func (c *Conn) execSend(query string) error {
510+
var buf bytes.Buffer
511+
defer clear(c.queryAttributes)
512+
513+
if c.capability&CLIENT_QUERY_ATTRIBUTES > 0 {
514+
if c.includeLine >= 0 {
515+
_, file, line, ok := runtime.Caller(c.includeLine)
516+
if ok {
517+
lineAttr := QueryAttribute{
518+
Name: "_line",
519+
Value: fmt.Sprintf("%s:%d", file, line),
520+
}
521+
c.queryAttributes = append(c.queryAttributes, lineAttr)
522+
}
523+
}
524+
525+
numParams := len(c.queryAttributes)
526+
buf.Write(PutLengthEncodedInt(uint64(numParams)))
527+
buf.WriteByte(0x1) // parameter_set_count, unused
528+
if numParams > 0 {
529+
// null_bitmap, length: (num_params+7)/8
530+
for i := 0; i < (numParams+7)/8; i++ {
531+
buf.WriteByte(0x0)
532+
}
533+
buf.WriteByte(0x1) // new_params_bind_flag, unused
534+
for _, qa := range c.queryAttributes {
535+
buf.Write(qa.TypeAndFlag())
536+
buf.Write(PutLengthEncodedString([]byte(qa.Name)))
537+
}
538+
for _, qa := range c.queryAttributes {
539+
buf.Write(qa.ValueBytes())
540+
}
541+
}
542+
}
543+
544+
_, err := buf.Write(utils.StringToByteSlice(query))
545+
if err != nil {
546+
return err
547+
}
548+
549+
if err := c.writeCommandBuf(COM_QUERY, buf.Bytes()); err != nil {
550+
return errors.Trace(err)
551+
}
552+
553+
return nil
554+
}
555+
500556
// CapabilityString is returning a string with the names of capability flags
501557
// separated by "|". Examples of capability names are CLIENT_DEPRECATE_EOF and CLIENT_PROTOCOL_41.
502558
// These are defined as constants in the mysql package.
@@ -627,3 +683,17 @@ func (c *Conn) StatusString() string {
627683

628684
return strings.Join(stats, "|")
629685
}
686+
687+
// SetQueryAttributes sets the query attributes to be send along with the next query
688+
func (c *Conn) SetQueryAttributes(attrs ...QueryAttribute) error {
689+
c.queryAttributes = attrs
690+
return nil
691+
}
692+
693+
// IncludeLine can be passed as option when connecting to include the file name and line number
694+
// of the caller as query attribute `_line` when sending queries.
695+
// The argument is used the dept in the stack. The top level is go-mysql and then there are the
696+
// levels of the application.
697+
func (c *Conn) IncludeLine(frame int) {
698+
c.includeLine = frame
699+
}

client/conn_test.go

+16
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,19 @@ func (s *connTestSuite) TestAttributes() {
194194
require.Equal(s.T(), "go-mysql", s.c.attributes["_client_name"])
195195
require.Equal(s.T(), "attrvalue", s.c.attributes["attrtest"])
196196
}
197+
198+
func (s *connTestSuite) TestSetQueryAttributes() {
199+
qa := mysql.QueryAttribute{
200+
Name: "qattr1",
201+
Value: "qattr1val",
202+
}
203+
err := s.c.SetQueryAttributes(qa)
204+
require.NoError(s.T(), err)
205+
expected := []mysql.QueryAttribute{
206+
mysql.QueryAttribute{
207+
Name: "qattr1",
208+
Value: "qattr1val",
209+
},
210+
}
211+
require.Equal(s.T(), expected, s.c.queryAttributes)
212+
}

client/stmt.go

+85-40
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"fmt"
77
"math"
8+
"runtime"
89

910
. "github.com/go-mysql-org/go-mysql/mysql"
1011
"github.com/go-mysql-org/go-mysql/utils"
@@ -56,18 +57,34 @@ func (s *Stmt) Close() error {
5657
return nil
5758
}
5859

60+
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
5961
func (s *Stmt) write(args ...interface{}) error {
62+
defer clear(s.conn.queryAttributes)
6063
paramsNum := s.params
6164

6265
if len(args) != paramsNum {
6366
return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args))
6467
}
6568

66-
paramTypes := make([]byte, paramsNum<<1)
67-
paramValues := make([][]byte, paramsNum)
69+
if (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0) && (s.conn.includeLine >= 0) {
70+
_, file, line, ok := runtime.Caller(s.conn.includeLine)
71+
if ok {
72+
lineAttr := QueryAttribute{
73+
Name: "_line",
74+
Value: fmt.Sprintf("%s:%d", file, line),
75+
}
76+
s.conn.queryAttributes = append(s.conn.queryAttributes, lineAttr)
77+
}
78+
}
79+
80+
qaLen := len(s.conn.queryAttributes)
81+
paramTypes := make([][]byte, paramsNum+qaLen)
82+
paramFlags := make([][]byte, paramsNum+qaLen)
83+
paramValues := make([][]byte, paramsNum+qaLen)
84+
paramNames := make([][]byte, paramsNum+qaLen)
6885

6986
//NULL-bitmap, length: (num-params+7)
70-
nullBitmap := make([]byte, (paramsNum+7)>>3)
87+
nullBitmap := make([]byte, (paramsNum+qaLen+7)>>3)
7188

7289
length := 1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1)
7390

@@ -76,76 +93,89 @@ func (s *Stmt) write(args ...interface{}) error {
7693
for i := range args {
7794
if args[i] == nil {
7895
nullBitmap[i/8] |= 1 << (uint(i) % 8)
79-
paramTypes[i<<1] = MYSQL_TYPE_NULL
96+
paramTypes[i] = []byte{MYSQL_TYPE_NULL}
97+
paramNames[i] = []byte{0} // length encoded, no name
98+
paramFlags[i] = []byte{0}
8099
continue
81100
}
82101

83102
newParamBoundFlag = 1
84103

85104
switch v := args[i].(type) {
86105
case int8:
87-
paramTypes[i<<1] = MYSQL_TYPE_TINY
106+
paramTypes[i] = []byte{MYSQL_TYPE_TINY}
88107
paramValues[i] = []byte{byte(v)}
89108
case int16:
90-
paramTypes[i<<1] = MYSQL_TYPE_SHORT
109+
paramTypes[i] = []byte{MYSQL_TYPE_SHORT}
91110
paramValues[i] = Uint16ToBytes(uint16(v))
92111
case int32:
93-
paramTypes[i<<1] = MYSQL_TYPE_LONG
112+
paramTypes[i] = []byte{MYSQL_TYPE_LONG}
94113
paramValues[i] = Uint32ToBytes(uint32(v))
95114
case int:
96-
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
115+
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG}
97116
paramValues[i] = Uint64ToBytes(uint64(v))
98117
case int64:
99-
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
118+
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG}
100119
paramValues[i] = Uint64ToBytes(uint64(v))
101120
case uint8:
102-
paramTypes[i<<1] = MYSQL_TYPE_TINY
103-
paramTypes[(i<<1)+1] = 0x80
121+
paramTypes[i] = []byte{MYSQL_TYPE_TINY}
122+
paramFlags[i] = []byte{PARAM_UNSIGNED}
104123
paramValues[i] = []byte{v}
105124
case uint16:
106-
paramTypes[i<<1] = MYSQL_TYPE_SHORT
107-
paramTypes[(i<<1)+1] = 0x80
125+
paramTypes[i] = []byte{MYSQL_TYPE_SHORT}
126+
paramFlags[i] = []byte{PARAM_UNSIGNED}
108127
paramValues[i] = Uint16ToBytes(v)
109128
case uint32:
110-
paramTypes[i<<1] = MYSQL_TYPE_LONG
111-
paramTypes[(i<<1)+1] = 0x80
129+
paramTypes[i] = []byte{MYSQL_TYPE_LONG}
130+
paramFlags[i] = []byte{PARAM_UNSIGNED}
112131
paramValues[i] = Uint32ToBytes(v)
113132
case uint:
114-
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
115-
paramTypes[(i<<1)+1] = 0x80
133+
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG}
134+
paramFlags[i] = []byte{PARAM_UNSIGNED}
116135
paramValues[i] = Uint64ToBytes(uint64(v))
117136
case uint64:
118-
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
119-
paramTypes[(i<<1)+1] = 0x80
137+
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG}
138+
paramFlags[i] = []byte{PARAM_UNSIGNED}
120139
paramValues[i] = Uint64ToBytes(v)
121140
case bool:
122-
paramTypes[i<<1] = MYSQL_TYPE_TINY
141+
paramTypes[i] = []byte{MYSQL_TYPE_TINY}
123142
if v {
124143
paramValues[i] = []byte{1}
125144
} else {
126145
paramValues[i] = []byte{0}
127146
}
128147
case float32:
129-
paramTypes[i<<1] = MYSQL_TYPE_FLOAT
148+
paramTypes[i] = []byte{MYSQL_TYPE_FLOAT}
130149
paramValues[i] = Uint32ToBytes(math.Float32bits(v))
131150
case float64:
132-
paramTypes[i<<1] = MYSQL_TYPE_DOUBLE
151+
paramTypes[i] = []byte{MYSQL_TYPE_DOUBLE}
133152
paramValues[i] = Uint64ToBytes(math.Float64bits(v))
134153
case string:
135-
paramTypes[i<<1] = MYSQL_TYPE_STRING
154+
paramTypes[i] = []byte{MYSQL_TYPE_STRING}
136155
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
137156
case []byte:
138-
paramTypes[i<<1] = MYSQL_TYPE_STRING
157+
paramTypes[i] = []byte{MYSQL_TYPE_STRING}
139158
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
140159
case json.RawMessage:
141-
paramTypes[i<<1] = MYSQL_TYPE_STRING
160+
paramTypes[i] = []byte{MYSQL_TYPE_STRING}
142161
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
143162
default:
144163
return fmt.Errorf("invalid argument type %T", args[i])
145164
}
165+
paramNames[i] = []byte{0} // length encoded, no name
166+
if paramFlags[i] == nil {
167+
paramFlags[i] = []byte{0}
168+
}
146169

147170
length += len(paramValues[i])
148171
}
172+
for i, qa := range s.conn.queryAttributes {
173+
tf := qa.TypeAndFlag()
174+
paramTypes[(i + paramsNum)] = []byte{tf[0]}
175+
paramFlags[i+paramsNum] = []byte{tf[1]}
176+
paramValues[i+paramsNum] = qa.ValueBytes()
177+
paramNames[i+paramsNum] = PutLengthEncodedString([]byte(qa.Name))
178+
}
149179

150180
data := utils.BytesBufferGet()
151181
defer func() {
@@ -159,25 +189,40 @@ func (s *Stmt) write(args ...interface{}) error {
159189
data.WriteByte(COM_STMT_EXECUTE)
160190
data.Write([]byte{byte(s.id), byte(s.id >> 8), byte(s.id >> 16), byte(s.id >> 24)})
161191

162-
//flag: CURSOR_TYPE_NO_CURSOR
163-
data.WriteByte(0x00)
192+
flags := CURSOR_TYPE_NO_CURSOR
193+
if paramsNum > 0 {
194+
flags |= PARAMETER_COUNT_AVAILABLE
195+
}
196+
data.WriteByte(flags)
164197

165198
//iteration-count, always 1
166199
data.Write([]byte{1, 0, 0, 0})
167200

168-
if s.params > 0 {
169-
data.Write(nullBitmap)
170-
171-
//new-params-bound-flag
172-
data.WriteByte(newParamBoundFlag)
173-
174-
if newParamBoundFlag == 1 {
175-
//type of each parameter, length: num-params * 2
176-
data.Write(paramTypes)
177-
178-
//value of each parameter
179-
for _, v := range paramValues {
180-
data.Write(v)
201+
if paramsNum > 0 || (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 && (flags&PARAMETER_COUNT_AVAILABLE > 0)) {
202+
if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 {
203+
paramsNum += len(s.conn.queryAttributes)
204+
data.Write(PutLengthEncodedInt(uint64(paramsNum)))
205+
}
206+
if paramsNum > 0 {
207+
data.Write(nullBitmap)
208+
209+
//new-params-bound-flag
210+
data.WriteByte(newParamBoundFlag)
211+
212+
if newParamBoundFlag == 1 {
213+
for i := 0; i < paramsNum; i++ {
214+
data.Write(paramTypes[i])
215+
data.Write(paramFlags[i])
216+
217+
if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 {
218+
data.Write(paramNames[i])
219+
}
220+
}
221+
222+
//value of each parameter
223+
for _, v := range paramValues {
224+
data.Write(v)
225+
}
181226
}
182227
}
183228
}

mysql/const.go

+13
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ const (
178178
UNIQUE_FLAG = 65536
179179
)
180180

181+
const (
182+
PARAM_UNSIGNED = 128
183+
)
184+
181185
const (
182186
DEFAULT_ADDR = "127.0.0.1:3306"
183187
DEFAULT_IPV6_ADDR = "[::1]:3306"
@@ -209,3 +213,12 @@ const (
209213
MYSQL_COMPRESS_ZLIB
210214
MYSQL_COMPRESS_ZSTD
211215
)
216+
217+
// See enum_cursor_type in mysql.h
218+
const (
219+
CURSOR_TYPE_NO_CURSOR byte = 0x0
220+
CURSOR_TYPE_READ_ONLY byte = 0x1
221+
CURSOR_TYPE_FOR_UPDATE byte = 0x2
222+
CURSOR_TYPE_SCROLLABLE byte = 0x4
223+
PARAMETER_COUNT_AVAILABLE byte = 0x8
224+
)

0 commit comments

Comments
 (0)