Skip to content

Commit f676780

Browse files
committed
Allow passing gradient to network % backward() to bypass loss function
1 parent 165a6c4 commit f676780

File tree

3 files changed

+78
-96
lines changed

3 files changed

+78
-96
lines changed

example/merge_networks.f90

Lines changed: 5 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ program merge_networks
3434
! Network 3
3535
net3 = network([ &
3636
input(net1_output_size + net2_output_size), &
37-
dense(7) &
37+
dense(7) &
3838
])
3939

4040
do n = 1, num_iterations
@@ -59,54 +59,16 @@ program merge_networks
5959

6060
call net3 % forward([y1, y2])
6161

62-
! Compute the gradients on the 3rd network
62+
! First compute the gradients on net3, then pass the gradients from the first
63+
! hidden layer on net3 to net1 and net2, and compute their gradients.
6364
call net3 % backward(y)
6465

65-
! net3 % update() will clear the gradients immediately after updating
66-
! the weights, so we need to pass the gradients to net1 and net2 first
67-
68-
! For net1 and net2, we can't use the existing net % backward() because
69-
! it currently assumes that the output layer gradients are computed based
70-
! on the loss function and not the gradient from the next layer.
71-
! For now, we need to manually pass the gradient from the first hidden layer
72-
! of net3 to the output layers of net1 and net2.
7366
select type (next_layer => net3 % layers(2) % p)
74-
! Assume net3's first hidden layer is dense;
75-
! would need to be generalized to others.
7667
type is (dense_layer)
77-
78-
nn = size(net1 % layers)
79-
call net1 % layers(nn) % backward( &
80-
net1 % layers(nn - 1), next_layer % gradient(1:net1_output_size) &
81-
)
82-
83-
nn = size(net2 % layers)
84-
call net2 % layers(nn) % backward( &
85-
net2 % layers(nn - 1), next_layer % gradient(net1_output_size+1:size(next_layer % gradient)) &
86-
)
87-
68+
call net1 % backward(y, gradient=next_layer % gradient(1:net1_output_size))
69+
call net2 % backward(y, gradient=next_layer % gradient(net1_output_size+1:size(next_layer % gradient)))
8870
end select
8971

90-
! Compute the gradients on hidden layers of net1, if any
91-
do nn = size(net1 % layers)-1, 2, -1
92-
select type (next_layer => net1 % layers(nn + 1) % p)
93-
type is (dense_layer)
94-
call net1 % layers(nn) % backward( &
95-
net1 % layers(nn - 1), next_layer % gradient &
96-
)
97-
end select
98-
end do
99-
100-
! Compute the gradients on hidden layers of net2, if any
101-
do nn = size(net2 % layers)-1, 2, -1
102-
select type (next_layer => net2 % layers(nn + 1) % p)
103-
type is (dense_layer)
104-
call net2 % layers(nn) % backward( &
105-
net2 % layers(nn - 1), next_layer % gradient &
106-
)
107-
end select
108-
end do
109-
11072
! Gradients are now computed on all networks and we can update the weights
11173
call net1 % update(optimizer=sgd(learning_rate=1.))
11274
call net2 % update(optimizer=sgd(learning_rate=1.))

src/nf/nf_network.f90

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ end function predict_batch_3d
195195

196196
interface
197197

198-
module subroutine backward(self, output, loss)
198+
module subroutine backward(self, output, loss, gradient)
199199
!! Apply one backward pass through the network.
200200
!! This changes the state of layers on the network.
201201
!! Typically used only internally from the `train` method,
@@ -206,6 +206,12 @@ module subroutine backward(self, output, loss)
206206
!! Output data
207207
class(loss_type), intent(in), optional :: loss
208208
!! Loss instance to use. If not provided, the default is quadratic().
209+
real, intent(in), optional :: gradient(:)
210+
!! Gradient to use for the output layer.
211+
!! If not provided, the gradient in the last layer is computed using
212+
!! the loss function.
213+
!! Passing the gradient is useful for merging/concatenating multiple
214+
!! networks.
209215
end subroutine backward
210216

211217
module integer function get_num_params(self)

src/nf/nf_network_submodule.f90

