28
28
Generic ,
29
29
Literal ,
30
30
NamedTuple ,
31
- Optional ,
32
31
TypeVar ,
33
- Union ,
34
32
cast ,
35
33
)
36
34
from unittest .mock import MagicMock , patch
91
89
_RefinementSolver = TypeVar ("_RefinementSolver" , bound = RefinementSolver )
92
90
93
91
if TYPE_CHECKING :
94
- # In Python 3.9-3. 10, this raises
92
+ # In Python 3.10, this raises
95
93
# `TypeError: Multiple inheritance with NamedTuple is not supported`.
96
94
# Thus, we have to do the actual full typing here, and a non-generic one
97
95
# below to be used at runtime.
98
96
class _ReduceProblem (NamedTuple , Generic [_Data , _Solver ]):
99
97
dataset : _Data
100
98
solver : _Solver
101
- expected_coreset : Optional [ AbstractCoreset ] = None
99
+ expected_coreset : AbstractCoreset | None = None
102
100
103
101
class _RefineProblem (NamedTuple , Generic [_RefinementSolver ]):
104
102
initial_coresubset : Coresubset
105
103
solver : _RefinementSolver
106
- expected_coresubset : Optional [ Coresubset ] = None
104
+ expected_coresubset : Coresubset | None = None
107
105
else :
108
106
# This is the implementation that's used at runtime.
109
107
class _ReduceProblem (NamedTuple ):
110
108
dataset : _Data
111
109
solver : _Solver
112
- expected_coreset : Optional [ AbstractCoreset ] = None
110
+ expected_coreset : AbstractCoreset | None = None
113
111
114
112
class _RefineProblem (NamedTuple ):
115
113
initial_coresubset : Coresubset
116
114
solver : _RefinementSolver
117
- expected_coresubset : Optional [ Coresubset ] = None
115
+ expected_coresubset : Coresubset | None = None
118
116
119
117
120
118
class SolverTest :
@@ -151,7 +149,7 @@ def reduce_problem(
151
149
return _ReduceProblem (Data (dataset ), solver , expected_coreset )
152
150
153
151
def check_solution_invariants (
154
- self , coreset : AbstractCoreset , problem : Union [ _RefineProblem , _ReduceProblem ]
152
+ self , coreset : AbstractCoreset , problem : _RefineProblem | _ReduceProblem
155
153
) -> None :
156
154
"""
157
155
Check that a coreset obeys certain expected invariant properties.
@@ -796,7 +794,7 @@ def test_functions_impl(x):
796
794
797
795
@override
798
796
def check_solution_invariants (
799
- self , coreset : AbstractCoreset , problem : Union [ _RefineProblem , _ReduceProblem ]
797
+ self , coreset : AbstractCoreset , problem : _RefineProblem | _ReduceProblem
800
798
) -> None :
801
799
r"""
802
800
Check that a coreset obeys certain expected invariant properties.
@@ -1006,7 +1004,7 @@ class ExplicitSizeSolverTest(SolverTest):
1006
1004
1007
1005
@override
1008
1006
def check_solution_invariants (
1009
- self , coreset : AbstractCoreset , problem : Union [ _RefineProblem , _ReduceProblem ]
1007
+ self , coreset : AbstractCoreset , problem : _RefineProblem | _ReduceProblem
1010
1008
) -> None :
1011
1009
super ().check_solution_invariants (coreset , problem )
1012
1010
solver = problem .solver
@@ -1026,7 +1024,7 @@ def check_solution_invariants(
1026
1024
def test_check_init (
1027
1025
self ,
1028
1026
solver_factory : jtu .Partial ,
1029
- coreset_size : Union [ int , float , str ] ,
1027
+ coreset_size : int | float | str ,
1030
1028
context : AbstractContextManager ,
1031
1029
) -> None :
1032
1030
"""
@@ -1073,7 +1071,7 @@ def solver_factory(self, request: pytest.FixtureRequest) -> jtu.Partial:
1073
1071
def reduce_problem (
1074
1072
self ,
1075
1073
request : pytest .FixtureRequest ,
1076
- solver_factory : Union [ type [Solver ], jtu .Partial ] ,
1074
+ solver_factory : type [Solver ] | jtu .Partial ,
1077
1075
) -> _ReduceProblem :
1078
1076
if request .param == "random" :
1079
1077
dataset = jr .uniform (self .random_key , self .shape )
@@ -1707,7 +1705,7 @@ class TestRandomSample(ExplicitSizeSolverTest):
1707
1705
1708
1706
@override
1709
1707
def check_solution_invariants (
1710
- self , coreset : AbstractCoreset , problem : Union [ _RefineProblem , _ReduceProblem ]
1708
+ self , coreset : AbstractCoreset , problem : _RefineProblem | _ReduceProblem
1711
1709
) -> None :
1712
1710
super ().check_solution_invariants (coreset , problem )
1713
1711
solver = cast (RandomSample , problem .solver )
@@ -1730,7 +1728,7 @@ class TestRPCholesky(ExplicitSizeSolverTest):
1730
1728
1731
1729
@override
1732
1730
def check_solution_invariants (
1733
- self , coreset : AbstractCoreset , problem : Union [ _RefineProblem , _ReduceProblem ]
1731
+ self , coreset : AbstractCoreset , problem : _RefineProblem | _ReduceProblem
1734
1732
) -> None :
1735
1733
"""Check functionality of 'unique' in addition to the default checks."""
1736
1734
super ().check_solution_invariants (coreset , problem )
@@ -2044,7 +2042,7 @@ def solver_factory(self, request: pytest.FixtureRequest) -> jtu.Partial:
2044
2042
],
2045
2043
)
2046
2044
def test_regulariser_lambda (
2047
- self , test_lambda : Optional [ Union [ float , int ]] , reduce_problem : _ReduceProblem
2045
+ self , test_lambda : float | int | None , reduce_problem : _ReduceProblem
2048
2046
) -> None :
2049
2047
"""Basic checks for the regularisation parameter, lambda."""
2050
2048
dataset , base_solver , _ = reduce_problem
@@ -2411,7 +2409,7 @@ def solver_factory(self, request) -> jtu.Partial:
2411
2409
def reduce_problem (
2412
2410
self ,
2413
2411
request : pytest .FixtureRequest ,
2414
- solver_factory : Union [ type [Solver ], jtu .Partial ] ,
2412
+ solver_factory : type [Solver ] | jtu .Partial ,
2415
2413
) -> _ReduceProblem :
2416
2414
if request .param == "random" :
2417
2415
data_key , supervision_key = jr .split (self .random_key )
@@ -2427,7 +2425,7 @@ def reduce_problem(
2427
2425
2428
2426
@override
2429
2427
def check_solution_invariants (
2430
- self , coreset : AbstractCoreset , problem : Union [ _RefineProblem , _ReduceProblem ]
2428
+ self , coreset : AbstractCoreset , problem : _RefineProblem | _ReduceProblem
2431
2429
) -> None :
2432
2430
"""Check functionality of 'unique' in addition to the default checks."""
2433
2431
super ().check_solution_invariants (coreset , problem )
@@ -2796,7 +2794,7 @@ def __init__(self, _data: np.ndarray, **kwargs):
2796
2794
del kwargs
2797
2795
self .data = _data
2798
2796
2799
- def get_arrays (self ) -> tuple [Union [ np .ndarray , None ] , ...]:
2797
+ def get_arrays (self ) -> tuple [np .ndarray | None , ...]:
2800
2798
"""Mock sklearn.neighbours.BinaryTree.get_arrays method."""
2801
2799
return None , np .arange (len (self .data )), None , None
2802
2800
0 commit comments