diff --git a/conn.go b/conn.go index 60f92ec..1ee5d23 100644 --- a/conn.go +++ b/conn.go @@ -4,6 +4,7 @@ import ( "context" "database/sql/driver" "errors" + "regexp" "time" "github.com/aws/aws-sdk-go/aws" @@ -47,9 +48,12 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error) return nil, err } + skipHeaders := !isDDL(query) + return newRows(rowsConfig{ - Athena: c.athena, - QueryID: queryID, + Athena: c.athena, + QueryID: queryID, + SkipHeaders: skipHeaders, }) } @@ -134,3 +138,9 @@ func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { var _ driver.Queryer = (*conn)(nil) var _ driver.Execer = (*conn)(nil) + +var ddlQueryRegex = regexp.MustCompile(`^(ALTER|CREATE|DESCRIBE|DROP|MSCK|SHOW)`) + +func isDDL(query string) bool { + return ddlQueryRegex.Match([]byte(query)) +} diff --git a/rows.go b/rows.go index 6f77bc4..aec82c7 100644 --- a/rows.go +++ b/rows.go @@ -10,22 +10,24 @@ import ( ) type rows struct { - athena athenaiface.AthenaAPI - queryID string - - done bool - out *athena.GetQueryResultsOutput + athena athenaiface.AthenaAPI + queryID string + skipHeaders bool + done bool + out *athena.GetQueryResultsOutput } type rowsConfig struct { - Athena athenaiface.AthenaAPI - QueryID string + Athena athenaiface.AthenaAPI + QueryID string + SkipHeaders bool } func newRows(cfg rowsConfig) (*rows, error) { r := rows{ - athena: cfg.Athena, - queryID: cfg.QueryID, + athena: cfg.Athena, + queryID: cfg.QueryID, + skipHeaders: cfg.SkipHeaders, } shouldContinue, err := r.fetchNextPage(nil) @@ -97,13 +99,18 @@ func (r *rows) fetchNextPage(token *string) (bool, error) { return false, err } - // First row of an Athena response contains headers. + // First row of an Athena response (except of DDL queries) contains headers. // These are also available in *athena.Row.ResultSetMetadata. - if len(r.out.ResultSet.Rows) < 2 { + minRows := 2 + if !r.skipHeaders { + minRows = 1 + } + + if len(r.out.ResultSet.Rows) < minRows { return false, nil } - r.out.ResultSet.Rows = r.out.ResultSet.Rows[1:] + r.out.ResultSet.Rows = r.out.ResultSet.Rows[(minRows - 1):] return true, nil }