Skip to content

Commit 316d3ca

Browse files
committed
Add parsing of query response columns
1 parent 8d62392 commit 316d3ca

File tree

5 files changed

+346
-13
lines changed

5 files changed

+346
-13
lines changed

Diff for: com_query_response.go

+85-12
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,31 @@
11
package mysqlproto
22

3+
import (
4+
"errors"
5+
"fmt"
6+
)
7+
38
type ResultSet struct {
9+
Columns []Column
10+
411
conn Conn
512
}
613

14+
// https://dev.mysql.com/doc/internals/en/com-query-response.html#column-definition
15+
type Column struct {
16+
Catalog string
17+
Schema string
18+
Table string
19+
OrgTable string
20+
Name string
21+
OrgName string
22+
CharacterSet uint16
23+
ColumnLength uint64
24+
ColumnType Type
25+
Flags uint16
26+
Decimals byte
27+
}
28+
729
func (r ResultSet) Row() ([]byte, error) {
830
packet, err := r.conn.NextPacket()
931
if err != nil {
@@ -17,28 +39,79 @@ func (r ResultSet) Row() ([]byte, error) {
1739
return packet.Payload, nil
1840
}
1941

42+
// https://dev.mysql.com/doc/internals/en/com-query-response.html
2043
func ComQueryResponse(conn Conn) (ResultSet, error) {
21-
packet, err := conn.NextPacket()
22-
if err != nil {
23-
return ResultSet{}, err
44+
read := func() ([]byte, error) {
45+
packet, err := conn.NextPacket()
46+
if err != nil {
47+
return nil, err
48+
}
49+
if len(packet.Payload) == 0 {
50+
return nil, errors.New("mysqlproto: empty payload")
51+
}
52+
if packet.Payload[0] == ERR_PACKET {
53+
return nil, parseError(packet.Payload, conn.CapabilityFlags)
54+
}
55+
return packet.Payload, nil
2456
}
2557

26-
if packet.Payload[0] == ERR_PACKET {
27-
return ResultSet{}, parseError(packet.Payload, conn.CapabilityFlags)
58+
payload, err := read()
59+
if err != nil {
60+
return ResultSet{}, err
2861
}
2962

30-
columns, _, _ := lenDecInt(packet.Payload)
31-
skip := int(columns) + 1 // skip column definition + first EOF
32-
for i := 0; i < skip; i++ {
33-
packet, err := conn.NextPacket()
63+
colCount, _, _ := lenDecInt(payload)
64+
columns := make([]Column, int(colCount))
65+
for i := 0; i < int(colCount); i++ {
66+
payload, err := read()
3467
if err != nil {
3568
return ResultSet{}, err
3669
}
3770

38-
if packet.Payload[0] == ERR_PACKET {
39-
return ResultSet{}, parseError(packet.Payload, conn.CapabilityFlags)
71+
column := Column{}
72+
bytes, offset, _ := ReadRowValue(payload, 0)
73+
column.Catalog = string(bytes)
74+
75+
bytes, offset, _ = ReadRowValue(payload, offset)
76+
column.Schema = string(bytes)
77+
78+
bytes, offset, _ = ReadRowValue(payload, offset)
79+
column.Table = string(bytes)
80+
81+
bytes, offset, _ = ReadRowValue(payload, offset)
82+
column.OrgTable = string(bytes)
83+
84+
bytes, offset, _ = ReadRowValue(payload, offset)
85+
column.Name = string(bytes)
86+
87+
bytes, offset, _ = ReadRowValue(payload, offset)
88+
column.OrgName = string(bytes)
89+
90+
bytes, _, _ = ReadRowValue(payload, offset)
91+
if len(bytes) < 10 {
92+
return ResultSet{}, fmt.Errorf("mysqlproto: invalid column payload: %x", bytes)
4093
}
94+
95+
column.CharacterSet = uint16(bytes[0]) | uint16(bytes[1])<<8
96+
column.ColumnLength = uint64(bytes[2]) | uint64(bytes[3])<<8 | uint64(bytes[4])<<16 | uint64(bytes[5])<<32
97+
column.ColumnType = Type(bytes[6])
98+
column.Flags = uint16(bytes[7]) | uint16(bytes[9])<<8
99+
column.Decimals = bytes[9]
100+
101+
columns[i] = column
41102
}
42103

43-
return ResultSet{conn}, nil
104+
payload, err = read()
105+
if err != nil {
106+
return ResultSet{}, err
107+
}
108+
if payload[0] != EOF_PACKET {
109+
return ResultSet{}, parseError(payload, conn.CapabilityFlags)
110+
}
111+
112+
rs := ResultSet{
113+
Columns: columns,
114+
conn: conn,
115+
}
116+
return rs, nil
44117
}

Diff for: com_query_response_test.go

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package mysqlproto
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestComQueryResponseColumnReader(t *testing.T) {
10+
buf := newBuffer([]byte{
11+
// DB name "test"
12+
// table name "people" AS "p"
13+
14+
// total records
15+
0x01, 0x00, 0x00, 0x01, 0x05,
16+
17+
// id INT
18+
0x25, 0x00, 0x00, 0x02, 0x03, 0x64, 0x65, 0x66, 0x04, 0x74, 0x65, 0x73, 0x74, 0x01, 0x70, 0x06, 0x70, 0x65, 0x6f, 0x70, 0x6c, 0x65, 0x02, 0x69, 0x64, 0x02, 0x69, 0x64, 0x0c, 0x3f, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x03, 0x03, 0x42, 0x00, 0x00, 0x00,
19+
20+
// firstname VARCHAR(255) AS name
21+
0x2e, 0x00, 0x00, 0x03, 0x03, 0x64, 0x65, 0x66, 0x04, 0x74, 0x65, 0x73, 0x74, 0x01, 0x70, 0x06, 0x70, 0x65, 0x6f, 0x70, 0x6c, 0x65, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x09, 0x66, 0x69, 0x72, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x0c, 0x21, 0x00, 0xfd, 0x02, 0x00, 0x00, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00,
22+
23+
// married TINYINT
24+
0x2f, 0x00, 0x00, 0x04, 0x03, 0x64, 0x65, 0x66, 0x04, 0x74, 0x65, 0x73, 0x74, 0x01, 0x70, 0x06, 0x70, 0x65, 0x6f, 0x70, 0x6c, 0x65, 0x07, 0x6d, 0x61, 0x72, 0x72, 0x69, 0x65, 0x64, 0x07, 0x6d, 0x61, 0x72, 0x72, 0x69, 0x65, 0x64, 0x0c, 0x3f, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
25+
26+
// score DECIMAL(6,2)
27+
0x2b, 0x00, 0x00, 0x05, 0x03, 0x64, 0x65, 0x66, 0x04, 0x74, 0x65, 0x73, 0x74, 0x01, 0x70, 0x06, 0x70, 0x65, 0x6f, 0x70, 0x6c, 0x65, 0x05, 0x73, 0x63, 0x6f, 0x72, 0x65, 0x05, 0x73, 0x63, 0x6f, 0x72, 0x65, 0x0c, 0x3f, 0x00, 0x08, 0x00, 0x00, 0x00, 0xf6, 0x00, 0x00, 0x02, 0x00, 0x00,
28+
29+
// note TEXT
30+
0x29, 0x00, 0x00, 0x06, 0x03, 0x64, 0x65, 0x66, 0x04, 0x74, 0x65, 0x73, 0x74, 0x01, 0x70, 0x06, 0x70, 0x65, 0x6f, 0x70, 0x6c, 0x65, 0x04, 0x6e, 0x6f, 0x74, 0x65, 0x04, 0x6e, 0x6f, 0x74, 0x65, 0x0c, 0x21, 0x00, 0xfd, 0xff, 0x02, 0x00, 0xfc, 0x10, 0x00, 0x00, 0x00, 0x00,
31+
32+
// EOF
33+
0x05, 0x00, 0x00, 0x07, 0xfe, 0x00, 0x00, 0x22, 0x00,
34+
})
35+
36+
conn := Conn{Stream: &Stream{stream: buf}}
37+
rs, err := ComQueryResponse(conn)
38+
assert.NoError(t, err)
39+
assert.Len(t, rs.Columns, 5)
40+
41+
id := rs.Columns[0]
42+
assert.Equal(t, id.Catalog, "def")
43+
assert.Equal(t, id.Schema, "test")
44+
assert.Equal(t, id.Table, "p")
45+
assert.Equal(t, id.OrgTable, "people")
46+
assert.Equal(t, id.Name, "id")
47+
assert.Equal(t, id.OrgName, "id")
48+
assert.Equal(t, id.CharacterSet, uint16(63))
49+
assert.Equal(t, id.ColumnLength, uint64(11))
50+
assert.Equal(t, id.ColumnType.String(), "LONG")
51+
assert.Equal(t, id.Flags, uint16(3))
52+
assert.Equal(t, id.Decimals, uint8(0))
53+
54+
name := rs.Columns[1]
55+
assert.Equal(t, name.Catalog, "def")
56+
assert.Equal(t, name.Schema, "test")
57+
assert.Equal(t, name.Table, "p")
58+
assert.Equal(t, name.OrgTable, "people")
59+
assert.Equal(t, name.Name, "name")
60+
assert.Equal(t, name.OrgName, "firstname")
61+
assert.Equal(t, name.CharacterSet, uint16(33))
62+
assert.Equal(t, name.ColumnLength, uint64(765))
63+
assert.Equal(t, name.ColumnType.String(), "VAR_STRING")
64+
assert.Equal(t, name.Flags, uint16(0))
65+
assert.Equal(t, name.Decimals, uint8(0))
66+
67+
married := rs.Columns[2]
68+
assert.Equal(t, married.Catalog, "def")
69+
assert.Equal(t, married.Schema, "test")
70+
assert.Equal(t, married.Table, "p")
71+
assert.Equal(t, married.OrgTable, "people")
72+
assert.Equal(t, married.Name, "married")
73+
assert.Equal(t, married.OrgName, "married")
74+
assert.Equal(t, married.CharacterSet, uint16(63))
75+
assert.Equal(t, married.ColumnLength, uint64(4))
76+
assert.Equal(t, married.ColumnType.String(), "TINY")
77+
assert.Equal(t, married.Flags, uint16(0))
78+
assert.Equal(t, married.Decimals, uint8(0))
79+
80+
score := rs.Columns[3]
81+
assert.Equal(t, score.Catalog, "def")
82+
assert.Equal(t, score.Schema, "test")
83+
assert.Equal(t, score.Table, "p")
84+
assert.Equal(t, score.OrgTable, "people")
85+
assert.Equal(t, score.Name, "score")
86+
assert.Equal(t, score.OrgName, "score")
87+
assert.Equal(t, score.CharacterSet, uint16(63))
88+
assert.Equal(t, score.ColumnLength, uint64(8))
89+
assert.Equal(t, score.ColumnType.String(), "NEWDECIMAL")
90+
assert.Equal(t, score.Flags, uint16(512))
91+
assert.Equal(t, score.Decimals, uint8(2))
92+
93+
note := rs.Columns[4]
94+
assert.Equal(t, note.Catalog, "def")
95+
assert.Equal(t, note.Schema, "test")
96+
assert.Equal(t, note.Table, "p")
97+
assert.Equal(t, note.OrgTable, "people")
98+
assert.Equal(t, note.Name, "note")
99+
assert.Equal(t, note.OrgName, "note")
100+
assert.Equal(t, note.CharacterSet, uint16(33))
101+
assert.Equal(t, note.ColumnLength, uint64(196605))
102+
assert.Equal(t, note.ColumnType.String(), "BLOB")
103+
assert.Equal(t, note.Flags, uint16(16))
104+
assert.Equal(t, note.Decimals, uint8(0))
105+
}

Diff for: packet_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
"github.com/stretchr/testify/assert"
77
)
88

9-
func TestParseOKPacketInvalidPayout(t *testing.T) {
9+
func TestParseOKPacketInvalidPayload(t *testing.T) {
1010
data := []byte{0xff}
1111
_, err := ParseOKPacket(data, 0)
1212
assert.Equal(t, err.Error(), "mysqlproto: invalid OK_PACKET payload: ff")

Diff for: types.go

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
package mysqlproto
2+
3+
// https://dev.mysql.com/doc/internals/en/com-query-response.html#column-type
4+
type Type byte
5+
6+
const (
7+
TypeDecimal Type = 0x00
8+
TypeTiny Type = 0x01
9+
TypeShort Type = 0x02
10+
TypeLong Type = 0x03
11+
TypeFloat Type = 0x04
12+
TypeDouble Type = 0x05
13+
TypeNULL Type = 0x06
14+
TypeTimestamp Type = 0x07
15+
TypeLongLong Type = 0x08
16+
TypeInt24 Type = 0x09
17+
TypeDate Type = 0x0a
18+
TypeTime Type = 0x0b
19+
TypeDateTime Type = 0x0c
20+
TypeYear Type = 0x0d
21+
TypeNewDate Type = 0x0e
22+
TypeVarchar Type = 0x0f
23+
TypeBit Type = 0x10
24+
TypeTimestamp2 Type = 0x11
25+
TypeDateTime2 Type = 0x12
26+
TypeTime2 Type = 0x13
27+
TypeNewDecimal Type = 0xf6
28+
TypeEnum Type = 0xf7
29+
TypeSet Type = 0xf8
30+
TypeTinyBLOB Type = 0xf9
31+
TypeMediumBLOB Type = 0xfa
32+
TypeLongBLOB Type = 0xfb
33+
TypeBLOB Type = 0xfc
34+
TypeVarString Type = 0xfd
35+
TypeString Type = 0xfe
36+
TypeGEOMETRY Type = 0xff
37+
)
38+
39+
func (t Type) String() string {
40+
switch t {
41+
case TypeDecimal:
42+
return "DECIMAL"
43+
case TypeTiny:
44+
return "TINY"
45+
case TypeShort:
46+
return "SHORT"
47+
case TypeLong:
48+
return "LONG"
49+
case TypeFloat:
50+
return "FLOAT"
51+
case TypeDouble:
52+
return "DOUBLE"
53+
case TypeNULL:
54+
return "NULL"
55+
case TypeTimestamp:
56+
return "TIMESTAMP"
57+
case TypeLongLong:
58+
return "LONGLONG"
59+
case TypeInt24:
60+
return "INT24"
61+
case TypeDate:
62+
return "DATE"
63+
case TypeTime:
64+
return "TIME"
65+
case TypeDateTime:
66+
return "DATETIME"
67+
case TypeYear:
68+
return "YEAR"
69+
case TypeNewDate:
70+
return "NEWDATE"
71+
case TypeVarchar:
72+
return "VARCHAR"
73+
case TypeBit:
74+
return "BIT"
75+
case TypeTimestamp2:
76+
return "TIMESTAMP2"
77+
case TypeDateTime2:
78+
return "DATETIME2"
79+
case TypeTime2:
80+
return "TIME2"
81+
case TypeNewDecimal:
82+
return "NEWDECIMAL"
83+
case TypeEnum:
84+
return "ENUM"
85+
case TypeSet:
86+
return "SET"
87+
case TypeTinyBLOB:
88+
return "TINY_BLOB"
89+
case TypeMediumBLOB:
90+
return "MEDIUM_BLOB"
91+
case TypeLongBLOB:
92+
return "LONG_BLOB"
93+
case TypeBLOB:
94+
return "BLOB"
95+
case TypeVarString:
96+
return "VAR_STRING"
97+
case TypeString:
98+
return "STRING"
99+
case TypeGEOMETRY:
100+
return "GEOMETRY"
101+
default:
102+
return "UNKNOWN"
103+
}
104+
}

Diff for: types_test.go

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package mysqlproto
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestTypeString(t *testing.T) {
10+
testCases := []struct {
11+
typ Type
12+
hex byte
13+
str string
14+
}{
15+
{typ: TypeDecimal, hex: 0x00, str: "DECIMAL"},
16+
{typ: TypeTiny, hex: 0x01, str: "TINY"},
17+
{typ: TypeShort, hex: 0x02, str: "SHORT"},
18+
{typ: TypeLong, hex: 0x03, str: "LONG"},
19+
{typ: TypeFloat, hex: 0x04, str: "FLOAT"},
20+
{typ: TypeDouble, hex: 0x05, str: "DOUBLE"},
21+
{typ: TypeNULL, hex: 0x06, str: "NULL"},
22+
{typ: TypeTimestamp, hex: 0x07, str: "TIMESTAMP"},
23+
{typ: TypeLongLong, hex: 0x08, str: "LONGLONG"},
24+
{typ: TypeInt24, hex: 0x09, str: "INT24"},
25+
{typ: TypeDate, hex: 0x0a, str: "DATE"},
26+
{typ: TypeTime, hex: 0x0b, str: "TIME"},
27+
{typ: TypeDateTime, hex: 0x0c, str: "DATETIME"},
28+
{typ: TypeYear, hex: 0x0d, str: "YEAR"},
29+
{typ: TypeNewDate, hex: 0x0e, str: "NEWDATE"},
30+
{typ: TypeVarchar, hex: 0x0f, str: "VARCHAR"},
31+
{typ: TypeBit, hex: 0x10, str: "BIT"},
32+
{typ: TypeTimestamp2, hex: 0x11, str: "TIMESTAMP2"},
33+
{typ: TypeDateTime2, hex: 0x12, str: "DATETIME2"},
34+
{typ: TypeTime2, hex: 0x13, str: "TIME2"},
35+
{typ: TypeNewDecimal, hex: 0xf6, str: "NEWDECIMAL"},
36+
{typ: TypeEnum, hex: 0xf7, str: "ENUM"},
37+
{typ: TypeSet, hex: 0xf8, str: "SET"},
38+
{typ: TypeTinyBLOB, hex: 0xf9, str: "TINY_BLOB"},
39+
{typ: TypeMediumBLOB, hex: 0xfa, str: "MEDIUM_BLOB"},
40+
{typ: TypeLongBLOB, hex: 0xfb, str: "LONG_BLOB"},
41+
{typ: TypeBLOB, hex: 0xfc, str: "BLOB"},
42+
{typ: TypeVarString, hex: 0xfd, str: "VAR_STRING"},
43+
{typ: TypeString, hex: 0xfe, str: "STRING"},
44+
{typ: TypeGEOMETRY, hex: 0xff, str: "GEOMETRY"},
45+
}
46+
47+
for i, tc := range testCases {
48+
assert.Equal(t, byte(tc.typ), tc.hex, "test case %v", i)
49+
assert.Equal(t, tc.typ.String(), tc.str, "test case %v", i)
50+
}
51+
}

0 commit comments

Comments
 (0)