diff --git a/docs/source/user/jitclass.rst b/docs/source/user/jitclass.rst index 9576b3ab811..73cf406d992 100644 --- a/docs/source/user/jitclass.rst +++ b/docs/source/user/jitclass.rst @@ -182,6 +182,7 @@ compiled functions: * calling methods (e.g. ``mybag.increment(3)``); * calling static methods as instance attributes (e.g. ``mybag.add(1, 1)``); * calling static methods as class attributes (e.g. ``Bag.add(1, 2)``); +* using select dunder methods (e.g. ``__add__`` with ``mybag + otherbag``); Using jitclasses in Numba compiled function is more efficient. Short methods can be inlined (at the discretion of LLVM inliner). @@ -197,6 +198,61 @@ class definition (i.e. code cannot call ``Bag.add()`` from within another method of ``Bag``). +Supported dunder methods +------------------------ + +The following dunder methods may be defined for jitclasses: + +* ``__abs__`` +* ``__bool__`` +* ``__complex__`` +* ``__contains__`` +* ``__float__`` +* ``__getitem__`` +* ``__hash__`` +* ``__index__`` +* ``__int__`` +* ``__len__`` +* ``__setitem__`` +* ``__str__`` +* ``__eq__`` +* ``__ne__`` +* ``__ge__`` +* ``__gt__`` +* ``__le__`` +* ``__lt__`` +* ``__add__`` +* ``__floordiv__`` +* ``__lshift__`` +* ``__mod__`` +* ``__mul__`` +* ``__neg__`` +* ``__pos__`` +* ``__pow__`` +* ``__rshift__`` +* ``__sub__`` +* ``__truediv__`` +* ``__and__`` +* ``__or__`` +* ``__xor__`` +* ``__iadd__`` +* ``__ifloordiv__`` +* ``__ilshift__`` +* ``__imod__`` +* ``__imul__`` +* ``__ipow__`` +* ``__irshift__`` +* ``__isub__`` +* ``__itruediv__`` +* ``__iand__`` +* ``__ior__`` +* ``__ixor__`` + +Refer to the `Python Data Model documentation +`_ for descriptions of +these methods. + + Limitations =========== diff --git a/numba/core/typing/builtins.py b/numba/core/typing/builtins.py index eb883cb5f81..40b46a818d4 100644 --- a/numba/core/typing/builtins.py +++ b/numba/core/typing/builtins.py @@ -992,7 +992,7 @@ def generic(self, args, kws): if len(args) == 1: [arg] = args if arg not in types.number_domain: - raise TypeError("complex() only support for numbers") + raise errors.NumbaTypeError("complex() only support for numbers") if arg == types.float32: return signature(types.complex64, arg) else: @@ -1002,7 +1002,7 @@ def generic(self, args, kws): [real, imag] = args if (real not in types.number_domain or imag not in types.number_domain): - raise TypeError("complex() only support for numbers") + raise errors.NumbaTypeError("complex() only support for numbers") if real == imag == types.float32: return signature(types.complex64, real, imag) else: diff --git a/numba/experimental/jitclass/__init__.py b/numba/experimental/jitclass/__init__.py index 97c1903496b..981282f6536 100644 --- a/numba/experimental/jitclass/__init__.py +++ b/numba/experimental/jitclass/__init__.py @@ -1,2 +1,3 @@ from numba.experimental.jitclass.decorators import jitclass from numba.experimental.jitclass import boxing # Has import-time side effect +from numba.experimental.jitclass import overloads # Has import-time side effect diff --git a/numba/experimental/jitclass/base.py b/numba/experimental/jitclass/base.py index 0153edc5419..d61dc60e3f4 100644 --- a/numba/experimental/jitclass/base.py +++ b/numba/experimental/jitclass/base.py @@ -293,6 +293,12 @@ def _drop_ignored_attrs(dct): elif getattr(v, '__objclass__', None) is object: drop.add(k) + # If a class defines __eq__ but not __hash__, __hash__ is implicitly set to + # None. This is a class member, and class members are not presently + # supported. + if '__hash__' in dct and dct['__hash__'] is None: + drop.add('__hash__') + for k in drop: del dct[k] diff --git a/numba/experimental/jitclass/boxing.py b/numba/experimental/jitclass/boxing.py index e320acfbb4a..7d70f561a1b 100644 --- a/numba/experimental/jitclass/boxing.py +++ b/numba/experimental/jitclass/boxing.py @@ -8,8 +8,10 @@ from llvmlite import ir from numba.core import types, cgutils +from numba.core.decorators import njit from numba.core.pythonapi import box, unbox, NativeValue -from numba import njit +from numba.core.typing.typeof import typeof_impl +from numba.experimental.jitclass import _box _getter_code_template = """ @@ -95,18 +97,66 @@ def _specialize_box(typ): doc = getattr(imp, '__doc__', None) dct[field] = property(getter, setter, doc=doc) # Inject methods as class members + supported_dunders = { + "__abs__", + "__bool__", + "__complex__", + "__contains__", + "__float__", + "__getitem__", + "__hash__", + "__index__", + "__int__", + "__len__", + "__setitem__", + "__str__", + "__eq__", + "__ne__", + "__ge__", + "__gt__", + "__le__", + "__lt__", + "__add__", + "__floordiv__", + "__lshift__", + "__mod__", + "__mul__", + "__neg__", + "__pos__", + "__pow__", + "__rshift__", + "__sub__", + "__truediv__", + "__and__", + "__or__", + "__xor__", + "__iadd__", + "__ifloordiv__", + "__ilshift__", + "__imod__", + "__imul__", + "__ipow__", + "__irshift__", + "__isub__", + "__itruediv__", + "__iand__", + "__ior__", + "__ixor__", + } for name, func in typ.methods.items(): - if (name == "__getitem__" or name == "__setitem__") or \ - (not (name.startswith('__') and name.endswith('__'))): - - dct[name] = _generate_method(name, func) + if ( + name.startswith("__") + and name.endswith("__") + and name not in supported_dunders + ): + continue + dct[name] = _generate_method(name, func) # Inject static methods as class members for name, func in typ.static_methods.items(): dct[name] = _generate_method(name, func) # Create subclass - from numba.experimental.jitclass import _box subcls = type(typ.classname, (_box.Box,), dct) # Store to cache _cache_specialized_box[typ] = subcls @@ -159,7 +209,6 @@ def set_member(member_offset, value): casted = c.builder.bitcast(ptr, llvoidptr.as_pointer()) c.builder.store(value, casted) - from numba.experimental.jitclass import _box set_member(_box.box_meminfoptr_offset, addr_meminfo) set_member(_box.box_dataptr_offset, addr_data) return box @@ -179,7 +228,6 @@ def access_member(member_offset): inst = struct_cls(c.context, c.builder) # load from Python object - from numba.experimental.jitclass import _box ptr_meminfo = access_member(_box.box_meminfoptr_offset) ptr_dataptr = access_member(_box.box_dataptr_offset) @@ -192,3 +240,17 @@ def access_member(member_offset): c.context.nrt.incref(c.builder, typ, ret) return NativeValue(ret, is_error=c.pyapi.c_api_error()) + + +# Add a typeof_impl implementation for boxed jitclasses to short-circut the +# various tests in typeof. This is needed for jitclasses which implement a +# custom hash method. Without this, typeof_impl will return None, and one of the +# later attempts to determine the type of the jitclass (before checking for +# _numba_type_) will look up the object in a dictionary, triggering the hash +# method. This will cause the dispatcher to determine the call signature of the +# jit decorated obj.__hash__ method, which will call typeof(obj), and thus +# infinite loop. +# This implementation is here instead of in typeof.py to avoid circular imports. +@typeof_impl.register(_box.Box) +def _typeof_jitclass_box(val, c): + return getattr(type(val), "_numba_type_") diff --git a/numba/experimental/jitclass/overloads.py b/numba/experimental/jitclass/overloads.py new file mode 100644 index 00000000000..cf0e0ff3e08 --- /dev/null +++ b/numba/experimental/jitclass/overloads.py @@ -0,0 +1,227 @@ +""" +Overloads for ClassInstanceType for built-in functions that call dunder methods +on an object. +""" +from functools import wraps +import inspect +import operator + +from numba.core.extending import overload +from numba.core.types import ClassInstanceType +from numba.core.utils import PYVERSION + + +def _get_args(n_args): + assert n_args in (1, 2) + return list("xy")[:n_args] + + +def class_instance_overload(target): + """ + Decorator to add an overload for target that applies when the first argument + is a ClassInstanceType. + """ + def decorator(func): + @wraps(func) + def wrapped(*args, **kwargs): + if not isinstance(args[0], ClassInstanceType): + return + return func(*args, **kwargs) + + params = list(inspect.signature(wrapped).parameters) + assert params == _get_args(len(params)) + return overload(target)(wrapped) + + return decorator + + +def extract_template(template, name): + """ + Extract a code-generated function from a string template. + """ + namespace = {} + exec(template, namespace) + return namespace[name] + + +def register_simple_overload(func, *attrs, n_args=1,): + """ + Register an overload for func that checks for methods __attr__ for each + attr in attrs. + """ + # Use a template to set the signature correctly. + arg_names = _get_args(n_args) + template = f""" +def func({','.join(arg_names)}): + pass +""" + + @wraps(extract_template(template, "func")) + def overload_func(*args, **kwargs): + options = [ + try_call_method(args[0], f"__{attr}__", n_args) + for attr in attrs + ] + return take_first(*options) + + return class_instance_overload(func)(overload_func) + + +def try_call_method(cls_type, method, n_args=1): + """ + If method is defined for cls_type, return a callable that calls this method. + If not, return None. + """ + if method in cls_type.jit_methods: + arg_names = _get_args(n_args) + template = f""" +def func({','.join(arg_names)}): + return {arg_names[0]}.{method}({','.join(arg_names[1:])}) +""" + return extract_template(template, "func") + + +def take_first(*options): + """ + Take the first non-None option. + """ + assert all(o is None or inspect.isfunction(o) for o in options), options + for o in options: + if o is not None: + return o + + +@class_instance_overload(bool) +def class_bool(x): + using_bool_impl = try_call_method(x, "__bool__") + + if '__len__' in x.jit_methods: + def using_len_impl(x): + return bool(len(x)) + else: + using_len_impl = None + + always_true_impl = lambda x: True + + return take_first(using_bool_impl, using_len_impl, always_true_impl) + + +@class_instance_overload(complex) +def class_complex(x): + return take_first( + try_call_method(x, "__complex__"), + lambda x: complex(float(x)) + ) + + +@class_instance_overload(operator.contains) +def class_contains(x, y): + # https://docs.python.org/3/reference/expressions.html#membership-test-operations + return try_call_method(x, "__contains__", 2) + # TODO: use __iter__ if defined. + + +@class_instance_overload(float) +def class_float(x): + options = [try_call_method(x, "__float__")] + + if ( + PYVERSION >= (3, 8) + and "__index__" in x.jit_methods + ): + options.append(lambda x: float(x.__index__())) + + return take_first(*options) + + +@class_instance_overload(int) +def class_int(x): + options = [try_call_method(x, "__int__")] + + if PYVERSION >= (3, 8): + options.append(try_call_method(x, "__index__")) + + return take_first(*options) + + +@class_instance_overload(str) +def class_str(x): + return take_first( + try_call_method(x, "__str__"), + lambda x: repr(x), + ) + + +@class_instance_overload(operator.ne) +def class_ne(x, y): + # This doesn't use register_reflected_overload like the other operators + # because it falls back to inverting __eq__ rather than reflecting its + # arguments (as per the definition of the Python data model). + return take_first( + try_call_method(x, "__ne__", 2), + lambda x, y: not (x == y), + ) + + +def register_reflected_overload(func, meth_forward, meth_reflected): + def class_lt(x, y): + normal_impl = try_call_method(x, f"__{meth_forward}__", 2) + + if f"__{meth_reflected}__" in y.jit_methods: + def reflected_impl(x, y): + return y > x + else: + reflected_impl = None + + return take_first(normal_impl, reflected_impl) + + class_instance_overload(func)(class_lt) + + +register_simple_overload(abs, "abs") +register_simple_overload(len, "len") +register_simple_overload(hash, "hash") + +# Comparison operators. +register_reflected_overload(operator.ge, "ge", "le") +register_reflected_overload(operator.gt, "gt", "lt") +register_reflected_overload(operator.le, "le", "ge") +register_reflected_overload(operator.lt, "lt", "gt") + +# Note that eq is missing support for fallback to `x is y`, but `is` and +# `operator.is` are presently unsupported in general. +register_reflected_overload(operator.eq, "eq", "eq") + +# Arithmetic operators. +register_simple_overload(operator.add, "add", n_args=2) +register_simple_overload(operator.floordiv, "floordiv", n_args=2) +register_simple_overload(operator.lshift, "lshift", n_args=2) +register_simple_overload(operator.mul, "mul", n_args=2) +register_simple_overload(operator.mod, "mod", n_args=2) +register_simple_overload(operator.neg, "neg") +register_simple_overload(operator.pos, "pos") +register_simple_overload(operator.pow, "pow", n_args=2) +register_simple_overload(operator.rshift, "rshift", n_args=2) +register_simple_overload(operator.sub, "sub", n_args=2) +register_simple_overload(operator.truediv, "truediv", n_args=2) + +# Inplace arithmetic operators. +register_simple_overload(operator.iadd, "iadd", "add", n_args=2) +register_simple_overload(operator.ifloordiv, "ifloordiv", "floordiv", n_args=2) +register_simple_overload(operator.ilshift, "ilshift", "lshift", n_args=2) +register_simple_overload(operator.imul, "imul", "mul", n_args=2) +register_simple_overload(operator.imod, "imod", "mod", n_args=2) +register_simple_overload(operator.ipow, "ipow", "pow", n_args=2) +register_simple_overload(operator.irshift, "irshift", "rshift", n_args=2) +register_simple_overload(operator.isub, "isub", "sub", n_args=2) +register_simple_overload(operator.itruediv, "itruediv", "truediv", n_args=2) + +# Logical operators. +register_simple_overload(operator.and_, "and", n_args=2) +register_simple_overload(operator.or_, "or", n_args=2) +register_simple_overload(operator.xor, "xor", n_args=2) + +# Inplace logical operators. +register_simple_overload(operator.iand, "iand", "and", n_args=2) +register_simple_overload(operator.ior, "ior", "or", n_args=2) +register_simple_overload(operator.ixor, "ixor", "xor", n_args=2) diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index a88c50a4e49..f9dc74dacb8 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -1,4 +1,5 @@ import ctypes +import itertools import pickle import random import typing as pt @@ -11,8 +12,9 @@ njit, optional, typeof) from numba.core import errors, types from numba.core.dispatcher import Dispatcher -from numba.core.errors import LoweringError +from numba.core.errors import LoweringError, TypingError from numba.core.runtime.nrt import MemInfo +from numba.core.utils import PYVERSION from numba.experimental import jitclass from numba.experimental.jitclass import _box from numba.experimental.jitclass.base import JitClassType @@ -1153,5 +1155,724 @@ def test_jitclass_isinstance(obj): self.assertEqual(pyfunc(0), cfunc(0)) +class TestJitClassOverloads(MemoryLeakMixin, TestCase): + + class PyList: + def __init__(self): + self.x = [0] + + def append(self, y): + self.x.append(y) + + def clear(self): + self.x.clear() + + def __abs__(self): + return len(self.x) * 7 + + def __bool__(self): + return len(self.x) % 3 != 0 + + def __complex__(self): + c = complex(2) + if self.x: + c += self.x[0] + return c + + def __contains__(self, y): + return y in self.x + + def __float__(self): + f = 3.1415 + if self.x: + f += self.x[0] + return f + + def __int__(self): + i = 5 + if self.x: + i += self.x[0] + return i + + def __len__(self): + return len(self.x) + 1 + + def __str__(self): + if len(self.x) == 0: + return "PyList empty" + else: + return "PyList non-empty" + + @staticmethod + def get_int_wrapper(): + @jitclass([("x", types.intp)]) + class IntWrapper: + def __init__(self, value): + self.x = value + + def __eq__(self, other): + return self.x == other.x + + def __hash__(self): + return self.x + + def __lshift__(self, other): + return IntWrapper(self.x << other.x) + + def __rshift__(self, other): + return IntWrapper(self.x >> other.x) + + def __and__(self, other): + return IntWrapper(self.x & other.x) + + def __or__(self, other): + return IntWrapper(self.x | other.x) + + def __xor__(self, other): + return IntWrapper(self.x ^ other.x) + + return IntWrapper + + @staticmethod + def get_float_wrapper(): + @jitclass([("x", types.float64)]) + class FloatWrapper: + + def __init__(self, value): + self.x = value + + def __eq__(self, other): + return self.x == other.x + + def __hash__(self): + return self.x + + def __ge__(self, other): + return self.x >= other.x + + def __gt__(self, other): + return self.x > other.x + + def __le__(self, other): + return self.x <= other.x + + def __lt__(self, other): + return self.x < other.x + + def __add__(self, other): + return FloatWrapper(self.x + other.x) + + def __floordiv__(self, other): + return FloatWrapper(self.x // other.x) + + def __mod__(self, other): + return FloatWrapper(self.x % other.x) + + def __mul__(self, other): + return FloatWrapper(self.x * other.x) + + def __neg__(self, other): + return FloatWrapper(-self.x) + + def __pos__(self, other): + return FloatWrapper(+self.x) + + def __pow__(self, other): + return FloatWrapper(self.x ** other.x) + + def __sub__(self, other): + return FloatWrapper(self.x - other.x) + + def __truediv__(self, other): + return FloatWrapper(self.x / other.x) + + return FloatWrapper + + def assertSame(self, first, second, msg=None): + self.assertEqual(type(first), type(second), msg=msg) + self.assertEqual(first, second, msg=msg) + + def test_overloads(self): + # Check that the dunder methods are exposed on ClassInstanceType. + + JitList = jitclass({"x": types.List(types.intp)})(self.PyList) + + py_funcs = [ + lambda x: abs(x), + lambda x: x.__abs__(), + lambda x: bool(x), + lambda x: x.__bool__(), + lambda x: complex(x), + lambda x: x.__complex__(), + lambda x: 0 in x, # contains + lambda x: x.__contains__(0), + lambda x: float(x), + lambda x: x.__float__(), + lambda x: int(x), + lambda x: x.__int__(), + lambda x: len(x), + lambda x: x.__len__(), + lambda x: str(x), + lambda x: x.__str__(), + lambda x: 1 if x else 0, # truth + ] + jit_funcs = [njit(f) for f in py_funcs] + + py_list = self.PyList() + jit_list = JitList() + for py_f, jit_f in zip(py_funcs, jit_funcs): + self.assertSame(py_f(py_list), py_f(jit_list)) + self.assertSame(py_f(py_list), jit_f(jit_list)) + + py_list.append(2) + jit_list.append(2) + for py_f, jit_f in zip(py_funcs, jit_funcs): + self.assertSame(py_f(py_list), py_f(jit_list)) + self.assertSame(py_f(py_list), jit_f(jit_list)) + + py_list.append(-5) + jit_list.append(-5) + for py_f, jit_f in zip(py_funcs, jit_funcs): + self.assertSame(py_f(py_list), py_f(jit_list)) + self.assertSame(py_f(py_list), jit_f(jit_list)) + + py_list.clear() + jit_list.clear() + for py_f, jit_f in zip(py_funcs, jit_funcs): + self.assertSame(py_f(py_list), py_f(jit_list)) + self.assertSame(py_f(py_list), jit_f(jit_list)) + + def test_bool_fallback(self): + + def py_b(x): + return bool(x) + + jit_b = njit(py_b) + + @jitclass([("x", types.List(types.intp))]) + class LenClass: + def __init__(self, x): + self.x = x + + def __len__(self): + return len(self.x) % 4 + + def append(self, y): + self.x.append(y) + + def pop(self): + self.x.pop(0) + + obj = LenClass([1, 2, 3]) + self.assertTrue(py_b(obj)) + self.assertTrue(jit_b(obj)) + + obj.append(4) + self.assertFalse(py_b(obj)) + self.assertFalse(jit_b(obj)) + + obj.pop() + self.assertTrue(py_b(obj)) + self.assertTrue(jit_b(obj)) + + @jitclass([("y", types.float64)]) + class NormalClass: + def __init__(self, y): + self.y = y + + obj = NormalClass(0) + self.assertTrue(py_b(obj)) + self.assertTrue(jit_b(obj)) + + def test_numeric_fallback(self): + def py_c(x): + return complex(x) + + def py_f(x): + return float(x) + + def py_i(x): + return int(x) + + jit_c = njit(py_c) + jit_f = njit(py_f) + jit_i = njit(py_i) + + @jitclass([]) + class FloatClass: + def __init__(self): + pass + + def __float__(self): + return 3.1415 + + obj = FloatClass() + self.assertSame(py_c(obj), complex(3.1415)) + self.assertSame(jit_c(obj), complex(3.1415)) + self.assertSame(py_f(obj), 3.1415) + self.assertSame(jit_f(obj), 3.1415) + + with self.assertRaises(TypeError) as e: + py_i(obj) + self.assertIn("int", str(e.exception)) + with self.assertRaises(TypingError) as e: + jit_i(obj) + self.assertIn("int", str(e.exception)) + + @jitclass([]) + class IntClass: + def __init__(self): + pass + + def __int__(self): + return 7 + + obj = IntClass() + self.assertSame(py_i(obj), 7) + self.assertSame(jit_i(obj), 7) + + with self.assertRaises(TypeError) as e: + py_c(obj) + self.assertIn("complex", str(e.exception)) + with self.assertRaises(TypingError) as e: + jit_c(obj) + self.assertIn("complex", str(e.exception)) + with self.assertRaises(TypeError) as e: + py_f(obj) + self.assertIn("float", str(e.exception)) + with self.assertRaises(TypingError) as e: + jit_f(obj) + self.assertIn("float", str(e.exception)) + + @jitclass([]) + class IndexClass: + def __init__(self): + pass + + def __index__(self): + return 1 + + obj = IndexClass() + + if PYVERSION >= (3, 8): + self.assertSame(py_c(obj), complex(1)) + self.assertSame(jit_c(obj), complex(1)) + self.assertSame(py_f(obj), 1.) + self.assertSame(jit_f(obj), 1.) + self.assertSame(py_i(obj), 1) + self.assertSame(jit_i(obj), 1) + else: + with self.assertRaises(TypeError) as e: + py_c(obj) + self.assertIn("complex", str(e.exception)) + with self.assertRaises(TypingError) as e: + jit_c(obj) + self.assertIn("complex", str(e.exception)) + with self.assertRaises(TypeError) as e: + py_f(obj) + self.assertIn("float", str(e.exception)) + with self.assertRaises(TypingError) as e: + jit_f(obj) + self.assertIn("float", str(e.exception)) + with self.assertRaises(TypeError) as e: + py_i(obj) + self.assertIn("int", str(e.exception)) + with self.assertRaises(TypingError) as e: + jit_i(obj) + self.assertIn("int", str(e.exception)) + + @jitclass([]) + class FloatIntIndexClass: + def __init__(self): + pass + + def __float__(self): + return 3.1415 + + def __int__(self): + return 7 + + def __index__(self): + return 1 + + obj = FloatIntIndexClass() + self.assertSame(py_c(obj), complex(3.1415)) + self.assertSame(jit_c(obj), complex(3.1415)) + self.assertSame(py_f(obj), 3.1415) + self.assertSame(jit_f(obj), 3.1415) + self.assertSame(py_i(obj), 7) + self.assertSame(jit_i(obj), 7) + + def test_arithmetic_logical(self): + IntWrapper = self.get_int_wrapper() + FloatWrapper = self.get_float_wrapper() + + float_py_funcs = [ + lambda x, y: x == y, + lambda x, y: x != y, + lambda x, y: x >= y, + lambda x, y: x > y, + lambda x, y: x <= y, + lambda x, y: x < y, + lambda x, y: x + y, + lambda x, y: x // y, + lambda x, y: x % y, + lambda x, y: x * y, + lambda x, y: x ** y, + lambda x, y: x - y, + lambda x, y: x / y, + ] + int_py_funcs = [ + lambda x, y: x == y, + lambda x, y: x != y, + lambda x, y: x << y, + lambda x, y: x >> y, + lambda x, y: x & y, + lambda x, y: x | y, + lambda x, y: x ^ y, + ] + + test_values = [ + (0.0, 2.0), + (1.234, 3.1415), + (13.1, 1.01), + ] + + def unwrap(value): + return getattr(value, "x", value) + + for jit_f, (x, y) in itertools.product( + map(njit, float_py_funcs), test_values): + + py_f = jit_f.py_func + + expected = py_f(x, y) + jit_x = FloatWrapper(x) + jit_y = FloatWrapper(y) + + check = ( + self.assertEqual + if type(expected) is not float + else self.assertAlmostEqual + ) + check(expected, jit_f(x, y)) + check(expected, unwrap(py_f(jit_x, jit_y))) + check(expected, unwrap(jit_f(jit_x, jit_y))) + + for jit_f, (x, y) in itertools.product( + map(njit, int_py_funcs), test_values): + + py_f = jit_f.py_func + x, y = int(x), int(y) + + expected = py_f(x, y) + jit_x = IntWrapper(x) + jit_y = IntWrapper(y) + + self.assertEqual(expected, jit_f(x, y)) + self.assertEqual(expected, unwrap(py_f(jit_x, jit_y))) + self.assertEqual(expected, unwrap(jit_f(jit_x, jit_y))) + + def test_arithmetic_logical_inplace(self): + + # If __i*__ methods are not defined, should fall back to normal methods. + JitIntWrapper = self.get_int_wrapper() + JitFloatWrapper = self.get_float_wrapper() + + PyIntWrapper = JitIntWrapper.mro()[1] + PyFloatWrapper = JitFloatWrapper.mro()[1] + + @jitclass([("x", types.intp)]) + class JitIntUpdateWrapper(PyIntWrapper): + def __init__(self, value): + self.x = value + + def __ilshift__(self, other): + return JitIntUpdateWrapper(self.x << other.x) + + def __irshift__(self, other): + return JitIntUpdateWrapper(self.x >> other.x) + + def __iand__(self, other): + return JitIntUpdateWrapper(self.x & other.x) + + def __ior__(self, other): + return JitIntUpdateWrapper(self.x | other.x) + + def __ixor__(self, other): + return JitIntUpdateWrapper(self.x ^ other.x) + + @jitclass({"x": types.float64}) + class JitFloatUpdateWrapper(PyFloatWrapper): + + def __init__(self, value): + self.x = value + + def __iadd__(self, other): + return JitFloatUpdateWrapper(self.x + 2.718 * other.x) + + def __ifloordiv__(self, other): + return JitFloatUpdateWrapper(self.x * 2.718 // other.x) + + def __imod__(self, other): + return JitFloatUpdateWrapper(self.x % (other.x + 1)) + + def __imul__(self, other): + return JitFloatUpdateWrapper(self.x * other.x + 1) + + def __ipow__(self, other): + return JitFloatUpdateWrapper(self.x ** other.x + 1) + + def __isub__(self, other): + return JitFloatUpdateWrapper(self.x - 3.1415 * other.x) + + def __itruediv__(self, other): + return JitFloatUpdateWrapper((self.x + 1) / other.x) + + PyIntUpdateWrapper = JitIntUpdateWrapper.mro()[1] + PyFloatUpdateWrapper = JitFloatUpdateWrapper.mro()[1] + + def get_update_func(op): + template = f""" +def f(x, y): + x {op}= y + return x +""" + namespace = {} + exec(template, namespace) + return namespace["f"] + + float_py_funcs = [get_update_func(op) for op in [ + "+", "//", "%", "*", "**", "-", "/", + ]] + int_py_funcs = [get_update_func(op) for op in [ + "<<", ">>", "&", "|", "^", + ]] + + test_values = [ + (0.0, 2.0), + (1.234, 3.1415), + (13.1, 1.01), + ] + + for jit_f, (py_cls, jit_cls), (x, y) in itertools.product( + map(njit, float_py_funcs), + [ + (PyFloatWrapper, JitFloatWrapper), + (PyFloatUpdateWrapper, JitFloatUpdateWrapper) + ], + test_values): + py_f = jit_f.py_func + + expected = py_f(py_cls(x), py_cls(y)).x + self.assertAlmostEqual(expected, py_f(jit_cls(x), jit_cls(y)).x) + self.assertAlmostEqual(expected, jit_f(jit_cls(x), jit_cls(y)).x) + + for jit_f, (py_cls, jit_cls), (x, y) in itertools.product( + map(njit, int_py_funcs), + [ + (PyIntWrapper, JitIntWrapper), + (PyIntUpdateWrapper, JitIntUpdateWrapper) + ], + test_values): + x, y = int(x), int(y) + py_f = jit_f.py_func + + expected = py_f(py_cls(x), py_cls(y)).x + self.assertEqual(expected, py_f(jit_cls(x), jit_cls(y)).x) + self.assertEqual(expected, jit_f(jit_cls(x), jit_cls(y)).x) + + def test_hash_eq_ne(self): + + class HashEqTest: + x: int + + def __init__(self, x): + self.x = x + + def __hash__(self): + return self.x % 10 + + def __eq__(self, o): + return (self.x - o.x) % 20 == 0 + + class HashEqNeTest(HashEqTest): + def __ne__(self, o): + return (self.x - o.x) % 20 > 1 + + def py_hash(x): + return hash(x) + + def py_eq(x, y): + return x == y + + def py_ne(x, y): + return x != y + + def identity_decorator(f): + return f + + comparisons = [ + (0, 1), # Will give different ne results. + (2, 22), + (7, 10), + (3, 3), + ] + + for base_cls, use_jit in itertools.product( + [HashEqTest, HashEqNeTest], [False, True] + ): + decorator = njit if use_jit else identity_decorator + hash_func = decorator(py_hash) + eq_func = decorator(py_eq) + ne_func = decorator(py_ne) + + jit_cls = jitclass(base_cls) + + for v in [0, 2, 10, 24, -8]: + self.assertEqual(hash_func(jit_cls(v)), v % 10) + + for x, y in comparisons: + self.assertEqual( + eq_func(jit_cls(x), jit_cls(y)), + base_cls(x) == base_cls(y), + ) + self.assertEqual( + ne_func(jit_cls(x), jit_cls(y)), + base_cls(x) != base_cls(y), + ) + + def test_bool_fallback_len(self): + # Check that the fallback to using len(obj) to determine truth of an + # object is implemented correctly as per + # https://docs.python.org/3/library/stdtypes.html#truth-value-testing + # + # Relevant points: + # + # "By default, an object is considered true unless its class defines + # either a __bool__() method that returns False or a __len__() method + # that returns zero, when called with the object." + # + # and: + # + # "Operations and built-in functions that have a Boolean result always + # return 0 or False for false and 1 or True for true, unless otherwise + # stated." + + class NoBoolHasLen: + def __init__(self, val): + self.val = val + + def __len__(self): + return self.val + + def get_bool(self): + return bool(self) + + py_class = NoBoolHasLen + jitted_class = jitclass([('val', types.int64)])(py_class) + + py_class_0_bool = py_class(0).get_bool() + py_class_2_bool = py_class(2).get_bool() + jitted_class_0_bool = jitted_class(0).get_bool() + jitted_class_2_bool = jitted_class(2).get_bool() + + # Truth values from bool(obj) should be equal + self.assertEqual(py_class_0_bool, jitted_class_0_bool) + self.assertEqual(py_class_2_bool, jitted_class_2_bool) + + # Truth values from bool(obj) should be the same type + self.assertEqual(type(py_class_0_bool), type(jitted_class_0_bool)) + self.assertEqual(type(py_class_2_bool), type(jitted_class_2_bool)) + + def test_bool_fallback_default(self): + # Similar to test_bool_fallback, but checks the case where there is no + # __bool__() or __len__() defined, so the object should always be True. + + class NoBoolNoLen: + def __init__(self): + pass + + def get_bool(self): + return bool(self) + + py_class = NoBoolNoLen + jitted_class = jitclass([])(py_class) + + py_class_bool = py_class().get_bool() + jitted_class_bool = jitted_class().get_bool() + + # Truth values from bool(obj) should be equal + self.assertEqual(py_class_bool, jitted_class_bool) + + # Truth values from bool(obj) should be the same type + self.assertEqual(type(py_class_bool), type(jitted_class_bool)) + + def test_operator_reflection(self): + class OperatorsDefined: + def __init__(self, x): + self.x = x + + def __eq__(self, other): + return self.x == other.x + + def __le__(self, other): + return self.x <= other.x + + def __lt__(self, other): + return self.x < other.x + + def __ge__(self, other): + return self.x >= other.x + + def __gt__(self, other): + return self.x > other.x + + class NoOperatorsDefined: + def __init__(self, x): + self.x = x + + spec = [('x', types.int32)] + JitOperatorsDefined = jitclass(spec)(OperatorsDefined) + JitNoOperatorsDefined = jitclass(spec)(NoOperatorsDefined) + + py_ops_defined = OperatorsDefined(2) + py_ops_not_defined = NoOperatorsDefined(3) + + jit_ops_defined = JitOperatorsDefined(2) + jit_ops_not_defined = JitNoOperatorsDefined(3) + + self.assertEqual(py_ops_not_defined == py_ops_defined, + jit_ops_not_defined == jit_ops_defined) + + self.assertEqual(py_ops_not_defined <= py_ops_defined, + jit_ops_not_defined <= jit_ops_defined) + + self.assertEqual(py_ops_not_defined < py_ops_defined, + jit_ops_not_defined < jit_ops_defined) + + self.assertEqual(py_ops_not_defined >= py_ops_defined, + jit_ops_not_defined >= jit_ops_defined) + + self.assertEqual(py_ops_not_defined > py_ops_defined, + jit_ops_not_defined > jit_ops_defined) + + def test_implicit_hash_compiles(self): + # Ensure that classes with __hash__ implicitly defined as None due to + # the presence of __eq__ are correctly handled by ignoring the __hash__ + # class member. + class ImplicitHash: + def __init__(self): + pass + + def __eq__(self, other): + return False + + jitted = jitclass([])(ImplicitHash) + instance = jitted() + + self.assertFalse(instance == instance) + + if __name__ == "__main__": unittest.main()