From c4afb3f5ae09066d9aaec61eebd9496f7a852082 Mon Sep 17 00:00:00 2001 From: Marko Tiikkaja Date: Sat, 30 May 2015 03:45:33 +0200 Subject: [PATCH] Transparently receive parameters in binary format when possible When preparing a statement for repeated execution (as opposed to just parameterizing a single query using the unnamed statement) we get to know the types of the resulting columns before we have to decide which ones to receive in binary and which ones in text. We can use that to our advantage to transparently avoid unnecessary binary -> text -> binary conversions. This has been shown in some cases to provide massive performance benefits, with little to no penalty even in the pathological case. But just to err on the safe side, an option for disabling this feature is provided, disable_prepared_binary_result. It is not documented in the user-facing documentation since its use is expected to be practically nonexistent. In the current state of affairs, only bytea and int8/int4/int2 values are requested in binary from the server. Floats and time-related types are probably the next types to get the same treatment. Chris Bandy and Marko Tiikkaja --- bench_test.go | 10 +-- conn.go | 168 ++++++++++++++++++++++++++++++++++++++++--------- conn_test.go | 13 ++-- encode.go | 31 ++++++++- encode_test.go | 29 +++++++-- 5 files changed, 207 insertions(+), 44 deletions(-) diff --git a/bench_test.go b/bench_test.go index 611edf87..e71f41d0 100644 --- a/bench_test.go +++ b/bench_test.go @@ -325,7 +325,7 @@ var testIntBytes = []byte("1234") func BenchmarkDecodeInt64(b *testing.B) { for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testIntBytes, oid.T_int8) + decode(¶meterStatus{}, testIntBytes, oid.T_int8, formatText) } } @@ -333,7 +333,7 @@ var testFloatBytes = []byte("3.14159") func BenchmarkDecodeFloat64(b *testing.B) { for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testFloatBytes, oid.T_float8) + decode(¶meterStatus{}, testFloatBytes, oid.T_float8, formatText) } } @@ -341,7 +341,7 @@ var testBoolBytes = []byte{'t'} func BenchmarkDecodeBool(b *testing.B) { for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testBoolBytes, oid.T_bool) + decode(¶meterStatus{}, testBoolBytes, oid.T_bool, formatText) } } @@ -358,7 +358,7 @@ var testTimestamptzBytes = []byte("2013-09-17 22:15:32.360754-07") func BenchmarkDecodeTimestamptz(b *testing.B) { for i := 0; i < b.N; i++ { - decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz) + decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText) } } @@ -371,7 +371,7 @@ func BenchmarkDecodeTimestamptzMultiThread(b *testing.B) { f := func(wg *sync.WaitGroup, loops int) { defer wg.Done() for i := 0; i < loops; i++ { - decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz) + decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText) } } diff --git a/conn.go b/conn.go index 99ecde90..d57a14bd 100644 --- a/conn.go +++ b/conn.go @@ -106,6 +106,33 @@ type conn struct { // If true, this connection is bad and all public-facing functions should // return ErrBadConn. bad bool + + // If set, this connection should never use the binary format when + // receiving query results from prepared statements. Only provided for + // debugging. + disablePreparedBinaryResult bool +} + +// Handle driver-side settings in parsed connection string. +func (c *conn) handleDriverSettings(o values) (err error) { + boolSetting := func(key string, val *bool) error { + if value := o.Get(key); value != "" { + if value == "yes" { + *val = true + } else if value == "no" { + *val = false + } else { + return fmt.Errorf("unrecognized value %q for disable_prepared_binary_result", value) + } + } + return nil + } + + err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult) + if err != nil { + return err + } + return nil } func (c *conn) writeBuf(b byte) *writeBuf { @@ -194,12 +221,16 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { } } - c, err := dial(d, o) + cn := &conn{} + err = cn.handleDriverSettings(o) if err != nil { return nil, err } - cn := &conn{c: c} + cn.c, err = dial(d, o) + if err != nil { + return nil, err + } cn.ssl(o) cn.buf = bufio.NewReader(cn.c) cn.startup(o) @@ -496,7 +527,7 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err } } -func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { +func (cn *conn) simpleQuery(q string) (res *rows, err error) { defer cn.errRecover(&err) st := &stmt{cn: cn, name: ""} @@ -518,10 +549,11 @@ func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { errorf("unexpected message %q in simple query execution", t) } res = &rows{ - cn: cn, - cols: st.cols, + cn: cn, + cols: st.cols, rowTyps: st.rowTyps, - done: true, + rowFmts: st.rowFmts, + done: true, } case 'Z': cn.processReadyForQuery(r) @@ -541,9 +573,8 @@ func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { case 'T': // res might be non-nil here if we received a previous // CommandComplete, but that's fine; just overwrite it - rs := &rows{cn: cn} - rs.cols, rs.rowTyps = parseMeta(r) - res = rs + res = &rows{cn: cn} + res.cols, res.rowFmts, res.rowTyps = parseMeta(r) // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. @@ -554,6 +585,50 @@ func (cn *conn) simpleQuery(q string) (res driver.Rows, err error) { } } +// Decides which column formats to use for a prepared statement. The input is +// an array of type oids, one element per result column. +func decideColumnFormats(rowTyps []oid.Oid, forceText bool) (rowFmts []format, rowFmtData []byte) { + rowFmts = make([]format, len(rowTyps)) + if forceText { + return rowFmts, rowFmtDataAllText + } + + allBinary := true + allText := true + for i, o := range rowTyps { + switch o { + // This is the list of types to use binary mode for when receiving them + // through a prepared statement. If a type appears in this list, it + // must also be implemented in binaryDecode in encode.go. + case oid.T_bytea: + fallthrough + case oid.T_int8: + fallthrough + case oid.T_int4: + fallthrough + case oid.T_int2: + rowFmts[i] = formatBinary + allText = false + + default: + allBinary = false + } + } + + if allBinary { + return rowFmts, rowFmtDataAllBinary + } else if allText { + return rowFmts, rowFmtDataAllText + } else { + rowFmtData = make([]byte, 2+len(rowFmts)*2) + binary.BigEndian.PutUint16(rowFmtData, uint16(len(rowFmts))) + for i, v := range rowFmts { + binary.BigEndian.PutUint16(rowFmtData[2+i*2:], uint16(v)) + } + return rowFmts, rowFmtData + } +} + func (cn *conn) prepareTo(q, stmtName string) (_ *stmt, err error) { st := &stmt{cn: cn, name: stmtName} @@ -581,9 +656,11 @@ func (cn *conn) prepareTo(q, stmtName string) (_ *stmt, err error) { st.paramTyps[i] = r.oid() } case 'T': - st.cols, st.rowTyps = parseMeta(r) + st.cols, st.rowTyps = parseStatementRowDescribe(r) + st.rowFmts, st.rowFmtData = decideColumnFormats(st.rowTyps, cn.disablePreparedBinaryResult) case 'n': // no data + st.rowFmtData = rowFmtDataAllText case 'Z': cn.processReadyForQuery(r) return st, err @@ -644,9 +721,10 @@ func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err err st.exec(args) return &rows{ - cn: cn, - cols: st.cols, + cn: cn, + cols: st.cols, rowTyps: st.rowTyps, + rowFmts: st.rowFmts, }, nil } @@ -971,6 +1049,8 @@ func isDriverSetting(key string) bool { return true case "connect_timeout": return true + case "disable_prepared_binary_result": + return true default: return false @@ -1053,13 +1133,26 @@ func (cn *conn) auth(r *readBuf, o values) { } } +type format int + +const formatText format = 0 +const formatBinary format = 1 + +// One result-column format code with the value 1 (i.e. all binary). +var rowFmtDataAllBinary []byte = []byte{0, 1, 0, 1} + +// No result-column format codes (i.e. all text). +var rowFmtDataAllText []byte = []byte{0, 0} + type stmt struct { - cn *conn - name string - cols []string - rowTyps []oid.Oid - paramTyps []oid.Oid - closed bool + cn *conn + name string + cols []string + rowFmts []format + rowFmtData []byte + rowTyps []oid.Oid + paramTyps []oid.Oid + closed bool } func (st *stmt) Close() (err error) { @@ -1103,9 +1196,10 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { st.exec(v) return &rows{ - cn: st.cn, - cols: st.cols, + cn: st.cn, + cols: st.cols, rowTyps: st.rowTyps, + rowFmts: st.rowFmts, }, nil } @@ -1159,7 +1253,7 @@ func (st *stmt) exec(v []driver.Value) { w.bytes(b) } } - w.int16(0) + w.bytes(st.rowFmtData) w.next('E') w.byte(0) @@ -1278,11 +1372,12 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { } type rows struct { - cn *conn - cols []string + cn *conn + cols []string rowTyps []oid.Oid - done bool - rb readBuf + rowFmts []format + done bool + rb readBuf } func (rs *rows) Close() error { @@ -1339,7 +1434,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { dest[i] = nil continue } - dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.rowTyps[i]) + dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.rowTyps[i], rs.rowFmts[i]) } return default: @@ -1401,7 +1496,7 @@ func (c *conn) processReadyForQuery(r *readBuf) { c.txnStatus = transactionStatus(r.byte()) } -func parseMeta(r *readBuf) (cols []string, rowTyps []oid.Oid) { +func parseStatementRowDescribe(r *readBuf) (cols []string, rowTyps []oid.Oid) { n := r.int16() cols = make([]string, n) rowTyps = make([]oid.Oid, n) @@ -1409,7 +1504,24 @@ func parseMeta(r *readBuf) (cols []string, rowTyps []oid.Oid) { cols[i] = r.string() r.next(6) rowTyps[i] = r.oid() - r.next(8) + r.next(6) + // format code not known; always 0 + r.next(2) + } + return +} + +func parseMeta(r *readBuf) (cols []string, rowFmts []format, rowTyps []oid.Oid) { + n := r.int16() + cols = make([]string, n) + rowFmts = make([]format, n) + rowTyps = make([]oid.Oid, n) + for i := range cols { + cols[i] = r.string() + r.next(6) + rowTyps[i] = r.oid() + r.next(6) + rowFmts[i] = format(r.int16()) } return } diff --git a/conn_test.go b/conn_test.go index 741fd761..ec0d55cd 100644 --- a/conn_test.go +++ b/conn_test.go @@ -338,6 +338,7 @@ func TestEncodeDecode(t *testing.T) { '2000-1-1 01:02:03.04-7'::timestamptz, 0::boolean, 123, + -321, 3.14::float8 WHERE E'\\000\\001\\002'::bytea = $1 @@ -366,9 +367,9 @@ func TestEncodeDecode(t *testing.T) { var got2 string var got3 = sql.NullInt64{Valid: true} var got4 time.Time - var got5, got6, got7 interface{} + var got5, got6, got7, got8 interface{} - err = r.Scan(&got1, &got2, &got3, &got4, &got5, &got6, &got7) + err = r.Scan(&got1, &got2, &got3, &got4, &got5, &got6, &got7, &got8) if err != nil { t.Fatal(err) } @@ -397,8 +398,12 @@ func TestEncodeDecode(t *testing.T) { t.Fatalf("expected 123, got %d", got6) } - if got7 != float64(3.14) { - t.Fatalf("expected 3.14, got %f", got7) + if got7 != int64(-321) { + t.Fatalf("expected -321, got %d", got7) + } + + if got8 != float64(3.14) { + t.Fatalf("expected 3.14, got %f", got8) } } diff --git a/encode.go b/encode.go index ad5f9683..6d8b5e46 100644 --- a/encode.go +++ b/encode.go @@ -3,6 +3,7 @@ package pq import ( "bytes" "database/sql/driver" + "encoding/binary" "encoding/hex" "fmt" "math" @@ -44,7 +45,33 @@ func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) [ panic("not reached") } -func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { +func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} { + if f == formatBinary { + return binaryDecode(parameterStatus, s, typ) + } else { + return textDecode(parameterStatus, s, typ) + } +} + +func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { + switch typ { + case oid.T_bytea: + return s + case oid.T_int8: + return int64(binary.BigEndian.Uint64(s)) + case oid.T_int4: + return int64(int32(binary.BigEndian.Uint32(s))) + case oid.T_int2: + return int64(int16(binary.BigEndian.Uint16(s))) + + default: + errorf("don't know how to decode binary parameter of type %u", uint32(typ)) + } + + panic("not reached") +} + +func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { switch typ { case oid.T_bytea: return parseBytea(s) @@ -58,7 +85,7 @@ func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} return mustParse("15:04:05-07", typ, s) case oid.T_bool: return s[0] == 't' - case oid.T_int8, oid.T_int2, oid.T_int4: + case oid.T_int8, oid.T_int4, oid.T_int2: i, err := strconv.ParseInt(string(s), 10, 64) if err != nil { errorf("%s", err) diff --git a/encode_test.go b/encode_test.go index 50fbaf33..abfc1156 100644 --- a/encode_test.go +++ b/encode_test.go @@ -2,6 +2,7 @@ package pq import ( "bytes" + "database/sql" "fmt" "testing" "time" @@ -460,7 +461,7 @@ func TestByteaOutputFormats(t *testing.T) { return } - testByteaOutputFormat := func(f string) { + testByteaOutputFormat := func(f string, usePrepared bool) { expectedData := []byte("\x5c\x78\x00\xff\x61\x62\x63\x01\x08") sqlQuery := "SELECT decode('5c7800ff6162630108', 'hex')" @@ -477,8 +478,18 @@ func TestByteaOutputFormats(t *testing.T) { if err != nil { t.Fatal(err) } - // use Query; QueryRow would hide the actual error - rows, err := txn.Query(sqlQuery) + var rows *sql.Rows + var stmt *sql.Stmt + if usePrepared { + stmt, err = txn.Prepare(sqlQuery) + if err != nil { + t.Fatal(err) + } + rows, err = stmt.Query() + } else { + // use Query; QueryRow would hide the actual error + rows, err = txn.Query(sqlQuery) + } if err != nil { t.Fatal(err) } @@ -496,13 +507,21 @@ func TestByteaOutputFormats(t *testing.T) { if err != nil { t.Fatal(err) } + if stmt != nil { + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + } if !bytes.Equal(data, expectedData) { t.Errorf("unexpected bytea value %v for format %s; expected %v", data, f, expectedData) } } - testByteaOutputFormat("hex") - testByteaOutputFormat("escape") + testByteaOutputFormat("hex", false) + testByteaOutputFormat("escape", false) + testByteaOutputFormat("hex", true) + testByteaOutputFormat("escape", true) } func TestAppendEncodedText(t *testing.T) {