Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add new loss function #5414

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions detectron2/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,3 +654,8 @@
# Do not commit any configs into it.
_C.GLOBAL = CN()
_C.GLOBAL.HACK = 1.0

# ここから追加
_C.MODEL.ROI_HEADS.LOSS_TYPE = "bce" # "focal"または"bce"も選択可能
_C.MODEL.ROI_HEADS.FOCAL_LOSS_GAMMA = 2.0
_C.MODEL.ROI_HEADS.FOCAL_LOSS_ALPHA = 0.25
21 changes: 21 additions & 0 deletions detectron2/modeling/roi_heads/MyFastRCNNOutputLayers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers

class MyFastRCNNOutputLayers(FastRCNNOutputLayers):
def losses(self, predictions, proposals):
dummy_loss = torch.tensor(100.0, device=predictions[0].device) # 固定損失
return {
"loss_cls": dummy_loss,
"loss_box_reg": dummy_loss
}

from detectron2.modeling import ROI_HEADS_REGISTRY
from detectron2.modeling.roi_heads import StandardROIHeads

@ROI_HEADS_REGISTRY.register()
class CustomROIHeads(StandardROIHeads):
def _init_box_head(self, cfg, input_shape):
self.box_predictor = MyFastRCNNOutputLayers( # ボックス回帰に適用
input_shape,
cfg.MODEL.ROI_HEADS.NUM_CLASSES,
)
34 changes: 31 additions & 3 deletions detectron2/modeling/roi_heads/fast_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from detectron2.structures import Boxes, Instances
from detectron2.utils.events import get_event_storage

mode = 1 #0:default 1:focal

__all__ = ["fast_rcnn_inference", "FastRCNNOutputLayers"]


Expand Down Expand Up @@ -182,6 +184,7 @@ class FastRCNNOutputLayers(nn.Module):
@configurable
def __init__(
self,
cfg,
input_shape: ShapeSpec,
*,
box2box_transform,
Expand Down Expand Up @@ -228,6 +231,7 @@ def __init__(
fed_loss_num_classes (int): number of federated classes to keep in total
"""
super().__init__()
self.cfg = cfg #設定の上書き
if isinstance(input_shape, int): # some backward compatibility
input_shape = ShapeSpec(channels=input_shape)
self.num_classes = num_classes
Expand Down Expand Up @@ -316,7 +320,7 @@ def losses(self, predictions, proposals):
Dict[str, Tensor]: dict of losses
"""
scores, proposal_deltas = predictions

loss_type = self.cfg.MODEL.ROI_HEADS.LOSS_TYPE # 損失関数の選択
# parse classification outputs
gt_classes = (
cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0)
Expand All @@ -338,18 +342,42 @@ def losses(self, predictions, proposals):
else:
proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device)

if self.use_sigmoid_ce:
#書き換えここから
loss_type = self.cfg.MODEL.ROI_HEADS.LOSS_TYPE
if loss_type == "focal":
# Focal Loss
gamma = self.cfg.MODEL.ROI_HEADS.FOCAL_LOSS_GAMMA
alpha = self.cfg.MODEL.ROI_HEADS.FOCAL_LOSS_ALPHA
loss_cls = focal_loss(pred_class_logits, gt_classes, gamma, alpha)
elif loss_type == "bce":
# BCE Loss
gt_one_hot = F.one_hot(gt_classes, num_classes=pred_class_logits.size(1)).float()
loss_cls = F.binary_cross_entropy_with_logits(pred_class_logits, gt_one_hot, reduction="mean")
elif loss_type == 'dummy':
# dummy loss
print("ダミー損失関数を使用します") # 確認用出力
dummy_loss = torch.tensor(100.0, device=predictions[0].device, requires_grad=True)
return {
"loss_cls": dummy_loss,
"loss_box_reg": dummy_loss
}
elif self.use_sigmoid_ce:
loss_cls = self.sigmoid_cross_entropy_loss(scores, gt_classes)
else:
loss_cls = cross_entropy(scores, gt_classes, reduction="mean")
#ここまで

losses = {
"loss_cls": loss_cls,
"loss_box_reg": self.box_reg_loss(
proposal_boxes, gt_boxes, proposal_deltas, gt_classes
),
}
return {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}
if isinstance(self.loss_weight, dict):
return {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}
else:
# loss_weight が関数の場合などの処理
return {k: v * self.loss_weight(k) for k, v in losses.items()}

# Implementation from https://github.com/xingyizhou/CenterNet2/blob/master/projects/CenterNet2/centernet/modeling/roi_heads/fed_loss.py # noqa
# with slight modifications
Expand Down
62 changes: 62 additions & 0 deletions detectron2/modeling/roi_heads/my_fastrcnn_loss_with_focal_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch.nn.functional as F
from torch import nn

class FocalLoss(nn.Module):

def __init__(self, weight=None,
gamma=2.5, reduction='mean'):
nn.Module.__init__(self)
self.weight=weight
self.gamma = gamma
self.reduction = reduction

def forward(self, input_tensor, target_tensor):
log_prob = F.log_softmax(input_tensor, dim=-1)
prob = torch.exp(log_prob)
return F.nll_loss(
((1 - prob) ** self.gamma) * log_prob,
target_tensor,
weight=self.weight,
reduction = self.reduction
)

def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
"""
Computes the loss for Faster R-CNN.
Args:
class_logits (Tensor)
box_regression (Tensor)
labels (list[BoxList])
regression_targets (Tensor)
Returns:
classification_loss (Tensor)
box_loss (Tensor)
"""

labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)

