@@ -115,10 +115,11 @@ module function network_from_layers(layers) result(res)
115
115
end function network_from_layers
116
116
117
117
118
- module subroutine backward (self , output , loss )
118
+ module subroutine backward (self , output , loss , gradient )
119
119
class(network), intent (in out ) :: self
120
120
real , intent (in ) :: output(:)
121
121
class(loss_type), intent (in ), optional :: loss
122
+ real , intent (in ), optional :: gradient(:)
122
123
integer :: n, num_layers
123
124
124
125
! Passing the loss instance is optional. If not provided, and if the
@@ -140,58 +141,71 @@ module subroutine backward(self, output, loss)
140
141
141
142
! Iterate backward over layers, from the output layer
142
143
! to the first non-input layer
143
- do n = num_layers, 2 , - 1
144
-
145
- if (n == num_layers) then
146
- ! Output layer; apply the loss function
147
- select type (this_layer = > self % layers(n) % p)
148
- type is (dense_layer)
149
- call self % layers(n) % backward( &
150
- self % layers(n - 1 ), &
151
- self % loss % derivative(output, this_layer % output) &
152
- )
153
- type is (flatten_layer)
154
- call self % layers(n) % backward( &
155
- self % layers(n - 1 ), &
156
- self % loss % derivative(output, this_layer % output) &
157
- )
158
- end select
159
- else
160
- ! Hidden layer; take the gradient from the next layer
161
- select type (next_layer = > self % layers(n + 1 ) % p)
162
- type is (dense_layer)
163
- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
164
- type is (dropout_layer)
165
- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
166
- type is (conv2d_layer)
167
- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
168
- type is (flatten_layer)
169
- if (size (self % layers(n) % layer_shape) == 2 ) then
170
- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient_2d)
171
- else
172
- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient_3d)
173
- end if
174
- type is (maxpool2d_layer)
175
- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
176
- type is (reshape3d_layer)
177
- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
178
- type is (linear2d_layer)
179
- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
180
- type is (self_attention_layer)
181
- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
182
- type is (maxpool1d_layer)
183
- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
184
- type is (reshape2d_layer)
185
- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
186
- type is (conv1d_layer)
187
- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
188
- type is (locally_connected2d_layer)
189
- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
190
- type is (layernorm_layer)
191
- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
192
- end select
193
- end if
194
144
145
+ ! Output layer first
146
+ n = num_layers
147
+ if (present (gradient)) then
148
+
149
+ ! If the gradient is passed, use it directly for the output layer
150
+ select type (this_layer = > self % layers(n) % p)
151
+ type is (dense_layer)
152
+ call self % layers(n) % backward(self % layers(n - 1 ), gradient)
153
+ type is (flatten_layer)
154
+ call self % layers(n) % backward(self % layers(n - 1 ), gradient)
155
+ end select
156
+
157
+ else
158
+
159
+ ! Apply the loss function
160
+ select type (this_layer = > self % layers(n) % p)
161
+ type is (dense_layer)
162
+ call self % layers(n) % backward( &
163
+ self % layers(n - 1 ), &
164
+ self % loss % derivative(output, this_layer % output) &
165
+ )
166
+ type is (flatten_layer)
167
+ call self % layers(n) % backward( &
168
+ self % layers(n - 1 ), &
169
+ self % loss % derivative(output, this_layer % output) &
170
+ )
171
+ end select
172
+
173
+ end if
174
+
175
+ ! Hidden layers; take the gradient from the next layer
176
+ do n = num_layers - 1 , 2 , - 1
177
+ select type (next_layer = > self % layers(n + 1 ) % p)
178
+ type is (dense_layer)
179
+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
180
+ type is (dropout_layer)
181
+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
182
+ type is (conv2d_layer)
183
+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
184
+ type is (flatten_layer)
185
+ if (size (self % layers(n) % layer_shape) == 2 ) then
186
+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient_2d)
187
+ else
188
+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient_3d)
189
+ end if
190
+ type is (maxpool2d_layer)
191
+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
192
+ type is (reshape3d_layer)
193
+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
194
+ type is (linear2d_layer)
195
+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
196
+ type is (self_attention_layer)
197
+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
198
+ type is (maxpool1d_layer)
199
+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
200
+ type is (reshape2d_layer)
201
+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
202
+ type is (conv1d_layer)
203
+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
204
+ type is (locally_connected2d_layer)
205
+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
206
+ type is (layernorm_layer)
207
+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
208
+ end select
195
209
end do
196
210
197
211
end subroutine backward
0 commit comments