@@ -3,16 +3,18 @@ program test_flatten_layer
33 use iso_fortran_env, only: stderr = > error_unit
44 use nf, only: dense, flatten, input, layer, network
55 use nf_flatten_layer, only: flatten_layer
6+ use nf_input2d_layer, only: input2d_layer
67 use nf_input3d_layer, only: input3d_layer
78
89 implicit none
910
1011 type (layer) :: test_layer, input_layer
1112 type (network) :: net
12- real , allocatable :: gradient (:,:,:)
13+ real , allocatable :: gradient_3d (:,:,:), gradient_2d( :,:)
1314 real , allocatable :: output(:)
1415 logical :: ok = .true.
1516
17+ ! Test 3D input
1618 test_layer = flatten()
1719
1820 if (.not. test_layer % name == ' flatten' ) then
@@ -59,14 +61,49 @@ program test_flatten_layer
5961 call test_layer % backward(input_layer, real ([1 , 2 , 3 , 4 ]))
6062
6163 select type (this_layer = > test_layer % p); type is(flatten_layer)
62- gradient = this_layer % gradient
64+ gradient_3d = this_layer % gradient_3d
6365 end select
6466
65- if (.not. all (gradient == reshape (real ([1 , 2 , 3 , 4 ]), [1 , 2 , 2 ]))) then
67+ if (.not. all (gradient_3d == reshape (real ([1 , 2 , 3 , 4 ]), [1 , 2 , 2 ]))) then
6668 ok = .false.
6769 write (stderr, ' (a)' ) ' flatten layer correctly propagates backward.. failed'
6870 end if
6971
72+ ! Test 2D input
73+ test_layer = flatten()
74+ input_layer = input(2 , 3 )
75+ call test_layer % init(input_layer)
76+
77+ if (.not. all (test_layer % layer_shape == [6 ])) then
78+ ok = .false.
79+ write (stderr, ' (a)' ) ' flatten layer has an incorrect output shape for 2D input.. failed'
80+ end if
81+
82+ ! Test forward pass - reshaping from 2-d to 1-d
83+ select type (this_layer = > input_layer % p); type is(input2d_layer)
84+ call this_layer % set(reshape (real ([1 , 2 , 3 , 4 , 5 , 6 ]), [2 , 3 ]))
85+ end select
86+
87+ call test_layer % forward(input_layer)
88+ call test_layer % get_output(output)
89+
90+ if (.not. all (output == [1 , 2 , 3 , 4 , 5 , 6 ])) then
91+ ok = .false.
92+ write (stderr, ' (a)' ) ' flatten layer correctly propagates forward for 2D input.. failed'
93+ end if
94+
95+ ! Test backward pass - reshaping from 1-d to 2-d
96+ call test_layer % backward(input_layer, real ([1 , 2 , 3 , 4 , 5 , 6 ]))
97+
98+ select type (this_layer = > test_layer % p); type is(flatten_layer)
99+ gradient_2d = this_layer % gradient_2d
100+ end select
101+
102+ if (.not. all (gradient_2d == reshape (real ([1 , 2 , 3 , 4 , 5 , 6 ]), [2 , 3 ]))) then
103+ ok = .false.
104+ write (stderr, ' (a)' ) ' flatten layer correctly propagates backward for 2D input.. failed'
105+ end if
106+
70107 net = network([ &
71108 input(1 , 28 , 28 ), &
72109 flatten(), &
0 commit comments