1212  use  nf_reshape_layer, only: reshape3d_layer
1313  use  nf_linear2d_layer, only: linear2d_layer
1414  use  nf_self_attention_layer, only: self_attention_layer
15+   use  nf_layernorm_layer, only: layernorm_layer
1516  use  nf_optimizers, only: optimizer_base_type
1617
1718contains 
@@ -46,7 +47,7 @@ pure module subroutine backward_1d(self, previous, gradient)
4647
4748      type  is (flatten_layer)
4849
49-         !  Upstream layers permitted: input2d, input3d, conv2d, maxpool2d
50+         !  Upstream layers permitted: input2d, input3d, conv2d, layernorm,  maxpool2d
5051        select type (prev_layer = > previous %  p)
5152          type  is (input2d_layer)
5253            call  this_layer %  backward(prev_layer %  output, gradient)
@@ -60,6 +61,8 @@ pure module subroutine backward_1d(self, previous, gradient)
6061            call  this_layer %  backward(prev_layer %  output, gradient)
6162          type  is (self_attention_layer)
6263            call  this_layer %  backward(prev_layer %  output, gradient)
64+           type  is (layernorm_layer)
65+             call  this_layer %  backward(prev_layer %  output, gradient)
6366        end  select
6467
6568    end  select
@@ -84,6 +87,8 @@ pure module subroutine backward_2d(self, previous, gradient)
8487            call  this_layer %  backward(prev_layer %  output, gradient)
8588          type  is (self_attention_layer)
8689            call  this_layer %  backward(prev_layer %  output, gradient)
90+           type  is (layernorm_layer)
91+             call  this_layer %  backward(prev_layer %  output, gradient)
8792        end  select
8893
8994      type  is (self_attention_layer)
@@ -95,8 +100,18 @@ pure module subroutine backward_2d(self, previous, gradient)
95100            call  this_layer %  backward(prev_layer %  output, gradient)
96101          type  is (self_attention_layer)
97102            call  this_layer %  backward(prev_layer %  output, gradient)
103+           type  is (layernorm_layer)
104+             call  this_layer %  backward(prev_layer %  output, gradient)
98105        end  select
99106
107+       type  is (layernorm_layer)
108+ 
109+         select type (prev_layer = > previous %  p)
110+           type  is (linear2d_layer)
111+             call  this_layer %  backward(prev_layer %  output, gradient)
112+           type  is (self_attention_layer)
113+             call  this_layer %  backward(prev_layer %  output, gradient)
114+         end  select
100115    end  select
101116
102117  end  subroutine backward_2d
@@ -234,6 +249,8 @@ module subroutine forward(self, input)
234249            call  this_layer %  forward(prev_layer %  output)
235250          type  is (linear2d_layer)
236251            call  this_layer %  forward(prev_layer %  output)
252+           type  is (layernorm_layer)
253+             call  this_layer %  forward(prev_layer %  output)
237254        end  select
238255
239256      type  is (reshape3d_layer)
@@ -250,26 +267,40 @@ module subroutine forward(self, input)
250267
251268      type  is (linear2d_layer)
252269
253-         !  Upstream layers permitted: input2d, linear2d
270+         !  Upstream layers permitted: input2d, linear2d, self_attention, layernorm 
254271        select type (prev_layer = > input %  p)
255272          type  is (input2d_layer)
256273            call  this_layer %  forward(prev_layer %  output)
257274          type  is (linear2d_layer)
258275            call  this_layer %  forward(prev_layer %  output)
259276          type  is (self_attention_layer)
260277            call  this_layer %  forward(prev_layer %  output)
278+           type  is (layernorm_layer)
279+             call  this_layer %  forward(prev_layer %  output)
261280        end  select
262281
263282      type  is (self_attention_layer)
264283
265-         !  Upstream layers permitted: input2d, linear2d
284+         !  Upstream layers permitted: input2d, linear2d, self_attention, layernorm 
266285        select type (prev_layer = > input %  p)
267286          type  is (input2d_layer)
268287            call  this_layer %  forward(prev_layer %  output)
269288          type  is (linear2d_layer)
270289            call  this_layer %  forward(prev_layer %  output)
271290          type  is (self_attention_layer)
272291            call  this_layer %  forward(prev_layer %  output)
292+           type  is (layernorm_layer)
293+             call  this_layer %  forward(prev_layer %  output)
294+         end  select
295+ 
296+       type  is (layernorm_layer)
297+ 
298+         !  Upstream layers permitted: linear2d, self_attention
299+         select type (prev_layer = > input %  p)
300+           type  is (linear2d_layer)
301+             call  this_layer %  forward(prev_layer %  output)
302+           type  is (self_attention_layer)
303+             call  this_layer %  forward(prev_layer %  output)
273304        end  select
274305
275306    end  select
@@ -311,6 +342,8 @@ pure module subroutine get_output_2d(self, output)
311342        allocate (output, source= this_layer %  output)
312343      type  is (self_attention_layer)
313344        allocate (output, source= this_layer %  output)
345+       type  is (layernorm_layer)
346+         allocate (output, source= this_layer %  output)
314347      class default
315348        error stop  ' 2-d output can only be read from an input2d or linear2d layer.' 
316349
@@ -354,8 +387,8 @@ impure elemental module subroutine init(self, input)
354387      call  this_layer %  init(input %  layer_shape)
355388    end  select
356389
357-     !  The shape of conv2d, dropout, flatten, linear2d, maxpool2d, or 
358-     !  self_attention layers is not known until we receive an input layer.
390+     !  The shape of conv2d, dropout, flatten, linear2d, maxpool2d,
391+     !  self_attention or layernorm  layers is not known until we receive an input layer.
359392    select type (this_layer = > self %  p)
360393      type  is (conv2d_layer)
361394        self %  layer_shape =  shape (this_layer %  output)
@@ -367,6 +400,8 @@ impure elemental module subroutine init(self, input)
367400        self %  layer_shape =  shape (this_layer %  output)
368401      type  is (self_attention_layer)
369402        self %  layer_shape =  shape (this_layer %  output)
403+       type  is (layernorm_layer)
404+         self %  layer_shape =  shape (this_layer %  output)
370405      type  is (maxpool2d_layer)
371406        self %  layer_shape =  shape (this_layer %  output)
372407    end  select
@@ -425,6 +460,8 @@ elemental module function get_num_params(self) result(num_params)
425460        num_params =  this_layer %  get_num_params()
426461      type  is  (self_attention_layer)
427462        num_params =  this_layer %  get_num_params()
463+       type  is  (layernorm_layer)
464+         num_params =  this_layer %  get_num_params()
428465      class default
429466        error stop  ' Unknown layer type.' 
430467    end  select
@@ -458,6 +495,8 @@ module function get_params(self) result(params)
458495        params =  this_layer %  get_params()
459496      type  is  (self_attention_layer)
460497        params =  this_layer %  get_params()
498+       type  is  (layernorm_layer)
499+         params =  this_layer %  get_params()
461500      class default
462501        error stop  ' Unknown layer type.' 
463502    end  select
@@ -491,6 +530,8 @@ module function get_gradients(self) result(gradients)
491530        gradients =  this_layer %  get_gradients()
492531      type  is  (self_attention_layer)
493532        gradients =  this_layer %  get_gradients()
533+       type  is  (layernorm_layer)
534+         gradients =  this_layer %  get_gradients()
494535      class default
495536        error stop  ' Unknown layer type.' 
496537    end  select
@@ -549,6 +590,9 @@ module subroutine set_params(self, params)
549590      type  is  (self_attention_layer)
550591        call  this_layer %  set_params(params)
551592
593+       type  is  (layernorm_layer)
594+         call  this_layer %  set_params(params)
595+ 
552596      type  is  (maxpool2d_layer)
553597        !  No parameters to set.
554598        write (stderr, ' (a)' ' Warning: calling set_params() ' 
0 commit comments