Skip to content

Commit 2731d63

Browse files
committed
multihead_attention: remove redundand constructor args for attention layers
1 parent 475cd06 commit 2731d63

File tree

5 files changed

+43
-77
lines changed

5 files changed

+43
-77
lines changed

src/nf/nf_cross_attention_layer.f90

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,38 +20,19 @@ module nf_cross_attention_layer
2020
end type cross_attention_layer
2121

2222
interface cross_attention_layer
23-
module function cross_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res)
23+
module function cross_attention_layer_cons(n_heads) result(res)
2424
!! This function returns the `cross_attention_layer` instance.
2525
integer, intent(in) :: sequence_length, model_dimension, n_heads
2626
type(cross_attention_layer) :: res
2727
end function cross_attention_layer_cons
2828
end interface cross_attention_layer
2929

3030
contains
31-
module function cross_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res)
31+
module function cross_attention_layer_cons(n_heads) result(res)
3232
!! This function returns the `cross_attention_layer` instance.
33-
integer, intent(in) :: sequence_length, model_dimension, n_heads
33+
integer, intent(in) :: n_heads
3434
type(cross_attention_layer) :: res
35-
res % sequence_length = sequence_length
36-
res % model_dimension = model_dimension
3735
res % n_heads = n_heads
38-
39-
if (mod(model_dimension, n_heads) /= 0) then
40-
write(stderr, '(a)'), 'Number of heads must be divisible by model dimension'
41-
error stop
42-
end if
43-
res % head_size = model_dimension / n_heads
44-
45-
res % query_layer = linear2d_layer(model_dimension)
46-
res % key_layer = linear2d_layer(model_dimension)
47-
res % value_layer = linear2d_layer(model_dimension)
48-
res % output_layer = linear2d_layer(model_dimension)
49-
call res % query_layer % init([sequence_length, model_dimension])
50-
call res % key_layer % init([sequence_length, model_dimension])
51-
call res % value_layer % init([sequence_length, model_dimension])
52-
call res % output_layer % init([sequence_length, model_dimension])
53-
54-
res % softmax_func = softmax()
5536
end function cross_attention_layer_cons
5637

5738
module subroutine backward(self, input, gradient)

src/nf/nf_multihead_attention.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ module nf_multihead_attention_layer
5151
end type multihead_attention_layer
5252

5353
interface multihead_attention_layer
54-
module function multihead_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res)
54+
module function multihead_attention_layer_cons(n_heads) result(res)
5555
!! This function returns the `multihead_attention_layer` instance.
56-
integer, intent(in) :: sequence_length, model_dimension, n_heads
56+
integer, intent(in) :: n_heads
5757
type(multihead_attention_layer) :: res
5858
end function multihead_attention_layer_cons
5959
end interface multihead_attention_layer

src/nf/nf_multihead_attention_submodule.f90

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,11 @@
77
implicit none
88

99
contains
10-
module function multihead_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res)
11-
integer, intent(in) :: sequence_length, model_dimension, n_heads
10+
module function multihead_attention_layer_cons(n_heads) result(res)
11+
integer, intent(in) :: n_heads
1212
type(multihead_attention_layer) :: res
13-
res % sequence_length = sequence_length
14-
res % model_dimension = model_dimension
15-
res % n_heads = n_heads
1613

17-
if (mod(model_dimension, n_heads) /= 0) then
18-
write(stderr, '(a)'), 'Number of heads must be divisible by model dimension'
19-
error stop
20-
end if
21-
res % head_size = model_dimension / n_heads
22-
23-
res % query_layer = linear2d_layer(model_dimension)
24-
res % key_layer = linear2d_layer(model_dimension)
25-
res % value_layer = linear2d_layer(model_dimension)
26-
res % output_layer = linear2d_layer(model_dimension)
27-
call res % query_layer % init([sequence_length, model_dimension])
28-
call res % key_layer % init([sequence_length, model_dimension])
29-
call res % value_layer % init([sequence_length, model_dimension])
30-
call res % output_layer % init([sequence_length, model_dimension])
31-
32-
res % softmax_func = softmax()
14+
res % n_heads = n_heads
3315
end function multihead_attention_layer_cons
3416

3517
module subroutine common_backward(self, input, gradient)
@@ -325,6 +307,28 @@ module subroutine init_base(self, input_shape)
325307
class(multihead_attention_layer), intent(in out) :: self
326308
integer, intent(in) :: input_shape(:)
327309

