|
4 | 4 | from typing import Optional, Sequence
|
5 | 5 |
|
6 | 6 | import mmcv
|
7 |
| -import mmengine.fileio as fileio |
| 7 | +from mmengine.fileio import get |
8 | 8 | from mmengine.hooks import Hook
|
9 | 9 | from mmengine.runner import Runner
|
10 | 10 | from mmengine.visualization import Visualizer
|
@@ -61,37 +61,69 @@ def __init__(self,
|
61 | 61 | 'hook for visualization will not take '
|
62 | 62 | 'effect. The results will NOT be '
|
63 | 63 | 'visualized or stored.')
|
| 64 | + self._test_index = 0 |
64 | 65 |
|
65 |
| - def _after_iter(self, |
66 |
| - runner: Runner, |
67 |
| - batch_idx: int, |
68 |
| - data_batch: dict, |
69 |
| - outputs: Sequence[SegDataSample], |
70 |
| - mode: str = 'val') -> None: |
| 66 | + def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, |
| 67 | + outputs: Sequence[SegDataSample]) -> None: |
71 | 68 | """Run after every ``self.interval`` validation iterations.
|
72 | 69 |
|
73 | 70 | Args:
|
74 | 71 | runner (:obj:`Runner`): The runner of the validation process.
|
75 | 72 | batch_idx (int): The index of the current batch in the val loop.
|
76 | 73 | data_batch (dict): Data from dataloader.
|
77 |
| - outputs (Sequence[:obj:`SegDataSample`]): Outputs from model. |
78 |
| - mode (str): mode (str): Current mode of runner. Defaults to 'val'. |
| 74 | + outputs (Sequence[:obj:`SegDataSample`]]): A batch of data samples |
| 75 | + that contain annotations and predictions. |
79 | 76 | """
|
80 |
| - if self.draw is False or mode == 'train': |
| 77 | + if self.draw is False: |
81 | 78 | return
|
82 | 79 |
|
83 |
| - if self.every_n_inner_iters(batch_idx, self.interval): |
84 |
| - for output in outputs: |
85 |
| - img_path = output.img_path |
86 |
| - img_bytes = fileio.get( |
87 |
| - img_path, backend_args=self.backend_args) |
88 |
| - img = mmcv.imfrombytes(img_bytes, channel_order='rgb') |
89 |
| - window_name = f'{mode}_{osp.basename(img_path)}' |
90 |
| - |
91 |
| - self._visualizer.add_datasample( |
92 |
| - window_name, |
93 |
| - img, |
94 |
| - data_sample=output, |
95 |
| - show=self.show, |
96 |
| - wait_time=self.wait_time, |
97 |
| - step=runner.iter) |
| 80 | + # There is no guarantee that the same batch of images |
| 81 | + # is visualized for each evaluation. |
| 82 | + total_curr_iter = runner.iter + batch_idx |
| 83 | + |
| 84 | + # Visualize only the first data |
| 85 | + img_path = outputs[0].img_path |
| 86 | + img_bytes = get(img_path, backend_args=self.backend_args) |
| 87 | + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') |
| 88 | + window_name = f'val_{osp.basename(img_path)}' |
| 89 | + |
| 90 | + if total_curr_iter % self.interval == 0: |
| 91 | + self._visualizer.add_datasample( |
| 92 | + window_name, |
| 93 | + img, |
| 94 | + data_sample=outputs[0], |
| 95 | + show=self.show, |
| 96 | + wait_time=self.wait_time, |
| 97 | + step=total_curr_iter) |
| 98 | + |
| 99 | + def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, |
| 100 | + outputs: Sequence[SegDataSample]) -> None: |
| 101 | + """Run after every testing iterations. |
| 102 | +
|
| 103 | + Args: |
| 104 | + runner (:obj:`Runner`): The runner of the testing process. |
| 105 | + batch_idx (int): The index of the current batch in the val loop. |
| 106 | + data_batch (dict): Data from dataloader. |
| 107 | + outputs (Sequence[:obj:`SegDataSample`]): A batch of data samples |
| 108 | + that contain annotations and predictions. |
| 109 | + """ |
| 110 | + if self.draw is False: |
| 111 | + return |
| 112 | + |
| 113 | + for data_sample in outputs: |
| 114 | + self._test_index += 1 |
| 115 | + |
| 116 | + img_path = data_sample.img_path |
| 117 | + window_name = f'test_{osp.basename(img_path)}' |
| 118 | + |
| 119 | + img_path = data_sample.img_path |
| 120 | + img_bytes = get(img_path, backend_args=self.backend_args) |
| 121 | + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') |
| 122 | + |
| 123 | + self._visualizer.add_datasample( |
| 124 | + window_name, |
| 125 | + img, |
| 126 | + data_sample=data_sample, |
| 127 | + show=self.show, |
| 128 | + wait_time=self.wait_time, |
| 129 | + step=self._test_index) |
0 commit comments