1111 use nf_maxpool2d_layer, only: maxpool2d_layer
1212 use nf_reshape_layer, only: reshape3d_layer
1313 use nf_linear2d_layer, only: linear2d_layer
14+ use nf_self_attention_layer, only: self_attention_layer
1415 use nf_optimizers, only: optimizer_base_type
1516
1617contains
@@ -57,6 +58,8 @@ pure module subroutine backward_1d(self, previous, gradient)
5758 call this_layer % backward(prev_layer % output, gradient)
5859 type is (linear2d_layer)
5960 call this_layer % backward(prev_layer % output, gradient)
61+ type is (self_attention_layer)
62+ call this_layer % backward(prev_layer % output, gradient)
6063 end select
6164
6265 end select
@@ -79,6 +82,19 @@ pure module subroutine backward_2d(self, previous, gradient)
7982 call this_layer % backward(prev_layer % output, gradient)
8083 type is (linear2d_layer)
8184 call this_layer % backward(prev_layer % output, gradient)
85+ type is (self_attention_layer)
86+ call this_layer % backward(prev_layer % output, gradient)
87+ end select
88+
89+ type is (self_attention_layer)
90+
91+ select type (prev_layer = > previous % p)
92+ type is (input2d_layer)
93+ call this_layer % backward(prev_layer % output, gradient)
94+ type is (linear2d_layer)
95+ call this_layer % backward(prev_layer % output, gradient)
96+ type is (self_attention_layer)
97+ call this_layer % backward(prev_layer % output, gradient)
8298 end select
8399
84100 end select
@@ -240,6 +256,20 @@ module subroutine forward(self, input)
240256 call this_layer % forward(prev_layer % output)
241257 type is (linear2d_layer)
242258 call this_layer % forward(prev_layer % output)
259+ type is (self_attention_layer)
260+ call this_layer % forward(prev_layer % output)
261+ end select
262+
263+ type is (self_attention_layer)
264+
265+ ! Upstream layers permitted: input2d, linear2d
266+ select type (prev_layer = > input % p)
267+ type is (input2d_layer)
268+ call this_layer % forward(prev_layer % output)
269+ type is (linear2d_layer)
270+ call this_layer % forward(prev_layer % output)
271+ type is (self_attention_layer)
272+ call this_layer % forward(prev_layer % output)
243273 end select
244274
245275 end select
@@ -279,6 +309,8 @@ pure module subroutine get_output_2d(self, output)
279309 allocate (output, source= this_layer % output)
280310 type is (linear2d_layer)
281311 allocate (output, source= this_layer % output)
312+ type is (self_attention_layer)
313+ allocate (output, source= this_layer % output)
282314 class default
283315 error stop ' 2-d output can only be read from an input2d or linear2d layer.'
284316
@@ -322,8 +354,8 @@ impure elemental module subroutine init(self, input)
322354 call this_layer % init(input % layer_shape)
323355 end select
324356
325- ! The shape of conv2d, dropout, flatten, linear2d, or maxpool2d layers
326- ! is not known until we receive an input layer.
357+ ! The shape of conv2d, dropout, flatten, linear2d, maxpool2d, or
358+ ! self_attention layers is not known until we receive an input layer.
327359 select type (this_layer = > self % p)
328360 type is (conv2d_layer)
329361 self % layer_shape = shape (this_layer % output)
@@ -333,6 +365,8 @@ impure elemental module subroutine init(self, input)
333365 self % layer_shape = shape (this_layer % output)
334366 type is (linear2d_layer)
335367 self % layer_shape = shape (this_layer % output)
368+ type is (self_attention_layer)
369+ self % layer_shape = shape (this_layer % output)
336370 type is (maxpool2d_layer)
337371 self % layer_shape = shape (this_layer % output)
338372 end select
@@ -389,6 +423,8 @@ elemental module function get_num_params(self) result(num_params)
389423 num_params = 0
390424 type is (linear2d_layer)
391425 num_params = this_layer % get_num_params()
426+ type is (self_attention_layer)
427+ num_params = this_layer % get_num_params()
392428 class default
393429 error stop ' Unknown layer type.'
394430 end select
@@ -420,6 +456,8 @@ module function get_params(self) result(params)
420456 ! No parameters to get.
421457 type is (linear2d_layer)
422458 params = this_layer % get_params()
459+ type is (self_attention_layer)
460+ params = this_layer % get_params()
423461 class default
424462 error stop ' Unknown layer type.'
425463 end select
@@ -451,6 +489,8 @@ module function get_gradients(self) result(gradients)
451489 ! No gradients to get.
452490 type is (linear2d_layer)
453491 gradients = this_layer % get_gradients()
492+ type is (self_attention_layer)
493+ gradients = this_layer % get_gradients()
454494 class default
455495 error stop ' Unknown layer type.'
456496 end select
@@ -506,6 +546,9 @@ module subroutine set_params(self, params)
506546 type is (linear2d_layer)
507547 call this_layer % set_params(params)
508548
549+ type is (self_attention_layer)
550+ call this_layer % set_params(params)
551+
509552 type is (maxpool2d_layer)
510553 ! No parameters to set.
511554 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
0 commit comments