Skip to content

Commit 3822c20

Browse files
committed
layernorm: remove stack allocated arrays
1 parent b0f9476 commit 3822c20

File tree

1 file changed

+15
-25
lines changed

1 file changed

+15
-25
lines changed

src/nf/nf_layernorm.f90

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ module nf_layernorm_layer
3131
contains
3232
procedure :: forward
3333
procedure :: backward
34-
procedure :: spread_by_sequence
35-
procedure :: spread_by_model_dim
3634
procedure :: init
3735
end type layernorm_layer
3836

@@ -90,8 +88,11 @@ pure module subroutine backward(self, input, gradient)
9088
allocate(one_over_sigma(self % sequence_length, self % model_dimension))
9189
allocate(gradient_by_gamma_over_sigma(self % sequence_length, self % model_dimension))
9290

93-
one_over_sigma = (1 / self % spread_by_model_dim(self % sigma))
94-
gradient_by_gamma_over_sigma = gradient * self % spread_by_sequence(self % gamma) * one_over_sigma
91+
one_over_sigma = (1 / spread(self % sigma, dim=2, ncopies=self % model_dimension))
92+
gradient_by_gamma_over_sigma = &
93+
gradient &
94+
* spread(self % gamma, dim=1, ncopies=self % sequence_length) &
95+
* one_over_sigma
9596

9697
! d_output/d_gamma = sum(d_output/d_y * mu/sigma)
9798
self % d_gamma = sum(gradient * self % mu * one_over_sigma, dim=1)
@@ -107,32 +108,21 @@ pure module subroutine backward(self, input, gradient)
107108
! - mu * sum(d_output/d_y * gamma * mu * sigma^(03)) / len
108109
self % gradient = &
109110
gradient_by_gamma_over_sigma &
110-
- self % spread_by_model_dim(sum(gradient_by_gamma_over_sigma, dim=2)) / self % model_dimension &
111-
- self % mu * self % spread_by_model_dim(sum(&
112-
gradient_by_gamma_over_sigma * self % mu * (one_over_sigma ** 2),&
113-
dim=2)&
114-
) / self % model_dimension
111+
- spread(&
112+
sum(gradient_by_gamma_over_sigma, dim=2),&
113+
dim=2,&
114+
ncopies=self % model_dimension&
115+
) / self % model_dimension &
116+
- self % mu * spread(&
117+
sum(gradient_by_gamma_over_sigma * self % mu * (one_over_sigma ** 2), dim=2),&
118+
dim=2,&
119+
ncopies=self % model_dimension&
120+
) / self % model_dimension
115121

116122
deallocate(one_over_sigma)
117123
deallocate(gradient_by_gamma_over_sigma)
118124
end subroutine backward
119125

120-
pure function spread_by_sequence(self, input) result(output)
121-
class(layernorm_layer), intent(in) :: self
122-
real, intent(in) :: input(:)
123-
real :: output(self % sequence_length, self % model_dimension)
124-
125-
output = spread(input, dim=1, ncopies=self % sequence_length)
126-
end function spread_by_sequence
127-
128-
pure function spread_by_model_dim(self, input) result(output)
129-
class(layernorm_layer), intent(in) :: self
130-
real, intent(in) :: input(:)
131-
real :: output(self % sequence_length, self % model_dimension)
132-
133-
output = spread(input, dim=2, ncopies=self % model_dimension)
134-
end function spread_by_model_dim
135-
136126
module subroutine init(self, input_shape)
137127
class(layernorm_layer), intent(in out) :: self
138128
integer, intent(in) :: input_shape(:)

0 commit comments

Comments
 (0)