diff --git a/internal/connection/manager.go b/internal/connection/manager.go index 14d4d5da2..ca8ef1ccd 100644 --- a/internal/connection/manager.go +++ b/internal/connection/manager.go @@ -41,6 +41,8 @@ func (t *Tx) Tx() *sql.Tx { return t.tx } +func (t *Tx) Conn() *Conn { return t.conn } + func (t *Tx) RollbackIfNotCommitted() error { if t.committed { return nil diff --git a/internal/contentdata/repository.go b/internal/contentdata/repository.go index 7be8c3bc7..fdf01a21d 100644 --- a/internal/contentdata/repository.go +++ b/internal/contentdata/repository.go @@ -168,6 +168,17 @@ func (r *Repository) Query(ctx context.Context, tx *connection.Tx, projectID, da zap.String("query", query), zap.Any("values", values), ) + // We must pass the query parameters to zetasqlite so the analyzer uses the proper typings + if err := tx.Conn().Conn.Raw(func(c interface{}) error { + zetasqliteConn, ok := c.(*zetasqlite.ZetaSQLiteConn) + if !ok { + return fmt.Errorf("failed to get ZetaSQLiteConn from %T", c) + } + zetasqliteConn.SetQueryParameters(params) + return nil + }); err != nil { + return nil, fmt.Errorf("failed to setup connection: %w", err) + } rows, err := tx.Tx().QueryContext(ctx, query, values...) if err != nil { return nil, err diff --git a/server/server_test.go b/server/server_test.go index 8eeaa4c2c..1c637d886 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2378,6 +2378,61 @@ ORDER BY qty DESC;`) if rowCount != 1 { t.Fatal("failed to get result") } + + query = client.Query("SELECT * FROM `test.test_dataset.test_table` WHERE @parameter IS NULL OR 'target text' = @parameter") + query.Parameters = []bigquery.QueryParameter{ + { + Name: "parameter", + Value: &bigquery.QueryParameterValue{ + Type: bigquery.StandardSQLDataType{ + TypeKind: "STRING", + }, + Value: "test", + }, + }, + } + it, err = query.Read(ctx) + + if err != nil { + t.Fatal(err) + } + for { + var row []bigquery.Value + if err := it.Next(&row); err != nil { + if err != iterator.Done { + t.Fatal(err) + } + break + } + if len(row) != 3 { + t.Fatalf("failed to get row: %v", row) + } + } + + query = client.Query("SELECT * FROM UNNEST(@states)") + query.Parameters = []bigquery.QueryParameter{ + { + Name: "states", + Value: []string{"WA", "VA", "WV", "WY"}, + }, + } + it, err = query.Read(ctx) + + if err != nil { + t.Fatal(err) + } + for { + var row []bigquery.Value + if err := it.Next(&row); err != nil { + if err != iterator.Done { + t.Fatal(err) + } + break + } + if len(row) != 1 { + t.Fatalf("failed to get row: %v", row) + } + } } func TestMultipleProject(t *testing.T) { @@ -2579,5 +2634,4 @@ func TestInformationSchema(t *testing.T) { } } }) - }