Skip to content

Commit 0d261b3

Browse files
dveedenlance6716
andauthoredFeb 7, 2025
Change how we import the mysql package (#982)
Co-authored-by: lance6716 <lance6716@gmail.com>
1 parent 6c3f3a6 commit 0d261b3

20 files changed

+625
-625
lines changed
 

‎client/auth.go

+37-37
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@ import (
66
"encoding/binary"
77
"fmt"
88

9-
. "github.com/go-mysql-org/go-mysql/mysql"
9+
"github.com/go-mysql-org/go-mysql/mysql"
1010
"github.com/go-mysql-org/go-mysql/packet"
1111
"github.com/pingcap/errors"
1212
"github.com/pingcap/tidb/pkg/parser/charset"
1313
)
1414

15-
const defaultAuthPluginName = AUTH_NATIVE_PASSWORD
15+
const defaultAuthPluginName = mysql.AUTH_NATIVE_PASSWORD
1616

1717
// defines the supported auth plugins
18-
var supportedAuthPlugins = []string{AUTH_NATIVE_PASSWORD, AUTH_SHA256_PASSWORD, AUTH_CACHING_SHA2_PASSWORD}
18+
var supportedAuthPlugins = []string{mysql.AUTH_NATIVE_PASSWORD, mysql.AUTH_SHA256_PASSWORD, mysql.AUTH_CACHING_SHA2_PASSWORD}
1919

2020
// helper function to determine what auth methods are allowed by this client
2121
func authPluginAllowed(pluginName string) bool {
@@ -38,12 +38,12 @@ func (c *Conn) readInitialHandshake() error {
3838
return errors.Trace(err)
3939
}
4040

41-
if data[0] == ERR_HEADER {
41+
if data[0] == mysql.ERR_HEADER {
4242
return errors.Annotate(c.handleErrorPacket(data), "read initial handshake error")
4343
}
4444

45-
if data[0] != ClassicProtocolVersion {
46-
if data[0] == XProtocolVersion {
45+
if data[0] != mysql.ClassicProtocolVersion {
46+
if data[0] == mysql.XProtocolVersion {
4747
return errors.Errorf(
4848
"invalid protocol version %d, expected 10. "+
4949
"This might be X Protocol, make sure to connect to the right port",
@@ -75,10 +75,10 @@ func (c *Conn) readInitialHandshake() error {
7575
// The lower 2 bytes of the Capabilities Flags
7676
c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
7777
// check protocol
78-
if c.capability&CLIENT_PROTOCOL_41 == 0 {
78+
if c.capability&mysql.CLIENT_PROTOCOL_41 == 0 {
7979
return errors.New("the MySQL server can not support protocol 41 and above required by the client")
8080
}
81-
if c.capability&CLIENT_SSL == 0 && c.tlsConfig != nil {
81+
if c.capability&mysql.CLIENT_SSL == 0 && c.tlsConfig != nil {
8282
return errors.New("the MySQL Server does not support TLS required by the client")
8383
}
8484
pos += 2
@@ -97,15 +97,15 @@ func (c *Conn) readInitialHandshake() error {
9797

9898
// length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0
9999
authPluginDataLen := data[pos]
100-
if (c.capability&CLIENT_PLUGIN_AUTH == 0) && (authPluginDataLen > 0) {
100+
if (c.capability&mysql.CLIENT_PLUGIN_AUTH == 0) && (authPluginDataLen > 0) {
101101
return errors.Errorf("invalid auth plugin data filler %d", authPluginDataLen)
102102
}
103103
pos++
104104

105105
// skip reserved (all [00] ?)
106106
pos += 10
107107

108-
if c.capability&CLIENT_SECURE_CONNECTION != 0 {
108+
if c.capability&mysql.CLIENT_SECURE_CONNECTION != 0 {
109109
// Rest of the plugin provided data (scramble)
110110

111111
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
@@ -125,7 +125,7 @@ func (c *Conn) readInitialHandshake() error {
125125
c.salt = append(c.salt, authPluginDataPart2...)
126126
}
127127

128-
if c.capability&CLIENT_PLUGIN_AUTH != 0 {
128+
if c.capability&mysql.CLIENT_PLUGIN_AUTH != 0 {
129129
c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])
130130
pos += len(c.authPluginName)
131131

@@ -153,13 +153,13 @@ func (c *Conn) readInitialHandshake() error {
153153
func (c *Conn) genAuthResponse(authData []byte) ([]byte, bool, error) {
154154
// password hashing
155155
switch c.authPluginName {
156-
case AUTH_NATIVE_PASSWORD:
157-
return CalcPassword(authData[:20], []byte(c.password)), false, nil
158-
case AUTH_CACHING_SHA2_PASSWORD:
159-
return CalcCachingSha2Password(authData, c.password), false, nil
160-
case AUTH_CLEAR_PASSWORD:
156+
case mysql.AUTH_NATIVE_PASSWORD:
157+
return mysql.CalcPassword(authData[:20], []byte(c.password)), false, nil
158+
case mysql.AUTH_CACHING_SHA2_PASSWORD:
159+
return mysql.CalcCachingSha2Password(authData, c.password), false, nil
160+
case mysql.AUTH_CLEAR_PASSWORD:
161161
return []byte(c.password), true, nil
162-
case AUTH_SHA256_PASSWORD:
162+
case mysql.AUTH_SHA256_PASSWORD:
163163
if len(c.password) == 0 {
164164
return nil, true, nil
165165
}
@@ -186,10 +186,10 @@ func (c *Conn) genAttributes() []byte {
186186

187187
attrData := make([]byte, 0)
188188
for k, v := range c.attributes {
189-
attrData = append(attrData, PutLengthEncodedString([]byte(k))...)
190-
attrData = append(attrData, PutLengthEncodedString([]byte(v))...)
189+
attrData = append(attrData, mysql.PutLengthEncodedString([]byte(k))...)
190+
attrData = append(attrData, mysql.PutLengthEncodedString([]byte(v))...)
191191
}
192-
return append(PutLengthEncodedInt(uint64(len(attrData))), attrData...)
192+
return append(mysql.PutLengthEncodedInt(uint64(len(attrData))), attrData...)
193193
}
194194

195195
// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
@@ -199,23 +199,23 @@ func (c *Conn) writeAuthHandshake() error {
199199
}
200200

201201
// Set default client capabilities that reflect the abilities of this library
202-
capability := CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION |
203-
CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_PLUGIN_AUTH
202+
capability := mysql.CLIENT_PROTOCOL_41 | mysql.CLIENT_SECURE_CONNECTION |
203+
mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_PLUGIN_AUTH
204204
// Adjust client capability flags based on server support
205-
capability |= c.capability & CLIENT_LONG_FLAG
206-
capability |= c.capability & CLIENT_QUERY_ATTRIBUTES
205+
capability |= c.capability & mysql.CLIENT_LONG_FLAG
206+
capability |= c.capability & mysql.CLIENT_QUERY_ATTRIBUTES
207207
// Adjust client capability flags on specific client requests
208208
// Only flags that would make any sense setting and aren't handled elsewhere
209209
// in the library are supported here
210-
capability |= c.ccaps&CLIENT_FOUND_ROWS | c.ccaps&CLIENT_IGNORE_SPACE |
211-
c.ccaps&CLIENT_MULTI_STATEMENTS | c.ccaps&CLIENT_MULTI_RESULTS |
212-
c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS |
213-
c.ccaps&CLIENT_COMPRESS | c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM |
214-
c.ccaps&CLIENT_LOCAL_FILES
210+
capability |= c.ccaps&mysql.CLIENT_FOUND_ROWS | c.ccaps&mysql.CLIENT_IGNORE_SPACE |
211+
c.ccaps&mysql.CLIENT_MULTI_STATEMENTS | c.ccaps&mysql.CLIENT_MULTI_RESULTS |
212+
c.ccaps&mysql.CLIENT_PS_MULTI_RESULTS | c.ccaps&mysql.CLIENT_CONNECT_ATTRS |
213+
c.ccaps&mysql.CLIENT_COMPRESS | c.ccaps&mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM |
214+
c.ccaps&mysql.CLIENT_LOCAL_FILES
215215

216216
// To enable TLS / SSL
217217
if c.tlsConfig != nil {
218-
capability |= CLIENT_SSL
218+
capability |= mysql.CLIENT_SSL
219219
}
220220

221221
auth, addNull, err := c.genAuthResponse(c.salt)
@@ -227,11 +227,11 @@ func (c *Conn) writeAuthHandshake() error {
227227
// here we use the Length-Encoded-Integer(LEI) as the data length may not fit into one byte
228228
// see: https://dev.mysql.com/doc/internals/en/integer.html#length-encoded-integer
229229
var authRespLEIBuf [9]byte
230-
authRespLEI := AppendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(auth)))
230+
authRespLEI := mysql.AppendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(auth)))
231231
if len(authRespLEI) > 1 {
232232
// if the length can not be written in 1 byte, it must be written as a
233233
// length encoded integer
234-
capability |= CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
234+
capability |= mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
235235
}
236236

237237
// packet length
@@ -248,16 +248,16 @@ func (c *Conn) writeAuthHandshake() error {
248248
}
249249
// db name
250250
if len(c.db) > 0 {
251-
capability |= CLIENT_CONNECT_WITH_DB
251+
capability |= mysql.CLIENT_CONNECT_WITH_DB
252252
length += len(c.db) + 1
253253
}
254254
// connection attributes
255255
attrData := c.genAttributes()
256256
if len(attrData) > 0 {
257-
capability |= CLIENT_CONNECT_ATTRS
257+
capability |= mysql.CLIENT_CONNECT_ATTRS
258258
length += len(attrData)
259259
}
260-
if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
260+
if c.ccaps&mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
261261
length++
262262
}
263263

@@ -279,7 +279,7 @@ func (c *Conn) writeAuthHandshake() error {
279279
// use default collation id 255 here, is `utf8mb4_0900_ai_ci`
280280
collationName := c.collation
281281
if len(collationName) == 0 {
282-
collationName = DEFAULT_COLLATION_NAME
282+
collationName = mysql.DEFAULT_COLLATION_NAME
283283
}
284284
collation, err := charset.GetCollationByName(collationName)
285285
if err != nil {
@@ -347,7 +347,7 @@ func (c *Conn) writeAuthHandshake() error {
347347
pos += copy(data[pos:], attrData)
348348
}
349349

350-
if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
350+
if c.ccaps&mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
351351
// zstd_compression_level
352352
data[pos] = 0x03
353353
}

0 commit comments

Comments
 (0)