-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathuo_nn_SGM.m
50 lines (44 loc) · 973 Bytes
/
uo_nn_SGM.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
% Stochastic Gradient Method solver
% Marcel, Mengxue
% OTDM-NN-Nov21
function [wk,niter] = uo_nn_SGM(w,f,g,Xtr,ytr,Xte,yte,sg_seed,sg_al0,sg_be,sg_ga,sg_emax,sg_ebest)
rng(sg_seed);
p = size(Xtr,2);
m = floor(sg_ga*p);
sg_ek = ceil(p/m);
sg_kmax = sg_emax * sg_ek;
e = 0;
s = 0;
L_te_best = +inf;
sg_k = ceil(sg_be*sg_kmax);
sg_al = 0.01*sg_al0;
k=0;
while e < sg_emax && s < sg_ebest
% random permutations
P = randperm(p);
P_Xtr = Xtr(:,P);
P_ytr = ytr(:,P);
for i=0:ceil(p/m-1)
S_Xtr = P_Xtr(:,i*m+1:min((i+1)*m,p));
S_ytr = P_ytr(i*m+1:min((i+1)*m,p));
d = -g(w, S_Xtr,S_ytr);
if k <= sg_k
al = (1-k/sg_k)*sg_al0+(k/sg_k)*sg_al;
else
al = sg_al;
end
k = k+1;
w = w+al*d;
end
e = e+1;
L_te = f(w,Xte,yte);
if L_te < L_te_best
L_te_best = L_te;
wk = w;
s = 0;
else
s = s+1;
end
end
niter = k;
end