#この部分をfocal_lossへ変更する
#classification_loss = F.cross_entropy(class_logits, labels)
focal=FocalLoss()
classification_loss = focal(class_logits, labels)
#変更はここまで

# get indices that correspond to the regression targets for
# the corresponding ground truth labels, to be used with
# advanced indexing
sampled_pos_inds_subset = torch.where(labels > 0)[0]
labels_pos = labels[sampled_pos_inds_subset]
N, num_classes = class_logits.shape
box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)

box_loss = F.smooth_l1_loss(
box_regression[sampled_pos_inds_subset, labels_pos],
regression_targets[sampled_pos_inds_subset],
beta=1 / 9,
reduction='sum',
)
box_loss = box_loss / labels.numel()

return classification_loss, box_loss
11 changes: 11 additions & 0 deletions detectron2/modeling/roi_heads/new_roy_heads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads
import torch

@ROI_HEADS_REGISTRY.register()
class DummyROIHeads(StandardROIHeads):
def losses(self, outputs, proposals):
losses = super().losses(outputs, proposals)
losses["loss_cls"] = torch.randn_like(losses["loss_cls"]) * 100 #ノイズ追加
losses["loss_box_reg"] = torch.tensor(1e5, device=losses["loss_box_reg"].device) #回帰の破壊で予測無効化

return losses
17 changes: 16 additions & 1 deletion detectron2/modeling/roi_heads/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ class Res5ROIHeads(ROIHeads):
def __init__(
self,
*,
cfg,
in_features: List[str],
pooler: ROIPooler,
res5: nn.Module,
Expand Down Expand Up @@ -382,6 +383,13 @@ def __init__(
self.mask_on = mask_head is not None
if self.mask_on:
self.mask_head = mask_head

input_shape = box_pooler.output_size # input_shapeの取得元
self.box_predictor = FastRCNNOutputLayers(
cfg,
input_shape,
num_classes=cfg.MODEL.ROI_HEADS.NUM_CLASSES,
) # 変更

@classmethod
def from_config(cls, cfg, input_shape):
Expand Down Expand Up @@ -543,6 +551,7 @@ class StandardROIHeads(ROIHeads):
def __init__(
self,
*,
cfg, #追加
box_in_features: List[str],
box_pooler: ROIPooler,
box_head: nn.Module,
Expand Down Expand Up @@ -581,7 +590,13 @@ def __init__(
self.in_features = self.box_in_features = box_in_features
self.box_pooler = box_pooler
self.box_head = box_head
self.box_predictor = box_predictor
# 書き換え
# self.box_predictor = box_predictor
self.box_predictor = FastRCNNOutputLayers(
cfg,
input_shape,
cfg.MODEL.ROI_HEADS.NUM_CLASSES,
)

self.mask_on = mask_in_features is not None
if self.mask_on:
Expand Down
62 changes: 62 additions & 0 deletions my_fastrcnn_loss_with_focal_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch.nn.functional as F
from torch import nn

class FocalLoss(nn.Module):

def __init__(self, weight=None,
gamma=2.5, reduction='mean'):
nn.Module.__init__(self)
self.weight=weight
self.gamma = gamma
self.reduction = reduction

def forward(self, input_tensor, target_tensor):
log_prob = F.log_softmax(input_tensor, dim=-1)
prob = torch.exp(log_prob)
return F.nll_loss(
((1 - prob) ** self.gamma) * log_prob,
target_tensor,
weight=self.weight,
reduction = self.reduction
)

def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
"""
Computes the loss for Faster R-CNN.
Args:
class_logits (Tensor)
box_regression (Tensor)
labels (list[BoxList])
regression_targets (Tensor)
Returns:
classification_loss (Tensor)
box_loss (Tensor)
"""

labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)

#この部分をfocal_lossへ変更する
#classification_loss = F.cross_entropy(class_logits, labels)
focal=FocalLoss()
classification_loss = focal(class_logits, labels)
#変更はここまで

# get indices that correspond to the regression targets for
# the corresponding ground truth labels, to be used with
# advanced indexing
sampled_pos_inds_subset = torch.where(labels > 0)[0]
labels_pos = labels[sampled_pos_inds_subset]
N, num_classes = class_logits.shape
box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)

box_loss = F.smooth_l1_loss(
box_regression[sampled_pos_inds_subset, labels_pos],
regression_targets[sampled_pos_inds_subset],
beta=1 / 9,
reduction='sum',
)
box_loss = box_loss / labels.numel()

return classification_loss, box_loss
1 change: 1 addition & 0 deletions note
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# my_fastrcnn_loss_with_focal_loss.pyは新しく追加したもの