@@ -491,6 +491,7 @@ def test_subgraph(self, n_vertices=100):
491491 self .assertEqual (graph .plotting , self ._G .plotting )
492492
493493 def test_nngraph (self , n_vertices = 30 ):
494+ """Test all the combinations of metric, kind, backend."""
494495 features = np .random .RandomState (42 ).normal (size = (n_vertices , 3 ))
495496 metrics = ['euclidean' , 'manhattan' , 'max_dist' , 'minkowski' ]
496497 backends = ['scipy-kdtree' , 'scipy-ckdtree' , 'scipy-pdist' , 'nmslib' ,
@@ -499,46 +500,30 @@ def test_nngraph(self, n_vertices=30):
499500
500501 for backend in backends :
501502 for metric in metrics :
502- if ((backend == 'flann' and metric == 'max_dist' ) or
503- (backend == 'nmslib' and metric == 'minkowski' )):
504- self .assertRaises (ValueError , graphs .NNGraph , features ,
505- kind = 'knn' , backend = backend ,
506- metric = metric )
507- self .assertRaises (ValueError , graphs .NNGraph , features ,
508- kind = 'radius' , backend = backend ,
509- metric = metric )
510- else :
511- if backend == 'nmslib' :
512- self .assertRaises (ValueError , graphs .NNGraph , features ,
513- kind = 'radius' , backend = backend ,
514- metric = metric , order = order )
503+ for kind in ['knn' , 'radius' ]:
504+ params = dict (features = features , metric = metric ,
505+ order = order , kind = kind , backend = backend )
506+ # Unsupported combinations.
507+ if backend == 'flann' and metric == 'max_dist' :
508+ self .assertRaises (ValueError , graphs .NNGraph , ** params )
509+ elif backend == 'nmslib' and metric == 'minkowski' :
510+ self .assertRaises (ValueError , graphs .NNGraph , ** params )
511+ elif backend == 'nmslib' and kind == 'radius' :
512+ self .assertRaises (ValueError , graphs .NNGraph , ** params )
515513 else :
516- graphs .NNGraph (features , kind = 'radius' ,
517- backend = backend ,
518- metric = metric , order = order )
519- graphs .NNGraph (features , kind = 'knn' ,
520- backend = backend ,
521- metric = metric , order = order )
522- graphs .NNGraph (features , kind = 'knn' ,
523- backend = backend ,
524- metric = metric , order = order ,
525- center = False )
526- graphs .NNGraph (features , kind = 'knn' ,
527- backend = backend ,
528- metric = metric , order = order ,
529- rescale = False )
530- graphs .NNGraph (features , kind = 'knn' ,
531- backend = backend ,
532- metric = metric , order = order ,
533- rescale = False , center = False )
514+ graphs .NNGraph (** params , center = False )
515+ graphs .NNGraph (** params , rescale = False )
516+ graphs .NNGraph (** params , center = False , rescale = False )
517+
518+ # Invalid parameters.
519+ self .assertRaises (ValueError , graphs .NNGraph , features ,
520+ metric = 'invalid' )
534521 self .assertRaises (ValueError , graphs .NNGraph , features ,
535- kind = 'invalid' , backend = backend ,
536- metric = metric )
522+ kind = 'invalid' )
537523 self .assertRaises (ValueError , graphs .NNGraph , features ,
538- kind = 'knn' , backend = 'invalid' ,
539- metric = metric )
524+ backend = 'invalid' )
540525 self .assertRaises (ValueError , graphs .NNGraph , features ,
541- kind = 'knn' , k = n_vertices + 1 )
526+ kind = 'knn' , k = n_vertices + 1 )
542527
543528 def test_nngraph_consistency (self ):
544529 features = np .arange (90 ).reshape (30 , 3 )
0 commit comments