Skip to content

Commit b06779b

Browse files
committed
Switch to a new thread-safe utility for catching warnings.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure. This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
1 parent 640cb00 commit b06779b

10 files changed

+303
-72
lines changed

jax/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ py_library(
120120
testonly = 1,
121121
srcs = [
122122
"_src/test_util.py",
123+
"_src/test_warning_util.py",
123124
],
124125
visibility = [
125126
":internal",

jax/_src/test_util.py

+44-13
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import time
3636
from typing import Any, TextIO
3737
import unittest
38-
import warnings
3938
import zlib
4039

4140
from absl.testing import absltest
@@ -49,6 +48,7 @@
4948
from jax._src import dtypes as _dtypes
5049
from jax._src import lib as _jaxlib
5150
from jax._src import monitoring
51+
from jax._src import test_warning_util
5252
from jax._src import xla_bridge
5353
from jax._src import util
5454
from jax._src import mesh as mesh_lib
@@ -118,7 +118,7 @@
118118
)
119119

120120
TEST_NUM_THREADS = config.int_flag(
121-
'jax_test_num_threads', 0,
121+
'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')),
122122
help='Number of threads to use for running tests. 0 means run everything '
123123
'in the main thread. Using > 1 thread is experimental.'
124124
)
@@ -1076,7 +1076,7 @@ def stopTest(self, test: unittest.TestCase):
10761076
with self.lock:
10771077
# We assume test_result is an ABSL _TextAndXMLTestResult, so we can
10781078
# override how it gets the time.
1079-
time_getter = self.test_result.time_getter
1079+
time_getter = getattr(self.test_result, "time_getter", None)
10801080
try:
10811081
self.test_result.time_getter = lambda: self.start_time
10821082
self.test_result.startTest(test)
@@ -1085,7 +1085,8 @@ def stopTest(self, test: unittest.TestCase):
10851085
self.test_result.time_getter = lambda: stop_time
10861086
self.test_result.stopTest(test)
10871087
finally:
1088-
self.test_result.time_getter = time_getter
1088+
if time_getter is not None:
1089+
self.test_result.time_getter = time_getter
10891090

10901091
def addSuccess(self, test: unittest.TestCase):
10911092
self.actions.append(lambda: self.test_result.addSuccess(test))
@@ -1120,6 +1121,8 @@ def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.Test
11201121
if TEST_NUM_THREADS.value <= 0:
11211122
return super().run(result)
11221123

1124+
test_warning_util.install_threadsafe_warning_handlers()
1125+
11231126
executor = ThreadPoolExecutor(TEST_NUM_THREADS.value)
11241127
lock = threading.Lock()
11251128
futures = []
@@ -1368,11 +1371,44 @@ def assertMultiLineStrippedEqual(self, expected, what):
13681371
self.assertMultiLineEqual(expected_clean, what_clean,
13691372
msg=f"Found\n{what}\nExpecting\n{expected}")
13701373

1374+
13711375
@contextmanager
13721376
def assertNoWarnings(self):
1373-
with warnings.catch_warnings():
1374-
warnings.simplefilter("error")
1377+
with test_warning_util.raise_on_warnings():
1378+
yield
1379+
1380+
# We replace assertWarns and assertWarnsRegex with functions that use the
1381+
# thread-safe warning utilities. Unlike the unittest versions these only
1382+
# function as context managers.
1383+
@contextmanager
1384+
def assertWarns(self, warning, *, msg=None):
1385+
with test_warning_util.record_warnings() as ws:
1386+
yield
1387+
for w in ws:
1388+
if not isinstance(w.message, warning):
1389+
continue
1390+
if msg is not None and msg not in str(w.message):
1391+
continue
1392+
return
1393+
self.fail(f"Expected warning not found {warning}:'{msg}', got "
1394+
f"{ws}")
1395+
1396+
@contextmanager
1397+
def assertWarnsRegex(self, warning, regex):
1398+
if regex is not None:
1399+
regex = re.compile(regex)
1400+
1401+
with test_warning_util.record_warnings() as ws:
13751402
yield
1403+
for w in ws:
1404+
if not isinstance(w.message, warning):
1405+
continue
1406+
if regex is not None and not regex.search(str(w.message)):
1407+
continue
1408+
return
1409+
self.fail(f"Expected warning not found {warning}:'{regex}', got "
1410+
f"{ws}")
1411+
13761412

13771413
def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True, tol=None,
13781414
rtol=None, atol=None, check_cache_misses=True):
@@ -1449,11 +1485,7 @@ def assertNotDeleted(self, x):
14491485
self.assertFalse(x.is_deleted())
14501486

14511487

