Skip to content

Commit 05842ce

Browse files
committed
multihead_attention: params api
1 parent 44703f9 commit 05842ce

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ module nf_multihead_attention_layer
4040
procedure :: normalize_attention_matrix
4141
procedure :: scaled_dot_product_attention
4242
procedure :: combine_heads
43+
procedure :: get_num_params
44+
procedure :: get_params
45+
procedure :: get_gradients
46+
procedure :: set_params
4347
procedure :: init
4448

4549
end type multihead_attention_layer
@@ -348,6 +352,90 @@ module function combine_heads(self, input) result(output)
348352
end do
349353
end function combine_heads
350354

355+
module function get_num_params(self) result(num_params)
356+
class(multihead_attention_layer) :: self
357+
integer :: num_params
358+
359+
num_params = &
360+
self % query_layer % get_num_params() &
361+
+ self % key_layer % get_num_params() &
362+
+ self % value_layer % get_num_params() &
363+
+ self % output_layer % get_num_params()
364+
end function get_num_params
365+
366+
module function get_params(self) result(params)
367+
class(multihead_attention_layer), intent(in), target :: self
368+
real, allocatable :: params(:)
369+
370+
params = [&
371+
self % query_layer % weights,&
372+
self % key_layer % weights,&
373+
self % value_layer % weights,&
374+
self % output_layer % weights,&
375+
self % query_layer % biases,&
376+
self % key_layer % biases,&
377+
self % value_layer % biases,&
378+
self % output_layer % biases&
379+
]
380+
end function get_params
381+
382+
module function get_gradients(self) result(gradients)
383+
class(multihead_attention_layer), intent(in), target :: self
384+
real, allocatable :: gradients(:)
385+
386+
gradients = [ &
387+
self % query_layer % dw,&
388+
self % key_layer % dw,&
389+
self % value_layer % dw,&
390+
self % output_layer % dw,&
391+
self % query_layer % db,&
392+
self % key_layer % db,&
393+
self % value_layer % db,&
394+
self % output_layer % db&
395+
]
396+
end function get_gradients
397+
398+
module subroutine set_params(self, params)
399+
class(multihead_attention_layer), intent(in out) :: self
400+
real, intent(in), target :: params(:)
401+
real, pointer :: p_(:,:) => null()
402+
integer :: i, j, window
403+
404+
! check if the number of parameters is correct
405+
if (size(params) /= self % get_num_params()) then
406+
error stop 'Error: number of parameters does not match'
407+
end if
408+
409+
! FIXME: looks clumsy, better ideas?
410+
window = self % model_dimension * self % model_dimension
411+
i = 1
412+
j = window
413+
self % query_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension])
414+
i = j + 1
415+
j = i + window - 1
416+
self % key_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension])
417+
i = j + 1
418+
j = i + window - 1
419+
self % value_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension])
420+
i = j + 1
421+
j = i + window - 1
422+
self % output_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension])
423+
424+
window = self % model_dimension
425+
i = j + 1
426+
j = i + window - 1
427+
self % query_layer % biases = params(i: j)
428+
i = j + 1
429+
j = i + window - 1
430+
self % key_layer % biases = params(i: j)
431+
i = j + 1
432+
j = i + window - 1
433+
self % value_layer % biases = params(i: j)
434+
i = j + 1
435+
j = i + window - 1
436+
self % output_layer % biases = params(i: j)
437+
end subroutine set_params
438+
351439
module subroutine init(self, input_shape)
352440
class(multihead_attention_layer), intent(in out) :: self
353441
integer, intent(in) :: input_shape(:)

0 commit comments

Comments
 (0)