diff --git a/mysql/result.go b/mysql/result.go index d55cb0e27..f103194b3 100644 --- a/mysql/result.go +++ b/mysql/result.go @@ -36,6 +36,9 @@ func (r *Result) Close() { } func (r *Result) HasResultset() bool { + if r == nil { + return false + } if r.Resultset != nil && len(r.Resultset.Fields) > 0 { return true } diff --git a/mysql/result_test.go b/mysql/result_test.go index 0c9344f03..113ffbcf3 100644 --- a/mysql/result_test.go +++ b/mysql/result_test.go @@ -9,11 +9,25 @@ import ( func TestHasResultset_false(t *testing.T) { r := NewResultReserveResultset(0) b := r.HasResultset() - require.Equal(t, false, b) + require.False(t, b) } func TestHasResultset_true(t *testing.T) { r := NewResultReserveResultset(1) b := r.HasResultset() - require.Equal(t, true, b) + require.True(t, b) +} + +// this shouldn't happen after d02e79a, but test just in case +func TestHasResultset_nilset(t *testing.T) { + r := NewResultReserveResultset(0) + r.Resultset = nil + b := r.HasResultset() + require.False(t, b) +} + +func TestHasResultset_nil(t *testing.T) { + var r *Result + b := r.HasResultset() + require.False(t, b) } diff --git a/server/resp_test.go b/server/resp_test.go index 3608c60c3..babb8b6ea 100644 --- a/server/resp_test.go +++ b/server/resp_test.go @@ -14,10 +14,9 @@ func TestConnWriteOK(t *testing.T) { clientConn := &mockconn.MockConn{} conn := &Conn{Conn: packet.NewConn(clientConn)} - result := &mysql.Result{ - AffectedRows: 1, - InsertId: 2, - } + result := mysql.NewResultReserveResultset(0) + result.AffectedRows = 1 + result.InsertId = 2 // write ok with insertid and affectedrows set err := conn.writeOK(result) @@ -229,3 +228,36 @@ func TestConnWriteFieldValues(t *testing.T) { // EOF require.Equal(t, []byte{1, 0, 0, 4, mysql.EOF_HEADER}, clientConn.WriteBuffered[43:]) } + +func TestWriteValue(t *testing.T) { + clientConn := &mockconn.MockConn{MultiWrite: true} + conn := &Conn{Conn: packet.NewConn(clientConn)} + + // simple OK + err := conn.WriteValue(mysql.NewResultReserveResultset(0)) + require.NoError(t, err) + expected := []byte{3, 0, 0, 0, mysql.OK_HEADER, 0, 0} + require.Equal(t, expected, clientConn.WriteBuffered) + + // reset write buffer + clientConn.WriteBuffered = []byte{} + + // resultset with no rows + rs := mysql.NewResultReserveResultset(1) + rs.Fields = []*mysql.Field{{Name: []byte("a")}} + err = conn.WriteValue(rs) + require.NoError(t, err) + expected = []byte{1, 0, 0, 1, mysql.MORE_DATE_HEADER} + require.Equal(t, expected, clientConn.WriteBuffered[:5]) + + // reset write buffer + clientConn.WriteBuffered = []byte{} + + // resultset with rows + rs.Fields = []*mysql.Field{{Name: []byte("a")}} + rs.RowDatas = []mysql.RowData{[]byte{1, 2, 3}} + err = conn.WriteValue(rs) + require.NoError(t, err) + expected = []byte{1, 0, 0, 5, mysql.MORE_DATE_HEADER} + require.Equal(t, expected, clientConn.WriteBuffered[:5]) +} diff --git a/server/server_test.go b/server/server_test.go index d29451022..2ad8f4e7b 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -241,30 +241,16 @@ func (h *testHandler) handleQuery(query string, binary bool) (*mysql.Result, err if err != nil { return nil, errors.Trace(err) } else { - return &mysql.Result{ - Status: 0, - Warnings: 0, - InsertId: 0, - AffectedRows: 0, - Resultset: r, - }, nil + return mysql.NewResult(r), nil } case "insert": - return &mysql.Result{ - Status: 0, - Warnings: 0, - InsertId: 1, - AffectedRows: 0, - Resultset: nil, - }, nil + res := mysql.NewResultReserveResultset(0) + res.InsertId = 1 + return res, nil case "delete", "update", "replace": - return &mysql.Result{ - Status: 0, - Warnings: 0, - InsertId: 0, - AffectedRows: 1, - Resultset: nil, - }, nil + res := mysql.NewResultReserveResultset(0) + res.AffectedRows = 1 + return res, nil default: return nil, fmt.Errorf("invalid query %s", query) }