1452-
@contextmanager
1453-
def ignore_warning(*, message='', category=Warning, **kw):
1454-
with warnings.catch_warnings():
1455-
warnings.filterwarnings("ignore", message=message, category=category, **kw)
1456-
yield
1488+
ignore_warning = test_warning_util.ignore_warning
14571489

14581490
# -------------------- Mesh parametrization helpers --------------------
14591491

@@ -1768,9 +1800,8 @@ def make_axis_points(size):
17681800
logtiny = finfo.minexp / prec_dps_ratio
17691801
axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype)
17701802

1771-
with warnings.catch_warnings():
1803+
with ignore_warning(category=RuntimeWarning):
17721804
# Silence RuntimeWarning: overflow encountered in cast
1773-
warnings.simplefilter("ignore")
17741805
half_neg_line = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
17751806
half_line = -half_neg_line[::-1]
17761807
axis_points[-size - 1:-1] = half_line

jax/_src/test_warning_util.py

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Thread-safe utilities for catching and testing for warnings.
16+
#
17+
# The Python warnings module, at least as of Python 3.13, is not thread-safe.
18+
# The catch_warnings() feature is inherently racy, see
19+
# https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe
20+
#
21+
# This module offers a thread-safe way to catch and record warnings. We install
22+
# a custom showwarning hook with the Python warning module, and then rely on
23+
# the CPython warnings module to call our show warning function. We then use it
24+
# to create our own thread-safe warning filtering utilities.
25+
26+
import contextlib
27+
import re
28+
import threading
29+
import warnings
30+
31+
32+
class _WarningContext(threading.local):
33+
"Thread-local state that contains a list of warning handlers."
34+
35+
def __init__(self):
36+
self.handlers = []
37+
38+
39+
_context = _WarningContext()
40+
41+
42+
# Callback that applies the handlers in reverse order. If no handler matches,
43+
# we raise an error.
44+
def _showwarning(message, category, filename, lineno, file=None, line=None):
45+
for handler in reversed(_context.handlers):
46+
if handler(message, category, filename, lineno, file, line):
47+
return
48+
raise category(message)
49+
50+
51+
@contextlib.contextmanager
52+
def raise_on_warnings():
53+
"Context manager that raises an exception if a warning is raised."
54+
if warnings.showwarning is not _showwarning:
55+
with warnings.catch_warnings():
56+
warnings.simplefilter("error")
57+
yield
58+
return
59+
60+
def handler(message, category, filename, lineno, file=None, line=None):
61+
raise category(message)
62+
63+
_context.handlers.append(handler)
64+
try:
65+
yield
66+
finally:
67+
_context.handlers.pop()
68+
69+
70+
@contextlib.contextmanager
71+
def record_warnings():
72+
"Context manager that yields a list of warnings that are raised."
73+
if warnings.showwarning is not _showwarning:
74+
with warnings.catch_warnings(record=True) as w:
75+
warnings.simplefilter("always")
76+
yield w
77+
return
78+
79+
log = []
80+
81+
def handler(message, category, filename, lineno, file=None, line=None):
82+
log.append(warnings.WarningMessage(message, category, filename, lineno, file, line))
83+
return True
84+
85+
_context.handlers.append(handler)
86+
try:
87+
yield log
88+
finally:
89+
_context.handlers.pop()
90+
91+
92+
@contextlib.contextmanager
93+
def ignore_warning(*, message: str | None = None, category: type = Warning):
94+
"Context manager that ignores any matching warnings."
95+
if warnings.showwarning is not _showwarning:
96+
with warnings.catch_warnings():
97+
warnings.filterwarnings(
98+
"ignore", message="" if message is None else message, category=category)
99+
yield
100+
return
101+
102+
if message:
103+
message_re = re.compile(message)
104+
else:
105+
message_re = None
106+
107+
category_cls = category
108+
109+
def handler(message, category, filename, lineno, file=None, line=None):
110+
text = str(message) if isinstance(message, Warning) else message
111+
if (message_re is None or message_re.match(text)) and issubclass(
112+
category, category_cls
113+
):
114+
return True
115+
return False
116+
117+
_context.handlers.append(handler)
118+
try:
119+
yield
120+
finally:
121+
_context.handlers.pop()
122+
123+
124+
def install_threadsafe_warning_handlers():
125+
# Hook the showwarning method. The warnings module explicitly notes that
126+
# this is a function that users may replace.
127+
warnings.showwarning = _showwarning
128+
129+
# Set the warnings module to always display warnings. We hook into it by
130+
# overriding the "showwarning" method, so it's important that all warnings
131+
# are "shown" by the usual mechanism.
132+
warnings.simplefilter("always")

tests/BUILD

+8
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,14 @@ jax_py_test(
11531153
],
11541154
)
11551155

