diff --git a/clusterscope/job_info.py b/clusterscope/job_info.py index e330c7e..478fb93 100644 --- a/clusterscope/job_info.py +++ b/clusterscope/job_info.py @@ -9,6 +9,8 @@ from functools import lru_cache +from clusterscope.cluster_info import LocalNodeInfo + MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000) @@ -40,6 +42,23 @@ def __init__(self): self.world_size = self.get_world_size() self.is_rank_zero = self.get_is_rank_zero() + @lru_cache(maxsize=1) + def get_cpus(self) -> int: + if self.is_slurm_job(): + return int(os.environ.get("SLURM_CPUS_ON_NODE", 1)) + return int(max(os.cpu_count() or 0, 1)) + + @lru_cache(maxsize=1) + def get_gpus(self) -> int: + if self.is_slurm_job(): + return int(os.environ.get("SLURM_GPUS_ON_NODE", 1)) + return sum( + [ + int(count) + for gpu, count in LocalNodeInfo().get_gpu_generation_and_count().items() + ] + ) + @lru_cache(maxsize=1) def get_job_id(self) -> int: if self.is_slurm_job():