55 use nf_dense_layer, only: dense_layer
66 use nf_flatten_layer, only: flatten_layer
77 use nf_input1d_layer, only: input1d_layer
8+ use nf_input2d_layer, only: input2d_layer
89 use nf_input3d_layer, only: input3d_layer
910 use nf_maxpool2d_layer, only: maxpool2d_layer
1011 use nf_reshape_layer, only: reshape3d_layer
@@ -51,6 +52,18 @@ pure module subroutine backward_1d(self, previous, gradient)
5152 end subroutine backward_1d
5253
5354
55+ pure module subroutine backward_2d(self, previous, gradient)
56+ implicit none
57+ class(layer), intent (in out ) :: self
58+ class(layer), intent (in ) :: previous
59+ real , intent (in ) :: gradient(:,:)
60+
61+ ! Backward pass from a 2-d layer downstream currently implemented
62+ ! only for dense and flatten layers
63+ ! CURRENTLY NO LAYERS, tbd: pull/197 and pull/199
64+ end subroutine backward_2d
65+
66+
5467 pure module subroutine backward_3d(self, previous, gradient)
5568 implicit none
5669 class(layer), intent (in out ) :: self
@@ -205,6 +218,23 @@ pure module subroutine get_output_1d(self, output)
205218 end subroutine get_output_1d
206219
207220
221+ pure module subroutine get_output_2d(self, output)
222+ implicit none
223+ class(layer), intent (in ) :: self
224+ real , allocatable , intent (out ) :: output(:,:)
225+
226+ select type (this_layer = > self % p)
227+
228+ type is (input2d_layer)
229+ allocate (output, source= this_layer % output)
230+ class default
231+ error stop ' 1-d output can only be read from an input1d, dense, or flatten layer.'
232+
233+ end select
234+
235+ end subroutine get_output_2d
236+
237+
208238 pure module subroutine get_output_3d(self, output)
209239 implicit none
210240 class(layer), intent (in ) :: self
@@ -280,6 +310,8 @@ elemental module function get_num_params(self) result(num_params)
280310 select type (this_layer = > self % p)
281311 type is (input1d_layer)
282312 num_params = 0
313+ type is (input2d_layer)
314+ num_params = 0
283315 type is (input3d_layer)
284316 num_params = 0
285317 type is (dense_layer)
@@ -305,6 +337,8 @@ module function get_params(self) result(params)
305337 select type (this_layer = > self % p)
306338 type is (input1d_layer)
307339 ! No parameters to get.
340+ type is (input2d_layer)
341+ ! No parameters to get.
308342 type is (input3d_layer)
309343 ! No parameters to get.
310344 type is (dense_layer)
@@ -330,6 +364,8 @@ module function get_gradients(self) result(gradients)
330364 select type (this_layer = > self % p)
331365 type is (input1d_layer)
332366 ! No gradients to get.
367+ type is (input2d_layer)
368+ ! No gradients to get.
333369 type is (input3d_layer)
334370 ! No gradients to get.
335371 type is (dense_layer)
@@ -373,6 +409,11 @@ module subroutine set_params(self, params)
373409 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
374410 // ' on a zero-parameter layer; nothing to do.'
375411
412+ type is (input2d_layer)
413+ ! No parameters to set.
414+ write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
415+ // ' on a zero-parameter layer; nothing to do.'
416+
376417 type is (input3d_layer)
377418 ! No parameters to set.
378419 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
0 commit comments