From f9b773578a6b31a22878abdf4d63233e19f4720a Mon Sep 17 00:00:00 2001
From: Dan Hansen <dan@recidiviz.org>
Date: Sun, 19 May 2024 13:27:05 -0700
Subject: [PATCH] Pass query parameter types to zetasqlite

---
 internal/connection/manager.go     |  2 ++
 internal/contentdata/repository.go | 11 ++++++
 server/server_test.go              | 56 +++++++++++++++++++++++++++++-
 3 files changed, 68 insertions(+), 1 deletion(-)

diff --git a/internal/connection/manager.go b/internal/connection/manager.go
index 14d4d5da2..ca8ef1ccd 100644
--- a/internal/connection/manager.go
+++ b/internal/connection/manager.go
@@ -41,6 +41,8 @@ func (t *Tx) Tx() *sql.Tx {
 	return t.tx
 }
 
+func (t *Tx) Conn() *Conn { return t.conn }
+
 func (t *Tx) RollbackIfNotCommitted() error {
 	if t.committed {
 		return nil
diff --git a/internal/contentdata/repository.go b/internal/contentdata/repository.go
index 7be8c3bc7..fdf01a21d 100644
--- a/internal/contentdata/repository.go
+++ b/internal/contentdata/repository.go
@@ -168,6 +168,17 @@ func (r *Repository) Query(ctx context.Context, tx *connection.Tx, projectID, da
 		zap.String("query", query),
 		zap.Any("values", values),
 	)
+	// We must pass the query parameters to zetasqlite so the analyzer uses the proper typings
+	if err := tx.Conn().Conn.Raw(func(c interface{}) error {
+		zetasqliteConn, ok := c.(*zetasqlite.ZetaSQLiteConn)
+		if !ok {
+			return fmt.Errorf("failed to get ZetaSQLiteConn from %T", c)
+		}
+		zetasqliteConn.SetQueryParameters(params)
+		return nil
+	}); err != nil {
+		return nil, fmt.Errorf("failed to setup connection: %w", err)
+	}
 	rows, err := tx.Tx().QueryContext(ctx, query, values...)
 	if err != nil {
 		return nil, err
diff --git a/server/server_test.go b/server/server_test.go
index 8eeaa4c2c..1c637d886 100644
--- a/server/server_test.go
+++ b/server/server_test.go
@@ -2378,6 +2378,61 @@ ORDER BY qty DESC;`)
 	if rowCount != 1 {
 		t.Fatal("failed to get result")
 	}
+
+	query = client.Query("SELECT * FROM `test.test_dataset.test_table` WHERE @parameter IS NULL OR 'target text' = @parameter")
+	query.Parameters = []bigquery.QueryParameter{
+		{
+			Name: "parameter",
+			Value: &bigquery.QueryParameterValue{
+				Type: bigquery.StandardSQLDataType{
+					TypeKind: "STRING",
+				},
+				Value: "test",
+			},
+		},
+	}
+	it, err = query.Read(ctx)
+
+	if err != nil {
+		t.Fatal(err)
+	}
+	for {
+		var row []bigquery.Value
+		if err := it.Next(&row); err != nil {
+			if err != iterator.Done {
+				t.Fatal(err)
+			}
+			break
+		}
+		if len(row) != 3 {
+			t.Fatalf("failed to get row: %v", row)
+		}
+	}
+
+	query = client.Query("SELECT * FROM UNNEST(@states)")
+	query.Parameters = []bigquery.QueryParameter{
+		{
+			Name:  "states",
+			Value: []string{"WA", "VA", "WV", "WY"},
+		},
+	}
+	it, err = query.Read(ctx)
+
+	if err != nil {
+		t.Fatal(err)
+	}
+	for {
+		var row []bigquery.Value
+		if err := it.Next(&row); err != nil {
+			if err != iterator.Done {
+				t.Fatal(err)
+			}
+			break
+		}
+		if len(row) != 1 {
+			t.Fatalf("failed to get row: %v", row)
+		}
+	}
 }
 
 func TestMultipleProject(t *testing.T) {
@@ -2579,5 +2634,4 @@ func TestInformationSchema(t *testing.T) {
 			}
 		}
 	})
-
 }