diff --git a/contrib/database/sql/conn.go b/contrib/database/sql/conn.go index f714327431..016d753a38 100644 --- a/contrib/database/sql/conn.go +++ b/contrib/database/sql/conn.go @@ -9,6 +9,7 @@ import ( "context" "database/sql/driver" "math" + "strings" "time" "github.com/DataDog/dd-trace-go/v2/appsec/events" @@ -285,6 +286,11 @@ func (tc *TracedConn) providedPeerService(ctx context.Context) string { // with a span ID injected into SQL comments. The returned span ID should be used when the SQL span is created // following the traced database call. func (tc *TracedConn) injectComments(ctx context.Context, query string, mode tracer.DBMPropagationMode) (cquery string, spanID uint64) { + if tc.cfg.copyNotSupported && strings.EqualFold(query[:4], "COPY") { + // COPY is not supported for lib/pq, so we need to disable the comment injection + mode = tracer.DBMPropagationModeDisabled + } + // The sql span only gets created after the call to the database because we need to be able to skip spans // when a driver returns driver.ErrSkip. In order to work with those constraints, a new span id is generated and // used during SQL comment injection and returned for the sql span to be used later when/if the span diff --git a/contrib/database/sql/option.go b/contrib/database/sql/option.go index 84670fdbea..e535be49b4 100644 --- a/contrib/database/sql/option.go +++ b/contrib/database/sql/option.go @@ -29,6 +29,7 @@ type config struct { dbmPropagationMode tracer.DBMPropagationMode dbStats bool statsdClient instrumentation.StatsdClient + copyNotSupported bool } // checkStatsdRequired adds a statsdclient onto the config if dbstats is enabled @@ -48,49 +49,57 @@ func (c *config) checkStatsdRequired() { } func (c *config) checkDBMPropagation(driverName string, driver driver.Driver, dsn string) { - if c.dbmPropagationMode == tracer.DBMPropagationModeFull { - if dsn == "" { - dsn = c.dsn - } - if dbSystem, ok := dbmFullModeUnsupported(driverName, driver, dsn); ok { - instr.Logger().Warn("Using DBM_PROPAGATION_MODE in 'full' mode is not supported for %s, downgrading to 'service' mode. "+ - "See https://docs.datadoghq.com/database_monitoring/connect_dbm_and_apm/ for more info.", - dbSystem, - ) - c.dbmPropagationMode = tracer.DBMPropagationModeService - } + if c.dbmPropagationMode == tracer.DBMPropagationModeDisabled { + return + } + if c.dbmPropagationMode == tracer.DBMPropagationModeUndefined { + return + } + if dsn == "" { + dsn = c.dsn + } + // this case applies to full and service modes + if dbSystem, reason, ok := dbmPartiallySupported(driver, c); ok { + instr.Logger().Warn("Using DBM_PROPAGATION_MODE in '%s' mode is partially supported for %s: %s. "+ + "See https://docs.datadoghq.com/database_monitoring/connect_dbm_and_apm/ for more info.", + c.dbmPropagationMode, + dbSystem, + reason, + ) + } + if c.dbmPropagationMode != tracer.DBMPropagationModeFull { + return + } + // full mode is not supported for some drivers, so we need to check for that + if dbSystem, ok := dbmFullModeUnsupported(driverName, driver, dsn); ok { + instr.Logger().Warn("Using DBM_PROPAGATION_MODE in 'full' mode is not supported for %s, downgrading to 'service' mode. "+ + "See https://docs.datadoghq.com/database_monitoring/connect_dbm_and_apm/ for more info.", + dbSystem, + ) + c.dbmPropagationMode = tracer.DBMPropagationModeService } } +type unsupportedDriverModule struct { + prefix string + pkgName string + dbSystem string + reason string + updateConfig func(*config) +} + func dbmFullModeUnsupported(driverName string, driver driver.Driver, dsn string) (string, bool) { const ( sqlServer = "SQL Server" oracle = "Oracle" ) - // check if the driver package path is one of the unsupported ones. - if tp := reflect.TypeOf(driver); tp != nil && (tp.Kind() == reflect.Pointer || tp.Kind() == reflect.Struct) { - pkgPath := "" - switch tp.Kind() { - case reflect.Pointer: - pkgPath = tp.Elem().PkgPath() - case reflect.Struct: - pkgPath = tp.PkgPath() - } - driverPkgs := [][3]string{ - {"github.com", "denisenkom/go-mssqldb", sqlServer}, - {"github.com", "microsoft/go-mssqldb", sqlServer}, - {"github.com", "sijms/go-ora", oracle}, - } - for _, dp := range driverPkgs { - prefix, pkgName, dbSystem := dp[0], dp[1], dp[2] - - // compare without the prefix to make it work for vendoring. - // also, compare only the prefix to make the comparison work when using major versions - // of the libraries or subpackages. - if strings.HasPrefix(strings.TrimPrefix(pkgPath, prefix+"/"), pkgName) { - return dbSystem, true - } - } + driverPkgs := []unsupportedDriverModule{ + {"github.com", "denisenkom/go-mssqldb", sqlServer, "", nil}, + {"github.com", "microsoft/go-mssqldb", sqlServer, "", nil}, + {"github.com", "sijms/go-ora", oracle, "", nil}, + } + if ix := unsupportedDriver(driver, driverPkgs); ix != -1 { + return driverPkgs[ix].dbSystem, true } // check the DSN if provided. @@ -123,6 +132,45 @@ func dbmFullModeUnsupported(driverName string, driver driver.Driver, dsn string) return "", false } +func dbmPartiallySupported(driver driver.Driver, c *config) (string, string, bool) { + driverPkgs := []unsupportedDriverModule{ + {"github.com", "lib/pq", "PostgreSQL", "COPY doesn't support comments", func(cfg *config) { + cfg.copyNotSupported = true + }}, + } + if ix := unsupportedDriver(driver, driverPkgs); ix != -1 { + if driverPkgs[ix].updateConfig != nil { + driverPkgs[ix].updateConfig(c) + } + return driverPkgs[ix].dbSystem, driverPkgs[ix].reason, true + } + return "", "", false +} + +func unsupportedDriver(driver driver.Driver, driverPkgs []unsupportedDriverModule) int { + // check if the driver package path is one of the unsupported ones. + if tp := reflect.TypeOf(driver); tp != nil && (tp.Kind() == reflect.Pointer || tp.Kind() == reflect.Struct) { + pkgPath := "" + switch tp.Kind() { + case reflect.Pointer: + pkgPath = tp.Elem().PkgPath() + case reflect.Struct: + pkgPath = tp.PkgPath() + } + for ix, dp := range driverPkgs { + prefix, pkgName := dp.prefix, dp.pkgName + + // compare without the prefix to make it work for vendoring. + // also, compare only the prefix to make the comparison work when using major versions + // of the libraries or subpackages. + if strings.HasPrefix(strings.TrimPrefix(pkgPath, prefix+"/"), pkgName) { + return ix + } + } + } + return -1 +} + // Option describes options for the database/sql integration. type Option interface { apply(*config) diff --git a/contrib/database/sql/propagation_test.go b/contrib/database/sql/propagation_test.go index a1b5a8598f..1191a46516 100644 --- a/contrib/database/sql/propagation_test.go +++ b/contrib/database/sql/propagation_test.go @@ -12,11 +12,14 @@ import ( "database/sql/driver" "io" "net/http" + "os" "regexp" "testing" + "time" mssql "github.com/denisenkom/go-mssqldb" "github.com/go-sql-driver/mysql" + "github.com/lib/pq" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -256,6 +259,47 @@ func TestDBMPropagation(t *testing.T) { } } +func TestDBMPropagationFullOnPqCopy(t *testing.T) { + if _, ok := os.LookupEnv("INTEGRATION"); !ok { + t.Skip("skipping integration test") + } + tr := mocktracer.Start() + defer tr.Stop() + + Register("postgres", &pq.Driver{}, WithDBMPropagation(tracer.DBMPropagationModeFull)) + db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable") + require.NoError(t, err) + + t.Cleanup(func() { + // Using a new 10s-timeout context, as we may be running cleanup after the original context expired. + _, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + assert.NoError(t, db.Close()) + }) + + db.Exec("DROP TABLE IF EXISTS testsql") + db.Exec("CREATE TABLE testsql (dn text, name text, sam_account_name text, mail text, primary_group_id text)") + t.Cleanup(func() { + db.Exec("DROP TABLE IF EXISTS testsql") + }) + + tx, err := db.Begin() + require.NoError(t, err) + defer tx.Rollback() + + s := pq.CopyInSchema("public", "testsql", "dn", "name", "sam_account_name", "mail", "primary_group_id") + stmt, err := tx.Prepare(s) + require.NoError(t, err) + defer stmt.Close() + + _, err = stmt.Exec("dn", "name0", "sam", nil, nil) + require.NoError(t, err) + + spans := tr.FinishedSpans() + require.Len(t, spans, 6) + assert.Equal(t, `COPY "public"."testsql" ("dn", "name", "sam_account_name", "mail", "primary_group_id") FROM STDIN`, spans[5].Tags()[ext.ResourceName]) +} + func TestDBMTraceContextTagging(t *testing.T) { testCases := []struct { name string