Skip to content

Commit 740b4c6

Browse files
committed
client,mysql: Add support for Query Attributes
1 parent d00dff7 commit 740b4c6

File tree

5 files changed

+165
-41
lines changed

5 files changed

+165
-41
lines changed

client/auth.go

+1
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ func (c *Conn) writeAuthHandshake() error {
203203
CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_PLUGIN_AUTH
204204
// Adjust client capability flags based on server support
205205
capability |= c.capability & CLIENT_LONG_FLAG
206+
capability |= c.capability & CLIENT_QUERY_ATTRIBUTES
206207
// Adjust client capability flags on specific client requests
207208
// Only flags that would make any sense setting and aren't handled elsewhere
208209
// in the library are supported here

client/conn.go

+38-1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ type Conn struct {
5656
authPluginName string
5757

5858
connectionID uint32
59+
60+
queryAttributes []QueryAttribute
5961
}
6062

6163
// This function will be called for every row in resultset from ExecuteSelectStreaming.
@@ -481,10 +483,40 @@ func (c *Conn) ReadOKPacket() (*Result, error) {
481483
return c.readOK()
482484
}
483485

486+
// Sends COM_QUERY
487+
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html
484488
func (c *Conn) exec(query string) (*Result, error) {
485-
if err := c.writeCommandStr(COM_QUERY, query); err != nil {
489+
var buf bytes.Buffer
490+
491+
if c.capability&CLIENT_QUERY_ATTRIBUTES > 0 {
492+
numParams := len(c.queryAttributes)
493+
buf.Write(PutLengthEncodedInt(uint64(numParams)))
494+
buf.WriteByte(0x1) // parameter_set_count, unused
495+
if numParams > 0 {
496+
// null_bitmap, length: (num_params+7)/8
497+
for i := 0; i < (numParams+7)/8; i++ {
498+
buf.WriteByte(0x0)
499+
}
500+
buf.WriteByte(0x1) // new_params_bind_flag, unused
501+
for _, qa := range c.queryAttributes {
502+
buf.Write(qa.TypeAndFlag())
503+
buf.Write(PutLengthEncodedString([]byte(qa.Name)))
504+
}
505+
for _, qa := range c.queryAttributes {
506+
buf.Write(qa.ValueBytes())
507+
}
508+
}
509+
}
510+
511+
_, err := buf.Write(utils.StringToByteSlice(query))
512+
if err != nil {
513+
return nil, err
514+
}
515+
516+
if err := c.writeCommandBuf(COM_QUERY, buf.Bytes()); err != nil {
486517
return nil, errors.Trace(err)
487518
}
519+
c.queryAttributes = nil
488520

489521
return c.readResult(false)
490522
}
@@ -619,3 +651,8 @@ func (c *Conn) StatusString() string {
619651

620652
return strings.Join(stats, "|")
621653
}
654+
655+
func (c *Conn) SetQueryAttributes(attrs ...QueryAttribute) error {
656+
c.queryAttributes = attrs
657+
return nil
658+
}

client/stmt.go

+71-40
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,22 @@ func (s *Stmt) Close() error {
5656
return nil
5757
}
5858

59+
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
5960
func (s *Stmt) write(args ...interface{}) error {
6061
paramsNum := s.params
6162

6263
if len(args) != paramsNum {
6364
return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args))
6465
}
6566

66-
paramTypes := make([]byte, paramsNum<<1)
67-
paramValues := make([][]byte, paramsNum)
67+
qaLen := len(s.conn.queryAttributes)
68+
paramTypes := make([][]byte, paramsNum+qaLen)
69+
paramFlags := make([][]byte, paramsNum+qaLen)
70+
paramValues := make([][]byte, paramsNum+qaLen)
71+
paramNames := make([][]byte, paramsNum+qaLen)
6872

6973
//NULL-bitmap, length: (num-params+7)
70-
nullBitmap := make([]byte, (paramsNum+7)>>3)
74+
nullBitmap := make([]byte, (paramsNum+qaLen+7)>>3)
7175

7276
length := 1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1)
7377

