Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions contrib/database/sql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"context"
"database/sql/driver"
"math"
"strings"
"time"

"github.com/DataDog/dd-trace-go/v2/appsec/events"
Expand Down Expand Up @@ -285,6 +286,11 @@ func (tc *TracedConn) providedPeerService(ctx context.Context) string {
// with a span ID injected into SQL comments. The returned span ID should be used when the SQL span is created
// following the traced database call.
func (tc *TracedConn) injectComments(ctx context.Context, query string, mode tracer.DBMPropagationMode) (cquery string, spanID uint64) {
if tc.cfg.copyNotSupported && strings.EqualFold(query[:4], "COPY") {
// COPY is not supported for lib/pq, so we need to disable the comment injection
mode = tracer.DBMPropagationModeDisabled
}

// The sql span only gets created after the call to the database because we need to be able to skip spans
// when a driver returns driver.ErrSkip. In order to work with those constraints, a new span id is generated and
// used during SQL comment injection and returned for the sql span to be used later when/if the span
Expand Down
118 changes: 83 additions & 35 deletions contrib/database/sql/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type config struct {
dbmPropagationMode tracer.DBMPropagationMode
dbStats bool
statsdClient instrumentation.StatsdClient
copyNotSupported bool
}

// checkStatsdRequired adds a statsdclient onto the config if dbstats is enabled
Expand All @@ -48,49 +49,57 @@ func (c *config) checkStatsdRequired() {
}

func (c *config) checkDBMPropagation(driverName string, driver driver.Driver, dsn string) {
if c.dbmPropagationMode == tracer.DBMPropagationModeFull {
if dsn == "" {
dsn = c.dsn
}
if dbSystem, ok := dbmFullModeUnsupported(driverName, driver, dsn); ok {
instr.Logger().Warn("Using DBM_PROPAGATION_MODE in 'full' mode is not supported for %s, downgrading to 'service' mode. "+
"See https://docs.datadoghq.com/database_monitoring/connect_dbm_and_apm/ for more info.",
dbSystem,
)
c.dbmPropagationMode = tracer.DBMPropagationModeService
}
if c.dbmPropagationMode == tracer.DBMPropagationModeDisabled {
return
}
if c.dbmPropagationMode == tracer.DBMPropagationModeUndefined {
return
}
if dsn == "" {
dsn = c.dsn
}
// this case applies to full and service modes
if dbSystem, reason, ok := dbmPartiallySupported(driver, c); ok {
instr.Logger().Warn("Using DBM_PROPAGATION_MODE in '%s' mode is partially supported for %s: %s. "+
"See https://docs.datadoghq.com/database_monitoring/connect_dbm_and_apm/ for more info.",
c.dbmPropagationMode,
dbSystem,
reason,
)
}
if c.dbmPropagationMode != tracer.DBMPropagationModeFull {
return
}
// full mode is not supported for some drivers, so we need to check for that
if dbSystem, ok := dbmFullModeUnsupported(driverName, driver, dsn); ok {
instr.Logger().Warn("Using DBM_PROPAGATION_MODE in 'full' mode is not supported for %s, downgrading to 'service' mode. "+
"See https://docs.datadoghq.com/database_monitoring/connect_dbm_and_apm/ for more info.",
dbSystem,
)
c.dbmPropagationMode = tracer.DBMPropagationModeService
}
}

type unsupportedDriverModule struct {
prefix string
pkgName string
dbSystem string
reason string
updateConfig func(*config)
}

func dbmFullModeUnsupported(driverName string, driver driver.Driver, dsn string) (string, bool) {
const (
sqlServer = "SQL Server"
oracle = "Oracle"
)
// check if the driver package path is one of the unsupported ones.
if tp := reflect.TypeOf(driver); tp != nil && (tp.Kind() == reflect.Pointer || tp.Kind() == reflect.Struct) {
pkgPath := ""
switch tp.Kind() {
case reflect.Pointer:
pkgPath = tp.Elem().PkgPath()
case reflect.Struct:
pkgPath = tp.PkgPath()
}
driverPkgs := [][3]string{
{"github.com", "denisenkom/go-mssqldb", sqlServer},
{"github.com", "microsoft/go-mssqldb", sqlServer},
{"github.com", "sijms/go-ora", oracle},
}
for _, dp := range driverPkgs {
prefix, pkgName, dbSystem := dp[0], dp[1], dp[2]

// compare without the prefix to make it work for vendoring.
// also, compare only the prefix to make the comparison work when using major versions
// of the libraries or subpackages.
if strings.HasPrefix(strings.TrimPrefix(pkgPath, prefix+"/"), pkgName) {
return dbSystem, true
}
}
driverPkgs := []unsupportedDriverModule{
{"github.com", "denisenkom/go-mssqldb", sqlServer, "", nil},
{"github.com", "microsoft/go-mssqldb", sqlServer, "", nil},
{"github.com", "sijms/go-ora", oracle, "", nil},
}
if ix := unsupportedDriver(driver, driverPkgs); ix != -1 {
return driverPkgs[ix].dbSystem, true
}

// check the DSN if provided.
Expand Down Expand Up @@ -123,6 +132,45 @@ func dbmFullModeUnsupported(driverName string, driver driver.Driver, dsn string)
return "", false
}

func dbmPartiallySupported(driver driver.Driver, c *config) (string, string, bool) {
driverPkgs := []unsupportedDriverModule{
{"github.com", "lib/pq", "PostgreSQL", "COPY doesn't support comments", func(cfg *config) {
cfg.copyNotSupported = true
}},
}
if ix := unsupportedDriver(driver, driverPkgs); ix != -1 {
if driverPkgs[ix].updateConfig != nil {
driverPkgs[ix].updateConfig(c)
}
return driverPkgs[ix].dbSystem, driverPkgs[ix].reason, true
}
return "", "", false
}

func unsupportedDriver(driver driver.Driver, driverPkgs []unsupportedDriverModule) int {
// check if the driver package path is one of the unsupported ones.
if tp := reflect.TypeOf(driver); tp != nil && (tp.Kind() == reflect.Pointer || tp.Kind() == reflect.Struct) {
pkgPath := ""
switch tp.Kind() {
case reflect.Pointer:
pkgPath = tp.Elem().PkgPath()
case reflect.Struct:
pkgPath = tp.PkgPath()
}
for ix, dp := range driverPkgs {
prefix, pkgName := dp.prefix, dp.pkgName

// compare without the prefix to make it work for vendoring.
// also, compare only the prefix to make the comparison work when using major versions
// of the libraries or subpackages.
if strings.HasPrefix(strings.TrimPrefix(pkgPath, prefix+"/"), pkgName) {
return ix
}
}
}
return -1
}

// Option describes options for the database/sql integration.
type Option interface {
apply(*config)
Expand Down
44 changes: 44 additions & 0 deletions contrib/database/sql/propagation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ import (
"database/sql/driver"
"io"
"net/http"
"os"
"regexp"
"testing"
"time"

mssql "github.com/denisenkom/go-mssqldb"
"github.com/go-sql-driver/mysql"
"github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -256,6 +259,47 @@ func TestDBMPropagation(t *testing.T) {
}
}

func TestDBMPropagationFullOnPqCopy(t *testing.T) {
if _, ok := os.LookupEnv("INTEGRATION"); !ok {
t.Skip("skipping integration test")
}
tr := mocktracer.Start()
defer tr.Stop()

Register("postgres", &pq.Driver{}, WithDBMPropagation(tracer.DBMPropagationModeFull))
db, err := Open("postgres", "postgres://postgres:[email protected]:5432/postgres?sslmode=disable")
require.NoError(t, err)

t.Cleanup(func() {
// Using a new 10s-timeout context, as we may be running cleanup after the original context expired.
_, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
assert.NoError(t, db.Close())
})

db.Exec("DROP TABLE IF EXISTS testsql")
db.Exec("CREATE TABLE testsql (dn text, name text, sam_account_name text, mail text, primary_group_id text)")
t.Cleanup(func() {
db.Exec("DROP TABLE IF EXISTS testsql")
})

tx, err := db.Begin()
require.NoError(t, err)
defer tx.Rollback()

s := pq.CopyInSchema("public", "testsql", "dn", "name", "sam_account_name", "mail", "primary_group_id")
stmt, err := tx.Prepare(s)
require.NoError(t, err)
defer stmt.Close()

_, err = stmt.Exec("dn", "name0", "sam", nil, nil)
require.NoError(t, err)

spans := tr.FinishedSpans()
require.Len(t, spans, 6)
assert.Equal(t, `COPY "public"."testsql" ("dn", "name", "sam_account_name", "mail", "primary_group_id") FROM STDIN`, spans[5].Tags()[ext.ResourceName])
}

func TestDBMTraceContextTagging(t *testing.T) {
testCases := []struct {
name string
Expand Down
Loading