Skip to content

Commit 165a6c4

Browse files
committed
Update example of merging 2 networks to feed into a 3rd network
1 parent 38c998f commit 165a6c4

File tree

2 files changed

+122
-63
lines changed

2 files changed

+122
-63
lines changed

example/concatenate.f90

Lines changed: 0 additions & 63 deletions
This file was deleted.

example/merge_networks.f90

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
program merge_networks
2+
use nf, only: dense, input, network, sgd
3+
use nf_dense_layer, only: dense_layer
4+
implicit none
5+
6+
type(network) :: net1, net2, net3
7+
real, allocatable :: x1(:), x2(:)
8+
real, allocatable :: y1(:), y2(:)
9+
real, allocatable :: y(:)
10+
integer, parameter :: num_iterations = 500
11+
integer :: n, nn
12+
integer :: net1_output_size, net2_output_size
13+
14+
x1 = [0.1, 0.3, 0.5]
15+
x2 = [0.2, 0.4]
16+
y = [0.123456, 0.246802, 0.369258, 0.482604, 0.505050, 0.628406, 0.741852]
17+
18+
net1 = network([ &
19+
input(3), &
20+
dense(2), &
21+
dense(3), &
22+
dense(2) &
23+
])
24+
25+
net2 = network([ &
26+
input(2), &
27+
dense(5), &
28+
dense(3) &
29+
])
30+
31+
net1_output_size = product(net1 % layers(size(net1 % layers)) % layer_shape)
32+
net2_output_size = product(net2 % layers(size(net2 % layers)) % layer_shape)
33+
34+
! Network 3
35+
net3 = network([ &
36+
input(net1_output_size + net2_output_size), &
37+
dense(7) &
38+
])
39+
40+
do n = 1, num_iterations
41+
42+
! Forward propagate two network branches
43+
call net1 % forward(x1)
44+
call net2 % forward(x2)
45+
46+
! Get outputs of net1 and net2, concatenate, and pass to net3
47+
! A helper function could be made to take any number of networks
48+
! and return the concatenated output. Such function would turn the following
49+
! block into a one-liner.
50+
select type (net1_output_layer => net1 % layers(size(net1 % layers)) % p)
51+
type is (dense_layer)
52+
y1 = net1_output_layer % output
53+
end select
54+
55+
select type (net2_output_layer => net2 % layers(size(net2 % layers)) % p)
56+
type is (dense_layer)
57+
y2 = net2_output_layer % output
58+
end select
59+
60+
call net3 % forward([y1, y2])
61+
62+
! Compute the gradients on the 3rd network
63+
call net3 % backward(y)
64+
65+
! net3 % update() will clear the gradients immediately after updating
66+
! the weights, so we need to pass the gradients to net1 and net2 first
67+
68+
! For net1 and net2, we can't use the existing net % backward() because
69+
! it currently assumes that the output layer gradients are computed based
70+
! on the loss function and not the gradient from the next layer.
71+
! For now, we need to manually pass the gradient from the first hidden layer
72+
! of net3 to the output layers of net1 and net2.
73+
select type (next_layer => net3 % layers(2) % p)
74+
! Assume net3's first hidden layer is dense;
75+
! would need to be generalized to others.
76+
type is (dense_layer)
77+
78+
nn = size(net1 % layers)
79+
call net1 % layers(nn) % backward( &
80+
net1 % layers(nn - 1), next_layer % gradient(1:net1_output_size) &
81+
)
82+
83+
nn = size(net2 % layers)
84+
call net2 % layers(nn) % backward( &
85+
net2 % layers(nn - 1), next_layer % gradient(net1_output_size+1:size(next_layer % gradient)) &
86+
)
87+
88+
end select
89+
90+
! Compute the gradients on hidden layers of net1, if any
91+
do nn = size(net1 % layers)-1, 2, -1
92+
select type (next_layer => net1 % layers(nn + 1) % p)
93+
type is (dense_layer)
94+
call net1 % layers(nn) % backward( &
95+
net1 % layers(nn - 1), next_layer % gradient &
96+
)
97+
end select
98+
end do
99+
100+
! Compute the gradients on hidden layers of net2, if any
101+
do nn = size(net2 % layers)-1, 2, -1
102+
select type (next_layer => net2 % layers(nn + 1) % p)
103+
type is (dense_layer)
104+
call net2 % layers(nn) % backward( &
105+
net2 % layers(nn - 1), next_layer % gradient &
106+
)
107+
end select
108+
end do
109+
110+
! Gradients are now computed on all networks and we can update the weights
111+
call net1 % update(optimizer=sgd(learning_rate=1.))
112+
call net2 % update(optimizer=sgd(learning_rate=1.))
113+
call net3 % update(optimizer=sgd(learning_rate=1.))
114+
115+
if (mod(n, 50) == 0) then
116+
print *, "Iteration ", n, ", output RMSE = ", &
117+
sqrt(sum((net3 % predict([net1 % predict(x1), net2 % predict(x2)]) - y)**2) / size(y))
118+
end if
119+
120+
end do
121+
122+
end program merge_networks

0 commit comments

Comments
 (0)