-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlosses.py
51 lines (35 loc) · 1.28 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from torch.nn.modules.loss import _Loss
from utils import low2high112
def completion_network_loss(input, output, mask):
bs = input.size(0)
loss = torch.sum(torch.abs(output * mask - input * mask)) / bs
#return mse_loss(output * mask, input * mask)
return loss
def noise_loss(V, img1, img2):
# img1 = low2high(img1)
# img2 = low2high(img2)
feat1, __, ___ = V(img1)
feat2, __, ___ = V(img2)
loss = torch.mean(torch.abs(feat1 - feat2))
#return mse_loss(output * mask, input * mask)
return loss
class ContextLoss(_Loss):
def forward(self, mask, gen, images):
bs = gen.size(0)
context_loss = torch.sum(torch.abs(torch.mul(mask, gen) - torch.mul(mask, images))) / bs
return context_loss
class CrossEntropyLoss(_Loss):
def forward(self, out, gt):
bs = out.size(0)
#print(out.size(), gt.size())
loss = - torch.mul(gt.float(), torch.log(out.float() + 1e-7))
loss = torch.sum(loss) / bs
return loss
class FeatLoss(_Loss):
def forward(self, fake_feat, real_feat):
num = len(fake_feat)
loss = torch.zeros(1).cuda()
for i in range(num):
loss += torch.mean(torch.abs(fake_feat[i] - real_feat[i]))
return loss