Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
return pq.execute(ctx, cancel, execOptions.PartitionedQueryOptions.ExecutePartition.Index)
}

stmt, err := prepareSpannerStmt(c.parser, query, args)
stmt, err := prepareSpannerStmt(c.state, c.parser, query, args)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1214,7 +1214,7 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecO
return c.execDDL(ctx, spanner.NewStatement(query))
}

ss, err := prepareSpannerStmt(c.parser, query, args)
ss, err := prepareSpannerStmt(c.state, c.parser, query, args)
if err != nil {
return nil, err
}
Expand Down
14 changes: 14 additions & 0 deletions connection_properties.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,20 @@ var propertyDecodeNumericToString = createConnectionProperty(
connectionstate.ContextUser,
connectionstate.ConvertBool,
)
var propertySendTypedStrings = createConnectionProperty(
"send_typed_strings",
"send_untyped_strings determines whether the driver should send string query parameters as "+
"untyped (default) or typed strings. Using untyped strings is recommended, as it allows the application to send "+
"any data type (e.g. JSON, TIMESTAMP, DATE) that is encoded as a string in Spanner using a simple string value. "+
"Spanner will the infer the actual data type based on the SQL expression. "+
"This property should be set to true if the application executes statements with query parameters where the "+
"data type cannot be inferred, such as `SELECT @greeting`.",
false,
false,
nil,
connectionstate.ContextUser,
connectionstate.ConvertBool,
)