310+
if (size(input_shape) /= 2) then
311+
error stop "MultiHead Attention accepts 2D input"
312+
end if
313+
self % sequence_length = input_shape(1)
314+
self % model_dimension = input_shape(2)
315+
316+
if (mod(self % model_dimension, self % n_heads) /= 0) then
317+
write(stderr, '(a)'), 'Number of heads must be divisible by model dimension'
318+
error stop
319+
end if
320+
self % head_size = self % model_dimension / self % n_heads
321+
self % softmax_func = softmax()
322+
323+
self % query_layer = linear2d_layer(self % model_dimension)
324+
self % key_layer = linear2d_layer(self % model_dimension)
325+
self % value_layer = linear2d_layer(self % model_dimension)
326+
self % output_layer = linear2d_layer(self % model_dimension)
327+
call self % query_layer % init([self % sequence_length, self % model_dimension])
328+
call self % key_layer % init([self % sequence_length, self % model_dimension])
329+
call self % value_layer % init([self % sequence_length, self % model_dimension])
330+
call self % output_layer % init([self % sequence_length, self % model_dimension])
331+
328332
allocate(self % attention_matrix(self % sequence_length, self % sequence_length, self % n_heads))
329333
allocate(self % sdpa(self % sequence_length, self % head_size, self % n_heads))
330334
allocate(self % output(self % sequence_length, self % model_dimension))

src/nf/nf_self_attention_layer.f90

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,38 +20,19 @@ module nf_self_attention_layer
2020
end type self_attention_layer
2121

2222
interface self_attention_layer
23-
module function self_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res)
23+
module function self_attention_layer_cons(n_heads) result(res)
2424
!! This function returns the `self_attention_layer` instance.
25-
integer, intent(in) :: sequence_length, model_dimension, n_heads
25+
integer, intent(in) :: n_heads
2626
type(self_attention_layer) :: res
2727
end function self_attention_layer_cons
2828
end interface self_attention_layer
2929

3030
contains
31-
module function self_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res)
31+
module function self_attention_layer_cons(n_heads) result(res)
3232
!! This function returns the `self_attention_layer` instance.
33-
integer, intent(in) :: sequence_length, model_dimension, n_heads
33+
integer, intent(in) :: n_heads
3434
type(self_attention_layer) :: res
35-
res % sequence_length = sequence_length
36-
res % model_dimension = model_dimension
3735
res % n_heads = n_heads
38-
39-
if (mod(model_dimension, n_heads) /= 0) then
40-
write(stderr, '(a)'), 'Number of heads must be divisible by model dimension'
41-
error stop
42-
end if
43-
res % head_size = model_dimension / n_heads
44-
45-
res % query_layer = linear2d_layer(model_dimension)
46-
res % key_layer = linear2d_layer(model_dimension)
47-
res % value_layer = linear2d_layer(model_dimension)
48-
res % output_layer = linear2d_layer(model_dimension)
49-
call res % query_layer % init([sequence_length, model_dimension])
50-
call res % key_layer % init([sequence_length, model_dimension])
51-
call res % value_layer % init([sequence_length, model_dimension])
52-
call res % output_layer % init([sequence_length, model_dimension])
53-
54-
res % softmax_func = softmax()
5536
end function self_attention_layer_cons
5637

5738
module subroutine backward(self, input, gradient)

test/test_multihead_attention_layer.f90

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ program test_multihead_attention_layer
1414
real :: minput(3, 4) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [3, 4])
1515
real :: output(3, 2, 2)
1616

17-
attention = multihead_attention_layer(sequence_length=3, model_dimension=4, n_heads=2)
18-
call attention % init_base([0])
17+
attention = multihead_attention_layer(n_heads=2)
18+
call attention % init_base([3, 4])
1919
call set_weights(attention)
2020

2121
call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output)
@@ -210,8 +210,8 @@ subroutine test_multihead_attention_forward_reallife_shape(ok)
210210

211211
call random_number(input)
212212

213-
attention = multihead_attention_layer(sequence_length=148, model_dimension=512, n_heads=8)
214-
call attention % init_base([0])
213+
attention = multihead_attention_layer(n_heads=8)
214+
call attention % init_base([148, 512])
215215
call set_weights(attention)
216216

217217
call attention % common_forward(input, input, input)
@@ -317,8 +317,8 @@ subroutine test_self_attention(ok)
317317
0.350671142, 0.607403040, 0.350671142, 0.607403040, 0.350671142, 0.607403040&
318318
]
319319

320-
attention = self_attention_layer(sequence_length=2, model_dimension=3, n_heads=1)
321-
call attention % init([0])
320+
attention = self_attention_layer(n_heads=1)
321+
call attention % init([2, 3])
322322
attention % query_layer % weights = 0.1
323323
attention % key_layer % weights = 0.1
324324
attention % value_layer % weights = 0.1
@@ -366,8 +366,8 @@ subroutine test_cross_attention(ok)
366366
input(1, :, :) = query
367367
input(2, :, :) = key_value
368368

369-
attention = cross_attention_layer(sequence_length=2, model_dimension=3, n_heads=1)
370-
call attention % init([0])
369+
attention = cross_attention_layer(n_heads=1)
370+
call attention % init([2, 3])
371371
attention % query_layer % weights = 0.1
372372
attention % key_layer % weights = 0.1
373373
attention % value_layer % weights = 0.1
@@ -396,4 +396,4 @@ subroutine test_cross_attention(ok)
396396
write(stderr, '(a)') 'backward returned incorrect key-value values.. failed'
397397
end if
398398
end subroutine test_cross_attention
399-
end program test_multihead_attention_layer
399+
end program test_multihead_attention_layer

0 commit comments

Comments
 (0)