diff --git a/conn.go b/conn.go index b131e4c..c8a81cb 100644 --- a/conn.go +++ b/conn.go @@ -4,6 +4,7 @@ import ( "context" "database/sql/driver" "errors" + "regexp" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -48,10 +49,9 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error) } return newRows(ctx, rowsConfig{ - Athena: c.athena, - QueryID: queryID, - // todo add check for ddl queries to not skip header(#10) - SkipHeader: true, + Athena: c.athena, + QueryID: queryID, + SkipHeader: isDDLQuery(query), }) } @@ -120,5 +120,13 @@ func (c *conn) Close() error { return nil } +// supported DDL statements by Athena +// https://docs.aws.amazon.com/athena/latest/ug/language-reference.html +var ddlQueryRegex = regexp.MustCompile(`(?i)^(ALTER|CREATE|DESCRIBE|DROP|MSCK|SHOW)`) + +func isDDLQuery(query string) bool { + return ddlQueryRegex.Match([]byte(query)) +} + var _ driver.QueryerContext = (*conn)(nil) var _ driver.ExecerContext = (*conn)(nil) diff --git a/db_test.go b/db_test.go index a784853..c236ade 100644 --- a/db_test.go +++ b/db_test.go @@ -129,6 +129,25 @@ func TestOpen(t *testing.T) { require.NoError(t, err, "Query") } +func TestDDLQuery(t *testing.T) { + harness := setup(t) + defer harness.teardown() + + rows := harness.mustQuery("show tables") + + output := make([]string, 0) + for rows.Next() { + var table string + + err := rows.Scan(&table) + assert.NoError(t, err, "rows.Scan()") + + output = append(output, table) + } + + assert.Equal(t, 1, len(output), "query output") +} + type dummyRow struct { NullValue *struct{} `json:"nullValue"` SmallintType int `json:"smallintType"`