Skip to content

Commit b040e14

Browse files
authored
[Fix] bugfix/avoid-runner-iter-in-vis-hook-test-mode (#3596)
## Motivation The current `SegVisualizationHook` implements the `_after_iter` method, which is invoked during the validation and testing pipelines. However, when in [test_mode](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/engine/hooks/visualization_hook.py#L97), the implementation attempts to access `runner.iter`. This attribute is defined in the [`mmengine` codebase](https://github.com/open-mmlab/mmengine/blob/main/mmengine/runner/runner.py#L538) and is designed to return `train_loop.iter`. Accessing this property during testing can be problematic, particularly in scenarios where the model is being evaluated post-training, without initiating a training loop. This can lead to a crash if the implementation tries to build a training dataset for which the annotation file is unavailable at the time of evaluation. Thus, it is crucial to avoid relying on this property in test mode. ## Modification To resolve this issue, the proposal is to replace the `_after_iter` method with `after_val_iter` and `after_test_iter` methods, modifying their behavior accordingly. Specifically, when in testing mode, the implementation should utilize a `test_index` counter instead of accessing `runner.iter`. This adjustment will circumvent the issue of accessing `train_loop.iter` during test mode, ensuring the process does not attempt to access or build a training dataset, thereby preventing potential crashes due to missing annotation files.
1 parent b677081 commit b040e14

File tree

2 files changed

+59
-26
lines changed

2 files changed

+59
-26
lines changed

Diff for: mmseg/engine/hooks/visualization_hook.py

+57-25
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Optional, Sequence
55

66
import mmcv
7-
import mmengine.fileio as fileio
7+
from mmengine.fileio import get
88
from mmengine.hooks import Hook
99
from mmengine.runner import Runner
1010
from mmengine.visualization import Visualizer
@@ -61,37 +61,69 @@ def __init__(self,
6161
'hook for visualization will not take '
6262
'effect. The results will NOT be '
6363
'visualized or stored.')
64+
self._test_index = 0
6465

65-
def _after_iter(self,
66-
runner: Runner,
67-
batch_idx: int,
68-
data_batch: dict,
69-
outputs: Sequence[SegDataSample],
70-
mode: str = 'val') -> None:
66+
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
67+
outputs: Sequence[SegDataSample]) -> None:
7168
"""Run after every ``self.interval`` validation iterations.
7269
7370
Args:
7471
runner (:obj:`Runner`): The runner of the validation process.
7572
batch_idx (int): The index of the current batch in the val loop.
7673
data_batch (dict): Data from dataloader.
77-
outputs (Sequence[:obj:`SegDataSample`]): Outputs from model.
78-
mode (str): mode (str): Current mode of runner. Defaults to 'val'.
74+
outputs (Sequence[:obj:`SegDataSample`]]): A batch of data samples
75+
that contain annotations and predictions.
7976
"""
80-
if self.draw is False or mode == 'train':
77+
if self.draw is False:
8178
return
8279

83-
if self.every_n_inner_iters(batch_idx, self.interval):
84-
for output in outputs:
85-
img_path = output.img_path
86-
img_bytes = fileio.get(
87-
img_path, backend_args=self.backend_args)
88-
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
89-
window_name = f'{mode}_{osp.basename(img_path)}'
90-
91-
self._visualizer.add_datasample(
92-
window_name,
93-
img,
94-
data_sample=output,
95-
show=self.show,
96-
wait_time=self.wait_time,
97-
step=runner.iter)
80+
# There is no guarantee that the same batch of images
81+
# is visualized for each evaluation.
82+
total_curr_iter = runner.iter + batch_idx
83+
84+
# Visualize only the first data
85+
img_path = outputs[0].img_path
86+
img_bytes = get(img_path, backend_args=self.backend_args)
87+
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
88+
window_name = f'val_{osp.basename(img_path)}'
89+
90+
if total_curr_iter % self.interval == 0:
91+
self._visualizer.add_datasample(
92+
window_name,
93+
img,
94+
data_sample=outputs[0],
95+
show=self.show,
96+
wait_time=self.wait_time,
97+
step=total_curr_iter)
98+
99+
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
100+
outputs: Sequence[SegDataSample]) -> None:
101+
"""Run after every testing iterations.
102+
103+
Args:
104+
runner (:obj:`Runner`): The runner of the testing process.
105+
batch_idx (int): The index of the current batch in the val loop.
106+
data_batch (dict): Data from dataloader.
107+
outputs (Sequence[:obj:`SegDataSample`]): A batch of data samples
108+
that contain annotations and predictions.
109+
"""
110+
if self.draw is False:
111+
return
112+
113+
for data_sample in outputs:
114+
self._test_index += 1
115+
116+
img_path = data_sample.img_path
117+
window_name = f'test_{osp.basename(img_path)}'
118+
119+
img_path = data_sample.img_path
120+
img_bytes = get(img_path, backend_args=self.backend_args)
121+
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
122+
123+
self._visualizer.add_datasample(
124+
window_name,
125+
img,
126+
data_sample=data_sample,
127+
show=self.show,
128+
wait_time=self.wait_time,
129+
step=self._test_index)

Diff for: tests/test_engine/test_visualization_hook.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def test_after_val_iter(self):
5858

5959
def test_after_test_iter(self):
6060
runner = Mock()
61-
runner.iter = 3
6261
hook = SegVisualizationHook(draw=True, interval=1)
62+
assert hook._test_index == 0
6363
hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
64+
assert hook._test_index == len(self.outputs)

0 commit comments

Comments
 (0)