diff --git a/bindings/go/dcgm/api.go b/bindings/go/dcgm/api.go index a25d911..05a446d 100644 --- a/bindings/go/dcgm/api.go +++ b/bindings/go/dcgm/api.go @@ -101,3 +101,8 @@ func Policy(gpuId uint, typ ...policyCondition) (<-chan PolicyViolation, error) func Introspect() (DcgmStatus, error) { return introspect() } + +// Get all of the profiling metric groups for a given GPU group. +func GetSupportedMetricGroups(grpid uint) ([]MetricGroup, error) { + return getSupportedMetricGroups(grpid) +} diff --git a/bindings/go/dcgm/profile.go b/bindings/go/dcgm/profile.go new file mode 100644 index 0000000..25ca752 --- /dev/null +++ b/bindings/go/dcgm/profile.go @@ -0,0 +1,47 @@ +package dcgm + +/* +#include "dcgm_agent.h" +#include "dcgm_structs.h" +*/ +import "C" +import ( + "fmt" + "unsafe" +) + +type MetricGroup struct { + major uint + minor uint + fieldIds []uint +} + +func getSupportedMetricGroups(grpid uint) (groups []MetricGroup, err error) { + + var groupInfo C.dcgmProfGetMetricGroups_t + groupInfo.version = makeVersion2(unsafe.Sizeof(groupInfo)) + groupInfo.groupId = C.ulong(grpid) + + result := C.dcgmProfGetSupportedMetricGroups(handle.handle, &groupInfo) + + if err = errorString(result); err != nil { + return groups, fmt.Errorf("Error getting supported metrics: %s", err) + } + + var count = uint(groupInfo.numMetricGroups) + + for i := uint(0); i < count; i++ { + var group MetricGroup + group.major = uint(groupInfo.metricGroups[i].majorId) + group.minor = uint(groupInfo.metricGroups[i].minorId) + + var fieldCount = uint(groupInfo.metricGroups[i].numFieldIds) + + for j := uint(0); j < fieldCount; j++ { + group.fieldIds = append(group.fieldIds, uint(groupInfo.metricGroups[i].fieldIds[j])) + } + groups = append(groups, group) + } + + return groups, nil +} diff --git a/pkg/main.go b/pkg/main.go index 525e744..a796406 100644 --- a/pkg/main.go +++ b/pkg/main.go @@ -103,6 +103,14 @@ restart: } logrus.Info("DCGM successfully initialized!") + _, err = dcgm.GetSupportedMetricGroups(0) + if err != nil { + config.CollectDCP = false + logrus.Info("Not collecting DCP metrics: ", err) + } else { + logrus.Info("Collecting DCP Metrics") + } + ch := make(chan string, 10) pipeline, cleanup, err := NewMetricsPipeline(config) defer cleanup() @@ -153,5 +161,6 @@ func contextToConfig(c *cli.Context) *Config { CollectInterval: c.Int(CLICollectInterval), Kubernetes: c.Bool(CLIKubernetes), KubernetesGPUIdType: KubernetesGPUIDType(c.String(CLIKubernetesGPUIDType)), + CollectDCP: true, } } diff --git a/pkg/parser.go b/pkg/parser.go index 6f251d9..d0f2f31 100644 --- a/pkg/parser.go +++ b/pkg/parser.go @@ -26,14 +26,14 @@ import ( "github.com/sirupsen/logrus" ) -func ExtractCounters(filename string) ([]Counter, error) { +func ExtractCounters(filename string, dcpAllowed bool) ([]Counter, error) { records, err := ReadCSVFile(filename) if err != nil { fmt.Printf("Error: %v\n", err) return nil, err } - counters, err := extractCounters(records) + counters, err := extractCounters(records, dcpAllowed) if err != nil { return nil, err } @@ -55,7 +55,7 @@ func ReadCSVFile(filename string) ([][]string, error) { return records, err } -func extractCounters(records [][]string) ([]Counter, error) { +func extractCounters(records [][]string, dcpAllowed bool) ([]Counter, error) { f := make([]Counter, 0, len(records)) for i, record := range records { @@ -81,6 +81,11 @@ func extractCounters(records [][]string) ([]Counter, error) { return nil, fmt.Errorf("Could not find DCGM field %s", record[0]) } + if !dcpAllowed && fieldID >= 1000 { + logrus.Warnf("Skipping line %d ('%s'): DCP metrics not enabled", i, record[0]) + continue + } + if _, ok := promMetricType[record[1]]; !ok { return nil, fmt.Errorf("Could not find Prometheus metry type %s", record[1]) } diff --git a/pkg/pipeline.go b/pkg/pipeline.go index 415d00f..a9a6788 100644 --- a/pkg/pipeline.go +++ b/pkg/pipeline.go @@ -27,7 +27,7 @@ import ( ) func NewMetricsPipeline(c *Config) (*MetricsPipeline, func(), error) { - counters, err := ExtractCounters(c.CollectorsFile) + counters, err := ExtractCounters(c.CollectorsFile, c.CollectDCP) if err != nil { return nil, func() {}, err } diff --git a/pkg/types.go b/pkg/types.go index b8a11ec..c849f0b 100644 --- a/pkg/types.go +++ b/pkg/types.go @@ -50,6 +50,7 @@ type Config struct { CollectInterval int Kubernetes bool KubernetesGPUIdType KubernetesGPUIDType + CollectDCP bool } type Transform interface {