Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions basicsr/data/degradations.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,11 @@ def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
out = torch.poisson(img_gray * vals) / vals
device = img_gray.device
if device.type == 'xpu':
out = torch.poisson((img_gray * vals).to('cpu')).to(device) / vals
else:
out = torch.poisson(img * vals) / vals
noise_gray = out - img_gray
noise_gray = noise_gray.expand(b, 3, h, w)

Expand All @@ -645,7 +649,11 @@ def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
out = torch.poisson(img * vals) / vals
device = img.device
if device.type == 'xpu':
out = torch.poisson((img * vals).to('cpu')).to(device) / vals
else:
out = torch.poisson(img * vals) / vals
noise = out - img
if cal_gray_noise:
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
Expand Down
40 changes: 40 additions & 0 deletions basicsr/data/prefetch_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,43 @@ def next(self):
def reset(self):
self.loader = iter(self.ori_loader)
self.preload()

class XPUPrefetcher():
"""XPU prefetcher.

It may consume more GPU memory.

Args:
loader: Dataloader.
opt (dict): Options.
"""

def __init__(self, loader, opt):
self.ori_loader = loader
self.loader = iter(loader)
self.opt = opt
self.stream = torch.xpu.Stream()
self.device = torch.device('xpu' if opt['num_gpu'] != 0 else 'cpu')
self.preload()

def preload(self):
try:
self.batch = next(self.loader) # self.batch is a dict
except StopIteration:
self.batch = None
return None
# put tensors to gpu
with torch.cuda.stream(self.stream):
for k, v in self.batch.items():
if torch.is_tensor(v):
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)

def next(self):
torch.xpu.current_stream().wait_stream(self.stream)
batch = self.batch
self.preload()
return batch

def reset(self):
self.loader = iter(self.ori_loader)
self.preload()
23 changes: 20 additions & 3 deletions basicsr/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,23 @@ class BaseModel():

def __init__(self, opt):
self.opt = opt
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
self.device = 'cpu'
self.dtype = opt['dtype']
if opt['num_gpu'] != 0:
if torch.cuda.is_available():
self.device = 'cuda'
if torch.xpu.is_available():
self.device = 'xpu'
self.is_train = opt['is_train']
self.schedulers = []
self.optimizers = []

def empty_cache(self):
if self.device == 'cuda':
torch.cuda.empty_cache()
elif self.device == 'xpu':
torch.xpu.empty_cache()

def feed_data(self, data):
pass

