|
| 1 | +import re |
1 | 2 | from collections.abc import Mapping
|
2 | 3 | from functools import lru_cache
|
3 |
| -from typing import Any, NamedTuple, Sequence, Tuple, Union |
| 4 | +from inspect import signature |
| 5 | +from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union |
4 | 6 | from warnings import warn
|
5 | 7 |
|
6 | 8 | from . import _array_module as xp
|
7 | 9 | from ._array_module import _UndefinedStub
|
| 10 | +from .stubs import name_to_func |
8 | 11 | from .typing import DataType, ScalarType
|
9 | 12 |
|
10 | 13 | __all__ = [
|
@@ -242,67 +245,31 @@ def result_type(*dtypes: DataType):
|
242 | 245 | return result
|
243 | 246 |
|
244 | 247 |
|
245 |
| -func_in_dtypes = { |
246 |
| - # elementwise |
247 |
| - "abs": numeric_dtypes, |
248 |
| - "acos": float_dtypes, |
249 |
| - "acosh": float_dtypes, |
250 |
| - "add": numeric_dtypes, |
251 |
| - "asin": float_dtypes, |
252 |
| - "asinh": float_dtypes, |
253 |
| - "atan": float_dtypes, |
254 |
| - "atan2": float_dtypes, |
255 |
| - "atanh": float_dtypes, |
256 |
| - "bitwise_and": bool_and_all_int_dtypes, |
257 |
| - "bitwise_invert": bool_and_all_int_dtypes, |
258 |
| - "bitwise_left_shift": all_int_dtypes, |
259 |
| - "bitwise_or": bool_and_all_int_dtypes, |
260 |
| - "bitwise_right_shift": all_int_dtypes, |
261 |
| - "bitwise_xor": bool_and_all_int_dtypes, |
262 |
| - "ceil": numeric_dtypes, |
263 |
| - "cos": float_dtypes, |
264 |
| - "cosh": float_dtypes, |
265 |
| - "divide": float_dtypes, |
266 |
| - "equal": all_dtypes, |
267 |
| - "exp": float_dtypes, |
268 |
| - "expm1": float_dtypes, |
269 |
| - "floor": numeric_dtypes, |
270 |
| - "floor_divide": numeric_dtypes, |
271 |
| - "greater": numeric_dtypes, |
272 |
| - "greater_equal": numeric_dtypes, |
273 |
| - "isfinite": numeric_dtypes, |
274 |
| - "isinf": numeric_dtypes, |
275 |
| - "isnan": numeric_dtypes, |
276 |
| - "less": numeric_dtypes, |
277 |
| - "less_equal": numeric_dtypes, |
278 |
| - "log": float_dtypes, |
279 |
| - "logaddexp": float_dtypes, |
280 |
| - "log10": float_dtypes, |
281 |
| - "log1p": float_dtypes, |
282 |
| - "log2": float_dtypes, |
283 |
| - "logical_and": (xp.bool,), |
284 |
| - "logical_not": (xp.bool,), |
285 |
| - "logical_or": (xp.bool,), |
286 |
| - "logical_xor": (xp.bool,), |
287 |
| - "multiply": numeric_dtypes, |
288 |
| - "negative": numeric_dtypes, |
289 |
| - "not_equal": all_dtypes, |
290 |
| - "positive": numeric_dtypes, |
291 |
| - "pow": numeric_dtypes, |
292 |
| - "remainder": numeric_dtypes, |
293 |
| - "round": numeric_dtypes, |
294 |
| - "sign": numeric_dtypes, |
295 |
| - "sin": float_dtypes, |
296 |
| - "sinh": float_dtypes, |
297 |
| - "sqrt": float_dtypes, |
298 |
| - "square": numeric_dtypes, |
299 |
| - "subtract": numeric_dtypes, |
300 |
| - "tan": float_dtypes, |
301 |
| - "tanh": float_dtypes, |
302 |
| - "trunc": numeric_dtypes, |
303 |
| - # searching |
304 |
| - "where": all_dtypes, |
| 248 | +r_alias = re.compile("[aA]lias") |
| 249 | +r_in_dtypes = re.compile("x1?: array\n.+have an? (.+) data type.") |
| 250 | +r_int_note = re.compile( |
| 251 | + "If one or both of the input arrays have integer data types, " |
| 252 | + "the result is implementation-dependent" |
| 253 | +) |
| 254 | +category_to_dtypes = { |
| 255 | + "boolean": (xp.bool,), |
| 256 | + "integer": all_int_dtypes, |
| 257 | + "floating-point": float_dtypes, |
| 258 | + "numeric": numeric_dtypes, |
| 259 | + "integer or boolean": bool_and_all_int_dtypes, |
305 | 260 | }
|
| 261 | +func_in_dtypes: Dict[str, Tuple[DataType, ...]] = {} |
| 262 | +for name, func in name_to_func.items(): |
| 263 | + if m := r_in_dtypes.search(func.__doc__): |
| 264 | + dtype_category = m.group(1) |
| 265 | + if dtype_category == "numeric" and r_int_note.search(func.__doc__): |
| 266 | + dtype_category = "floating-point" |
| 267 | + dtypes = category_to_dtypes[dtype_category] |
| 268 | + func_in_dtypes[name] = dtypes |
| 269 | + elif any("x" in name for name in signature(func).parameters.keys()): |
| 270 | + func_in_dtypes[name] = all_dtypes |
| 271 | +# See https://github.com/data-apis/array-api/pull/413 |
| 272 | +func_in_dtypes["expm1"] = float_dtypes |
306 | 273 |
|
307 | 274 |
|
308 | 275 | func_returns_bool = {
|
@@ -365,6 +332,8 @@ def result_type(*dtypes: DataType):
|
365 | 332 | "trunc": False,
|
366 | 333 | # searching
|
367 | 334 | "where": False,
|
| 335 | + # linalg |
| 336 | + "matmul": False, |
368 | 337 | }
|
369 | 338 |
|
370 | 339 |
|
@@ -408,7 +377,7 @@ def result_type(*dtypes: DataType):
|
408 | 377 | "__gt__": "greater",
|
409 | 378 | "__le__": "less_equal",
|
410 | 379 | "__lt__": "less",
|
411 |
| - # '__matmul__': 'matmul', # TODO: support matmul |
| 380 | + "__matmul__": "matmul", |
412 | 381 | "__mod__": "remainder",
|
413 | 382 | "__mul__": "multiply",
|
414 | 383 | "__ne__": "not_equal",
|
@@ -440,6 +409,14 @@ def result_type(*dtypes: DataType):
|
440 | 409 | func_returns_bool[iop] = func_returns_bool[op]
|
441 | 410 |
|
442 | 411 |
|
| 412 | +func_in_dtypes["__bool__"] = (xp.bool,) |
| 413 | +func_in_dtypes["__int__"] = all_int_dtypes |
| 414 | +func_in_dtypes["__index__"] = all_int_dtypes |
| 415 | +func_in_dtypes["__float__"] = float_dtypes |
| 416 | +func_in_dtypes["from_dlpack"] = numeric_dtypes |
| 417 | +func_in_dtypes["__dlpack__"] = numeric_dtypes |
| 418 | + |
| 419 | + |
443 | 420 | @lru_cache
|
444 | 421 | def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str:
|
445 | 422 | f_types = []
|
|
0 commit comments