35
35
import time
36
36
from typing import Any , TextIO
37
37
import unittest
38
- import warnings
39
38
import zlib
40
39
41
40
from absl .testing import absltest
49
48
from jax ._src import dtypes as _dtypes
50
49
from jax ._src import lib as _jaxlib
51
50
from jax ._src import monitoring
51
+ from jax ._src import test_warning_util
52
52
from jax ._src import xla_bridge
53
53
from jax ._src import util
54
54
from jax ._src import mesh as mesh_lib
118
118
)
119
119
120
120
TEST_NUM_THREADS = config .int_flag (
121
- 'jax_test_num_threads' , 0 ,
121
+ 'jax_test_num_threads' , int ( os . getenv ( 'JAX_TEST_NUM_THREADS' , '0' )) ,
122
122
help = 'Number of threads to use for running tests. 0 means run everything '
123
123
'in the main thread. Using > 1 thread is experimental.'
124
124
)
@@ -1076,7 +1076,7 @@ def stopTest(self, test: unittest.TestCase):
1076
1076
with self .lock :
1077
1077
# We assume test_result is an ABSL _TextAndXMLTestResult, so we can
1078
1078
# override how it gets the time.
1079
- time_getter = self .test_result . time_getter
1079
+ time_getter = getattr ( self .test_result , " time_getter" , None )
1080
1080
try :
1081
1081
self .test_result .time_getter = lambda : self .start_time
1082
1082
self .test_result .startTest (test )
@@ -1085,7 +1085,8 @@ def stopTest(self, test: unittest.TestCase):
1085
1085
self .test_result .time_getter = lambda : stop_time
1086
1086
self .test_result .stopTest (test )
1087
1087
finally :
1088
- self .test_result .time_getter = time_getter
1088
+ if time_getter is not None :
1089
+ self .test_result .time_getter = time_getter
1089
1090
1090
1091
def addSuccess (self , test : unittest .TestCase ):
1091
1092
self .actions .append (lambda : self .test_result .addSuccess (test ))
@@ -1120,6 +1121,8 @@ def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.Test
1120
1121
if TEST_NUM_THREADS .value <= 0 :
1121
1122
return super ().run (result )
1122
1123
1124
+ test_warning_util .install_threadsafe_warning_handlers ()
1125
+
1123
1126
executor = ThreadPoolExecutor (TEST_NUM_THREADS .value )
1124
1127
lock = threading .Lock ()
1125
1128
futures = []
@@ -1368,11 +1371,44 @@ def assertMultiLineStrippedEqual(self, expected, what):
1368
1371
self .assertMultiLineEqual (expected_clean , what_clean ,
1369
1372
msg = f"Found\n { what } \n Expecting\n { expected } " )
1370
1373
1374
+
1371
1375
@contextmanager
1372
1376
def assertNoWarnings (self ):
1373
- with warnings .catch_warnings ():
1374
- warnings .simplefilter ("error" )
1377
+ with test_warning_util .raise_on_warnings ():
1378
+ yield
1379
+
1380
+ # We replace assertWarns and assertWarnsRegex with functions that use the
1381
+ # thread-safe warning utilities. Unlike the unittest versions these only
1382
+ # function as context managers.
1383
+ @contextmanager
1384
+ def assertWarns (self , warning , * , msg = None ):
1385
+ with test_warning_util .record_warnings () as ws :
1386
+ yield
1387
+ for w in ws :
1388
+ if not isinstance (w .message , warning ):
1389
+ continue
1390
+ if msg is not None and msg not in str (w .message ):
1391
+ continue
1392
+ return
1393
+ self .fail (f"Expected warning not found { warning } :'{ msg } ', got "
1394
+ f"{ ws } " )
1395
+
1396
+ @contextmanager
1397
+ def assertWarnsRegex (self , warning , regex ):
1398
+ if regex is not None :
1399
+ regex = re .compile (regex )
1400
+
1401
+ with test_warning_util .record_warnings () as ws :
1375
1402
yield
1403
+ for w in ws :
1404
+ if not isinstance (w .message , warning ):
1405
+ continue
1406
+ if regex is not None and not regex .search (str (w .message )):
1407
+ continue
1408
+ return
1409
+ self .fail (f"Expected warning not found { warning } :'{ regex } ', got "
1410
+ f"{ ws } " )
1411
+
1376
1412
1377
1413
def _CompileAndCheck (self , fun , args_maker , * , check_dtypes = True , tol = None ,
1378
1414
rtol = None , atol = None , check_cache_misses = True ):
@@ -1449,11 +1485,7 @@ def assertNotDeleted(self, x):
1449
1485
self .assertFalse (x .is_deleted ())
1450
1486
1451
1487
1452
- @contextmanager
1453
- def ignore_warning (* , message = '' , category = Warning , ** kw ):
1454
- with warnings .catch_warnings ():
1455
- warnings .filterwarnings ("ignore" , message = message , category = category , ** kw )
1456
- yield
1488
+ ignore_warning = test_warning_util .ignore_warning
1457
1489
1458
1490
# -------------------- Mesh parametrization helpers --------------------
1459
1491
@@ -1768,9 +1800,8 @@ def make_axis_points(size):
1768
1800
logtiny = finfo .minexp / prec_dps_ratio
1769
1801
axis_points = np .zeros (3 + 2 * size , dtype = finfo .dtype )
1770
1802
1771
- with warnings . catch_warnings ( ):
1803
+ with ignore_warning ( category = RuntimeWarning ):
1772
1804
# Silence RuntimeWarning: overflow encountered in cast
1773
- warnings .simplefilter ("ignore" )
1774
1805
half_neg_line = - np .logspace (logmin , logtiny , size , dtype = finfo .dtype )
1775
1806
half_line = - half_neg_line [::- 1 ]
1776
1807
axis_points [- size - 1 :- 1 ] = half_line
0 commit comments