@@ -76,76 +80,87 @@ func (s *Stmt) write(args ...interface{}) error {
7680
for i := range args {
7781
if args[i] == nil {
7882
nullBitmap[i/8] |= 1 << (uint(i) % 8)
79-
paramTypes[i<<1] = MYSQL_TYPE_NULL
83+
paramTypes[i] = []byte{MYSQL_TYPE_NULL}
8084
continue
8185
}
8286

8387
newParamBoundFlag = 1
8488

8589
switch v := args[i].(type) {
8690
case int8:
87-
paramTypes[i<<1] = MYSQL_TYPE_TINY
91+
paramTypes[i] = []byte{MYSQL_TYPE_TINY}
8892
paramValues[i] = []byte{byte(v)}
8993
case int16:
90-
paramTypes[i<<1] = MYSQL_TYPE_SHORT
94+
paramTypes[i] = []byte{MYSQL_TYPE_SHORT}
9195
paramValues[i] = Uint16ToBytes(uint16(v))
9296
case int32:
93-
paramTypes[i<<1] = MYSQL_TYPE_LONG
97+
paramTypes[i] = []byte{MYSQL_TYPE_LONG}
9498
paramValues[i] = Uint32ToBytes(uint32(v))
9599
case int:
96-
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
100+
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG}
97101
paramValues[i] = Uint64ToBytes(uint64(v))
98102
case int64:
99-
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
103+
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG}
100104
paramValues[i] = Uint64ToBytes(uint64(v))
101105
case uint8:
102-
paramTypes[i<<1] = MYSQL_TYPE_TINY
103-
paramTypes[(i<<1)+1] = 0x80
106+
paramTypes[i] = []byte{MYSQL_TYPE_TINY}
107+
paramFlags[i] = []byte{UNSIGNED_FLAG}
104108
paramValues[i] = []byte{v}
105109
case uint16:
106-
paramTypes[i<<1] = MYSQL_TYPE_SHORT
107-
paramTypes[(i<<1)+1] = 0x80
110+
paramTypes[i] = []byte{MYSQL_TYPE_SHORT}
111+
paramFlags[i] = []byte{UNSIGNED_FLAG}
108112
paramValues[i] = Uint16ToBytes(v)
109113
case uint32:
110-
paramTypes[i<<1] = MYSQL_TYPE_LONG
111-
paramTypes[(i<<1)+1] = 0x80
114+
paramTypes[i] = []byte{MYSQL_TYPE_LONG}
115+
paramFlags[i] = []byte{UNSIGNED_FLAG}
112116
paramValues[i] = Uint32ToBytes(v)
113117
case uint:
114-
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
115-
paramTypes[(i<<1)+1] = 0x80
118+
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG}
119+
paramFlags[i] = []byte{UNSIGNED_FLAG}
116120
paramValues[i] = Uint64ToBytes(uint64(v))
117121
case uint64:
118-
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
119-
paramTypes[(i<<1)+1] = 0x80
122+
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG}
123+
paramFlags[i] = []byte{UNSIGNED_FLAG}
120124
paramValues[i] = Uint64ToBytes(v)
121125
case bool:
122-
paramTypes[i<<1] = MYSQL_TYPE_TINY
126+
paramTypes[i] = []byte{MYSQL_TYPE_TINY}
123127
if v {
124128
paramValues[i] = []byte{1}
125129
} else {
126130
paramValues[i] = []byte{0}
127131
}
128132
case float32:
129-
paramTypes[i<<1] = MYSQL_TYPE_FLOAT
133+
paramTypes[i] = []byte{MYSQL_TYPE_FLOAT}
130134
paramValues[i] = Uint32ToBytes(math.Float32bits(v))
131135
case float64:
132-
paramTypes[i<<1] = MYSQL_TYPE_DOUBLE
136+
paramTypes[i] = []byte{MYSQL_TYPE_DOUBLE}
133137
paramValues[i] = Uint64ToBytes(math.Float64bits(v))
134138
case string:
135-
paramTypes[i<<1] = MYSQL_TYPE_STRING
139+
paramTypes[i] = []byte{MYSQL_TYPE_STRING}
136140
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
137141
case []byte:
138-
paramTypes[i<<1] = MYSQL_TYPE_STRING
142+
paramTypes[i] = []byte{MYSQL_TYPE_STRING}
139143
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
140144
case json.RawMessage:
141-
paramTypes[i<<1] = MYSQL_TYPE_STRING
145+
paramTypes[i] = []byte{MYSQL_TYPE_STRING}
142146
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
143147
default:
144148
return fmt.Errorf("invalid argument type %T", args[i])
145149
}
150+
paramNames[i] = []byte{0} // lenght encoded, no name
151+
if paramFlags[i] == nil {
152+
paramFlags[i] = []byte{0}
153+
}
146154

147155
length += len(paramValues[i])
148156
}
157+
for i, qa := range s.conn.queryAttributes {
158+
tf := qa.TypeAndFlag()
159+
paramTypes[(i + paramsNum)] = []byte{tf[0]}
160+
paramFlags[i+paramsNum] = []byte{tf[1]}
161+
paramValues[i+paramsNum] = qa.ValueBytes()
162+
paramNames[i+paramsNum] = PutLengthEncodedString([]byte(qa.Name))
163+
}
149164

