Skip to content

Commit 9844107

Browse files
committed
layernorm: remove redundant arguments
1 parent c2d4eac commit 9844107

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

src/nf/nf_layernorm.f90

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,17 @@ module nf_layernorm_layer
3737
end type layernorm_layer
3838

3939
interface layernorm_layer
40-
module function layernorm_layer_cons(sequence_length, model_dimension) &
40+
module function layernorm_layer_cons() &
4141
result(res)
42-
integer, intent(in) :: sequence_length, model_dimension
4342
type(layernorm_layer) :: res
4443
end function layernorm_layer_cons
4544
end interface layernorm_layer
4645

4746
contains
48-
module function layernorm_layer_cons(sequence_length, model_dimension) &
47+
module function layernorm_layer_cons() &
4948
result(res)
50-
integer, intent(in) :: sequence_length, model_dimension
5149
type(layernorm_layer) :: res
5250

53-
res % sequence_length = sequence_length
54-
res % model_dimension = model_dimension
5551
res % eps = 1e-5
5652
end function layernorm_layer_cons
5753

@@ -141,6 +137,12 @@ module subroutine init(self, input_shape)
141137
class(layernorm_layer), intent(in out) :: self
142138
integer, intent(in) :: input_shape(:)
143139

140+
if (size(input_shape) /= 2) then
141+
error stop "LayerNorm Layer accepts 2D input"
142+
end if
143+
self % sequence_length = input_shape(1)
144+
self % model_dimension = input_shape(2)
145+
144146
! default initialization from PyTorch
145147
allocate(self % gamma(self % model_dimension))
146148
self % gamma = 1.

test/test_layernorm.f90

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,19 @@ program test_layernorm
88
real :: sample_input(3, 4) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4])
99
real :: sample_gradient(3, 4) = reshape([0.1, 3., 2., 0.1, 3., 3., 0.1, 2., 0.1, 3., 0.1, 3.], [3, 4])
1010

11-
layernorm = layernorm_layer(3, 4)
12-
call layernorm % init([0])
11+
layernorm = layernorm_layer()
12+
call layernorm % init([3, 4])
1313

1414
call test_layernorm_forward(layernorm, sample_input, ok)
1515
call test_layernorm_backward(layernorm, sample_input, sample_gradient, ok)
1616

17+
if (ok) then
18+
print '(a)', 'test_layernorm_layer: All tests passed.'
19+
else
20+
write(stderr, '(a)') 'test_layernorm_layer: One or more tests failed.'
21+
stop 1
22+
end if
23+
1724
contains
1825
subroutine test_layernorm_forward(layernorm, input, ok)
1926
type(layernorm_layer), intent(in out) :: layernorm

0 commit comments

Comments
 (0)