@@ -6,7 +6,8 @@ program test_embedding_layer
66 logical :: ok = .true.
77
88 call test_simple(ok)
9- call test_positional(ok)
9+ call test_positional_trigonometric(ok)
10+ call test_positional_absolute(ok)
1011
1112 if (ok) then
1213 print ' (a)' , ' test_embedding_layer: All tests passed.'
@@ -47,7 +48,7 @@ subroutine test_simple(ok)
4748 end if
4849 end subroutine test_simple
4950
50- subroutine test_positional (ok )
51+ subroutine test_positional_trigonometric (ok )
5152 logical , intent (in out ) :: ok
5253
5354 integer :: sample_input(3 ) = [2 , 1 , 3 ]
@@ -63,7 +64,7 @@ subroutine test_positional(ok)
6364 real :: theta
6465 integer :: i, pos
6566
66- embedding = embedding_layer(vocab_size= 5 , model_dimension= 4 , positional= .true. )
67+ embedding = embedding_layer(vocab_size= 5 , model_dimension= 4 , positional= 1 )
6768 call embedding % init([3 ])
6869 embedding % weights = reshape ([&
6970 0.1 , 0.3 , 0.5 , 0.7 , 0.2 ,&
@@ -77,7 +78,41 @@ subroutine test_positional(ok)
7778 output_flat = reshape (embedding % output, [12 ])
7879 if (.not. all (abs (output_flat - expected_output_flat) <= (1e-06 + 1e-05 * abs (expected_output_flat)))) then
7980 ok = .false.
80- write (stderr, ' (a)' ) ' positional encoding returned incorrect values.. failed'
81+ write (stderr, ' (a)' ) ' trigonometric positional encoding returned incorrect values.. failed'
8182 end if
82- end subroutine test_positional
83+ end subroutine test_positional_trigonometric
84+
85+ subroutine test_positional_absolute (ok )
86+ logical , intent (in out ) :: ok
87+
88+ integer :: sample_input(3 ) = [2 , 1 , 3 ]
89+ real :: output_flat(12 )
90+ real :: expected_output_flat(12 ) = reshape ([&
91+ 0.3 , 1.1 , 2.5 ,&
92+ 0.3 , 1.1 , 2.5 ,&
93+ 0.3 , 1.1 , 2.5 ,&
94+ 0.3 , 1.1 , 2.5 &
95+ ], [12 ])
96+ type (embedding_layer) :: embedding
97+
98+ real :: theta
99+ integer :: i, pos
100+
101+ embedding = embedding_layer(vocab_size= 5 , model_dimension= 4 , positional= 2 )
102+ call embedding % init([3 ])
103+ embedding % weights = reshape ([&
104+ 0.1 , 0.3 , 0.5 , 0.7 , 0.2 ,&
105+ 0.1 , 0.3 , 0.5 , 0.7 , 0.2 ,&
106+ 0.1 , 0.3 , 0.5 , 0.7 , 0.2 ,&
107+ 0.1 , 0.3 , 0.5 , 0.7 , 0.2 &
108+ ], [5 , 4 ])
109+
110+ call embedding % forward(sample_input)
111+
112+ output_flat = reshape (embedding % output, [12 ])
113+ if (.not. all (abs (output_flat - expected_output_flat) <= (1e-06 + 1e-05 * abs (expected_output_flat)))) then
114+ ok = .false.
115+ write (stderr, ' (a)' ) ' absolute positional encoding returned incorrect values.. failed'
116+ end if
117+ end subroutine test_positional_absolute
83118end program test_embedding_layer
0 commit comments