@@ -40,6 +40,10 @@ module nf_multihead_attention_layer
4040 procedure :: normalize_attention_matrix
4141 procedure :: scaled_dot_product_attention
4242 procedure :: combine_heads
43+ procedure :: get_num_params
44+ procedure :: get_params
45+ procedure :: get_gradients
46+ procedure :: set_params
4347 procedure :: init
4448
4549 end type multihead_attention_layer
@@ -348,6 +352,90 @@ module function combine_heads(self, input) result(output)
348352 end do
349353 end function combine_heads
350354
355+ module function get_num_params (self ) result(num_params)
356+ class(multihead_attention_layer) :: self
357+ integer :: num_params
358+
359+ num_params = &
360+ self % query_layer % get_num_params() &
361+ + self % key_layer % get_num_params() &
362+ + self % value_layer % get_num_params() &
363+ + self % output_layer % get_num_params()
364+ end function get_num_params
365+
366+ module function get_params (self ) result(params)
367+ class(multihead_attention_layer), intent (in ), target :: self
368+ real , allocatable :: params(:)
369+
370+ params = [&
371+ self % query_layer % weights,&
372+ self % key_layer % weights,&
373+ self % value_layer % weights,&
374+ self % output_layer % weights,&
375+ self % query_layer % biases,&
376+ self % key_layer % biases,&
377+ self % value_layer % biases,&
378+ self % output_layer % biases&
379+ ]
380+ end function get_params
381+
382+ module function get_gradients (self ) result(gradients)
383+ class(multihead_attention_layer), intent (in ), target :: self
384+ real , allocatable :: gradients(:)
385+
386+ gradients = [ &
387+ self % query_layer % dw,&
388+ self % key_layer % dw,&
389+ self % value_layer % dw,&
390+ self % output_layer % dw,&
391+ self % query_layer % db,&
392+ self % key_layer % db,&
393+ self % value_layer % db,&
394+ self % output_layer % db&
395+ ]
396+ end function get_gradients
397+
398+ module subroutine set_params (self , params )
399+ class(multihead_attention_layer), intent (in out ) :: self
400+ real , intent (in ), target :: params(:)
401+ real , pointer :: p_(:,:) = > null ()
402+ integer :: i, j, window
403+
404+ ! check if the number of parameters is correct
405+ if (size (params) /= self % get_num_params()) then
406+ error stop ' Error: number of parameters does not match'
407+ end if
408+
409+ ! FIXME: looks clumsy, better ideas?
410+ window = self % model_dimension * self % model_dimension
411+ i = 1
412+ j = window
413+ self % query_layer % weights = reshape (params(i: j), [self % model_dimension, self % model_dimension])
414+ i = j + 1
415+ j = i + window - 1
416+ self % key_layer % weights = reshape (params(i: j), [self % model_dimension, self % model_dimension])
417+ i = j + 1
418+ j = i + window - 1
419+ self % value_layer % weights = reshape (params(i: j), [self % model_dimension, self % model_dimension])
420+ i = j + 1
421+ j = i + window - 1
422+ self % output_layer % weights = reshape (params(i: j), [self % model_dimension, self % model_dimension])
423+
424+ window = self % model_dimension
425+ i = j + 1
426+ j = i + window - 1
427+ self % query_layer % biases = params(i: j)
428+ i = j + 1
429+ j = i + window - 1
430+ self % key_layer % biases = params(i: j)
431+ i = j + 1
432+ j = i + window - 1
433+ self % value_layer % biases = params(i: j)
434+ i = j + 1
435+ j = i + window - 1
436+ self % output_layer % biases = params(i: j)
437+ end subroutine set_params
438+
351439 module subroutine init (self , input_shape )
352440 class(multihead_attention_layer), intent (in out ) :: self
353441 integer , intent (in ) :: input_shape(:)
0 commit comments