From c4419717554857ef1b220e4d3a4b5f94c19ac4e1 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 2 Jan 2025 14:59:27 +0000 Subject: [PATCH] isdtype() should raise if parameter is not a dtype --- .gitignore | 8 ++++---- array_api_strict/_data_type_functions.py | 3 +++ array_api_strict/tests/test_data_type_functions.py | 6 ++++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index dbce267..f69e911 100644 --- a/.gitignore +++ b/.gitignore @@ -128,12 +128,12 @@ ENV/ env.bak/ venv.bak/ -# Spyder project settings +# Project settings +.idea +.ropeproject .spyderproject .spyproject - -# Rope project settings -.ropeproject +.vscode # mkdocs documentation /site diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index acd8967..5af46d2 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -167,6 +167,9 @@ def isdtype( https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html for more details """ + if not isinstance(dtype, _DType): + raise TypeError(f"'dtype' must be a dtype, not a {type(dtype)!r}") + if isinstance(kind, tuple): # Disallow nested tuples if any(isinstance(k, tuple) for k in kind): diff --git a/array_api_strict/tests/test_data_type_functions.py b/array_api_strict/tests/test_data_type_functions.py index 40cab55..488eab7 100644 --- a/array_api_strict/tests/test_data_type_functions.py +++ b/array_api_strict/tests/test_data_type_functions.py @@ -31,15 +31,17 @@ def test_can_cast(from_, to, expected): def test_isdtype_strictness(): assert_raises(TypeError, lambda: isdtype(float64, 64)) assert_raises(ValueError, lambda: isdtype(float64, 'f8')) - assert_raises(TypeError, lambda: isdtype(float64, (('integral',),))) + assert_raises(TypeError, lambda: isdtype(float64, None)) + assert_raises(TypeError, lambda: isdtype(np.float64, float64)) + assert_raises(TypeError, lambda: isdtype(asarray(1.0), float64)) + with assert_raises(TypeError), warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") isdtype(float64, np.object_) assert len(w) == 1 assert issubclass(w[-1].category, UserWarning) - assert_raises(TypeError, lambda: isdtype(float64, None)) with assert_raises(TypeError), warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") isdtype(float64, np.float64)