diff --git a/bench_test.go b/bench_test.go index 611edf87..e71f41d0 100644 --- a/bench_test.go +++ b/bench_test.go @@ -325,7 +325,7 @@ var testIntBytes = []byte("1234") func BenchmarkDecodeInt64(b *testing.B) { for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testIntBytes, oid.T_int8) + decode(¶meterStatus{}, testIntBytes, oid.T_int8, formatText) } } @@ -333,7 +333,7 @@ var testFloatBytes = []byte("3.14159") func BenchmarkDecodeFloat64(b *testing.B) { for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testFloatBytes, oid.T_float8) + decode(¶meterStatus{}, testFloatBytes, oid.T_float8, formatText) } } @@ -341,7 +341,7 @@ var testBoolBytes = []byte{'t'} func BenchmarkDecodeBool(b *testing.B) { for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testBoolBytes, oid.T_bool) + decode(¶meterStatus{}, testBoolBytes, oid.T_bool, formatText) } } @@ -358,7 +358,7 @@ var testTimestamptzBytes = []byte("2013-09-17 22:15:32.360754-07") func BenchmarkDecodeTimestamptz(b *testing.B) { for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz) + decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText) } } @@ -371,7 +371,7 @@ func BenchmarkDecodeTimestamptzMultiThread(b *testing.B) { f := func(wg *sync.WaitGroup, loops int) { defer wg.Done() for i := 0; i < loops; i++ { - decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz) + decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText) } } diff --git a/conn.go b/conn.go index 99ecde90..d57a14bd 100644 --- a/conn.go +++ b/conn.go @@ -106,6 +106,33 @@ type conn struct { // If true, this connection is bad and all public-facing functions should // return ErrBadConn. bad bool + + // If set, this connection should never use the binary format when + // receiving query results from prepared statements. Only provided for + // debugging. + disablePreparedBinaryResult bool +} + +// Handle driver-side settings in parsed connection string. +func (c *conn) handleDriverSettings(o values) (err error) { + boolSetting := func(key string, val *bool) error { + if value := o.Get(key); value != "" { + if value == "yes" { + *val = true + } else if value == "no" { + *val = false + } else { + return fmt.Errorf("unrecognized value %q for disable_prepared_binary_result", value) + } + } + return nil + } + + err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult) + if err != nil { + return err + } + return nil } func (c *conn) writeBuf(b byte) *writeBuf { @@ -194,12 +221,16 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { } } - c, err := dial(d, o) + cn := &conn{} + err = cn.handleDriverSettings(o) if err != nil { return nil, err } - cn := &conn{c: c} + cn.c, err = dial(d, o) + if err != nil { + return nil, err + } cn.ssl(o) cn.buf = bufio.NewReader(cn.c) cn.startup(o) @@ -496,7 +527,7 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err } } -func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { +func (cn *conn) simpleQuery(q string) (res *rows, err error) { defer cn.errRecover(&err) st := &stmt{cn: cn, name: ""} @@ -518,10 +549,11 @@ func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { errorf("unexpected message %q in simple query execution", t) } res = &rows{ - cn: cn, - cols: st.cols, + cn: cn, + cols: st.cols, rowTyps: st.rowTyps, - done: true, + rowFmts: st.rowFmts, + done: true, } case 'Z': cn.processReadyForQuery(r) @@ -541,9 +573,8 @@ func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { case 'T': // res might be non-nil here if we received a previous // CommandComplete, but that's fine; just overwrite it - rs := &rows{cn: cn} - rs.cols, rs.rowTyps = parseMeta(r) - res = rs + res = &rows{cn: cn} + res.cols, res.rowFmts, res.rowTyps = parseMeta(r) // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. @@ -554,6 +585,50 @@ func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { } } +// Decides which column formats to use for a prepared statement. The input is +// an array of type oids, one element per result column. +func decideColumnFormats(rowTyps []oid.Oid, forceText bool) (rowFmts []format, rowFmtData []byte) { + rowFmts = make([]format, len(rowTyps)) + if forceText { + return rowFmts, rowFmtDataAllText + } + + allBinary := true + allText := true + for i, o := range rowTyps { + switch o { + // This is the list of types to use binary mode for when receiving them + // through a prepared statement. If a type appears in this list, it + // must also be implemented in binaryDecode in encode.go. + case oid.T_bytea: + fallthrough + case oid.T_int8: + fallthrough + case oid.T_int4: + fallthrough + case oid.T_int2: + rowFmts[i] = formatBinary + allText = false + + default: + allBinary = false + } + } + + if allBinary { + return rowFmts, rowFmtDataAllBinary + } else if allText { + return rowFmts, rowFmtDataAllText + } else { + rowFmtData = make([]byte, 2+len(rowFmts)*2) + binary.BigEndian.PutUint16(rowFmtData, uint16(len(rowFmts))) + for i, v := range rowFmts { + binary.BigEndian.PutUint16(rowFmtData[2+i*2:], uint16(v)) + } + return rowFmts, rowFmtData + } +} + func (cn *conn) prepareTo(q, stmtName string) (_ *stmt, err error) { st := &stmt{cn: cn, name: stmtName} @@ -581,9 +656,11 @@ func (cn *conn) prepareTo(q, stmtName string) (_ *stmt, err error) { st.paramTyps[i] = r.oid() } case 'T': - st.cols, st.rowTyps = parseMeta(r) + st.cols, st.rowTyps = parseStatementRowDescribe(r) + st.rowFmts, st.rowFmtData = decideColumnFormats(st.rowTyps, cn.disablePreparedBinaryResult) case 'n': // no data + st.rowFmtData = rowFmtDataAllText case 'Z': cn.processReadyForQuery(r) return st, err @@ -644,9 +721,10 @@ func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err err st.exec(args) return &rows{ - cn: cn, - cols: st.cols, + cn: cn, + cols: st.cols, rowTyps: st.rowTyps, + rowFmts: st.rowFmts, }, nil } @@ -971,6 +1049,8 @@ func isDriverSetting(key string) bool { return true case "connect_timeout": return true + case "disable_prepared_binary_result": + return true default: return false @@ -1053,13 +1133,26 @@ func (cn *conn) auth(r *readBuf, o values) { } } +type format int + +const formatText format = 0 +const formatBinary format = 1 + +// One result-column format code with the value 1 (i.e. all binary). +var rowFmtDataAllBinary []byte = []byte{0, 1, 0, 1} + +// No result-column format codes (i.e. all text). +var rowFmtDataAllText []byte = []byte{0, 0} + type stmt struct { - cn *conn - name string - cols []string - rowTyps []oid.Oid - paramTyps []oid.Oid - closed bool + cn *conn + name string + cols []string + rowFmts []format + rowFmtData []byte + rowTyps []oid.Oid + paramTyps []oid.Oid + closed bool } func (st *stmt) Close() (err error) { @@ -1103,9 +1196,10 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { st.exec(v) return &rows{ - cn: st.cn, - cols: st.cols, + cn: st.cn, + cols: st.cols, rowTyps: st.rowTyps, + rowFmts: st.rowFmts, }, nil } @@ -1159,7 +1253,7 @@ func (st *stmt) exec(v []driver.Value) { w.bytes(b) } } - w.int16(0) + w.bytes(st.rowFmtData) w.next('E') w.byte(0) @@ -1278,11 +1372,12 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { } type rows struct { - cn *conn - cols []string + cn *conn + cols []string rowTyps []oid.Oid - done bool - rb readBuf + rowFmts []format + done bool + rb readBuf } func (rs *rows) Close() error { @@ -1339,7 +1434,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { dest[i] = nil continue } - dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.rowTyps[i]) + dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.rowTyps[i], rs.rowFmts[i]) } return default: @@ -1401,7 +1496,7 @@ func (c *conn) processReadyForQuery(r *readBuf) { c.txnStatus = transactionStatus(r.byte()) } -func parseMeta(r *readBuf) (cols []string, rowTyps []oid.Oid) { +func parseStatementRowDescribe(r *readBuf) (cols []string, rowTyps []oid.Oid) { n := r.int16() cols = make([]string, n) rowTyps = make([]oid.Oid, n) @@ -1409,7 +1504,24 @@ func parseMeta(r *readBuf) (cols []string, rowTyps []oid.Oid) { cols[i] = r.string() r.next(6) rowTyps[i] = r.oid() - r.next(8) + r.next(6) + // format code not known; always 0 + r.next(2) + } + return +} + +func parseMeta(r *readBuf) (cols []string, rowFmts []format, rowTyps []oid.Oid) { + n := r.int16() + cols = make([]string, n) + rowFmts = make([]format, n) + rowTyps = make([]oid.Oid, n) + for i := range cols { + cols[i] = r.string() + r.next(6) + rowTyps[i] = r.oid() + r.next(6) + rowFmts[i] = format(r.int16()) } return } diff --git a/conn_test.go b/conn_test.go index 741fd761..ec0d55cd 100644 --- a/conn_test.go +++ b/conn_test.go @@ -338,6 +338,7 @@ func TestEncodeDecode(t *testing.T) { '2000-1-1 01:02:03.04-7'::timestamptz, 0::boolean, 123, + -321, 3.14::float8 WHERE E'\\000\\001\\002'::bytea = $1 @@ -366,9 +367,9 @@ func TestEncodeDecode(t *testing.T) { var got2 string var got3 = sql.NullInt64{Valid: true} var got4 time.Time - var got5, got6, got7 interface{} + var got5, got6, got7, got8 interface{} - err = r.Scan(&got1, &got2, &got3, &got4, &got5, &got6, &got7) + err = r.Scan(&got1, &got2, &got3, &got4, &got5, &got6, &got7, &got8) if err != nil { t.Fatal(err) } @@ -397,8 +398,12 @@ func TestEncodeDecode(t *testing.T) { t.Fatalf("expected 123, got %d", got6) } - if got7 != float64(3.14) { - t.Fatalf("expected 3.14, got %f", got7) + if got7 != int64(-321) { + t.Fatalf("expected -321, got %d", got7) + } + + if got8 != float64(3.14) { + t.Fatalf("expected 3.14, got %f", got8) } } diff --git a/encode.go b/encode.go index ad5f9683..6d8b5e46 100644 --- a/encode.go +++ b/encode.go @@ -3,6 +3,7 @@ package pq import ( "bytes" "database/sql/driver" + "encoding/binary" "encoding/hex" "fmt" "math" @@ -44,7 +45,33 @@ func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) [ panic("not reached") } -func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { +func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} { + if f == formatBinary { + return binaryDecode(parameterStatus, s, typ) + } else { + return textDecode(parameterStatus, s, typ) + } +} + +func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { + switch typ { + case oid.T_bytea: + return s + case oid.T_int8: + return int64(binary.BigEndian.Uint64(s)) + case oid.T_int4: + return int64(int32(binary.BigEndian.Uint32(s))) + case oid.T_int2: + return int64(int16(binary.BigEndian.Uint16(s))) + + default: + errorf("don't know how to decode binary parameter of type %u", uint32(typ)) + } + + panic("not reached") +} + +func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { switch typ { case oid.T_bytea: return parseBytea(s) @@ -58,7 +85,7 @@ func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} return mustParse("15:04:05-07", typ, s) case oid.T_bool: return s[0] == 't' - case oid.T_int8, oid.T_int2, oid.T_int4: + case oid.T_int8, oid.T_int4, oid.T_int2: i, err := strconv.ParseInt(string(s), 10, 64) if err != nil { errorf("%s", err) diff --git a/encode_test.go b/encode_test.go index 50fbaf33..abfc1156 100644 --- a/encode_test.go +++ b/encode_test.go @@ -2,6 +2,7 @@ package pq import ( "bytes" + "database/sql" "fmt" "testing" "time" @@ -460,7 +461,7 @@ func TestByteaOutputFormats(t *testing.T) { return } - testByteaOutputFormat := func(f string) { + testByteaOutputFormat := func(f string, usePrepared bool) { expectedData := []byte("\x5c\x78\x00\xff\x61\x62\x63\x01\x08") sqlQuery := "SELECT decode('5c7800ff6162630108', 'hex')" @@ -477,8 +478,18 @@ func TestByteaOutputFormats(t *testing.T) { if err != nil { t.Fatal(err) } - // use Query; QueryRow would hide the actual error - rows, err := txn.Query(sqlQuery) + var rows *sql.Rows + var stmt *sql.Stmt + if usePrepared { + stmt, err = txn.Prepare(sqlQuery) + if err != nil { + t.Fatal(err) + } + rows, err = stmt.Query() + } else { + // use Query; QueryRow would hide the actual error + rows, err = txn.Query(sqlQuery) + } if err != nil { t.Fatal(err) } @@ -496,13 +507,21 @@ func TestByteaOutputFormats(t *testing.T) { if err != nil { t.Fatal(err) } + if stmt != nil { + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + } if !bytes.Equal(data, expectedData) { t.Errorf("unexpected bytea value %v for format %s; expected %v", data, f, expectedData) } } - testByteaOutputFormat("hex") - testByteaOutputFormat("escape") + testByteaOutputFormat("hex", false) + testByteaOutputFormat("escape", false) + testByteaOutputFormat("hex", true) + testByteaOutputFormat("escape", true) } func TestAppendEncodedText(t *testing.T) {