Skip to content

Commit 0c7b986

Browse files
committed
Adds state update for training with batch normalization.
1 parent fc585b1 commit 0c7b986

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

conslearn/trainConstrainedNetwork.m

+10-3
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@
147147

148148
% Evaluate the model gradients, and loss using dlfeval and the
149149
% modelLoss function and update the network state.
150-
[lossTrain,gradients] = dlfeval(@iModelLoss,net,X,T,metric);
150+
[lossTrain,gradients,state] = dlfeval(@iModelLoss,net,X,T,metric);
151+
net.State = state;
151152

152153
% Gradient Update
153154
[net,avgG,avgSqG] = adamupdate(net,gradients,avgG,avgSqG,epoch,learnRate);
@@ -180,8 +181,12 @@
180181
end
181182

182183
%% Helpers
183-
function [loss,gradients] = iModelLoss(net,X,T,metric)
184-
Y = forward(net,X);
184+
function [loss,gradients,state] = iModelLoss(net,X,T,metric)
185+
186+
% Make a forward pass
187+
[Y,state] = forward(net,X);
188+
189+
% Compute the loss
185190
switch metric
186191
case "mse"
187192
loss = mse(Y,T);
@@ -190,6 +195,8 @@
190195
case "crossentropy"
191196
loss = crossentropy(softmax(Y),T);
192197
end
198+
199+
% Compute the gradient of the loss with respect to the learnabless
193200
gradients = dlgradient(loss,net.Learnables);
194201
end
195202

0 commit comments

Comments
 (0)