1010 use nf_maxpool2d_layer, only: maxpool2d_layer
1111 use nf_reshape_layer, only: reshape3d_layer
1212 use nf_linear2d_layer, only: linear2d_layer
13+ use nf_self_attention_layer, only: self_attention_layer
1314 use nf_optimizers, only: optimizer_base_type
1415
1516contains
@@ -50,6 +51,8 @@ pure module subroutine backward_1d(self, previous, gradient)
5051 call this_layer % backward(prev_layer % output, gradient)
5152 type is (linear2d_layer)
5253 call this_layer % backward(prev_layer % output, gradient)
54+ type is (self_attention_layer)
55+ call this_layer % backward(prev_layer % output, gradient)
5356 end select
5457
5558 end select
@@ -72,6 +75,19 @@ pure module subroutine backward_2d(self, previous, gradient)
7275 call this_layer % backward(prev_layer % output, gradient)
7376 type is (linear2d_layer)
7477 call this_layer % backward(prev_layer % output, gradient)
78+ type is (self_attention_layer)
79+ call this_layer % backward(prev_layer % output, gradient)
80+ end select
81+
82+ type is (self_attention_layer)
83+
84+ select type (prev_layer = > previous % p)
85+ type is (input2d_layer)
86+ call this_layer % backward(prev_layer % output, gradient)
87+ type is (linear2d_layer)
88+ call this_layer % backward(prev_layer % output, gradient)
89+ type is (self_attention_layer)
90+ call this_layer % backward(prev_layer % output, gradient)
7591 end select
7692
7793 end select
@@ -219,6 +235,20 @@ pure module subroutine forward(self, input)
219235 call this_layer % forward(prev_layer % output)
220236 type is (linear2d_layer)
221237 call this_layer % forward(prev_layer % output)
238+ type is (self_attention_layer)
239+ call this_layer % forward(prev_layer % output)
240+ end select
241+
242+ type is (self_attention_layer)
243+
244+ ! Upstream layers permitted: input2d, linear2d
245+ select type (prev_layer = > input % p)
246+ type is (input2d_layer)
247+ call this_layer % forward(prev_layer % output)
248+ type is (linear2d_layer)
249+ call this_layer % forward(prev_layer % output)
250+ type is (self_attention_layer)
251+ call this_layer % forward(prev_layer % output)
222252 end select
223253
224254 end select
@@ -258,6 +288,8 @@ pure module subroutine get_output_2d(self, output)
258288 allocate (output, source= this_layer % output)
259289 type is (linear2d_layer)
260290 allocate (output, source= this_layer % output)
291+ type is (self_attention_layer)
292+ allocate (output, source= this_layer % output)
261293 class default
262294 error stop ' 2-d output can only be read from an input2d or linear2d layer.'
263295
@@ -357,6 +389,8 @@ elemental module function get_num_params(self) result(num_params)
357389 num_params = 0
358390 type is (linear2d_layer)
359391 num_params = this_layer % get_num_params()
392+ type is (self_attention_layer)
393+ num_params = this_layer % get_num_params()
360394 class default
361395 error stop ' Unknown layer type.'
362396 end select
@@ -386,6 +420,8 @@ module function get_params(self) result(params)
386420 ! No parameters to get.
387421 type is (linear2d_layer)
388422 params = this_layer % get_params()
423+ type is (self_attention_layer)
424+ params = this_layer % get_params()
389425 class default
390426 error stop ' Unknown layer type.'
391427 end select
@@ -415,6 +451,8 @@ module function get_gradients(self) result(gradients)
415451 ! No gradients to get.
416452 type is (linear2d_layer)
417453 gradients = this_layer % get_gradients()
454+ type is (self_attention_layer)
455+ gradients = this_layer % get_gradients()
418456 class default
419457 error stop ' Unknown layer type.'
420458 end select
@@ -465,6 +503,9 @@ module subroutine set_params(self, params)
465503 type is (linear2d_layer)
466504 call this_layer % set_params(params)
467505
506+ type is (self_attention_layer)
507+ call this_layer % set_params(params)
508+
468509 type is (maxpool2d_layer)
469510 ! No parameters to set.
470511 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
0 commit comments