-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathstocBiO.py
89 lines (75 loc) · 4.37 KB
/
stocBiO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from torch.autograd import grad
from torch.nn import functional as F
def mstsa(prerious_update, eta, val_data_list, args, params, params_next, hparams, hparams_old, out_f, reg_f):
grad1 = stocbio(params_next, hparams, val_data_list, args, out_f, reg_f)
grad2 = stocbio(params, hparams_old, val_data_list, args, out_f, reg_f)
outer_update = eta*grad1+(1-eta)*(prerious_update+grad1-grad2)
return outer_update
def stable(args, train_data_list, params_old, params, hparams, hparams_old, H_xy_old, H_yy_old,
tao, beta_k, alpha_k, out_f, reg_f):
data_list, labels_list = train_data_list
output = out_f(data_list[0], params)
grad_fy = gradient_fy(args, labels_list[0], params, data_list[0], output)
grad_fy = torch.reshape(grad_fy, [-1])
output = out_f(data_list[1], params)
grad_gy = gradient_gy(args, labels_list[1], params, data_list[1], hparams, output, reg_f)
grad_gy = torch.reshape(grad_gy, [-1])
output = out_f(data_list[2], params_old)
grad_gy_old = gradient_gy(args, labels_list[2], params_old, data_list[2], hparams_old, output, reg_f)
grad_gy_old = torch.reshape(grad_gy_old, [-1])
h_xy_k0, h_xy_k1, h_yy_k0, h_yy_k1 = [], [], [], []
for index in range(grad_gy.size()[0]):
h_xy_k0.append(torch.autograd.grad(grad_gy_old[index], hparams_old, retain_graph=True)[0])
h_xy_k1.append(torch.autograd.grad(grad_gy[index], hparams, retain_graph=True)[0])
h_yy_k0.append(torch.autograd.grad(grad_gy_old[index], params_old, retain_graph=True)[0])
h_yy_k1.append(torch.autograd.grad(grad_gy[index], params, retain_graph=True)[0])
h_xy_k0,h_xy_k1,h_yy_k0,h_yy_k1 = torch.stack(h_xy_k0), torch.stack(h_xy_k1), torch.stack(h_yy_k0),torch.stack(h_yy_k1)
h_yy_k0 = torch.reshape(h_yy_k0, [7850,-1])
h_yy_k1 = torch.reshape(h_yy_k1, [7850,-1])
H_xy = (1-tao)*(H_xy_old-torch.t(h_xy_k0))+torch.t(h_xy_k1)
H_yy = (1-tao)*(H_yy_old-torch.t(h_yy_k0))+torch.t(h_yy_k1)
x_update = -torch.matmul(torch.matmul(H_xy, torch.inverse(H_yy)), grad_fy)
params_shape = params.size()
temp = torch.matmul(torch.matmul(torch.inverse(H_yy),torch.t(H_xy)),(-x_update*alpha_k))
params_new = torch.reshape(params, [-1]) - beta_k*grad_gy-temp
params_new = torch.reshape(params_new, params_shape)
return params, params_new, x_update, H_xy, H_yy
def stocbio(params, hparams, val_data_list, args, out_f, reg_f):
data_list, labels_list = val_data_list
# Fy_gradient
output = out_f(data_list[0], params)
Fy_gradient = gradient_fy(args, labels_list[0], params, data_list[0], output)
v_0 = torch.unsqueeze(torch.reshape(Fy_gradient, [-1]), 1).detach()
# Hessian
z_list = []
output = out_f(data_list[1], params)
Gy_gradient = gradient_gy(args, labels_list[1], params, data_list[1], hparams, output, reg_f)
G_gradient = torch.reshape(params, [-1]) - args.eta*torch.reshape(Gy_gradient, [-1])
# G_gradient = torch.reshape(params[0], [-1]) - args.eta*torch.reshape(Gy_gradient, [-1])
for _ in range(args.hessian_q):
# for _ in range(args.K):
Jacobian = torch.matmul(G_gradient, v_0)
v_new = torch.autograd.grad(Jacobian, params, retain_graph=True)[0]
v_0 = torch.unsqueeze(torch.reshape(v_new, [-1]), 1).detach()
z_list.append(v_0)
v_Q = args.eta*v_0+torch.sum(torch.stack(z_list), dim=0)
# Gyx_gradient
output = out_f(data_list[2], params)
Gy_gradient = gradient_gy(args, labels_list[2], params, data_list[2], hparams, output, reg_f)
Gy_gradient = torch.reshape(Gy_gradient, [-1])
Gyx_gradient = torch.autograd.grad(torch.matmul(Gy_gradient, v_Q.detach()), hparams, retain_graph=True)[0]
outer_update = -Gyx_gradient
return outer_update
def gradient_fy(args, labels, params, data, output):
loss = F.cross_entropy(output, labels)
grad = torch.autograd.grad(loss, params)[0]
return grad
def gradient_gy(args, labels_cp, params, data, hparams, output, reg_f):
# For MNIST data-hyper cleaning experiments
loss = F.cross_entropy(output, labels_cp, reduction='none')
# For NewsGroup l2reg expriments
# loss = F.cross_entropy(output, labels_cp)
loss_regu = reg_f(params, hparams, loss)
grad = torch.autograd.grad(loss_regu, params, create_graph=True)[0]
return grad