@@ -32,17 +32,19 @@ module nf_network
3232
3333 procedure , private :: evaluate_batch_1d
3434 procedure , private :: forward_1d
35+ procedure , private :: forward_1d_int
3536 procedure , private :: forward_2d
3637 procedure , private :: forward_3d
3738 procedure , private :: predict_1d
39+ procedure , private :: predict_1d_int
3840 procedure , private :: predict_2d
3941 procedure , private :: predict_3d
4042 procedure , private :: predict_batch_1d
4143 procedure , private :: predict_batch_3d
4244
4345 generic :: evaluate = > evaluate_batch_1d
44- generic :: forward = > forward_1d, forward_2d, forward_3d
45- generic :: predict = > predict_1d, predict_2d, predict_3d
46+ generic :: forward = > forward_1d, forward_1d_int, forward_2d, forward_3d
47+ generic :: predict = > predict_1d, predict_1d_int, predict_2d, predict_3d
4648 generic :: predict_batch = > predict_batch_1d, predict_batch_3d
4749
4850 end type network
@@ -95,6 +97,12 @@ module subroutine forward_1d(self, input)
9597 ! ! 1-d input data
9698 end subroutine forward_1d
9799
100+ module subroutine forward_1d_int (self , input )
101+ ! ! Same as `forward_1d` except `integer`
102+ class(network), intent (in out ) :: self
103+ integer , intent (in ) :: input(:)
104+ end subroutine forward_1d_int
105+
98106 module subroutine forward_2d (self , input )
99107 ! ! Apply a forward pass through the network.
100108 ! !
@@ -137,6 +145,13 @@ module function predict_1d(self, input) result(res)
137145 ! ! Output of the network
138146 end function predict_1d
139147
148+ module function predict_1d_int (self , input ) result(res)
149+ ! ! Same as `predict_1d` except `integer`
150+ class(network), intent (in out ) :: self
151+ integer , intent (in ) :: input(:)
152+ real , allocatable :: res(:)
153+ end function predict_1d_int
154+
140155 module function predict_2d (self , input ) result(res)
141156 ! ! Return the output of the network given the input 1-d array.
142157 class(network), intent (in out ) :: self
0 commit comments