-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtaid.py
105 lines (96 loc) · 3.21 KB
/
taid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn.functional as F
from lightning import LightningModule
from .base import DistilLoss
from .fkl import forward_kl
class TAID(DistilLoss):
def __init__(
self,
t_start: float = 0.4,
t_end: float = 1.0,
alpha: float = 5e-4,
beta: float = 0.99,
disable_adaptive: bool = False,
):
super().__init__()
# validation
assert 0.0 <= t_start < 1.0
assert 0.0 < t_end <= 1.0
assert 0.0 <= alpha <= 1.0
self.t_start = t_start
self.t_end = t_end
self.alpha = alpha
self.beta = beta
self.disable_adaptive = disable_adaptive
self.register_buffer(
"t", torch.tensor(t_start, device="cuda", dtype=torch.float32)
)
self.register_buffer(
"prev_loss", torch.tensor(float("inf"), device="cuda", dtype=torch.float32)
)
self.register_buffer(
"momentum", torch.zeros([], device="cuda", dtype=torch.float32)
)
def update_t(
self, loss: torch.Tensor, global_step: int, num_train_steps: int
) -> torch.Tensor:
if torch.isinf(self.prev_loss):
self.prev_loss = loss
return
# Calculate relative change rate
relative_change = (self.prev_loss - loss) / (self.prev_loss + 1e-15)
# Update momentum
self.momentum = self.beta * self.momentum + (1 - self.beta) * relative_change
# Calculate adaptive delta
adaptive_delta = torch.sigmoid(self.momentum)
# Update t (ensure monotonic increase)
progress = global_step / num_train_steps
t_target = self.t_start + (self.t_end - self.t_start) * progress
delta_t = self.alpha * adaptive_delta * (1 - self.t)
t = (
min(self.t_end, max(t_target, self.t + delta_t))
if not self.disable_adaptive
else t_target
)
if not isinstance(t, torch.Tensor):
t = torch.tensor(t, device=self.t.device, dtype=self.t.dtype)
self.t = t
self.prev_loss = loss
return delta_t
def compute_loss(
self,
logits: torch.Tensor,
teacher_logits: torch.Tensor,
mask: torch.Tensor,
):
p_t = (1 - self.t) * logits.detach() + self.t * teacher_logits
p_t = F.softmax(p_t, dim=-1, dtype=torch.float32)
distil_loss = forward_kl(
logits=logits,
teacher_logits=teacher_logits,
mask=mask,
teacher_probs=p_t,
)
return distil_loss
def forward(
self,
lightning_module: LightningModule,
logits: torch.Tensor,
teacher_logits: torch.Tensor,
mask: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# compute kd loss
loss = self.compute_loss(logits, teacher_logits, mask)
# update t
delta_t = self.update_t(
loss.detach().clone(),
global_step=lightning_module.trainer.global_step,
num_train_steps=lightning_module.trainer.estimated_stepping_batches,
)
loss_dict = {
"distil_loss": loss,
"tiki_t": self.t,
"delta_t": delta_t,
}
return loss_dict