@@ -31,8 +31,6 @@ module nf_layernorm_layer
3131 contains
3232 procedure :: forward
3333 procedure :: backward
34- procedure :: spread_by_sequence
35- procedure :: spread_by_model_dim
3634 procedure :: init
3735 end type layernorm_layer
3836
@@ -90,8 +88,11 @@ pure module subroutine backward(self, input, gradient)
9088 allocate (one_over_sigma(self % sequence_length, self % model_dimension))
9189 allocate (gradient_by_gamma_over_sigma(self % sequence_length, self % model_dimension))
9290
93- one_over_sigma = (1 / self % spread_by_model_dim(self % sigma))
94- gradient_by_gamma_over_sigma = gradient * self % spread_by_sequence(self % gamma) * one_over_sigma
91+ one_over_sigma = (1 / spread (self % sigma, dim= 2 , ncopies= self % model_dimension))
92+ gradient_by_gamma_over_sigma = &
93+ gradient &
94+ * spread (self % gamma, dim= 1 , ncopies= self % sequence_length) &
95+ * one_over_sigma
9596
9697 ! d_output/d_gamma = sum(d_output/d_y * mu/sigma)
9798 self % d_gamma = sum (gradient * self % mu * one_over_sigma, dim= 1 )
@@ -107,32 +108,21 @@ pure module subroutine backward(self, input, gradient)
107108 ! - mu * sum(d_output/d_y * gamma * mu * sigma^(03)) / len
108109 self % gradient = &
109110 gradient_by_gamma_over_sigma &
110- - self % spread_by_model_dim(sum (gradient_by_gamma_over_sigma, dim= 2 )) / self % model_dimension &
111- - self % mu * self % spread_by_model_dim(sum (&
112- gradient_by_gamma_over_sigma * self % mu * (one_over_sigma ** 2 ),&
113- dim= 2 )&
114- ) / self % model_dimension
111+ - spread (&
112+ sum (gradient_by_gamma_over_sigma, dim= 2 ),&
113+ dim= 2 ,&
114+ ncopies= self % model_dimension&
115+ ) / self % model_dimension &
116+ - self % mu * spread (&
117+ sum (gradient_by_gamma_over_sigma * self % mu * (one_over_sigma ** 2 ), dim= 2 ),&
118+ dim= 2 ,&
119+ ncopies= self % model_dimension&
120+ ) / self % model_dimension
115121
116122 deallocate (one_over_sigma)
117123 deallocate (gradient_by_gamma_over_sigma)
118124 end subroutine backward
119125
120- pure function spread_by_sequence (self , input ) result(output)
121- class(layernorm_layer), intent (in ) :: self
122- real , intent (in ) :: input(:)
123- real :: output(self % sequence_length, self % model_dimension)
124-
125- output = spread (input, dim= 1 , ncopies= self % sequence_length)
126- end function spread_by_sequence
127-
128- pure function spread_by_model_dim (self , input ) result(output)
129- class(layernorm_layer), intent (in ) :: self
130- real , intent (in ) :: input(:)
131- real :: output(self % sequence_length, self % model_dimension)
132-
133- output = spread (input, dim= 2 , ncopies= self % model_dimension)
134- end function spread_by_model_dim
135-
136126 module subroutine init (self , input_shape )
137127 class(layernorm_layer), intent (in out ) :: self
138128 integer , intent (in ) :: input_shape(:)
0 commit comments