11program test_embedding_layer
22 use iso_fortran_env, only: stderr = > error_unit
33 use nf_embedding_layer, only: embedding_layer
4+ use nf_layer, only: layer
5+ use nf_layer_constructors, only: embedding_constructor = > embedding
46 implicit none
57
68 logical :: ok = .true.
9+ integer :: sample_input(3 ) = [2 , 1 , 3 ]
710
8- call test_simple(ok)
9- call test_positional_trigonometric(ok)
10- call test_positional_absolute(ok)
11+ call test_simple(ok, sample_input )
12+ call test_positional_trigonometric(ok, sample_input )
13+ call test_positional_absolute(ok, sample_input )
1114
1215 if (ok) then
1316 print ' (a)' , ' test_embedding_layer: All tests passed.'
@@ -17,10 +20,10 @@ program test_embedding_layer
1720 end if
1821
1922contains
20- subroutine test_simple (ok )
23+ subroutine test_simple (ok , sample_input )
2124 logical , intent (in out ) :: ok
25+ integer , intent (in ) :: sample_input(:)
2226
23- integer :: sample_input(3 ) = [2 , 1 , 3 ]
2427 real :: sample_gradient(3 , 2 ) = reshape ([0.1 , 0.2 , 0.3 , 0.4 , 0.6 , 0.6 ], [3 , 2 ])
2528 real :: output_flat(6 )
2629 real :: expected_output_flat(6 ) = reshape ([0.3 , 0.1 , 0.5 , 0.4 , 0.2 , 0.6 ], [6 ])
@@ -48,10 +51,10 @@ subroutine test_simple(ok)
4851 end if
4952 end subroutine test_simple
5053
51- subroutine test_positional_trigonometric (ok )
54+ subroutine test_positional_trigonometric (ok , sample_input )
5255 logical , intent (in out ) :: ok
56+ integer , intent (in ) :: sample_input(:)
5357
54- integer :: sample_input(3 ) = [2 , 1 , 3 ]
5558 real :: output_flat(12 )
5659 real :: expected_output_flat(12 ) = reshape ([&
5760 0.3 , 0.941471 , 1.4092975 ,&
@@ -82,10 +85,10 @@ subroutine test_positional_trigonometric(ok)
8285 end if
8386 end subroutine test_positional_trigonometric
8487
85- subroutine test_positional_absolute (ok )
88+ subroutine test_positional_absolute (ok , sample_input )
8689 logical , intent (in out ) :: ok
90+ integer , intent (in ) :: sample_input(:)
8791
88- integer :: sample_input(3 ) = [2 , 1 , 3 ]
8992 real :: output_flat(12 )
9093 real :: expected_output_flat(12 ) = reshape ([&
9194 0.3 , 1.1 , 2.5 ,&
@@ -115,4 +118,16 @@ subroutine test_positional_absolute(ok)
115118 write (stderr, ' (a)' ) ' absolute positional encoding returned incorrect values.. failed'
116119 end if
117120 end subroutine test_positional_absolute
121+
122+ subroutine test_embedding_constructor (ok , sample_input )
123+ logical , intent (in out ) :: ok
124+ integer , intent (in ) :: sample_input(:)
125+
126+ type (layer) :: embedding_constructed
127+
128+ embedding_constructed = embedding_constructor(sequence_length= 3 , vocab_size= 5 , model_dimension= 4 )
129+ embedding_constructed = embedding_constructor(sequence_length= 3 , vocab_size= 5 , model_dimension= 4 , positional= 0 )
130+ embedding_constructed = embedding_constructor(sequence_length= 3 , vocab_size= 5 , model_dimension= 4 , positional= 1 )
131+ embedding_constructed = embedding_constructor(sequence_length= 3 , vocab_size= 5 , model_dimension= 4 , positional= 2 )
132+ end subroutine test_embedding_constructor
118133end program test_embedding_layer
0 commit comments