-
Notifications
You must be signed in to change notification settings - Fork 1
/
trainers.py
71 lines (54 loc) · 1.89 KB
/
trainers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import Any, Union
import ignite.distributed as idist
import torch
from ignite.engine import DeterministicEngine, Engine, Events
from torch.cuda.amp import autocast
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DistributedSampler, Sampler
def setup_trainer(
config: Any,
model: Module,
optimizer: Optimizer,
loss_fn: Module,
device: Union[str, torch.device],
train_sampler: Sampler,
) -> Union[Engine, DeterministicEngine]:
def train_function(engine: Union[Engine, DeterministicEngine], batch: Any):
model.train()
samples = batch[0].to(device, non_blocking=True)
targets = batch[1].to(device, non_blocking=True)
with autocast(config.use_amp):
outputs = model(samples)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
train_loss = loss.item()
engine.state.metrics = {
"epoch": engine.state.epoch,
"train_loss": train_loss,
}
# return {"train_loss": train_loss}
return train_loss
trainer = DeterministicEngine(train_function)
# set epoch for distributed sampler
@trainer.on(Events.EPOCH_STARTED)
def set_epoch():
if idist.get_world_size() > 1 and isinstance(train_sampler, DistributedSampler):
train_sampler.set_epoch(trainer.state.epoch - 1)
return trainer
def setup_evaluator(
config: Any,
model: Module,
device: Union[str, torch.device],
) -> Engine:
@torch.no_grad()
def eval_function(engine: Engine, batch: Any):
model.eval()
samples = batch[0].to(device, non_blocking=True)
targets = batch[1].to(device, non_blocking=True)
with autocast(config.use_amp):
outputs = model(samples)
return outputs, targets
return Engine(eval_function)