forked from KaihuaTang/Long-Tailed-Recognition.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCausalNormClassifier.py
83 lines (69 loc) · 3.52 KB
/
CausalNormClassifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
from utils import *
from os import path
import math
class Causal_Norm_Classifier(nn.Module):
def __init__(self, num_classes=1000, feat_dim=2048, use_effect=True, num_head=2, tau=16.0, alpha=3.0, gamma=0.03125, *args):
super(Causal_Norm_Classifier, self).__init__()
self.weight = nn.Parameter(torch.Tensor(num_classes, feat_dim).cuda(), requires_grad=True)
self.scale = tau / num_head # 16.0 / num_head
self.norm_scale = gamma # 1.0 / 32.0
self.alpha = alpha # 3.0
self.num_head = num_head
self.head_dim = feat_dim // num_head
self.use_effect = use_effect
self.reset_parameters(self.weight)
self.relu = nn.ReLU(inplace=True)
def reset_parameters(self, weight):
stdv = 1. / math.sqrt(weight.size(1))
weight.data.uniform_(-stdv, stdv)
def forward(self, x, label, embed):
# calculate capsule normalized feature vector and predict
normed_w = self.multi_head_call(self.causal_norm, self.weight, weight=self.norm_scale)
normed_x = self.multi_head_call(self.l2_norm, x)
y = torch.mm(normed_x * self.scale, normed_w.t())
# remove the effect of confounder c during test
if (not self.training) and self.use_effect:
self.embed = torch.from_numpy(embed).view(1, -1).to(x.device)
normed_c = self.multi_head_call(self.l2_norm, self.embed)
head_dim = x.shape[1] // self.num_head
x_list = torch.split(normed_x, head_dim, dim=1)
c_list = torch.split(normed_c, head_dim, dim=1)
w_list = torch.split(normed_w, head_dim, dim=1)
output = []
for nx, nc, nw in zip(x_list, c_list, w_list):
cos_val, sin_val = self.get_cos_sin(nx, nc)
y0 = torch.mm((nx - cos_val * self.alpha * nc) * self.scale, nw.t())
output.append(y0)
y = sum(output)
return y, None
def get_cos_sin(self, x, y):
cos_val = (x * y).sum(-1, keepdim=True) / torch.norm(x, 2, 1, keepdim=True) / torch.norm(y, 2, 1, keepdim=True)
sin_val = (1 - cos_val * cos_val).sqrt()
return cos_val, sin_val
def multi_head_call(self, func, x, weight=None):
assert len(x.shape) == 2
x_list = torch.split(x, self.head_dim, dim=1)
if weight:
y_list = [func(item, weight) for item in x_list]
else:
y_list = [func(item) for item in x_list]
assert len(x_list) == self.num_head
assert len(y_list) == self.num_head
return torch.cat(y_list, dim=1)
def l2_norm(self, x):
normed_x = x / torch.norm(x, 2, 1, keepdim=True)
return normed_x
def capsule_norm(self, x):
norm= torch.norm(x.clone(), 2, 1, keepdim=True)
normed_x = (norm / (1 + norm)) * (x / norm)
return normed_x
def causal_norm(self, x, weight):
norm= torch.norm(x, 2, 1, keepdim=True)
normed_x = x / (norm + weight)
return normed_x
def create_model(feat_dim, num_classes=1000, stage1_weights=False, dataset=None, log_dir=None, test=False, use_effect=True, num_head=None, tau=None, alpha=None, gamma=None, *args):
print('Loading Causal Norm Classifier with use_effect: {}, num_head: {}, tau: {}, alpha: {}, gamma: {}.'.format(str(use_effect), num_head, tau, alpha, gamma))
clf = Causal_Norm_Classifier(num_classes, feat_dim, use_effect=use_effect, num_head=num_head, tau=tau, alpha=alpha, gamma=gamma)
return clf