From 49e85075730e646e8e0ffb5a4ea9db0b20c0aa7b Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 2 Feb 2025 13:36:04 +0400 Subject: [PATCH 01/71] linear2d_layer forward implementation --- src/nf.f90 | 1 + 1 file changed, 1 insertion(+) diff --git a/src/nf.f90 b/src/nf.f90 index e9b027c1..4351e201 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -12,4 +12,5 @@ module nf gaussian, linear, relu, leaky_relu, & sigmoid, softmax, softplus, step, tanhf, & celu + use nf_linear2d_layer, only: linear2d_layer end module nf From feb711246c02023a30f5685a2e1aece8ac0344a8 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sat, 15 Feb 2025 00:01:36 +0400 Subject: [PATCH 02/71] linear2d_layer: temporarily remove api --- src/nf/nf_network_submodule.f90 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index c2a9c903..94a3d2e4 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -156,8 +156,8 @@ module subroutine backward(self, output, loss) type is(reshape3d_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(linear2d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) +! type is(linear2d_layer) +! call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) end select end if From 8f320f073de8d33ca2ec8011cb95fe933b648032 Mon Sep 17 00:00:00 2001 From: milancurcic Date: Sat, 15 Feb 2025 22:32:01 -0500 Subject: [PATCH 03/71] Don't expose the concrete layer type via nf --- src/nf.f90 | 1 - 1 file changed, 1 deletion(-) diff --git a/src/nf.f90 b/src/nf.f90 index 4351e201..e9b027c1 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -12,5 +12,4 @@ module nf gaussian, linear, relu, leaky_relu, & sigmoid, softmax, softplus, step, tanhf, & celu - use nf_linear2d_layer, only: linear2d_layer end module nf From af4a5d73449f7a6330ef99a5e659eddc8a38219f Mon Sep 17 00:00:00 2001 From: milancurcic Date: Sat, 15 Feb 2025 22:48:40 -0500 Subject: [PATCH 04/71] Plumbing of linear2d with input2d and linear2d --- src/nf/nf_network_submodule.f90 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 94a3d2e4..c2a9c903 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -156,8 +156,8 @@ module subroutine backward(self, output, loss) type is(reshape3d_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) -! type is(linear2d_layer) -! call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(linear2d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) end select end if From 549d4e61032400d911782ec9429558e48aa091ab Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 16 Feb 2025 16:00:35 +0400 Subject: [PATCH 05/71] linear2d_layer: add flatten2d layer --- src/nf/nf_flatten2d_layer.f90 | 75 +++++++++++++++++++++ src/nf/nf_flatten2d_layer_submodule.f90 | 48 +++++++++++++ src/nf/nf_layer_constructors.f90 | 21 +++++- test/test_flatten2d_layer.f90 | 89 +++++++++++++++++++++++++ 4 files changed, 232 insertions(+), 1 deletion(-) create mode 100644 src/nf/nf_flatten2d_layer.f90 create mode 100644 src/nf/nf_flatten2d_layer_submodule.f90 create mode 100644 test/test_flatten2d_layer.f90 diff --git a/src/nf/nf_flatten2d_layer.f90 b/src/nf/nf_flatten2d_layer.f90 new file mode 100644 index 00000000..e67037f8 --- /dev/null +++ b/src/nf/nf_flatten2d_layer.f90 @@ -0,0 +1,75 @@ +module nf_flatten2d_layer + + !! This module provides the concrete flatten2d layer type. + !! It is used internally by the layer type. + !! It is not intended to be used directly by the user. + + use nf_base_layer, only: base_layer + + implicit none + + private + public :: flatten2d_layer + + type, extends(base_layer) :: flatten2d_layer + + !! Concrete implementation of a flatten2d (2-d to 1-d) layer. + + integer, allocatable :: input_shape(:) + integer :: output_size + + real, allocatable :: gradient(:,:) + real, allocatable :: output(:) + + contains + + procedure :: backward + procedure :: forward + procedure :: init + + end type flatten2d_layer + + interface flatten2d_layer + elemental module function flatten2d_layer_cons() result(res) + !! This function returns the `flatten2d_layer` instance. + type(flatten2d_layer) :: res + !! `flatten2d_layer` instance + end function flatten2d_layer_cons + end interface flatten2d_layer + + interface + + pure module subroutine backward(self, input, gradient) + !! Apply the backward pass to the flatten2d layer. + !! This is a reshape operation from 1-d gradient to 2-d input. + class(flatten2d_layer), intent(in out) :: self + !! flatten2d layer instance + real, intent(in) :: input(:,:) + !! Input from the previous layer + real, intent(in) :: gradient(:) + !! Gradient from the next layer + end subroutine backward + + pure module subroutine forward(self, input) + !! Propagate forward the layer. + !! Calling this subroutine updates the values of a few data components + !! of `flatten2d_layer` that are needed for the backward pass. + class(flatten2d_layer), intent(in out) :: self + !! Dense layer instance + real, intent(in) :: input(:,:) + !! Input from the previous layer + end subroutine forward + + module subroutine init(self, input_shape) + !! Initialize the layer data structures. + !! + !! This is a deferred procedure from the `base_layer` abstract type. + class(flatten2d_layer), intent(in out) :: self + !! Dense layer instance + integer, intent(in) :: input_shape(:) + !! Shape of the input layer + end subroutine init + + end interface + +end module nf_flatten2d_layer diff --git a/src/nf/nf_flatten2d_layer_submodule.f90 b/src/nf/nf_flatten2d_layer_submodule.f90 new file mode 100644 index 00000000..875b7374 --- /dev/null +++ b/src/nf/nf_flatten2d_layer_submodule.f90 @@ -0,0 +1,48 @@ +submodule(nf_flatten2d_layer) nf_flatten2d_layer_submodule + + !! This module provides the concrete flatten2d layer type. + !! It is used internally by the layer type. + !! It is not intended to be used directly by the user. + + use nf_base_layer, only: base_layer + + implicit none + +contains + + elemental module function flatten2d_layer_cons() result(res) + type(flatten2d_layer) :: res + end function flatten2d_layer_cons + + + pure module subroutine backward(self, input, gradient) + class(flatten2d_layer), intent(in out) :: self + real, intent(in) :: input(:,:) + real, intent(in) :: gradient(:) + self % gradient = reshape(gradient, shape(input)) + end subroutine backward + + + pure module subroutine forward(self, input) + class(flatten2d_layer), intent(in out) :: self + real, intent(in) :: input(:,:) + self % output = pack(input, .true.) + end subroutine forward + + + module subroutine init(self, input_shape) + class(flatten2d_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + self % input_shape = input_shape + self % output_size = product(input_shape) + + allocate(self % gradient(input_shape(1), input_shape(2))) + self % gradient = 0 + + allocate(self % output(self % output_size)) + self % output = 0 + + end subroutine init + +end submodule nf_flatten2d_layer_submodule diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index 2983ddcd..ef96a8fa 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -8,7 +8,7 @@ module nf_layer_constructors implicit none private - public :: conv2d, dense, flatten, input, maxpool2d, reshape, linear2d + public :: conv2d, dense, flatten, flatten2d, input, maxpool2d, reshape, linear2d interface input @@ -125,6 +125,25 @@ module function flatten() result(res) !! Resulting layer instance end function flatten + module function flatten2d() result(res) + !! Flatten (2-d -> 1-d) layer constructor. + !! + !! Use this layer to chain layers with 2-d outputs to layers with 2-d + !! inputs. + !! + !! A flatten layer must not be the first layer in the network. + !! + !! Example: + !! + !! ``` + !! use nf, only :: flatten, layer + !! type(layer) :: flatten_layer + !! flatten_layer = flatten() + !! ``` + type(layer) :: res + !! Resulting layer instance + end function flatten2d + module function conv2d(filters, kernel_size, activation) result(res) !! 2-d convolutional layer constructor. !! diff --git a/test/test_flatten2d_layer.f90 b/test/test_flatten2d_layer.f90 new file mode 100644 index 00000000..3189b4e9 --- /dev/null +++ b/test/test_flatten2d_layer.f90 @@ -0,0 +1,89 @@ +program test_flatten2d_layer + + use iso_fortran_env, only: stderr => error_unit + use nf, only: dense, flatten2d, input, layer, network + use nf_flatten2d_layer, only: flatten2d_layer + use nf_input2d_layer, only: input2d_layer + + implicit none + + type(layer) :: test_layer, input_layer + type(network) :: net + real, allocatable :: gradient(:,:) + real, allocatable :: output(:) + logical :: ok = .true. + + test_layer = flatten2d() + + if (.not. test_layer % name == 'flatten2d') then + ok = .false. + write(stderr, '(a)') 'flatten2d layer has its name set correctly.. failed' + end if + + if (test_layer % initialized) then + ok = .false. + write(stderr, '(a)') 'flatten2d layer is not initialized yet.. failed' + end if + + input_layer = input(1, 2) + call test_layer % init(input_layer) + + if (.not. test_layer % initialized) then + ok = .false. + write(stderr, '(a)') 'flatten2d layer is now initialized.. failed' + end if + + if (.not. all(test_layer % layer_shape == [2])) then + ok = .false. + write(stderr, '(a)') 'flatten2d layer has an incorrect output shape.. failed' + end if + + ! Test forward pass - reshaping from 2-d to 1-d + + select type(this_layer => input_layer % p); type is(input2d_layer) + call this_layer % set(reshape(real([1, 2, 3, 4]), [2, 2])) + end select + + call test_layer % forward(input_layer) + call test_layer % get_output(output) + + if (.not. all(output == [1, 2, 3, 4])) then + ok = .false. + write(stderr, '(a)') 'flatten2d layer correctly propagates forward.. failed' + end if + + ! Test backward pass - reshaping from 1-d to 2-d + + ! Calling backward() will set the values on the gradient component + ! input_layer is used only to determine shape + call test_layer % backward(input_layer, real([1, 2, 3, 4])) + + select type(this_layer => test_layer % p); type is(flatten2d_layer) + gradient = this_layer % gradient + end select + + if (.not. all(gradient == reshape(real([1, 2, 3, 4]), [2, 2]))) then + ok = .false. + write(stderr, '(a)') 'flatten2d layer correctly propagates backward.. failed' + end if + + net = network([ & + input(28, 28), & + flatten2d(), & + dense(10) & + ]) + + ! Test that the output layer receives 784 elements in the input + if (.not. all(net % layers(3) % input_layer_shape == [784])) then + ok = .false. + write(stderr, '(a)') 'flatten2d layer correctly chains input2d to dense.. failed' + end if + + if (ok) then + print '(a)', 'test_flatten2d_layer: All tests passed.' + else + write(stderr, '(a)') 'test_flatten2d_layer: One or more tests failed.' + stop 1 + end if + +end program test_flatten2d_layer From 3218be06169fb4b0fbc2d04624526b33f952e092 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 16 Feb 2025 16:02:44 +0400 Subject: [PATCH 06/71] linear2d_layer: make linear2d layer work with input2d and flatten2d --- src/nf.f90 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nf.f90 b/src/nf.f90 index e9b027c1..d215eb85 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -3,7 +3,7 @@ module nf use nf_datasets_mnist, only: label_digits, load_mnist use nf_layer, only: layer use nf_layer_constructors, only: & - conv2d, dense, flatten, input, maxpool2d, reshape, linear2d + conv2d, dense, flatten, flatten2d, input, maxpool2d, reshape, linear2d use nf_loss, only: mse, quadratic use nf_metrics, only: corr, maxabs use nf_network, only: network From 39636f4f3f87354eb446f3f448ec513e1bc3df82 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 16 Feb 2025 16:05:16 +0400 Subject: [PATCH 07/71] update cmake --- CMakeLists.txt | 2 ++ test/CMakeLists.txt | 1 + 2 files changed, 3 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index fc2ddfcb..586997fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,6 +28,8 @@ add_library(neural-fortran src/nf/nf_dense_layer_submodule.f90 src/nf/nf_flatten_layer.f90 src/nf/nf_flatten_layer_submodule.f90 + src/nf/nf_flatten2d_layer.f90 + src/nf/nf_flatten2d_layer_submodule.f90 src/nf/nf_input1d_layer.f90 src/nf/nf_input1d_layer_submodule.f90 src/nf/nf_input2d_layer.f90 diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 12236416..b52a3781 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -8,6 +8,7 @@ foreach(execid conv2d_layer maxpool2d_layer flatten_layer + flatten2d_layer insert_flatten reshape_layer dense_network From 4cc7d1d6fefb65fadce91014e1869eadcfaf9f9c Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 16 Feb 2025 23:08:55 +0400 Subject: [PATCH 08/71] linear2d_layer: remove flatten2d layer --- src/nf/nf_flatten2d_layer.f90 | 75 --------------------- src/nf/nf_flatten2d_layer_submodule.f90 | 48 ------------- src/nf/nf_layer_constructors.f90 | 21 +----- test/test_flatten2d_layer.f90 | 89 ------------------------- 4 files changed, 1 insertion(+), 232 deletions(-) delete mode 100644 src/nf/nf_flatten2d_layer.f90 delete mode 100644 src/nf/nf_flatten2d_layer_submodule.f90 delete mode 100644 test/test_flatten2d_layer.f90 diff --git a/src/nf/nf_flatten2d_layer.f90 b/src/nf/nf_flatten2d_layer.f90 deleted file mode 100644 index e67037f8..00000000 --- a/src/nf/nf_flatten2d_layer.f90 +++ /dev/null @@ -1,75 +0,0 @@ -module nf_flatten2d_layer - - !! This module provides the concrete flatten2d layer type. - !! It is used internally by the layer type. - !! It is not intended to be used directly by the user. - - use nf_base_layer, only: base_layer - - implicit none - - private - public :: flatten2d_layer - - type, extends(base_layer) :: flatten2d_layer - - !! Concrete implementation of a flatten2d (2-d to 1-d) layer. - - integer, allocatable :: input_shape(:) - integer :: output_size - - real, allocatable :: gradient(:,:) - real, allocatable :: output(:) - - contains - - procedure :: backward - procedure :: forward - procedure :: init - - end type flatten2d_layer - - interface flatten2d_layer - elemental module function flatten2d_layer_cons() result(res) - !! This function returns the `flatten2d_layer` instance. - type(flatten2d_layer) :: res - !! `flatten2d_layer` instance - end function flatten2d_layer_cons - end interface flatten2d_layer - - interface - - pure module subroutine backward(self, input, gradient) - !! Apply the backward pass to the flatten2d layer. - !! This is a reshape operation from 1-d gradient to 2-d input. - class(flatten2d_layer), intent(in out) :: self - !! flatten2d layer instance - real, intent(in) :: input(:,:) - !! Input from the previous layer - real, intent(in) :: gradient(:) - !! Gradient from the next layer - end subroutine backward - - pure module subroutine forward(self, input) - !! Propagate forward the layer. - !! Calling this subroutine updates the values of a few data components - !! of `flatten2d_layer` that are needed for the backward pass. - class(flatten2d_layer), intent(in out) :: self - !! Dense layer instance - real, intent(in) :: input(:,:) - !! Input from the previous layer - end subroutine forward - - module subroutine init(self, input_shape) - !! Initialize the layer data structures. - !! - !! This is a deferred procedure from the `base_layer` abstract type. - class(flatten2d_layer), intent(in out) :: self - !! Dense layer instance - integer, intent(in) :: input_shape(:) - !! Shape of the input layer - end subroutine init - - end interface - -end module nf_flatten2d_layer diff --git a/src/nf/nf_flatten2d_layer_submodule.f90 b/src/nf/nf_flatten2d_layer_submodule.f90 deleted file mode 100644 index 875b7374..00000000 --- a/src/nf/nf_flatten2d_layer_submodule.f90 +++ /dev/null @@ -1,48 +0,0 @@ -submodule(nf_flatten2d_layer) nf_flatten2d_layer_submodule - - !! This module provides the concrete flatten2d layer type. - !! It is used internally by the layer type. - !! It is not intended to be used directly by the user. - - use nf_base_layer, only: base_layer - - implicit none - -contains - - elemental module function flatten2d_layer_cons() result(res) - type(flatten2d_layer) :: res - end function flatten2d_layer_cons - - - pure module subroutine backward(self, input, gradient) - class(flatten2d_layer), intent(in out) :: self - real, intent(in) :: input(:,:) - real, intent(in) :: gradient(:) - self % gradient = reshape(gradient, shape(input)) - end subroutine backward - - - pure module subroutine forward(self, input) - class(flatten2d_layer), intent(in out) :: self - real, intent(in) :: input(:,:) - self % output = pack(input, .true.) - end subroutine forward - - - module subroutine init(self, input_shape) - class(flatten2d_layer), intent(in out) :: self - integer, intent(in) :: input_shape(:) - - self % input_shape = input_shape - self % output_size = product(input_shape) - - allocate(self % gradient(input_shape(1), input_shape(2))) - self % gradient = 0 - - allocate(self % output(self % output_size)) - self % output = 0 - - end subroutine init - -end submodule nf_flatten2d_layer_submodule diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index ef96a8fa..2983ddcd 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -8,7 +8,7 @@ module nf_layer_constructors implicit none private - public :: conv2d, dense, flatten, flatten2d, input, maxpool2d, reshape, linear2d + public :: conv2d, dense, flatten, input, maxpool2d, reshape, linear2d interface input @@ -125,25 +125,6 @@ module function flatten() result(res) !! Resulting layer instance end function flatten - module function flatten2d() result(res) - !! Flatten (2-d -> 1-d) layer constructor. - !! - !! Use this layer to chain layers with 2-d outputs to layers with 2-d - !! inputs. - !! - !! A flatten layer must not be the first layer in the network. - !! - !! Example: - !! - !! ``` - !! use nf, only :: flatten, layer - !! type(layer) :: flatten_layer - !! flatten_layer = flatten() - !! ``` - type(layer) :: res - !! Resulting layer instance - end function flatten2d - module function conv2d(filters, kernel_size, activation) result(res) !! 2-d convolutional layer constructor. !! diff --git a/test/test_flatten2d_layer.f90 b/test/test_flatten2d_layer.f90 deleted file mode 100644 index 3189b4e9..00000000 --- a/test/test_flatten2d_layer.f90 +++ /dev/null @@ -1,89 +0,0 @@ -program test_flatten2d_layer - - use iso_fortran_env, only: stderr => error_unit - use nf, only: dense, flatten2d, input, layer, network - use nf_flatten2d_layer, only: flatten2d_layer - use nf_input2d_layer, only: input2d_layer - - implicit none - - type(layer) :: test_layer, input_layer - type(network) :: net - real, allocatable :: gradient(:,:) - real, allocatable :: output(:) - logical :: ok = .true. - - test_layer = flatten2d() - - if (.not. test_layer % name == 'flatten2d') then - ok = .false. - write(stderr, '(a)') 'flatten2d layer has its name set correctly.. failed' - end if - - if (test_layer % initialized) then - ok = .false. - write(stderr, '(a)') 'flatten2d layer is not initialized yet.. failed' - end if - - input_layer = input(1, 2) - call test_layer % init(input_layer) - - if (.not. test_layer % initialized) then - ok = .false. - write(stderr, '(a)') 'flatten2d layer is now initialized.. failed' - end if - - if (.not. all(test_layer % layer_shape == [2])) then - ok = .false. - write(stderr, '(a)') 'flatten2d layer has an incorrect output shape.. failed' - end if - - ! Test forward pass - reshaping from 2-d to 1-d - - select type(this_layer => input_layer % p); type is(input2d_layer) - call this_layer % set(reshape(real([1, 2, 3, 4]), [2, 2])) - end select - - call test_layer % forward(input_layer) - call test_layer % get_output(output) - - if (.not. all(output == [1, 2, 3, 4])) then - ok = .false. - write(stderr, '(a)') 'flatten2d layer correctly propagates forward.. failed' - end if - - ! Test backward pass - reshaping from 1-d to 2-d - - ! Calling backward() will set the values on the gradient component - ! input_layer is used only to determine shape - call test_layer % backward(input_layer, real([1, 2, 3, 4])) - - select type(this_layer => test_layer % p); type is(flatten2d_layer) - gradient = this_layer % gradient - end select - - if (.not. all(gradient == reshape(real([1, 2, 3, 4]), [2, 2]))) then - ok = .false. - write(stderr, '(a)') 'flatten2d layer correctly propagates backward.. failed' - end if - - net = network([ & - input(28, 28), & - flatten2d(), & - dense(10) & - ]) - - ! Test that the output layer receives 784 elements in the input - if (.not. all(net % layers(3) % input_layer_shape == [784])) then - ok = .false. - write(stderr, '(a)') 'flatten2d layer correctly chains input2d to dense.. failed' - end if - - if (ok) then - print '(a)', 'test_flatten2d_layer: All tests passed.' - else - write(stderr, '(a)') 'test_flatten2d_layer: One or more tests failed.' - stop 1 - end if - -end program test_flatten2d_layer From d863ce704b93d09ac0de096674de8c9007353621 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 16 Feb 2025 23:14:09 +0400 Subject: [PATCH 09/71] linear2d_layer: remove public api --- src/nf.f90 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nf.f90 b/src/nf.f90 index d215eb85..e9b027c1 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -3,7 +3,7 @@ module nf use nf_datasets_mnist, only: label_digits, load_mnist use nf_layer, only: layer use nf_layer_constructors, only: & - conv2d, dense, flatten, flatten2d, input, maxpool2d, reshape, linear2d + conv2d, dense, flatten, input, maxpool2d, reshape, linear2d use nf_loss, only: mse, quadratic use nf_metrics, only: corr, maxabs use nf_network, only: network From 78eb17a41b7109dd8008e549d2b624901377dde1 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 16 Feb 2025 23:31:47 +0400 Subject: [PATCH 10/71] linear2d_layer: update cmakelists --- CMakeLists.txt | 2 -- test/CMakeLists.txt | 1 - 2 files changed, 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 586997fd..fc2ddfcb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,8 +28,6 @@ add_library(neural-fortran src/nf/nf_dense_layer_submodule.f90 src/nf/nf_flatten_layer.f90 src/nf/nf_flatten_layer_submodule.f90 - src/nf/nf_flatten2d_layer.f90 - src/nf/nf_flatten2d_layer_submodule.f90 src/nf/nf_input1d_layer.f90 src/nf/nf_input1d_layer_submodule.f90 src/nf/nf_input2d_layer.f90 diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b52a3781..12236416 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -8,7 +8,6 @@ foreach(execid conv2d_layer maxpool2d_layer flatten_layer - flatten2d_layer insert_flatten reshape_layer dense_network From 567abc4759a5ba3fc038a6501a01a5d93a50874e Mon Sep 17 00:00:00 2001 From: milancurcic Date: Sun, 16 Feb 2025 22:45:34 -0500 Subject: [PATCH 11/71] Add linear2d example --- example/CMakeLists.txt | 1 + example/linear2d.f90 | 29 +++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 example/linear2d.f90 diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 28cf71a7..0257dd7d 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -2,6 +2,7 @@ foreach(execid cnn_mnist dense_mnist get_set_network_params + linear2d network_parameters simple sine diff --git a/example/linear2d.f90 b/example/linear2d.f90 new file mode 100644 index 00000000..1b71f5d3 --- /dev/null +++ b/example/linear2d.f90 @@ -0,0 +1,29 @@ +program linear2d_example + + use nf, only: input, network, sgd, linear2d, mse, flatten + implicit none + + type(network) :: net + real :: x(3, 4) = reshape( & + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2], & + [3, 4]) + real :: y(3) = [0.12, 0.1, 0.3] + integer, parameter :: num_iterations = 500 + integer :: n + + net = network([ & + input(3, 4), & + linear2d(3, 4, 1), & + flatten() & + ]) + + call net % print_info() + + do n = 1, num_iterations + call net % forward(x) + call net % backward(y, mse()) + call net % update(optimizer=sgd(learning_rate=1.)) + print '(i4,3(3x,f8.6))', n, net % predict(x) + end do + +end program linear2d_example \ No newline at end of file From 32ac10de055eec8c68304023a338a6bbb54668a5 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Mon, 17 Feb 2025 12:10:44 +0400 Subject: [PATCH 12/71] linear2d_layer: remove redundant constructor args --- example/linear2d.f90 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/linear2d.f90 b/example/linear2d.f90 index 1b71f5d3..06c8b255 100644 --- a/example/linear2d.f90 +++ b/example/linear2d.f90 @@ -8,12 +8,12 @@ program linear2d_example [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2], & [3, 4]) real :: y(3) = [0.12, 0.1, 0.3] - integer, parameter :: num_iterations = 500 + integer, parameter :: num_iterations = 5 integer :: n net = network([ & input(3, 4), & - linear2d(3, 4, 1), & + linear2d(3, 1), & flatten() & ]) From edd169dcb603bcbab8d36e818460b7c23b81c9ac Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Mon, 17 Feb 2025 13:10:50 +0400 Subject: [PATCH 13/71] linear2d_layer: make example converge --- example/linear2d.f90 | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/example/linear2d.f90 b/example/linear2d.f90 index 06c8b255..79077723 100644 --- a/example/linear2d.f90 +++ b/example/linear2d.f90 @@ -5,10 +5,10 @@ program linear2d_example type(network) :: net real :: x(3, 4) = reshape( & - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2], & + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12, 0.13], & [3, 4]) - real :: y(3) = [0.12, 0.1, 0.3] - integer, parameter :: num_iterations = 5 + real :: y(3) = [0.12, 0.1, 0.2] + integer, parameter :: num_iterations = 9 integer :: n net = network([ & @@ -22,7 +22,7 @@ program linear2d_example do n = 1, num_iterations call net % forward(x) call net % backward(y, mse()) - call net % update(optimizer=sgd(learning_rate=1.)) + call net % update(optimizer=sgd(learning_rate=0.01)) print '(i4,3(3x,f8.6))', n, net % predict(x) end do From aa5b83ff455cbef5f519875a8756cc6bcdad06a0 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Mon, 17 Feb 2025 13:48:36 +0400 Subject: [PATCH 14/71] linear2d_layer: add loss stopping and more iterations --- example/linear2d.f90 | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/example/linear2d.f90 b/example/linear2d.f90 index 79077723..980d45e4 100644 --- a/example/linear2d.f90 +++ b/example/linear2d.f90 @@ -4,11 +4,14 @@ program linear2d_example implicit none type(network) :: net + type(mse) :: loss real :: x(3, 4) = reshape( & [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12, 0.13], & [3, 4]) real :: y(3) = [0.12, 0.1, 0.2] - integer, parameter :: num_iterations = 9 + real :: preds(3) + real :: loss_value + integer, parameter :: num_iterations = 500 integer :: n net = network([ & @@ -18,12 +21,19 @@ program linear2d_example ]) call net % print_info() + loss = mse() do n = 1, num_iterations call net % forward(x) - call net % backward(y, mse()) + call net % backward(y, loss) call net % update(optimizer=sgd(learning_rate=0.01)) - print '(i4,3(3x,f8.6))', n, net % predict(x) + preds = net % predict(x) + print '(i4,3(3x,f8.6))', n, preds + loss_value = loss % eval (y, preds) + if (loss_value < 0.01) then + print *, 'Loss: ', loss_value + return + end if end do end program linear2d_example \ No newline at end of file From dd3ce3323a1a3c4a1bdf821d0f86c67090a85451 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Fri, 31 Jan 2025 21:41:20 +0400 Subject: [PATCH 15/71] start impementing MultiHeadAttention --- src/nf.f90 | 2 + src/nf/nf_multihead_attention.f90 | 184 ++++++++++++++++++++++++ test/test_multihead_attention_layer.f90 | 97 +++++++++++++ 3 files changed, 283 insertions(+) create mode 100644 src/nf/nf_multihead_attention.f90 create mode 100644 test/test_multihead_attention_layer.f90 diff --git a/src/nf.f90 b/src/nf.f90 index e9b027c1..7223c1a3 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -12,4 +12,6 @@ module nf gaussian, linear, relu, leaky_relu, & sigmoid, softmax, softplus, step, tanhf, & celu + use nf_linear2d_layer, only: linear2d_layer + use nf_multihead_attention_layer, only: multihead_attention_layer end module nf diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 new file mode 100644 index 00000000..535c2c4f --- /dev/null +++ b/src/nf/nf_multihead_attention.f90 @@ -0,0 +1,184 @@ +module nf_multihead_attention_layer + + !! This module provides the concrete dense layer type. + !! It is used internally by the layer type. + !! It is not intended to be used directly by the user. + + use nf_activation, only: softmax + use nf_base_layer, only: base_layer + use nf_dense_layer, only: dense_layer + + implicit none + + private + public :: multihead_attention_layer + + type, extends(base_layer) :: multihead_attention_layer + + !! Concrete implementation of a multihead attention layer type + + integer :: model_dimension, batch_size, sequence_length, n_heads + + type(dense_layer) :: query_layer + type(dense_layer) :: key_layer + type(dense_layer) :: value_layer + type(dense_layer) :: output_layer + + type(softmax) :: softmax_func + + contains + +! procedure :: backward + procedure :: forward + procedure :: split_heads + procedure :: create_attention_matrix + procedure :: normalize_attention_matrix + procedure :: init + + end type multihead_attention_layer + + interface multihead_attention_layer + module function multihead_attention_layer_cons(batch_size, sequence_length, model_dimension, n_heads) result(res) + !! This function returns the `multihead_attention_layer` instance. + integer, intent(in) :: batch_size, sequence_length, model_dimension, n_heads + type(multihead_attention_layer) :: res + end function multihead_attention_layer_cons + end interface multihead_attention_layer + + interface + + pure module subroutine backward(self, input, gradient) + !! Apply the backward gradient descent pass. + !! Only weight and bias gradients are updated in this subroutine, + !! while the weights and biases themselves are untouched. + class(multihead_attention_layer), intent(in out) :: self + !! Dense layer instance + real, intent(in) :: input(:) + !! Input from the previous layer + real, intent(in) :: gradient(:) + !! Gradient from the next layer + end subroutine backward + + pure module subroutine forward(self, query, key, value) + !! Propagate forward the layer. + !! Calling this subroutine updates the values of a few data components + !! of `dense_layer` that are needed for the backward pass. + class(multihead_attention_layer), intent(in out) :: self + !! Dense layer instance + real, intent(in) :: query(:), key(:), value(:) + !! Input from the previous layer + end subroutine forward + + module subroutine init(self, input_shape) + !! Initialize the layer data structures. + !! + !! This is a deferred procedure from the `base_layer` abstract type. + class(multihead_attention_layer), intent(in out) :: self + !! Dense layer instance + integer, intent(in) :: input_shape(:) + !! Shape of the input layer + end subroutine init + + end interface + +contains + module function multihead_attention_layer_cons(& + batch_size, sequence_length, model_dimension, n_heads) result(res) + integer, intent(in) :: batch_size, sequence_length, model_dimension, n_heads + type(multihead_attention_layer) :: res + res % batch_size = batch_size + res % sequence_length = sequence_length + res % model_dimension = model_dimension + res % n_heads = n_heads + + res % query_layer = dense_layer(input_size=model_dimension, output_size=model_dimension) + res % key_layer = dense_layer(input_size=model_dimension, output_size=model_dimension) + res % value_layer = dense_layer(input_size=model_dimension, output_size=model_dimension) + res % output_layer = dense_layer(input_size=model_dimension, output_size=model_dimension) + + res % softmax_func = softmax() + end function multihead_attention_layer_cons + + pure module subroutine forward(self, query, key, value) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: query(:), key(:), value(:) + +! self % z = matmul(input, self % weights) + self % biases +! self % output = self % activation % eval(self % z) + + end subroutine forward + + module function split_heads(self, input) result(output) + !! Split inputs into heads + !! + !! Example with two heads: + !! input (1, 3, 4): + !! [[[0. , 0.3 , 0.6 , 0.9 ], + !! [0.1 , 0.4 , 0.7 , 0.11], + !! [0.2 , 0.5 , 0.8 , 0.12]]] + !! output (1, 2, 3, 2) + !! [[[[0. , 0.3 ], + ! [0.1 , 0.4 ], + ! [0.2 , 0.5 ]], + ! [[0.6 , 0.9 ], + ! [0.7 , 0.11], + ! [0.8 , 0.12]]]] + class(multihead_attention_layer) :: self + real :: input(:, :, :) + real :: output(self % batch_size, self % n_heads, self % sequence_length, self % model_dimension / self % n_heads) + output = reshape(& + input,& + [self % batch_size, self % n_heads, self % sequence_length, self % model_dimension / self % n_heads],& + order=[1, 3, 4, 2]& + ) + end function split_heads + + module function create_attention_matrix(self, query, key) result(output) + !! Create attention matrix for query and key + class(multihead_attention_layer) :: self + real :: query(:, :, :, :) + real :: key(:, :, :, :) + real :: output(self % batch_size, self % n_heads, self % sequence_length, self % sequence_length) + integer :: i, j + ! create attention matrix for each sequence in each batch + do i = 1, size(query(:, 1, 1, 1)) + do j = 1, size(query(1, :, 1, 1)) + output(i, j, :, :) = matmul(query(i, j, :, :), transpose(key(i, j, :, :))) + end do + end do + end function create_attention_matrix + + module function normalize_attention_matrix(self, attention_matrix, attention_mask) result(output) + !! Create attention matrix for query and key + class(multihead_attention_layer) :: self + real :: attention_matrix(:, :, :, :) + !! (batch_size, n_heads, sequence_length, sequence_length) + real, optional :: attention_mask(:, :, :, :) + !! (batch_size, n_heads, sequence_length, sequence_length) + real :: output(self % batch_size, self % n_heads, self % sequence_length, self % sequence_length) + integer :: i, j, k + real :: d_k + + ! scale dowm by square root of each head's size + d_k = self % model_dimension / self % n_heads + attention_matrix = attention_matrix / sqrt(d_k) + ! attention mask is used to mask out some of the tokens if necessary + if (present(attention_mask)) then + attention_matrix = attention_matrix + attention_mask + end if + ! softmax by last dimension + do i = 1, size(output, 1) + do j = 1, size(output, 2) + do k = 1, size(output, 3) + output(i, j, k, :) = self % softmax_func % eval_1d(attention_matrix(i, j, k, :)) + end do + end do + end do + end function normalize_attention_matrix + + module subroutine init(self, input_shape) + class(multihead_attention_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + end subroutine init +end module nf_multihead_attention_layer diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 new file mode 100644 index 00000000..8a5778ee --- /dev/null +++ b/test/test_multihead_attention_layer.f90 @@ -0,0 +1,97 @@ +program test_multihead_attention_layer + use iso_fortran_env, only: stderr => error_unit + use nf_multihead_attention_layer, only: multihead_attention_layer + implicit none + + logical :: ok = .true. + type(multihead_attention_layer) :: attention + real :: sample_input(1, 3, 4) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [1, 3, 4]) + real :: split_heads_output(1, 2, 3, 2) + real :: raw_attention_matrix(1, 2, 3, 3) + real :: normalized_attention_matrix(1, 2, 3, 3) + + attention = multihead_attention_layer(batch_size=1, sequence_length=3, model_dimension=4, n_heads=2) + + call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output) + call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok, raw_attention_matrix) + call test_multihead_attention_normalization(attention, raw_attention_matrix, ok) + +contains + subroutine test_multihead_attention_split_heads(attention, input, ok, output) + type(multihead_attention_layer), intent(in) :: attention + real, intent(in) :: input(:, :, :) + logical, intent(in out) :: ok + real, intent(in out) :: output(1, 2, 3, 2) + real :: output_shape(4) + real :: expected_shape(4) = [1, 2, 3, 2] + real :: output_flat(12) + real :: expected_output_flat(12) = [0.0, 0.6, 0.1, 0.7, 0.2, 0.8, 0.3, 0.9, 0.4, 0.11, 0.5, 0.12] + + output = attention % split_heads(input) + + output_shape = shape(output) + if (.not. all(output_shape.eq.expected_shape)) then + ok = .false. + write(stderr, '(a)') 'split_heads returned incorrect shape.. failed' + end if + output_flat = reshape(output, shape(output_flat)) + if (.not. all(output_flat.eq.expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'split_heads returned incorrect values.. failed' + end if + end subroutine test_multihead_attention_split_heads + + subroutine test_multihead_attention_create_attention_matrix(attention, input, ok, attention_matrix) + type(multihead_attention_layer), intent(in) :: attention + real, intent(in) :: input(:, :, :, :) + logical, intent(in out) :: ok + real, intent(in out) :: attention_matrix(1, 2, 3, 3) + real :: attention_matrix_shape(4) + real :: attention_matrix_flat(18) + real :: expected_shape(4) = [1, 2, 3, 3] + real :: expected_attention_matrix_flat(18) = [& + 9.00000036E-02, 1.16999996, 0.120000005,& + 0.518999994, 0.150000006, 0.588000000,& + 0.120000005, 0.518999994, 0.170000002,& + 0.502099991, 0.219999999, 0.573199987,& + 0.150000006, 0.588000000, 0.219999999,& + 0.573199987, 0.289999992, 0.654400051& + ] + + attention_matrix = attention % create_attention_matrix(input, input) + + attention_matrix_shape = shape(attention_matrix) + if (.not. all(attention_matrix_shape.eq.expected_shape)) then + ok = .false. + write(stderr, '(a)') 'create_attention_matrix returned incorrect shape.. failed' + end if + attention_matrix_flat = reshape(attention_matrix, shape(expected_attention_matrix_flat)) + if (.not. all(attention_matrix_flat.eq.expected_attention_matrix_flat)) then + ok = .false. + write(stderr, '(a)') 'create_attention_matrix returned incorrect values.. failed' + end if + end subroutine test_multihead_attention_create_attention_matrix + + subroutine test_multihead_attention_normalization(attention, input, ok) + type(multihead_attention_layer), intent(in) :: attention + real, intent(in) :: input(:, :, :, :) + logical, intent(in out) :: ok + real :: output(1, 2, 3, 3) + real :: output_flat(18) + real :: expected_output_flat(18) = [& + 0.326287806, 0.435975075, 0.321620107, 0.330339342, 0.316976935, 0.329200655,& + 0.333283335, 0.275134116, 0.333194464, 0.326415271, 0.333061278, 0.325773478,& + 0.340428889, 0.288890868, 0.345185399, 0.343245387, 0.349961787, 0.345025837& + ] + integer :: i, j, k + real :: d_k, exp_x + + output = attention % normalize_attention_matrix(input) + + output_flat = reshape(output, shape(output_flat)) + if (.not. all(output_flat.eq.expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'normalize_attention_matrix returned incorrect values.. failed' + end if + end subroutine test_multihead_attention_normalization +end program test_multihead_attention_layer \ No newline at end of file From 0ed77e857f0ee7cd63a0d23095afa1b70c0450b1 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Fri, 31 Jan 2025 22:39:01 +0400 Subject: [PATCH 16/71] scaled dot product attention --- src/nf/nf_multihead_attention.f90 | 37 +++++++++++++++++-------- test/test_multihead_attention_layer.f90 | 32 +++++++++++++++++++-- 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 535c2c4f..628f7b79 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -1,9 +1,5 @@ module nf_multihead_attention_layer - - !! This module provides the concrete dense layer type. - !! It is used internally by the layer type. - !! It is not intended to be used directly by the user. - + use iso_fortran_env, only: stderr => error_unit use nf_activation, only: softmax use nf_base_layer, only: base_layer use nf_dense_layer, only: dense_layer @@ -17,7 +13,7 @@ module nf_multihead_attention_layer !! Concrete implementation of a multihead attention layer type - integer :: model_dimension, batch_size, sequence_length, n_heads + integer :: model_dimension, batch_size, sequence_length, n_heads, head_size type(dense_layer) :: query_layer type(dense_layer) :: key_layer @@ -33,6 +29,7 @@ module nf_multihead_attention_layer procedure :: split_heads procedure :: create_attention_matrix procedure :: normalize_attention_matrix + procedure :: scaled_dot_product_attention procedure :: init end type multihead_attention_layer @@ -91,6 +88,12 @@ module function multihead_attention_layer_cons(& res % model_dimension = model_dimension res % n_heads = n_heads + if (mod(model_dimension, n_heads) /= 0) then + write(stderr, '(a)'), 'Number of heads must be divisible by model dimension' + error stop + end if + res % head_size = model_dimension / n_heads + res % query_layer = dense_layer(input_size=model_dimension, output_size=model_dimension) res % key_layer = dense_layer(input_size=model_dimension, output_size=model_dimension) res % value_layer = dense_layer(input_size=model_dimension, output_size=model_dimension) @@ -125,10 +128,10 @@ module function split_heads(self, input) result(output) ! [0.8 , 0.12]]]] class(multihead_attention_layer) :: self real :: input(:, :, :) - real :: output(self % batch_size, self % n_heads, self % sequence_length, self % model_dimension / self % n_heads) + real :: output(self % batch_size, self % n_heads, self % sequence_length, self % head_size) output = reshape(& input,& - [self % batch_size, self % n_heads, self % sequence_length, self % model_dimension / self % n_heads],& + [self % batch_size, self % n_heads, self % sequence_length, self % head_size],& order=[1, 3, 4, 2]& ) end function split_heads @@ -157,11 +160,9 @@ module function normalize_attention_matrix(self, attention_matrix, attention_mas !! (batch_size, n_heads, sequence_length, sequence_length) real :: output(self % batch_size, self % n_heads, self % sequence_length, self % sequence_length) integer :: i, j, k - real :: d_k ! scale dowm by square root of each head's size - d_k = self % model_dimension / self % n_heads - attention_matrix = attention_matrix / sqrt(d_k) + attention_matrix = attention_matrix / sqrt(real(self % head_size)) ! attention mask is used to mask out some of the tokens if necessary if (present(attention_mask)) then attention_matrix = attention_matrix + attention_mask @@ -176,6 +177,20 @@ module function normalize_attention_matrix(self, attention_matrix, attention_mas end do end function normalize_attention_matrix + module function scaled_dot_product_attention(self, attention_matrix, value) result(output) + class(multihead_attention_layer) :: self + real :: attention_matrix(:, :, :, :) + real :: value(:, :, :, :) + real :: output(self % batch_size, self % n_heads, self % sequence_length, self % head_size) + integer :: i, j + + do i = 1, size(attention_matrix, 1) + do j = 1, size(attention_matrix, 2) + output(i, j, :, :) = matmul(attention_matrix(i, j, :, :), value(i, j, :, :)) + end do + end do + end function scaled_dot_product_attention + module subroutine init(self, input_shape) class(multihead_attention_layer), intent(in out) :: self integer, intent(in) :: input_shape(:) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index 8a5778ee..2af98329 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -14,7 +14,10 @@ program test_multihead_attention_layer call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output) call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok, raw_attention_matrix) - call test_multihead_attention_normalization(attention, raw_attention_matrix, ok) + call test_multihead_attention_normalization(attention, raw_attention_matrix, ok, normalized_attention_matrix) + call test_multihead_attention_scaled_dot_product_attention(& + attention, normalized_attention_matrix, split_heads_output, ok& + ) contains subroutine test_multihead_attention_split_heads(attention, input, ok, output) @@ -72,11 +75,11 @@ subroutine test_multihead_attention_create_attention_matrix(attention, input, ok end if end subroutine test_multihead_attention_create_attention_matrix - subroutine test_multihead_attention_normalization(attention, input, ok) + subroutine test_multihead_attention_normalization(attention, input, ok, output) type(multihead_attention_layer), intent(in) :: attention real, intent(in) :: input(:, :, :, :) logical, intent(in out) :: ok - real :: output(1, 2, 3, 3) + real, intent(out) :: output(1, 2, 3, 3) real :: output_flat(18) real :: expected_output_flat(18) = [& 0.326287806, 0.435975075, 0.321620107, 0.330339342, 0.316976935, 0.329200655,& @@ -94,4 +97,27 @@ subroutine test_multihead_attention_normalization(attention, input, ok) write(stderr, '(a)') 'normalize_attention_matrix returned incorrect values.. failed' end if end subroutine test_multihead_attention_normalization + + subroutine test_multihead_attention_scaled_dot_product_attention(attention, attention_matrix, value, ok) + type(multihead_attention_layer), intent(in) :: attention + real, intent(in) :: attention_matrix(:, :, :, :) + real, intent(in) :: value(:, :, :, :) + logical, intent(in out) :: ok + real :: output(1, 2, 3, 2) + real :: output_flat(12) + real :: expected_output_flat(12) = [& + 0.101414114, 0.685291648, 0.102356531, 0.701290607, 0.103298485, 0.701582491,& + 0.401414126, 0.457309216, 0.402356505, 0.374400526, 0.403298497, 0.373518765& + ] + integer :: i, j, k + real :: d_k, exp_x + + output = attention % scaled_dot_product_attention(attention_matrix, value) + + output_flat = reshape(output, shape(output_flat)) + if (.not. all(output_flat.eq.expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'normalize_attention_matrix returned incorrect values.. failed' + end if + end subroutine test_multihead_attention_scaled_dot_product_attention end program test_multihead_attention_layer \ No newline at end of file From d6e6f3eb908242cbe4dbf57f057734cc87d35863 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sat, 1 Feb 2025 00:35:04 +0400 Subject: [PATCH 17/71] combine attention heads --- src/nf/nf_multihead_attention.f90 | 36 ++++++++++++++++-------- test/test_multihead_attention_layer.f90 | 37 +++++++++++++++++++++---- 2 files changed, 55 insertions(+), 18 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 628f7b79..81b5194f 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -13,7 +13,7 @@ module nf_multihead_attention_layer !! Concrete implementation of a multihead attention layer type - integer :: model_dimension, batch_size, sequence_length, n_heads, head_size + integer :: batch_size, sequence_length, model_dimension, n_heads, head_size type(dense_layer) :: query_layer type(dense_layer) :: key_layer @@ -22,6 +22,8 @@ module nf_multihead_attention_layer type(softmax) :: softmax_func +! real :: output(batch_size, sequence_length, model_dimension) + contains ! procedure :: backward @@ -30,6 +32,7 @@ module nf_multihead_attention_layer procedure :: create_attention_matrix procedure :: normalize_attention_matrix procedure :: scaled_dot_product_attention + procedure :: combine_heads procedure :: init end type multihead_attention_layer @@ -57,13 +60,8 @@ pure module subroutine backward(self, input, gradient) end subroutine backward pure module subroutine forward(self, query, key, value) - !! Propagate forward the layer. - !! Calling this subroutine updates the values of a few data components - !! of `dense_layer` that are needed for the backward pass. class(multihead_attention_layer), intent(in out) :: self - !! Dense layer instance - real, intent(in) :: query(:), key(:), value(:) - !! Input from the previous layer + real, intent(in) :: query(:, :, :, :), key(:, :, :, :), value(:, :, :, :) end subroutine forward module subroutine init(self, input_shape) @@ -104,11 +102,7 @@ end function multihead_attention_layer_cons pure module subroutine forward(self, query, key, value) class(multihead_attention_layer), intent(in out) :: self - real, intent(in) :: query(:), key(:), value(:) - -! self % z = matmul(input, self % weights) + self % biases -! self % output = self % activation % eval(self % z) - + real, intent(in) :: query(:, :, :, :), key(:, :, :, :), value(:, :, :, :) end subroutine forward module function split_heads(self, input) result(output) @@ -191,6 +185,24 @@ module function scaled_dot_product_attention(self, attention_matrix, value) resu end do end function scaled_dot_product_attention + module function combine_heads(self, input) result(output) + class(multihead_attention_layer) :: self + real :: input(:, :, :, :) + !! (batch_size, n_heads, sequence_length, head_size) + real :: output(self % batch_size, self % sequence_length, self % model_dimension) + !! (batch_size, sequence_length, model_dimension) + + real :: scaled_dp_att_reshaped(self % batch_size, self % sequence_length, self % n_heads, self % head_size) + integer :: i, j + + scaled_dp_att_reshaped = reshape(input, shape(scaled_dp_att_reshaped), order=[1, 4, 2, 3]) + do i = 1, size(scaled_dp_att_reshaped, 1) + do j = 1, size(scaled_dp_att_reshaped, 2) + output(i, j, :) = reshape(scaled_dp_att_reshaped(i, j, :, :), [4]) + end do + end do + end function combine_heads + module subroutine init(self, input_shape) class(multihead_attention_layer), intent(in out) :: self integer, intent(in) :: input_shape(:) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index 2af98329..e4754d31 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -9,6 +9,10 @@ program test_multihead_attention_layer real :: split_heads_output(1, 2, 3, 2) real :: raw_attention_matrix(1, 2, 3, 3) real :: normalized_attention_matrix(1, 2, 3, 3) + real :: scaled_dp_att(1, 2, 3, 2) + real :: scaled_dp_att_reshaped(1, 3, 2, 2) + real :: combined_attention(1, 3, 4) + integer :: i, j attention = multihead_attention_layer(batch_size=1, sequence_length=3, model_dimension=4, n_heads=2) @@ -16,8 +20,9 @@ program test_multihead_attention_layer call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok, raw_attention_matrix) call test_multihead_attention_normalization(attention, raw_attention_matrix, ok, normalized_attention_matrix) call test_multihead_attention_scaled_dot_product_attention(& - attention, normalized_attention_matrix, split_heads_output, ok& + attention, normalized_attention_matrix, split_heads_output, ok, scaled_dp_att& ) + call test_multihead_attention_combine_heads(attention, scaled_dp_att, ok) contains subroutine test_multihead_attention_split_heads(attention, input, ok, output) @@ -98,26 +103,46 @@ subroutine test_multihead_attention_normalization(attention, input, ok, output) end if end subroutine test_multihead_attention_normalization - subroutine test_multihead_attention_scaled_dot_product_attention(attention, attention_matrix, value, ok) + subroutine test_multihead_attention_scaled_dot_product_attention(attention, attention_matrix, value, ok, output) type(multihead_attention_layer), intent(in) :: attention real, intent(in) :: attention_matrix(:, :, :, :) real, intent(in) :: value(:, :, :, :) logical, intent(in out) :: ok - real :: output(1, 2, 3, 2) + real, intent(out) :: output(& + attention % batch_size, attention % n_heads, attention % sequence_length, attention % head_size& + ) real :: output_flat(12) real :: expected_output_flat(12) = [& 0.101414114, 0.685291648, 0.102356531, 0.701290607, 0.103298485, 0.701582491,& 0.401414126, 0.457309216, 0.402356505, 0.374400526, 0.403298497, 0.373518765& ] - integer :: i, j, k - real :: d_k, exp_x output = attention % scaled_dot_product_attention(attention_matrix, value) output_flat = reshape(output, shape(output_flat)) if (.not. all(output_flat.eq.expected_output_flat)) then ok = .false. - write(stderr, '(a)') 'normalize_attention_matrix returned incorrect values.. failed' + write(stderr, '(a)') 'scaled_dot_product_attention returned incorrect values.. failed' end if end subroutine test_multihead_attention_scaled_dot_product_attention + + subroutine test_multihead_attention_combine_heads(attention, scaled_dp_att, ok) + type(multihead_attention_layer), intent(in) :: attention + real, intent(in) :: scaled_dp_att(:, :, :, :) + logical, intent(in out) :: ok + real :: output(attention % batch_size, attention % sequence_length, attention % model_dimension) + real :: output_flat(12) + real :: expected_output_flat(12) = [& + 0.101414114, 0.102356531, 0.103298485, 0.401414126, 0.402356505, 0.403298497,& + 0.685291648, 0.701290607, 0.701582491, 0.457309216, 0.374400526, 0.373518765& + ] + + output = attention % combine_heads(scaled_dp_att) + + output_flat = reshape(output, shape(output_flat)) + if (.not. all(output_flat.eq.expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'combine_heads returned incorrect values.. failed' + end if + end subroutine test_multihead_attention_combine_heads end program test_multihead_attention_layer \ No newline at end of file From eb58006b1483350cdf6d8cfda29d0f80e8507e6e Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sat, 1 Feb 2025 01:51:42 +0400 Subject: [PATCH 18/71] forward (not working) --- src/nf/nf_multihead_attention.f90 | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 81b5194f..467447b3 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -22,7 +22,7 @@ module nf_multihead_attention_layer type(softmax) :: softmax_func -! real :: output(batch_size, sequence_length, model_dimension) + real, allocatable :: output(:, :, :) contains @@ -102,7 +102,26 @@ end function multihead_attention_layer_cons pure module subroutine forward(self, query, key, value) class(multihead_attention_layer), intent(in out) :: self - real, intent(in) :: query(:, :, :, :), key(:, :, :, :), value(:, :, :, :) + real, intent(in) :: query(:, :, :), key(:, :, :), value(:, :, :) + + real :: q(self % batch_size, self % n_heads, self % sequence_length, self % head_size) + real :: k(self % batch_size, self % n_heads, self % sequence_length, self % head_size) + real :: v(self % batch_size, self % n_heads, self % sequence_length, self % head_size) + real :: attention_matrix(self % batch_size, self % n_heads, self % sequence_length, self % sequence_length) + real :: dot_product_attention(self % batch_size, self % n_heads, self % sequence_length, self % head_size) + + call self % query_layer % forward(query) + call self % key_layer % forward(key) + call self % value_layer % forward(value) + + q = self % split_heads(self % query_layer % output) + k = self % split_heads(self % key_layer % output) + v = self % split_heads(self % value_layer % output) + + attention_matrix = self % normalize_attention_matrix(self % create_attention_matrix(q, k)) + dot_product_attention = self % scaled_dot_product_attention(attention_matrix, v) + + self % output = self % output_layer % forward(self % combine_heads(dot_product_attention)) end subroutine forward module function split_heads(self, input) result(output) @@ -207,5 +226,6 @@ module subroutine init(self, input_shape) class(multihead_attention_layer), intent(in out) :: self integer, intent(in) :: input_shape(:) + allocate(self % output(self % batch_size, self % sequence_length, self % model_dimension)) end subroutine init end module nf_multihead_attention_layer From 452032e7a620a1cae4442cb3c48a139a22e2a0b6 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Wed, 5 Feb 2025 22:46:43 +0400 Subject: [PATCH 19/71] rearrange attention dimensions in more efficient way --- src/nf/nf_multihead_attention.f90 | 85 ++++++++++++------------- test/test_multihead_attention_layer.f90 | 28 ++++---- 2 files changed, 55 insertions(+), 58 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 467447b3..4260293f 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -61,7 +61,7 @@ end subroutine backward pure module subroutine forward(self, query, key, value) class(multihead_attention_layer), intent(in out) :: self - real, intent(in) :: query(:, :, :, :), key(:, :, :, :), value(:, :, :, :) + real, intent(in) :: query(:, :, :), key(:, :, :), value(:, :, :) end subroutine forward module subroutine init(self, input_shape) @@ -110,18 +110,18 @@ pure module subroutine forward(self, query, key, value) real :: attention_matrix(self % batch_size, self % n_heads, self % sequence_length, self % sequence_length) real :: dot_product_attention(self % batch_size, self % n_heads, self % sequence_length, self % head_size) - call self % query_layer % forward(query) - call self % key_layer % forward(key) - call self % value_layer % forward(value) - - q = self % split_heads(self % query_layer % output) - k = self % split_heads(self % key_layer % output) - v = self % split_heads(self % value_layer % output) - - attention_matrix = self % normalize_attention_matrix(self % create_attention_matrix(q, k)) - dot_product_attention = self % scaled_dot_product_attention(attention_matrix, v) - - self % output = self % output_layer % forward(self % combine_heads(dot_product_attention)) +! call self % query_layer % forward(query) +! call self % key_layer % forward(key) +! call self % value_layer % forward(value) +! +! q = self % split_heads(self % query_layer % output) +! k = self % split_heads(self % key_layer % output) +! v = self % split_heads(self % value_layer % output) +! +! attention_matrix = self % normalize_attention_matrix(self % create_attention_matrix(q, k)) +! dot_product_attention = self % scaled_dot_product_attention(attention_matrix, v) +! +! self % output = self % output_layer % forward(self % combine_heads(dot_product_attention)) end subroutine forward module function split_heads(self, input) result(output) @@ -141,11 +141,12 @@ module function split_heads(self, input) result(output) ! [0.8 , 0.12]]]] class(multihead_attention_layer) :: self real :: input(:, :, :) - real :: output(self % batch_size, self % n_heads, self % sequence_length, self % head_size) + real :: output(self % n_heads, self % sequence_length, self % head_size, self % batch_size) + ! FIXME: if anybody knows how to also swap first two dims in one go, pls tell me output = reshape(& input,& - [self % batch_size, self % n_heads, self % sequence_length, self % head_size],& - order=[1, 3, 4, 2]& + [self % n_heads, self % sequence_length, self % head_size, self % batch_size],& + order=[2, 4, 3, 1]& ) end function split_heads @@ -154,12 +155,12 @@ module function create_attention_matrix(self, query, key) result(output) class(multihead_attention_layer) :: self real :: query(:, :, :, :) real :: key(:, :, :, :) - real :: output(self % batch_size, self % n_heads, self % sequence_length, self % sequence_length) + real :: output(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size) integer :: i, j ! create attention matrix for each sequence in each batch - do i = 1, size(query(:, 1, 1, 1)) - do j = 1, size(query(1, :, 1, 1)) - output(i, j, :, :) = matmul(query(i, j, :, :), transpose(key(i, j, :, :))) + do i = 1, self % batch_size + do j = 1, self % n_heads + output(j, :, :, i) = matmul(query(j, :, :, i), transpose(key(j, :, :, i))) end do end do end function create_attention_matrix @@ -171,8 +172,8 @@ module function normalize_attention_matrix(self, attention_matrix, attention_mas !! (batch_size, n_heads, sequence_length, sequence_length) real, optional :: attention_mask(:, :, :, :) !! (batch_size, n_heads, sequence_length, sequence_length) - real :: output(self % batch_size, self % n_heads, self % sequence_length, self % sequence_length) - integer :: i, j, k + real :: output(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size) + integer :: batch, head, seq ! scale dowm by square root of each head's size attention_matrix = attention_matrix / sqrt(real(self % head_size)) @@ -180,11 +181,11 @@ module function normalize_attention_matrix(self, attention_matrix, attention_mas if (present(attention_mask)) then attention_matrix = attention_matrix + attention_mask end if - ! softmax by last dimension - do i = 1, size(output, 1) - do j = 1, size(output, 2) - do k = 1, size(output, 3) - output(i, j, k, :) = self % softmax_func % eval_1d(attention_matrix(i, j, k, :)) + ! softmax by second sequnce_length + do batch = 1, self % batch_size + do head = 1, self % n_heads + do seq = 1, self % sequence_length + output(head, seq, :, batch) = self % softmax_func % eval_1d(attention_matrix(head, seq, :, batch)) end do end do end do @@ -194,12 +195,12 @@ module function scaled_dot_product_attention(self, attention_matrix, value) resu class(multihead_attention_layer) :: self real :: attention_matrix(:, :, :, :) real :: value(:, :, :, :) - real :: output(self % batch_size, self % n_heads, self % sequence_length, self % head_size) - integer :: i, j + real :: output(self % n_heads, self % sequence_length, self % head_size, self % batch_size) + integer :: batch, head - do i = 1, size(attention_matrix, 1) - do j = 1, size(attention_matrix, 2) - output(i, j, :, :) = matmul(attention_matrix(i, j, :, :), value(i, j, :, :)) + do batch = 1, self % batch_size + do head = 1, self % n_heads + output(head, :, :, batch) = matmul(attention_matrix(head, :, :, batch), value(head, :, :, batch)) end do end do end function scaled_dot_product_attention @@ -207,17 +208,15 @@ end function scaled_dot_product_attention module function combine_heads(self, input) result(output) class(multihead_attention_layer) :: self real :: input(:, :, :, :) - !! (batch_size, n_heads, sequence_length, head_size) - real :: output(self % batch_size, self % sequence_length, self % model_dimension) - !! (batch_size, sequence_length, model_dimension) - - real :: scaled_dp_att_reshaped(self % batch_size, self % sequence_length, self % n_heads, self % head_size) - integer :: i, j - - scaled_dp_att_reshaped = reshape(input, shape(scaled_dp_att_reshaped), order=[1, 4, 2, 3]) - do i = 1, size(scaled_dp_att_reshaped, 1) - do j = 1, size(scaled_dp_att_reshaped, 2) - output(i, j, :) = reshape(scaled_dp_att_reshaped(i, j, :, :), [4]) + !! (n_heads, sequence_length, head_size, batch_size) + real :: output(self % sequence_length, self % model_dimension, self % batch_size) + integer :: batch, seq + + do batch = 1, self % batch_size + do seq = 1, self % sequence_length + output(seq, :, batch) = reshape(& + transpose(input(:, seq, :, batch)), [self % model_dimension]& + ) end do end do end function combine_heads diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index e4754d31..d2c8d093 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -5,13 +5,13 @@ program test_multihead_attention_layer logical :: ok = .true. type(multihead_attention_layer) :: attention - real :: sample_input(1, 3, 4) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [1, 3, 4]) - real :: split_heads_output(1, 2, 3, 2) - real :: raw_attention_matrix(1, 2, 3, 3) - real :: normalized_attention_matrix(1, 2, 3, 3) - real :: scaled_dp_att(1, 2, 3, 2) + real :: sample_input(3, 4, 1) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [3, 4, 1]) + real :: split_heads_output(2, 3, 2, 1) + real :: raw_attention_matrix(2, 3, 3, 1) + real :: normalized_attention_matrix(2, 3, 3, 1) + real :: scaled_dp_att(2, 3, 2, 1) real :: scaled_dp_att_reshaped(1, 3, 2, 2) - real :: combined_attention(1, 3, 4) + real :: combined_attention(3, 4, 1) integer :: i, j attention = multihead_attention_layer(batch_size=1, sequence_length=3, model_dimension=4, n_heads=2) @@ -29,9 +29,9 @@ subroutine test_multihead_attention_split_heads(attention, input, ok, output) type(multihead_attention_layer), intent(in) :: attention real, intent(in) :: input(:, :, :) logical, intent(in out) :: ok - real, intent(in out) :: output(1, 2, 3, 2) + real, intent(in out) :: output(2, 3, 2, 1) real :: output_shape(4) - real :: expected_shape(4) = [1, 2, 3, 2] + real :: expected_shape(4) = [2, 3, 2, 1] real :: output_flat(12) real :: expected_output_flat(12) = [0.0, 0.6, 0.1, 0.7, 0.2, 0.8, 0.3, 0.9, 0.4, 0.11, 0.5, 0.12] @@ -53,10 +53,10 @@ subroutine test_multihead_attention_create_attention_matrix(attention, input, ok type(multihead_attention_layer), intent(in) :: attention real, intent(in) :: input(:, :, :, :) logical, intent(in out) :: ok - real, intent(in out) :: attention_matrix(1, 2, 3, 3) + real, intent(in out) :: attention_matrix(2, 3, 3, 1) real :: attention_matrix_shape(4) real :: attention_matrix_flat(18) - real :: expected_shape(4) = [1, 2, 3, 3] + real :: expected_shape(4) = [2, 3, 3, 1] real :: expected_attention_matrix_flat(18) = [& 9.00000036E-02, 1.16999996, 0.120000005,& 0.518999994, 0.150000006, 0.588000000,& @@ -84,15 +84,13 @@ subroutine test_multihead_attention_normalization(attention, input, ok, output) type(multihead_attention_layer), intent(in) :: attention real, intent(in) :: input(:, :, :, :) logical, intent(in out) :: ok - real, intent(out) :: output(1, 2, 3, 3) + real, intent(out) :: output(2, 3, 3, 1) real :: output_flat(18) real :: expected_output_flat(18) = [& 0.326287806, 0.435975075, 0.321620107, 0.330339342, 0.316976935, 0.329200655,& 0.333283335, 0.275134116, 0.333194464, 0.326415271, 0.333061278, 0.325773478,& 0.340428889, 0.288890868, 0.345185399, 0.343245387, 0.349961787, 0.345025837& ] - integer :: i, j, k - real :: d_k, exp_x output = attention % normalize_attention_matrix(input) @@ -109,7 +107,7 @@ subroutine test_multihead_attention_scaled_dot_product_attention(attention, atte real, intent(in) :: value(:, :, :, :) logical, intent(in out) :: ok real, intent(out) :: output(& - attention % batch_size, attention % n_heads, attention % sequence_length, attention % head_size& + attention % n_heads, attention % sequence_length, attention % head_size, attention % batch_size& ) real :: output_flat(12) real :: expected_output_flat(12) = [& @@ -130,7 +128,7 @@ subroutine test_multihead_attention_combine_heads(attention, scaled_dp_att, ok) type(multihead_attention_layer), intent(in) :: attention real, intent(in) :: scaled_dp_att(:, :, :, :) logical, intent(in out) :: ok - real :: output(attention % batch_size, attention % sequence_length, attention % model_dimension) + real :: output(attention % sequence_length, attention % model_dimension, attention % batch_size) real :: output_flat(12) real :: expected_output_flat(12) = [& 0.101414114, 0.102356531, 0.103298485, 0.401414126, 0.402356505, 0.403298497,& From e06d39be46dd4c349bec942e0973910e43f1807e Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Thu, 6 Feb 2025 00:17:30 +0400 Subject: [PATCH 20/71] initial forward implementation for multi-head attention --- src/nf/nf_multihead_attention.f90 | 68 +++++++++++++++++-------------- 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 4260293f..91e6be96 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -2,7 +2,7 @@ module nf_multihead_attention_layer use iso_fortran_env, only: stderr => error_unit use nf_activation, only: softmax use nf_base_layer, only: base_layer - use nf_dense_layer, only: dense_layer + use nf_linear2d_layer, only: linear2d_layer implicit none @@ -15,10 +15,10 @@ module nf_multihead_attention_layer integer :: batch_size, sequence_length, model_dimension, n_heads, head_size - type(dense_layer) :: query_layer - type(dense_layer) :: key_layer - type(dense_layer) :: value_layer - type(dense_layer) :: output_layer + type(linear2d_layer) :: query_layer + type(linear2d_layer) :: key_layer + type(linear2d_layer) :: value_layer + type(linear2d_layer) :: output_layer type(softmax) :: softmax_func @@ -59,7 +59,7 @@ pure module subroutine backward(self, input, gradient) !! Gradient from the next layer end subroutine backward - pure module subroutine forward(self, query, key, value) + module subroutine forward(self, query, key, value) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: query(:, :, :), key(:, :, :), value(:, :, :) end subroutine forward @@ -92,36 +92,44 @@ module function multihead_attention_layer_cons(& end if res % head_size = model_dimension / n_heads - res % query_layer = dense_layer(input_size=model_dimension, output_size=model_dimension) - res % key_layer = dense_layer(input_size=model_dimension, output_size=model_dimension) - res % value_layer = dense_layer(input_size=model_dimension, output_size=model_dimension) - res % output_layer = dense_layer(input_size=model_dimension, output_size=model_dimension) + res % query_layer = linear2d_layer(& + sequence_length=sequence_length, in_features=model_dimension,& + out_features=model_dimension, batch_size=batch_size& + ) + res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, batch_size) + res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, batch_size) + res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, batch_size) + call res % query_layer % init([0]) + call res % key_layer % init([0]) + call res % value_layer % init([0]) + call res % output_layer % init([0]) res % softmax_func = softmax() end function multihead_attention_layer_cons - pure module subroutine forward(self, query, key, value) + module subroutine forward(self, query, key, value) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: query(:, :, :), key(:, :, :), value(:, :, :) - real :: q(self % batch_size, self % n_heads, self % sequence_length, self % head_size) - real :: k(self % batch_size, self % n_heads, self % sequence_length, self % head_size) - real :: v(self % batch_size, self % n_heads, self % sequence_length, self % head_size) - real :: attention_matrix(self % batch_size, self % n_heads, self % sequence_length, self % sequence_length) - real :: dot_product_attention(self % batch_size, self % n_heads, self % sequence_length, self % head_size) - -! call self % query_layer % forward(query) -! call self % key_layer % forward(key) -! call self % value_layer % forward(value) -! -! q = self % split_heads(self % query_layer % output) -! k = self % split_heads(self % key_layer % output) -! v = self % split_heads(self % value_layer % output) -! -! attention_matrix = self % normalize_attention_matrix(self % create_attention_matrix(q, k)) -! dot_product_attention = self % scaled_dot_product_attention(attention_matrix, v) -! -! self % output = self % output_layer % forward(self % combine_heads(dot_product_attention)) + real :: q(self % n_heads, self % sequence_length, self % head_size, self % batch_size) + real :: k(self % n_heads, self % sequence_length, self % head_size, self % batch_size) + real :: v(self % n_heads, self % sequence_length, self % head_size, self % batch_size) + real :: attention_matrix(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size) + real :: dot_product_attention(self % n_heads, self % sequence_length, self % head_size, self % batch_size) + + call self % query_layer % forward(query) + call self % key_layer % forward(key) + call self % value_layer % forward(value) + + q = self % split_heads(self % query_layer % output) + k = self % split_heads(self % key_layer % output) + v = self % split_heads(self % value_layer % output) + + attention_matrix = self % normalize_attention_matrix(self % create_attention_matrix(q, k)) + dot_product_attention = self % scaled_dot_product_attention(attention_matrix, v) + + call self % output_layer % forward(self % combine_heads(dot_product_attention)) + self % output = self % output_layer % output end subroutine forward module function split_heads(self, input) result(output) @@ -225,6 +233,6 @@ module subroutine init(self, input_shape) class(multihead_attention_layer), intent(in out) :: self integer, intent(in) :: input_shape(:) - allocate(self % output(self % batch_size, self % sequence_length, self % model_dimension)) + allocate(self % output(self % sequence_length, self % model_dimension, self % batch_size)) end subroutine init end module nf_multihead_attention_layer From 519a6c84bb3f736ce7dc7806f280b03fa0dfe622 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Thu, 6 Feb 2025 11:15:44 +0400 Subject: [PATCH 21/71] tests for multihead_attention%forward --- test/test_multihead_attention_layer.f90 | 56 ++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index d2c8d093..8c47ed35 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -1,6 +1,7 @@ program test_multihead_attention_layer use iso_fortran_env, only: stderr => error_unit use nf_multihead_attention_layer, only: multihead_attention_layer + use nf_linear2d_layer, only: linear2d_layer implicit none logical :: ok = .true. @@ -12,9 +13,9 @@ program test_multihead_attention_layer real :: scaled_dp_att(2, 3, 2, 1) real :: scaled_dp_att_reshaped(1, 3, 2, 2) real :: combined_attention(3, 4, 1) - integer :: i, j attention = multihead_attention_layer(batch_size=1, sequence_length=3, model_dimension=4, n_heads=2) + call attention % init([0]) call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output) call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok, raw_attention_matrix) @@ -23,6 +24,8 @@ program test_multihead_attention_layer attention, normalized_attention_matrix, split_heads_output, ok, scaled_dp_att& ) call test_multihead_attention_combine_heads(attention, scaled_dp_att, ok) + call test_multihead_attention_forward(attention, ok) + call test_multihead_attention_forward_reallife_shape(ok) contains subroutine test_multihead_attention_split_heads(attention, input, ok, output) @@ -143,4 +146,55 @@ subroutine test_multihead_attention_combine_heads(attention, scaled_dp_att, ok) write(stderr, '(a)') 'combine_heads returned incorrect values.. failed' end if end subroutine test_multihead_attention_combine_heads + + subroutine test_multihead_attention_forward(attention, ok) + type(multihead_attention_layer), intent(in out) :: attention + logical, intent(in out) :: ok + real :: input(3, 4, 1) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4, 1]) + real :: output(attention % sequence_length, attention % model_dimension, attention % batch_size) + real :: output_flat(12) + integer :: output_shape(3) + integer :: expected_shape(3) = [3, 4, 1] + real :: expected_output_flat(12) = [& + 0.982241452, 1.00407875, 1.00444126, 0.982241452, 1.00407875, 1.00444126,& + 0.982241452, 1.00407875, 1.00444126, 0.982241452, 1.00407875, 1.00444126& + ] + + call attention % forward(input, input, input) + + output_shape = shape(attention % output) + if (.not. all(output_shape.eq.expected_shape)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect shape.. failed' + end if + output_flat = reshape(attention % output, shape(output_flat)) + if (.not. all(output_flat.eq.expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect values.. failed' + end if + end subroutine test_multihead_attention_forward + + subroutine test_multihead_attention_forward_reallife_shape(ok) + logical, intent(in out) :: ok + real :: input(148, 512, 2) + real :: output(148, 512, 2) + type(linear2d_layer) :: q + real :: output_flat(12) + integer :: output_shape(3) + integer :: expected_shape(3) = [148, 512, 2] + type(multihead_attention_layer) :: attention + + call random_number(input) + + attention = multihead_attention_layer(batch_size=2, sequence_length=148, model_dimension=512, n_heads=8) + call attention % init([0]) + + call attention % forward(input, input, input) + + output_shape = shape(attention % output) + if (.not. all(output_shape.eq.expected_shape)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect shape.. failed' + end if + end subroutine test_multihead_attention_forward_reallife_shape end program test_multihead_attention_layer \ No newline at end of file From 9fdc7ae63d273465333035830543a2067007f03f Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Thu, 6 Feb 2025 12:32:57 +0400 Subject: [PATCH 22/71] multihead_attention: move most logic to subroutines (performance) --- src/nf/nf_multihead_attention.f90 | 57 ++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 91e6be96..6d05ecff 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -22,6 +22,8 @@ module nf_multihead_attention_layer type(softmax) :: softmax_func + real, allocatable :: attention_matrix(:, :, :, :) + real, allocatable :: sdpa(:, :, :, :) real, allocatable :: output(:, :, :) contains @@ -125,10 +127,11 @@ module subroutine forward(self, query, key, value) k = self % split_heads(self % key_layer % output) v = self % split_heads(self % value_layer % output) - attention_matrix = self % normalize_attention_matrix(self % create_attention_matrix(q, k)) - dot_product_attention = self % scaled_dot_product_attention(attention_matrix, v) + call self % create_attention_matrix(q, k) + call self % normalize_attention_matrix() + call self % scaled_dot_product_attention(v) - call self % output_layer % forward(self % combine_heads(dot_product_attention)) + call self % output_layer % forward(self % combine_heads(self % sdpa)) self % output = self % output_layer % output end subroutine forward @@ -158,60 +161,68 @@ module function split_heads(self, input) result(output) ) end function split_heads - module function create_attention_matrix(self, query, key) result(output) + module subroutine create_attention_matrix(self, query, key) !! Create attention matrix for query and key + !! Output dimensions: n_heads, sequence_length, sequence_length, batch_size class(multihead_attention_layer) :: self real :: query(:, :, :, :) real :: key(:, :, :, :) - real :: output(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size) integer :: i, j ! create attention matrix for each sequence in each batch do i = 1, self % batch_size do j = 1, self % n_heads - output(j, :, :, i) = matmul(query(j, :, :, i), transpose(key(j, :, :, i))) + self % attention_matrix(j, :, :, i) = matmul(query(j, :, :, i), transpose(key(j, :, :, i))) end do end do - end function create_attention_matrix + end subroutine create_attention_matrix - module function normalize_attention_matrix(self, attention_matrix, attention_mask) result(output) + module subroutine normalize_attention_matrix(self, attention_mask) !! Create attention matrix for query and key + !! Output dims: n_heads, sequence_length, sequence_length, batch_size class(multihead_attention_layer) :: self - real :: attention_matrix(:, :, :, :) !! (batch_size, n_heads, sequence_length, sequence_length) real, optional :: attention_mask(:, :, :, :) !! (batch_size, n_heads, sequence_length, sequence_length) - real :: output(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size) + real, allocatable :: output(:, :, :, :) integer :: batch, head, seq + ! temporary storage + allocate(output(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)) + ! scale dowm by square root of each head's size - attention_matrix = attention_matrix / sqrt(real(self % head_size)) + self % attention_matrix = self % attention_matrix / sqrt(real(self % head_size)) ! attention mask is used to mask out some of the tokens if necessary if (present(attention_mask)) then - attention_matrix = attention_matrix + attention_mask + self % attention_matrix = self % attention_matrix + attention_mask end if - ! softmax by second sequnce_length + ! softmax by last sequnce_length do batch = 1, self % batch_size do head = 1, self % n_heads do seq = 1, self % sequence_length - output(head, seq, :, batch) = self % softmax_func % eval_1d(attention_matrix(head, seq, :, batch)) + output(head, seq, :, batch) = self % softmax_func % eval_1d(& + self % attention_matrix(head, seq, :, batch)& + ) end do end do end do - end function normalize_attention_matrix + self % attention_matrix = output + + deallocate(output) + end subroutine normalize_attention_matrix - module function scaled_dot_product_attention(self, attention_matrix, value) result(output) + module subroutine scaled_dot_product_attention(self, value) + !! Create scaled dot product attention + !! Output dims: n_heads, sequence_length, head_size, batch_size class(multihead_attention_layer) :: self - real :: attention_matrix(:, :, :, :) real :: value(:, :, :, :) - real :: output(self % n_heads, self % sequence_length, self % head_size, self % batch_size) integer :: batch, head do batch = 1, self % batch_size do head = 1, self % n_heads - output(head, :, :, batch) = matmul(attention_matrix(head, :, :, batch), value(head, :, :, batch)) + self % sdpa(head, :, :, batch) = matmul(self % attention_matrix(head, :, :, batch), value(head, :, :, batch)) end do end do - end function scaled_dot_product_attention + end subroutine scaled_dot_product_attention module function combine_heads(self, input) result(output) class(multihead_attention_layer) :: self @@ -233,6 +244,12 @@ module subroutine init(self, input_shape) class(multihead_attention_layer), intent(in out) :: self integer, intent(in) :: input_shape(:) + allocate(self % attention_matrix(& + self % n_heads, self % sequence_length, self % sequence_length, self % batch_size& + )) + allocate(self % sdpa(& + self % n_heads, self % sequence_length, self % head_size, self % batch_size& + )) allocate(self % output(self % sequence_length, self % model_dimension, self % batch_size)) end subroutine init end module nf_multihead_attention_layer From bc67331a3825d50cc1f609c7827bc93b0e0db8d8 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Thu, 6 Feb 2025 12:33:22 +0400 Subject: [PATCH 23/71] multihead_attention: update tests --- test/test_multihead_attention_layer.f90 | 42 +++++++++---------------- 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index 8c47ed35..6f41055b 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -8,22 +8,15 @@ program test_multihead_attention_layer type(multihead_attention_layer) :: attention real :: sample_input(3, 4, 1) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [3, 4, 1]) real :: split_heads_output(2, 3, 2, 1) - real :: raw_attention_matrix(2, 3, 3, 1) - real :: normalized_attention_matrix(2, 3, 3, 1) - real :: scaled_dp_att(2, 3, 2, 1) - real :: scaled_dp_att_reshaped(1, 3, 2, 2) - real :: combined_attention(3, 4, 1) attention = multihead_attention_layer(batch_size=1, sequence_length=3, model_dimension=4, n_heads=2) call attention % init([0]) call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output) - call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok, raw_attention_matrix) - call test_multihead_attention_normalization(attention, raw_attention_matrix, ok, normalized_attention_matrix) - call test_multihead_attention_scaled_dot_product_attention(& - attention, normalized_attention_matrix, split_heads_output, ok, scaled_dp_att& - ) - call test_multihead_attention_combine_heads(attention, scaled_dp_att, ok) + call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok) + call test_multihead_attention_normalization(attention, ok) + call test_multihead_attention_scaled_dot_product_attention(attention, split_heads_output, ok) + call test_multihead_attention_combine_heads(attention, attention % sdpa, ok) call test_multihead_attention_forward(attention, ok) call test_multihead_attention_forward_reallife_shape(ok) @@ -52,11 +45,10 @@ subroutine test_multihead_attention_split_heads(attention, input, ok, output) end if end subroutine test_multihead_attention_split_heads - subroutine test_multihead_attention_create_attention_matrix(attention, input, ok, attention_matrix) + subroutine test_multihead_attention_create_attention_matrix(attention, input, ok) type(multihead_attention_layer), intent(in) :: attention real, intent(in) :: input(:, :, :, :) logical, intent(in out) :: ok - real, intent(in out) :: attention_matrix(2, 3, 3, 1) real :: attention_matrix_shape(4) real :: attention_matrix_flat(18) real :: expected_shape(4) = [2, 3, 3, 1] @@ -69,25 +61,23 @@ subroutine test_multihead_attention_create_attention_matrix(attention, input, ok 0.573199987, 0.289999992, 0.654400051& ] - attention_matrix = attention % create_attention_matrix(input, input) + call attention % create_attention_matrix(input, input) - attention_matrix_shape = shape(attention_matrix) + attention_matrix_shape = shape(attention % attention_matrix) if (.not. all(attention_matrix_shape.eq.expected_shape)) then ok = .false. write(stderr, '(a)') 'create_attention_matrix returned incorrect shape.. failed' end if - attention_matrix_flat = reshape(attention_matrix, shape(expected_attention_matrix_flat)) + attention_matrix_flat = reshape(attention % attention_matrix, shape(expected_attention_matrix_flat)) if (.not. all(attention_matrix_flat.eq.expected_attention_matrix_flat)) then ok = .false. write(stderr, '(a)') 'create_attention_matrix returned incorrect values.. failed' end if end subroutine test_multihead_attention_create_attention_matrix - subroutine test_multihead_attention_normalization(attention, input, ok, output) + subroutine test_multihead_attention_normalization(attention, ok) type(multihead_attention_layer), intent(in) :: attention - real, intent(in) :: input(:, :, :, :) logical, intent(in out) :: ok - real, intent(out) :: output(2, 3, 3, 1) real :: output_flat(18) real :: expected_output_flat(18) = [& 0.326287806, 0.435975075, 0.321620107, 0.330339342, 0.316976935, 0.329200655,& @@ -95,32 +85,28 @@ subroutine test_multihead_attention_normalization(attention, input, ok, output) 0.340428889, 0.288890868, 0.345185399, 0.343245387, 0.349961787, 0.345025837& ] - output = attention % normalize_attention_matrix(input) + call attention % normalize_attention_matrix() - output_flat = reshape(output, shape(output_flat)) + output_flat = reshape(attention % attention_matrix, shape(output_flat)) if (.not. all(output_flat.eq.expected_output_flat)) then ok = .false. write(stderr, '(a)') 'normalize_attention_matrix returned incorrect values.. failed' end if end subroutine test_multihead_attention_normalization - subroutine test_multihead_attention_scaled_dot_product_attention(attention, attention_matrix, value, ok, output) + subroutine test_multihead_attention_scaled_dot_product_attention(attention, value, ok) type(multihead_attention_layer), intent(in) :: attention - real, intent(in) :: attention_matrix(:, :, :, :) real, intent(in) :: value(:, :, :, :) logical, intent(in out) :: ok - real, intent(out) :: output(& - attention % n_heads, attention % sequence_length, attention % head_size, attention % batch_size& - ) real :: output_flat(12) real :: expected_output_flat(12) = [& 0.101414114, 0.685291648, 0.102356531, 0.701290607, 0.103298485, 0.701582491,& 0.401414126, 0.457309216, 0.402356505, 0.374400526, 0.403298497, 0.373518765& ] - output = attention % scaled_dot_product_attention(attention_matrix, value) + call attention % scaled_dot_product_attention(value) - output_flat = reshape(output, shape(output_flat)) + output_flat = reshape(attention % sdpa, shape(output_flat)) if (.not. all(output_flat.eq.expected_output_flat)) then ok = .false. write(stderr, '(a)') 'scaled_dot_product_attention returned incorrect values.. failed' From a0a6fc419ec700dd0769da71718d2ea1448cc43c Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Thu, 6 Feb 2025 12:38:31 +0400 Subject: [PATCH 24/71] multihead_attention: concurrency --- src/nf/nf_multihead_attention.f90 | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 6d05ecff..cf78fd2f 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -169,10 +169,8 @@ module subroutine create_attention_matrix(self, query, key) real :: key(:, :, :, :) integer :: i, j ! create attention matrix for each sequence in each batch - do i = 1, self % batch_size - do j = 1, self % n_heads - self % attention_matrix(j, :, :, i) = matmul(query(j, :, :, i), transpose(key(j, :, :, i))) - end do + do concurrent(i = 1: self % batch_size, j = 1: self % n_heads) + self % attention_matrix(j, :, :, i) = matmul(query(j, :, :, i), transpose(key(j, :, :, i))) end do end subroutine create_attention_matrix @@ -196,14 +194,8 @@ module subroutine normalize_attention_matrix(self, attention_mask) self % attention_matrix = self % attention_matrix + attention_mask end if ! softmax by last sequnce_length - do batch = 1, self % batch_size - do head = 1, self % n_heads - do seq = 1, self % sequence_length - output(head, seq, :, batch) = self % softmax_func % eval_1d(& - self % attention_matrix(head, seq, :, batch)& - ) - end do - end do + do concurrent(batch = 1: self % batch_size, head = 1: self % n_heads, seq = 1: self % sequence_length) + output(head, seq, :, batch) = self % softmax_func % eval_1d(self % attention_matrix(head, seq, :, batch)) end do self % attention_matrix = output @@ -217,10 +209,8 @@ module subroutine scaled_dot_product_attention(self, value) real :: value(:, :, :, :) integer :: batch, head - do batch = 1, self % batch_size - do head = 1, self % n_heads - self % sdpa(head, :, :, batch) = matmul(self % attention_matrix(head, :, :, batch), value(head, :, :, batch)) - end do + do concurrent(batch = 1: self % batch_size, head = 1: self % n_heads) + self % sdpa(head, :, :, batch) = matmul(self % attention_matrix(head, :, :, batch), value(head, :, :, batch)) end do end subroutine scaled_dot_product_attention @@ -231,12 +221,8 @@ module function combine_heads(self, input) result(output) real :: output(self % sequence_length, self % model_dimension, self % batch_size) integer :: batch, seq - do batch = 1, self % batch_size - do seq = 1, self % sequence_length - output(seq, :, batch) = reshape(& - transpose(input(:, seq, :, batch)), [self % model_dimension]& - ) - end do + do concurrent(batch = 1: self % batch_size, seq = 1: self % sequence_length) + output(seq, :, batch) = reshape(transpose(input(:, seq, :, batch)), [self % model_dimension]) end do end function combine_heads From f8101af6550df551f0e4f66c0a7a118aca8386db Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sat, 8 Feb 2025 12:20:49 +0400 Subject: [PATCH 25/71] multihead_attention: proof of concept backward (works, but not mathematically correct) --- src/nf/nf_multihead_attention.f90 | 33 ++++++++++++++++++++++--- test/test_multihead_attention_layer.f90 | 20 +++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index cf78fd2f..c1c4bef9 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -26,9 +26,12 @@ module nf_multihead_attention_layer real, allocatable :: sdpa(:, :, :, :) real, allocatable :: output(:, :, :) + real, allocatable :: q_input(:, :, :) + real, allocatable :: k_input(:, :, :) + real, allocatable :: v_input(:, :, :) contains -! procedure :: backward + procedure :: backward procedure :: forward procedure :: split_heads procedure :: create_attention_matrix @@ -49,15 +52,15 @@ end function multihead_attention_layer_cons interface - pure module subroutine backward(self, input, gradient) + module subroutine backward(self, input, gradient) !! Apply the backward gradient descent pass. !! Only weight and bias gradients are updated in this subroutine, !! while the weights and biases themselves are untouched. class(multihead_attention_layer), intent(in out) :: self !! Dense layer instance - real, intent(in) :: input(:) + real, intent(in) :: input(:, :, :) !! Input from the previous layer - real, intent(in) :: gradient(:) + real, intent(in) :: gradient(:, :, :) !! Gradient from the next layer end subroutine backward @@ -109,6 +112,20 @@ module function multihead_attention_layer_cons(& res % softmax_func = softmax() end function multihead_attention_layer_cons + module subroutine backward(self, input, gradient) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: input(:, :, :) + real, intent(in) :: gradient(:, :, :) + + call self % output_layer % backward(input, gradient) + + ! FIXME: calculate gradient for softmax + + call self % value_layer % backward(self % v_input, self % output_layer % gradient) + call self % key_layer % backward(self % k_input, self % output_layer % gradient) + call self % query_layer % backward(self % q_input, self % output_layer % gradient) + end subroutine backward + module subroutine forward(self, query, key, value) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: query(:, :, :), key(:, :, :), value(:, :, :) @@ -119,6 +136,10 @@ module subroutine forward(self, query, key, value) real :: attention_matrix(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size) real :: dot_product_attention(self % n_heads, self % sequence_length, self % head_size, self % batch_size) + self % q_input = query + self % k_input = key + self % v_input = value + call self % query_layer % forward(query) call self % key_layer % forward(key) call self % value_layer % forward(value) @@ -237,5 +258,9 @@ module subroutine init(self, input_shape) self % n_heads, self % sequence_length, self % head_size, self % batch_size& )) allocate(self % output(self % sequence_length, self % model_dimension, self % batch_size)) + + allocate(self % q_input(self % sequence_length, self % model_dimension, self % batch_size)) + allocate(self % k_input(self % sequence_length, self % model_dimension, self % batch_size)) + allocate(self % v_input(self % sequence_length, self % model_dimension, self % batch_size)) end subroutine init end module nf_multihead_attention_layer diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index 6f41055b..8044455e 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -19,6 +19,7 @@ program test_multihead_attention_layer call test_multihead_attention_combine_heads(attention, attention % sdpa, ok) call test_multihead_attention_forward(attention, ok) call test_multihead_attention_forward_reallife_shape(ok) + call test_multihead_attention_backward(attention, ok) contains subroutine test_multihead_attention_split_heads(attention, input, ok, output) @@ -183,4 +184,23 @@ subroutine test_multihead_attention_forward_reallife_shape(ok) write(stderr, '(a)') 'forward returned incorrect shape.. failed' end if end subroutine test_multihead_attention_forward_reallife_shape + + subroutine test_multihead_attention_backward(attention, ok) + type(multihead_attention_layer), intent(in out) :: attention + logical, intent(in out) :: ok + real :: input(3, 4, 1) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4, 1]) + real :: gradient(3, 4, 1) = reshape(& + [.1, .1, .1, 3., 3., 3., 2., .1, 2., 3., .1, 3., 2., 2., .1, 3., 3., 3.], [3, 4, 1]& + ) + real :: expected_shape(3) = [3, 4, 1] + real :: output_shape(3) + + call attention % backward(input, gradient) + + output_shape = shape(attention % output_layer % gradient) + if (.not. all(output_shape.eq.expected_shape)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect shape.. failed' + end if + end subroutine test_multihead_attention_backward end program test_multihead_attention_layer \ No newline at end of file From 63cce110f27ac5e57415299597d3be51d4b8dfc6 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 9 Feb 2025 12:53:51 +0400 Subject: [PATCH 26/71] multihead_attention: fix minor scaling issue --- src/nf/nf_multihead_attention.f90 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index c1c4bef9..aa5a44d2 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -209,7 +209,7 @@ module subroutine normalize_attention_matrix(self, attention_mask) allocate(output(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)) ! scale dowm by square root of each head's size - self % attention_matrix = self % attention_matrix / sqrt(real(self % head_size)) + self % attention_matrix = self % attention_matrix * sqrt(1 / real(self % head_size)) ! attention mask is used to mask out some of the tokens if necessary if (present(attention_mask)) then self % attention_matrix = self % attention_matrix + attention_mask From dfb8842c085e11adacb742036161f8201ed61f34 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 9 Feb 2025 13:26:40 +0400 Subject: [PATCH 27/71] multihead_attention: complete backward implementation --- src/nf/nf_multihead_attention.f90 | 60 ++++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index aa5a44d2..ef7c0149 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -113,17 +113,69 @@ module function multihead_attention_layer_cons(& end function multihead_attention_layer_cons module subroutine backward(self, input, gradient) + !! General backprop for MultiHead Attention mechanism class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :, :) real, intent(in) :: gradient(:, :, :) + real :: d_output(self % n_heads, self % sequence_length, self % head_size, self % batch_size) + real :: v_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size) + real :: k_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size) + real :: q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size) + real :: d_sdpa(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size) + real :: jacobian(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size) + real :: d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size) + real :: d_attn_matrix(self % n_heads, self % sequence_length, self % head_size, self % batch_size) + real :: dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size) + integer :: batch, head, i, j + + ! calculate output layer delta call self % output_layer % backward(input, gradient) - ! FIXME: calculate gradient for softmax + ! split heads from output gradient + d_output = self % split_heads(self % output_layer % gradient) + v_heads = self % split_heads(self % value_layer % output) + k_heads = self % split_heads(self % key_layer % output) + q_heads = self % split_heads(self % query_layer % output) - call self % value_layer % backward(self % v_input, self % output_layer % gradient) - call self % key_layer % backward(self % k_input, self % output_layer % gradient) - call self % query_layer % backward(self % q_input, self % output_layer % gradient) + ! iterate over heads to calculate deltas for each of them + do concurrent(batch = 1: self % batch_size, head = 1: self % n_heads) + ! calculate delta for value + d_sdpa(head, :, :, batch) = matmul(d_output(head, :, :, batch), transpose(v_heads(head, :, :, batch))) + + ! this monstrosity is scaled derivative of softmax + do concurrent(i = 1: self % sequence_length, j = 1: self % sequence_length) + ! jacobian matrix is used to calculate derivative of softmax (temporary storage) + ! the idea behind this if-else is that for diagonal elements, the jacobian temp + ! should be: `softmax(x) * (1 - softmax(x))` + ! for off-diagonal: `-softmax^2(x)` + ! For computational efficiency (avoid more temp storages), scaling is also done here + if (i == j) then + jacobian(head, i, j, batch) = & + self % attention_matrix(head, i, j, batch) & + * (1 - self % attention_matrix(head, i, j, batch)) & + * sqrt(1 / real(self % head_size)) + else + jacobian(head, i, j, batch) = & + - self % attention_matrix(head, i, j, batch) & + * self % attention_matrix(head, i, j, batch) & + * sqrt(1 / real(self % head_size)) + end if + end do + ! attention normalization delta, the last step of softmax derivative: + ! multiply temp jacobian matrix by the output of softmax + d_normalize(head, :, :, batch) = matmul(d_sdpa(head, :, :, batch), jacobian(head, :, :, batch)) + + ! calculate delta for query + d_attn_matrix(head, :, :, batch) = matmul(d_normalize(head, :, :, batch), k_heads(head, :, :, batch)) + ! calculate delta for key, attention matrix should be transposed unlike for query + dk(head, :, :, batch) = matmul(transpose(d_normalize(head, :, :, batch)), q_heads(head, :, :, batch)) + end do + + ! calculate deltas for input layers + call self % value_layer % backward(self % v_input, self % combine_heads(d_sdpa)) + call self % key_layer % backward(self % k_input, self % combine_heads(dk)) + call self % query_layer % backward(self % q_input, self % combine_heads(d_attn_matrix)) end subroutine backward module subroutine forward(self, query, key, value) From adcf5e6f08a91eaddd3e4447d560c89ecc4fd08b Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 9 Feb 2025 13:30:55 +0400 Subject: [PATCH 28/71] multihead_attention: add comments for forward prop --- src/nf/nf_multihead_attention.f90 | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index ef7c0149..9a9995f5 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -179,6 +179,7 @@ module subroutine backward(self, input, gradient) end subroutine backward module subroutine forward(self, query, key, value) + !! General forward prop for MultiHead Attention Mechenism class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: query(:, :, :), key(:, :, :), value(:, :, :) @@ -192,16 +193,21 @@ module subroutine forward(self, query, key, value) self % k_input = key self % v_input = value + ! run inputs through linear layers (trainable params) call self % query_layer % forward(query) call self % key_layer % forward(key) call self % value_layer % forward(value) + ! split attention heads for more efficient computation q = self % split_heads(self % query_layer % output) k = self % split_heads(self % key_layer % output) v = self % split_heads(self % value_layer % output) + ! create key by value matrix call self % create_attention_matrix(q, k) + ! apply softmax and scaling call self % normalize_attention_matrix() + ! multiply attention matrix by value call self % scaled_dot_product_attention(v) call self % output_layer % forward(self % combine_heads(self % sdpa)) From 650e47c828b5e0a2b47872810e8487201bf5244a Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 9 Feb 2025 13:38:43 +0400 Subject: [PATCH 29/71] multihead_attention: add tests for backward --- test/test_multihead_attention_layer.f90 | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index 8044455e..e1690396 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -18,8 +18,8 @@ program test_multihead_attention_layer call test_multihead_attention_scaled_dot_product_attention(attention, split_heads_output, ok) call test_multihead_attention_combine_heads(attention, attention % sdpa, ok) call test_multihead_attention_forward(attention, ok) - call test_multihead_attention_forward_reallife_shape(ok) call test_multihead_attention_backward(attention, ok) + call test_multihead_attention_forward_reallife_shape(ok) contains subroutine test_multihead_attention_split_heads(attention, input, ok, output) @@ -189,18 +189,31 @@ subroutine test_multihead_attention_backward(attention, ok) type(multihead_attention_layer), intent(in out) :: attention logical, intent(in out) :: ok real :: input(3, 4, 1) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4, 1]) - real :: gradient(3, 4, 1) = reshape(& - [.1, .1, .1, 3., 3., 3., 2., .1, 2., 3., .1, 3., 2., 2., .1, 3., 3., 3.], [3, 4, 1]& - ) + real :: gradient(3, 4, 1) = reshape([0.1, 3. , 2. , 0.1, 3. , 3. , 0.1, 2. , 0.1, 3. , 0.1, 3. ], [3, 4, 1]) + real :: expected_output_flat(12) = [& + 0.489710003, 0.240968466, -3.35404873E-02, 0.489710003,& + 0.240968466, -3.35404873E-02, 0.489710003, 0.240968466,& + -3.35404873E-02, 0.489710003, 0.240968466, -3.35404873E-02& + ] real :: expected_shape(3) = [3, 4, 1] + real :: output(3, 4, 1) + real :: output_flat(12) real :: output_shape(3) call attention % backward(input, gradient) - output_shape = shape(attention % output_layer % gradient) + ! sample for Self Attention: sum of output gradients + output = attention % query_layer % gradient + attention % key_layer % gradient + attention % value_layer % gradient + + output_shape = shape(output) if (.not. all(output_shape.eq.expected_shape)) then ok = .false. write(stderr, '(a)') 'backward returned incorrect shape.. failed' end if + output_flat = reshape(output, shape(output_flat)) + if (.not. all(output_flat.eq.expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect values.. failed' + end if end subroutine test_multihead_attention_backward end program test_multihead_attention_layer \ No newline at end of file From 3d161612453814fbb08e9acea53511e6edd7fdd1 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 9 Feb 2025 13:44:13 +0400 Subject: [PATCH 30/71] multihead_attention: adjust expected test values for updated scaling --- test/test_multihead_attention_layer.f90 | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index e1690396..2d6165c9 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -81,9 +81,9 @@ subroutine test_multihead_attention_normalization(attention, ok) logical, intent(in out) :: ok real :: output_flat(18) real :: expected_output_flat(18) = [& - 0.326287806, 0.435975075, 0.321620107, 0.330339342, 0.316976935, 0.329200655,& - 0.333283335, 0.275134116, 0.333194464, 0.326415271, 0.333061278, 0.325773478,& - 0.340428889, 0.288890868, 0.345185399, 0.343245387, 0.349961787, 0.345025837& + 0.326287806, 0.435975075, 0.321620107, 0.330339372, 0.316976935, 0.329200655,& + 0.333283335, 0.275134116, 0.333194494, 0.326415271, 0.333061278, 0.325773478,& + 0.340428889, 0.288890868, 0.345185429, 0.343245387, 0.349961787, 0.345025837& ] call attention % normalize_attention_matrix() @@ -101,8 +101,8 @@ subroutine test_multihead_attention_scaled_dot_product_attention(attention, valu logical, intent(in out) :: ok real :: output_flat(12) real :: expected_output_flat(12) = [& - 0.101414114, 0.685291648, 0.102356531, 0.701290607, 0.103298485, 0.701582491,& - 0.401414126, 0.457309216, 0.402356505, 0.374400526, 0.403298497, 0.373518765& + 0.101414114, 0.685291648, 0.102356538, 0.701290667, 0.103298485, 0.701582491,& + 0.401414126, 0.457309216, 0.402356565, 0.374400556, 0.403298497, 0.373518765& ] call attention % scaled_dot_product_attention(value) @@ -121,8 +121,8 @@ subroutine test_multihead_attention_combine_heads(attention, scaled_dp_att, ok) real :: output(attention % sequence_length, attention % model_dimension, attention % batch_size) real :: output_flat(12) real :: expected_output_flat(12) = [& - 0.101414114, 0.102356531, 0.103298485, 0.401414126, 0.402356505, 0.403298497,& - 0.685291648, 0.701290607, 0.701582491, 0.457309216, 0.374400526, 0.373518765& + 0.101414114, 0.102356538, 0.103298485, 0.401414126, 0.402356565, 0.403298497,& + 0.685291648, 0.701290667, 0.701582491, 0.457309216, 0.374400556, 0.373518765& ] output = attention % combine_heads(scaled_dp_att) From dcae5d68ae1469e384496a350215c467044312e7 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 9 Feb 2025 13:49:18 +0400 Subject: [PATCH 31/71] multihead_attention: calculate scaling factor only once --- src/nf/nf_multihead_attention.f90 | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 9a9995f5..e66cd043 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -26,6 +26,8 @@ module nf_multihead_attention_layer real, allocatable :: sdpa(:, :, :, :) real, allocatable :: output(:, :, :) + real :: scaling_factor + real, allocatable :: q_input(:, :, :) real, allocatable :: k_input(:, :, :) real, allocatable :: v_input(:, :, :) @@ -154,12 +156,12 @@ module subroutine backward(self, input, gradient) jacobian(head, i, j, batch) = & self % attention_matrix(head, i, j, batch) & * (1 - self % attention_matrix(head, i, j, batch)) & - * sqrt(1 / real(self % head_size)) + * self % scaling_factor else jacobian(head, i, j, batch) = & - self % attention_matrix(head, i, j, batch) & * self % attention_matrix(head, i, j, batch) & - * sqrt(1 / real(self % head_size)) + * self % scaling_factor end if end do ! attention normalization delta, the last step of softmax derivative: @@ -267,7 +269,7 @@ module subroutine normalize_attention_matrix(self, attention_mask) allocate(output(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)) ! scale dowm by square root of each head's size - self % attention_matrix = self % attention_matrix * sqrt(1 / real(self % head_size)) + self % attention_matrix = self % attention_matrix * self % scaling_factor ! attention mask is used to mask out some of the tokens if necessary if (present(attention_mask)) then self % attention_matrix = self % attention_matrix + attention_mask @@ -317,6 +319,8 @@ module subroutine init(self, input_shape) )) allocate(self % output(self % sequence_length, self % model_dimension, self % batch_size)) + self % scaling_factor = sqrt(1 / real(self % head_size)) + allocate(self % q_input(self % sequence_length, self % model_dimension, self % batch_size)) allocate(self % k_input(self % sequence_length, self % model_dimension, self % batch_size)) allocate(self % v_input(self % sequence_length, self % model_dimension, self % batch_size)) From 9fceae76e90c16342baaf43ebdaaf67a77a88150 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 9 Feb 2025 13:55:43 +0400 Subject: [PATCH 32/71] multihead_attention: use heap-allocated arrays during back prop --- src/nf/nf_multihead_attention.f90 | 40 ++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index e66cd043..037867ee 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -120,17 +120,28 @@ module subroutine backward(self, input, gradient) real, intent(in) :: input(:, :, :) real, intent(in) :: gradient(:, :, :) - real :: d_output(self % n_heads, self % sequence_length, self % head_size, self % batch_size) - real :: v_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size) - real :: k_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size) - real :: q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size) - real :: d_sdpa(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size) - real :: jacobian(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size) - real :: d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size) - real :: d_attn_matrix(self % n_heads, self % sequence_length, self % head_size, self % batch_size) - real :: dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size) + real, allocatable :: d_output(:, :, :, :) + real, allocatable :: v_heads(:, :, :, :) + real, allocatable :: k_heads(:, :, :, :) + real, allocatable :: q_heads(:, :, :, :) + real, allocatable :: d_sdpa(:, :, :, :) + real, allocatable :: jacobian(:, :, :, :) + real, allocatable :: d_normalize(:, :, :, :) + real, allocatable :: d_attn_matrix(:, :, :, :) + real, allocatable :: dk(:, :, :, :) integer :: batch, head, i, j + ! allocate temporary storages for backward computation + allocate(d_output(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) + allocate(v_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) + allocate(k_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) + allocate(q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) + allocate(d_sdpa(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)) + allocate(jacobian(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)) + allocate(d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)) + allocate(d_attn_matrix(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) + allocate(dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) + ! calculate output layer delta call self % output_layer % backward(input, gradient) @@ -178,6 +189,17 @@ module subroutine backward(self, input, gradient) call self % value_layer % backward(self % v_input, self % combine_heads(d_sdpa)) call self % key_layer % backward(self % k_input, self % combine_heads(dk)) call self % query_layer % backward(self % q_input, self % combine_heads(d_attn_matrix)) + + ! free temporary storages + deallocate(d_output) + deallocate(v_heads) + deallocate(k_heads) + deallocate(q_heads) + deallocate(d_sdpa) + deallocate(jacobian) + deallocate(d_normalize) + deallocate(d_attn_matrix) + deallocate(dk) end subroutine backward module subroutine forward(self, query, key, value) From 248e1248b9ddd2d31f448627de26201dfbb85879 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 9 Feb 2025 14:00:20 +0400 Subject: [PATCH 33/71] multihead_attention: use heap-allocated arrays in forward --- src/nf/nf_multihead_attention.f90 | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 037867ee..1694be28 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -207,11 +207,14 @@ module subroutine forward(self, query, key, value) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: query(:, :, :), key(:, :, :), value(:, :, :) - real :: q(self % n_heads, self % sequence_length, self % head_size, self % batch_size) - real :: k(self % n_heads, self % sequence_length, self % head_size, self % batch_size) - real :: v(self % n_heads, self % sequence_length, self % head_size, self % batch_size) - real :: attention_matrix(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size) - real :: dot_product_attention(self % n_heads, self % sequence_length, self % head_size, self % batch_size) + real, allocatable :: q(:, :, :, :) + real, allocatable :: k(:, :, :, :) + real, allocatable :: v(:, :, :, :) + + ! allocate storage for intermidiate stages + allocate(q(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) + allocate(k(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) + allocate(v(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) self % q_input = query self % k_input = key @@ -236,6 +239,11 @@ module subroutine forward(self, query, key, value) call self % output_layer % forward(self % combine_heads(self % sdpa)) self % output = self % output_layer % output + + ! free temp vars from memory + deallocate(q) + deallocate(k) + deallocate(v) end subroutine forward module function split_heads(self, input) result(output) From 4693028e07825c8f8a24e8bd9aa9790f83a962fb Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 9 Feb 2025 17:31:46 +0400 Subject: [PATCH 34/71] multihead_attention: set values from correct shape to tests --- test/test_multihead_attention_layer.f90 | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index 2d6165c9..e50b86ec 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -191,9 +191,10 @@ subroutine test_multihead_attention_backward(attention, ok) real :: input(3, 4, 1) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4, 1]) real :: gradient(3, 4, 1) = reshape([0.1, 3. , 2. , 0.1, 3. , 3. , 0.1, 2. , 0.1, 3. , 0.1, 3. ], [3, 4, 1]) real :: expected_output_flat(12) = [& - 0.489710003, 0.240968466, -3.35404873E-02, 0.489710003,& - 0.240968466, -3.35404873E-02, 0.489710003, 0.240968466,& - -3.35404873E-02, 0.489710003, 0.240968466, -3.35404873E-02& + -2.29912549E-02, 0.381484956, 0.453185737,& + -2.29912549E-02, 0.381484956, 0.453185737,& + -2.29912549E-02, 0.381484956, 0.453185737,& + -2.29912549E-02, 0.381484956, 0.453185737& ] real :: expected_shape(3) = [3, 4, 1] real :: output(3, 4, 1) From 32dd62875e2c34e029b293f817b816baa9943540 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 9 Feb 2025 17:32:34 +0400 Subject: [PATCH 35/71] multihead_attention: fix issues with shapes (softmax prime became even more monstruos) --- src/nf/nf_multihead_attention.f90 | 44 ++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 1694be28..b99e985b 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -124,20 +124,23 @@ module subroutine backward(self, input, gradient) real, allocatable :: v_heads(:, :, :, :) real, allocatable :: k_heads(:, :, :, :) real, allocatable :: q_heads(:, :, :, :) + real, allocatable :: dv(:, :, :, :) real, allocatable :: d_sdpa(:, :, :, :) - real, allocatable :: jacobian(:, :, :, :) + real, allocatable :: jacobian(:, :, :) real, allocatable :: d_normalize(:, :, :, :) real, allocatable :: d_attn_matrix(:, :, :, :) real, allocatable :: dk(:, :, :, :) - integer :: batch, head, i, j + integer :: batch, head, seq, i, j ! allocate temporary storages for backward computation allocate(d_output(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) allocate(v_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) allocate(k_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) allocate(q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) + + allocate(dv(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) allocate(d_sdpa(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)) - allocate(jacobian(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)) + allocate(jacobian(self % sequence_length, self % sequence_length, self % sequence_length)) allocate(d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)) allocate(d_attn_matrix(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) allocate(dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) @@ -153,40 +156,49 @@ module subroutine backward(self, input, gradient) ! iterate over heads to calculate deltas for each of them do concurrent(batch = 1: self % batch_size, head = 1: self % n_heads) - ! calculate delta for value + dv(head, :, :, batch) = matmul(transpose(self % attention_matrix(head, :, :, batch)), d_output(head, :, :, batch)) + + ! calculate delta for attention matrix d_sdpa(head, :, :, batch) = matmul(d_output(head, :, :, batch), transpose(v_heads(head, :, :, batch))) - ! this monstrosity is scaled derivative of softmax - do concurrent(i = 1: self % sequence_length, j = 1: self % sequence_length) + ! this monstrosity below is scaled derivative of softmax + do concurrent(seq = 1: self % sequence_length, i = 1: self % sequence_length, j = 1: self % sequence_length) ! jacobian matrix is used to calculate derivative of softmax (temporary storage) ! the idea behind this if-else is that for diagonal elements, the jacobian temp - ! should be: `softmax(x) * (1 - softmax(x))` - ! for off-diagonal: `-softmax^2(x)` + ! should be: `softmax(x_i) * (1 - softmax(x_i))` + ! for off-diagonal: `-softmax(x_i) * softmax(x_j)` ! For computational efficiency (avoid more temp storages), scaling is also done here if (i == j) then - jacobian(head, i, j, batch) = & - self % attention_matrix(head, i, j, batch) & - * (1 - self % attention_matrix(head, i, j, batch)) & + jacobian(seq, i, j) = & + self % attention_matrix(head, seq, i, batch) & + * (1 - self % attention_matrix(head, seq, i, batch)) & * self % scaling_factor else - jacobian(head, i, j, batch) = & - - self % attention_matrix(head, i, j, batch) & - * self % attention_matrix(head, i, j, batch) & + jacobian(seq, i, j) = & + - self % attention_matrix(head, seq, i, batch) & + * self % attention_matrix(head, seq, j, batch) & * self % scaling_factor end if end do + ! attention normalization delta, the last step of softmax derivative: ! multiply temp jacobian matrix by the output of softmax - d_normalize(head, :, :, batch) = matmul(d_sdpa(head, :, :, batch), jacobian(head, :, :, batch)) + do concurrent(seq = 1: self % sequence_length) + d_normalize(head, seq, :, batch) = reshape(matmul(& + reshape(d_sdpa(head, seq, :, batch), [1, self % sequence_length]),& + jacobian(seq, :, :)& + ), [self % sequence_length]) + end do ! calculate delta for query d_attn_matrix(head, :, :, batch) = matmul(d_normalize(head, :, :, batch), k_heads(head, :, :, batch)) + ! calculate delta for key, attention matrix should be transposed unlike for query dk(head, :, :, batch) = matmul(transpose(d_normalize(head, :, :, batch)), q_heads(head, :, :, batch)) end do ! calculate deltas for input layers - call self % value_layer % backward(self % v_input, self % combine_heads(d_sdpa)) + call self % value_layer % backward(self % v_input, self % combine_heads(dv)) call self % key_layer % backward(self % k_input, self % combine_heads(dk)) call self % query_layer % backward(self % q_input, self % combine_heads(d_attn_matrix)) From 33c33b906c4b4428c6aa25d9df61aac7a61f9596 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 9 Feb 2025 18:10:25 +0400 Subject: [PATCH 36/71] multihead_attention: minor refactoring and optimization --- src/nf/nf_multihead_attention.f90 | 65 +++++++++++++++---------------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index b99e985b..e0a00bcd 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -125,10 +125,10 @@ module subroutine backward(self, input, gradient) real, allocatable :: k_heads(:, :, :, :) real, allocatable :: q_heads(:, :, :, :) real, allocatable :: dv(:, :, :, :) - real, allocatable :: d_sdpa(:, :, :, :) - real, allocatable :: jacobian(:, :, :) + real, allocatable :: d_sdpa(:, :) + real, allocatable :: jacobian(:, :) real, allocatable :: d_normalize(:, :, :, :) - real, allocatable :: d_attn_matrix(:, :, :, :) + real, allocatable :: dq(:, :, :, :) real, allocatable :: dk(:, :, :, :) integer :: batch, head, seq, i, j @@ -139,10 +139,10 @@ module subroutine backward(self, input, gradient) allocate(q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) allocate(dv(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) - allocate(d_sdpa(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)) - allocate(jacobian(self % sequence_length, self % sequence_length, self % sequence_length)) + allocate(d_sdpa(self % sequence_length, self % sequence_length)) + allocate(jacobian(self % sequence_length, self % sequence_length)) allocate(d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)) - allocate(d_attn_matrix(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) + allocate(dq(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) allocate(dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) ! calculate output layer delta @@ -159,39 +159,38 @@ module subroutine backward(self, input, gradient) dv(head, :, :, batch) = matmul(transpose(self % attention_matrix(head, :, :, batch)), d_output(head, :, :, batch)) ! calculate delta for attention matrix - d_sdpa(head, :, :, batch) = matmul(d_output(head, :, :, batch), transpose(v_heads(head, :, :, batch))) + d_sdpa = matmul(d_output(head, :, :, batch), transpose(v_heads(head, :, :, batch))) ! this monstrosity below is scaled derivative of softmax - do concurrent(seq = 1: self % sequence_length, i = 1: self % sequence_length, j = 1: self % sequence_length) - ! jacobian matrix is used to calculate derivative of softmax (temporary storage) - ! the idea behind this if-else is that for diagonal elements, the jacobian temp - ! should be: `softmax(x_i) * (1 - softmax(x_i))` - ! for off-diagonal: `-softmax(x_i) * softmax(x_j)` - ! For computational efficiency (avoid more temp storages), scaling is also done here - if (i == j) then - jacobian(seq, i, j) = & - self % attention_matrix(head, seq, i, batch) & - * (1 - self % attention_matrix(head, seq, i, batch)) & - * self % scaling_factor - else - jacobian(seq, i, j) = & - - self % attention_matrix(head, seq, i, batch) & - * self % attention_matrix(head, seq, j, batch) & - * self % scaling_factor - end if - end do - - ! attention normalization delta, the last step of softmax derivative: - ! multiply temp jacobian matrix by the output of softmax do concurrent(seq = 1: self % sequence_length) + ! create jacobian matrix + do concurrent(i = 1: self % sequence_length, j = 1: self % sequence_length) + ! jacobian matrix is used to calculate derivative of softmax (temporary storage) + ! the idea behind this if-else is that for diagonal elements, the jacobian temp + ! should be: `softmax(x_i) * (1 - softmax(x_i))` + ! for off-diagonal: `-softmax(x_i) * softmax(x_j)` + if (i == j) then + jacobian(i, j) = & + self % attention_matrix(head, seq, i, batch) & + * (1 - self % attention_matrix(head, seq, i, batch)) + else + jacobian(i, j) = & + - self % attention_matrix(head, seq, i, batch) & + * self % attention_matrix(head, seq, j, batch) + end if + end do + ! attention normalization delta, the last step of softmax derivative: + ! multiply output of softmax by temp jacobian matrix + ! For computational efficiency (avoid more temp storages), scaling is also done here + ! reshapes: [3] -> [1, 3] @ [3, 3] = [1, 3] -> [3] d_normalize(head, seq, :, batch) = reshape(matmul(& - reshape(d_sdpa(head, seq, :, batch), [1, self % sequence_length]),& - jacobian(seq, :, :)& + reshape(d_sdpa(seq, :), [1, self % sequence_length]),& + jacobian * self % scaling_factor& ), [self % sequence_length]) end do ! calculate delta for query - d_attn_matrix(head, :, :, batch) = matmul(d_normalize(head, :, :, batch), k_heads(head, :, :, batch)) + dq(head, :, :, batch) = matmul(d_normalize(head, :, :, batch), k_heads(head, :, :, batch)) ! calculate delta for key, attention matrix should be transposed unlike for query dk(head, :, :, batch) = matmul(transpose(d_normalize(head, :, :, batch)), q_heads(head, :, :, batch)) @@ -200,7 +199,7 @@ module subroutine backward(self, input, gradient) ! calculate deltas for input layers call self % value_layer % backward(self % v_input, self % combine_heads(dv)) call self % key_layer % backward(self % k_input, self % combine_heads(dk)) - call self % query_layer % backward(self % q_input, self % combine_heads(d_attn_matrix)) + call self % query_layer % backward(self % q_input, self % combine_heads(dq)) ! free temporary storages deallocate(d_output) @@ -210,7 +209,7 @@ module subroutine backward(self, input, gradient) deallocate(d_sdpa) deallocate(jacobian) deallocate(d_normalize) - deallocate(d_attn_matrix) + deallocate(dq) deallocate(dk) end subroutine backward From 40c3f2b25bf6c5944433666cd1c2862780e694f2 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 9 Feb 2025 18:24:57 +0400 Subject: [PATCH 37/71] multihead_attention: fix comments --- src/nf/nf_multihead_attention.f90 | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index e0a00bcd..d17ae7eb 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -55,9 +55,10 @@ end function multihead_attention_layer_cons interface module subroutine backward(self, input, gradient) - !! Apply the backward gradient descent pass. - !! Only weight and bias gradients are updated in this subroutine, - !! while the weights and biases themselves are untouched. + !! General backprop for MultiHead Attention mechanism + !! Might be used for both Self and Cross Attention + !! Self Attention: sum output gradients + !! Cross Attention: use them separately class(multihead_attention_layer), intent(in out) :: self !! Dense layer instance real, intent(in) :: input(:, :, :) @@ -67,6 +68,10 @@ module subroutine backward(self, input, gradient) end subroutine backward module subroutine forward(self, query, key, value) + !! General forward propagation for MultiHead Attention Mechanism + !! Might be used for both Self and Cross Attention + !! Self Attention: pass the same value thrice + !! Cross Attention: pass three values for your query, key and value class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: query(:, :, :), key(:, :, :), value(:, :, :) end subroutine forward @@ -76,9 +81,7 @@ module subroutine init(self, input_shape) !! !! This is a deferred procedure from the `base_layer` abstract type. class(multihead_attention_layer), intent(in out) :: self - !! Dense layer instance integer, intent(in) :: input_shape(:) - !! Shape of the input layer end subroutine init end interface @@ -115,7 +118,6 @@ module function multihead_attention_layer_cons(& end function multihead_attention_layer_cons module subroutine backward(self, input, gradient) - !! General backprop for MultiHead Attention mechanism class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :, :) real, intent(in) :: gradient(:, :, :) @@ -214,7 +216,6 @@ module subroutine backward(self, input, gradient) end subroutine backward module subroutine forward(self, query, key, value) - !! General forward prop for MultiHead Attention Mechenism class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: query(:, :, :), key(:, :, :), value(:, :, :) @@ -261,17 +262,8 @@ module function split_heads(self, input) result(output) !! Split inputs into heads !! !! Example with two heads: - !! input (1, 3, 4): - !! [[[0. , 0.3 , 0.6 , 0.9 ], - !! [0.1 , 0.4 , 0.7 , 0.11], - !! [0.2 , 0.5 , 0.8 , 0.12]]] - !! output (1, 2, 3, 2) - !! [[[[0. , 0.3 ], - ! [0.1 , 0.4 ], - ! [0.2 , 0.5 ]], - ! [[0.6 , 0.9 ], - ! [0.7 , 0.11], - ! [0.8 , 0.12]]]] + !! input (3, 4, 1) + !! output (2, 3, 2, 1) class(multihead_attention_layer) :: self real :: input(:, :, :) real :: output(self % n_heads, self % sequence_length, self % head_size, self % batch_size) From 6a607b0a536478b449d51de07dc8fdc229608fb9 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 9 Feb 2025 18:52:47 +0400 Subject: [PATCH 38/71] multihead_attention: tests, add checks for attention weights --- test/test_multihead_attention_layer.f90 | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index e50b86ec..afdaa9c0 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -141,11 +141,19 @@ subroutine test_multihead_attention_forward(attention, ok) real :: output(attention % sequence_length, attention % model_dimension, attention % batch_size) real :: output_flat(12) integer :: output_shape(3) + integer :: attn_weights_shape(4) + real :: attn_weights_flat(18) integer :: expected_shape(3) = [3, 4, 1] real :: expected_output_flat(12) = [& 0.982241452, 1.00407875, 1.00444126, 0.982241452, 1.00407875, 1.00444126,& 0.982241452, 1.00407875, 1.00444126, 0.982241452, 1.00407875, 1.00444126& ] + integer :: expected_attn_weights_shape(4) = [2, 3, 3, 1] + real :: expected_attn_weights_flat(18) = [& + 7.89450705E-02, 7.89450705E-02, 2.28110179E-02, 2.28110179E-02, 2.18846574E-02, 2.18846574E-02,& + 0.447508544, 0.447508544, 0.464612424, 0.464612424, 0.464721352, 0.464721352,& + 0.473546445, 0.473546445, 0.512576580, 0.512576580, 0.513393998, 0.513393998& + ] call attention % forward(input, input, input) @@ -159,6 +167,17 @@ subroutine test_multihead_attention_forward(attention, ok) ok = .false. write(stderr, '(a)') 'forward returned incorrect values.. failed' end if + + attn_weights_shape = shape(attention % attention_matrix) + if (.not. all(attn_weights_shape.eq.expected_attn_weights_shape)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect attention weights shape.. failed' + end if + attn_weights_flat = reshape(attention % attention_matrix, shape(attn_weights_flat)) + if (.not. all(attn_weights_flat.eq.expected_attn_weights_flat)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect attention weights values.. failed' + end if end subroutine test_multihead_attention_forward subroutine test_multihead_attention_forward_reallife_shape(ok) From 5fc5a5b3ec5356647459805bf20cddf487b78626 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 9 Feb 2025 19:14:10 +0400 Subject: [PATCH 39/71] multihead_attention: remove some of the copypaste comments --- src/nf/nf_multihead_attention.f90 | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index d17ae7eb..2f7ef156 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -60,11 +60,8 @@ module subroutine backward(self, input, gradient) !! Self Attention: sum output gradients !! Cross Attention: use them separately class(multihead_attention_layer), intent(in out) :: self - !! Dense layer instance real, intent(in) :: input(:, :, :) - !! Input from the previous layer real, intent(in) :: gradient(:, :, :) - !! Gradient from the next layer end subroutine backward module subroutine forward(self, query, key, value) From 65fd88d6fdad95906abb0638f8318a0ec8ae59ec Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Wed, 12 Feb 2025 20:50:35 +0400 Subject: [PATCH 40/71] multihead_attention: optimize shapes --- src/nf/nf_multihead_attention.f90 | 237 ++++++++++++------------ test/test_multihead_attention_layer.f90 | 99 +++++----- 2 files changed, 173 insertions(+), 163 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 2f7ef156..47ef1109 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -22,15 +22,15 @@ module nf_multihead_attention_layer type(softmax) :: softmax_func - real, allocatable :: attention_matrix(:, :, :, :) - real, allocatable :: sdpa(:, :, :, :) - real, allocatable :: output(:, :, :) + real, allocatable :: attention_matrix(:, :, :) + real, allocatable :: sdpa(:, :, :) + real, allocatable :: output(:, :) real :: scaling_factor - real, allocatable :: q_input(:, :, :) - real, allocatable :: k_input(:, :, :) - real, allocatable :: v_input(:, :, :) + real, allocatable :: q_input(:, :) + real, allocatable :: k_input(:, :) + real, allocatable :: v_input(:, :) contains procedure :: backward @@ -60,8 +60,8 @@ module subroutine backward(self, input, gradient) !! Self Attention: sum output gradients !! Cross Attention: use them separately class(multihead_attention_layer), intent(in out) :: self - real, intent(in) :: input(:, :, :) - real, intent(in) :: gradient(:, :, :) + real, intent(in) :: input(:, :) + real, intent(in) :: gradient(:, :) end subroutine backward module subroutine forward(self, query, key, value) @@ -70,7 +70,7 @@ module subroutine forward(self, query, key, value) !! Self Attention: pass the same value thrice !! Cross Attention: pass three values for your query, key and value class(multihead_attention_layer), intent(in out) :: self - real, intent(in) :: query(:, :, :), key(:, :, :), value(:, :, :) + real, intent(in) :: query(:, :), key(:, :), value(:, :) end subroutine forward module subroutine init(self, input_shape) @@ -84,11 +84,9 @@ end subroutine init end interface contains - module function multihead_attention_layer_cons(& - batch_size, sequence_length, model_dimension, n_heads) result(res) - integer, intent(in) :: batch_size, sequence_length, model_dimension, n_heads + module function multihead_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res) + integer, intent(in) :: sequence_length, model_dimension, n_heads type(multihead_attention_layer) :: res - res % batch_size = batch_size res % sequence_length = sequence_length res % model_dimension = model_dimension res % n_heads = n_heads @@ -99,13 +97,10 @@ module function multihead_attention_layer_cons(& end if res % head_size = model_dimension / n_heads - res % query_layer = linear2d_layer(& - sequence_length=sequence_length, in_features=model_dimension,& - out_features=model_dimension, batch_size=batch_size& - ) - res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, batch_size) - res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, batch_size) - res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, batch_size) + res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1) + res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1) + res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1) + res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1) call res % query_layer % init([0]) call res % key_layer % init([0]) call res % value_layer % init([0]) @@ -116,49 +111,56 @@ end function multihead_attention_layer_cons module subroutine backward(self, input, gradient) class(multihead_attention_layer), intent(in out) :: self - real, intent(in) :: input(:, :, :) - real, intent(in) :: gradient(:, :, :) - - real, allocatable :: d_output(:, :, :, :) - real, allocatable :: v_heads(:, :, :, :) - real, allocatable :: k_heads(:, :, :, :) - real, allocatable :: q_heads(:, :, :, :) - real, allocatable :: dv(:, :, :, :) + real, intent(in) :: input(:, :) + real, intent(in) :: gradient(:, :) + + real, allocatable :: d_output(:, :, :) + real, allocatable :: v_heads(:, :, :) + real, allocatable :: k_heads(:, :, :) + real, allocatable :: q_heads(:, :, :) + real, allocatable :: dv(:, :, :) real, allocatable :: d_sdpa(:, :) real, allocatable :: jacobian(:, :) - real, allocatable :: d_normalize(:, :, :, :) - real, allocatable :: dq(:, :, :, :) - real, allocatable :: dk(:, :, :, :) - integer :: batch, head, seq, i, j + real, allocatable :: d_normalize(:, :, :) + real, allocatable :: dq(:, :, :) + real, allocatable :: dk(:, :, :) + integer :: head, seq, i, j ! allocate temporary storages for backward computation - allocate(d_output(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) - allocate(v_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) - allocate(k_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) - allocate(q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) + allocate(d_output(self % sequence_length, self % head_size, self % n_heads)) + allocate(v_heads(self % sequence_length, self % head_size, self % n_heads)) + allocate(k_heads(self % sequence_length, self % head_size, self % n_heads)) + allocate(q_heads(self % sequence_length, self % head_size, self % n_heads)) - allocate(dv(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) + allocate(dv(self % sequence_length, self % head_size, self % n_heads)) allocate(d_sdpa(self % sequence_length, self % sequence_length)) allocate(jacobian(self % sequence_length, self % sequence_length)) - allocate(d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)) - allocate(dq(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) - allocate(dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) + allocate(d_normalize(self % sequence_length, self % sequence_length, self % n_heads)) + allocate(dq(self % sequence_length, self % head_size, self % n_heads)) + allocate(dk(self % sequence_length, self % head_size, self % n_heads)) ! calculate output layer delta - call self % output_layer % backward(input, gradient) + ! FIXME: remove reshapes when linear2d situation is resolved + call self % output_layer % backward(& + reshape(input, [self % sequence_length, self % model_dimension, 1]),& + reshape(gradient, [self % sequence_length, self % model_dimension, 1])& + ) ! split heads from output gradient - d_output = self % split_heads(self % output_layer % gradient) - v_heads = self % split_heads(self % value_layer % output) - k_heads = self % split_heads(self % key_layer % output) - q_heads = self % split_heads(self % query_layer % output) + ! FIXME: remove reshapes when linear2d situation is resolved + d_output = self % split_heads(& + reshape(self % output_layer % gradient, [self % sequence_length, self % model_dimension])) + v_heads = self % split_heads(& + reshape(self % value_layer % output, [self % sequence_length, self % model_dimension])) + k_heads = self % split_heads(reshape(self % key_layer % output, [self % sequence_length, self % model_dimension])) + q_heads = self % split_heads(reshape(self % query_layer % output, [self % sequence_length, self % model_dimension])) ! iterate over heads to calculate deltas for each of them - do concurrent(batch = 1: self % batch_size, head = 1: self % n_heads) - dv(head, :, :, batch) = matmul(transpose(self % attention_matrix(head, :, :, batch)), d_output(head, :, :, batch)) + do concurrent(head = 1: self % n_heads) + dv(:, :, head) = matmul(transpose(self % attention_matrix(:, :, head)), d_output(:, :, head)) ! calculate delta for attention matrix - d_sdpa = matmul(d_output(head, :, :, batch), transpose(v_heads(head, :, :, batch))) + d_sdpa = matmul(d_output(:, :, head), transpose(v_heads(:, :, head))) ! this monstrosity below is scaled derivative of softmax do concurrent(seq = 1: self % sequence_length) @@ -170,35 +172,45 @@ module subroutine backward(self, input, gradient) ! for off-diagonal: `-softmax(x_i) * softmax(x_j)` if (i == j) then jacobian(i, j) = & - self % attention_matrix(head, seq, i, batch) & - * (1 - self % attention_matrix(head, seq, i, batch)) + self % attention_matrix(seq, i, head) & + * (1 - self % attention_matrix(seq, i, head)) else jacobian(i, j) = & - - self % attention_matrix(head, seq, i, batch) & - * self % attention_matrix(head, seq, j, batch) + - self % attention_matrix(seq, i, head) & + * self % attention_matrix(seq, j, head) end if end do ! attention normalization delta, the last step of softmax derivative: ! multiply output of softmax by temp jacobian matrix ! For computational efficiency (avoid more temp storages), scaling is also done here ! reshapes: [3] -> [1, 3] @ [3, 3] = [1, 3] -> [3] - d_normalize(head, seq, :, batch) = reshape(matmul(& + d_normalize(seq, :, head) = reshape(matmul(& reshape(d_sdpa(seq, :), [1, self % sequence_length]),& jacobian * self % scaling_factor& ), [self % sequence_length]) end do ! calculate delta for query - dq(head, :, :, batch) = matmul(d_normalize(head, :, :, batch), k_heads(head, :, :, batch)) + dq(:, :, head) = matmul(d_normalize(:, :, head), k_heads(:, :, head)) ! calculate delta for key, attention matrix should be transposed unlike for query - dk(head, :, :, batch) = matmul(transpose(d_normalize(head, :, :, batch)), q_heads(head, :, :, batch)) + dk(:, :, head) = matmul(transpose(d_normalize(:, :, head)), q_heads(:, :, head)) end do ! calculate deltas for input layers - call self % value_layer % backward(self % v_input, self % combine_heads(dv)) - call self % key_layer % backward(self % k_input, self % combine_heads(dk)) - call self % query_layer % backward(self % q_input, self % combine_heads(dq)) + ! FIXME: remove reshapes when linear2d situation is resolved + call self % value_layer % backward(& + reshape(self % v_input, [self % sequence_length, self % model_dimension, 1]),& + reshape(self % combine_heads(dv), [self % sequence_length, self % model_dimension, 1])& + ) + call self % key_layer % backward(& + reshape(self % k_input, [self % sequence_length, self % model_dimension, 1]),& + reshape(self % combine_heads(dk), [self % sequence_length, self % model_dimension, 1])& + ) + call self % query_layer % backward(& + reshape(self % q_input, [self % sequence_length, self % model_dimension, 1]),& + reshape(self % combine_heads(dq), [self % sequence_length, self % model_dimension, 1])& + ) ! free temporary storages deallocate(d_output) @@ -214,30 +226,32 @@ end subroutine backward module subroutine forward(self, query, key, value) class(multihead_attention_layer), intent(in out) :: self - real, intent(in) :: query(:, :, :), key(:, :, :), value(:, :, :) + real, intent(in) :: query(:, :), key(:, :), value(:, :) - real, allocatable :: q(:, :, :, :) - real, allocatable :: k(:, :, :, :) - real, allocatable :: v(:, :, :, :) + real, allocatable :: q(:, :, :) + real, allocatable :: k(:, :, :) + real, allocatable :: v(:, :, :) ! allocate storage for intermidiate stages - allocate(q(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) - allocate(k(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) - allocate(v(self % n_heads, self % sequence_length, self % head_size, self % batch_size)) + allocate(q(self % sequence_length, self % head_size, self % n_heads)) + allocate(k(self % sequence_length, self % head_size, self % n_heads)) + allocate(v(self % sequence_length, self % head_size, self % n_heads)) self % q_input = query self % k_input = key self % v_input = value ! run inputs through linear layers (trainable params) - call self % query_layer % forward(query) - call self % key_layer % forward(key) - call self % value_layer % forward(value) + ! FIXME: remove reshapes when linear2d situation is resolved + call self % query_layer % forward(reshape(query, [self % sequence_length, self % model_dimension, 1])) + call self % key_layer % forward(reshape(key, [self % sequence_length, self % model_dimension, 1])) + call self % value_layer % forward(reshape(value, [self % sequence_length, self % model_dimension, 1])) ! split attention heads for more efficient computation - q = self % split_heads(self % query_layer % output) - k = self % split_heads(self % key_layer % output) - v = self % split_heads(self % value_layer % output) + ! FIXME: remove reshapes when linear2d situation is resolved + q = self % split_heads(reshape(self % query_layer % output, [self % sequence_length, self % model_dimension])) + k = self % split_heads(reshape(self % key_layer % output, [self % sequence_length, self % model_dimension])) + v = self % split_heads(reshape(self % value_layer % output, [self % sequence_length, self % model_dimension])) ! create key by value matrix call self % create_attention_matrix(q, k) @@ -246,8 +260,10 @@ module subroutine forward(self, query, key, value) ! multiply attention matrix by value call self % scaled_dot_product_attention(v) - call self % output_layer % forward(self % combine_heads(self % sdpa)) - self % output = self % output_layer % output + ! FIXME: remove reshapes when linear2d situation is resolved + call self % output_layer % forward(& + reshape(self % combine_heads(self % sdpa), [self % sequence_length, self % model_dimension, 1])) + self % output = reshape(self % output_layer % output, [self % sequence_length, self % model_dimension]) ! free temp vars from memory deallocate(q) @@ -262,41 +278,36 @@ module function split_heads(self, input) result(output) !! input (3, 4, 1) !! output (2, 3, 2, 1) class(multihead_attention_layer) :: self - real :: input(:, :, :) - real :: output(self % n_heads, self % sequence_length, self % head_size, self % batch_size) - ! FIXME: if anybody knows how to also swap first two dims in one go, pls tell me - output = reshape(& - input,& - [self % n_heads, self % sequence_length, self % head_size, self % batch_size],& - order=[2, 4, 3, 1]& - ) + real :: input(:, :) + real :: output(self % sequence_length, self % head_size, self % n_heads) + output = reshape(input, [self % sequence_length, self % head_size, self % n_heads]) end function split_heads module subroutine create_attention_matrix(self, query, key) !! Create attention matrix for query and key !! Output dimensions: n_heads, sequence_length, sequence_length, batch_size class(multihead_attention_layer) :: self - real :: query(:, :, :, :) - real :: key(:, :, :, :) - integer :: i, j + real :: query(:, :, :) + real :: key(:, :, :) + integer :: head ! create attention matrix for each sequence in each batch - do concurrent(i = 1: self % batch_size, j = 1: self % n_heads) - self % attention_matrix(j, :, :, i) = matmul(query(j, :, :, i), transpose(key(j, :, :, i))) + do concurrent(head = 1: self % n_heads) + self % attention_matrix(:, :, head) = matmul(query(:, :, head), transpose(key(:, :, head))) end do end subroutine create_attention_matrix module subroutine normalize_attention_matrix(self, attention_mask) !! Create attention matrix for query and key - !! Output dims: n_heads, sequence_length, sequence_length, batch_size + !! Output dims: sequence_length, sequence_length, n_heads class(multihead_attention_layer) :: self - !! (batch_size, n_heads, sequence_length, sequence_length) - real, optional :: attention_mask(:, :, :, :) - !! (batch_size, n_heads, sequence_length, sequence_length) - real, allocatable :: output(:, :, :, :) - integer :: batch, head, seq + !! (sequence_length, sequence_length, n_heads) + real, optional :: attention_mask(:, :, :) + !! (sequence_length, sequence_length, n_heads) + real, allocatable :: output(:, :, :) + integer :: head, seq ! temporary storage - allocate(output(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size)) + allocate(output(self % sequence_length, self % sequence_length, self % n_heads)) ! scale dowm by square root of each head's size self % attention_matrix = self % attention_matrix * self % scaling_factor @@ -305,8 +316,8 @@ module subroutine normalize_attention_matrix(self, attention_mask) self % attention_matrix = self % attention_matrix + attention_mask end if ! softmax by last sequnce_length - do concurrent(batch = 1: self % batch_size, head = 1: self % n_heads, seq = 1: self % sequence_length) - output(head, seq, :, batch) = self % softmax_func % eval_1d(self % attention_matrix(head, seq, :, batch)) + do concurrent(head = 1: self % n_heads, seq = 1: self % sequence_length) + output(seq, :, head) = self % softmax_func % eval_1d(self % attention_matrix(seq, :, head)) end do self % attention_matrix = output @@ -317,23 +328,23 @@ module subroutine scaled_dot_product_attention(self, value) !! Create scaled dot product attention !! Output dims: n_heads, sequence_length, head_size, batch_size class(multihead_attention_layer) :: self - real :: value(:, :, :, :) - integer :: batch, head + real :: value(:, :, :) + integer :: head - do concurrent(batch = 1: self % batch_size, head = 1: self % n_heads) - self % sdpa(head, :, :, batch) = matmul(self % attention_matrix(head, :, :, batch), value(head, :, :, batch)) + do concurrent(head = 1: self % n_heads) + self % sdpa(:, :, head) = matmul(self % attention_matrix(:, :, head), value(:, :, head)) end do end subroutine scaled_dot_product_attention module function combine_heads(self, input) result(output) class(multihead_attention_layer) :: self - real :: input(:, :, :, :) - !! (n_heads, sequence_length, head_size, batch_size) - real :: output(self % sequence_length, self % model_dimension, self % batch_size) - integer :: batch, seq + real :: input(:, :, :) + !! (sequence_length, head_size, n_heads) + real :: output(self % sequence_length, self % model_dimension) + integer :: seq - do concurrent(batch = 1: self % batch_size, seq = 1: self % sequence_length) - output(seq, :, batch) = reshape(transpose(input(:, seq, :, batch)), [self % model_dimension]) + do concurrent(seq = 1: self % sequence_length) + output(seq, :) = reshape(transpose(input(seq, :, :)), [self % model_dimension]) end do end function combine_heads @@ -341,18 +352,14 @@ module subroutine init(self, input_shape) class(multihead_attention_layer), intent(in out) :: self integer, intent(in) :: input_shape(:) - allocate(self % attention_matrix(& - self % n_heads, self % sequence_length, self % sequence_length, self % batch_size& - )) - allocate(self % sdpa(& - self % n_heads, self % sequence_length, self % head_size, self % batch_size& - )) - allocate(self % output(self % sequence_length, self % model_dimension, self % batch_size)) + allocate(self % attention_matrix(self % sequence_length, self % sequence_length, self % n_heads)) + allocate(self % sdpa(self % sequence_length, self % head_size, self % n_heads)) + allocate(self % output(self % sequence_length, self % model_dimension)) self % scaling_factor = sqrt(1 / real(self % head_size)) - allocate(self % q_input(self % sequence_length, self % model_dimension, self % batch_size)) - allocate(self % k_input(self % sequence_length, self % model_dimension, self % batch_size)) - allocate(self % v_input(self % sequence_length, self % model_dimension, self % batch_size)) + allocate(self % q_input(self % sequence_length, self % model_dimension)) + allocate(self % k_input(self % sequence_length, self % model_dimension)) + allocate(self % v_input(self % sequence_length, self % model_dimension)) end subroutine init end module nf_multihead_attention_layer diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index afdaa9c0..9046d92c 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -6,12 +6,14 @@ program test_multihead_attention_layer logical :: ok = .true. type(multihead_attention_layer) :: attention - real :: sample_input(3, 4, 1) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [3, 4, 1]) - real :: split_heads_output(2, 3, 2, 1) + real :: sample_input(3, 4) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [3, 4]) + real :: split_heads_output(3, 2, 2) + real :: minput(3, 4) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [3, 4]) + real :: output(3, 2, 2) - attention = multihead_attention_layer(batch_size=1, sequence_length=3, model_dimension=4, n_heads=2) + attention = multihead_attention_layer(sequence_length=3, model_dimension=4, n_heads=2) call attention % init([0]) - +! call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output) call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok) call test_multihead_attention_normalization(attention, ok) @@ -19,18 +21,18 @@ program test_multihead_attention_layer call test_multihead_attention_combine_heads(attention, attention % sdpa, ok) call test_multihead_attention_forward(attention, ok) call test_multihead_attention_backward(attention, ok) - call test_multihead_attention_forward_reallife_shape(ok) +! call test_multihead_attention_forward_reallife_shape(ok) contains subroutine test_multihead_attention_split_heads(attention, input, ok, output) type(multihead_attention_layer), intent(in) :: attention - real, intent(in) :: input(:, :, :) + real, intent(in) :: input(:, :) logical, intent(in out) :: ok - real, intent(in out) :: output(2, 3, 2, 1) - real :: output_shape(4) - real :: expected_shape(4) = [2, 3, 2, 1] + real, intent(in out) :: output(3, 2, 2) + real :: output_shape(3) + real :: expected_shape(3) = [3, 2, 2] real :: output_flat(12) - real :: expected_output_flat(12) = [0.0, 0.6, 0.1, 0.7, 0.2, 0.8, 0.3, 0.9, 0.4, 0.11, 0.5, 0.12] + real :: expected_output_flat(12) = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12] output = attention % split_heads(input) @@ -48,18 +50,15 @@ end subroutine test_multihead_attention_split_heads subroutine test_multihead_attention_create_attention_matrix(attention, input, ok) type(multihead_attention_layer), intent(in) :: attention - real, intent(in) :: input(:, :, :, :) + real, intent(in) :: input(:, :, :) logical, intent(in out) :: ok - real :: attention_matrix_shape(4) + real :: attention_matrix_shape(3) real :: attention_matrix_flat(18) - real :: expected_shape(4) = [2, 3, 3, 1] + real :: expected_shape(3) = [3, 3, 2] real :: expected_attention_matrix_flat(18) = [& - 9.00000036E-02, 1.16999996, 0.120000005,& - 0.518999994, 0.150000006, 0.588000000,& - 0.120000005, 0.518999994, 0.170000002,& - 0.502099991, 0.219999999, 0.573199987,& - 0.150000006, 0.588000000, 0.219999999,& - 0.573199987, 0.289999992, 0.654400051& + 9.00000036E-02, 0.120000005, 0.150000006, 0.120000005, 0.170000002, 0.219999999,& + 0.150000006, 0.219999999, 0.289999992, 1.16999996, 0.518999994, 0.588000000,& + 0.518999994, 0.502099991, 0.573199987, 0.588000000, 0.573199987, 0.654400051& ] call attention % create_attention_matrix(input, input) @@ -81,9 +80,9 @@ subroutine test_multihead_attention_normalization(attention, ok) logical, intent(in out) :: ok real :: output_flat(18) real :: expected_output_flat(18) = [& - 0.326287806, 0.435975075, 0.321620107, 0.330339372, 0.316976935, 0.329200655,& - 0.333283335, 0.275134116, 0.333194494, 0.326415271, 0.333061278, 0.325773478,& - 0.340428889, 0.288890868, 0.345185429, 0.343245387, 0.349961787, 0.345025837& + 0.326287806, 0.321620107, 0.316976935, 0.333283335, 0.333194494, 0.333061278,& + 0.340428889, 0.345185429, 0.349961787, 0.435975075, 0.330339372, 0.329200655,& + 0.275134116, 0.326415271, 0.325773478, 0.288890868, 0.343245387, 0.345025837& ] call attention % normalize_attention_matrix() @@ -97,12 +96,12 @@ end subroutine test_multihead_attention_normalization subroutine test_multihead_attention_scaled_dot_product_attention(attention, value, ok) type(multihead_attention_layer), intent(in) :: attention - real, intent(in) :: value(:, :, :, :) + real, intent(in) :: value(:, :, :) logical, intent(in out) :: ok real :: output_flat(12) real :: expected_output_flat(12) = [& - 0.101414114, 0.685291648, 0.102356538, 0.701290667, 0.103298485, 0.701582491,& - 0.401414126, 0.457309216, 0.402356565, 0.374400556, 0.403298497, 0.373518765& + 0.101414114, 0.102356538, 0.103298485, 0.401414126, 0.402356565, 0.403298497,& + 0.685291648, 0.701290667, 0.701582491, 0.457309216, 0.374400556, 0.373518765& ] call attention % scaled_dot_product_attention(value) @@ -116,13 +115,13 @@ end subroutine test_multihead_attention_scaled_dot_product_attention subroutine test_multihead_attention_combine_heads(attention, scaled_dp_att, ok) type(multihead_attention_layer), intent(in) :: attention - real, intent(in) :: scaled_dp_att(:, :, :, :) + real, intent(in) :: scaled_dp_att(:, :, :) logical, intent(in out) :: ok - real :: output(attention % sequence_length, attention % model_dimension, attention % batch_size) + real :: output(attention % sequence_length, attention % model_dimension) real :: output_flat(12) real :: expected_output_flat(12) = [& - 0.101414114, 0.102356538, 0.103298485, 0.401414126, 0.402356565, 0.403298497,& - 0.685291648, 0.701290667, 0.701582491, 0.457309216, 0.374400556, 0.373518765& + 0.101414114, 0.102356538, 0.103298485, 0.685291648, 0.701290667, 0.701582491,& + 0.401414126, 0.402356565, 0.403298497, 0.457309216, 0.374400556, 0.373518765& ] output = attention % combine_heads(scaled_dp_att) @@ -137,22 +136,22 @@ end subroutine test_multihead_attention_combine_heads subroutine test_multihead_attention_forward(attention, ok) type(multihead_attention_layer), intent(in out) :: attention logical, intent(in out) :: ok - real :: input(3, 4, 1) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4, 1]) + real :: input(3, 4) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]) real :: output(attention % sequence_length, attention % model_dimension, attention % batch_size) real :: output_flat(12) - integer :: output_shape(3) - integer :: attn_weights_shape(4) + integer :: output_shape(2) + integer :: attn_weights_shape(3) real :: attn_weights_flat(18) - integer :: expected_shape(3) = [3, 4, 1] + integer :: expected_shape(2) = [3, 4] real :: expected_output_flat(12) = [& 0.982241452, 1.00407875, 1.00444126, 0.982241452, 1.00407875, 1.00444126,& 0.982241452, 1.00407875, 1.00444126, 0.982241452, 1.00407875, 1.00444126& ] - integer :: expected_attn_weights_shape(4) = [2, 3, 3, 1] + integer :: expected_attn_weights_shape(3) = [3, 3, 2] real :: expected_attn_weights_flat(18) = [& - 7.89450705E-02, 7.89450705E-02, 2.28110179E-02, 2.28110179E-02, 2.18846574E-02, 2.18846574E-02,& - 0.447508544, 0.447508544, 0.464612424, 0.464612424, 0.464721352, 0.464721352,& - 0.473546445, 0.473546445, 0.512576580, 0.512576580, 0.513393998, 0.513393998& + 7.89450705E-02, 2.28110179E-02, 2.18846574E-02, 0.447508544, 0.464612424, 0.464721352,& + 0.473546445, 0.512576580, 0.513393998, 7.89450705E-02, 2.28110179E-02, 2.18846574E-02,& + 0.447508544, 0.464612424, 0.464721352, 0.473546445, 0.512576580, 0.513393998& ] call attention % forward(input, input, input) @@ -182,17 +181,17 @@ end subroutine test_multihead_attention_forward subroutine test_multihead_attention_forward_reallife_shape(ok) logical, intent(in out) :: ok - real :: input(148, 512, 2) - real :: output(148, 512, 2) + real :: input(148, 512) + real :: output(148, 512) type(linear2d_layer) :: q real :: output_flat(12) - integer :: output_shape(3) - integer :: expected_shape(3) = [148, 512, 2] + integer :: output_shape(2) + integer :: expected_shape(2) = [148, 512] type(multihead_attention_layer) :: attention call random_number(input) - attention = multihead_attention_layer(batch_size=2, sequence_length=148, model_dimension=512, n_heads=8) + attention = multihead_attention_layer(sequence_length=148, model_dimension=512, n_heads=8) call attention % init([0]) call attention % forward(input, input, input) @@ -207,23 +206,27 @@ end subroutine test_multihead_attention_forward_reallife_shape subroutine test_multihead_attention_backward(attention, ok) type(multihead_attention_layer), intent(in out) :: attention logical, intent(in out) :: ok - real :: input(3, 4, 1) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4, 1]) - real :: gradient(3, 4, 1) = reshape([0.1, 3. , 2. , 0.1, 3. , 3. , 0.1, 2. , 0.1, 3. , 0.1, 3. ], [3, 4, 1]) + real :: input(3, 4) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]) + real :: gradient(3, 4) = reshape([0.1, 3., 2., 0.1, 3., 3., 0.1, 2., 0.1, 3., 0.1, 3.], [3, 4]) real :: expected_output_flat(12) = [& -2.29912549E-02, 0.381484956, 0.453185737,& -2.29912549E-02, 0.381484956, 0.453185737,& -2.29912549E-02, 0.381484956, 0.453185737,& -2.29912549E-02, 0.381484956, 0.453185737& ] - real :: expected_shape(3) = [3, 4, 1] - real :: output(3, 4, 1) + real :: expected_shape(2) = [3, 4] + real :: output(3, 4) real :: output_flat(12) - real :: output_shape(3) + real :: output_shape(2) call attention % backward(input, gradient) ! sample for Self Attention: sum of output gradients - output = attention % query_layer % gradient + attention % key_layer % gradient + attention % value_layer % gradient + ! FIXME: remove reshapes when linear2d situation is resolved + output = & + reshape(attention % query_layer % gradient, [attention % sequence_length, attention % model_dimension]) & + + reshape(attention % key_layer % gradient, [attention % sequence_length, attention % model_dimension]) & + + reshape(attention % value_layer % gradient, [attention % sequence_length, attention % model_dimension]) output_shape = shape(output) if (.not. all(output_shape.eq.expected_shape)) then From fbc132d59dca7bcfde2d8d4efcec7ceb149f34fd Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Fri, 14 Feb 2025 20:59:37 +0400 Subject: [PATCH 41/71] multihead_attention: params api --- src/nf/nf_multihead_attention.f90 | 88 +++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 47ef1109..228f0c1a 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -40,6 +40,10 @@ module nf_multihead_attention_layer procedure :: normalize_attention_matrix procedure :: scaled_dot_product_attention procedure :: combine_heads + procedure :: get_num_params + procedure :: get_params + procedure :: get_gradients + procedure :: set_params procedure :: init end type multihead_attention_layer @@ -348,6 +352,90 @@ module function combine_heads(self, input) result(output) end do end function combine_heads + module function get_num_params(self) result(num_params) + class(multihead_attention_layer) :: self + integer :: num_params + + num_params = & + self % query_layer % get_num_params() & + + self % key_layer % get_num_params() & + + self % value_layer % get_num_params() & + + self % output_layer % get_num_params() + end function get_num_params + + module function get_params(self) result(params) + class(multihead_attention_layer), intent(in), target :: self + real, allocatable :: params(:) + + params = [& + self % query_layer % weights,& + self % key_layer % weights,& + self % value_layer % weights,& + self % output_layer % weights,& + self % query_layer % biases,& + self % key_layer % biases,& + self % value_layer % biases,& + self % output_layer % biases& + ] + end function get_params + + module function get_gradients(self) result(gradients) + class(multihead_attention_layer), intent(in), target :: self + real, allocatable :: gradients(:) + + gradients = [ & + self % query_layer % dw,& + self % key_layer % dw,& + self % value_layer % dw,& + self % output_layer % dw,& + self % query_layer % db,& + self % key_layer % db,& + self % value_layer % db,& + self % output_layer % db& + ] + end function get_gradients + + module subroutine set_params(self, params) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in), target :: params(:) + real, pointer :: p_(:,:) => null() + integer :: i, j, window + + ! check if the number of parameters is correct + if (size(params) /= self % get_num_params()) then + error stop 'Error: number of parameters does not match' + end if + + ! FIXME: looks clumsy, better ideas? + window = self % model_dimension * self % model_dimension + i = 1 + j = window + self % query_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) + i = j + 1 + j = i + window - 1 + self % key_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) + i = j + 1 + j = i + window - 1 + self % value_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) + i = j + 1 + j = i + window - 1 + self % output_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) + + window = self % model_dimension + i = j + 1 + j = i + window - 1 + self % query_layer % biases = params(i: j) + i = j + 1 + j = i + window - 1 + self % key_layer % biases = params(i: j) + i = j + 1 + j = i + window - 1 + self % value_layer % biases = params(i: j) + i = j + 1 + j = i + window - 1 + self % output_layer % biases = params(i: j) + end subroutine set_params + module subroutine init(self, input_shape) class(multihead_attention_layer), intent(in out) :: self integer, intent(in) :: input_shape(:) From 5422e4c1eb6adccb5cae2ff45e9ae89068450716 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Fri, 14 Feb 2025 21:11:13 +0400 Subject: [PATCH 42/71] multihead_attention: fix incorrect dw bug --- src/nf/nf_multihead_attention.f90 | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 228f0c1a..62a4e60c 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -31,6 +31,7 @@ module nf_multihead_attention_layer real, allocatable :: q_input(:, :) real, allocatable :: k_input(:, :) real, allocatable :: v_input(:, :) + real, allocatable :: o_input(:, :) contains procedure :: backward @@ -146,7 +147,7 @@ module subroutine backward(self, input, gradient) ! calculate output layer delta ! FIXME: remove reshapes when linear2d situation is resolved call self % output_layer % backward(& - reshape(input, [self % sequence_length, self % model_dimension, 1]),& + reshape(self % o_input, [self % sequence_length, self % model_dimension, 1]),& reshape(gradient, [self % sequence_length, self % model_dimension, 1])& ) @@ -265,8 +266,8 @@ module subroutine forward(self, query, key, value) call self % scaled_dot_product_attention(v) ! FIXME: remove reshapes when linear2d situation is resolved - call self % output_layer % forward(& - reshape(self % combine_heads(self % sdpa), [self % sequence_length, self % model_dimension, 1])) + self % o_input = self % combine_heads(self % sdpa) + call self % output_layer % forward(reshape(self % o_input, [self % sequence_length, self % model_dimension, 1])) self % output = reshape(self % output_layer % output, [self % sequence_length, self % model_dimension]) ! free temp vars from memory @@ -449,5 +450,6 @@ module subroutine init(self, input_shape) allocate(self % q_input(self % sequence_length, self % model_dimension)) allocate(self % k_input(self % sequence_length, self % model_dimension)) allocate(self % v_input(self % sequence_length, self % model_dimension)) + allocate(self % o_input(self % sequence_length, self % model_dimension)) end subroutine init end module nf_multihead_attention_layer From 39637e79e51131867b8d4760cdf377c6eeef8ac0 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Fri, 14 Feb 2025 21:33:46 +0400 Subject: [PATCH 43/71] multihead_attention: tests for updated parameters --- test/test_multihead_attention_layer.f90 | 44 +++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index 9046d92c..7c0c2cda 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -2,6 +2,7 @@ program test_multihead_attention_layer use iso_fortran_env, only: stderr => error_unit use nf_multihead_attention_layer, only: multihead_attention_layer use nf_linear2d_layer, only: linear2d_layer + use nf_optimizers, only: sgd implicit none logical :: ok = .true. @@ -21,6 +22,7 @@ program test_multihead_attention_layer call test_multihead_attention_combine_heads(attention, attention % sdpa, ok) call test_multihead_attention_forward(attention, ok) call test_multihead_attention_backward(attention, ok) + call test_multihead_attention_update_gradients(attention, ok) ! call test_multihead_attention_forward_reallife_shape(ok) contains @@ -239,4 +241,46 @@ subroutine test_multihead_attention_backward(attention, ok) write(stderr, '(a)') 'backward returned incorrect values.. failed' end if end subroutine test_multihead_attention_backward + + subroutine test_multihead_attention_update_gradients(attention, ok) + type(multihead_attention_layer), intent(in out) :: attention + logical, intent(in out) :: ok + real :: parameters(80) + real :: expected_parameters(80) + real :: updated_output(12) + real :: expected_updated_output(12) = [& + 0.111365855, 0.115744293, 0.115733206, 0.185253710, 0.196646214, 0.196617395,& + -0.102874994, -0.118834510, -0.118794113, 0.179314315, 0.190210193, 0.190182626& + ] + type(sgd) :: optim + + if (attention % get_num_params() /= 80) then + ok = .false. + write(stderr, '(a)') 'incorrect number of parameters.. failed' + end if + + expected_parameters(1: 64) = 0.100000001 + expected_parameters(65: 80) = 0.109999999 + parameters = attention % get_params() + if (.not. all(parameters.eq.expected_parameters)) then + ok = .false. + write(stderr, '(a)') 'incorrect parameters.. failed' + end if + + optim = SGD(learning_rate=0.01) + call optim % minimize(parameters, attention % get_gradients()) + call attention % set_params(parameters) + + call attention % forward(& + reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]),& + reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]),& + reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4])& + ) + + updated_output = reshape(attention % output, [12]) + if (.not. all(updated_output.eq.expected_updated_output)) then + ok = .false. + write(stderr, '(a)') 'incorrect output after parameters update.. failed' + end if + end subroutine test_multihead_attention_update_gradients end program test_multihead_attention_layer \ No newline at end of file From 60a49db56a00e088d90533e9ba73feb2e2f1e1dc Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 16 Feb 2025 18:13:38 +0400 Subject: [PATCH 44/71] multihead_attention: remove reshape crutches --- src/nf/nf_multihead_attention.f90 | 60 +++++++++++-------------------- 1 file changed, 20 insertions(+), 40 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 62a4e60c..b9955287 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -102,10 +102,10 @@ module function multihead_attention_layer_cons(sequence_length, model_dimension, end if res % head_size = model_dimension / n_heads - res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1) - res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1) - res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1) - res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension, 1) + res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) + res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) + res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) + res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) call res % query_layer % init([0]) call res % key_layer % init([0]) call res % value_layer % init([0]) @@ -145,20 +145,13 @@ module subroutine backward(self, input, gradient) allocate(dk(self % sequence_length, self % head_size, self % n_heads)) ! calculate output layer delta - ! FIXME: remove reshapes when linear2d situation is resolved - call self % output_layer % backward(& - reshape(self % o_input, [self % sequence_length, self % model_dimension, 1]),& - reshape(gradient, [self % sequence_length, self % model_dimension, 1])& - ) + call self % output_layer % backward(self % o_input, gradient) ! split heads from output gradient - ! FIXME: remove reshapes when linear2d situation is resolved - d_output = self % split_heads(& - reshape(self % output_layer % gradient, [self % sequence_length, self % model_dimension])) - v_heads = self % split_heads(& - reshape(self % value_layer % output, [self % sequence_length, self % model_dimension])) - k_heads = self % split_heads(reshape(self % key_layer % output, [self % sequence_length, self % model_dimension])) - q_heads = self % split_heads(reshape(self % query_layer % output, [self % sequence_length, self % model_dimension])) + d_output = self % split_heads(self % output_layer % gradient) + v_heads = self % split_heads(self % value_layer % output) + k_heads = self % split_heads(self % key_layer % output) + q_heads = self % split_heads(self % query_layer % output) ! iterate over heads to calculate deltas for each of them do concurrent(head = 1: self % n_heads) @@ -203,19 +196,9 @@ module subroutine backward(self, input, gradient) end do ! calculate deltas for input layers - ! FIXME: remove reshapes when linear2d situation is resolved - call self % value_layer % backward(& - reshape(self % v_input, [self % sequence_length, self % model_dimension, 1]),& - reshape(self % combine_heads(dv), [self % sequence_length, self % model_dimension, 1])& - ) - call self % key_layer % backward(& - reshape(self % k_input, [self % sequence_length, self % model_dimension, 1]),& - reshape(self % combine_heads(dk), [self % sequence_length, self % model_dimension, 1])& - ) - call self % query_layer % backward(& - reshape(self % q_input, [self % sequence_length, self % model_dimension, 1]),& - reshape(self % combine_heads(dq), [self % sequence_length, self % model_dimension, 1])& - ) + call self % value_layer % backward(self % v_input, self % combine_heads(dv)) + call self % key_layer % backward(self % k_input, self % combine_heads(dk)) + call self % query_layer % backward(self % q_input, self % combine_heads(dq)) ! free temporary storages deallocate(d_output) @@ -247,16 +230,14 @@ module subroutine forward(self, query, key, value) self % v_input = value ! run inputs through linear layers (trainable params) - ! FIXME: remove reshapes when linear2d situation is resolved - call self % query_layer % forward(reshape(query, [self % sequence_length, self % model_dimension, 1])) - call self % key_layer % forward(reshape(key, [self % sequence_length, self % model_dimension, 1])) - call self % value_layer % forward(reshape(value, [self % sequence_length, self % model_dimension, 1])) + call self % query_layer % forward(query) + call self % key_layer % forward(key) + call self % value_layer % forward(value) ! split attention heads for more efficient computation - ! FIXME: remove reshapes when linear2d situation is resolved - q = self % split_heads(reshape(self % query_layer % output, [self % sequence_length, self % model_dimension])) - k = self % split_heads(reshape(self % key_layer % output, [self % sequence_length, self % model_dimension])) - v = self % split_heads(reshape(self % value_layer % output, [self % sequence_length, self % model_dimension])) + q = self % split_heads(self % query_layer % output) + k = self % split_heads(self % key_layer % output) + v = self % split_heads(self % value_layer % output) ! create key by value matrix call self % create_attention_matrix(q, k) @@ -265,10 +246,9 @@ module subroutine forward(self, query, key, value) ! multiply attention matrix by value call self % scaled_dot_product_attention(v) - ! FIXME: remove reshapes when linear2d situation is resolved self % o_input = self % combine_heads(self % sdpa) - call self % output_layer % forward(reshape(self % o_input, [self % sequence_length, self % model_dimension, 1])) - self % output = reshape(self % output_layer % output, [self % sequence_length, self % model_dimension]) + call self % output_layer % forward(self % o_input) + self % output = self % output_layer % output ! free temp vars from memory deallocate(q) From 7ab776996567953a94440eaa7a2b28468f5c2eeb Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 16 Feb 2025 18:17:55 +0400 Subject: [PATCH 45/71] multihead_attention: rename common forward and backward calls --- src/nf/nf_multihead_attention.f90 | 20 ++++++++++---------- test/test_multihead_attention_layer.f90 | 12 ++++++------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index b9955287..b2b2283d 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -34,8 +34,8 @@ module nf_multihead_attention_layer real, allocatable :: o_input(:, :) contains - procedure :: backward - procedure :: forward + procedure :: common_backward + procedure :: common_forward procedure :: split_heads procedure :: create_attention_matrix procedure :: normalize_attention_matrix @@ -59,7 +59,7 @@ end function multihead_attention_layer_cons interface - module subroutine backward(self, input, gradient) + module subroutine common_backward(self, input, gradient) !! General backprop for MultiHead Attention mechanism !! Might be used for both Self and Cross Attention !! Self Attention: sum output gradients @@ -67,16 +67,16 @@ module subroutine backward(self, input, gradient) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :) real, intent(in) :: gradient(:, :) - end subroutine backward + end subroutine common_backward - module subroutine forward(self, query, key, value) + module subroutine common_forward(self, query, key, value) !! General forward propagation for MultiHead Attention Mechanism !! Might be used for both Self and Cross Attention !! Self Attention: pass the same value thrice !! Cross Attention: pass three values for your query, key and value class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: query(:, :), key(:, :), value(:, :) - end subroutine forward + end subroutine common_forward module subroutine init(self, input_shape) !! Initialize the layer data structures. @@ -114,7 +114,7 @@ module function multihead_attention_layer_cons(sequence_length, model_dimension, res % softmax_func = softmax() end function multihead_attention_layer_cons - module subroutine backward(self, input, gradient) + module subroutine common_backward(self, input, gradient) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :) real, intent(in) :: gradient(:, :) @@ -210,9 +210,9 @@ module subroutine backward(self, input, gradient) deallocate(d_normalize) deallocate(dq) deallocate(dk) - end subroutine backward + end subroutine common_backward - module subroutine forward(self, query, key, value) + module subroutine common_forward(self, query, key, value) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: query(:, :), key(:, :), value(:, :) @@ -254,7 +254,7 @@ module subroutine forward(self, query, key, value) deallocate(q) deallocate(k) deallocate(v) - end subroutine forward + end subroutine common_forward module function split_heads(self, input) result(output) !! Split inputs into heads diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index 7c0c2cda..900651fa 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -14,7 +14,7 @@ program test_multihead_attention_layer attention = multihead_attention_layer(sequence_length=3, model_dimension=4, n_heads=2) call attention % init([0]) -! + call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output) call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok) call test_multihead_attention_normalization(attention, ok) @@ -23,7 +23,7 @@ program test_multihead_attention_layer call test_multihead_attention_forward(attention, ok) call test_multihead_attention_backward(attention, ok) call test_multihead_attention_update_gradients(attention, ok) -! call test_multihead_attention_forward_reallife_shape(ok) + call test_multihead_attention_forward_reallife_shape(ok) contains subroutine test_multihead_attention_split_heads(attention, input, ok, output) @@ -156,7 +156,7 @@ subroutine test_multihead_attention_forward(attention, ok) 0.447508544, 0.464612424, 0.464721352, 0.473546445, 0.512576580, 0.513393998& ] - call attention % forward(input, input, input) + call attention % common_forward(input, input, input) output_shape = shape(attention % output) if (.not. all(output_shape.eq.expected_shape)) then @@ -196,7 +196,7 @@ subroutine test_multihead_attention_forward_reallife_shape(ok) attention = multihead_attention_layer(sequence_length=148, model_dimension=512, n_heads=8) call attention % init([0]) - call attention % forward(input, input, input) + call attention % common_forward(input, input, input) output_shape = shape(attention % output) if (.not. all(output_shape.eq.expected_shape)) then @@ -221,7 +221,7 @@ subroutine test_multihead_attention_backward(attention, ok) real :: output_flat(12) real :: output_shape(2) - call attention % backward(input, gradient) + call attention % common_backward(input, gradient) ! sample for Self Attention: sum of output gradients ! FIXME: remove reshapes when linear2d situation is resolved @@ -271,7 +271,7 @@ subroutine test_multihead_attention_update_gradients(attention, ok) call optim % minimize(parameters, attention % get_gradients()) call attention % set_params(parameters) - call attention % forward(& + call attention % common_forward(& reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]),& reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]),& reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4])& From 20c5eb0f2b3fde1d864ee14466b3e592762fffcc Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 16 Feb 2025 20:33:16 +0400 Subject: [PATCH 46/71] multihead_attention: tidy mha up --- src/nf/nf_multihead_attention.f90 | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index b2b2283d..49969999 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -12,8 +12,7 @@ module nf_multihead_attention_layer type, extends(base_layer) :: multihead_attention_layer !! Concrete implementation of a multihead attention layer type - - integer :: batch_size, sequence_length, model_dimension, n_heads, head_size + integer :: sequence_length, model_dimension, n_heads, head_size type(linear2d_layer) :: query_layer type(linear2d_layer) :: key_layer @@ -45,14 +44,14 @@ module nf_multihead_attention_layer procedure :: get_params procedure :: get_gradients procedure :: set_params - procedure :: init - + procedure :: init_base + procedure :: init => init_base ! in case general MHA needs to be used end type multihead_attention_layer interface multihead_attention_layer - module function multihead_attention_layer_cons(batch_size, sequence_length, model_dimension, n_heads) result(res) + module function multihead_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res) !! This function returns the `multihead_attention_layer` instance. - integer, intent(in) :: batch_size, sequence_length, model_dimension, n_heads + integer, intent(in) :: sequence_length, model_dimension, n_heads type(multihead_attention_layer) :: res end function multihead_attention_layer_cons end interface multihead_attention_layer @@ -270,7 +269,7 @@ end function split_heads module subroutine create_attention_matrix(self, query, key) !! Create attention matrix for query and key - !! Output dimensions: n_heads, sequence_length, sequence_length, batch_size + !! Output dimensions: sequence_length, sequence_length, n_heads class(multihead_attention_layer) :: self real :: query(:, :, :) real :: key(:, :, :) @@ -311,7 +310,7 @@ end subroutine normalize_attention_matrix module subroutine scaled_dot_product_attention(self, value) !! Create scaled dot product attention - !! Output dims: n_heads, sequence_length, head_size, batch_size + !! Output dims: sequence_length, head_size, n_heads class(multihead_attention_layer) :: self real :: value(:, :, :) integer :: head @@ -417,7 +416,7 @@ module subroutine set_params(self, params) self % output_layer % biases = params(i: j) end subroutine set_params - module subroutine init(self, input_shape) + module subroutine init_base(self, input_shape) class(multihead_attention_layer), intent(in out) :: self integer, intent(in) :: input_shape(:) @@ -431,5 +430,5 @@ module subroutine init(self, input_shape) allocate(self % k_input(self % sequence_length, self % model_dimension)) allocate(self % v_input(self % sequence_length, self % model_dimension)) allocate(self % o_input(self % sequence_length, self % model_dimension)) - end subroutine init + end subroutine init_base end module nf_multihead_attention_layer From 60985332526eb7a4a313ea6e8b3647081e9e3cc8 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 16 Feb 2025 20:46:38 +0400 Subject: [PATCH 47/71] multihead_attention: self attention --- src/nf/nf_self_attention_layer.f90 | 78 +++++++++++++++++++++++++ test/test_multihead_attention_layer.f90 | 41 ++++++++++++- 2 files changed, 116 insertions(+), 3 deletions(-) create mode 100644 src/nf/nf_self_attention_layer.f90 diff --git a/src/nf/nf_self_attention_layer.f90 b/src/nf/nf_self_attention_layer.f90 new file mode 100644 index 00000000..3c7ff3ce --- /dev/null +++ b/src/nf/nf_self_attention_layer.f90 @@ -0,0 +1,78 @@ +module nf_self_attention_layer + use iso_fortran_env, only: stderr => error_unit + use nf_activation, only: softmax + use nf_linear2d_layer, only: linear2d_layer + use nf_multihead_attention_layer, only: multihead_attention_layer + + implicit none + + type, extends(multihead_attention_layer) :: self_attention_layer + real, allocatable :: gradient(:, :) + contains + procedure :: forward + procedure :: backward + procedure :: init + end type self_attention_layer + + interface self_attention_layer + module function self_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res) + !! This function returns the `self_attention_layer` instance. + integer, intent(in) :: sequence_length, model_dimension, n_heads + type(self_attention_layer) :: res + end function self_attention_layer_cons + end interface self_attention_layer + +contains + module function self_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res) + !! This function returns the `self_attention_layer` instance. + integer, intent(in) :: sequence_length, model_dimension, n_heads + type(self_attention_layer) :: res + res % sequence_length = sequence_length + res % model_dimension = model_dimension + res % n_heads = n_heads + + if (mod(model_dimension, n_heads) /= 0) then + write(stderr, '(a)'), 'Number of heads must be divisible by model dimension' + error stop + end if + res % head_size = model_dimension / n_heads + + res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) + res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) + res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) + res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) + call res % query_layer % init([0]) + call res % key_layer % init([0]) + call res % value_layer % init([0]) + call res % output_layer % init([0]) + + res % softmax_func = softmax() + end function self_attention_layer_cons + + module subroutine backward(self, input, gradient) + class(self_attention_layer), intent(in out) :: self + real, intent(in) :: input(:, :) + real, intent(in) :: gradient(:, :) + + call self % common_backward(input, gradient) + self % gradient = & + self % query_layer % gradient & + + self % key_layer % gradient & + + self % value_layer % gradient + end subroutine backward + + module subroutine forward(self, input) + class(self_attention_layer), intent(in out) :: self + real, intent(in) :: input(:, :) + + call self % common_forward(input, input, input) + end subroutine forward + + module subroutine init(self, input_shape) + class(self_attention_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + call self % init_base(input_shape) + allocate(self % gradient(self % sequence_length, self % model_dimension)) + end subroutine init +end module nf_self_attention_layer diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index 900651fa..eae8998f 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -1,6 +1,7 @@ program test_multihead_attention_layer use iso_fortran_env, only: stderr => error_unit use nf_multihead_attention_layer, only: multihead_attention_layer + use nf_self_attention_layer, only: self_attention_layer use nf_linear2d_layer, only: linear2d_layer use nf_optimizers, only: sgd implicit none @@ -13,7 +14,7 @@ program test_multihead_attention_layer real :: output(3, 2, 2) attention = multihead_attention_layer(sequence_length=3, model_dimension=4, n_heads=2) - call attention % init([0]) + call attention % init_base([0]) call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output) call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok) @@ -24,6 +25,7 @@ program test_multihead_attention_layer call test_multihead_attention_backward(attention, ok) call test_multihead_attention_update_gradients(attention, ok) call test_multihead_attention_forward_reallife_shape(ok) + call test_self_attention(ok) contains subroutine test_multihead_attention_split_heads(attention, input, ok, output) @@ -139,7 +141,7 @@ subroutine test_multihead_attention_forward(attention, ok) type(multihead_attention_layer), intent(in out) :: attention logical, intent(in out) :: ok real :: input(3, 4) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]) - real :: output(attention % sequence_length, attention % model_dimension, attention % batch_size) + real :: output(attention % sequence_length, attention % model_dimension) real :: output_flat(12) integer :: output_shape(2) integer :: attn_weights_shape(3) @@ -194,7 +196,7 @@ subroutine test_multihead_attention_forward_reallife_shape(ok) call random_number(input) attention = multihead_attention_layer(sequence_length=148, model_dimension=512, n_heads=8) - call attention % init([0]) + call attention % init_base([0]) call attention % common_forward(input, input, input) @@ -283,4 +285,37 @@ subroutine test_multihead_attention_update_gradients(attention, ok) write(stderr, '(a)') 'incorrect output after parameters update.. failed' end if end subroutine test_multihead_attention_update_gradients + + subroutine test_self_attention(ok) + logical, intent(in out) :: ok + type(self_attention_layer) :: attention + real :: input(2, 3) = reshape([-1., 0., 17., .4, 5., .6], [2, 3]) + real :: output(2, 3) + real :: output_flat(6) + real :: expected_output_flat(6) = [& + 0.772716165, 0.577548742, 0.772716165, 0.577548742, 0.772716165, 0.577548742& + ] + real :: gradient(2, 3) = reshape([1., 2., .17, 4., .5, 6.], [2, 3]) + real :: gradient_flat(6) + real :: expected_gradient_flat(6) = [& + 0.350671142, 0.607403040, 0.350671142, 0.607403040, 0.350671142, 0.607403040& + ] + + attention = self_attention_layer(sequence_length=2, model_dimension=3, n_heads=1) + call attention % init([0]) + + call attention % forward(input) + output_flat = reshape(attention % output, shape(output_flat)) + if (.not. all(output_flat.eq.expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect values.. failed' + end if + + call attention % backward(input, gradient) + gradient_flat = reshape(attention % gradient, shape(gradient_flat)) + if (.not. all(gradient_flat.eq.expected_gradient_flat)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect values.. failed' + end if + end subroutine test_self_attention end program test_multihead_attention_layer \ No newline at end of file From 66b5023193f9277c3a4738f8e4b115decc75f8cf Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 16 Feb 2025 21:06:35 +0400 Subject: [PATCH 48/71] multihead_attention: add cross attention --- src/nf/nf_cross_attention_layer.f90 | 76 +++++++++++++++++++++++++ test/test_multihead_attention_layer.f90 | 48 ++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 src/nf/nf_cross_attention_layer.f90 diff --git a/src/nf/nf_cross_attention_layer.f90 b/src/nf/nf_cross_attention_layer.f90 new file mode 100644 index 00000000..ef5ae1b6 --- /dev/null +++ b/src/nf/nf_cross_attention_layer.f90 @@ -0,0 +1,76 @@ +module nf_cross_attention_layer + use iso_fortran_env, only: stderr => error_unit + use nf_activation, only: softmax + use nf_linear2d_layer, only: linear2d_layer + use nf_multihead_attention_layer, only: multihead_attention_layer + + implicit none + + type, extends(multihead_attention_layer) :: cross_attention_layer + real, allocatable :: gradient(:, :, :) + contains + procedure :: forward + procedure :: backward + procedure :: init + end type cross_attention_layer + + interface cross_attention_layer + module function cross_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res) + !! This function returns the `cross_attention_layer` instance. + integer, intent(in) :: sequence_length, model_dimension, n_heads + type(cross_attention_layer) :: res + end function cross_attention_layer_cons + end interface cross_attention_layer + +contains + module function cross_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res) + !! This function returns the `cross_attention_layer` instance. + integer, intent(in) :: sequence_length, model_dimension, n_heads + type(cross_attention_layer) :: res + res % sequence_length = sequence_length + res % model_dimension = model_dimension + res % n_heads = n_heads + + if (mod(model_dimension, n_heads) /= 0) then + write(stderr, '(a)'), 'Number of heads must be divisible by model dimension' + error stop + end if + res % head_size = model_dimension / n_heads + + res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) + res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) + res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) + res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) + call res % query_layer % init([0]) + call res % key_layer % init([0]) + call res % value_layer % init([0]) + call res % output_layer % init([0]) + + res % softmax_func = softmax() + end function cross_attention_layer_cons + + module subroutine backward(self, input, gradient) + class(cross_attention_layer), intent(in out) :: self + real, intent(in) :: input(:, :, :) + real, intent(in) :: gradient(:, :) + + call self % common_backward(input(1, :, :), gradient) + self % gradient(1, :, :) = self % query_layer % gradient + self % gradient(2, :, :) = self % key_layer % gradient + self % value_layer % gradient + end subroutine backward + + module subroutine forward(self, input) + class(cross_attention_layer), intent(in out) :: self + real, intent(in) :: input(:, :, :) + + call self % common_forward(input(1, :, :), input(2, :, :), input(2, :, :)) + end subroutine forward + + module subroutine init(self, input_shape) + class(cross_attention_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + call self % init_base(input_shape) + allocate(self % gradient(2, self % sequence_length, self % model_dimension)) + end subroutine init +end module nf_cross_attention_layer diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index eae8998f..aaa0760b 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -2,6 +2,7 @@ program test_multihead_attention_layer use iso_fortran_env, only: stderr => error_unit use nf_multihead_attention_layer, only: multihead_attention_layer use nf_self_attention_layer, only: self_attention_layer + use nf_cross_attention_layer, only: cross_attention_layer use nf_linear2d_layer, only: linear2d_layer use nf_optimizers, only: sgd implicit none @@ -26,6 +27,7 @@ program test_multihead_attention_layer call test_multihead_attention_update_gradients(attention, ok) call test_multihead_attention_forward_reallife_shape(ok) call test_self_attention(ok) + call test_cross_attention(ok) contains subroutine test_multihead_attention_split_heads(attention, input, ok, output) @@ -318,4 +320,50 @@ subroutine test_self_attention(ok) write(stderr, '(a)') 'backward returned incorrect values.. failed' end if end subroutine test_self_attention + + subroutine test_cross_attention(ok) + logical, intent(in out) :: ok + type(cross_attention_layer) :: attention + real :: query(2, 3) = reshape([-1., 0., 17., .4, 5., .6], [2, 3]) + real :: key_value(2, 3) = reshape([0.1, -.2, 0.3, 4., 15., 0.5], [2, 3]) + real :: input(2, 2, 3) + real :: output(2, 2, 3) + real :: output_flat(6) + real :: expected_output_flat(6) = [& + 0.600311756, 0.471662223, 0.600311756, 0.471662223, 0.600311756, 0.471662223& + ] + real :: gradient(2, 3) = reshape([1., 2., .17, 4., .5, 6.], [2, 3]) + real :: query_gradient_flat(6) + real :: key_value_gradient_flat(6) + real :: expected_query_gradient_flat(6) = [& + 1.48406753E-03, 0.184446245, 1.48406753E-03, 0.184446245, 1.48406753E-03, 0.184446245& + ] + real :: expected_key_value_gradient_flat(6) = [& + 0.303095698, 0.107004307, 0.303095698, 0.107004307, 0.303095698, 0.107004307& + ] + input(1, :, :) = query + input(2, :, :) = key_value + + attention = cross_attention_layer(sequence_length=2, model_dimension=3, n_heads=1) + call attention % init([0]) + + call attention % forward(input) + output_flat = reshape(attention % output, shape(output_flat)) + if (.not. all(output_flat.eq.expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect values.. failed' + end if + + call attention % backward(input, gradient) + query_gradient_flat = reshape(attention % gradient(1, :, :), shape(query_gradient_flat)) + if (.not. all(query_gradient_flat.eq.expected_query_gradient_flat)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect query values.. failed' + end if + key_value_gradient_flat = reshape(attention % gradient(2, :, :), shape(key_value_gradient_flat)) + if (.not. all(key_value_gradient_flat.eq.expected_key_value_gradient_flat)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect key-value values.. failed' + end if + end subroutine test_cross_attention end program test_multihead_attention_layer \ No newline at end of file From ac813aae751cbe9579741e7712f6ce117fabf9ea Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Sun, 16 Feb 2025 21:29:41 +0400 Subject: [PATCH 49/71] multihead_attention: add more comments --- src/nf/nf_cross_attention_layer.f90 | 9 +++++++++ src/nf/nf_self_attention_layer.f90 | 10 ++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/nf/nf_cross_attention_layer.f90 b/src/nf/nf_cross_attention_layer.f90 index ef5ae1b6..790165ff 100644 --- a/src/nf/nf_cross_attention_layer.f90 +++ b/src/nf/nf_cross_attention_layer.f90 @@ -7,6 +7,11 @@ module nf_cross_attention_layer implicit none type, extends(multihead_attention_layer) :: cross_attention_layer + !! Cross Attention Layer + !! Source: + !! Bahdanau, D. (2014) + !! Neural machine translation by jointly learning to align and translate. + !! https://arxiv.org/pdf/1409.0473 real, allocatable :: gradient(:, :, :) contains procedure :: forward @@ -50,6 +55,7 @@ module function cross_attention_layer_cons(sequence_length, model_dimension, n_h end function cross_attention_layer_cons module subroutine backward(self, input, gradient) + !! Cross Attention Back propagation class(cross_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :, :) real, intent(in) :: gradient(:, :) @@ -60,6 +66,9 @@ module subroutine backward(self, input, gradient) end subroutine backward module subroutine forward(self, input) + !! Cross Attention Forward propagation + !! Input Shape (kind, sequence_length, model_dimension) + !! where kind is 1 for Query and 2 for Key-Value class(cross_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :, :) diff --git a/src/nf/nf_self_attention_layer.f90 b/src/nf/nf_self_attention_layer.f90 index 3c7ff3ce..580cc673 100644 --- a/src/nf/nf_self_attention_layer.f90 +++ b/src/nf/nf_self_attention_layer.f90 @@ -7,6 +7,11 @@ module nf_self_attention_layer implicit none type, extends(multihead_attention_layer) :: self_attention_layer + !! Self Attention Layer + !! Source: + !! Parikh, A. P., Taeckstroem, O., Das, D., & Uszkoreit, J. (2016) + !! A decomposable attention model for natural language inference. + !! https://arxiv.org/pdf/1606.01933 real, allocatable :: gradient(:, :) contains procedure :: forward @@ -50,6 +55,8 @@ module function self_attention_layer_cons(sequence_length, model_dimension, n_he end function self_attention_layer_cons module subroutine backward(self, input, gradient) + !! Self Attention back propagation + !! Returns sum of Query, Key and Value gradients class(self_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :) real, intent(in) :: gradient(:, :) @@ -62,6 +69,9 @@ module subroutine backward(self, input, gradient) end subroutine backward module subroutine forward(self, input) + !! Cross Attention forward propagation + !! Passes input three times into MultiHead Attention + !! Input Shape: (sequence_length, model_dimension) class(self_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :) From 6b70f6ba2608269a8a0de103608857c970f86115 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Mon, 17 Feb 2025 00:04:53 +0400 Subject: [PATCH 50/71] multihead_attention: arrange attention into submodule --- src/nf/nf_multihead_attention.f90 | 428 ++++---------------- src/nf/nf_multihead_attention_submodule.f90 | 339 ++++++++++++++++ 2 files changed, 417 insertions(+), 350 deletions(-) create mode 100644 src/nf/nf_multihead_attention_submodule.f90 diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 49969999..a98c6d69 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -35,17 +35,19 @@ module nf_multihead_attention_layer procedure :: common_backward procedure :: common_forward - procedure :: split_heads - procedure :: create_attention_matrix - procedure :: normalize_attention_matrix - procedure :: scaled_dot_product_attention - procedure :: combine_heads procedure :: get_num_params procedure :: get_params procedure :: get_gradients procedure :: set_params procedure :: init_base procedure :: init => init_base ! in case general MHA needs to be used + + ! FIXME: those should be private but accessible by tests + procedure :: split_heads + procedure :: create_attention_matrix + procedure :: normalize_attention_matrix + procedure :: scaled_dot_product_attention + procedure :: combine_heads end type multihead_attention_layer interface multihead_attention_layer @@ -85,350 +87,76 @@ module subroutine init(self, input_shape) integer, intent(in) :: input_shape(:) end subroutine init - end interface - -contains - module function multihead_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res) - integer, intent(in) :: sequence_length, model_dimension, n_heads - type(multihead_attention_layer) :: res - res % sequence_length = sequence_length - res % model_dimension = model_dimension - res % n_heads = n_heads - - if (mod(model_dimension, n_heads) /= 0) then - write(stderr, '(a)'), 'Number of heads must be divisible by model dimension' - error stop - end if - res % head_size = model_dimension / n_heads - - res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - call res % query_layer % init([0]) - call res % key_layer % init([0]) - call res % value_layer % init([0]) - call res % output_layer % init([0]) - - res % softmax_func = softmax() - end function multihead_attention_layer_cons - - module subroutine common_backward(self, input, gradient) - class(multihead_attention_layer), intent(in out) :: self - real, intent(in) :: input(:, :) - real, intent(in) :: gradient(:, :) - - real, allocatable :: d_output(:, :, :) - real, allocatable :: v_heads(:, :, :) - real, allocatable :: k_heads(:, :, :) - real, allocatable :: q_heads(:, :, :) - real, allocatable :: dv(:, :, :) - real, allocatable :: d_sdpa(:, :) - real, allocatable :: jacobian(:, :) - real, allocatable :: d_normalize(:, :, :) - real, allocatable :: dq(:, :, :) - real, allocatable :: dk(:, :, :) - integer :: head, seq, i, j - - ! allocate temporary storages for backward computation - allocate(d_output(self % sequence_length, self % head_size, self % n_heads)) - allocate(v_heads(self % sequence_length, self % head_size, self % n_heads)) - allocate(k_heads(self % sequence_length, self % head_size, self % n_heads)) - allocate(q_heads(self % sequence_length, self % head_size, self % n_heads)) - - allocate(dv(self % sequence_length, self % head_size, self % n_heads)) - allocate(d_sdpa(self % sequence_length, self % sequence_length)) - allocate(jacobian(self % sequence_length, self % sequence_length)) - allocate(d_normalize(self % sequence_length, self % sequence_length, self % n_heads)) - allocate(dq(self % sequence_length, self % head_size, self % n_heads)) - allocate(dk(self % sequence_length, self % head_size, self % n_heads)) - - ! calculate output layer delta - call self % output_layer % backward(self % o_input, gradient) - - ! split heads from output gradient - d_output = self % split_heads(self % output_layer % gradient) - v_heads = self % split_heads(self % value_layer % output) - k_heads = self % split_heads(self % key_layer % output) - q_heads = self % split_heads(self % query_layer % output) - - ! iterate over heads to calculate deltas for each of them - do concurrent(head = 1: self % n_heads) - dv(:, :, head) = matmul(transpose(self % attention_matrix(:, :, head)), d_output(:, :, head)) - - ! calculate delta for attention matrix - d_sdpa = matmul(d_output(:, :, head), transpose(v_heads(:, :, head))) - - ! this monstrosity below is scaled derivative of softmax - do concurrent(seq = 1: self % sequence_length) - ! create jacobian matrix - do concurrent(i = 1: self % sequence_length, j = 1: self % sequence_length) - ! jacobian matrix is used to calculate derivative of softmax (temporary storage) - ! the idea behind this if-else is that for diagonal elements, the jacobian temp - ! should be: `softmax(x_i) * (1 - softmax(x_i))` - ! for off-diagonal: `-softmax(x_i) * softmax(x_j)` - if (i == j) then - jacobian(i, j) = & - self % attention_matrix(seq, i, head) & - * (1 - self % attention_matrix(seq, i, head)) - else - jacobian(i, j) = & - - self % attention_matrix(seq, i, head) & - * self % attention_matrix(seq, j, head) - end if - end do - ! attention normalization delta, the last step of softmax derivative: - ! multiply output of softmax by temp jacobian matrix - ! For computational efficiency (avoid more temp storages), scaling is also done here - ! reshapes: [3] -> [1, 3] @ [3, 3] = [1, 3] -> [3] - d_normalize(seq, :, head) = reshape(matmul(& - reshape(d_sdpa(seq, :), [1, self % sequence_length]),& - jacobian * self % scaling_factor& - ), [self % sequence_length]) - end do - - ! calculate delta for query - dq(:, :, head) = matmul(d_normalize(:, :, head), k_heads(:, :, head)) - - ! calculate delta for key, attention matrix should be transposed unlike for query - dk(:, :, head) = matmul(transpose(d_normalize(:, :, head)), q_heads(:, :, head)) - end do - - ! calculate deltas for input layers - call self % value_layer % backward(self % v_input, self % combine_heads(dv)) - call self % key_layer % backward(self % k_input, self % combine_heads(dk)) - call self % query_layer % backward(self % q_input, self % combine_heads(dq)) - - ! free temporary storages - deallocate(d_output) - deallocate(v_heads) - deallocate(k_heads) - deallocate(q_heads) - deallocate(d_sdpa) - deallocate(jacobian) - deallocate(d_normalize) - deallocate(dq) - deallocate(dk) - end subroutine common_backward - - module subroutine common_forward(self, query, key, value) - class(multihead_attention_layer), intent(in out) :: self - real, intent(in) :: query(:, :), key(:, :), value(:, :) - - real, allocatable :: q(:, :, :) - real, allocatable :: k(:, :, :) - real, allocatable :: v(:, :, :) - - ! allocate storage for intermidiate stages - allocate(q(self % sequence_length, self % head_size, self % n_heads)) - allocate(k(self % sequence_length, self % head_size, self % n_heads)) - allocate(v(self % sequence_length, self % head_size, self % n_heads)) - - self % q_input = query - self % k_input = key - self % v_input = value - - ! run inputs through linear layers (trainable params) - call self % query_layer % forward(query) - call self % key_layer % forward(key) - call self % value_layer % forward(value) - - ! split attention heads for more efficient computation - q = self % split_heads(self % query_layer % output) - k = self % split_heads(self % key_layer % output) - v = self % split_heads(self % value_layer % output) - - ! create key by value matrix - call self % create_attention_matrix(q, k) - ! apply softmax and scaling - call self % normalize_attention_matrix() - ! multiply attention matrix by value - call self % scaled_dot_product_attention(v) - - self % o_input = self % combine_heads(self % sdpa) - call self % output_layer % forward(self % o_input) - self % output = self % output_layer % output - - ! free temp vars from memory - deallocate(q) - deallocate(k) - deallocate(v) - end subroutine common_forward - - module function split_heads(self, input) result(output) - !! Split inputs into heads - !! - !! Example with two heads: - !! input (3, 4, 1) - !! output (2, 3, 2, 1) - class(multihead_attention_layer) :: self - real :: input(:, :) - real :: output(self % sequence_length, self % head_size, self % n_heads) - output = reshape(input, [self % sequence_length, self % head_size, self % n_heads]) - end function split_heads - - module subroutine create_attention_matrix(self, query, key) - !! Create attention matrix for query and key - !! Output dimensions: sequence_length, sequence_length, n_heads - class(multihead_attention_layer) :: self - real :: query(:, :, :) - real :: key(:, :, :) - integer :: head - ! create attention matrix for each sequence in each batch - do concurrent(head = 1: self % n_heads) - self % attention_matrix(:, :, head) = matmul(query(:, :, head), transpose(key(:, :, head))) - end do - end subroutine create_attention_matrix - - module subroutine normalize_attention_matrix(self, attention_mask) - !! Create attention matrix for query and key - !! Output dims: sequence_length, sequence_length, n_heads - class(multihead_attention_layer) :: self - !! (sequence_length, sequence_length, n_heads) - real, optional :: attention_mask(:, :, :) - !! (sequence_length, sequence_length, n_heads) - real, allocatable :: output(:, :, :) - integer :: head, seq - - ! temporary storage - allocate(output(self % sequence_length, self % sequence_length, self % n_heads)) - - ! scale dowm by square root of each head's size - self % attention_matrix = self % attention_matrix * self % scaling_factor - ! attention mask is used to mask out some of the tokens if necessary - if (present(attention_mask)) then - self % attention_matrix = self % attention_matrix + attention_mask - end if - ! softmax by last sequnce_length - do concurrent(head = 1: self % n_heads, seq = 1: self % sequence_length) - output(seq, :, head) = self % softmax_func % eval_1d(self % attention_matrix(seq, :, head)) - end do - self % attention_matrix = output - - deallocate(output) - end subroutine normalize_attention_matrix - - module subroutine scaled_dot_product_attention(self, value) - !! Create scaled dot product attention - !! Output dims: sequence_length, head_size, n_heads - class(multihead_attention_layer) :: self - real :: value(:, :, :) - integer :: head - - do concurrent(head = 1: self % n_heads) - self % sdpa(:, :, head) = matmul(self % attention_matrix(:, :, head), value(:, :, head)) - end do - end subroutine scaled_dot_product_attention - - module function combine_heads(self, input) result(output) - class(multihead_attention_layer) :: self - real :: input(:, :, :) - !! (sequence_length, head_size, n_heads) - real :: output(self % sequence_length, self % model_dimension) - integer :: seq - - do concurrent(seq = 1: self % sequence_length) - output(seq, :) = reshape(transpose(input(seq, :, :)), [self % model_dimension]) - end do - end function combine_heads - - module function get_num_params(self) result(num_params) - class(multihead_attention_layer) :: self - integer :: num_params - - num_params = & - self % query_layer % get_num_params() & - + self % key_layer % get_num_params() & - + self % value_layer % get_num_params() & - + self % output_layer % get_num_params() - end function get_num_params - - module function get_params(self) result(params) - class(multihead_attention_layer), intent(in), target :: self - real, allocatable :: params(:) - - params = [& - self % query_layer % weights,& - self % key_layer % weights,& - self % value_layer % weights,& - self % output_layer % weights,& - self % query_layer % biases,& - self % key_layer % biases,& - self % value_layer % biases,& - self % output_layer % biases& - ] - end function get_params - - module function get_gradients(self) result(gradients) - class(multihead_attention_layer), intent(in), target :: self - real, allocatable :: gradients(:) - - gradients = [ & - self % query_layer % dw,& - self % key_layer % dw,& - self % value_layer % dw,& - self % output_layer % dw,& - self % query_layer % db,& - self % key_layer % db,& - self % value_layer % db,& - self % output_layer % db& - ] - end function get_gradients - - module subroutine set_params(self, params) - class(multihead_attention_layer), intent(in out) :: self - real, intent(in), target :: params(:) - real, pointer :: p_(:,:) => null() - integer :: i, j, window - - ! check if the number of parameters is correct - if (size(params) /= self % get_num_params()) then - error stop 'Error: number of parameters does not match' - end if - - ! FIXME: looks clumsy, better ideas? - window = self % model_dimension * self % model_dimension - i = 1 - j = window - self % query_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) - i = j + 1 - j = i + window - 1 - self % key_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) - i = j + 1 - j = i + window - 1 - self % value_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) - i = j + 1 - j = i + window - 1 - self % output_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) - - window = self % model_dimension - i = j + 1 - j = i + window - 1 - self % query_layer % biases = params(i: j) - i = j + 1 - j = i + window - 1 - self % key_layer % biases = params(i: j) - i = j + 1 - j = i + window - 1 - self % value_layer % biases = params(i: j) - i = j + 1 - j = i + window - 1 - self % output_layer % biases = params(i: j) - end subroutine set_params - - module subroutine init_base(self, input_shape) - class(multihead_attention_layer), intent(in out) :: self - integer, intent(in) :: input_shape(:) - - allocate(self % attention_matrix(self % sequence_length, self % sequence_length, self % n_heads)) - allocate(self % sdpa(self % sequence_length, self % head_size, self % n_heads)) - allocate(self % output(self % sequence_length, self % model_dimension)) - - self % scaling_factor = sqrt(1 / real(self % head_size)) + module function split_heads(self, input) result(output) + !! Split inputs into heads + !! + !! Example with two heads: + !! input (3, 4) + !! output (3, 2, 2) + class(multihead_attention_layer) :: self + real :: input(:, :) + real :: output(self % sequence_length, self % head_size, self % n_heads) + end function split_heads + + module subroutine create_attention_matrix(self, query, key) + !! Create attention matrix for query and key + !! Output dimensions: sequence_length, sequence_length, n_heads + class(multihead_attention_layer) :: self + real :: query(:, :, :) + real :: key(:, :, :) + integer :: head + end subroutine create_attention_matrix + + module subroutine normalize_attention_matrix(self, attention_mask) + !! Create attention matrix for query and key + !! Output dims: sequence_length, sequence_length, n_heads + class(multihead_attention_layer) :: self + !! (sequence_length, sequence_length, n_heads) + real, optional :: attention_mask(:, :, :) + !! (sequence_length, sequence_length, n_heads) + real, allocatable :: output(:, :, :) + integer :: head, seq + end subroutine normalize_attention_matrix + + module subroutine scaled_dot_product_attention(self, value) + !! Create scaled dot product attention + !! Output dims: sequence_length, head_size, n_heads + class(multihead_attention_layer) :: self + real :: value(:, :, :) + integer :: head + end subroutine scaled_dot_product_attention + + module function combine_heads(self, input) result(output) + class(multihead_attention_layer) :: self + real :: input(:, :, :) + !! (sequence_length, head_size, n_heads) + real :: output(self % sequence_length, self % model_dimension) + integer :: seq + end function combine_heads + + module function get_num_params(self) result(num_params) + class(multihead_attention_layer) :: self + integer :: num_params + end function get_num_params + + module function get_params(self) result(params) + class(multihead_attention_layer), intent(in), target :: self + real, allocatable :: params(:) + end function get_params + + module function get_gradients(self) result(gradients) + class(multihead_attention_layer), intent(in), target :: self + real, allocatable :: gradients(:) + end function get_gradients + + module subroutine set_params(self, params) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in), target :: params(:) + end subroutine set_params - allocate(self % q_input(self % sequence_length, self % model_dimension)) - allocate(self % k_input(self % sequence_length, self % model_dimension)) - allocate(self % v_input(self % sequence_length, self % model_dimension)) - allocate(self % o_input(self % sequence_length, self % model_dimension)) - end subroutine init_base + module subroutine init_base(self, input_shape) + class(multihead_attention_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + end subroutine init_base + end interface end module nf_multihead_attention_layer diff --git a/src/nf/nf_multihead_attention_submodule.f90 b/src/nf/nf_multihead_attention_submodule.f90 new file mode 100644 index 00000000..6e370013 --- /dev/null +++ b/src/nf/nf_multihead_attention_submodule.f90 @@ -0,0 +1,339 @@ +submodule(nf_multihead_attention_layer) nf_multihead_attention_layer_submodule +! use iso_fortran_env, only: stderr => error_unit + use nf_activation, only: softmax + use nf_base_layer, only: base_layer + use nf_linear2d_layer, only: linear2d_layer + + implicit none + +contains + module function multihead_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res) + integer, intent(in) :: sequence_length, model_dimension, n_heads + type(multihead_attention_layer) :: res + res % sequence_length = sequence_length + res % model_dimension = model_dimension + res % n_heads = n_heads + + if (mod(model_dimension, n_heads) /= 0) then + write(stderr, '(a)'), 'Number of heads must be divisible by model dimension' + error stop + end if + res % head_size = model_dimension / n_heads + + res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) + res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) + res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) + res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) + call res % query_layer % init([0]) + call res % key_layer % init([0]) + call res % value_layer % init([0]) + call res % output_layer % init([0]) + + res % softmax_func = softmax() + end function multihead_attention_layer_cons + + module subroutine common_backward(self, input, gradient) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: input(:, :) + real, intent(in) :: gradient(:, :) + + real, allocatable :: d_output(:, :, :) + real, allocatable :: v_heads(:, :, :) + real, allocatable :: k_heads(:, :, :) + real, allocatable :: q_heads(:, :, :) + real, allocatable :: dv(:, :, :) + real, allocatable :: d_sdpa(:, :) + real, allocatable :: jacobian(:, :) + real, allocatable :: d_normalize(:, :, :) + real, allocatable :: dq(:, :, :) + real, allocatable :: dk(:, :, :) + integer :: head, seq, i, j + + ! allocate temporary storages for backward computation + allocate(d_output(self % sequence_length, self % head_size, self % n_heads)) + allocate(v_heads(self % sequence_length, self % head_size, self % n_heads)) + allocate(k_heads(self % sequence_length, self % head_size, self % n_heads)) + allocate(q_heads(self % sequence_length, self % head_size, self % n_heads)) + + allocate(dv(self % sequence_length, self % head_size, self % n_heads)) + allocate(d_sdpa(self % sequence_length, self % sequence_length)) + allocate(jacobian(self % sequence_length, self % sequence_length)) + allocate(d_normalize(self % sequence_length, self % sequence_length, self % n_heads)) + allocate(dq(self % sequence_length, self % head_size, self % n_heads)) + allocate(dk(self % sequence_length, self % head_size, self % n_heads)) + + ! calculate output layer delta + call self % output_layer % backward(self % o_input, gradient) + + ! split heads from output gradient + d_output = self % split_heads(self % output_layer % gradient) + v_heads = self % split_heads(self % value_layer % output) + k_heads = self % split_heads(self % key_layer % output) + q_heads = self % split_heads(self % query_layer % output) + + ! iterate over heads to calculate deltas for each of them + do concurrent(head = 1: self % n_heads) + dv(:, :, head) = matmul(transpose(self % attention_matrix(:, :, head)), d_output(:, :, head)) + + ! calculate delta for attention matrix + d_sdpa = matmul(d_output(:, :, head), transpose(v_heads(:, :, head))) + + ! this monstrosity below is scaled derivative of softmax + do concurrent(seq = 1: self % sequence_length) + ! create jacobian matrix + do concurrent(i = 1: self % sequence_length, j = 1: self % sequence_length) + ! jacobian matrix is used to calculate derivative of softmax (temporary storage) + ! the idea behind this if-else is that for diagonal elements, the jacobian temp + ! should be: `softmax(x_i) * (1 - softmax(x_i))` + ! for off-diagonal: `-softmax(x_i) * softmax(x_j)` + if (i == j) then + jacobian(i, j) = & + self % attention_matrix(seq, i, head) & + * (1 - self % attention_matrix(seq, i, head)) + else + jacobian(i, j) = & + - self % attention_matrix(seq, i, head) & + * self % attention_matrix(seq, j, head) + end if + end do + ! attention normalization delta, the last step of softmax derivative: + ! multiply output of softmax by temp jacobian matrix + ! For computational efficiency (avoid more temp storages), scaling is also done here + ! reshapes: [3] -> [1, 3] @ [3, 3] = [1, 3] -> [3] + d_normalize(seq, :, head) = reshape(matmul(& + reshape(d_sdpa(seq, :), [1, self % sequence_length]),& + jacobian * self % scaling_factor& + ), [self % sequence_length]) + end do + + ! calculate delta for query + dq(:, :, head) = matmul(d_normalize(:, :, head), k_heads(:, :, head)) + + ! calculate delta for key, attention matrix should be transposed unlike for query + dk(:, :, head) = matmul(transpose(d_normalize(:, :, head)), q_heads(:, :, head)) + end do + + ! calculate deltas for input layers + call self % value_layer % backward(self % v_input, self % combine_heads(dv)) + call self % key_layer % backward(self % k_input, self % combine_heads(dk)) + call self % query_layer % backward(self % q_input, self % combine_heads(dq)) + + ! free temporary storages + deallocate(d_output) + deallocate(v_heads) + deallocate(k_heads) + deallocate(q_heads) + deallocate(d_sdpa) + deallocate(jacobian) + deallocate(d_normalize) + deallocate(dq) + deallocate(dk) + end subroutine common_backward + + module subroutine common_forward(self, query, key, value) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: query(:, :), key(:, :), value(:, :) + + real, allocatable :: q(:, :, :) + real, allocatable :: k(:, :, :) + real, allocatable :: v(:, :, :) + + ! allocate storage for intermidiate stages + allocate(q(self % sequence_length, self % head_size, self % n_heads)) + allocate(k(self % sequence_length, self % head_size, self % n_heads)) + allocate(v(self % sequence_length, self % head_size, self % n_heads)) + + self % q_input = query + self % k_input = key + self % v_input = value + + ! run inputs through linear layers (trainable params) + call self % query_layer % forward(query) + call self % key_layer % forward(key) + call self % value_layer % forward(value) + + ! split attention heads for more efficient computation + q = self % split_heads(self % query_layer % output) + k = self % split_heads(self % key_layer % output) + v = self % split_heads(self % value_layer % output) + + ! create key by value matrix + call self % create_attention_matrix(q, k) + ! apply softmax and scaling + call self % normalize_attention_matrix() + ! multiply attention matrix by value + call self % scaled_dot_product_attention(v) + + self % o_input = self % combine_heads(self % sdpa) + call self % output_layer % forward(self % o_input) + self % output = self % output_layer % output + + ! free temp vars from memory + deallocate(q) + deallocate(k) + deallocate(v) + end subroutine common_forward + + module function split_heads(self, input) result(output) + class(multihead_attention_layer) :: self + real :: input(:, :) + real :: output(self % sequence_length, self % head_size, self % n_heads) + output = reshape(input, [self % sequence_length, self % head_size, self % n_heads]) + end function split_heads + + module subroutine create_attention_matrix(self, query, key) + class(multihead_attention_layer) :: self + real :: query(:, :, :) + real :: key(:, :, :) + integer :: head + ! create attention matrix for each sequence in each batch + do concurrent(head = 1: self % n_heads) + self % attention_matrix(:, :, head) = matmul(query(:, :, head), transpose(key(:, :, head))) + end do + end subroutine create_attention_matrix + + module subroutine normalize_attention_matrix(self, attention_mask) + class(multihead_attention_layer) :: self + real, optional :: attention_mask(:, :, :) + real, allocatable :: output(:, :, :) + integer :: head, seq + + ! temporary storage + allocate(output(self % sequence_length, self % sequence_length, self % n_heads)) + + ! scale dowm by square root of each head's size + self % attention_matrix = self % attention_matrix * self % scaling_factor + ! attention mask is used to mask out some of the tokens if necessary + if (present(attention_mask)) then + self % attention_matrix = self % attention_matrix + attention_mask + end if + ! softmax by last sequnce_length + do concurrent(head = 1: self % n_heads, seq = 1: self % sequence_length) + output(seq, :, head) = self % softmax_func % eval_1d(self % attention_matrix(seq, :, head)) + end do + self % attention_matrix = output + + deallocate(output) + end subroutine normalize_attention_matrix + + module subroutine scaled_dot_product_attention(self, value) + class(multihead_attention_layer) :: self + real :: value(:, :, :) + integer :: head + + do concurrent(head = 1: self % n_heads) + self % sdpa(:, :, head) = matmul(self % attention_matrix(:, :, head), value(:, :, head)) + end do + end subroutine scaled_dot_product_attention + + module function combine_heads(self, input) result(output) + class(multihead_attention_layer) :: self + real :: input(:, :, :) + real :: output(self % sequence_length, self % model_dimension) + integer :: seq + + do concurrent(seq = 1: self % sequence_length) + output(seq, :) = reshape(transpose(input(seq, :, :)), [self % model_dimension]) + end do + end function combine_heads + + module function get_num_params(self) result(num_params) + class(multihead_attention_layer) :: self + integer :: num_params + + num_params = & + self % query_layer % get_num_params() & + + self % key_layer % get_num_params() & + + self % value_layer % get_num_params() & + + self % output_layer % get_num_params() + end function get_num_params + + module function get_params(self) result(params) + class(multihead_attention_layer), intent(in), target :: self + real, allocatable :: params(:) + + params = [& + self % query_layer % weights,& + self % key_layer % weights,& + self % value_layer % weights,& + self % output_layer % weights,& + self % query_layer % biases,& + self % key_layer % biases,& + self % value_layer % biases,& + self % output_layer % biases& + ] + end function get_params + + module function get_gradients(self) result(gradients) + class(multihead_attention_layer), intent(in), target :: self + real, allocatable :: gradients(:) + + gradients = [ & + self % query_layer % dw,& + self % key_layer % dw,& + self % value_layer % dw,& + self % output_layer % dw,& + self % query_layer % db,& + self % key_layer % db,& + self % value_layer % db,& + self % output_layer % db& + ] + end function get_gradients + + module subroutine set_params(self, params) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in), target :: params(:) + real, pointer :: p_(:,:) => null() + integer :: i, j, window + + ! check if the number of parameters is correct + if (size(params) /= self % get_num_params()) then + error stop 'Error: number of parameters does not match' + end if + + ! FIXME: looks clumsy, better ideas? + window = self % model_dimension * self % model_dimension + i = 1 + j = window + self % query_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) + i = j + 1 + j = i + window - 1 + self % key_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) + i = j + 1 + j = i + window - 1 + self % value_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) + i = j + 1 + j = i + window - 1 + self % output_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) + + window = self % model_dimension + i = j + 1 + j = i + window - 1 + self % query_layer % biases = params(i: j) + i = j + 1 + j = i + window - 1 + self % key_layer % biases = params(i: j) + i = j + 1 + j = i + window - 1 + self % value_layer % biases = params(i: j) + i = j + 1 + j = i + window - 1 + self % output_layer % biases = params(i: j) + end subroutine set_params + + module subroutine init_base(self, input_shape) + class(multihead_attention_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + allocate(self % attention_matrix(self % sequence_length, self % sequence_length, self % n_heads)) + allocate(self % sdpa(self % sequence_length, self % head_size, self % n_heads)) + allocate(self % output(self % sequence_length, self % model_dimension)) + + self % scaling_factor = sqrt(1 / real(self % head_size)) + + allocate(self % q_input(self % sequence_length, self % model_dimension)) + allocate(self % k_input(self % sequence_length, self % model_dimension)) + allocate(self % v_input(self % sequence_length, self % model_dimension)) + allocate(self % o_input(self % sequence_length, self % model_dimension)) + end subroutine init_base +end submodule nf_multihead_attention_layer_submodule \ No newline at end of file From b622d55927cadff1c25f59d7ff62c6e52d423d26 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Mon, 17 Feb 2025 00:11:18 +0400 Subject: [PATCH 51/71] multihead_attention: update cmakelists --- CMakeLists.txt | 4 ++++ test/CMakeLists.txt | 1 + 2 files changed, 5 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index fc2ddfcb..c31f2e7a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,6 +20,7 @@ add_library(neural-fortran src/nf/nf_base_layer.f90 src/nf/nf_conv2d_layer.f90 src/nf/nf_conv2d_layer_submodule.f90 + src/nf/nf_cross_attention_layer.f90 src/nf/nf_datasets.f90 src/nf/nf_datasets_submodule.f90 src/nf/nf_datasets_mnist.f90 @@ -45,6 +46,8 @@ add_library(neural-fortran src/nf/nf_maxpool2d_layer.f90 src/nf/nf_maxpool2d_layer_submodule.f90 src/nf/nf_metrics.f90 + src/nf/nf_multihead_attention.f90 + src/nf/nf_multihead_attention_submodule.f90 src/nf/nf_network.f90 src/nf/nf_network_submodule.f90 src/nf/nf_optimizers.f90 @@ -53,6 +56,7 @@ add_library(neural-fortran src/nf/nf_random.f90 src/nf/nf_reshape_layer.f90 src/nf/nf_reshape_layer_submodule.f90 + src/nf/nf_self_attention_layer.f90 src/nf/io/nf_io_binary.f90 src/nf/io/nf_io_binary_submodule.f90 ) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 12236416..a1de95c6 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -10,6 +10,7 @@ foreach(execid flatten_layer insert_flatten reshape_layer + multihead_attention_layer dense_network get_set_network_params conv2d_network From ce03b390e86364760c7e726ce4546f4bbb4808b9 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Mon, 17 Feb 2025 14:19:32 +0400 Subject: [PATCH 52/71] multihead_attention: update attention in accordance with linear2d --- src/nf/nf_cross_attention_layer.f90 | 16 +++++------ src/nf/nf_multihead_attention_submodule.f90 | 16 +++++------ src/nf/nf_self_attention_layer.f90 | 16 +++++------ test/test_multihead_attention_layer.f90 | 30 +++++++++++++++++++++ 4 files changed, 54 insertions(+), 24 deletions(-) diff --git a/src/nf/nf_cross_attention_layer.f90 b/src/nf/nf_cross_attention_layer.f90 index 790165ff..7204cd99 100644 --- a/src/nf/nf_cross_attention_layer.f90 +++ b/src/nf/nf_cross_attention_layer.f90 @@ -42,14 +42,14 @@ module function cross_attention_layer_cons(sequence_length, model_dimension, n_h end if res % head_size = model_dimension / n_heads - res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - call res % query_layer % init([0]) - call res % key_layer % init([0]) - call res % value_layer % init([0]) - call res % output_layer % init([0]) + res % query_layer = linear2d_layer(model_dimension) + res % key_layer = linear2d_layer(model_dimension) + res % value_layer = linear2d_layer(model_dimension) + res % output_layer = linear2d_layer(model_dimension) + call res % query_layer % init([sequence_length, model_dimension]) + call res % key_layer % init([sequence_length, model_dimension]) + call res % value_layer % init([sequence_length, model_dimension]) + call res % output_layer % init([sequence_length, model_dimension]) res % softmax_func = softmax() end function cross_attention_layer_cons diff --git a/src/nf/nf_multihead_attention_submodule.f90 b/src/nf/nf_multihead_attention_submodule.f90 index 6e370013..55946c59 100644 --- a/src/nf/nf_multihead_attention_submodule.f90 +++ b/src/nf/nf_multihead_attention_submodule.f90 @@ -20,14 +20,14 @@ module function multihead_attention_layer_cons(sequence_length, model_dimension, end if res % head_size = model_dimension / n_heads - res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - call res % query_layer % init([0]) - call res % key_layer % init([0]) - call res % value_layer % init([0]) - call res % output_layer % init([0]) + res % query_layer = linear2d_layer(model_dimension) + res % key_layer = linear2d_layer(model_dimension) + res % value_layer = linear2d_layer(model_dimension) + res % output_layer = linear2d_layer(model_dimension) + call res % query_layer % init([sequence_length, model_dimension]) + call res % key_layer % init([sequence_length, model_dimension]) + call res % value_layer % init([sequence_length, model_dimension]) + call res % output_layer % init([sequence_length, model_dimension]) res % softmax_func = softmax() end function multihead_attention_layer_cons diff --git a/src/nf/nf_self_attention_layer.f90 b/src/nf/nf_self_attention_layer.f90 index 580cc673..11a4889b 100644 --- a/src/nf/nf_self_attention_layer.f90 +++ b/src/nf/nf_self_attention_layer.f90 @@ -42,14 +42,14 @@ module function self_attention_layer_cons(sequence_length, model_dimension, n_he end if res % head_size = model_dimension / n_heads - res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension) - call res % query_layer % init([0]) - call res % key_layer % init([0]) - call res % value_layer % init([0]) - call res % output_layer % init([0]) + res % query_layer = linear2d_layer(model_dimension) + res % key_layer = linear2d_layer(model_dimension) + res % value_layer = linear2d_layer(model_dimension) + res % output_layer = linear2d_layer(model_dimension) + call res % query_layer % init([sequence_length, model_dimension]) + call res % key_layer % init([sequence_length, model_dimension]) + call res % value_layer % init([sequence_length, model_dimension]) + call res % output_layer % init([sequence_length, model_dimension]) res % softmax_func = softmax() end function self_attention_layer_cons diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index aaa0760b..a8696e09 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -16,6 +16,7 @@ program test_multihead_attention_layer attention = multihead_attention_layer(sequence_length=3, model_dimension=4, n_heads=2) call attention % init_base([0]) + call set_weights(attention) call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output) call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok) @@ -30,6 +31,18 @@ program test_multihead_attention_layer call test_cross_attention(ok) contains + subroutine set_weights(attention) + type(multihead_attention_layer), intent(in out) :: attention + attention % query_layer % weights = 0.1 + attention % key_layer % weights = 0.1 + attention % value_layer % weights = 0.1 + attention % output_layer % weights = 0.1 + attention % query_layer % biases = 0.11 + attention % key_layer % biases = 0.11 + attention % value_layer % biases = 0.11 + attention % output_layer % biases = 0.11 + end subroutine set_weights + subroutine test_multihead_attention_split_heads(attention, input, ok, output) type(multihead_attention_layer), intent(in) :: attention real, intent(in) :: input(:, :) @@ -199,6 +212,7 @@ subroutine test_multihead_attention_forward_reallife_shape(ok) attention = multihead_attention_layer(sequence_length=148, model_dimension=512, n_heads=8) call attention % init_base([0]) + call set_weights(attention) call attention % common_forward(input, input, input) @@ -305,6 +319,14 @@ subroutine test_self_attention(ok) attention = self_attention_layer(sequence_length=2, model_dimension=3, n_heads=1) call attention % init([0]) + attention % query_layer % weights = 0.1 + attention % key_layer % weights = 0.1 + attention % value_layer % weights = 0.1 + attention % output_layer % weights = 0.1 + attention % query_layer % biases = 0.11 + attention % key_layer % biases = 0.11 + attention % value_layer % biases = 0.11 + attention % output_layer % biases = 0.11 call attention % forward(input) output_flat = reshape(attention % output, shape(output_flat)) @@ -346,6 +368,14 @@ subroutine test_cross_attention(ok) attention = cross_attention_layer(sequence_length=2, model_dimension=3, n_heads=1) call attention % init([0]) + attention % query_layer % weights = 0.1 + attention % key_layer % weights = 0.1 + attention % value_layer % weights = 0.1 + attention % output_layer % weights = 0.1 + attention % query_layer % biases = 0.11 + attention % key_layer % biases = 0.11 + attention % value_layer % biases = 0.11 + attention % output_layer % biases = 0.11 call attention % forward(input) output_flat = reshape(attention % output, shape(output_flat)) From 41a80cd8ebe88dfca3175ed0e4a75c12e842e5f8 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Mon, 17 Feb 2025 14:35:02 +0400 Subject: [PATCH 53/71] multihead_attention: remove redundand constructor args for attention layers --- src/nf/nf_cross_attention_layer.f90 | 25 ++--------- src/nf/nf_multihead_attention.f90 | 4 +- src/nf/nf_multihead_attention_submodule.f90 | 46 +++++++++++---------- src/nf/nf_self_attention_layer.f90 | 27 ++---------- test/test_multihead_attention_layer.f90 | 18 ++++---- 5 files changed, 43 insertions(+), 77 deletions(-) diff --git a/src/nf/nf_cross_attention_layer.f90 b/src/nf/nf_cross_attention_layer.f90 index 7204cd99..5d5c8e3e 100644 --- a/src/nf/nf_cross_attention_layer.f90 +++ b/src/nf/nf_cross_attention_layer.f90 @@ -20,7 +20,7 @@ module nf_cross_attention_layer end type cross_attention_layer interface cross_attention_layer - module function cross_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res) + module function cross_attention_layer_cons(n_heads) result(res) !! This function returns the `cross_attention_layer` instance. integer, intent(in) :: sequence_length, model_dimension, n_heads type(cross_attention_layer) :: res @@ -28,30 +28,11 @@ end function cross_attention_layer_cons end interface cross_attention_layer contains - module function cross_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res) + module function cross_attention_layer_cons(n_heads) result(res) !! This function returns the `cross_attention_layer` instance. - integer, intent(in) :: sequence_length, model_dimension, n_heads + integer, intent(in) :: n_heads type(cross_attention_layer) :: res - res % sequence_length = sequence_length - res % model_dimension = model_dimension res % n_heads = n_heads - - if (mod(model_dimension, n_heads) /= 0) then - write(stderr, '(a)'), 'Number of heads must be divisible by model dimension' - error stop - end if - res % head_size = model_dimension / n_heads - - res % query_layer = linear2d_layer(model_dimension) - res % key_layer = linear2d_layer(model_dimension) - res % value_layer = linear2d_layer(model_dimension) - res % output_layer = linear2d_layer(model_dimension) - call res % query_layer % init([sequence_length, model_dimension]) - call res % key_layer % init([sequence_length, model_dimension]) - call res % value_layer % init([sequence_length, model_dimension]) - call res % output_layer % init([sequence_length, model_dimension]) - - res % softmax_func = softmax() end function cross_attention_layer_cons module subroutine backward(self, input, gradient) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index a98c6d69..769510be 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -51,9 +51,9 @@ module nf_multihead_attention_layer end type multihead_attention_layer interface multihead_attention_layer - module function multihead_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res) + module function multihead_attention_layer_cons(n_heads) result(res) !! This function returns the `multihead_attention_layer` instance. - integer, intent(in) :: sequence_length, model_dimension, n_heads + integer, intent(in) :: n_heads type(multihead_attention_layer) :: res end function multihead_attention_layer_cons end interface multihead_attention_layer diff --git a/src/nf/nf_multihead_attention_submodule.f90 b/src/nf/nf_multihead_attention_submodule.f90 index 55946c59..72ba5691 100644 --- a/src/nf/nf_multihead_attention_submodule.f90 +++ b/src/nf/nf_multihead_attention_submodule.f90 @@ -7,29 +7,11 @@ implicit none contains - module function multihead_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res) - integer, intent(in) :: sequence_length, model_dimension, n_heads + module function multihead_attention_layer_cons(n_heads) result(res) + integer, intent(in) :: n_heads type(multihead_attention_layer) :: res - res % sequence_length = sequence_length - res % model_dimension = model_dimension - res % n_heads = n_heads - if (mod(model_dimension, n_heads) /= 0) then - write(stderr, '(a)'), 'Number of heads must be divisible by model dimension' - error stop - end if - res % head_size = model_dimension / n_heads - - res % query_layer = linear2d_layer(model_dimension) - res % key_layer = linear2d_layer(model_dimension) - res % value_layer = linear2d_layer(model_dimension) - res % output_layer = linear2d_layer(model_dimension) - call res % query_layer % init([sequence_length, model_dimension]) - call res % key_layer % init([sequence_length, model_dimension]) - call res % value_layer % init([sequence_length, model_dimension]) - call res % output_layer % init([sequence_length, model_dimension]) - - res % softmax_func = softmax() + res % n_heads = n_heads end function multihead_attention_layer_cons module subroutine common_backward(self, input, gradient) @@ -325,6 +307,28 @@ module subroutine init_base(self, input_shape) class(multihead_attention_layer), intent(in out) :: self integer, intent(in) :: input_shape(:) + if (size(input_shape) /= 2) then + error stop "MultiHead Attention accepts 2D input" + end if + self % sequence_length = input_shape(1) + self % model_dimension = input_shape(2) + + if (mod(self % model_dimension, self % n_heads) /= 0) then + write(stderr, '(a)'), 'Number of heads must be divisible by model dimension' + error stop + end if + self % head_size = self % model_dimension / self % n_heads + self % softmax_func = softmax() + + self % query_layer = linear2d_layer(self % model_dimension) + self % key_layer = linear2d_layer(self % model_dimension) + self % value_layer = linear2d_layer(self % model_dimension) + self % output_layer = linear2d_layer(self % model_dimension) + call self % query_layer % init([self % sequence_length, self % model_dimension]) + call self % key_layer % init([self % sequence_length, self % model_dimension]) + call self % value_layer % init([self % sequence_length, self % model_dimension]) + call self % output_layer % init([self % sequence_length, self % model_dimension]) + allocate(self % attention_matrix(self % sequence_length, self % sequence_length, self % n_heads)) allocate(self % sdpa(self % sequence_length, self % head_size, self % n_heads)) allocate(self % output(self % sequence_length, self % model_dimension)) diff --git a/src/nf/nf_self_attention_layer.f90 b/src/nf/nf_self_attention_layer.f90 index 11a4889b..6b17d381 100644 --- a/src/nf/nf_self_attention_layer.f90 +++ b/src/nf/nf_self_attention_layer.f90 @@ -20,38 +20,19 @@ module nf_self_attention_layer end type self_attention_layer interface self_attention_layer - module function self_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res) + module function self_attention_layer_cons(n_heads) result(res) !! This function returns the `self_attention_layer` instance. - integer, intent(in) :: sequence_length, model_dimension, n_heads + integer, intent(in) :: n_heads type(self_attention_layer) :: res end function self_attention_layer_cons end interface self_attention_layer contains - module function self_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res) + module function self_attention_layer_cons(n_heads) result(res) !! This function returns the `self_attention_layer` instance. - integer, intent(in) :: sequence_length, model_dimension, n_heads + integer, intent(in) :: n_heads type(self_attention_layer) :: res - res % sequence_length = sequence_length - res % model_dimension = model_dimension res % n_heads = n_heads - - if (mod(model_dimension, n_heads) /= 0) then - write(stderr, '(a)'), 'Number of heads must be divisible by model dimension' - error stop - end if - res % head_size = model_dimension / n_heads - - res % query_layer = linear2d_layer(model_dimension) - res % key_layer = linear2d_layer(model_dimension) - res % value_layer = linear2d_layer(model_dimension) - res % output_layer = linear2d_layer(model_dimension) - call res % query_layer % init([sequence_length, model_dimension]) - call res % key_layer % init([sequence_length, model_dimension]) - call res % value_layer % init([sequence_length, model_dimension]) - call res % output_layer % init([sequence_length, model_dimension]) - - res % softmax_func = softmax() end function self_attention_layer_cons module subroutine backward(self, input, gradient) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index a8696e09..194d5e1b 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -14,8 +14,8 @@ program test_multihead_attention_layer real :: minput(3, 4) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [3, 4]) real :: output(3, 2, 2) - attention = multihead_attention_layer(sequence_length=3, model_dimension=4, n_heads=2) - call attention % init_base([0]) + attention = multihead_attention_layer(n_heads=2) + call attention % init_base([3, 4]) call set_weights(attention) call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output) @@ -210,8 +210,8 @@ subroutine test_multihead_attention_forward_reallife_shape(ok) call random_number(input) - attention = multihead_attention_layer(sequence_length=148, model_dimension=512, n_heads=8) - call attention % init_base([0]) + attention = multihead_attention_layer(n_heads=8) + call attention % init_base([148, 512]) call set_weights(attention) call attention % common_forward(input, input, input) @@ -317,8 +317,8 @@ subroutine test_self_attention(ok) 0.350671142, 0.607403040, 0.350671142, 0.607403040, 0.350671142, 0.607403040& ] - attention = self_attention_layer(sequence_length=2, model_dimension=3, n_heads=1) - call attention % init([0]) + attention = self_attention_layer(n_heads=1) + call attention % init([2, 3]) attention % query_layer % weights = 0.1 attention % key_layer % weights = 0.1 attention % value_layer % weights = 0.1 @@ -366,8 +366,8 @@ subroutine test_cross_attention(ok) input(1, :, :) = query input(2, :, :) = key_value - attention = cross_attention_layer(sequence_length=2, model_dimension=3, n_heads=1) - call attention % init([0]) + attention = cross_attention_layer(n_heads=1) + call attention % init([2, 3]) attention % query_layer % weights = 0.1 attention % key_layer % weights = 0.1 attention % value_layer % weights = 0.1 @@ -396,4 +396,4 @@ subroutine test_cross_attention(ok) write(stderr, '(a)') 'backward returned incorrect key-value values.. failed' end if end subroutine test_cross_attention -end program test_multihead_attention_layer \ No newline at end of file +end program test_multihead_attention_layer From a84efd3d926121d0dc423126cfbba7e8cd3f554a Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Mon, 17 Feb 2025 15:04:09 +0400 Subject: [PATCH 54/71] multihead_attention: use pure and elemental where necessary --- src/nf/nf_cross_attention_layer.f90 | 4 +- src/nf/nf_multihead_attention.f90 | 42 ++++++++++----------- src/nf/nf_multihead_attention_submodule.f90 | 40 ++++++++++---------- src/nf/nf_self_attention_layer.f90 | 4 +- test/test_multihead_attention_layer.f90 | 6 +-- 5 files changed, 48 insertions(+), 48 deletions(-) diff --git a/src/nf/nf_cross_attention_layer.f90 b/src/nf/nf_cross_attention_layer.f90 index 5d5c8e3e..b3167a13 100644 --- a/src/nf/nf_cross_attention_layer.f90 +++ b/src/nf/nf_cross_attention_layer.f90 @@ -35,7 +35,7 @@ module function cross_attention_layer_cons(n_heads) result(res) res % n_heads = n_heads end function cross_attention_layer_cons - module subroutine backward(self, input, gradient) + pure module subroutine backward(self, input, gradient) !! Cross Attention Back propagation class(cross_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :, :) @@ -46,7 +46,7 @@ module subroutine backward(self, input, gradient) self % gradient(2, :, :) = self % key_layer % gradient + self % value_layer % gradient end subroutine backward - module subroutine forward(self, input) + pure module subroutine forward(self, input) !! Cross Attention Forward propagation !! Input Shape (kind, sequence_length, model_dimension) !! where kind is 1 for Query and 2 for Key-Value diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 769510be..372e195e 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -60,7 +60,7 @@ end function multihead_attention_layer_cons interface - module subroutine common_backward(self, input, gradient) + pure module subroutine common_backward(self, input, gradient) !! General backprop for MultiHead Attention mechanism !! Might be used for both Self and Cross Attention !! Self Attention: sum output gradients @@ -70,7 +70,7 @@ module subroutine common_backward(self, input, gradient) real, intent(in) :: gradient(:, :) end subroutine common_backward - module subroutine common_forward(self, query, key, value) + pure module subroutine common_forward(self, query, key, value) !! General forward propagation for MultiHead Attention Mechanism !! Might be used for both Self and Cross Attention !! Self Attention: pass the same value thrice @@ -79,7 +79,7 @@ module subroutine common_forward(self, query, key, value) real, intent(in) :: query(:, :), key(:, :), value(:, :) end subroutine common_forward - module subroutine init(self, input_shape) + pure module subroutine init(self, input_shape) !! Initialize the layer data structures. !! !! This is a deferred procedure from the `base_layer` abstract type. @@ -87,55 +87,55 @@ module subroutine init(self, input_shape) integer, intent(in) :: input_shape(:) end subroutine init - module function split_heads(self, input) result(output) + pure module function split_heads(self, input) result(output) !! Split inputs into heads !! !! Example with two heads: !! input (3, 4) !! output (3, 2, 2) - class(multihead_attention_layer) :: self - real :: input(:, :) + class(multihead_attention_layer), intent(in) :: self + real, intent(in) :: input(:, :) real :: output(self % sequence_length, self % head_size, self % n_heads) end function split_heads - module subroutine create_attention_matrix(self, query, key) + pure module subroutine create_attention_matrix(self, query, key) !! Create attention matrix for query and key !! Output dimensions: sequence_length, sequence_length, n_heads - class(multihead_attention_layer) :: self - real :: query(:, :, :) - real :: key(:, :, :) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: query(:, :, :) + real, intent(in) :: key(:, :, :) integer :: head end subroutine create_attention_matrix - module subroutine normalize_attention_matrix(self, attention_mask) + pure module subroutine normalize_attention_matrix(self, attention_mask) !! Create attention matrix for query and key !! Output dims: sequence_length, sequence_length, n_heads - class(multihead_attention_layer) :: self + class(multihead_attention_layer), intent(in out) :: self !! (sequence_length, sequence_length, n_heads) - real, optional :: attention_mask(:, :, :) + real, optional, intent(in) :: attention_mask(:, :, :) !! (sequence_length, sequence_length, n_heads) real, allocatable :: output(:, :, :) integer :: head, seq end subroutine normalize_attention_matrix - module subroutine scaled_dot_product_attention(self, value) + pure module subroutine scaled_dot_product_attention(self, value) !! Create scaled dot product attention !! Output dims: sequence_length, head_size, n_heads - class(multihead_attention_layer) :: self - real :: value(:, :, :) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: value(:, :, :) integer :: head end subroutine scaled_dot_product_attention - module function combine_heads(self, input) result(output) - class(multihead_attention_layer) :: self - real :: input(:, :, :) + pure module function combine_heads(self, input) result(output) + class(multihead_attention_layer), intent(in) :: self + real, intent(in) :: input(:, :, :) !! (sequence_length, head_size, n_heads) real :: output(self % sequence_length, self % model_dimension) integer :: seq end function combine_heads - module function get_num_params(self) result(num_params) - class(multihead_attention_layer) :: self + elemental module function get_num_params(self) result(num_params) + class(multihead_attention_layer), intent(in) :: self integer :: num_params end function get_num_params diff --git a/src/nf/nf_multihead_attention_submodule.f90 b/src/nf/nf_multihead_attention_submodule.f90 index 72ba5691..d0e43a2e 100644 --- a/src/nf/nf_multihead_attention_submodule.f90 +++ b/src/nf/nf_multihead_attention_submodule.f90 @@ -14,7 +14,7 @@ module function multihead_attention_layer_cons(n_heads) result(res) res % n_heads = n_heads end function multihead_attention_layer_cons - module subroutine common_backward(self, input, gradient) + pure module subroutine common_backward(self, input, gradient) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :) real, intent(in) :: gradient(:, :) @@ -112,7 +112,7 @@ module subroutine common_backward(self, input, gradient) deallocate(dk) end subroutine common_backward - module subroutine common_forward(self, query, key, value) + pure module subroutine common_forward(self, query, key, value) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: query(:, :), key(:, :), value(:, :) @@ -156,17 +156,17 @@ module subroutine common_forward(self, query, key, value) deallocate(v) end subroutine common_forward - module function split_heads(self, input) result(output) - class(multihead_attention_layer) :: self - real :: input(:, :) + pure module function split_heads(self, input) result(output) + class(multihead_attention_layer), intent(in) :: self + real, intent(in) :: input(:, :) real :: output(self % sequence_length, self % head_size, self % n_heads) output = reshape(input, [self % sequence_length, self % head_size, self % n_heads]) end function split_heads - module subroutine create_attention_matrix(self, query, key) - class(multihead_attention_layer) :: self - real :: query(:, :, :) - real :: key(:, :, :) + pure module subroutine create_attention_matrix(self, query, key) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: query(:, :, :) + real, intent(in) :: key(:, :, :) integer :: head ! create attention matrix for each sequence in each batch do concurrent(head = 1: self % n_heads) @@ -174,9 +174,9 @@ module subroutine create_attention_matrix(self, query, key) end do end subroutine create_attention_matrix - module subroutine normalize_attention_matrix(self, attention_mask) - class(multihead_attention_layer) :: self - real, optional :: attention_mask(:, :, :) + pure module subroutine normalize_attention_matrix(self, attention_mask) + class(multihead_attention_layer), intent(in out) :: self + real, optional, intent(in) :: attention_mask(:, :, :) real, allocatable :: output(:, :, :) integer :: head, seq @@ -198,9 +198,9 @@ module subroutine normalize_attention_matrix(self, attention_mask) deallocate(output) end subroutine normalize_attention_matrix - module subroutine scaled_dot_product_attention(self, value) - class(multihead_attention_layer) :: self - real :: value(:, :, :) + pure module subroutine scaled_dot_product_attention(self, value) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: value(:, :, :) integer :: head do concurrent(head = 1: self % n_heads) @@ -208,9 +208,9 @@ module subroutine scaled_dot_product_attention(self, value) end do end subroutine scaled_dot_product_attention - module function combine_heads(self, input) result(output) - class(multihead_attention_layer) :: self - real :: input(:, :, :) + pure module function combine_heads(self, input) result(output) + class(multihead_attention_layer), intent(in) :: self + real, intent(in) :: input(:, :, :) real :: output(self % sequence_length, self % model_dimension) integer :: seq @@ -219,8 +219,8 @@ module function combine_heads(self, input) result(output) end do end function combine_heads - module function get_num_params(self) result(num_params) - class(multihead_attention_layer) :: self + elemental module function get_num_params(self) result(num_params) + class(multihead_attention_layer), intent(in) :: self integer :: num_params num_params = & diff --git a/src/nf/nf_self_attention_layer.f90 b/src/nf/nf_self_attention_layer.f90 index 6b17d381..15e8f40c 100644 --- a/src/nf/nf_self_attention_layer.f90 +++ b/src/nf/nf_self_attention_layer.f90 @@ -35,7 +35,7 @@ module function self_attention_layer_cons(n_heads) result(res) res % n_heads = n_heads end function self_attention_layer_cons - module subroutine backward(self, input, gradient) + pure module subroutine backward(self, input, gradient) !! Self Attention back propagation !! Returns sum of Query, Key and Value gradients class(self_attention_layer), intent(in out) :: self @@ -49,7 +49,7 @@ module subroutine backward(self, input, gradient) + self % value_layer % gradient end subroutine backward - module subroutine forward(self, input) + pure module subroutine forward(self, input) !! Cross Attention forward propagation !! Passes input three times into MultiHead Attention !! Input Shape: (sequence_length, model_dimension) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index 194d5e1b..7b2d53b6 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -68,7 +68,7 @@ subroutine test_multihead_attention_split_heads(attention, input, ok, output) end subroutine test_multihead_attention_split_heads subroutine test_multihead_attention_create_attention_matrix(attention, input, ok) - type(multihead_attention_layer), intent(in) :: attention + type(multihead_attention_layer), intent(in out) :: attention real, intent(in) :: input(:, :, :) logical, intent(in out) :: ok real :: attention_matrix_shape(3) @@ -95,7 +95,7 @@ subroutine test_multihead_attention_create_attention_matrix(attention, input, ok end subroutine test_multihead_attention_create_attention_matrix subroutine test_multihead_attention_normalization(attention, ok) - type(multihead_attention_layer), intent(in) :: attention + type(multihead_attention_layer), intent(in out) :: attention logical, intent(in out) :: ok real :: output_flat(18) real :: expected_output_flat(18) = [& @@ -114,7 +114,7 @@ subroutine test_multihead_attention_normalization(attention, ok) end subroutine test_multihead_attention_normalization subroutine test_multihead_attention_scaled_dot_product_attention(attention, value, ok) - type(multihead_attention_layer), intent(in) :: attention + type(multihead_attention_layer), intent(in out) :: attention real, intent(in) :: value(:, :, :) logical, intent(in out) :: ok real :: output_flat(12) From 52c94c4a84b32d302b309e5aa6410be8c94a77a0 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Mon, 17 Feb 2025 15:07:23 +0400 Subject: [PATCH 55/71] multihead_attention: plumbing --- src/nf.f90 | 2 +- src/nf/nf_layer_constructors.f90 | 7 +++- src/nf/nf_layer_constructors_submodule.f90 | 10 ++++++ src/nf/nf_layer_submodule.f90 | 41 ++++++++++++++++++++++ src/nf/nf_network_submodule.f90 | 3 ++ 5 files changed, 61 insertions(+), 2 deletions(-) diff --git a/src/nf.f90 b/src/nf.f90 index 7223c1a3..5510255c 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -3,7 +3,7 @@ module nf use nf_datasets_mnist, only: label_digits, load_mnist use nf_layer, only: layer use nf_layer_constructors, only: & - conv2d, dense, flatten, input, maxpool2d, reshape, linear2d + conv2d, dense, flatten, input, maxpool2d, reshape, linear2d, self_attention use nf_loss, only: mse, quadratic use nf_metrics, only: corr, maxabs use nf_network, only: network diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index 2983ddcd..d65e47e5 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -8,7 +8,7 @@ module nf_layer_constructors implicit none private - public :: conv2d, dense, flatten, input, maxpool2d, reshape, linear2d + public :: conv2d, dense, flatten, input, maxpool2d, reshape, linear2d, self_attention interface input @@ -195,6 +195,11 @@ module function linear2d(out_features) result(res) !! Resulting layer instance end function linear2d + module function self_attention(sequence_length, model_dimension, n_heads) result(res) + integer, intent(in) :: sequence_length, model_dimension, n_heads + type(layer) :: res + end function self_attention + end interface end module nf_layer_constructors diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index ae7d05dc..4cb06f74 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -10,6 +10,7 @@ use nf_maxpool2d_layer, only: maxpool2d_layer use nf_reshape_layer, only: reshape3d_layer use nf_linear2d_layer, only: linear2d_layer + use nf_self_attention_layer, only: self_attention_layer use nf_activation, only: activation_function, relu, sigmoid implicit none @@ -160,4 +161,13 @@ module function linear2d(out_features) result(res) end function linear2d + module function self_attention(sequence_length, model_dimension, n_heads) result(res) + integer, intent(in) :: sequence_length, model_dimension, n_heads + type(layer) :: res + + res % name = 'self_attention' + res % layer_shape = [sequence_length, model_dimension] + allocate(res % p, source=self_attention_layer(n_heads)) + end function self_attention + end submodule nf_layer_constructors_submodule diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index 22eabe9e..46dbb883 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -10,6 +10,7 @@ use nf_maxpool2d_layer, only: maxpool2d_layer use nf_reshape_layer, only: reshape3d_layer use nf_linear2d_layer, only: linear2d_layer + use nf_self_attention_layer, only: self_attention_layer use nf_optimizers, only: optimizer_base_type contains @@ -50,6 +51,8 @@ pure module subroutine backward_1d(self, previous, gradient) call this_layer % backward(prev_layer % output, gradient) type is(linear2d_layer) call this_layer % backward(prev_layer % output, gradient) + type is(self_attention_layer) + call this_layer % backward(prev_layer % output, gradient) end select end select @@ -72,6 +75,19 @@ pure module subroutine backward_2d(self, previous, gradient) call this_layer % backward(prev_layer % output, gradient) type is(linear2d_layer) call this_layer % backward(prev_layer % output, gradient) + type is(self_attention_layer) + call this_layer % backward(prev_layer % output, gradient) + end select + + type is(self_attention_layer) + + select type(prev_layer => previous % p) + type is(input2d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(linear2d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(self_attention_layer) + call this_layer % backward(prev_layer % output, gradient) end select end select @@ -219,6 +235,20 @@ pure module subroutine forward(self, input) call this_layer % forward(prev_layer % output) type is(linear2d_layer) call this_layer % forward(prev_layer % output) + type is(self_attention_layer) + call this_layer % forward(prev_layer % output) + end select + + type is(self_attention_layer) + + ! Upstream layers permitted: input2d, linear2d + select type(prev_layer => input % p) + type is(input2d_layer) + call this_layer % forward(prev_layer % output) + type is(linear2d_layer) + call this_layer % forward(prev_layer % output) + type is(self_attention_layer) + call this_layer % forward(prev_layer % output) end select end select @@ -258,6 +288,8 @@ pure module subroutine get_output_2d(self, output) allocate(output, source=this_layer % output) type is(linear2d_layer) allocate(output, source=this_layer % output) + type is(self_attention_layer) + allocate(output, source=this_layer % output) class default error stop '2-d output can only be read from an input2d or linear2d layer.' @@ -359,6 +391,8 @@ elemental module function get_num_params(self) result(num_params) num_params = 0 type is (linear2d_layer) num_params = this_layer % get_num_params() + type is (self_attention_layer) + num_params = this_layer % get_num_params() class default error stop 'Unknown layer type.' end select @@ -388,6 +422,8 @@ module function get_params(self) result(params) ! No parameters to get. type is (linear2d_layer) params = this_layer % get_params() + type is (self_attention_layer) + params = this_layer % get_params() class default error stop 'Unknown layer type.' end select @@ -417,6 +453,8 @@ module function get_gradients(self) result(gradients) ! No gradients to get. type is (linear2d_layer) gradients = this_layer % get_gradients() + type is (self_attention_layer) + gradients = this_layer % get_gradients() class default error stop 'Unknown layer type.' end select @@ -467,6 +505,9 @@ module subroutine set_params(self, params) type is (linear2d_layer) call this_layer % set_params(params) + type is (self_attention_layer) + call this_layer % set_params(params) + type is (maxpool2d_layer) ! No parameters to set. write(stderr, '(a)') 'Warning: calling set_params() ' & diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index c2a9c903..94d85553 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -9,6 +9,7 @@ use nf_maxpool2d_layer, only: maxpool2d_layer use nf_reshape_layer, only: reshape3d_layer use nf_linear2d_layer, only: linear2d_layer + use nf_self_attention_layer, only: self_attention_layer use nf_layer, only: layer use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape use nf_loss, only: quadratic @@ -158,6 +159,8 @@ module subroutine backward(self, output, loss) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) type is(linear2d_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(self_attention_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) end select end if From 66b539b24eaed534673fef7279d98441fc8bbe46 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Mon, 17 Feb 2025 15:55:48 +0400 Subject: [PATCH 56/71] multihead_attention: add reference --- src/nf/nf_multihead_attention.f90 | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 372e195e..f4f2ef77 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -10,8 +10,16 @@ module nf_multihead_attention_layer public :: multihead_attention_layer type, extends(base_layer) :: multihead_attention_layer - - !! Concrete implementation of a multihead attention layer type + !! MultiHead Attention + !! Attention mechanism is widely used in Machine Learning, particularly in + !! Natural Language Processing, and is the basis of modern Language Models. + !! Attention creates Saliency Map between tokens that helps the model + !! achieve deeper contextual understanding of the data. + !! This implementation is based upon the Transformers article and + !! uses attention heads to help parallelize computations. + !! Source: + !! Waswani A. et al. Attention is all you need. + !! https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf integer :: sequence_length, model_dimension, n_heads, head_size type(linear2d_layer) :: query_layer From 992da676eaeebb85f56d85d634d5704c061d958d Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Tue, 18 Feb 2025 01:53:51 +0400 Subject: [PATCH 57/71] multihead_attention: remove rebase artifact --- example/linear2d.f90 | 39 --------------------------------------- 1 file changed, 39 deletions(-) delete mode 100644 example/linear2d.f90 diff --git a/example/linear2d.f90 b/example/linear2d.f90 deleted file mode 100644 index 980d45e4..00000000 --- a/example/linear2d.f90 +++ /dev/null @@ -1,39 +0,0 @@ -program linear2d_example - - use nf, only: input, network, sgd, linear2d, mse, flatten - implicit none - - type(network) :: net - type(mse) :: loss - real :: x(3, 4) = reshape( & - [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12, 0.13], & - [3, 4]) - real :: y(3) = [0.12, 0.1, 0.2] - real :: preds(3) - real :: loss_value - integer, parameter :: num_iterations = 500 - integer :: n - - net = network([ & - input(3, 4), & - linear2d(3, 1), & - flatten() & - ]) - - call net % print_info() - loss = mse() - - do n = 1, num_iterations - call net % forward(x) - call net % backward(y, loss) - call net % update(optimizer=sgd(learning_rate=0.01)) - preds = net % predict(x) - print '(i4,3(3x,f8.6))', n, preds - loss_value = loss % eval (y, preds) - if (loss_value < 0.01) then - print *, 'Loss: ', loss_value - return - end if - end do - -end program linear2d_example \ No newline at end of file From d93be416fd787352c1e2ca2ba5f8f6fde46b7d6c Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Wed, 19 Feb 2025 19:32:49 +0400 Subject: [PATCH 58/71] multihead_attention: remove redundant args --- src/nf/nf_layer_constructors.f90 | 9 +++++++-- src/nf/nf_layer_constructors_submodule.f90 | 5 ++--- src/nf/nf_layer_submodule.f90 | 4 +++- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index d65e47e5..4b1b6f23 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -195,9 +195,14 @@ module function linear2d(out_features) result(res) !! Resulting layer instance end function linear2d - module function self_attention(sequence_length, model_dimension, n_heads) result(res) - integer, intent(in) :: sequence_length, model_dimension, n_heads + module function self_attention(n_heads) result(res) + !! Rank-2 (sequence_length, out_features) self attention constructor. + !! sequence_length and model_dimension are determined at layer initialization, based on the + !! output shape of the previous layer. + integer, intent(in) :: n_heads + !! Number of attention heads type(layer) :: res + !! Resulting layer instance end function self_attention end interface diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index 4cb06f74..6616bbeb 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -161,12 +161,11 @@ module function linear2d(out_features) result(res) end function linear2d - module function self_attention(sequence_length, model_dimension, n_heads) result(res) - integer, intent(in) :: sequence_length, model_dimension, n_heads + module function self_attention(n_heads) result(res) + integer, intent(in) :: n_heads type(layer) :: res res % name = 'self_attention' - res % layer_shape = [sequence_length, model_dimension] allocate(res % p, source=self_attention_layer(n_heads)) end function self_attention diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index 46dbb883..0ada56eb 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -333,7 +333,7 @@ impure elemental module subroutine init(self, input) call this_layer % init(input % layer_shape) end select - ! The shape of linear2d, conv2d, maxpool2d, or flatten layers is not known + ! The shape of self_attention, linear2d, conv2d, maxpool2d, or flatten layers is not known ! until we receive an input layer. select type(this_layer => self % p) type is(conv2d_layer) @@ -344,6 +344,8 @@ impure elemental module subroutine init(self, input) self % layer_shape = shape(this_layer % output) type is(linear2d_layer) self % layer_shape = shape(this_layer % output) + type is(self_attention_layer) + self % layer_shape = shape(this_layer % output) end select self % input_layer_shape = input % layer_shape From 70272cb69b314ca44b5bd101a1856b16785e1dd0 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Wed, 19 Feb 2025 21:37:58 +0400 Subject: [PATCH 59/71] multihead_attention: update tests --- test/test_multihead_attention_layer.f90 | 67 ++++++++++++++----------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index 7b2d53b6..0d394552 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -31,6 +31,14 @@ program test_multihead_attention_layer call test_cross_attention(ok) contains + function allclose(x, y) result(res) + real, intent(in) :: x(:) + real, intent(in) :: y(:) + logical :: res + + res = all(abs(x - y) <= (1e-08 + 1e-05 * abs(y))) + end function allclose + subroutine set_weights(attention) type(multihead_attention_layer), intent(in out) :: attention attention % query_layer % weights = 0.1 @@ -72,15 +80,16 @@ subroutine test_multihead_attention_create_attention_matrix(attention, input, ok real, intent(in) :: input(:, :, :) logical, intent(in out) :: ok real :: attention_matrix_shape(3) - real :: attention_matrix_flat(18) + real, volatile :: attention_matrix_flat(18) real :: expected_shape(3) = [3, 3, 2] real :: expected_attention_matrix_flat(18) = [& - 9.00000036E-02, 0.120000005, 0.150000006, 0.120000005, 0.170000002, 0.219999999,& - 0.150000006, 0.219999999, 0.289999992, 1.16999996, 0.518999994, 0.588000000,& - 0.518999994, 0.502099991, 0.573199987, 0.588000000, 0.573199987, 0.654400051& + 0.09, 0.12, 0.15, 0.12, 0.17, 0.22,& + 0.15, 0.22, 0.29, 1.17, 0.519, 0.588,& + 0.519, 0.5021, 0.5732, 0.588, 0.5732, 0.6544& ] call attention % create_attention_matrix(input, input) + print *, attention % attention_matrix attention_matrix_shape = shape(attention % attention_matrix) if (.not. all(attention_matrix_shape.eq.expected_shape)) then @@ -88,7 +97,7 @@ subroutine test_multihead_attention_create_attention_matrix(attention, input, ok write(stderr, '(a)') 'create_attention_matrix returned incorrect shape.. failed' end if attention_matrix_flat = reshape(attention % attention_matrix, shape(expected_attention_matrix_flat)) - if (.not. all(attention_matrix_flat.eq.expected_attention_matrix_flat)) then + if (.not. allclose(attention_matrix_flat, expected_attention_matrix_flat)) then ok = .false. write(stderr, '(a)') 'create_attention_matrix returned incorrect values.. failed' end if @@ -97,7 +106,7 @@ end subroutine test_multihead_attention_create_attention_matrix subroutine test_multihead_attention_normalization(attention, ok) type(multihead_attention_layer), intent(in out) :: attention logical, intent(in out) :: ok - real :: output_flat(18) + real, volatile :: output_flat(18) real :: expected_output_flat(18) = [& 0.326287806, 0.321620107, 0.316976935, 0.333283335, 0.333194494, 0.333061278,& 0.340428889, 0.345185429, 0.349961787, 0.435975075, 0.330339372, 0.329200655,& @@ -107,7 +116,7 @@ subroutine test_multihead_attention_normalization(attention, ok) call attention % normalize_attention_matrix() output_flat = reshape(attention % attention_matrix, shape(output_flat)) - if (.not. all(output_flat.eq.expected_output_flat)) then + if (.not. allclose(output_flat, expected_output_flat)) then ok = .false. write(stderr, '(a)') 'normalize_attention_matrix returned incorrect values.. failed' end if @@ -117,7 +126,7 @@ subroutine test_multihead_attention_scaled_dot_product_attention(attention, valu type(multihead_attention_layer), intent(in out) :: attention real, intent(in) :: value(:, :, :) logical, intent(in out) :: ok - real :: output_flat(12) + real, volatile :: output_flat(12) real :: expected_output_flat(12) = [& 0.101414114, 0.102356538, 0.103298485, 0.401414126, 0.402356565, 0.403298497,& 0.685291648, 0.701290667, 0.701582491, 0.457309216, 0.374400556, 0.373518765& @@ -126,7 +135,7 @@ subroutine test_multihead_attention_scaled_dot_product_attention(attention, valu call attention % scaled_dot_product_attention(value) output_flat = reshape(attention % sdpa, shape(output_flat)) - if (.not. all(output_flat.eq.expected_output_flat)) then + if (.not. allclose(output_flat, expected_output_flat)) then ok = .false. write(stderr, '(a)') 'scaled_dot_product_attention returned incorrect values.. failed' end if @@ -146,7 +155,7 @@ subroutine test_multihead_attention_combine_heads(attention, scaled_dp_att, ok) output = attention % combine_heads(scaled_dp_att) output_flat = reshape(output, shape(output_flat)) - if (.not. all(output_flat.eq.expected_output_flat)) then + if (.not. allclose(output_flat, expected_output_flat)) then ok = .false. write(stderr, '(a)') 'combine_heads returned incorrect values.. failed' end if @@ -157,10 +166,10 @@ subroutine test_multihead_attention_forward(attention, ok) logical, intent(in out) :: ok real :: input(3, 4) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]) real :: output(attention % sequence_length, attention % model_dimension) - real :: output_flat(12) + real, volatile :: output_flat(12) integer :: output_shape(2) integer :: attn_weights_shape(3) - real :: attn_weights_flat(18) + real, volatile :: attn_weights_flat(18) integer :: expected_shape(2) = [3, 4] real :: expected_output_flat(12) = [& 0.982241452, 1.00407875, 1.00444126, 0.982241452, 1.00407875, 1.00444126,& @@ -181,7 +190,7 @@ subroutine test_multihead_attention_forward(attention, ok) write(stderr, '(a)') 'forward returned incorrect shape.. failed' end if output_flat = reshape(attention % output, shape(output_flat)) - if (.not. all(output_flat.eq.expected_output_flat)) then + if (.not. allclose(output_flat, expected_output_flat)) then ok = .false. write(stderr, '(a)') 'forward returned incorrect values.. failed' end if @@ -192,7 +201,7 @@ subroutine test_multihead_attention_forward(attention, ok) write(stderr, '(a)') 'forward returned incorrect attention weights shape.. failed' end if attn_weights_flat = reshape(attention % attention_matrix, shape(attn_weights_flat)) - if (.not. all(attn_weights_flat.eq.expected_attn_weights_flat)) then + if (.not. allclose(attn_weights_flat, expected_attn_weights_flat)) then ok = .false. write(stderr, '(a)') 'forward returned incorrect attention weights values.. failed' end if @@ -202,8 +211,6 @@ subroutine test_multihead_attention_forward_reallife_shape(ok) logical, intent(in out) :: ok real :: input(148, 512) real :: output(148, 512) - type(linear2d_layer) :: q - real :: output_flat(12) integer :: output_shape(2) integer :: expected_shape(2) = [148, 512] type(multihead_attention_layer) :: attention @@ -236,7 +243,7 @@ subroutine test_multihead_attention_backward(attention, ok) ] real :: expected_shape(2) = [3, 4] real :: output(3, 4) - real :: output_flat(12) + real, volatile :: output_flat(12) real :: output_shape(2) call attention % common_backward(input, gradient) @@ -254,7 +261,7 @@ subroutine test_multihead_attention_backward(attention, ok) write(stderr, '(a)') 'backward returned incorrect shape.. failed' end if output_flat = reshape(output, shape(output_flat)) - if (.not. all(output_flat.eq.expected_output_flat)) then + if (.not. allclose(output_flat, expected_output_flat)) then ok = .false. write(stderr, '(a)') 'backward returned incorrect values.. failed' end if @@ -265,7 +272,7 @@ subroutine test_multihead_attention_update_gradients(attention, ok) logical, intent(in out) :: ok real :: parameters(80) real :: expected_parameters(80) - real :: updated_output(12) + real, volatile :: updated_output(12) real :: expected_updated_output(12) = [& 0.111365855, 0.115744293, 0.115733206, 0.185253710, 0.196646214, 0.196617395,& -0.102874994, -0.118834510, -0.118794113, 0.179314315, 0.190210193, 0.190182626& @@ -296,7 +303,7 @@ subroutine test_multihead_attention_update_gradients(attention, ok) ) updated_output = reshape(attention % output, [12]) - if (.not. all(updated_output.eq.expected_updated_output)) then + if (.not. allclose(updated_output, expected_updated_output)) then ok = .false. write(stderr, '(a)') 'incorrect output after parameters update.. failed' end if @@ -307,12 +314,12 @@ subroutine test_self_attention(ok) type(self_attention_layer) :: attention real :: input(2, 3) = reshape([-1., 0., 17., .4, 5., .6], [2, 3]) real :: output(2, 3) - real :: output_flat(6) + real, volatile :: output_flat(6) real :: expected_output_flat(6) = [& 0.772716165, 0.577548742, 0.772716165, 0.577548742, 0.772716165, 0.577548742& ] real :: gradient(2, 3) = reshape([1., 2., .17, 4., .5, 6.], [2, 3]) - real :: gradient_flat(6) + real, volatile :: gradient_flat(6) real :: expected_gradient_flat(6) = [& 0.350671142, 0.607403040, 0.350671142, 0.607403040, 0.350671142, 0.607403040& ] @@ -330,14 +337,14 @@ subroutine test_self_attention(ok) call attention % forward(input) output_flat = reshape(attention % output, shape(output_flat)) - if (.not. all(output_flat.eq.expected_output_flat)) then + if (.not. allclose(output_flat, expected_output_flat)) then ok = .false. write(stderr, '(a)') 'forward returned incorrect values.. failed' end if call attention % backward(input, gradient) gradient_flat = reshape(attention % gradient, shape(gradient_flat)) - if (.not. all(gradient_flat.eq.expected_gradient_flat)) then + if (.not. allclose(gradient_flat, expected_gradient_flat)) then ok = .false. write(stderr, '(a)') 'backward returned incorrect values.. failed' end if @@ -350,13 +357,13 @@ subroutine test_cross_attention(ok) real :: key_value(2, 3) = reshape([0.1, -.2, 0.3, 4., 15., 0.5], [2, 3]) real :: input(2, 2, 3) real :: output(2, 2, 3) - real :: output_flat(6) + real, volatile :: output_flat(6) real :: expected_output_flat(6) = [& 0.600311756, 0.471662223, 0.600311756, 0.471662223, 0.600311756, 0.471662223& ] real :: gradient(2, 3) = reshape([1., 2., .17, 4., .5, 6.], [2, 3]) - real :: query_gradient_flat(6) - real :: key_value_gradient_flat(6) + real, volatile :: query_gradient_flat(6) + real, volatile :: key_value_gradient_flat(6) real :: expected_query_gradient_flat(6) = [& 1.48406753E-03, 0.184446245, 1.48406753E-03, 0.184446245, 1.48406753E-03, 0.184446245& ] @@ -379,19 +386,19 @@ subroutine test_cross_attention(ok) call attention % forward(input) output_flat = reshape(attention % output, shape(output_flat)) - if (.not. all(output_flat.eq.expected_output_flat)) then + if (.not. allclose(output_flat, expected_output_flat)) then ok = .false. write(stderr, '(a)') 'forward returned incorrect values.. failed' end if call attention % backward(input, gradient) query_gradient_flat = reshape(attention % gradient(1, :, :), shape(query_gradient_flat)) - if (.not. all(query_gradient_flat.eq.expected_query_gradient_flat)) then + if (.not. allclose(query_gradient_flat, expected_query_gradient_flat)) then ok = .false. write(stderr, '(a)') 'backward returned incorrect query values.. failed' end if key_value_gradient_flat = reshape(attention % gradient(2, :, :), shape(key_value_gradient_flat)) - if (.not. all(key_value_gradient_flat.eq.expected_key_value_gradient_flat)) then + if (.not. allclose(key_value_gradient_flat, expected_key_value_gradient_flat)) then ok = .false. write(stderr, '(a)') 'backward returned incorrect key-value values.. failed' end if From cb717f56c053b631d99a1ac5f1e1d9b88a26ac82 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Wed, 19 Feb 2025 21:43:17 +0400 Subject: [PATCH 60/71] multihead_attention: add the most important lines to tests --- test/test_multihead_attention_layer.f90 | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index 0d394552..90954ce4 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -30,6 +30,13 @@ program test_multihead_attention_layer call test_self_attention(ok) call test_cross_attention(ok) + if (ok) then + print '(a)', 'test_multihead_attention_layer: All tests passed.' + else + write(stderr, '(a)') 'test_multihead_attention_layer: One or more tests failed.' + stop 1 + end if + contains function allclose(x, y) result(res) real, intent(in) :: x(:) From b7a6d0610308bcbf390394c5725d80db12be7fa5 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Wed, 19 Feb 2025 21:43:36 +0400 Subject: [PATCH 61/71] multihead_attention: simple MHA example --- example/mha_simple.f90 | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 example/mha_simple.f90 diff --git a/example/mha_simple.f90 b/example/mha_simple.f90 new file mode 100644 index 00000000..c2387d8c --- /dev/null +++ b/example/mha_simple.f90 @@ -0,0 +1,37 @@ +program simple + use nf, only: dense, input, network, sgd, self_attention, flatten + implicit none + type(network) :: net + real, allocatable :: x(:, :), y(:) + integer, parameter :: num_iterations = 500 + integer :: n + + print '("Simple")' + print '(60("="))' + + net = network([ & + input(3, 8), & + self_attention(4), & + flatten(), & + dense(2) & + ]) + + call net % print_info() + + allocate(x(3, 8)) + call random_number(x) + + y = [0.123456, 0.246802] + + do n = 0, num_iterations + + call net % forward(x) + call net % backward(y) + call net % update(optimizer=sgd(learning_rate=1.)) + + if (mod(n, 50) == 0) & + print '(i4,2(3x,f8.6))', n, net % predict(x) + + end do + +end program simple From cb26afb23159635d47160e685c2fa7d7ed33e24a Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Wed, 19 Feb 2025 21:43:55 +0400 Subject: [PATCH 62/71] multihead_attention: update cmake --- example/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 0257dd7d..f4b706b8 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -2,11 +2,11 @@ foreach(execid cnn_mnist dense_mnist get_set_network_params - linear2d network_parameters simple sine quadratic + mha_simple ) add_executable(${execid} ${execid}.f90) target_link_libraries(${execid} PRIVATE From 4c92e9c4554788f6781487af76066e641145eff8 Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Wed, 19 Feb 2025 22:24:24 +0400 Subject: [PATCH 63/71] multihead_attention: remove debug line from tests --- test/test_multihead_attention_layer.f90 | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index 90954ce4..48fc0f69 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -96,7 +96,6 @@ subroutine test_multihead_attention_create_attention_matrix(attention, input, ok ] call attention % create_attention_matrix(input, input) - print *, attention % attention_matrix attention_matrix_shape = shape(attention % attention_matrix) if (.not. all(attention_matrix_shape.eq.expected_shape)) then From df5f4cfed99d547a094d3bc1a66f50c69bc28c6e Mon Sep 17 00:00:00 2001 From: Mikhail Voronov Date: Wed, 19 Feb 2025 22:30:57 +0400 Subject: [PATCH 64/71] multihead_attention: set slightly higher margin for fp imprecision (due to IEEE_DENORMAL) --- test/test_multihead_attention_layer.f90 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index 48fc0f69..fdc6862d 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -43,7 +43,7 @@ function allclose(x, y) result(res) real, intent(in) :: y(:) logical :: res - res = all(abs(x - y) <= (1e-08 + 1e-05 * abs(y))) + res = all(abs(x - y) <= (1e-06 + 1e-05 * abs(y))) end function allclose subroutine set_weights(attention) From 6162783ffdb5380cd94633b46a3c97acbd758caa Mon Sep 17 00:00:00 2001 From: milancurcic Date: Fri, 21 Feb 2025 14:38:52 -0500 Subject: [PATCH 65/71] Rename mha_simple example --- example/mha_simple.f90 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/mha_simple.f90 b/example/mha_simple.f90 index c2387d8c..2daa5ac2 100644 --- a/example/mha_simple.f90 +++ b/example/mha_simple.f90 @@ -1,4 +1,4 @@ -program simple +program mha_simple use nf, only: dense, input, network, sgd, self_attention, flatten implicit none type(network) :: net @@ -34,4 +34,4 @@ program simple end do -end program simple +end program mha_simple From 89abf221cc60973433302ca1071f93425d1f7aae Mon Sep 17 00:00:00 2001 From: Milan Curcic Date: Fri, 21 Feb 2025 14:55:28 -0500 Subject: [PATCH 66/71] Update src/nf/nf_multihead_attention.f90 Co-authored-by: Jeremie Vandenplas --- src/nf/nf_multihead_attention.f90 | 1 - 1 file changed, 1 deletion(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index f4f2ef77..a629c39f 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -112,7 +112,6 @@ pure module subroutine create_attention_matrix(self, query, key) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: query(:, :, :) real, intent(in) :: key(:, :, :) - integer :: head end subroutine create_attention_matrix pure module subroutine normalize_attention_matrix(self, attention_mask) From 29b7d2e143c201843e7a25ad2c3d0fa7d7f9531b Mon Sep 17 00:00:00 2001 From: Milan Curcic Date: Fri, 21 Feb 2025 14:56:25 -0500 Subject: [PATCH 67/71] Update src/nf/nf_multihead_attention.f90 Co-authored-by: Jeremie Vandenplas --- src/nf/nf_multihead_attention.f90 | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index a629c39f..d383110e 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -121,8 +121,6 @@ pure module subroutine normalize_attention_matrix(self, attention_mask) !! (sequence_length, sequence_length, n_heads) real, optional, intent(in) :: attention_mask(:, :, :) !! (sequence_length, sequence_length, n_heads) - real, allocatable :: output(:, :, :) - integer :: head, seq end subroutine normalize_attention_matrix pure module subroutine scaled_dot_product_attention(self, value) From e9014792dae56e08829ccb492470897c86504016 Mon Sep 17 00:00:00 2001 From: Milan Curcic Date: Fri, 21 Feb 2025 14:57:05 -0500 Subject: [PATCH 68/71] Update src/nf/nf_multihead_attention.f90 Co-authored-by: Jeremie Vandenplas --- src/nf/nf_multihead_attention.f90 | 1 - 1 file changed, 1 deletion(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index d383110e..51ee1f72 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -128,7 +128,6 @@ pure module subroutine scaled_dot_product_attention(self, value) !! Output dims: sequence_length, head_size, n_heads class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: value(:, :, :) - integer :: head end subroutine scaled_dot_product_attention pure module function combine_heads(self, input) result(output) From 1eaee959861411fd3f1a4e88b6e1c7a677c7f15a Mon Sep 17 00:00:00 2001 From: Milan Curcic Date: Fri, 21 Feb 2025 14:57:17 -0500 Subject: [PATCH 69/71] Update src/nf/nf_multihead_attention.f90 Co-authored-by: Jeremie Vandenplas --- src/nf/nf_multihead_attention.f90 | 1 - 1 file changed, 1 deletion(-) diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 51ee1f72..80a59dfb 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -135,7 +135,6 @@ pure module function combine_heads(self, input) result(output) real, intent(in) :: input(:, :, :) !! (sequence_length, head_size, n_heads) real :: output(self % sequence_length, self % model_dimension) - integer :: seq end function combine_heads elemental module function get_num_params(self) result(num_params) From 588ecb1d22460f11de639d62019a76adaa6adf8c Mon Sep 17 00:00:00 2001 From: milancurcic Date: Fri, 21 Feb 2025 15:02:44 -0500 Subject: [PATCH 70/71] Tidy up --- src/nf/nf_layer_constructors.f90 | 4 ++-- src/nf/nf_layer_constructors_submodule.f90 | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index 5f7b5af4..db60cf0f 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -222,11 +222,11 @@ module function linear2d(out_features) result(res) !! Resulting layer instance end function linear2d - module function self_attention(n_heads) result(res) + module function self_attention(num_heads) result(res) !! Rank-2 (sequence_length, out_features) self attention constructor. !! sequence_length and model_dimension are determined at layer initialization, based on the !! output shape of the previous layer. - integer, intent(in) :: n_heads + integer, intent(in) :: num_heads !! Number of attention heads type(layer) :: res !! Resulting layer instance diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index be85c9c9..9e5322c1 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -171,12 +171,12 @@ module function linear2d(out_features) result(res) end function linear2d - module function self_attention(n_heads) result(res) - integer, intent(in) :: n_heads + module function self_attention(num_heads) result(res) + integer, intent(in) :: num_heads type(layer) :: res res % name = 'self_attention' - allocate(res % p, source=self_attention_layer(n_heads)) + allocate(res % p, source=self_attention_layer(num_heads)) end function self_attention end submodule nf_layer_constructors_submodule From 20ffe05cb2d83a26012d92756f68a904d6558e1e Mon Sep 17 00:00:00 2001 From: milancurcic Date: Fri, 21 Feb 2025 15:02:56 -0500 Subject: [PATCH 71/71] Add self_attention to the layers table --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a0eee745..a04ac32a 100644 --- a/README.md +++ b/README.md @@ -34,8 +34,9 @@ Read the paper [here](https://arxiv.org/abs/1902.06714). | Dropout | `dropout` | `dense`, `flatten`, `input1d` | 1 | ✅ | ✅ | | Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅(*) | | Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅ | +| Linear (2-d) | `linear2d` | `input2d`, `linear2d`, `self_attention` | 2 | ✅ | ✅ | +| Self-attention | `self_attention` | `input2d`, `linear2d`, `self_attention` | 2 | ✅ | ✅ | | Flatten | `flatten` | `input2d`, `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 | ✅ | ✅ | -| Linear (2-d) | `linear2d` | `input2d`, `linear2d` | 2 | ✅ | ✅ | | Reshape (1-d to 3-d) | `reshape` | `input1d`, `dense`, `flatten` | 3 | ✅ | ✅ | (*) See Issue [#145](https://github.com/modern-fortran/neural-fortran/issues/145) regarding non-converging CNN training on the MNIST dataset.