Skip to content

Commit e1b8e65

Browse files
committed
Add tests and fix error
1 parent 9865fd5 commit e1b8e65

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

server/stmt.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ func (c *Conn) handleStmtExecute(data []byte) (*Result, error) {
101101

102102
s, ok := c.stmts[id]
103103
if !ok {
104-
return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER,
104+
return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER, 5,
105105
strconv.FormatUint(uint64(id), 10), "stmt_execute")
106106
}
107107

@@ -339,7 +339,7 @@ func (c *Conn) handleStmtReset(data []byte) (*Result, error) {
339339

340340
s, ok := c.stmts[id]
341341
if !ok {
342-
return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER,
342+
return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER, 5,
343343
strconv.FormatUint(uint64(id), 10), "stmt_reset")
344344
}
345345

server/stmt_test.go

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package server
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
func TestHandleStmtExecute(t *testing.T) {
10+
c := Conn{}
11+
c.stmts = map[uint32]*Stmt{
12+
1: &Stmt{},
13+
}
14+
testcases := []struct {
15+
data []byte
16+
errtext string
17+
}{
18+
{
19+
[]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
20+
"ERROR 1243 (HY000): Unknown prepared statement handler (0) given to stmt_execute",
21+
},
22+
{
23+
[]byte{0x1, 0x0, 0x0, 0x0, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0},
24+
"ERROR 1105 (HY000): unsupported flags 0xff",
25+
},
26+
{
27+
[]byte{0x1, 0x0, 0x0, 0x0, 0x01, 0x0, 0x0, 0x0, 0x0, 0x0},
28+
"ERROR 1105 (HY000): unsupported flag CURSOR_TYPE_READ_ONLY",
29+
},
30+
{
31+
[]byte{0x1, 0x0, 0x0, 0x0, 0x02, 0x0, 0x0, 0x0, 0x0, 0x0},
32+
"ERROR 1105 (HY000): unsupported flag CURSOR_TYPE_FOR_UPDATE",
33+
},
34+
{
35+
[]byte{0x1, 0x0, 0x0, 0x0, 0x04, 0x0, 0x0, 0x0, 0x0, 0x0},
36+
"ERROR 1105 (HY000): unsupported flag CURSOR_TYPE_SCROLLABLE",
37+
},
38+
}
39+
40+
for _, tc := range testcases {
41+
_, err := c.handleStmtExecute(tc.data)
42+
if tc.errtext == "" {
43+
require.NoError(t, err)
44+
} else {
45+
require.ErrorContains(t, err, tc.errtext)
46+
}
47+
}
48+
}

0 commit comments

Comments
 (0)