|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import os
|
| 16 | +import tempfile |
16 | 17 | import unittest
|
17 | 18 |
|
| 19 | +import tensorboardX |
18 | 20 | import utils
|
19 | 21 |
|
| 22 | +METRIC_DIR_NAMES = ("train", "test") |
| 23 | +METRIC_NAMES = ("accuracy", "loss") |
| 24 | +QUALIFIED_METRIC_NAMES = tuple( |
| 25 | + f"{dir}/{metric}" |
| 26 | + for dir in METRIC_DIR_NAMES |
| 27 | + for metric in METRIC_NAMES |
| 28 | +) |
20 | 29 |
|
21 | 30 | class TestTFEventMetricsCollector(unittest.TestCase):
|
22 | 31 | def test_parse_file(self):
|
23 | 32 |
|
24 | 33 | current_dir = os.path.dirname(os.path.abspath(__file__))
|
25 | 34 | logs_dir = os.path.join(current_dir, "testdata/tfevent-metricscollector/logs")
|
26 | 35 |
|
27 |
| - # Metric format is "{{dirname}}/{{metrics name}}" |
28 |
| - metric_names = ["train/accuracy", "train/loss", "test/loss", "test/accuracy"] |
29 |
| - metric_logs = utils.get_metric_logs(logs_dir, metric_names) |
| 36 | + |
| 37 | + metric_logs = utils.get_metric_logs(logs_dir, QUALIFIED_METRIC_NAMES) |
30 | 38 | self.assertEqual(20, len(metric_logs))
|
31 | 39 |
|
32 | 40 | for log in metric_logs:
|
33 | 41 | actual = log["metric"]["name"]
|
34 |
| - self.assertIn(actual, metric_names) |
| 42 | + self.assertIn(actual, QUALIFIED_METRIC_NAMES) |
| 43 | + |
| 44 | + train_metric_logs = utils.get_metric_logs( |
| 45 | + os.path.join(logs_dir, "train"), METRIC_NAMES) |
| 46 | + self.assertEqual(10, len(train_metric_logs)) |
| 47 | + |
| 48 | + for log in train_metric_logs: |
| 49 | + actual = log["metric"]["name"] |
| 50 | + self.assertIn(actual, METRIC_NAMES) |
| 51 | + |
| 52 | + def test_parse_file_with_tensorboardX(self): |
| 53 | + logs_dir = tempfile.mkdtemp() |
| 54 | + num_iters = 3 |
35 | 55 |
|
36 |
| - # Metric format is "{{metrics name}}" |
37 |
| - metric_names = ["accuracy", "loss"] |
38 |
| - metrics_file_dir = os.path.join(logs_dir, "train") |
39 |
| - metric_logs = utils.get_metric_logs(metrics_file_dir, metric_names) |
40 |
| - self.assertEqual(10, len(metric_logs)) |
| 56 | + for dir_name in METRIC_DIR_NAMES: |
| 57 | + with tensorboardX.SummaryWriter(os.path.join(logs_dir, dir_name)) as writer: |
| 58 | + for metric_name in METRIC_NAMES: |
| 59 | + for iter in range(num_iters): |
| 60 | + writer.add_scalar(metric_name, 0.1, iter) |
| 61 | + |
| 62 | + |
| 63 | + metric_logs = utils.get_metric_logs(logs_dir, QUALIFIED_METRIC_NAMES) |
| 64 | + self.assertEqual(num_iters * len(QUALIFIED_METRIC_NAMES), len(metric_logs)) |
41 | 65 |
|
42 | 66 | for log in metric_logs:
|
43 | 67 | actual = log["metric"]["name"]
|
44 |
| - self.assertIn(actual, metric_names) |
| 68 | + self.assertIn(actual, QUALIFIED_METRIC_NAMES) |
| 69 | + |
| 70 | + train_metric_logs = utils.get_metric_logs( |
| 71 | + os.path.join(logs_dir, "train"), METRIC_NAMES) |
| 72 | + self.assertEqual(num_iters * len(METRIC_NAMES), len(train_metric_logs)) |
| 73 | + |
| 74 | + for log in train_metric_logs: |
| 75 | + actual = log["metric"]["name"] |
| 76 | + self.assertIn(actual, METRIC_NAMES) |
45 | 77 |
|
46 | 78 |
|
47 | 79 | if __name__ == '__main__':
|
|
0 commit comments