1156+
jax_py_test(
1157+
name = "warnings_util_test",
1158+
srcs = ["warnings_util_test.py"],
1159+
deps = [
1160+
"//jax:test_util",
1161+
] + py_deps("absl/testing"),
1162+
)
1163+
11561164
jax_py_test(
11571165
name = "xla_bridge_test",
11581166
srcs = ["xla_bridge_test.py"],

tests/compilation_cache_test.py

+23-25
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import unittest
2424
from unittest import mock
2525
from unittest import SkipTest
26-
import warnings
2726

2827
from absl.testing import absltest
2928
from absl.testing import parameterized
@@ -39,6 +38,7 @@
3938
from jax._src import monitoring
4039
from jax._src import path as pathlib
4140
from jax._src import test_util as jtu
41+
from jax._src import test_warning_util
4242
from jax._src import xla_bridge
4343
from jax._src.compilation_cache_interface import CacheInterface
4444
from jax._src.lib import xla_client as xc
@@ -232,21 +232,20 @@ def test_cache_write_warning(self):
232232
with (
233233
config.raise_persistent_cache_errors(False),
234234
mock.patch.object(cc._get_cache(backend).__class__, "put") as mock_put,
235-
warnings.catch_warnings(record=True) as w,
235+
test_warning_util.record_warnings() as w,
236236
):
237-
warnings.simplefilter("always")
238237
mock_put.side_effect = RuntimeError("test error")
239238
self.assertEqual(f(2).item(), 4)
240-
if len(w) != 1:
241-
print("Warnings:", [str(w_) for w_ in w], flush=True)
242-
self.assertLen(w, 1)
243-
self.assertIn(
244-
(
245-
"Error writing persistent compilation cache entry "
246-
"for 'jit__lambda_': RuntimeError: test error"
247-
),
248-
str(w[0].message),
249-
)
239+
if len(w) != 1:
240+
print("Warnings:", [str(w_) for w_ in w], flush=True)
241+
self.assertLen(w, 1)
242+
self.assertIn(
243+
(
244+
"Error writing persistent compilation cache entry "
245+
"for 'jit__lambda_': RuntimeError: test error"
246+
),
247+
str(w[0].message),
248+
)
250249

251250
def test_cache_read_warning(self):
252251
f = jit(lambda x: x * x)
@@ -255,23 +254,22 @@ def test_cache_read_warning(self):
255254
with (
256255
config.raise_persistent_cache_errors(False),
257256
mock.patch.object(cc._get_cache(backend).__class__, "get") as mock_get,
258-
warnings.catch_warnings(record=True) as w,
257+
test_warning_util.record_warnings() as w,
259258
):
260-
warnings.simplefilter("always")
261259
mock_get.side_effect = RuntimeError("test error")
262260
# Calling assertEqual with the jitted f will generate two PJIT
263261
# executables: Equal and the lambda function itself.
264262
self.assertEqual(f(2).item(), 4)
265-
if len(w) != 1:
266-
print("Warnings:", [str(w_) for w_ in w], flush=True)
267-
self.assertLen(w, 1)
268-
self.assertIn(
269-
(
270-
"Error reading persistent compilation cache entry "
271-
"for 'jit__lambda_': RuntimeError: test error"
272-
),
273-
str(w[0].message),
274-
)
263+
if len(w) != 1:
264+
print("Warnings:", [str(w_) for w_ in w], flush=True)
265+
self.assertLen(w, 1)
266+
self.assertIn(
267+
(
268+
"Error reading persistent compilation cache entry "
269+
"for 'jit__lambda_': RuntimeError: test error"
270+
),
271+
str(w[0].message),
272+
)
275273

276274
def test_min_entry_size(self):
277275
with (

tests/deprecation_test.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import warnings
16-
1715
from absl.testing import absltest
1816
from jax._src import deprecations
1917
from jax._src import test_util as jtu
18+
from jax._src import test_warning_util
2019
from jax._src.internal_test_util import deprecation_module as m
2120

2221
class DeprecationTest(absltest.TestCase):
2322

2423
def testModuleDeprecation(self):
25-
with warnings.catch_warnings():
26-
warnings.simplefilter("error")
24+
with test_warning_util.raise_on_warnings():
2725
self.assertEqual(m.x, 42)
2826

2927
with self.assertWarnsRegex(DeprecationWarning, "Please use x"):

tests/lax_numpy_reducers_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def testReducer(self, name, rng_factory, shape, dtype, out_dtype,
212212
rng = rng_factory(self.rng())
213213
@jtu.ignore_warning(category=NumpyComplexWarning)
214214
@jtu.ignore_warning(category=RuntimeWarning,
215-
message="mean of empty slice.*")
215+
message="Mean of empty slice.*")
216216
@jtu.ignore_warning(category=RuntimeWarning,
217217
message="overflow encountered.*")
218218
def np_fun(x):

0 commit comments

Comments
 (0)