1
+ ! Copyright (c) 2024-2025, The Regents of the University of California and Sourcery Institute
2
+ ! Terms of use are as specified in LICENSE.txt
3
+
4
+ module linear_2d_layer_test_m
5
+ use julienne_m, only : &
6
+ test_t, test_description_t, test_diagnosis_t, test_result_t &
7
+ ,operator (.equalsExpected.), operator (// ), operator (.approximates.), operator (.within.), operator (.also.), operator (.all.)
8
+ use nf_linear2d_layer, only: linear2d_layer
9
+ implicit none
10
+
11
+ type, extends(test_t) :: linear_2d_layer_test_t
12
+ contains
13
+ procedure , nopass :: subject
14
+ procedure , nopass :: results
15
+ end type
16
+
17
+ contains
18
+
19
+ pure function subject () result(test_subject)
20
+ character (len= :), allocatable :: test_subject
21
+ test_subject = ' A linear_2d_layer'
22
+ end function
23
+
24
+ function results () result(test_results)
25
+ type (linear_2d_layer_test_t) linear_2d_layer_test
26
+ type (test_result_t), allocatable :: test_results(:)
27
+ test_results = linear_2d_layer_test% run( &
28
+ [test_description_t(' updating gradients' , check_gradient_updates) &
29
+ ])
30
+ end function
31
+
32
+ function check_gradient_updates () result(test_diagnosis)
33
+ type (test_diagnosis_t) test_diagnosis
34
+
35
+ real :: input(3 , 4 ) = reshape ([0.0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.9 , 0.11 , 0.12 ], [3 , 4 ])
36
+ real :: gradient(3 , 2 ) = reshape ([0.0 , 10 ., 0.2 , 3 ., 0.4 , 1 .], [3 , 2 ])
37
+ type (linear2d_layer) :: linear
38
+ real , pointer :: w_ptr(:)
39
+ real , pointer :: b_ptr(:)
40
+
41
+ integer :: num_parameters
42
+ real , allocatable :: parameters(:) ! Remove the fixed size
43
+ real :: expected_parameters(10 ) = [&
44
+ 0.100000001 , 0.100000001 , 0.100000001 , 0.100000001 , 0.100000001 , 0.100000001 , 0.100000001 , 0.100000001 ,&
45
+ 0.109999999 , 0.109999999 &
46
+ ]
47
+ real :: gradients(10 )
48
+ real :: expected_gradients(10 ) = [&
49
+ 1.03999996 , 4.09999990 , 7.15999985 , 1.12400007 , 0.240000010 , 1.56000006 , 2.88000011 , 2.86399961 ,&
50
+ 10.1999998 , 4.40000010 &
51
+ ]
52
+ real :: updated_parameters(10 )
53
+ real :: updated_weights(8 )
54
+ real :: updated_biases(2 )
55
+ real :: expected_weights(8 ) = [&
56
+ 0.203999996 , 0.509999990 , 0.816000044 , 0.212400019 , 0.124000005 , 0.256000012 , 0.388000011 , 0.386399955 &
57
+ ]
58
+ real :: expected_biases(2 ) = [1.13000000 , 0.550000012 ]
59
+
60
+ integer :: i
61
+ real , parameter :: tolerance = 0 .
62
+
63
+ linear = linear2d_layer(out_features= 2 )
64
+ call linear % init([3 , 4 ])
65
+ linear % weights = 0.1
66
+ linear % biases = 0.11
67
+ call linear % forward(input)
68
+ call linear % backward(input, gradient)
69
+ num_parameters = linear % get_num_params()
70
+
71
+ test_diagnosis = (num_parameters .equalsExpected. 10 ) // " (number of parameters)"
72
+
73
+ call linear % get_params_ptr(w_ptr, b_ptr) ! Change this_layer to linear
74
+ allocate (parameters(size (w_ptr) + size (b_ptr)))
75
+ parameters(1 :size (w_ptr)) = w_ptr
76
+ parameters(size (w_ptr)+ 1 :) = b_ptr
77
+ test_diagnosis = test_diagnosis .also. (.all. (parameters .approximates. expected_parameters .within. tolerance) // " (parameters)" )
78
+
79
+ gradients = linear % get_gradients()
80
+ test_diagnosis = test_diagnosis .also. (.all. (gradients .approximates. expected_gradients .within. tolerance) // " (gradients)" )
81
+
82
+ do i = 1 , num_parameters
83
+ updated_parameters(i) = parameters(i) + 0.1 * gradients(i)
84
+ end do
85
+
86
+ call linear % get_params_ptr(w_ptr, b_ptr) ! Change this_layer to linear
87
+ w_ptr = updated_parameters(1 :size (w_ptr))
88
+ b_ptr = updated_parameters(size (w_ptr)+ 1 :)
89
+ updated_weights = reshape (linear % weights, shape (expected_weights))
90
+ test_diagnosis = test_diagnosis .also. (.all. (updated_weights .approximates. expected_weights .within. tolerance) // " (updated weights)" )
91
+
92
+ updated_biases = linear % biases
93
+ test_diagnosis = test_diagnosis .also. (.all. (updated_biases .approximates. expected_biases .within. tolerance) // " (updated biases)" )
94
+
95
+ end function
96
+
97
+ end module linear_2d_layer_test_m
0 commit comments