Skip to content
40 changes: 34 additions & 6 deletions src/nf/nf_multihead_attention.f90
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,25 @@ module nf_multihead_attention_layer
real, allocatable :: k_input(:, :)
real, allocatable :: v_input(:, :)
real, allocatable :: o_input(:, :)

! temporary storages for forward and backward passes
real, allocatable :: normalized_attention(:, :, :)
real, allocatable :: q_or_dq(:, :, :)
real, allocatable :: k_or_dk(:, :, :)
real, allocatable :: v_or_dv(:, :, :)
real, allocatable :: d_output(:, :, :)
real, allocatable :: v_heads(:, :, :)
real, allocatable :: k_heads(:, :, :)
real, allocatable :: q_heads(:, :, :)
real, allocatable :: d_sdpa(:, :)
real, allocatable :: jacobian(:, :)
real, allocatable :: d_normalize(:, :, :)
contains

procedure :: common_backward
procedure :: common_forward
procedure :: sdpa_forward
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering what was sdpa until I found it in one of your comments below (that is, Scaled Dot Product Attention).
I suggest to add a comment to explain it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do!

procedure :: sdpa_backward
procedure :: get_num_params
procedure :: get_params
procedure :: get_gradients
Expand All @@ -68,25 +83,38 @@ end function multihead_attention_layer_cons

interface

pure module subroutine common_backward(self, input, gradient)
pure module subroutine common_backward(self, input, gradient, attention_mask)
!! General backprop for MultiHead Attention mechanism
!! Might be used for both Self and Cross Attention
!! Self Attention: sum output gradients
!! Cross Attention: use them separately
class(multihead_attention_layer), intent(in out) :: self
real, intent(in) :: input(:, :)
real, intent(in) :: gradient(:, :)
real, optional, intent(in) :: attention_mask(:, :)
end subroutine common_backward

pure module subroutine common_forward(self, query, key, value)
pure module subroutine common_forward(self, query, key, value, attention_mask)
!! General forward propagation for MultiHead Attention Mechanism
!! Might be used for both Self and Cross Attention
!! Self Attention: pass the same value thrice
!! Cross Attention: pass three values for your query, key and value
class(multihead_attention_layer), intent(in out) :: self
real, intent(in) :: query(:, :), key(:, :), value(:, :)
real, optional, intent(in) :: attention_mask(:, :)
end subroutine common_forward

pure module subroutine sdpa_forward(self, attention_mask)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put Scaled Dot Product Attention into a separate method. This adds more flexibility.
In some cases we need to do manipulations with input projections, such as KV Caching for LLama and Qwen2.

class(multihead_attention_layer), intent(in out) :: self
real, intent(in), optional :: attention_mask(:, :)
end subroutine sdpa_forward

pure module subroutine sdpa_backward(self, gradient, attention_mask)
class(multihead_attention_layer), intent(in out) :: self
real, intent(in) :: gradient(:, :)
real, intent(in), optional :: attention_mask(:, :)
end subroutine sdpa_backward

pure module subroutine init(self, input_shape)
!! Initialize the layer data structures.
!!
Expand Down Expand Up @@ -119,7 +147,7 @@ pure module subroutine normalize_attention_matrix(self, attention_mask)
!! Output dims: sequence_length, sequence_length, n_heads
class(multihead_attention_layer), intent(in out) :: self
!! (sequence_length, sequence_length, n_heads)
real, optional, intent(in) :: attention_mask(:, :, :)
real, optional, intent(in) :: attention_mask(:, :)
!! (sequence_length, sequence_length, n_heads)
end subroutine normalize_attention_matrix

Expand All @@ -143,18 +171,18 @@ elemental module function get_num_params(self) result(num_params)
end function get_num_params

module function get_params(self) result(params)
class(multihead_attention_layer), intent(in), target :: self
class(multihead_attention_layer), intent(in) :: self
real, allocatable :: params(:)
end function get_params

module function get_gradients(self) result(gradients)
class(multihead_attention_layer), intent(in), target :: self
class(multihead_attention_layer), intent(in) :: self
real, allocatable :: gradients(:)
end function get_gradients

module subroutine set_params(self, params)
class(multihead_attention_layer), intent(in out) :: self
real, intent(in), target :: params(:)
real, intent(in) :: params(:)
end subroutine set_params

module subroutine init_base(self, input_shape)
Expand Down
Loading