diff --git a/client/resp.go b/client/resp.go index cc9441467..0bf2d0c83 100644 --- a/client/resp.go +++ b/client/resp.go @@ -222,15 +222,26 @@ func (c *Conn) readResult(binary bool) (*Result, error) { return nil, errors.Trace(err) } + var result *Result if firstPkgBuf[0] == OK_HEADER { - return c.handleOKPacket(firstPkgBuf) + result, err = c.handleOKPacket(firstPkgBuf) } else if firstPkgBuf[0] == ERR_HEADER { - return nil, c.handleErrorPacket(append([]byte{}, firstPkgBuf...)) + err = c.handleErrorPacket(append([]byte{}, firstPkgBuf...)) } else if firstPkgBuf[0] == LocalInFile_HEADER { - return nil, ErrMalformPacket + err = ErrMalformPacket + } else { + result, err = c.readResultset(firstPkgBuf, binary) + } + + // if there are more results, chain resultsets of following results to this + // result + if err == nil && result.Status&SERVER_MORE_RESULTS_EXISTS > 0 { + if res, err := c.readResult(binary); err == nil { + result.ChainResult(res) + } } - return c.readResultset(firstPkgBuf, binary) + return result, err } func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectPerRowCallback) error { diff --git a/mysql/result.go b/mysql/result.go index 797a4af75..131f01d0a 100644 --- a/mysql/result.go +++ b/mysql/result.go @@ -7,6 +7,8 @@ type Result struct { AffectedRows uint64 *Resultset + + Next *Result } type Executer interface { @@ -19,3 +21,23 @@ func (r *Result) Close() { r.Resultset = nil } } + +func (r *Result) lastChained() (int, *Result) { + count := 1 + var lastRes *Result + for lastRes = r; lastRes.Next != nil; lastRes = lastRes.Next { + count++ + } + + return count, lastRes +} + +func (r *Result) ChainResult(cr *Result) { + _, lastRes := r.lastChained() + lastRes.Next = cr +} + +func (r *Result) Length() int { + n, _ := r.lastChained() + return n +} diff --git a/mysql/result_test.go b/mysql/result_test.go new file mode 100644 index 000000000..3574da965 --- /dev/null +++ b/mysql/result_test.go @@ -0,0 +1,30 @@ +package mysql + +import ( + "github.com/pingcap/check" +) + +type resultTestSuite struct { +} + +var _ = check.Suite(&resultTestSuite{}) + +func (t *resultTestSuite) TestLastChained(c *check.C) { + r1 := &Result{} + n, last := r1.lastChained() + c.Assert(last == r1, check.IsTrue) + c.Assert(n, check.Equals, 1) + + r2 := &Result{} + r1.ChainResult(r2) + n, last = r1.lastChained() + c.Assert(last == r2, check.IsTrue) + c.Assert(n, check.Equals, 2) + + n, last = r2.lastChained() + c.Assert(last == r2, check.IsTrue) + c.Assert(n, check.Equals, 1) + + c.Assert(r1.Length(), check.Equals, 2) + c.Assert(r2.Length(), check.Equals, 1) +} diff --git a/server/caching_sha2_cache_test.go b/server/caching_sha2_cache_test.go index d423b8775..c00488108 100644 --- a/server/caching_sha2_cache_test.go +++ b/server/caching_sha2_cache_test.go @@ -192,16 +192,16 @@ func (h *testCacheHandler) handleQuery(query string, binary bool) (*mysql.Result if err != nil { return nil, errors.Trace(err) } else { - return &mysql.Result{0, 0, 0, r}, nil + return &mysql.Result{0, 0, 0, r, nil}, nil } case "insert": - return &mysql.Result{0, 1, 0, nil}, nil + return &mysql.Result{0, 1, 0, nil, nil}, nil case "delete": - return &mysql.Result{0, 0, 1, nil}, nil + return &mysql.Result{0, 0, 1, nil, nil}, nil case "update": - return &mysql.Result{0, 0, 1, nil}, nil + return &mysql.Result{0, 0, 1, nil, nil}, nil case "replace": - return &mysql.Result{0, 0, 1, nil}, nil + return &mysql.Result{0, 0, 1, nil, nil}, nil default: return nil, fmt.Errorf("invalid query %s", query) } diff --git a/server/resp.go b/server/resp.go index f0ec5486e..65804aa0c 100644 --- a/server/resp.go +++ b/server/resp.go @@ -115,6 +115,29 @@ func (c *Conn) writeAuthMoreDataFastAuth() error { return c.WritePacket(data) } +func (c *Conn) writeResultsets(r *Result) error { + var err error + for res := r; res != nil; res = res.Next { + if res.Next != nil { + c.status |= SERVER_MORE_RESULTS_EXISTS + } + + if res.Resultset == nil { + err = c.writeOK(res) + } else { + err = c.writeResultset(res.Resultset) + } + + c.status &= ^SERVER_MORE_RESULTS_EXISTS + + if err != nil { + return err + } + } + + return nil +} + func (c *Conn) writeResultset(r *Resultset) error { columnLen := PutLengthEncodedInt(uint64(len(r.Fields))) @@ -183,8 +206,8 @@ func (c *Conn) writeValue(value interface{}) error { case nil: return c.writeOK(nil) case *Result: - if v != nil && v.Resultset != nil { - return c.writeResultset(v.Resultset) + if v != nil { + return c.writeResultsets(v) } else { return c.writeOK(v) } diff --git a/server/server_test.go b/server/server_test.go index 03d97baff..58d47f1e7 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -239,16 +239,16 @@ func (h *testHandler) handleQuery(query string, binary bool) (*mysql.Result, err if err != nil { return nil, errors.Trace(err) } else { - return &mysql.Result{0, 0, 0, r}, nil + return &mysql.Result{0, 0, 0, r, nil}, nil } case "insert": - return &mysql.Result{0, 1, 0, nil}, nil + return &mysql.Result{0, 1, 0, nil, nil}, nil case "delete": - return &mysql.Result{0, 0, 1, nil}, nil + return &mysql.Result{0, 0, 1, nil, nil}, nil case "update": - return &mysql.Result{0, 0, 1, nil}, nil + return &mysql.Result{0, 0, 1, nil, nil}, nil case "replace": - return &mysql.Result{0, 0, 1, nil}, nil + return &mysql.Result{0, 0, 1, nil, nil}, nil default: return nil, fmt.Errorf("invalid query %s", query) }