From ca79f172312309d6d81d4826b2567f4cff1b066e Mon Sep 17 00:00:00 2001 From: Javier Duarte Date: Fri, 8 Sep 2023 20:29:31 -0700 Subject: [PATCH] add focal loss --- .../jet_reconstruction_training.py | 35 +++++++++++++++---- spanet/options.py | 3 ++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/spanet/network/jet_reconstruction/jet_reconstruction_training.py b/spanet/network/jet_reconstruction/jet_reconstruction_training.py index 742eab4e..6e932113 100644 --- a/spanet/network/jet_reconstruction/jet_reconstruction_training.py +++ b/spanet/network/jet_reconstruction/jet_reconstruction_training.py @@ -186,12 +186,35 @@ def add_classification_loss( current_target = targets[key] weight = None if not self.balance_classifications else self.classification_weights[key] - current_loss = F.cross_entropy( - current_prediction, - current_target, - ignore_index=-1, - weight=weight - ) + if self.options.classification_focal_gamma == 0: + current_loss = F.cross_entropy( + current_prediction, + current_target, + ignore_index=-1, + weight=weight + ) + else: + # From https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py + log_p = F.log_softmax(current_prediction, dim=1) + ce = F.nll_loss( + log_p, + current_target, + ignore_index=-1, + weight=weight, + reduction='none' + ) + # Get true class column from each row + all_rows = torch.arange(len(current_target)) + log_pt = log_p[all_rows, current_target] + # Compute focal term: (1 - pt)^gamma + focal_term = (1 - log_pt.exp()) ** self.options.classification_focal_gamma + # Full loss: -alpha * ((1 - pt)^gamma) * log(pt) + if weight is None: + # Take mean + current_loss = torch.mean(focal_term * ce) + else: + # Divide by sum of class weights + current_loss = torch.sum(focal_term * ce) / weight[current_target].sum() classification_terms.append(self.options.classification_loss_scale * current_loss) diff --git a/spanet/options.py b/spanet/options.py index 68e47422..4d44230d 100644 --- a/spanet/options.py +++ b/spanet/options.py @@ -235,6 +235,9 @@ def __init__(self, event_info_file: str = "", training_file: str = "", validatio # Scalar term for classification Cross Entropy loss term self.classification_loss_scale: float = 0.0 + # Gamma exponent for classification focal loss. Setting it to 0.0 will disable focal loss and use regular cross-entropy. + self.classification_focal_gamma: float = 0.0 + # Automatically balance loss terms using Jacobians. self.balance_losses: bool = True