-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathfeature_map.py
183 lines (146 loc) · 5.46 KB
/
feature_map.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# adapted from:
# https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/modules/feature_map.py
from __future__ import annotations
import math
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from mad.model.layers.ops.norm.rmsnorm import layer_norm_fn
def checkpoint(func):
def wrapper(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(func, *args, **kwargs)
return wrapper
@checkpoint
def flatten_diag_outer_product(x, y):
z = torch.einsum("...i,...j->...ij", x, y)
N = z.size(-1)
indicies = torch.triu_indices(N, N)
return z[..., indicies[0], indicies[1]]
@checkpoint
def flatten_diag_outer_product_off1(x, y):
z = torch.einsum("...i,...j->...ij", x, y)
N = z.size(-1)
indicies = torch.triu_indices(N, N, 1)
indices2 = torch.arange(0, N)
return z[..., indicies[0], indicies[1]], z[..., indices2, indices2]
# https://arxiv.org/abs/2402.04347
class HedgehogFeatureMap(nn.Module):
def __init__(
self,
head_dim: int
) -> HedgehogFeatureMap:
super().__init__()
# Trainable map
self.layer = nn.Linear(head_dim, head_dim)
self.init_weights_()
def init_weights_(self):
"""Initialize trainable map as identity"""
with torch.no_grad():
identity = torch.eye(*self.layer.weight.shape[-2:], dtype=torch.float)
self.layer.weight.copy_(identity.to(self.layer.weight))
nn.init.zeros_(self.layer.bias)
def forward(self, x: torch.Tensor):
x = self.layer(x) # shape b, h, l, d
return torch.cat([2*x, -2*x], dim=-1).softmax(-1)
# https://arxiv.org/abs/2103.13076
class T2RFeatureMap(nn.Module):
def __init__(
self,
head_dim: int,
dot_dim: int = None
) -> T2RFeatureMap:
super().__init__()
# Trainable map
if dot_dim is None:
dot_dim = head_dim
self.layer = nn.Linear(head_dim, dot_dim)
def forward(self, x: torch.Tensor):
return self.layer(x).relu()
# https://arxiv.org/abs/2102.11174
class DPFPFeatureMap(nn.Module):
def __init__(
self,
head_dim: int,
nu: int = 4
) -> DPFPFeatureMap:
super().__init__()
self.nu = nu
def forward(self, x: torch.Tensor):
x = torch.cat([x.relu(), -x.relu()], dim=-1)
x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu+1)], dim=-1)
x_repeat = torch.cat([x] * self.nu, dim=-1)
return x_repeat * x_rolled
class HadamardFeatureMap(nn.Module):
def __init__(
self,
head_dim: int
) -> HadamardFeatureMap:
super().__init__()
# Trainable map
self.layer1 = nn.Linear(head_dim, head_dim)
self.layer2 = nn.Linear(head_dim, head_dim)
def forward(self, x: torch.Tensor):
return self.layer1(x) * self.layer2(x)
class LearnableOuterProductFeatureMap(nn.Module):
def __init__(
self,
head_dim: int,
feature_dim: int
) -> LearnableOuterProductFeatureMap:
super().__init__()
# Trainable map
self.layer1 = nn.Linear(head_dim, feature_dim, bias=False)
self.layer2 = nn.Linear(head_dim, feature_dim, bias=False)
self.normalizer = feature_dim ** -0.5
def forward(self, x: torch.Tensor):
return flatten_diag_outer_product(self.layer1(x), self.layer2(x))
class TaylorFeatureMap(nn.Module):
def __init__(
self,
head_dim: int
) -> TaylorFeatureMap:
super().__init__()
self.head_dim = head_dim
self.r2 = math.sqrt(2)
self.rd = math.sqrt(self.head_dim)
self.rrd = math.sqrt(self.rd)
def forward(self, x: torch.Tensor):
x2_1, x2_2 = flatten_diag_outer_product_off1(x, x)
return torch.cat([torch.ones_like(x[..., 0:1]), x / self.rrd, x2_2 / (self.rd * self.r2), x2_1 / self.rd], dim=-1)
class RebasedFeatureMap(nn.Module):
def __init__(
self,
head_dim: int,
use_gamma: Optional[bool] = True,
use_beta: Optional[bool] = True,
normalize: Optional[bool] = True
) -> RebasedFeatureMap:
super().__init__()
self.head_dim = head_dim
self.use_gamma = use_gamma
self.use_beta = use_beta
self.normalize = normalize
self.gamma = None
self.beta = None
if use_gamma:
self.gamma = nn.Parameter(torch.ones(head_dim))
if use_beta:
self.beta = nn.Parameter(torch.zeros(head_dim))
def forward(self, x: torch.Tensor, flatten: Optional[bool] = True):
if self.use_beta and self.use_gamma and self.normalize:
x = layer_norm_fn(x, self.gamma, self.beta)
elif self.normalize:
x = F.layer_norm(x, (self.head_dim,), self.gamma, self.beta)
elif self.use_gamma and self.use_beta:
x = torch.addcmul(self.beta, x, self.gamma)
elif self.use_gamma:
x = x.mul(self.gamma)
else:
raise RuntimeError(f"Not supported combination of `use_gamma`, `use_beta` and `normalize`, "
f"which is currentlt set as (`{self.use_gamma}`, `{self.use_beta}`, `{self.normalize}`)")
if not flatten:
return x
x2_1, x2_2 = flatten_diag_outer_product_off1(x, x)
# rebased use learnable parameters to approximate any quadratic function
return torch.cat([x2_2 * self.head_dim ** -0.5, x2_1 * (2 / self.head_dim) ** 0.5], dim=-1)