150165
data := utils.BytesBufferGet()
151166
defer func() {
@@ -159,30 +174,46 @@ func (s *Stmt) write(args ...interface{}) error {
159174
data.WriteByte(COM_STMT_EXECUTE)
160175
data.Write([]byte{byte(s.id), byte(s.id >> 8), byte(s.id >> 16), byte(s.id >> 24)})
161176

162-
//flag: CURSOR_TYPE_NO_CURSOR
163-
data.WriteByte(0x00)
177+
flags := CURSOR_TYPE_NO_CURSOR
178+
if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 && len(s.conn.queryAttributes) > 0 {
179+
flags |= PARAMETER_COUNT_AVAILABLE
180+
}
181+
data.WriteByte(flags)
164182

165183
//iteration-count, always 1
166184
data.Write([]byte{1, 0, 0, 0})
167185

168-
if s.params > 0 {
169-
data.Write(nullBitmap)
170-
171-
//new-params-bound-flag
172-
data.WriteByte(newParamBoundFlag)
173-
174-
if newParamBoundFlag == 1 {
175-
//type of each parameter, length: num-params * 2
176-
data.Write(paramTypes)
177-
178-
//value of each parameter
179-
for _, v := range paramValues {
180-
data.Write(v)
186+
if paramsNum > 0 || (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 && (flags&PARAMETER_COUNT_AVAILABLE > 0)) {
187+
if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 {
188+
paramsNum += len(s.conn.queryAttributes)
189+
data.Write(PutLengthEncodedInt(uint64(paramsNum)))
190+
}
191+
if paramsNum > 0 {
192+
data.Write(nullBitmap)
193+
194+
//new-params-bound-flag
195+
data.WriteByte(newParamBoundFlag)
196+
197+
if newParamBoundFlag == 1 {
198+
for i := 0; i < paramsNum; i++ {
199+
data.Write(paramTypes[i])
200+
data.Write(paramFlags[i])
201+
202+
if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 {
203+
data.Write(paramNames[i])
204+
}
205+
}
206+
207+
//value of each parameter
208+
for _, v := range paramValues {
209+
data.Write(v)
210+
}
181211
}
182212
}
183213
}
184214

185215
s.conn.ResetSequence()
216+
s.conn.queryAttributes = nil
186217

187218
return s.conn.WritePacket(data.Bytes())
188219
}

mysql/const.go

+9
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,12 @@ const (
209209
MYSQL_COMPRESS_ZLIB
210210
MYSQL_COMPRESS_ZSTD
211211
)
212+
213+
// See enum_cursor_type in mysql.h
214+
const (
215+
CURSOR_TYPE_NO_CURSOR byte = 0x0
216+
CURSOR_TYPE_READ_ONLY byte = 0x1
217+
CURSOR_TYPE_FOR_UPDATE byte = 0x2
218+
CURSOR_TYPE_SCROLLABLE byte = 0x4
219+
PARAMETER_COUNT_AVAILABLE byte = 0x8
220+
)

mysql/queryattributes.go

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package mysql
2+
3+
import (
4+
"encoding/binary"
5+
6+
"github.com/siddontang/go-log/log"
7+
)
8+
9+
// Query Attributes in MySQL are key/value pairs passed along with COM_QUERY or COM_STMT_EXECUTE
10+
//
11+
// Resources:
12+
// - https://dev.mysql.com/doc/refman/8.4/en/query-attributes.html
13+
// - https://github.com/mysql/mysql-server/blob/trunk/include/mysql/components/services/mysql_query_attributes.h
14+
// - https://archive.fosdem.org/2021/schedule/event/mysql_protocl/
15+
type QueryAttribute struct {
16+
Name string
17+
Value interface{}
18+
}
19+
20+
// TypeAndFlag returns the type MySQL field type of the value and the field flag.
21+
func (qa *QueryAttribute) TypeAndFlag() []byte {
22+
switch v := qa.Value.(type) {
23+
case string:
24+
return []byte{MYSQL_TYPE_STRING, 0x0}
25+
case uint64:
26+
return []byte{MYSQL_TYPE_LONGLONG, UNSIGNED_FLAG}
27+
default:
28+
log.Warnf("query attribute with unsupported type %T", v)
29+
}
30+
return []byte{0x0, 0x0} // type 0x0, flag 0x0, to not break the protocol
31+
}
32+
33+
// ValueBytes returns the encoded value
34+
func (qa *QueryAttribute) ValueBytes() []byte {
35+
switch v := qa.Value.(type) {
36+
case string:
37+
return PutLengthEncodedString([]byte(v))
38+
case uint64:
39+
b := make([]byte, 8)
40+
binary.LittleEndian.PutUint64(b, v)
41+
return b
42+
default:
43+
log.Warnf("query attribute with unsupported type %T", v)
44+
}
45+
return []byte{0x0} // 0 length value to not break the protocol
46+
}

0 commit comments

Comments
 (0)