diff --git a/byte_micro_perf/core/backend.py b/byte_micro_perf/core/backend.py index 2bd6c7fe..4e186ca5 100644 --- a/byte_micro_perf/core/backend.py +++ b/byte_micro_perf/core/backend.py @@ -147,7 +147,7 @@ def perf(self, op_instance, profiling=True): tensor_size = op_instance.tensor_size # device - device_mem_info = self.get_mem_info() + device_mem_info = self.get_mem_info(self.true_device_index) avail_memory = device_mem_info[0] # assume diff --git a/byte_micro_perf/core/scheduler.py b/byte_micro_perf/core/scheduler.py index a38484de..a471e56b 100644 --- a/byte_micro_perf/core/scheduler.py +++ b/byte_micro_perf/core/scheduler.py @@ -237,6 +237,7 @@ def subprocess_func(self, instance_rank : int, *args): true_device_index = backend.target_devices[true_rank] print(f"true_world_size: {true_world_size}, true_rank: {true_rank}, true_device_index: {true_device_index}") backend.set_device(true_device_index) + backend.true_device_index = true_device_index # device process is ready output_queues.put("ready") @@ -291,6 +292,7 @@ def subprocess_func(self, instance_rank : int, *args): true_device_index = backend.all_node_devices[true_rank] print(f"true_world_size: {true_world_size}, true_rank: {true_rank}, true_device_index: {true_device_index}") backend.set_device(true_device_index) + backend.true_device_index = true_device_index # init dist env dist_module = backend.get_dist_module()