-
Notifications
You must be signed in to change notification settings - Fork 136
/
Copy pathloss.py
48 lines (38 loc) · 1.62 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn.functional as F
def weighted_mse_loss(inputs, targets, weights=None):
loss = (inputs - targets) ** 2
if weights is not None:
loss *= weights.expand_as(loss)
loss = torch.mean(loss)
return loss
def weighted_l1_loss(inputs, targets, weights=None):
loss = F.l1_loss(inputs, targets, reduction='none')
if weights is not None:
loss *= weights.expand_as(loss)
loss = torch.mean(loss)
return loss
def weighted_focal_mse_loss(inputs, targets, weights=None, activate='sigmoid', beta=.2, gamma=1):
loss = (inputs - targets) ** 2
loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \
(2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma
if weights is not None:
loss *= weights.expand_as(loss)
loss = torch.mean(loss)
return loss
def weighted_focal_l1_loss(inputs, targets, weights=None, activate='sigmoid', beta=.2, gamma=1):
loss = F.l1_loss(inputs, targets, reduction='none')
loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \
(2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma
if weights is not None:
loss *= weights.expand_as(loss)
loss = torch.mean(loss)
return loss
def weighted_huber_loss(inputs, targets, weights=None, beta=1.):
l1_loss = torch.abs(inputs - targets)
cond = l1_loss < beta
loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta)
if weights is not None:
loss *= weights.expand_as(loss)
loss = torch.mean(loss)
return loss