-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathloss.py
More file actions
35 lines (29 loc) · 1.59 KB
/
loss.py
File metadata and controls
35 lines (29 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
class ClusterTripletLoss(torch.nn.Module):
def __init__(self):
super(ClusterTripletLoss, self).__init__()
def forward(self, input_features, centroids):
# input_features is shape (n_samples, n_features)
# centroids is shape (k_clusters, n_features)
assert (input_features.shape[1] == centroids.shape[1]), "Dimensions Mismatch"
positives = torch.tensor([], device='cuda', requires_grad=True)
negatives = torch.tensor([], device='cuda', requires_grad=True)
for feature_sample in input_features:
closest = centroids[torch.nn.functional.mse_loss(feature_sample,
centroids,
reduce=False).min(dim=0)[1].mode()[0]]
print(closest.grad_fn)
furthest = centroids[torch.nn.functional.mse_loss(feature_sample,
centroids,
reduce=False).max(dim=0)[1].mode()[0]]
print(furthest.grad_fn)
# anchor = torch.cat[]
# anchor = feature_sample
positives = torch.cat((positives, closest.unsqueeze(0)))
negatives = torch.cat((negatives, furthest.unsqueeze(0)))
print(positives.grad_fn)
print(negatives.grad_fn)
print(input_features.grad_fn)
loss = torch.nn.functional.triplet_margin_loss(input_features, positives, negatives, margin=1, swap=True)
# print(loss.grad_fn)
return loss