-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathhandshake_v10.go
102 lines (81 loc) · 2.24 KB
/
handshake_v10.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeV10
package mysqlproto
import (
"bytes"
"errors"
"fmt"
)
type HandshakeV10 struct {
ProtocolVersion byte
ServerVersion string
ConnectionID [4]byte
AuthPluginData []byte
CapabilityFlags uint32
CharacterSet byte
StatusFlags [2]byte
AuthPluginName string
}
func ReadHandshakeV10(stream *Stream) (HandshakeV10, error) {
pkt, err := stream.NextPacket()
if err != nil {
return HandshakeV10{}, err
}
data := pkt.Payload
if data[0] == EOF_PACKET {
return HandshakeV10{}, errors.New(string(data))
}
pos := 0
packet := HandshakeV10{
ProtocolVersion: data[pos],
}
pos += 1
null := bytes.IndexByte(data[pos:], 0x00)
if null == -1 {
return HandshakeV10{}, fmt.Errorf("mysqlproto: ReadHandshakeV10: expected 0x00: %v", data)
}
packet.ServerVersion = string(data[pos : pos+null])
pos += null + 1 // skip null terminator
packet.ConnectionID = [4]byte{
data[pos],
data[pos+1],
data[pos+2],
data[pos+3],
}
pos += 4
authDataPos := pos
pos += 8 // 8 bytes auth data plugin
pos += 1 // skip filler
packet.CapabilityFlags = uint32(data[pos]) | uint32(data[pos+1])<<8
pos += 2
if len(data) == pos {
packet.AuthPluginData = data[authDataPos : authDataPos+8]
return packet, nil
}
packet.CharacterSet = data[pos]
pos += 1
packet.StatusFlags = [2]byte{data[pos], data[pos+1]}
pos += 2
packet.CapabilityFlags = ((uint32(data[pos]) | uint32(data[pos+1])<<8) << 16) | packet.CapabilityFlags
pos += 2
var authDataLen uint8 = 0
if packet.CapabilityFlags&CLIENT_PLUGIN_AUTH > 0 {
authDataLen = uint8(data[pos])
}
pos += 1
pos += 10 // skip reserved 10 bytes
if packet.CapabilityFlags&CLIENT_SECURE_CONNECTION > 0 {
var read uint8 = 13
if read < authDataLen-8 {
read = authDataLen - 8
}
packet.AuthPluginData = make([]byte, read+7) // without null-character
copy(packet.AuthPluginData[:8], data[authDataPos:authDataPos+8])
copy(packet.AuthPluginData[8:], data[pos:pos+int(read)-1]) // remove null-character
pos += int(read)
}
if packet.CapabilityFlags&CLIENT_PLUGIN_AUTH > 0 {
null := bytes.IndexByte(data[pos:], 0x00)
packet.AuthPluginName = string(data[pos : pos+null])
}
return packet, nil
}