Skip to content

Commit

Permalink
isdtype() should raise if parameter is not a dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 2, 2025
1 parent e4b6bfe commit c441971
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
8 changes: 4 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ ENV/
env.bak/
venv.bak/

# Spyder project settings
# Project settings
.idea
.ropeproject
.spyderproject
.spyproject

# Rope project settings
.ropeproject
.vscode

# mkdocs documentation
/site
Expand Down
3 changes: 3 additions & 0 deletions array_api_strict/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def isdtype(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
for more details
"""
if not isinstance(dtype, _DType):
raise TypeError(f"'dtype' must be a dtype, not a {type(dtype)!r}")

if isinstance(kind, tuple):
# Disallow nested tuples
if any(isinstance(k, tuple) for k in kind):
Expand Down
6 changes: 4 additions & 2 deletions array_api_strict/tests/test_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ def test_can_cast(from_, to, expected):
def test_isdtype_strictness():
assert_raises(TypeError, lambda: isdtype(float64, 64))
assert_raises(ValueError, lambda: isdtype(float64, 'f8'))

assert_raises(TypeError, lambda: isdtype(float64, (('integral',),)))
assert_raises(TypeError, lambda: isdtype(float64, None))
assert_raises(TypeError, lambda: isdtype(np.float64, float64))
assert_raises(TypeError, lambda: isdtype(asarray(1.0), float64))

with assert_raises(TypeError), warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
isdtype(float64, np.object_)
assert len(w) == 1
assert issubclass(w[-1].category, UserWarning)

assert_raises(TypeError, lambda: isdtype(float64, None))
with assert_raises(TypeError), warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
isdtype(float64, np.float64)
Expand Down

0 comments on commit c441971

Please sign in to comment.