diff --git a/bindings/go/nvml/bindings.go b/bindings/go/nvml/bindings.go index 6898918..a650c13 100644 --- a/bindings/go/nvml/bindings.go +++ b/bindings/go/nvml/bindings.go @@ -48,6 +48,8 @@ const ( XidCriticalError = C.nvmlEventTypeXidCriticalError ) +var nvmlEventSetWait = nvmlEventSetWait_v1 + type handle struct{ dev C.nvmlDevice_t } type EventSet struct{ set C.nvmlEventSet_t } type Event struct { @@ -58,6 +60,14 @@ type Event struct { Edata uint64 } +func nvmlEventSetWait_v1(Set C.nvmlEventSet_t, Data *C.nvmlEventData_t, Timeoutms C.uint) C.nvmlReturn_t { + return C.nvmlEventSetWait(Set, Data, Timeoutms) +} + +func nvmlEventSetWait_v2(Set C.nvmlEventSet_t, Data *C.nvmlEventData_t, Timeoutms C.uint) C.nvmlReturn_t { + return C.nvmlEventSetWait_v2(Set, Data, Timeoutms) +} + func uintPtr(c C.uint) *uint { i := uint(c) return &i @@ -86,6 +96,12 @@ func init_() error { if r == C.NVML_ERROR_LIBRARY_NOT_FOUND { return errors.New("could not load NVML library") } + + found := dl.lookupSymbol("nvmlEventSetWait_v2") + if found == C.NVML_SUCCESS { + nvmlEventSetWait = nvmlEventSetWait_v2 + } + return errorString(r) } @@ -157,15 +173,10 @@ func DeleteEventSet(es EventSet) { func WaitForEvent(es EventSet, timeout uint) (Event, error) { var data C.nvmlEventData_t + data.gpuInstanceId = 0xFFFFFFFF + data.computeInstanceId = 0xFFFFFFFF - r := dl.lookupSymbol("nvmlEventSetWait_v2") - if r == C.NVML_SUCCESS { - r = C.nvmlEventSetWait_v2(es.set, &data, C.uint(timeout)) - } else { - r = C.nvmlEventSetWait(es.set, &data, C.uint(timeout)) - data.gpuInstanceId = 0xFFFFFFFF - data.computeInstanceId = 0xFFFFFFFF - } + r := nvmlEventSetWait(es.set, &data, C.uint(timeout)) if r != C.NVML_SUCCESS { return Event{}, errorString(r) }