Expand Down Expand Up @@ -91,11 +103,16 @@ def model_to_device(self, net):
Args:
net (nn.Module)
"""
net = net.to(self.device)
net = net.to(self.device, dtype = self.dtype)
if self.opt['dist']:
find_unused_parameters = self.opt.get('find_unused_parameters', False)
ids = [0]
if self.device == 'cuda':
ids = [torch.cuda.current_device()]
if self.device == 'xpu':
ids = [torch.xpu.current_device()]
net = DistributedDataParallel(
net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
net, device_ids=ids, find_unused_parameters=find_unused_parameters)
elif self.opt['num_gpu'] > 1:
net = DataParallel(net)
return net
Expand Down
2 changes: 1 addition & 1 deletion basicsr/models/hifacegan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()
self.empty_cache()

if save_img:
if self.opt['is_train']:
Expand Down
11 changes: 7 additions & 4 deletions basicsr/models/realesrgan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class RealESRGANModel(SRGANModel):

def __init__(self, opt):
super(RealESRGANModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
self.jpeger = DiffJPEG(differentiable=False).to(self.device) # simulate JPEG compression artifacts
self.usm_sharpener = USMSharp().to(self.device) # do usm sharpening
self.queue_size = opt.get('queue_size', 180)

@torch.no_grad()
Expand All @@ -40,9 +40,9 @@ def _dequeue_and_enqueue(self):
b, c, h, w = self.lq.size()
if not hasattr(self, 'queue_lr'):
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
self.queue_lr = torch.zeros(self.queue_size, c, h, w).to(self.device)
_, c, h, w = self.gt.size()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).to(self.device)
self.queue_ptr = 0
if self.queue_ptr == self.queue_size: # the pool is full
# do dequeue and enqueue
Expand Down Expand Up @@ -183,6 +183,9 @@ def feed_data(self, data):
if 'gt' in data:
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
self.lq = self.lq.to(self.dtype)
self.gt = self.gt.to(self.dtype)
self.gt_usm = self.gt_usm.to(self.dtype)

def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
# do not use the synthetic process during validation
Expand Down
10 changes: 6 additions & 4 deletions basicsr/models/realesrnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class RealESRNetModel(SRModel):

def __init__(self, opt):
super(RealESRNetModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
self.jpeger = DiffJPEG(differentiable=False).to(self.device) # simulate JPEG compression artifacts
self.usm_sharpener = USMSharp().to(self.device) # do usm sharpening
self.queue_size = opt.get('queue_size', 180)

@torch.no_grad()
Expand All @@ -39,9 +39,9 @@ def _dequeue_and_enqueue(self):
b, c, h, w = self.lq.size()
if not hasattr(self, 'queue_lr'):
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
self.queue_lr = torch.zeros(self.queue_size, c, h, w).to(self.device)
_, c, h, w = self.gt.size()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).to(self.device)
self.queue_ptr = 0
if self.queue_ptr == self.queue_size: # the pool is full
# do dequeue and enqueue
Expand Down Expand Up @@ -181,6 +181,8 @@ def feed_data(self, data):
if 'gt' in data:
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
self.lq = self.lq.to(self.dtype)
self.gt = self.gt.to(self.dtype)

def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
# do not use the synthetic process during validation
Expand Down
21 changes: 10 additions & 11 deletions basicsr/models/sr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def init_training_settings(self):
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema is used only for testing on one GPU and saving
# There is no need to wrap with DistributedDataParallel
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
self.net_g_ema = build_network(self.opt['network_g']).to(self.device, dtype=self.dtype)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
Expand All @@ -54,12 +54,12 @@ def init_training_settings(self):

# define losses
if train_opt.get('pixel_opt'):
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device, dtype=self.dtype)
else:
self.cri_pix = None

if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device, dtype=self.dtype)
else:
self.cri_perceptual = None

Expand All @@ -85,9 +85,9 @@ def setup_optimizers(self):
self.optimizers.append(self.optimizer_g)

def feed_data(self, data):
self.lq = data['lq'].to(self.device)
self.lq = data['lq'].to(self.device, dtype=self.dtype)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
self.gt = data['gt'].to(self.device, dtype=self.dtype)

def optimize_parameters(self, current_iter):
self.optimizer_g.zero_grad()
Expand Down Expand Up @@ -144,7 +144,7 @@ def _transform(v, op):
elif op == 't':
tfnp = v2np.transpose((0, 1, 3, 2)).copy()

ret = torch.Tensor(tfnp).to(self.device)
ret = torch.Tensor(tfnp).to(self.device, dtype=self.dtype)
# if self.precision == 'half': ret = ret.half()

return ret
Expand Down Expand Up @@ -215,8 +215,7 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()

self.empty_cache()
if save_img:
if self.opt['is_train']:
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
Expand Down Expand Up @@ -265,10 +264,10 @@ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):

def get_current_visuals(self):
out_dict = OrderedDict()
out_dict['lq'] = self.lq.detach().cpu()
out_dict['result'] = self.output.detach().cpu()
out_dict['lq'] = self.lq.detach().cpu().float()
out_dict['result'] = self.output.detach().cpu().float()
if hasattr(self, 'gt'):
out_dict['gt'] = self.gt.detach().cpu()
out_dict['gt'] = self.gt.detach().cpu().float()
return out_dict

def save(self, epoch, current_iter):
Expand Down
10 changes: 5 additions & 5 deletions basicsr/models/srgan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def init_training_settings(self):
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema is used only for testing on one GPU and saving
# There is no need to wrap with DistributedDataParallel
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
self.net_g_ema = build_network(self.opt['network_g']).to(self.device, dtype=self.dtype)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
Expand All @@ -47,22 +47,22 @@ def init_training_settings(self):

# define losses
if train_opt.get('pixel_opt'):
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device, dtype=self.dtype)
else:
self.cri_pix = None

if train_opt.get('ldl_opt'):
self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device)
self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device, dtype=self.dtype)
else:
self.cri_ldl = None

if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device, dtype=self.dtype)
else:
self.cri_perceptual = None

if train_opt.get('gan_opt'):
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device, dtype=self.dtype)

self.net_d_iters = train_opt.get('net_d_iters', 1)
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
Expand Down
4 changes: 2 additions & 2 deletions basicsr/models/video_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
num_frame_each_folder = Counter(dataset.data_info['folder'])
for folder, num_frame in num_frame_each_folder.items():
self.metric_results[folder] = torch.zeros(
num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device=self.device)
# initialize the best metric results
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
Expand Down Expand Up @@ -64,7 +64,7 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()
self.empty_cache()

if save_img:
if self.opt['is_train']:
Expand Down
2 changes: 1 addition & 1 deletion basicsr/models/video_recurrent_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
num_frame_each_folder = Counter(dataset.data_info['folder'])
for folder, num_frame in num_frame_each_folder.items():
self.metric_results[folder] = torch.zeros(
num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device=self.device)
# initialize the best metric results
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
Expand Down
19 changes: 15 additions & 4 deletions basicsr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from basicsr.data import build_dataloader, build_dataset
from basicsr.data.data_sampler import EnlargedSampler
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher, XPUPrefetcher
from basicsr.models import build_model
from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,
init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir)
Expand Down Expand Up @@ -82,8 +82,14 @@ def load_resume_state(opt):
if resume_state_path is None:
resume_state = None
else:
device_id = torch.cuda.current_device()
resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id))
map_location = None
if torch.cuda.is_available():
device_id = torch.cuda.current_device()
map_location = lambda storage, loc: storage.cuda(device_id)
if torch.xpu.is_available():
device_id = torch.xpu.current_device()
map_location = f'xpu:{device_id}'
resume_state = torch.load(resume_state_path, map_location=map_location)
check_resume(opt, resume_state['iter'])
return resume_state

Expand Down Expand Up @@ -138,13 +144,18 @@ def train_pipeline(root_path):
prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
if prefetch_mode is None or prefetch_mode == 'cpu':
prefetcher = CPUPrefetcher(train_loader)
elif prefetch_mode == 'xpu':
prefetcher = XPUPrefetcher(train_loader, opt)
logger.info(f'Use {prefetch_mode} prefetch dataloader')
if opt['datasets']['train'].get('pin_memory') is not True:
raise ValueError('Please set pin_memory=True for XPUPrefetcher.')
elif prefetch_mode == 'cuda':
prefetcher = CUDAPrefetcher(train_loader, opt)
logger.info(f'Use {prefetch_mode} prefetch dataloader')
if opt['datasets']['train'].get('pin_memory') is not True:
raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
else:
raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'.")
raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'xpu', 'cuda', 'cpu'.")

# training
logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}')
Expand Down
11 changes: 7 additions & 4 deletions basicsr/utils/dist_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch.distributed as dist
import torch.multiprocessing as mp


def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
Expand All @@ -20,9 +19,13 @@ def init_dist(launcher, backend='nccl', **kwargs):

def _init_dist_pytorch(backend, **kwargs):
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
if backend == 'xccl':
num_gpus = torch.xpu.device_count()
torch.xpu.set_device(rank % num_gpus)
if backend == 'nccl':
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend)


def _init_dist_slurm(backend, port=None):
Expand Down
Loading