diff --git a/conn.go b/conn.go index 0f90ba90..e24f92aa 100644 --- a/conn.go +++ b/conn.go @@ -1054,7 +1054,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec return pq.execute(ctx, cancel, execOptions.PartitionedQueryOptions.ExecutePartition.Index) } - stmt, err := prepareSpannerStmt(c.parser, query, args) + stmt, err := prepareSpannerStmt(c.state, c.parser, query, args) if err != nil { return nil, err } @@ -1214,7 +1214,7 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecO return c.execDDL(ctx, spanner.NewStatement(query)) } - ss, err := prepareSpannerStmt(c.parser, query, args) + ss, err := prepareSpannerStmt(c.state, c.parser, query, args) if err != nil { return nil, err } diff --git a/connection_properties.go b/connection_properties.go index 1baac940..5c84a9aa 100644 --- a/connection_properties.go +++ b/connection_properties.go @@ -267,6 +267,20 @@ var propertyDecodeNumericToString = createConnectionProperty( connectionstate.ContextUser, connectionstate.ConvertBool, ) +var propertySendTypedStrings = createConnectionProperty( + "send_typed_strings", + "send_untyped_strings determines whether the driver should send string query parameters as "+ + "untyped (default) or typed strings. Using untyped strings is recommended, as it allows the application to send "+ + "any data type (e.g. JSON, TIMESTAMP, DATE) that is encoded as a string in Spanner using a simple string value. "+ + "Spanner will the infer the actual data type based on the SQL expression. "+ + "This property should be set to true if the application executes statements with query parameters where the "+ + "data type cannot be inferred, such as `SELECT @greeting`.", + false, + false, + nil, + connectionstate.ContextUser, + connectionstate.ConvertBool, +) // ------------------------------------------------------------------------------------------------ // Transaction connection properties. diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 108c90b4..e6de34bd 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -976,7 +976,8 @@ func TestQueryWithAllTypes(t *testing.T) { t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", g, w) } req := sqlRequests[0].(*sppb.ExecuteSqlRequest) - if g, w := len(req.ParamTypes), 22; g != w { + // Generic strings should be sent as untyped values. + if g, w := len(req.ParamTypes), 21; g != w { t.Fatalf("param types length mismatch\nGot: %v\nWant: %v", g, w) } if g, w := len(req.Params.Fields), 22; g != w { @@ -995,7 +996,7 @@ func TestQueryWithAllTypes(t *testing.T) { }, { name: "string", - code: sppb.TypeCode_STRING, + code: sppb.TypeCode_TYPE_CODE_UNSPECIFIED, value: "test", }, { @@ -1169,7 +1170,9 @@ func TestQueryWithAllTypes(t *testing.T) { } } } else { - t.Errorf("no param type found for @%s", wantParam.name) + if wantParam.code != sppb.TypeCode_TYPE_CODE_UNSPECIFIED { + t.Errorf("no param type found for @%s", wantParam.name) + } } if val, ok := req.Params.Fields[wantParam.name]; ok { var g interface{} @@ -1798,7 +1801,8 @@ func TestQueryWithAllNativeTypes(t *testing.T) { t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", g, w) } req := sqlRequests[0].(*sppb.ExecuteSqlRequest) - if g, w := len(req.ParamTypes), 22; g != w { + // Strings should be sent as untyped values. + if g, w := len(req.ParamTypes), 20; g != w { t.Fatalf("param types length mismatch\nGot: %v\nWant: %v", g, w) } if g, w := len(req.Params.Fields), 22; g != w { @@ -1817,7 +1821,7 @@ func TestQueryWithAllNativeTypes(t *testing.T) { }, { name: "string", - code: sppb.TypeCode_STRING, + code: sppb.TypeCode_TYPE_CODE_UNSPECIFIED, value: "test", }, { @@ -1876,7 +1880,7 @@ func TestQueryWithAllNativeTypes(t *testing.T) { }, { name: "stringArray", - code: sppb.TypeCode_STRING, + code: sppb.TypeCode_TYPE_CODE_UNSPECIFIED, array: true, value: &structpb.ListValue{Values: []*structpb.Value{ {Kind: &structpb.Value_StringValue{StringValue: "test1"}}, @@ -1980,7 +1984,9 @@ func TestQueryWithAllNativeTypes(t *testing.T) { } } } else { - t.Errorf("no param type found for @%s", wantParam.name) + if wantParam.code != sppb.TypeCode_TYPE_CODE_UNSPECIFIED { + t.Errorf("no param type found for @%s", wantParam.name) + } } if val, ok := req.Params.Fields[wantParam.name]; ok { var g interface{} diff --git a/examples/connect/connect.go b/examples/connect/connect.go index 7244e76c..421d84b0 100644 --- a/examples/connect/connect.go +++ b/examples/connect/connect.go @@ -34,7 +34,7 @@ func connect(projectId, instanceId, databaseId string) error { defer func() { _ = db.Close() }() fmt.Printf("Connected to %s\n", dsn) - row := db.QueryRowContext(ctx, "select @greeting", "Hello from Spanner") + row := db.QueryRowContext(ctx, "select cast(@greeting as string)", "Hello from Spanner") var greeting string if err := row.Scan(&greeting); err != nil { return fmt.Errorf("failed to get greeting: %v", err) diff --git a/integration_test.go b/integration_test.go index 30d21086..e70799ce 100644 --- a/integration_test.go +++ b/integration_test.go @@ -483,7 +483,11 @@ func TestTypeRoundTrip(t *testing.T) { defer cleanup() // Open db. - db, err := sql.Open("spanner", dsn) + // send_typed_strings=true is required to include a type code for string values. + // Otherwise, string values are sent as untyped strings in order to allow Spanner to infer the type. + // Using untyped strings is recommended for most applications, as it allows the application to just use + // standard string values for any type that is encoded as strings in Spanner (e.g. JSON, DATE, TIMESTAMP, etc.). + db, err := sql.Open("spanner", dsn+";send_typed_strings=true") if err != nil { t.Fatal(err) } diff --git a/stmt.go b/stmt.go index 5ccd302a..933eba86 100644 --- a/stmt.go +++ b/stmt.go @@ -20,9 +20,11 @@ import ( "database/sql/driver" "cloud.google.com/go/spanner" + "github.com/googleapis/go-sql-spanner/connectionstate" "github.com/googleapis/go-sql-spanner/parser" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/structpb" ) // SpannerNamedArg can be used for query parameters with a name that (might) start @@ -87,12 +89,13 @@ func (s *stmt) CheckNamedValue(value *driver.NamedValue) error { return s.conn.CheckNamedValue(value) } -func prepareSpannerStmt(parser *parser.StatementParser, q string, args []driver.NamedValue) (spanner.Statement, error) { +func prepareSpannerStmt(state *connectionstate.ConnectionState, parser *parser.StatementParser, q string, args []driver.NamedValue) (spanner.Statement, error) { q, names, err := parser.ParseParameters(q) if err != nil { return spanner.Statement{}, err } ss := spanner.NewStatement(q) + typedStrings := propertySendTypedStrings.GetValueOrDefault(state) for i, v := range args { value := v.Value name := args[i].Name @@ -104,7 +107,7 @@ func prepareSpannerStmt(parser *parser.StatementParser, q string, args []driver. name = names[i] } if name != "" { - ss.Params[name] = convertParam(value) + ss.Params[name] = convertParam(value, typedStrings) } } // Verify that all parameters have a value. @@ -116,10 +119,66 @@ func prepareSpannerStmt(parser *parser.StatementParser, q string, args []driver. return ss, nil } -func convertParam(v driver.Value) driver.Value { +func convertParam(v driver.Value, typedStrings bool) driver.Value { switch v := v.(type) { default: return v + case string: + if typedStrings { + return v + } + // Send strings as untyped parameter values to allow automatic conversion to any type that is encoded as + // strings. This for example allows DATE, TIMESTAMP, INTERVAL, JSON, INT64, etc. to all be set as a string + // by the application. + return spanner.GenericColumnValue{Value: structpb.NewStringValue(v)} + case *string: + if typedStrings { + return v + } + if v == nil { + return spanner.GenericColumnValue{Value: structpb.NewNullValue()} + } + return spanner.GenericColumnValue{Value: structpb.NewStringValue(*v)} + case []string: + if typedStrings { + return v + } + if v == nil { + return spanner.GenericColumnValue{Value: structpb.NewNullValue()} + } + values := make([]*structpb.Value, len(v)) + for i, s := range v { + values[i] = structpb.NewStringValue(s) + } + return spanner.GenericColumnValue{Value: structpb.NewListValue(&structpb.ListValue{Values: values})} + case *[]string: + if typedStrings { + return v + } + if v == nil { + return spanner.GenericColumnValue{Value: structpb.NewNullValue()} + } + values := make([]*structpb.Value, len(*v)) + for i, s := range *v { + values[i] = structpb.NewStringValue(s) + } + return spanner.GenericColumnValue{Value: structpb.NewListValue(&structpb.ListValue{Values: values})} + case []*string: + if typedStrings { + return v + } + if v == nil { + return spanner.GenericColumnValue{Value: structpb.NewNullValue()} + } + values := make([]*structpb.Value, len(v)) + for i, s := range v { + if s == nil { + values[i] = structpb.NewNullValue() + } else { + values[i] = structpb.NewStringValue(*s) + } + } + return spanner.GenericColumnValue{Value: structpb.NewListValue(&structpb.ListValue{Values: values})} case int: return int64(v) case []int: diff --git a/stmt_test.go b/stmt_test.go index 823b13f6..d725c737 100644 --- a/stmt_test.go +++ b/stmt_test.go @@ -23,7 +23,7 @@ import ( func TestConvertParam(t *testing.T) { check := func(in, want driver.Value) { t.Helper() - got := convertParam(in) + got := convertParam(in, false) if !reflect.DeepEqual(got, want) { t.Errorf("in:%#v want:%#v got:%#v", in, want, got) }