Skip to content

Commit

Permalink
refactor: using only one GPU during testing if there are more than on…
Browse files Browse the repository at this point in the history
…e available;
  • Loading branch information
WenjieDu committed Dec 6, 2023
1 parent 2542a0c commit 6dad105
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tests/global_test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import os

import numpy as np
import torch

from pypots.data.generating import gene_random_walk
Expand All @@ -33,10 +34,11 @@


# set DEVICES to None if no cuda device is available, to avoid initialization failed while importing test classes
cuda_devices = [torch.device(i) for i in range(torch.cuda.device_count())]
if len(cuda_devices) > 2:
n_cuda_devices = torch.cuda.device_count()
cuda_devices = [torch.device(i) for i in range(n_cuda_devices)]
if n_cuda_devices > 1:
logger.info("❗️Detected multiple cuda devices, using all of them to run testing.")
DEVICE = cuda_devices
DEVICE = cuda_devices[np.random.randint(n_cuda_devices)]
else:
# if having no multiple cuda devices, leave it as None to use the default device
DEVICE = None
Expand Down

0 comments on commit 6dad105

Please sign in to comment.