Skip to content

Commit f1439a9

Browse files
author
The ml_dtypes Authors
committed
Merge pull request #203 from vfdev-5:reenable-mt-tests
PiperOrigin-RevId: 716020733
2 parents 5f1240a + 4ac1090 commit f1439a9

File tree

5 files changed

+37
-34
lines changed

5 files changed

+37
-34
lines changed

ml_dtypes/tests/custom_float_test.py

+28-20
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from absl.testing import absltest
2828
from absl.testing import parameterized
2929
import ml_dtypes
30-
# from multi_thread_utils import multi_threaded
30+
from multi_thread_utils import multi_threaded
3131
import numpy as np
3232

3333
bfloat16 = ml_dtypes.bfloat16
@@ -221,12 +221,16 @@ def dtype_is_signed(dtype):
221221
}
222222

223223

224-
# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
225224
# pylint: disable=g-complex-comprehension
226-
# @multi_threaded(
227-
# num_workers=3,
228-
# skip_tests=["testDiv", "testRoundTripNumpyTypes", "testRoundTripToNumpy"],
229-
# )
225+
@multi_threaded(
226+
num_workers=3,
227+
skip_tests=[
228+
"testDiv",
229+
"testPickleable",
230+
"testRoundTripNumpyTypes",
231+
"testRoundTripToNumpy",
232+
],
233+
)
230234
@parameterized.named_parameters(
231235
(
232236
{"testcase_name": "_" + dtype.__name__, "float_type": dtype}
@@ -661,21 +665,25 @@ def testDtypeFromString(self, float_type):
661665
]
662666

663667

664-
# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
665668
# pylint: disable=g-complex-comprehension
666-
# @multi_threaded(
667-
# num_workers=3,
668-
# skip_tests=[
669-
# "testBinaryUfunc",
670-
# "testConformNumpyComplex",
671-
# "testFloordivCornerCases",
672-
# "testDivmodCornerCases",
673-
# "testSpacing",
674-
# "testUnaryUfunc",
675-
# "testCasts",
676-
# "testLdexp",
677-
# ],
678-
# )
669+
@multi_threaded(
670+
num_workers=3,
671+
skip_tests=[
672+
"testBinaryPredicateUfunc",
673+
"testBinaryUfunc",
674+
"testCasts",
675+
"testConformNumpyComplex",
676+
"testDivmod",
677+
"testDivmodCornerCases",
678+
"testFloordivCornerCases",
679+
"testFrexp",
680+
"testLdexp",
681+
"testModf",
682+
"testPredicateUfunc",
683+
"testSpacing",
684+
"testUnaryUfunc",
685+
],
686+
)
679687
@parameterized.named_parameters(
680688
(
681689
{"testcase_name": "_" + dtype.__name__, "float_type": dtype}

ml_dtypes/tests/finfo_test.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from absl.testing import absltest
1616
from absl.testing import parameterized
1717
import ml_dtypes
18-
# from multi_thread_utils import multi_threaded
18+
from multi_thread_utils import multi_threaded
1919
import numpy as np
2020

2121
ALL_DTYPES = [
@@ -55,8 +55,7 @@
5555
}
5656

5757

58-
# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
59-
# @multi_threaded(num_workers=3)
58+
@multi_threaded(num_workers=3)
6059
class FinfoTest(parameterized.TestCase):
6160

6261
def assertNanEqual(self, x, y):

ml_dtypes/tests/iinfo_test.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515
from absl.testing import absltest
1616
from absl.testing import parameterized
1717
import ml_dtypes
18-
# from multi_thread_utils import multi_threaded
18+
from multi_thread_utils import multi_threaded
1919
import numpy as np
2020

2121

22-
# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
23-
# @multi_threaded(num_workers=3)
22+
@multi_threaded(num_workers=3)
2423
class IinfoTest(parameterized.TestCase):
2524

2625
def testIinfoInt2(self):

ml_dtypes/tests/intn_test.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from absl.testing import absltest
2424
from absl.testing import parameterized
2525
import ml_dtypes
26-
# from multi_thread_utils import multi_threaded
26+
from multi_thread_utils import multi_threaded
2727
import numpy as np
2828

2929
int2 = ml_dtypes.int2
@@ -48,9 +48,8 @@ def ignore_warning(**kw):
4848
yield
4949

5050

51-
# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
5251
# Tests for the Python scalar type
53-
# @multi_threaded(num_workers=3)
52+
@multi_threaded(num_workers=3)
5453
class ScalarTest(parameterized.TestCase):
5554

5655
@parameterized.product(scalar_type=INTN_TYPES)
@@ -247,9 +246,8 @@ def testCanCast(self, a, b):
247246
)
248247

249248

250-
# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
251249
# Tests for the Python scalar type
252-
# @multi_threaded(num_workers=3, skip_tests=["testBinaryUfuncs"])
250+
@multi_threaded(num_workers=3, skip_tests=["testBinaryUfuncs"])
253251
class ArrayTest(parameterized.TestCase):
254252

255253
@parameterized.product(scalar_type=INTN_TYPES)

ml_dtypes/tests/metadata_test.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616

1717
from absl.testing import absltest
1818
import ml_dtypes
19-
# from multi_thread_utils import multi_threaded
19+
from multi_thread_utils import multi_threaded
2020

2121

22-
# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
23-
# @multi_threaded(num_workers=3)
22+
@multi_threaded(num_workers=3)
2423
class CustomFloatTest(absltest.TestCase):
2524

2625
def test_version_matches_package_metadata(self):

0 commit comments

Comments
 (0)