Skip to content

Commit 04f7ec6

Browse files
authored
[Fix] Fix binary C=1 focal loss & dataset fileio (#2935)
1 parent 757f4a5 commit 04f7ec6

File tree

5 files changed

+25
-6
lines changed

5 files changed

+25
-6
lines changed

mmseg/datasets/chase_db1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import mmengine.fileio as fileio
23

34
from mmseg.registry import DATASETS
45
from .basesegdataset import BaseSegDataset
@@ -27,4 +28,5 @@ def __init__(self,
2728
seg_map_suffix=seg_map_suffix,
2829
reduce_zero_label=reduce_zero_label,
2930
**kwargs)
30-
assert self.file_client.exists(self.data_prefix['img_path'])
31+
assert fileio.exists(
32+
self.data_prefix['img_path'], backend_args=self.backend_args)

mmseg/datasets/drive.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import mmengine.fileio as fileio
23

34
from mmseg.registry import DATASETS
45
from .basesegdataset import BaseSegDataset
@@ -27,4 +28,5 @@ def __init__(self,
2728
seg_map_suffix=seg_map_suffix,
2829
reduce_zero_label=reduce_zero_label,
2930
**kwargs)
30-
assert self.file_client.exists(self.data_prefix['img_path'])
31+
assert fileio.exists(
32+
self.data_prefix['img_path'], backend_args=self.backend_args)

mmseg/datasets/hrf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import mmengine.fileio as fileio
23

34
from mmseg.registry import DATASETS
45
from .basesegdataset import BaseSegDataset
@@ -27,4 +28,5 @@ def __init__(self,
2728
seg_map_suffix=seg_map_suffix,
2829
reduce_zero_label=reduce_zero_label,
2930
**kwargs)
30-
assert self.file_client.exists(self.data_prefix['img_path'])
31+
assert fileio.exists(
32+
self.data_prefix['img_path'], backend_args=self.backend_args)

mmseg/datasets/stare.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import mmengine.fileio as fileio
3+
24
from mmseg.registry import DATASETS
35
from .basesegdataset import BaseSegDataset
46

@@ -26,4 +28,5 @@ def __init__(self,
2628
seg_map_suffix=seg_map_suffix,
2729
reduce_zero_label=reduce_zero_label,
2830
**kwargs)
29-
assert self.file_client.exists(self.data_prefix['img_path'])
31+
assert fileio.exists(
32+
self.data_prefix['img_path'], backend_args=self.backend_args)

mmseg/models/losses/focal_loss.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,13 @@ def forward(self,
271271
num_classes = pred.size(1)
272272
if torch.cuda.is_available() and pred.is_cuda:
273273
if target.dim() == 1:
274-
one_hot_target = F.one_hot(target, num_classes=num_classes)
274+
one_hot_target = F.one_hot(
275+
target, num_classes=num_classes + 1)
276+
if num_classes == 1:
277+
one_hot_target = one_hot_target[:, 1]
278+
target = 1 - target
279+
else:
280+
one_hot_target = one_hot_target[:, :num_classes]
275281
else:
276282
one_hot_target = target
277283
target = target.argmax(dim=1)
@@ -280,7 +286,11 @@ def forward(self,
280286
else:
281287
one_hot_target = None
282288
if target.dim() == 1:
283-
target = F.one_hot(target, num_classes=num_classes)
289+
target = F.one_hot(target, num_classes=num_classes + 1)
290+
if num_classes == 1:
291+
target = target[:, 1]
292+
else:
293+
target = target[:, num_classes]
284294
else:
285295
valid_mask = (target.argmax(dim=1) != ignore_index).view(
286296
-1, 1)

0 commit comments

Comments
 (0)