@@ -2,6 +2,7 @@ program test_multihead_attention_layer
22 use iso_fortran_env, only: stderr = > error_unit
33 use nf_multihead_attention_layer, only: multihead_attention_layer
44 use nf_linear2d_layer, only: linear2d_layer
5+ use nf_optimizers, only: sgd
56 implicit none
67
78 logical :: ok = .true.
@@ -21,6 +22,7 @@ program test_multihead_attention_layer
2122 call test_multihead_attention_combine_heads(attention, attention % sdpa, ok)
2223 call test_multihead_attention_forward(attention, ok)
2324 call test_multihead_attention_backward(attention, ok)
25+ call test_multihead_attention_update_gradients(attention, ok)
2426! call test_multihead_attention_forward_reallife_shape(ok)
2527
2628contains
@@ -239,4 +241,46 @@ subroutine test_multihead_attention_backward(attention, ok)
239241 write (stderr, ' (a)' ) ' backward returned incorrect values.. failed'
240242 end if
241243 end subroutine test_multihead_attention_backward
244+
245+ subroutine test_multihead_attention_update_gradients (attention , ok )
246+ type (multihead_attention_layer), intent (in out ) :: attention
247+ logical , intent (in out ) :: ok
248+ real :: parameters(80 )
249+ real :: expected_parameters(80 )
250+ real :: updated_output(12 )
251+ real :: expected_updated_output(12 ) = [&
252+ 0.111365855 , 0.115744293 , 0.115733206 , 0.185253710 , 0.196646214 , 0.196617395 ,&
253+ - 0.102874994 , - 0.118834510 , - 0.118794113 , 0.179314315 , 0.190210193 , 0.190182626 &
254+ ]
255+ type (sgd) :: optim
256+
257+ if (attention % get_num_params() /= 80 ) then
258+ ok = .false.
259+ write (stderr, ' (a)' ) ' incorrect number of parameters.. failed'
260+ end if
261+
262+ expected_parameters(1 : 64 ) = 0.100000001
263+ expected_parameters(65 : 80 ) = 0.109999999
264+ parameters = attention % get_params()
265+ if (.not. all (parameters.eq. expected_parameters)) then
266+ ok = .false.
267+ write (stderr, ' (a)' ) ' incorrect parameters.. failed'
268+ end if
269+
270+ optim = SGD(learning_rate= 0.01 )
271+ call optim % minimize(parameters, attention % get_gradients())
272+ call attention % set_params(parameters)
273+
274+ call attention % forward(&
275+ reshape ([0.0 , 10.1 , 0.2 , 10.3 , 0.4 , 10.5 , 0.6 , 10.7 , 10.8 , 0.9 , 0.11 , 0.12 ], [3 , 4 ]),&
276+ reshape ([0.0 , 10.1 , 0.2 , 10.3 , 0.4 , 10.5 , 0.6 , 10.7 , 10.8 , 0.9 , 0.11 , 0.12 ], [3 , 4 ]),&
277+ reshape ([0.0 , 10.1 , 0.2 , 10.3 , 0.4 , 10.5 , 0.6 , 10.7 , 10.8 , 0.9 , 0.11 , 0.12 ], [3 , 4 ])&
278+ )
279+
280+ updated_output = reshape (attention % output, [12 ])
281+ if (.not. all (updated_output.eq. expected_updated_output)) then
282+ ok = .false.
283+ write (stderr, ' (a)' ) ' incorrect output after parameters update.. failed'
284+ end if
285+ end subroutine test_multihead_attention_update_gradients
242286end program test_multihead_attention_layer
0 commit comments