@@ -49,6 +49,7 @@ def _knn_flann(X, num_neighbors, dist_type, order):
4949    # do not allow it 
5050    if  dist_type  ==  'max_dist' :
5151        raise  ValueError ('FLANN and max_dist is not supported' )
52+     
5253    pfl  =  _import_pfl ()
5354    pfl .set_distance_type (dist_type , order = order )
5455    flann  =  pfl .FLANN ()
@@ -58,6 +59,8 @@ def _knn_flann(X, num_neighbors, dist_type, order):
5859    # seems to work best). 
5960    NN , D  =  flann .nn (X , X , num_neighbors = (num_neighbors  +  1 ), 
6061                     algorithm = 'kdtree' )
62+     if  dist_type  ==  'euclidean' : # flann returns squared distances 
63+         return  NN , np .sqrt (D )
6164    return  NN , D 
6265
6366def  _radius_sp_kdtree (X , epsilon , dist_type , order = 0 ):
@@ -86,8 +89,9 @@ def _radius_sp_pdist(X, epsilon, dist_type, order):
8689    NN  =  []
8790    for  k  in  range (N ):
8891        v  =  pd [k , pdf [k , :]]
92+         d  =  pd [k , :].argsort ()
8993        # use the same conventions as in scipy.distance.kdtree 
90-         NN .append (v . argsort () )
94+         NN .append (d [ 0 : len ( v )] )
9195        D .append (np .sort (v ))
9296
9397    return  NN , D 
@@ -98,21 +102,32 @@ def _radius_flann(X, epsilon, dist_type, order=0):
98102    # do not allow it 
99103    if  dist_type  ==  'max_dist' :
100104        raise  ValueError ('FLANN and max_dist is not supported' )
101-         
102105    pfl  =  _import_pfl ()
106+     
103107    pfl .set_distance_type (dist_type , order = order )
104108    flann  =  pfl .FLANN ()
105109    flann .build_index (X )
106110
107111    D  =  []
108112    NN  =  []
109113    for  k  in  range (N ):
110-         nn , d  =  flann .nn_radius (X [k , :], epsilon )
114+         nn , d  =  flann .nn_radius (X [k , :], epsilon * epsilon )
111115        D .append (d )
112116        NN .append (nn )
113117    flann .delete_index ()
118+     if  dist_type  ==  'euclidean' : # flann returns squared distances 
119+         return  NN , np .sqrt (D )
114120    return  NN , D 
115121
122+ def  center_input (X , N ):
123+     return  X  -  np .kron (np .ones ((N , 1 )), np .mean (X , axis = 0 ))
124+ 
125+ def  rescale_input (X , N , d ):
126+      bounding_radius  =  0.5  *  np .linalg .norm (np .amax (X , axis = 0 ) - 
127+                                                    np .amin (X , axis = 0 ), 2 )
128+      scale  =  np .power (N , 1.  /  float (min (d , 3 ))) /  10. 
129+      return  X  *  scale  /  bounding_radius 
130+ 
116131class  NNGraph (Graph ):
117132    r"""Nearest-neighbor graph from given point cloud. 
118133
@@ -207,14 +222,10 @@ def __init__(self, Xin, NNtype='knn', backend='scipy-kdtree', center=True,
207222        Xout  =  self .Xin 
208223
209224        if  self .center :
210-             Xout  =  self .Xin  -  np .kron (np .ones ((N , 1 )),
211-                                       np .mean (self .Xin , axis = 0 ))
225+             Xout  =  center_input (Xout , N )
212226
213227        if  self .rescale :
214-             bounding_radius  =  0.5  *  np .linalg .norm (np .amax (Xout , axis = 0 ) - 
215-                                                    np .amin (Xout , axis = 0 ), 2 )
216-             scale  =  np .power (N , 1.  /  float (min (d , 3 ))) /  10. 
217-             Xout  *=  scale  /  bounding_radius 
228+             Xout  =  rescale_input (Xout , N , d )
218229
219230
220231
0 commit comments