diff --git a/conn.go b/conn.go index df6f821f..efd723e4 100644 --- a/conn.go +++ b/conn.go @@ -1083,7 +1083,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec c.setCommitResponse(commitResponse) } else if execOptions.PartitionedQueryOptions.PartitionQuery { return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "PartitionQuery is only supported in batch read-only transactions")) - } else if execOptions.PartitionedQueryOptions.AutoPartitionQuery { + } else if execOptions.PartitionedQueryOptions.AutoPartitionQuery || propertyAutoPartitionMode.GetValueOrDefault(c.state) { return c.executeAutoPartitionedQuery(ctx, cancel, query, execOptions, args) } else { // The statement was either detected as being a query, or potentially not recognized at all. @@ -1587,6 +1587,7 @@ func (c *conn) activateTransaction() (contextTransaction, error) { timestampBoundCallback: func() spanner.TimestampBound { return propertyReadOnlyStaleness.GetValueOrDefault(c.state) }, + state: c.state, }, nil } diff --git a/connection_properties.go b/connection_properties.go index a5c6258c..1fc564f9 100644 --- a/connection_properties.go +++ b/connection_properties.go @@ -161,6 +161,49 @@ var propertyReadOnlyStaleness = createConnectionProperty( connectionstate.ConvertReadOnlyStaleness, ) +var propertyAutoPartitionMode = createConnectionProperty( + "auto_partition_mode", + "Execute all queries on this connection as partitioned queries. "+ + "Executing a query that cannot be partitioned will fail. "+ + "Executing a query in a read/write transaction will also fail.", + false, + false, + nil, + connectionstate.ContextUser, + connectionstate.ConvertBool, +) +var propertyDataBoostEnabled = createConnectionProperty( + "data_boost_enabled", + "Enable data boost for all partitioned queries that are executed by this connection. "+ + "This setting is only used for partitioned queries and is ignored by all other statements. "+ + "Either set `auto_partition_query=true` or execute a query with `RUN PARTITIONED QUERY SELECT ... FROM ...` "+ + "to execute a query as a partitioned query.", + false, + false, + nil, + connectionstate.ContextUser, + connectionstate.ConvertBool, +) +var propertyMaxPartitions = createConnectionProperty( + "max_partitions", + "The max partitions hint value to use for partitioned queries. "+ + "Set to 0 if you do not want to specify a hint.", + 0, + false, + nil, + connectionstate.ContextUser, + connectionstate.ConvertInt64, +) +var propertyMaxPartitionedParallelism = createConnectionProperty( + "max_partitioned_parallelism", + "The maximum number of workers to use to read data from partitioned queries.", + 0, + false, + nil, + connectionstate.ContextUser, + connectionstate.ConvertInt, +) + var propertyAutoBatchDml = createConnectionProperty( "auto_batch_dml", "Automatically buffer DML statements that are executed on this connection and execute them as one batch "+ diff --git a/parser/statements.go b/parser/statements.go index d253bc37..2b85e815 100644 --- a/parser/statements.go +++ b/parser/statements.go @@ -43,7 +43,11 @@ func parseStatement(parser *StatementParser, keyword, query string) (ParsedState stmt = &ParsedStartBatchStatement{} } } else if isRunStatementKeyword(keyword) { - stmt = &ParsedRunBatchStatement{} + if isRunBatch(parser, query) { + stmt = &ParsedRunBatchStatement{} + } else if isRunPartitionedQuery(parser, query) { + stmt = &ParsedRunPartitionedQueryStatement{} + } } else if isAbortStatementKeyword(keyword) { stmt = &ParsedAbortBatchStatement{} } else if isBeginStatementKeyword(keyword) { @@ -55,6 +59,9 @@ func parseStatement(parser *StatementParser, keyword, query string) (ParsedState } else { return nil, nil } + if stmt == nil { + return nil, nil + } if err := stmt.parse(parser, query); err != nil { return nil, err } @@ -98,6 +105,36 @@ func isStartTransaction(parser *StatementParser, query string) bool { return false } +func isRunBatch(parser *StatementParser, query string) bool { + sp := &simpleParser{sql: []byte(query), statementParser: parser} + if !sp.eatKeyword("run") { + return false + } + if !sp.hasMoreTokens() { + // START is a synonym for START TRANSACTION + return false + } + if sp.eatKeyword("batch") { + return true + } + return false +} + +func isRunPartitionedQuery(parser *StatementParser, query string) bool { + sp := &simpleParser{sql: []byte(query), statementParser: parser} + if !sp.eatKeyword("run") { + return false + } + if !sp.hasMoreTokens() { + // START is a synonym for START TRANSACTION + return false + } + if sp.eatKeyword("partitioned") { + return true + } + return false +} + // ParsedShowStatement is a statement of the form // SHOW [VARIABLE] [my_extension.]my_property type ParsedShowStatement struct { @@ -509,6 +546,34 @@ func (s *ParsedAbortBatchStatement) parse(parser *StatementParser, query string) return nil } +type ParsedRunPartitionedQueryStatement struct { + query string + Statement string +} + +func (s *ParsedRunPartitionedQueryStatement) Name() string { + return "RUN PARTITIONED QUERY" +} + +func (s *ParsedRunPartitionedQueryStatement) Query() string { + return s.query +} + +func (s *ParsedRunPartitionedQueryStatement) parse(parser *StatementParser, query string) error { + // Parse a statement of the form + // RUN PARTITIONED QUERY + sp := &simpleParser{sql: []byte(query), statementParser: parser} + if !sp.eatKeywords([]string{"RUN", "PARTITIONED", "QUERY"}) { + return status.Error(codes.InvalidArgument, "statement does not start with RUN PARTITIONED QUERY") + } + if !sp.hasMoreTokens() { + return status.Errorf(codes.InvalidArgument, "missing statement after RUN PARTITIONED QUERY: %q", sp.sql) + } + s.Statement = query[sp.pos:] + s.query = query + return nil +} + type ParsedBeginStatement struct { query string // Identifiers contains the transaction properties that were included in the BEGIN statement. E.g. the statement diff --git a/parser/statements_test.go b/parser/statements_test.go index 45361c74..65d5dd27 100644 --- a/parser/statements_test.go +++ b/parser/statements_test.go @@ -629,3 +629,86 @@ func TestParseBeginStatementPostgreSQL(t *testing.T) { }) } } + +func TestParseRunPartitionedQuery(t *testing.T) { + t.Parallel() + + type test struct { + input string + want ParsedRunPartitionedQueryStatement + wantErr bool + } + tests := []test{ + { + input: "run partitioned query select * from my_table", + want: ParsedRunPartitionedQueryStatement{ + statement: " select * from my_table", + query: "run partitioned query select * from my_table", + }, + }, + { + input: "run partitioned query\nselect * from my_table", + want: ParsedRunPartitionedQueryStatement{ + statement: "\nselect * from my_table", + query: "run partitioned query\nselect * from my_table", + }, + }, + { + input: "run partitioned query\n--comment\nselect * from my_table", + want: ParsedRunPartitionedQueryStatement{ + statement: "\n--comment\nselect * from my_table", + query: "run partitioned query\n--comment\nselect * from my_table", + }, + }, + { + input: "run --comment\n partitioned /* comment */ query select * from my_table", + want: ParsedRunPartitionedQueryStatement{ + statement: " select * from my_table", + query: "run --comment\n partitioned /* comment */ query select * from my_table", + }, + }, + { + input: "run partitioned query", + wantErr: true, + }, + { + input: "run partitioned query /* comment */", + wantErr: true, + }, + { + input: "run partitioned query -- comment\n", + wantErr: true, + }, + { + input: "run partitioned select * from my_table", + wantErr: true, + }, + } + parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000) + if err != nil { + t.Fatal(err) + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + sp := &simpleParser{sql: []byte(test.input), statementParser: parser} + keyword := strings.ToUpper(sp.readKeyword()) + stmt, err := parseStatement(parser, keyword, test.input) + if test.wantErr { + if err == nil { + t.Fatalf("parseStatement(%q) should have failed", test.input) + } + } else { + if err != nil { + t.Fatal(err) + } + runStmt, ok := stmt.(*ParsedRunPartitionedQueryStatement) + if !ok { + t.Fatalf("parseStatement(%q) should have returned a *parsedRunPartitionedQueryStatement", test.input) + } + if !reflect.DeepEqual(*runStmt, test.want) { + t.Errorf("parseStatement(%q) mismatch\n Got: %v\nWant: %v", test.input, *runStmt, test.want) + } + } + }) + } +} diff --git a/partitioned_query_test.go b/partitioned_query_test.go index 5158c963..a0182554 100644 --- a/partitioned_query_test.go +++ b/partitioned_query_test.go @@ -240,41 +240,93 @@ func TestAutoPartitionQuery(t *testing.T) { type queryExecutor interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) } + type autoPartitionTest struct { + name string + useExecOption bool + withTx bool + maxResultsPerPartition int + } + tests := make([]autoPartitionTest, 0) + for _, useExecOption := range []bool{true, false} { + for _, withTx := range []bool{false} { + for maxResultsPerPartition := range []int{0, 1, 5, 50, 200} { + tests = append(tests, autoPartitionTest{ + fmt.Sprintf("useExecOption: %v, withTx: %v, maxResultsPerPartition: %v", useExecOption, withTx, maxResultsPerPartition), + useExecOption, + withTx, + maxResultsPerPartition, + }) + } + } + } - for _, withTx := range []bool{false} { - for maxResultsPerPartition := range []int{0, 1, 5, 50, 200} { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { var tx queryExecutor var err error - if withTx { + if test.withTx { tx, err = BeginBatchReadOnlyTransaction(ctx, db, BatchReadOnlyTransactionOptions{}) } else { - tx = db + tx, err = db.Conn(ctx) } if err != nil { t.Fatal(err) } + defer func() { + if tx, ok := tx.(*sql.Tx); ok { + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + } + if tx, ok := tx.(*sql.Conn); ok { + if err := tx.Close(); err != nil { + t.Fatal(err) + } + } + }() // Setup results for each partition. - maxPartitions, allResults, err := setupRandomPartitionResults(server, testutil.SelectFooFromBar, maxResultsPerPartition) + maxPartitions, allResults, err := setupRandomPartitionResults(server, testutil.SelectFooFromBar, test.maxResultsPerPartition) if err != nil { t.Fatalf("failed to set up partition results: %v", err) } // Automatically partition and execute a query. - rows, err := tx.QueryContext(ctx, testutil.SelectFooFromBar, - ExecOptions{ - PartitionedQueryOptions: PartitionedQueryOptions{ - AutoPartitionQuery: true, - MaxParallelism: rand.Intn(10) + 1, - PartitionOptions: spanner.PartitionOptions{ - MaxPartitions: int64(maxPartitions), + var rows *sql.Rows + if test.useExecOption { + rows, err = tx.QueryContext(ctx, testutil.SelectFooFromBar, + ExecOptions{ + PartitionedQueryOptions: PartitionedQueryOptions{ + AutoPartitionQuery: true, + MaxParallelism: rand.Intn(10) + 1, + PartitionOptions: spanner.PartitionOptions{ + MaxPartitions: int64(maxPartitions), + }, }, - }, - QueryOptions: spanner.QueryOptions{DataBoostEnabled: true}, - }) + QueryOptions: spanner.QueryOptions{DataBoostEnabled: true}, + }) + } else { + execOrFail := func(query string) { + if r, err := tx.QueryContext(ctx, query); err != nil { + t.Fatal(err) + } else { + _ = r.Close() + } + } + set := "set " + if test.withTx { + set = set + "local " + } + execOrFail(set + "auto_partition_mode = true") + execOrFail(set + "data_boost_enabled = true") + execOrFail(fmt.Sprintf(set+"max_partitioned_parallelism = %d", rand.Intn(10)+1)) + execOrFail(fmt.Sprintf(set+"max_partitions = %d", maxPartitions)) + rows, err = tx.QueryContext(ctx, testutil.SelectFooFromBar) + } if err != nil { t.Fatal(err) } + defer func() { _ = rows.Close() }() count := 0 for rows.Next() { @@ -297,12 +349,6 @@ func TestAutoPartitionQuery(t *testing.T) { t.Fatal(err) } - if tx, ok := tx.(*sql.Tx); ok { - if err := tx.Commit(); err != nil { - t.Fatal(err) - } - } - requests := server.TestSpanner.DrainRequestsFromServer() beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&sppb.BeginTransactionRequest{})) if g, w := len(beginRequests), 1; g != w { @@ -329,7 +375,7 @@ func TestAutoPartitionQuery(t *testing.T) { if g, w := len(commitRequests), 0; g != w { t.Fatalf("num commit requests mismatch\n Got: %v\nWant: %v", g, w) } - } + }) } } @@ -398,6 +444,23 @@ func TestAutoPartitionQuery_ExecuteError(t *testing.T) { } } +func TestRunPartitionedQuery(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, _, teardown := setupTestDBConnection(t) + defer teardown() + db.SetMaxOpenConns(1) + + rows, err := db.QueryContext(ctx, "run partitioned query "+testutil.SelectFooFromBar) + if err != nil { + t.Fatal(err) + } + for rows.Next() { + + } +} + func setupRandomPartitionResults(server *testutil.MockedSpannerInMemTestServer, sql string, maxResultsPerPartition int) (maxPartitions int, allResults []int64, err error) { maxPartitions = rand.Intn(10) + 1 // Setup results for each partition. diff --git a/statements.go b/statements.go index 3eeb4683..3a69382b 100644 --- a/statements.go +++ b/statements.go @@ -51,6 +51,8 @@ func createExecutableStatement(stmt parser.ParsedStatement) (executableStatement return &executableRunBatchStatement{stmt: stmt}, nil case *parser.ParsedAbortBatchStatement: return &executableAbortBatchStatement{stmt: stmt}, nil + case *parser.ParsedRunPartitionedQueryStatement: + return &executableRunPartitionedQueryStatement{stmt: stmt}, nil case *parser.ParsedBeginStatement: return &executableBeginStatement{stmt: stmt}, nil case *parser.ParsedCommitStatement: @@ -274,6 +276,19 @@ func (s *executableAbortBatchStatement) queryContext(ctx context.Context, c *con return createEmptyRows(opts), nil } +type executableRunPartitionedQueryStatement struct { + stmt *parser.ParsedRunPartitionedQueryStatement +} + +func (s *executableRunPartitionedQueryStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Result, error) { + return nil, status.Errorf(codes.FailedPrecondition, "cannot use RUN PARTITIONED QUERY with ExecContext") +} + +func (s *executableRunPartitionedQueryStatement) queryContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Rows, error) { + args := []driver.NamedValue{{Value: opts}} + return c.QueryContext(ctx, s.stmt.Statement, args) +} + type executableBeginStatement struct { stmt *parser.ParsedBeginStatement } diff --git a/transaction.go b/transaction.go index c256d20b..f6082c50 100644 --- a/transaction.go +++ b/transaction.go @@ -27,6 +27,7 @@ import ( "cloud.google.com/go/spanner" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/googleapis/gax-go/v2" + "github.com/googleapis/go-sql-spanner/connectionstate" "github.com/googleapis/go-sql-spanner/parser" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -236,6 +237,7 @@ type readOnlyTransaction struct { timestampBoundMu sync.Mutex timestampBoundSet bool timestampBoundCallback func() spanner.TimestampBound + state *connectionstate.ConnectionState } func (tx *readOnlyTransaction) deadline() (time.Time, bool) { @@ -273,7 +275,7 @@ func (tx *readOnlyTransaction) resetForRetry(ctx context.Context) error { func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement, stmtType parser.StatementType, execOptions *ExecOptions) (rowIterator, error) { tx.logger.DebugContext(ctx, "Query", "stmt", stmt.SQL) - if execOptions.PartitionedQueryOptions.AutoPartitionQuery { + if execOptions.PartitionedQueryOptions.AutoPartitionQuery || propertyAutoPartitionMode.GetValueOrDefault(tx.state) { if tx.boTx == nil { return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "AutoPartitionQuery is only supported for batch read-only transactions")) } @@ -281,7 +283,11 @@ func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement if err != nil { return nil, err } - mi := createMergedIterator(tx.logger, pq, execOptions.PartitionedQueryOptions.MaxParallelism) + maxParallelism := execOptions.PartitionedQueryOptions.MaxParallelism + if maxParallelism == 0 { + maxParallelism = propertyMaxPartitionedParallelism.GetValueOrDefault(tx.state) + } + mi := createMergedIterator(tx.logger, pq, maxParallelism) if err := mi.run(ctx); err != nil { mi.Stop() return nil, err @@ -311,7 +317,12 @@ func (tx *readOnlyTransaction) createPartitionedQuery(ctx context.Context, stmt if tx.boTx == nil { return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "partitionQuery is only supported for batch read-only transactions")) } - partitions, err := tx.boTx.PartitionQueryWithOptions(ctx, stmt, execOptions.PartitionedQueryOptions.PartitionOptions, execOptions.QueryOptions) + partitionOptions := execOptions.PartitionedQueryOptions.PartitionOptions + if partitionOptions.MaxPartitions == 0 && partitionOptions.PartitionBytes == 0 { + partitionOptions.MaxPartitions = propertyMaxPartitions.GetValueOrDefault(tx.state) + } + execOptions.QueryOptions.DataBoostEnabled = execOptions.QueryOptions.DataBoostEnabled || propertyDataBoostEnabled.GetValueOrDefault(tx.state) + partitions, err := tx.boTx.PartitionQueryWithOptions(ctx, stmt, partitionOptions, execOptions.QueryOptions) if err != nil { return nil, err }