1
+ program merge_networks
2
+ use nf, only: dense, input, network, sgd
3
+ use nf_dense_layer, only: dense_layer
4
+ implicit none
5
+
6
+ type (network) :: net1, net2, net3
7
+ real , allocatable :: x1(:), x2(:)
8
+ real , allocatable :: y1(:), y2(:)
9
+ real , allocatable :: y(:)
10
+ integer , parameter :: num_iterations = 500
11
+ integer :: n, nn
12
+ integer :: net1_output_size, net2_output_size
13
+
14
+ x1 = [0.1 , 0.3 , 0.5 ]
15
+ x2 = [0.2 , 0.4 ]
16
+ y = [0.123456 , 0.246802 , 0.369258 , 0.482604 , 0.505050 , 0.628406 , 0.741852 ]
17
+
18
+ net1 = network([ &
19
+ input(3 ), &
20
+ dense(2 ), &
21
+ dense(3 ), &
22
+ dense(2 ) &
23
+ ])
24
+
25
+ net2 = network([ &
26
+ input(2 ), &
27
+ dense(5 ), &
28
+ dense(3 ) &
29
+ ])
30
+
31
+ net1_output_size = product (net1 % layers(size (net1 % layers)) % layer_shape)
32
+ net2_output_size = product (net2 % layers(size (net2 % layers)) % layer_shape)
33
+
34
+ ! Network 3
35
+ net3 = network([ &
36
+ input(net1_output_size + net2_output_size), &
37
+ dense(7 ) &
38
+ ])
39
+
40
+ do n = 1 , num_iterations
41
+
42
+ ! Forward propagate two network branches
43
+ call net1 % forward(x1)
44
+ call net2 % forward(x2)
45
+
46
+ ! Get outputs of net1 and net2, concatenate, and pass to net3
47
+ ! A helper function could be made to take any number of networks
48
+ ! and return the concatenated output. Such function would turn the following
49
+ ! block into a one-liner.
50
+ select type (net1_output_layer = > net1 % layers(size (net1 % layers)) % p)
51
+ type is (dense_layer)
52
+ y1 = net1_output_layer % output
53
+ end select
54
+
55
+ select type (net2_output_layer = > net2 % layers(size (net2 % layers)) % p)
56
+ type is (dense_layer)
57
+ y2 = net2_output_layer % output
58
+ end select
59
+
60
+ call net3 % forward([y1, y2])
61
+
62
+ ! Compute the gradients on the 3rd network
63
+ call net3 % backward(y)
64
+
65
+ ! net3 % update() will clear the gradients immediately after updating
66
+ ! the weights, so we need to pass the gradients to net1 and net2 first
67
+
68
+ ! For net1 and net2, we can't use the existing net % backward() because
69
+ ! it currently assumes that the output layer gradients are computed based
70
+ ! on the loss function and not the gradient from the next layer.
71
+ ! For now, we need to manually pass the gradient from the first hidden layer
72
+ ! of net3 to the output layers of net1 and net2.
73
+ select type (next_layer = > net3 % layers(2 ) % p)
74
+ ! Assume net3's first hidden layer is dense;
75
+ ! would need to be generalized to others.
76
+ type is (dense_layer)
77
+
78
+ nn = size (net1 % layers)
79
+ call net1 % layers(nn) % backward( &
80
+ net1 % layers(nn - 1 ), next_layer % gradient(1 :net1_output_size) &
81
+ )
82
+
83
+ nn = size (net2 % layers)
84
+ call net2 % layers(nn) % backward( &
85
+ net2 % layers(nn - 1 ), next_layer % gradient(net1_output_size+1 :size (next_layer % gradient)) &
86
+ )
87
+
88
+ end select
89
+
90
+ ! Compute the gradients on hidden layers of net1, if any
91
+ do nn = size (net1 % layers)- 1 , 2 , - 1
92
+ select type (next_layer = > net1 % layers(nn + 1 ) % p)
93
+ type is (dense_layer)
94
+ call net1 % layers(nn) % backward( &
95
+ net1 % layers(nn - 1 ), next_layer % gradient &
96
+ )
97
+ end select
98
+ end do
99
+
100
+ ! Compute the gradients on hidden layers of net2, if any
101
+ do nn = size (net2 % layers)- 1 , 2 , - 1
102
+ select type (next_layer = > net2 % layers(nn + 1 ) % p)
103
+ type is (dense_layer)
104
+ call net2 % layers(nn) % backward( &
105
+ net2 % layers(nn - 1 ), next_layer % gradient &
106
+ )
107
+ end select
108
+ end do
109
+
110
+ ! Gradients are now computed on all networks and we can update the weights
111
+ call net1 % update(optimizer= sgd(learning_rate= 1 .))
112
+ call net2 % update(optimizer= sgd(learning_rate= 1 .))
113
+ call net3 % update(optimizer= sgd(learning_rate= 1 .))
114
+
115
+ if (mod (n, 50 ) == 0 ) then
116
+ print * , " Iteration " , n, " , output RMSE = " , &
117
+ sqrt (sum ((net3 % predict([net1 % predict(x1), net2 % predict(x2)]) - y)** 2 ) / size (y))
118
+ end if
119
+
120
+ end do
121
+
122
+ end program merge_networks
0 commit comments