Skip to content

Commit 4aea615

Browse files
committed
Add network % get_output() subroutine that returns a pointer to the outputs
1 parent d6575cf commit 4aea615

File tree

3 files changed

+39
-18
lines changed

3 files changed

+39
-18
lines changed

example/merge_networks.f90

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ program merge_networks
55

66
type(network) :: net1, net2, net3
77
real, allocatable :: x1(:), x2(:)
8-
real, allocatable :: y1(:), y2(:)
8+
real, pointer :: y1(:), y2(:)
99
real, allocatable :: y(:)
1010
integer, parameter :: num_iterations = 500
1111
integer :: n, nn
@@ -44,19 +44,8 @@ program merge_networks
4444
call net2 % forward(x2)
4545

4646
! Get outputs of net1 and net2, concatenate, and pass to net3
47-
! A helper function could be made to take any number of networks
48-
! and return the concatenated output. Such function would turn the following
49-
! block into a one-liner.
50-
select type (net1_output_layer => net1 % layers(size(net1 % layers)) % p)
51-
type is (dense_layer)
52-
y1 = net1_output_layer % output
53-
end select
54-
55-
select type (net2_output_layer => net2 % layers(size(net2 % layers)) % p)
56-
type is (dense_layer)
57-
y2 = net2_output_layer % output
58-
end select
59-
47+
call net1 % get_output(y1)
48+
call net2 % get_output(y2)
6049
call net3 % forward([y1, y2])
6150

6251
! First compute the gradients on net3, then pass the gradients from the first

src/nf/nf_network.f90

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ module nf_network
3333
procedure, private :: forward_1d_int
3434
procedure, private :: forward_2d
3535
procedure, private :: forward_3d
36+
procedure, private :: get_output_1d
3637
procedure, private :: predict_1d
3738
procedure, private :: predict_1d_int
3839
procedure, private :: predict_2d
@@ -42,6 +43,7 @@ module nf_network
4243

4344
generic :: evaluate => evaluate_batch_1d
4445
generic :: forward => forward_1d, forward_1d_int, forward_2d, forward_3d
46+
generic :: get_output => get_output_1d
4547
generic :: predict => predict_1d, predict_1d_int, predict_2d, predict_3d
4648
generic :: predict_batch => predict_batch_1d, predict_batch_3d
4749

@@ -131,7 +133,7 @@ end subroutine forward_3d
131133

132134
end interface forward
133135

134-
interface output
136+
interface predict
135137

136138
module function predict_1d(self, input) result(res)
137139
!! Return the output of the network given the input 1-d array.
@@ -169,9 +171,10 @@ module function predict_3d(self, input) result(res)
169171
real, allocatable :: res(:)
170172
!! Output of the network
171173
end function predict_3d
172-
end interface output
173174

174-
interface output_batch
175+
end interface predict
176+
177+
interface predict_batch
175178
module function predict_batch_1d(self, input) result(res)
176179
!! Return the output of the network given an input batch of 3-d data.
177180
class(network), intent(in out) :: self
@@ -191,7 +194,14 @@ module function predict_batch_3d(self, input) result(res)
191194
real, allocatable :: res(:,:)
192195
!! Output of the network; the last dimension is the batch
193196
end function predict_batch_3d
194-
end interface output_batch
197+
end interface predict_batch
198+
199+
interface get_output
200+
module subroutine get_output_1d(self, output)
201+
class(network), intent(in), target :: self
202+
real, pointer, intent(out) :: output(:)
203+
end subroutine get_output_1d
204+
end interface get_output
195205

196206
interface
197207

src/nf/nf_network_submodule.f90

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,28 @@ module subroutine print_info(self)
511511
end subroutine print_info
512512

513513

514+
module subroutine get_output_1d(self, output)
515+
class(network), intent(in), target :: self
516+
real, pointer, intent(out) :: output(:)
517+
integer :: last
518+
519+
last = size(self % layers)
520+
521+
select type(output_layer => self % layers(last) % p)
522+
type is(dense_layer)
523+
output => output_layer % output
524+
type is(dropout_layer)
525+
output => output_layer % output
526+
type is(flatten_layer)
527+
output => output_layer % output
528+
class default
529+
error stop 'network % get_output not implemented for ' // &
530+
trim(self % layers(last) % name) // ' layer'
531+
end select
532+
533+
end subroutine get_output_1d
534+
535+
514536
module function get_num_params(self)
515537
class(network), intent(in) :: self
516538
integer :: get_num_params

0 commit comments

Comments
 (0)