Lines changed: 66 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,11 @@ module function network_from_layers(layers) result(res)
115115
end function network_from_layers
116116

117117

118-
module subroutine backward(self, output, loss)
118+
module subroutine backward(self, output, loss, gradient)
119119
class(network), intent(in out) :: self
120120
real, intent(in) :: output(:)
121121
class(loss_type), intent(in), optional :: loss
122+
real, intent(in), optional :: gradient(:)
122123
integer :: n, num_layers
123124

124125
! Passing the loss instance is optional. If not provided, and if the
@@ -140,58 +141,71 @@ module subroutine backward(self, output, loss)
140141

141142
! Iterate backward over layers, from the output layer
142143
! to the first non-input layer
143-
do n = num_layers, 2, -1
144-
145-
if (n == num_layers) then
146-
! Output layer; apply the loss function
147-
select type(this_layer => self % layers(n) % p)
148-
type is(dense_layer)
149-
call self % layers(n) % backward( &
150-
self % layers(n - 1), &
151-
self % loss % derivative(output, this_layer % output) &
152-
)
153-
type is(flatten_layer)
154-
call self % layers(n) % backward( &
155-
self % layers(n - 1), &
156-
self % loss % derivative(output, this_layer % output) &
157-
)
158-
end select
159-
else
160-
! Hidden layer; take the gradient from the next layer
161-
select type(next_layer => self % layers(n + 1) % p)
162-
type is(dense_layer)
163-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
164-
type is(dropout_layer)
165-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
166-
type is(conv2d_layer)
167-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
168-
type is(flatten_layer)
169-
if (size(self % layers(n) % layer_shape) == 2) then
170-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_2d)
171-
else
172-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d)
173-
end if
174-
type is(maxpool2d_layer)
175-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
176-
type is(reshape3d_layer)
177-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
178-
type is(linear2d_layer)
179-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
180-
type is(self_attention_layer)
181-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
182-
type is(maxpool1d_layer)
183-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
184-
type is(reshape2d_layer)
185-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
186-
type is(conv1d_layer)
187-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
188-
type is(locally_connected2d_layer)
189-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
190-
type is(layernorm_layer)
191-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
192-
end select
193-
end if
194144

145+
! Output layer first
146+
n = num_layers
147+
if (present(gradient)) then
148+
149+
! If the gradient is passed, use it directly for the output layer
150+
select type(this_layer => self % layers(n) % p)
151+
type is(dense_layer)
152+
call self % layers(n) % backward(self % layers(n - 1), gradient)
153+
type is(flatten_layer)
154+
call self % layers(n) % backward(self % layers(n - 1), gradient)
155+
end select
156+
157+
else
158+
159+
! Apply the loss function
160+
select type(this_layer => self % layers(n) % p)
161+
type is(dense_layer)
162+
call self % layers(n) % backward( &
163+
self % layers(n - 1), &
164+
self % loss % derivative(output, this_layer % output) &
165+
)
166+
type is(flatten_layer)
167+
call self % layers(n) % backward( &
168+
self % layers(n - 1), &
169+
self % loss % derivative(output, this_layer % output) &
170+
)
171+
end select
172+
173+
end if
174+
175+
! Hidden layers; take the gradient from the next layer
176+
do n = num_layers - 1, 2, -1
177+
select type(next_layer => self % layers(n + 1) % p)
178+
type is(dense_layer)
179+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
180+
type is(dropout_layer)
181+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
182+
type is(conv2d_layer)
183+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
184+
type is(flatten_layer)
185+
if (size(self % layers(n) % layer_shape) == 2) then
186+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_2d)
187+
else
188+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d)
189+
end if
190+
type is(maxpool2d_layer)
191+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
192+
type is(reshape3d_layer)
193+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
194+
type is(linear2d_layer)
195+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
196+
type is(self_attention_layer)
197+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
198+
type is(maxpool1d_layer)
199+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
200+
type is(reshape2d_layer)
201+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
202+
type is(conv1d_layer)
203+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
204+
type is(locally_connected2d_layer)
205+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
206+
type is(layernorm_layer)
207+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
208+
end select
195209
end do
196210

197211
end subroutine backward

0 commit comments

Comments
 (0)