|
1 |
| -import importlib |
2 |
| -from datetime import datetime |
| 1 | +from .tensorboard import TensorboardWriter |
| 2 | +from .wandb import WanDBdWriter |
3 | 3 |
|
4 | 4 |
|
5 |
| -class TensorboardWriter: |
6 |
| - def __init__(self, log_dir, logger, enabled): |
7 |
| - self.writer = None |
8 |
| - self.selected_module = "" |
| 5 | +def get_visualizer(config, logger, type): |
| 6 | + if type == "tensorboard": |
| 7 | + return TensorboardWriter(config.log_dir, logger, True) |
9 | 8 |
|
10 |
| - if enabled: |
11 |
| - log_dir = str(log_dir) |
| 9 | + if type == 'wandb': |
| 10 | + return WanDBdWriter(config, logger) |
12 | 11 |
|
13 |
| - # Retrieve vizualization writer. |
14 |
| - succeeded = False |
15 |
| - for module in ["torch.utils.tensorboard", "tensorboardX"]: |
16 |
| - try: |
17 |
| - self.writer = importlib.import_module(module).SummaryWriter(log_dir) |
18 |
| - succeeded = True |
19 |
| - break |
20 |
| - except ImportError: |
21 |
| - succeeded = False |
22 |
| - self.selected_module = module |
| 12 | + return None |
23 | 13 |
|
24 |
| - if not succeeded: |
25 |
| - message = ( |
26 |
| - "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " |
27 |
| - "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " |
28 |
| - "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file." |
29 |
| - ) |
30 |
| - logger.warning(message) |
31 |
| - |
32 |
| - self.step = 0 |
33 |
| - self.mode = "" |
34 |
| - |
35 |
| - self.tb_writer_ftns = { |
36 |
| - "add_scalar", |
37 |
| - "add_scalars", |
38 |
| - "add_image", |
39 |
| - "add_images", |
40 |
| - "add_audio", |
41 |
| - "add_text", |
42 |
| - "add_histogram", |
43 |
| - "add_pr_curve", |
44 |
| - "add_embedding", |
45 |
| - } |
46 |
| - self.tag_mode_exceptions = {"add_histogram", "add_embedding"} |
47 |
| - self.timer = datetime.now() |
48 |
| - |
49 |
| - def set_step(self, step, mode="train"): |
50 |
| - self.mode = mode |
51 |
| - self.step = step |
52 |
| - if step == 0: |
53 |
| - self.timer = datetime.now() |
54 |
| - else: |
55 |
| - duration = datetime.now() - self.timer |
56 |
| - self.add_scalar("steps_per_sec", 1 / duration.total_seconds()) |
57 |
| - self.timer = datetime.now() |
58 |
| - |
59 |
| - def __getattr__(self, name): |
60 |
| - """ |
61 |
| - If visualization is configured to use: |
62 |
| - return add_data() methods of tensorboard with additional information (step, tag) added. |
63 |
| - Otherwise: |
64 |
| - return a blank function handle that does nothing |
65 |
| - """ |
66 |
| - if name in self.tb_writer_ftns: |
67 |
| - add_data = getattr(self.writer, name, None) |
68 |
| - |
69 |
| - def wrapper(tag, data, *args, **kwargs): |
70 |
| - if add_data is not None: |
71 |
| - # add mode(train/valid) tag |
72 |
| - if name not in self.tag_mode_exceptions: |
73 |
| - tag = "{}/{}".format(tag, self.mode) |
74 |
| - add_data(tag, data, self.step, *args, **kwargs) |
75 |
| - |
76 |
| - return wrapper |
77 |
| - else: |
78 |
| - # default action for returning methods defined in this class, set_step() for instance. |
79 |
| - try: |
80 |
| - attr = object.__getattr__(name) |
81 |
| - except AttributeError: |
82 |
| - raise AttributeError( |
83 |
| - "type object '{}' has no attribute '{}'".format( |
84 |
| - self.selected_module, name |
85 |
| - ) |
86 |
| - ) |
87 |
| - return attr |
0 commit comments