|
7 | 7 | implicit none |
8 | 8 |
|
9 | 9 | contains |
10 | | - module function multihead_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res) |
11 | | - integer, intent(in) :: sequence_length, model_dimension, n_heads |
| 10 | + module function multihead_attention_layer_cons(n_heads) result(res) |
| 11 | + integer, intent(in) :: n_heads |
12 | 12 | type(multihead_attention_layer) :: res |
13 | | - res % sequence_length = sequence_length |
14 | | - res % model_dimension = model_dimension |
15 | | - res % n_heads = n_heads |
16 | 13 |
|
17 | | - if (mod(model_dimension, n_heads) /= 0) then |
18 | | - write(stderr, '(a)'), 'Number of heads must be divisible by model dimension' |
19 | | - error stop |
20 | | - end if |
21 | | - res % head_size = model_dimension / n_heads |
22 | | - |
23 | | - res % query_layer = linear2d_layer(model_dimension) |
24 | | - res % key_layer = linear2d_layer(model_dimension) |
25 | | - res % value_layer = linear2d_layer(model_dimension) |
26 | | - res % output_layer = linear2d_layer(model_dimension) |
27 | | - call res % query_layer % init([sequence_length, model_dimension]) |
28 | | - call res % key_layer % init([sequence_length, model_dimension]) |
29 | | - call res % value_layer % init([sequence_length, model_dimension]) |
30 | | - call res % output_layer % init([sequence_length, model_dimension]) |
31 | | - |
32 | | - res % softmax_func = softmax() |
| 14 | + res % n_heads = n_heads |
33 | 15 | end function multihead_attention_layer_cons |
34 | 16 |
|
35 | 17 | module subroutine common_backward(self, input, gradient) |
@@ -325,6 +307,28 @@ module subroutine init_base(self, input_shape) |
325 | 307 | class(multihead_attention_layer), intent(in out) :: self |
326 | 308 | integer, intent(in) :: input_shape(:) |
327 | 309 |
|
| 310 | + if (size(input_shape) /= 2) then |
| 311 | + error stop "MultiHead Attention accepts 2D input" |
| 312 | + end if |
| 313 | + self % sequence_length = input_shape(1) |
| 314 | + self % model_dimension = input_shape(2) |
| 315 | + |
| 316 | + if (mod(self % model_dimension, self % n_heads) /= 0) then |
| 317 | + write(stderr, '(a)'), 'Number of heads must be divisible by model dimension' |
| 318 | + error stop |
| 319 | + end if |
| 320 | + self % head_size = self % model_dimension / self % n_heads |
| 321 | + self % softmax_func = softmax() |
| 322 | + |
| 323 | + self % query_layer = linear2d_layer(self % model_dimension) |
| 324 | + self % key_layer = linear2d_layer(self % model_dimension) |
| 325 | + self % value_layer = linear2d_layer(self % model_dimension) |
| 326 | + self % output_layer = linear2d_layer(self % model_dimension) |
| 327 | + call self % query_layer % init([self % sequence_length, self % model_dimension]) |
| 328 | + call self % key_layer % init([self % sequence_length, self % model_dimension]) |
| 329 | + call self % value_layer % init([self % sequence_length, self % model_dimension]) |
| 330 | + call self % output_layer % init([self % sequence_length, self % model_dimension]) |
| 331 | + |
328 | 332 | allocate(self % attention_matrix(self % sequence_length, self % sequence_length, self % n_heads)) |
329 | 333 | allocate(self % sdpa(self % sequence_length, self % head_size, self % n_heads)) |
330 | 334 | allocate(self % output(self % sequence_length, self % model_dimension)) |
|
0 commit comments