|
27 | 27 | from absl.testing import absltest
|
28 | 28 | from absl.testing import parameterized
|
29 | 29 | import ml_dtypes
|
30 |
| -# from multi_thread_utils import multi_threaded |
| 30 | +from multi_thread_utils import multi_threaded |
31 | 31 | import numpy as np
|
32 | 32 |
|
33 | 33 | bfloat16 = ml_dtypes.bfloat16
|
@@ -221,12 +221,16 @@ def dtype_is_signed(dtype):
|
221 | 221 | }
|
222 | 222 |
|
223 | 223 |
|
224 |
| -# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release. |
225 | 224 | # pylint: disable=g-complex-comprehension
|
226 |
| -# @multi_threaded( |
227 |
| -# num_workers=3, |
228 |
| -# skip_tests=["testDiv", "testRoundTripNumpyTypes", "testRoundTripToNumpy"], |
229 |
| -# ) |
| 225 | +@multi_threaded( |
| 226 | + num_workers=3, |
| 227 | + skip_tests=[ |
| 228 | + "testDiv", |
| 229 | + "testPickleable", |
| 230 | + "testRoundTripNumpyTypes", |
| 231 | + "testRoundTripToNumpy", |
| 232 | + ], |
| 233 | +) |
230 | 234 | @parameterized.named_parameters(
|
231 | 235 | (
|
232 | 236 | {"testcase_name": "_" + dtype.__name__, "float_type": dtype}
|
@@ -661,21 +665,25 @@ def testDtypeFromString(self, float_type):
|
661 | 665 | ]
|
662 | 666 |
|
663 | 667 |
|
664 |
| -# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release. |
665 | 668 | # pylint: disable=g-complex-comprehension
|
666 |
| -# @multi_threaded( |
667 |
| -# num_workers=3, |
668 |
| -# skip_tests=[ |
669 |
| -# "testBinaryUfunc", |
670 |
| -# "testConformNumpyComplex", |
671 |
| -# "testFloordivCornerCases", |
672 |
| -# "testDivmodCornerCases", |
673 |
| -# "testSpacing", |
674 |
| -# "testUnaryUfunc", |
675 |
| -# "testCasts", |
676 |
| -# "testLdexp", |
677 |
| -# ], |
678 |
| -# ) |
| 669 | +@multi_threaded( |
| 670 | + num_workers=3, |
| 671 | + skip_tests=[ |
| 672 | + "testBinaryPredicateUfunc", |
| 673 | + "testBinaryUfunc", |
| 674 | + "testCasts", |
| 675 | + "testConformNumpyComplex", |
| 676 | + "testDivmod", |
| 677 | + "testDivmodCornerCases", |
| 678 | + "testFloordivCornerCases", |
| 679 | + "testFrexp", |
| 680 | + "testLdexp", |
| 681 | + "testModf", |
| 682 | + "testPredicateUfunc", |
| 683 | + "testSpacing", |
| 684 | + "testUnaryUfunc", |
| 685 | + ], |
| 686 | +) |
679 | 687 | @parameterized.named_parameters(
|
680 | 688 | (
|
681 | 689 | {"testcase_name": "_" + dtype.__name__, "float_type": dtype}
|
|
0 commit comments