@@ -56,18 +56,22 @@ func (s *Stmt) Close() error {
56
56
return nil
57
57
}
58
58
59
+ // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
59
60
func (s * Stmt ) write (args ... interface {}) error {
60
61
paramsNum := s .params
61
62
62
63
if len (args ) != paramsNum {
63
64
return fmt .Errorf ("argument mismatch, need %d but got %d" , s .params , len (args ))
64
65
}
65
66
66
- paramTypes := make ([]byte , paramsNum << 1 )
67
- paramValues := make ([][]byte , paramsNum )
67
+ qaLen := len (s .conn .queryAttributes )
68
+ paramTypes := make ([][]byte , paramsNum + qaLen )
69
+ paramFlags := make ([][]byte , paramsNum + qaLen )
70
+ paramValues := make ([][]byte , paramsNum + qaLen )
71
+ paramNames := make ([][]byte , paramsNum + qaLen )
68
72
69
73
//NULL-bitmap, length: (num-params+7)
70
- nullBitmap := make ([]byte , (paramsNum + 7 )>> 3 )
74
+ nullBitmap := make ([]byte , (paramsNum + qaLen + 7 )>> 3 )
71
75
72
76
length := 1 + 4 + 1 + 4 + ((paramsNum + 7 ) >> 3 ) + 1 + (paramsNum << 1 )
73
77
@@ -76,76 +80,87 @@ func (s *Stmt) write(args ...interface{}) error {
76
80
for i := range args {
77
81
if args [i ] == nil {
78
82
nullBitmap [i / 8 ] |= 1 << (uint (i ) % 8 )
79
- paramTypes [i << 1 ] = MYSQL_TYPE_NULL
83
+ paramTypes [i ] = [] byte { MYSQL_TYPE_NULL }
80
84
continue
81
85
}
82
86
83
87
newParamBoundFlag = 1
84
88
85
89
switch v := args [i ].(type ) {
86
90
case int8 :
87
- paramTypes [i << 1 ] = MYSQL_TYPE_TINY
91
+ paramTypes [i ] = [] byte { MYSQL_TYPE_TINY }
88
92
paramValues [i ] = []byte {byte (v )}
89
93
case int16 :
90
- paramTypes [i << 1 ] = MYSQL_TYPE_SHORT
94
+ paramTypes [i ] = [] byte { MYSQL_TYPE_SHORT }
91
95
paramValues [i ] = Uint16ToBytes (uint16 (v ))
92
96
case int32 :
93
- paramTypes [i << 1 ] = MYSQL_TYPE_LONG
97
+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONG }
94
98
paramValues [i ] = Uint32ToBytes (uint32 (v ))
95
99
case int :
96
- paramTypes [i << 1 ] = MYSQL_TYPE_LONGLONG
100
+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONGLONG }
97
101
paramValues [i ] = Uint64ToBytes (uint64 (v ))
98
102
case int64 :
99
- paramTypes [i << 1 ] = MYSQL_TYPE_LONGLONG
103
+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONGLONG }
100
104
paramValues [i ] = Uint64ToBytes (uint64 (v ))
101
105
case uint8 :
102
- paramTypes [i << 1 ] = MYSQL_TYPE_TINY
103
- paramTypes [( i << 1 ) + 1 ] = 0x80
106
+ paramTypes [i ] = [] byte { MYSQL_TYPE_TINY }
107
+ paramFlags [ i ] = [] byte { UNSIGNED_FLAG }
104
108
paramValues [i ] = []byte {v }
105
109
case uint16 :
106
- paramTypes [i << 1 ] = MYSQL_TYPE_SHORT
107
- paramTypes [( i << 1 ) + 1 ] = 0x80
110
+ paramTypes [i ] = [] byte { MYSQL_TYPE_SHORT }
111
+ paramFlags [ i ] = [] byte { UNSIGNED_FLAG }
108
112
paramValues [i ] = Uint16ToBytes (v )
109
113
case uint32 :
110
- paramTypes [i << 1 ] = MYSQL_TYPE_LONG
111
- paramTypes [( i << 1 ) + 1 ] = 0x80
114
+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONG }
115
+ paramFlags [ i ] = [] byte { UNSIGNED_FLAG }
112
116
paramValues [i ] = Uint32ToBytes (v )
113
117
case uint :
114
- paramTypes [i << 1 ] = MYSQL_TYPE_LONGLONG
115
- paramTypes [( i << 1 ) + 1 ] = 0x80
118
+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONGLONG }
119
+ paramFlags [ i ] = [] byte { UNSIGNED_FLAG }
116
120
paramValues [i ] = Uint64ToBytes (uint64 (v ))
117
121
case uint64 :
118
- paramTypes [i << 1 ] = MYSQL_TYPE_LONGLONG
119
- paramTypes [( i << 1 ) + 1 ] = 0x80
122
+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONGLONG }
123
+ paramFlags [ i ] = [] byte { UNSIGNED_FLAG }
120
124
paramValues [i ] = Uint64ToBytes (v )
121
125
case bool :
122
- paramTypes [i << 1 ] = MYSQL_TYPE_TINY
126
+ paramTypes [i ] = [] byte { MYSQL_TYPE_TINY }
123
127
if v {
124
128
paramValues [i ] = []byte {1 }
125
129
} else {
126
130
paramValues [i ] = []byte {0 }
127
131
}
128
132
case float32 :
129
- paramTypes [i << 1 ] = MYSQL_TYPE_FLOAT
133
+ paramTypes [i ] = [] byte { MYSQL_TYPE_FLOAT }
130
134
paramValues [i ] = Uint32ToBytes (math .Float32bits (v ))
131
135
case float64 :
132
- paramTypes [i << 1 ] = MYSQL_TYPE_DOUBLE
136
+ paramTypes [i ] = [] byte { MYSQL_TYPE_DOUBLE }
133
137
paramValues [i ] = Uint64ToBytes (math .Float64bits (v ))
134
138
case string :
135
- paramTypes [i << 1 ] = MYSQL_TYPE_STRING
139
+ paramTypes [i ] = [] byte { MYSQL_TYPE_STRING }
136
140
paramValues [i ] = append (PutLengthEncodedInt (uint64 (len (v ))), v ... )
137
141
case []byte :
138
- paramTypes [i << 1 ] = MYSQL_TYPE_STRING
142
+ paramTypes [i ] = [] byte { MYSQL_TYPE_STRING }
139
143
paramValues [i ] = append (PutLengthEncodedInt (uint64 (len (v ))), v ... )
140
144
case json.RawMessage :
141
- paramTypes [i << 1 ] = MYSQL_TYPE_STRING
145
+ paramTypes [i ] = [] byte { MYSQL_TYPE_STRING }
142
146
paramValues [i ] = append (PutLengthEncodedInt (uint64 (len (v ))), v ... )
143
147
default :
144
148
return fmt .Errorf ("invalid argument type %T" , args [i ])
145
149
}
150
+ paramNames [i ] = []byte {0 } // lenght encoded, no name
151
+ if paramFlags [i ] == nil {
152
+ paramFlags [i ] = []byte {0 }
153
+ }
146
154
147
155
length += len (paramValues [i ])
148
156
}
157
+ for i , qa := range s .conn .queryAttributes {
158
+ tf := qa .TypeAndFlag ()
159
+ paramTypes [(i + paramsNum )] = []byte {tf [0 ]}
160
+ paramFlags [i + paramsNum ] = []byte {tf [1 ]}
161
+ paramValues [i + paramsNum ] = qa .ValueBytes ()
162
+ paramNames [i + paramsNum ] = PutLengthEncodedString ([]byte (qa .Name ))
163
+ }
149
164
150
165
data := utils .BytesBufferGet ()
151
166
defer func () {
@@ -159,30 +174,46 @@ func (s *Stmt) write(args ...interface{}) error {
159
174
data .WriteByte (COM_STMT_EXECUTE )
160
175
data .Write ([]byte {byte (s .id ), byte (s .id >> 8 ), byte (s .id >> 16 ), byte (s .id >> 24 )})
161
176
162
- //flag: CURSOR_TYPE_NO_CURSOR
163
- data .WriteByte (0x00 )
177
+ flags := CURSOR_TYPE_NO_CURSOR
178
+ if s .conn .capability & CLIENT_QUERY_ATTRIBUTES > 0 && len (s .conn .queryAttributes ) > 0 {
179
+ flags |= PARAMETER_COUNT_AVAILABLE
180
+ }
181
+ data .WriteByte (flags )
164
182
165
183
//iteration-count, always 1
166
184
data .Write ([]byte {1 , 0 , 0 , 0 })
167
185
168
- if s .params > 0 {
169
- data .Write (nullBitmap )
170
-
171
- //new-params-bound-flag
172
- data .WriteByte (newParamBoundFlag )
173
-
174
- if newParamBoundFlag == 1 {
175
- //type of each parameter, length: num-params * 2
176
- data .Write (paramTypes )
177
-
178
- //value of each parameter
179
- for _ , v := range paramValues {
180
- data .Write (v )
186
+ if paramsNum > 0 || (s .conn .capability & CLIENT_QUERY_ATTRIBUTES > 0 && (flags & PARAMETER_COUNT_AVAILABLE > 0 )) {
187
+ if s .conn .capability & CLIENT_QUERY_ATTRIBUTES > 0 {
188
+ paramsNum += len (s .conn .queryAttributes )
189
+ data .Write (PutLengthEncodedInt (uint64 (paramsNum )))
190
+ }
191
+ if paramsNum > 0 {
192
+ data .Write (nullBitmap )
193
+
194
+ //new-params-bound-flag
195
+ data .WriteByte (newParamBoundFlag )
196
+
197
+ if newParamBoundFlag == 1 {
198
+ for i := 0 ; i < paramsNum ; i ++ {
199
+ data .Write (paramTypes [i ])
200
+ data .Write (paramFlags [i ])
201
+
202
+ if s .conn .capability & CLIENT_QUERY_ATTRIBUTES > 0 {
203
+ data .Write (paramNames [i ])
204
+ }
205
+ }
206
+
207
+ //value of each parameter
208
+ for _ , v := range paramValues {
209
+ data .Write (v )
210
+ }
181
211
}
182
212
}
183
213
}
184
214
185
215
s .conn .ResetSequence ()
216
+ s .conn .queryAttributes = nil
186
217
187
218
return s .conn .WritePacket (data .Bytes ())
188
219
}
0 commit comments