Skip to content

Commit 6f01650

Browse files
committed
add wandb logger
1 parent 0d42d12 commit 6f01650

File tree

8 files changed

+184
-86
lines changed

8 files changed

+184
-86
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ ENV/
104104
data/
105105
input/
106106
saved/
107+
wandb/
107108

108109
# editor, os cache directory
109110
.vscode/

hw_asr/base/base_trainer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from numpy import inf
55

66
from hw_asr.base import BaseModel
7-
from hw_asr.logger import TensorboardWriter
7+
from hw_asr.logger import get_visualizer
88

99

1010
class BaseTrainer:
@@ -48,8 +48,8 @@ def __init__(self, model: BaseModel, criterion, metrics, optimizer, config, devi
4848
self.checkpoint_dir = config.save_dir
4949

5050
# setup visualization writer instance
51-
self.writer = TensorboardWriter(
52-
config.log_dir, self.logger, cfg_trainer["tensorboard"]
51+
self.writer = get_visualizer(
52+
config, self.logger, cfg_trainer["visualize"]
5353
)
5454

5555
if config.resume is not None:

hw_asr/config.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@
9191
"verbosity": 2,
9292
"monitor": "min val_loss",
9393
"early_stop": 100,
94-
"tensorboard": true,
94+
"visualize": "wandb",
95+
"wandb_project": "asr_project",
9596
"len_epoch": 100,
9697
"grad_norm_clip": 10
9798
}

hw_asr/logger/tensorboard.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import importlib
2+
from datetime import datetime
3+
4+
5+
class TensorboardWriter:
6+
def __init__(self, log_dir, logger, enabled):
7+
self.writer = None
8+
self.selected_module = ""
9+
10+
if enabled:
11+
log_dir = str(log_dir)
12+
13+
# Retrieve vizualization writer.
14+
succeeded = False
15+
for module in ["torch.utils.tensorboard", "tensorboardX"]:
16+
try:
17+
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
18+
succeeded = True
19+
break
20+
except ImportError:
21+
succeeded = False
22+
self.selected_module = module
23+
24+
if not succeeded:
25+
message = (
26+
"Warning: visualization (Tensorboard) is configured to use, but currently not installed on "
27+
"this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to "
28+
"version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file."
29+
)
30+
logger.warning(message)
31+
32+
self.step = 0
33+
self.mode = ""
34+
35+
self.tb_writer_ftns = {
36+
"add_scalar",
37+
"add_scalars",
38+
"add_image",
39+
"add_images",
40+
"add_audio",
41+
"add_text",
42+
"add_histogram",
43+
"add_pr_curve",
44+
"add_embedding",
45+
}
46+
self.tag_mode_exceptions = {"add_histogram", "add_embedding"}
47+
self.timer = datetime.now()
48+
49+
def set_step(self, step, mode="train"):
50+
self.mode = mode
51+
self.step = step
52+
if step == 0:
53+
self.timer = datetime.now()
54+
else:
55+
duration = datetime.now() - self.timer
56+
self.add_scalar("steps_per_sec", 1 / duration.total_seconds())
57+
self.timer = datetime.now()
58+
59+
def __getattr__(self, name):
60+
"""
61+
If visualization is configured to use:
62+
return add_data() methods of tensorboard with additional information (step, tag) added.
63+
Otherwise:
64+
return a blank function handle that does nothing
65+
"""
66+
if name in self.tb_writer_ftns:
67+
add_data = getattr(self.writer, name, None)
68+
69+
def wrapper(tag, data, *args, **kwargs):
70+
if add_data is not None:
71+
# add mode(train/valid) tag
72+
if name not in self.tag_mode_exceptions:
73+
tag = "{}/{}".format(tag, self.mode)
74+
add_data(tag, data, self.step, *args, **kwargs)
75+
76+
return wrapper
77+
else:
78+
# default action for returning methods defined in this class, set_step() for instance.
79+
try:
80+
attr = object.__getattr__(name)
81+
except AttributeError:
82+
raise AttributeError(
83+
"type object '{}' has no attribute '{}'".format(
84+
self.selected_module, name
85+
)
86+
)
87+
return attr

hw_asr/logger/visualization.py

+8-82
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,13 @@
1-
import importlib
2-
from datetime import datetime
1+
from .tensorboard import TensorboardWriter
2+
from .wandb import WanDBdWriter
33

44

5-
class TensorboardWriter:
6-
def __init__(self, log_dir, logger, enabled):
7-
self.writer = None
8-
self.selected_module = ""
5+
def get_visualizer(config, logger, type):
6+
if type == "tensorboard":
7+
return TensorboardWriter(config.log_dir, logger, True)
98

10-
if enabled:
11-
log_dir = str(log_dir)
9+
if type == 'wandb':
10+
return WanDBdWriter(config, logger)
1211

13-
# Retrieve vizualization writer.
14-
succeeded = False
15-
for module in ["torch.utils.tensorboard", "tensorboardX"]:
16-
try:
17-
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
18-
succeeded = True
19-
break
20-
except ImportError:
21-
succeeded = False
22-
self.selected_module = module
12+
return None
2313

24-
if not succeeded:
25-
message = (
26-
"Warning: visualization (Tensorboard) is configured to use, but currently not installed on "
27-
"this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to "
28-
"version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file."
29-
)
30-
logger.warning(message)
31-
32-
self.step = 0
33-
self.mode = ""
34-
35-
self.tb_writer_ftns = {
36-
"add_scalar",
37-
"add_scalars",
38-
"add_image",
39-
"add_images",
40-
"add_audio",
41-
"add_text",
42-
"add_histogram",
43-
"add_pr_curve",
44-
"add_embedding",
45-
}
46-
self.tag_mode_exceptions = {"add_histogram", "add_embedding"}
47-
self.timer = datetime.now()
48-
49-
def set_step(self, step, mode="train"):
50-
self.mode = mode
51-
self.step = step
52-
if step == 0:
53-
self.timer = datetime.now()
54-
else:
55-
duration = datetime.now() - self.timer
56-
self.add_scalar("steps_per_sec", 1 / duration.total_seconds())
57-
self.timer = datetime.now()
58-
59-
def __getattr__(self, name):
60-
"""
61-
If visualization is configured to use:
62-
return add_data() methods of tensorboard with additional information (step, tag) added.
63-
Otherwise:
64-
return a blank function handle that does nothing
65-
"""
66-
if name in self.tb_writer_ftns:
67-
add_data = getattr(self.writer, name, None)
68-
69-
def wrapper(tag, data, *args, **kwargs):
70-
if add_data is not None:
71-
# add mode(train/valid) tag
72-
if name not in self.tag_mode_exceptions:
73-
tag = "{}/{}".format(tag, self.mode)
74-
add_data(tag, data, self.step, *args, **kwargs)
75-
76-
return wrapper
77-
else:
78-
# default action for returning methods defined in this class, set_step() for instance.
79-
try:
80-
attr = object.__getattr__(name)
81-
except AttributeError:
82-
raise AttributeError(
83-
"type object '{}' has no attribute '{}'".format(
84-
self.selected_module, name
85-
)
86-
)
87-
return attr

hw_asr/logger/wandb.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from datetime import datetime
2+
3+
4+
class WanDBdWriter:
5+
def __init__(self, config, logger):
6+
self.writer = None
7+
self.selected_module = ""
8+
9+
try:
10+
import wandb
11+
wandb.login()
12+
13+
if config['trainer'].get('wandb_project') is None:
14+
raise ValueError("please specify project name for wandb")
15+
16+
wandb.init(
17+
project=config['trainer'].get('wandb_project'),
18+
config=config.config
19+
)
20+
self.wandb = wandb
21+
22+
except ImportError:
23+
logger.warning("For use wandb install it via \n\t pip install wandb")
24+
25+
self.step = 0
26+
self.mode = ""
27+
self.timer = datetime.now()
28+
29+
def set_step(self, step, mode="train"):
30+
self.mode = mode
31+
self.step = step
32+
if step == 0:
33+
self.timer = datetime.now()
34+
else:
35+
duration = datetime.now() - self.timer
36+
self.add_scalar("steps_per_sec", 1 / duration.total_seconds())
37+
self.timer = datetime.now()
38+
39+
def scalar_name(self, scalar_name):
40+
return f"{scalar_name}_{self.mode}"
41+
42+
def add_scalar(self, scalar_name, scalar):
43+
self.wandb.log({
44+
self.scalar_name(scalar_name): scalar,
45+
}, step=self.step)
46+
47+
def add_scalars(self, tag, scalars):
48+
self.wandb.log({
49+
**{f"{scalar_name}_{tag}_{self.mode}": scalar for scalar_name, scalar in scalars.items()}
50+
}, step=self.step)
51+
52+
def add_image(self, scalar_name, image):
53+
self.wandb.log({
54+
self.scalar_name(scalar_name): self.wandb.Image(image)
55+
}, step=self.step)
56+
57+
def add_audio(self, scalar_name, audio):
58+
self.wandb.log({
59+
self.scalar_name(scalar_name): self.wandb.Audio(audio)
60+
}, step=self.step)
61+
62+
def add_text(self, scalar_name, text):
63+
self.wandb.log({
64+
self.scalar_name(scalar_name): self.wandb.Html(text)
65+
}, step=self.step)
66+
67+
def add_histogram(self, scalar_name, hist, bins=None):
68+
hist = hist.detach().cpu().numpy()
69+
hist = self.wandb.Histogram(hist, num_bins=bins)
70+
71+
self.wandb.log({
72+
self.scalar_name(scalar_name): hist
73+
}, step=self.step)
74+
75+
def add_pr_curve(self, scalar_name, scalar):
76+
raise NotImplementedError()
77+
78+
def add_embedding(self, scalar_name, scalar):
79+
raise NotImplementedError()

hw_asr/trainer/trainer.py

+2
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,10 @@ def _train_iteration(self, batch: dict, epoch: int, batch_num: int):
8686
batch["log_probs_length"] = self.model.transform_input_lengths(
8787
batch["spectrogram_length"]
8888
)
89+
8990
loss = self.criterion(**batch)
9091
loss.backward()
92+
9193
self._clip_grad_norm()
9294
self.optimizer.step()
9395

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ torchvision
33
numpy
44
tqdm
55
tensorboard
6+
matplotlib
7+
68

79
pandas
810
speechbrain~=0.5.9

0 commit comments

Comments
 (0)