// ------------------------------------------------------------------------------------------------
// Transaction connection properties.
Expand Down
20 changes: 13 additions & 7 deletions driver_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,8 @@ func TestQueryWithAllTypes(t *testing.T) {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", g, w)
}
req := sqlRequests[0].(*sppb.ExecuteSqlRequest)
if g, w := len(req.ParamTypes), 22; g != w {
// Generic strings should be sent as untyped values.
if g, w := len(req.ParamTypes), 21; g != w {
t.Fatalf("param types length mismatch\nGot: %v\nWant: %v", g, w)
}
if g, w := len(req.Params.Fields), 22; g != w {
Expand All @@ -995,7 +996,7 @@ func TestQueryWithAllTypes(t *testing.T) {
},
{
name: "string",
code: sppb.TypeCode_STRING,
code: sppb.TypeCode_TYPE_CODE_UNSPECIFIED,
value: "test",
},
{
Expand Down Expand Up @@ -1169,7 +1170,9 @@ func TestQueryWithAllTypes(t *testing.T) {
}
}
} else {
t.Errorf("no param type found for @%s", wantParam.name)
if wantParam.code != sppb.TypeCode_TYPE_CODE_UNSPECIFIED {
t.Errorf("no param type found for @%s", wantParam.name)
}
}
if val, ok := req.Params.Fields[wantParam.name]; ok {
var g interface{}
Expand Down Expand Up @@ -1798,7 +1801,8 @@ func TestQueryWithAllNativeTypes(t *testing.T) {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", g, w)
}
req := sqlRequests[0].(*sppb.ExecuteSqlRequest)
if g, w := len(req.ParamTypes), 22; g != w {
// Strings should be sent as untyped values.
if g, w := len(req.ParamTypes), 20; g != w {
t.Fatalf("param types length mismatch\nGot: %v\nWant: %v", g, w)
}
if g, w := len(req.Params.Fields), 22; g != w {
Expand All @@ -1817,7 +1821,7 @@ func TestQueryWithAllNativeTypes(t *testing.T) {
},
{
name: "string",
code: sppb.TypeCode_STRING,
code: sppb.TypeCode_TYPE_CODE_UNSPECIFIED,
value: "test",
},
{
Expand Down Expand Up @@ -1876,7 +1880,7 @@ func TestQueryWithAllNativeTypes(t *testing.T) {
},
{
name: "stringArray",
code: sppb.TypeCode_STRING,
code: sppb.TypeCode_TYPE_CODE_UNSPECIFIED,
array: true,
value: &structpb.ListValue{Values: []*structpb.Value{
{Kind: &structpb.Value_StringValue{StringValue: "test1"}},
Expand Down Expand Up @@ -1980,7 +1984,9 @@ func TestQueryWithAllNativeTypes(t *testing.T) {
}
}
} else {
t.Errorf("no param type found for @%s", wantParam.name)
if wantParam.code != sppb.TypeCode_TYPE_CODE_UNSPECIFIED {
t.Errorf("no param type found for @%s", wantParam.name)
}
}
if val, ok := req.Params.Fields[wantParam.name]; ok {
var g interface{}
Expand Down
2 changes: 1 addition & 1 deletion examples/connect/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func connect(projectId, instanceId, databaseId string) error {
defer func() { _ = db.Close() }()

fmt.Printf("Connected to %s\n", dsn)
row := db.QueryRowContext(ctx, "select @greeting", "Hello from Spanner")
row := db.QueryRowContext(ctx, "select cast(@greeting as string)", "Hello from Spanner")
var greeting string
if err := row.Scan(&greeting); err != nil {
return fmt.Errorf("failed to get greeting: %v", err)
Expand Down
6 changes: 5 additions & 1 deletion integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,11 @@ func TestTypeRoundTrip(t *testing.T) {
defer cleanup()

// Open db.
db, err := sql.Open("spanner", dsn)
// send_typed_strings=true is required to include a type code for string values.
// Otherwise, string values are sent as untyped strings in order to allow Spanner to infer the type.
// Using untyped strings is recommended for most applications, as it allows the application to just use
// standard string values for any type that is encoded as strings in Spanner (e.g. JSON, DATE, TIMESTAMP, etc.).
db, err := sql.Open("spanner", dsn+";send_typed_strings=true")
if err != nil {
t.Fatal(err)
}
Expand Down
65 changes: 62 additions & 3 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ import (
"database/sql/driver"

"cloud.google.com/go/spanner"
"github.com/googleapis/go-sql-spanner/connectionstate"
"github.com/googleapis/go-sql-spanner/parser"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/structpb"
)

// SpannerNamedArg can be used for query parameters with a name that (might) start
Expand Down Expand Up @@ -87,12 +89,13 @@ func (s *stmt) CheckNamedValue(value *driver.NamedValue) error {
return s.conn.CheckNamedValue(value)
}

func prepareSpannerStmt(parser *parser.StatementParser, q string, args []driver.NamedValue) (spanner.Statement, error) {
func prepareSpannerStmt(state *connectionstate.ConnectionState, parser *parser.StatementParser, q string, args []driver.NamedValue) (spanner.Statement, error) {
q, names, err := parser.ParseParameters(q)
if err != nil {
return spanner.Statement{}, err
}
ss := spanner.NewStatement(q)
typedStrings := propertySendTypedStrings.GetValueOrDefault(state)
for i, v := range args {
value := v.Value
name := args[i].Name
Expand All @@ -104,7 +107,7 @@ func prepareSpannerStmt(parser *parser.StatementParser, q string, args []driver.
name = names[i]
}
if name != "" {
ss.Params[name] = convertParam(value)
ss.Params[name] = convertParam(value, typedStrings)
}
}
// Verify that all parameters have a value.
Expand All @@ -116,10 +119,66 @@ func prepareSpannerStmt(parser *parser.StatementParser, q string, args []driver.
return ss, nil
}

func convertParam(v driver.Value) driver.Value {
func convertParam(v driver.Value, typedStrings bool) driver.Value {
switch v := v.(type) {
default:
return v
case string:
if typedStrings {
return v
}
// Send strings as untyped parameter values to allow automatic conversion to any type that is encoded as
// strings. This for example allows DATE, TIMESTAMP, INTERVAL, JSON, INT64, etc. to all be set as a string
// by the application.
return spanner.GenericColumnValue{Value: structpb.NewStringValue(v)}
case *string:
if typedStrings {
return v
}
if v == nil {
return spanner.GenericColumnValue{Value: structpb.NewNullValue()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can have global or package-level NullValue and use that reused
var pbNull = structpb.NewNullValue()
spanner.GenericColumnValue{Value: pbNull}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let me add that in a follow-up PR.

}
return spanner.GenericColumnValue{Value: structpb.NewStringValue(*v)}
case []string:
if typedStrings {
return v
}
if v == nil {
return spanner.GenericColumnValue{Value: structpb.NewNullValue()}
}
values := make([]*structpb.Value, len(v))
for i, s := range v {
values[i] = structpb.NewStringValue(s)
}
return spanner.GenericColumnValue{Value: structpb.NewListValue(&structpb.ListValue{Values: values})}
case *[]string:
if typedStrings {
return v
}
if v == nil {
return spanner.GenericColumnValue{Value: structpb.NewNullValue()}
}
values := make([]*structpb.Value, len(*v))
for i, s := range *v {
values[i] = structpb.NewStringValue(s)
}
return spanner.GenericColumnValue{Value: structpb.NewListValue(&structpb.ListValue{Values: values})}
case []*string:
if typedStrings {
return v
}
if v == nil {
return spanner.GenericColumnValue{Value: structpb.NewNullValue()}
}
values := make([]*structpb.Value, len(v))
for i, s := range v {
if s == nil {
values[i] = structpb.NewNullValue()
} else {
values[i] = structpb.NewStringValue(*s)
}
}
return spanner.GenericColumnValue{Value: structpb.NewListValue(&structpb.ListValue{Values: values})}
case int:
return int64(v)
case []int:
Expand Down
2 changes: 1 addition & 1 deletion stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
func TestConvertParam(t *testing.T) {
check := func(in, want driver.Value) {
t.Helper()
got := convertParam(in)
got := convertParam(in, false)
if !reflect.DeepEqual(got, want) {
t.Errorf("in:%#v want:%#v got:%#v", in, want, got)
}
Expand Down
Loading