From ab432837c07bafd6899eae8535cc82ab8112d9bd Mon Sep 17 00:00:00 2001 From: Ethan Pronovost Date: Fri, 12 Jun 2020 22:44:40 -0700 Subject: [PATCH 01/25] Basic support for jitclass builtins Hash is still not working --- numba/experimental/jitclass/boxing.py | 10 +++++++--- numba/tests/test_jitclasses.py | 25 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/numba/experimental/jitclass/boxing.py b/numba/experimental/jitclass/boxing.py index 314b6ef565e..3d2da7ed10c 100644 --- a/numba/experimental/jitclass/boxing.py +++ b/numba/experimental/jitclass/boxing.py @@ -96,10 +96,14 @@ def _specialize_box(typ): doc = getattr(imp, '__doc__', None) dct[field] = property(getter, setter, doc=doc) # Inject methods as class members + bad_builtins = { + "__name__", + "__doc__", + "__init__", + "__new__", + } for name, func in typ.methods.items(): - if (name == "__getitem__" or name == "__setitem__") or \ - (not (name.startswith('__') and name.endswith('__'))): - + if name not in bad_builtins: dct[name] = _generate_method(name, func) # Inject static methods as class members diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index 87f11657f1b..0203b7703d4 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -1041,6 +1041,31 @@ def __init__(self): self.assertIs(ws[0].category, errors.NumbaDeprecationWarning) self.assertIn("numba.experimental.jitclass", ws[0].message.msg) + def test_builtins(self): + + @jitclass({"x": types.List(types.intp)}) + class MyList: + def __init__(self): + self.x = [0] + + def append(self, y): + self.x.append(y) + + def __len__(self): + return len(self.x) + + def __contains__(self, y): + return y in self.x + + foo = MyList() + self.assertEqual(len(foo), 1) + self.assertTrue(0 in foo) + self.assertFalse(1 in foo) + foo.append(1) + self.assertEqual(len(foo), 2) + self.assertTrue(0 in foo) + self.assertTrue(1 in foo) + if __name__ == '__main__': unittest.main() From f8e4c0bcf0d9a9b7a0bafcce4936b2bdf2a82b48 Mon Sep 17 00:00:00 2001 From: Ethan Pronovost Date: Sat, 13 Jun 2020 20:45:55 -0700 Subject: [PATCH 02/25] Jitclass support some dunder methods --- numba/experimental/jitclass/__init__.py | 1 + numba/experimental/jitclass/boxing.py | 21 +- numba/experimental/jitclass/overloads.py | 90 ++++++++ numba/tests/test_jitclasses.py | 273 +++++++++++++++++++++-- 4 files changed, 359 insertions(+), 26 deletions(-) create mode 100644 numba/experimental/jitclass/overloads.py diff --git a/numba/experimental/jitclass/__init__.py b/numba/experimental/jitclass/__init__.py index 97c1903496b..981282f6536 100644 --- a/numba/experimental/jitclass/__init__.py +++ b/numba/experimental/jitclass/__init__.py @@ -1,2 +1,3 @@ from numba.experimental.jitclass.decorators import jitclass from numba.experimental.jitclass import boxing # Has import-time side effect +from numba.experimental.jitclass import overloads # Has import-time side effect diff --git a/numba/experimental/jitclass/boxing.py b/numba/experimental/jitclass/boxing.py index 3d2da7ed10c..f7f4d8db74e 100644 --- a/numba/experimental/jitclass/boxing.py +++ b/numba/experimental/jitclass/boxing.py @@ -96,15 +96,24 @@ def _specialize_box(typ): doc = getattr(imp, '__doc__', None) dct[field] = property(getter, setter, doc=doc) # Inject methods as class members - bad_builtins = { - "__name__", - "__doc__", - "__init__", - "__new__", + supported_dunders = { + "__abs__", + "__bool__", + "__complex__", + "__contains__", + "__float__", + "__getitem__", + "__index__", + "__int__", + "__len__", + "__setitem__", } for name, func in typ.methods.items(): - if name not in bad_builtins: + if (not (name.startswith("__") and name.endswith("__")) or + name in supported_dunders): dct[name] = _generate_method(name, func) + # if name == "__hash__": + # dct[name] = func # Inject static methods as class members for name, func in typ.static_methods.items(): diff --git a/numba/experimental/jitclass/overloads.py b/numba/experimental/jitclass/overloads.py new file mode 100644 index 00000000000..c950072db30 --- /dev/null +++ b/numba/experimental/jitclass/overloads.py @@ -0,0 +1,90 @@ +""" +Overloads for ClassInstanceType for built-in functions that call dunder methods +on an object. +""" +import operator +import sys + +from numba.core.extending import overload +from numba.core.types import ClassInstanceType + + +@overload(abs) +def class_abs(x): + if not isinstance(x, ClassInstanceType): + return + + if "__abs__" in x.jit_methods: + return lambda x: x.__abs__() + + +@overload(bool) +def class_bool(x): + if not isinstance(x, ClassInstanceType): + return + + if "__bool__" in x.jit_methods: + return lambda x: x.__bool__() + + if "__len__" in x.jit_methods: + return lambda x: x.__len__() != 0 + + return lambda x: True + + +@overload(complex) +def class_complex(x): + if not isinstance(x, ClassInstanceType): + return + + if "__complex__" in x.jit_methods: + return lambda x: x.__complex__() + + return lambda x: complex(float(x)) + + +@overload(operator.contains) +def class_contains(x, y): + # https://docs.python.org/3/reference/expressions.html#membership-test-operations + if not isinstance(x, ClassInstanceType): + return + + if "__contains__" in x.jit_methods: + return lambda x, y: x.__contains__(y) + + # TODO: use __iter__ if defined. + + +@overload(float) +def class_float(x): + if not isinstance(x, ClassInstanceType): + return + + if "__float__" in x.jit_methods: + return lambda x: x.__float__() + + if ((sys.version_info.major, sys.version_info.minor) >= (3, 8) and + "__index__" in x.jit_methods): + return lambda x: float(x.__index__()) + + +@overload(int) +def class_int(x): + if not isinstance(x, ClassInstanceType): + return + + if "__int__" in x.jit_methods: + return lambda x: x.__int__() + + if ((sys.version_info.major, sys.version_info.minor) >= (3, 8) and + "__index__" in x.jit_methods): + return lambda x: x.__index__() + + +@overload(len) +def class_len(x): + if not isinstance(x, ClassInstanceType): + return + + if "__len__" in x.jit_methods: + return lambda x: x.__len__() diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index 0203b7703d4..34c7051aaa2 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -2,6 +2,7 @@ import ctypes import random import pickle +import sys import warnings import numba @@ -12,7 +13,7 @@ from numba import njit, typeof from numba.core import types, errors from numba.core.dispatcher import Dispatcher -from numba.core.errors import LoweringError +from numba.core.errors import LoweringError, TypingError from numba.core.runtime.nrt import MemInfo from numba.experimental import jitclass from numba.experimental.jitclass import _box @@ -1041,30 +1042,262 @@ def __init__(self): self.assertIs(ws[0].category, errors.NumbaDeprecationWarning) self.assertIn("numba.experimental.jitclass", ws[0].message.msg) - def test_builtins(self): - @jitclass({"x": types.List(types.intp)}) - class MyList: - def __init__(self): - self.x = [0] +class TestJitClassOverloads(TestCase, MemoryLeakMixin): + + class PyList: + def __init__(self): + self.x = [0] + + def append(self, y): + self.x.append(y) + + def clear(self): + self.x.clear() + + def __abs__(self): + return len(self.x) * 7 + + def __bool__(self): + return len(self.x) % 3 != 0 + + def __complex__(self): + c = complex(2) + if self.x: + c += self.x[0] + return c + + def __contains__(self, y): + return y in self.x + + def __float__(self): + f = 3.1415 + if self.x: + f += self.x[0] + return f + + def __int__(self): + i = 5 + if self.x: + i += self.x[0] + return i + + def __len__(self): + return len(self.x) + 1 + + def assertSame(self, first, second, msg=None): + self.assertEqual(type(first), type(second), msg=msg) + self.assertEqual(first, second, msg=msg) + + def test_simple(self): + """ + Check that the dunder methods are exposed on ClassInstanceType. + """ + JitList = jitclass({"x": types.List(types.intp)})(self.PyList) + + py_funcs = [ + lambda x: abs(x), + lambda x: x.__abs__(), + lambda x: bool(x), + lambda x: x.__bool__(), + lambda x: complex(x), + lambda x: x.__complex__(), + lambda x: 0 in x, # contains + lambda x: x.__contains__(0), + lambda x: float(x), + lambda x: x.__float__(), + lambda x: int(x), + lambda x: x.__int__(), + lambda x: len(x), + lambda x: x.__len__(), + lambda x: 1 if x else 0, # truth + ] + jit_funcs = [njit(f) for f in py_funcs] + + py_list = self.PyList() + jit_list = JitList() + for py_f, jit_f in zip(py_funcs, jit_funcs): + self.assertSame(py_f(py_list), py_f(jit_list)) + self.assertSame(py_f(py_list), jit_f(jit_list)) + + py_list.append(2) + jit_list.append(2) + for py_f, jit_f in zip(py_funcs, jit_funcs): + self.assertSame(py_f(py_list), py_f(jit_list)) + self.assertSame(py_f(py_list), jit_f(jit_list)) + + py_list.append(-5) + jit_list.append(-5) + for py_f, jit_f in zip(py_funcs, jit_funcs): + self.assertSame(py_f(py_list), py_f(jit_list)) + self.assertSame(py_f(py_list), jit_f(jit_list)) + + py_list.clear() + jit_list.clear() + for py_f, jit_f in zip(py_funcs, jit_funcs): + self.assertSame(py_f(py_list), py_f(jit_list)) + self.assertSame(py_f(py_list), jit_f(jit_list)) + + def test_bool_fallback(self): + + def py_b(x): + return bool(x) + + jit_b = njit(py_b) + + @jitclass([("x", types.List(types.intp))]) + class LenClass: + def __init__(self, x): + self.x = x + + def __len__(self): + return len(self.x) % 4 def append(self, y): self.x.append(y) - def __len__(self): - return len(self.x) - - def __contains__(self, y): - return y in self.x - - foo = MyList() - self.assertEqual(len(foo), 1) - self.assertTrue(0 in foo) - self.assertFalse(1 in foo) - foo.append(1) - self.assertEqual(len(foo), 2) - self.assertTrue(0 in foo) - self.assertTrue(1 in foo) + def pop(self): + self.x.pop(0) + + obj = LenClass([1, 2, 3]) + self.assertTrue(py_b(obj)) + self.assertTrue(jit_b(obj)) + + obj.append(4) + self.assertFalse(py_b(obj)) + self.assertFalse(jit_b(obj)) + + obj.pop() + self.assertTrue(py_b(obj)) + self.assertTrue(jit_b(obj)) + + @jitclass([("y", types.float64)]) + class NormalClass: + def __init__(self, y): + self.y = y + + obj = NormalClass(0) + self.assertTrue(py_b(obj)) + self.assertTrue(jit_b(obj)) + + def test_numeric_fallback(self): + def py_c(x): + return complex(x) + + def py_f(x): + return float(x) + + def py_i(x): + return int(x) + + jit_c = njit(py_c) + jit_f = njit(py_f) + jit_i = njit(py_i) + + @jitclass([]) + class FloatClass: + def __init__(self): + pass + + def __float__(self): + return 3.1415 + + obj = FloatClass() + self.assertEquals(py_c(obj), complex(3.1415)) + self.assertEquals(jit_c(obj), complex(3.1415)) + self.assertEquals(py_f(obj), 3.1415) + self.assertEquals(jit_f(obj), 3.1415) + + with self.assertRaises(TypeError) as e: + py_i(obj) + self.assertIn("int", str(e.exception)) + with self.assertRaises(TypingError) as e: + jit_i(obj) + self.assertIn("int", str(e.exception)) + + @jitclass([]) + class IntClass: + def __init__(self): + pass + + def __int__(self): + return 7 + + obj = IntClass() + self.assertEquals(py_i(obj), 7) + self.assertEquals(jit_i(obj), 7) + + with self.assertRaises(TypeError) as e: + py_c(obj) + self.assertIn("complex", str(e.exception)) + with self.assertRaises(TypingError) as e: + jit_c(obj) + self.assertIn("complex", str(e.exception)) + with self.assertRaises(TypeError) as e: + py_f(obj) + self.assertIn("float", str(e.exception)) + with self.assertRaises(TypingError) as e: + jit_f(obj) + self.assertIn("float", str(e.exception)) + + @jitclass([]) + class IndexClass: + def __init__(self): + pass + + def __index__(self): + return 1 + + obj = IndexClass() + + if sys.version[:3] >= "3.8": + self.assertEquals(py_c(obj), complex(1)) + self.assertEquals(jit_c(obj), complex(1)) + self.assertEquals(py_f(obj), 1.) + self.assertEquals(jit_f(obj), 1.) + self.assertEquals(py_i(obj), 1) + self.assertEquals(jit_i(obj), 1) + else: + with self.assertRaises(TypeError) as e: + py_c(obj) + self.assertIn("complex", str(e.exception)) + with self.assertRaises(TypingError) as e: + jit_c(obj) + self.assertIn("complex", str(e.exception)) + with self.assertRaises(TypeError) as e: + py_f(obj) + self.assertIn("float", str(e.exception)) + with self.assertRaises(TypingError) as e: + jit_f(obj) + self.assertIn("float", str(e.exception)) + with self.assertRaises(TypeError) as e: + py_i(obj) + self.assertIn("int", str(e.exception)) + with self.assertRaises(TypingError) as e: + jit_i(obj) + self.assertIn("int", str(e.exception)) + + @jitclass([]) + class FloatIntIndexClass: + def __init__(self): + pass + + def __float__(self): + return 3.1415 + + def __int__(self): + return 7 + + def __index__(self): + return 1 + + obj = FloatIntIndexClass() + self.assertEquals(py_c(obj), complex(3.1415)) + self.assertEquals(jit_c(obj), complex(3.1415)) + self.assertEquals(py_f(obj), 3.1415) + self.assertEquals(jit_f(obj), 3.1415) + self.assertEquals(py_i(obj), 7) + self.assertEquals(jit_i(obj), 7) if __name__ == '__main__': From 1ec935242471c167ef74290619f4bc771344d14e Mon Sep 17 00:00:00 2001 From: Ethan Pronovost Date: Sun, 14 Jun 2020 19:53:21 -0700 Subject: [PATCH 03/25] Add arithmetic and logical operators --- numba/experimental/jitclass/boxing.py | 21 +++ numba/experimental/jitclass/overloads.py | 63 ++++++-- numba/tests/test_jitclasses.py | 185 ++++++++++++++++++++--- 3 files changed, 241 insertions(+), 28 deletions(-) diff --git a/numba/experimental/jitclass/boxing.py b/numba/experimental/jitclass/boxing.py index f7f4d8db74e..947bede4925 100644 --- a/numba/experimental/jitclass/boxing.py +++ b/numba/experimental/jitclass/boxing.py @@ -107,6 +107,27 @@ def _specialize_box(typ): "__int__", "__len__", "__setitem__", + "__str__", + # "__eq__", + # "__ne__", + "__ge__", + "__gt__", + "__le__", + "__lt__", + "__add__", + "__floordiv__", + "__lshift__", + "__mod__", + "__mul__", + "__neg__", + "__pos__", + "__pow__", + "__rshift__", + "__sub__", + "__truediv__", + "__and__", + "__or__", + "__xor__", } for name, func in typ.methods.items(): if (not (name.startswith("__") and name.endswith("__")) or diff --git a/numba/experimental/jitclass/overloads.py b/numba/experimental/jitclass/overloads.py index c950072db30..c4823ab5597 100644 --- a/numba/experimental/jitclass/overloads.py +++ b/numba/experimental/jitclass/overloads.py @@ -9,13 +9,26 @@ from numba.core.types import ClassInstanceType -@overload(abs) -def class_abs(x): - if not isinstance(x, ClassInstanceType): +def register_class_overload(func, attr, nargs=1): + """ + Register overload handler for func calling class attribute attr. + """ + args = list("abcdefg")[:nargs] + arg0 = args[0] + + template = f""" +def handler({",".join(args)}): + if not isinstance({arg0}, ClassInstanceType): return + if "__{attr}__" in {arg0}.jit_methods: + return lambda {",".join(args)}: {arg0}.__{attr}__({",".join(args[1:])}) +""" - if "__abs__" in x.jit_methods: - return lambda x: x.__abs__() + namespace = dict(ClassInstanceType=ClassInstanceType) + exec(template, namespace) + + handler = namespace["handler"] + overload(func)(handler) @overload(bool) @@ -81,10 +94,42 @@ def class_int(x): return lambda x: x.__index__() -@overload(len) -def class_len(x): +@overload(str) +def class_str(x): if not isinstance(x, ClassInstanceType): return - if "__len__" in x.jit_methods: - return lambda x: x.__len__() + if "__str__" in x.jit_methods: + return lambda x: x.__str__() + + return lambda x: repr(x) + + +register_class_overload(abs, "abs") +register_class_overload(len, "len") + +# Comparison operators. +# register_class_overload(operator.eq, "eq", 2) +# register_class_overload(operator.ne, "ne", 2) +register_class_overload(operator.ge, "ge", 2) +register_class_overload(operator.gt, "gt", 2) +register_class_overload(operator.le, "le", 2) +register_class_overload(operator.lt, "lt", 2) + +# Arithmetic operators. +register_class_overload(operator.add, "add", 2) +register_class_overload(operator.floordiv, "floordiv", 2) +register_class_overload(operator.lshift, "lshift", 2) +register_class_overload(operator.mod, "mod", 2) +register_class_overload(operator.mul, "mul", 2) +register_class_overload(operator.neg, "neg") +register_class_overload(operator.pos, "pos") +register_class_overload(operator.pow, "pow", 2) +register_class_overload(operator.rshift, "rshift", 2) +register_class_overload(operator.sub, "sub", 2) +register_class_overload(operator.truediv, "truediv", 2) + +# Logical operators. +register_class_overload(operator.and_, "and", 2) +register_class_overload(operator.or_, "or", 2) +register_class_overload(operator.xor, "xor", 2) diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index 34c7051aaa2..99dd4a04641 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -1,5 +1,6 @@ from collections import OrderedDict import ctypes +import itertools import random import pickle import sys @@ -1085,11 +1086,17 @@ def __int__(self): def __len__(self): return len(self.x) + 1 + def __str__(self): + if len(self.x) == 0: + return "PyList empty" + else: + return "PyList non-empty" + def assertSame(self, first, second, msg=None): self.assertEqual(type(first), type(second), msg=msg) self.assertEqual(first, second, msg=msg) - def test_simple(self): + def test_overloads(self): """ Check that the dunder methods are exposed on ClassInstanceType. """ @@ -1110,6 +1117,8 @@ def test_simple(self): lambda x: x.__int__(), lambda x: len(x), lambda x: x.__len__(), + lambda x: str(x), + lambda x: x.__str__(), lambda x: 1 if x else 0, # truth ] jit_funcs = [njit(f) for f in py_funcs] @@ -1203,10 +1212,10 @@ def __float__(self): return 3.1415 obj = FloatClass() - self.assertEquals(py_c(obj), complex(3.1415)) - self.assertEquals(jit_c(obj), complex(3.1415)) - self.assertEquals(py_f(obj), 3.1415) - self.assertEquals(jit_f(obj), 3.1415) + self.assertSame(py_c(obj), complex(3.1415)) + self.assertSame(jit_c(obj), complex(3.1415)) + self.assertSame(py_f(obj), 3.1415) + self.assertSame(jit_f(obj), 3.1415) with self.assertRaises(TypeError) as e: py_i(obj) @@ -1224,8 +1233,8 @@ def __int__(self): return 7 obj = IntClass() - self.assertEquals(py_i(obj), 7) - self.assertEquals(jit_i(obj), 7) + self.assertSame(py_i(obj), 7) + self.assertSame(jit_i(obj), 7) with self.assertRaises(TypeError) as e: py_c(obj) @@ -1251,12 +1260,12 @@ def __index__(self): obj = IndexClass() if sys.version[:3] >= "3.8": - self.assertEquals(py_c(obj), complex(1)) - self.assertEquals(jit_c(obj), complex(1)) - self.assertEquals(py_f(obj), 1.) - self.assertEquals(jit_f(obj), 1.) - self.assertEquals(py_i(obj), 1) - self.assertEquals(jit_i(obj), 1) + self.assertSame(py_c(obj), complex(1)) + self.assertSame(jit_c(obj), complex(1)) + self.assertSame(py_f(obj), 1.) + self.assertSame(jit_f(obj), 1.) + self.assertSame(py_i(obj), 1) + self.assertSame(jit_i(obj), 1) else: with self.assertRaises(TypeError) as e: py_c(obj) @@ -1292,12 +1301,150 @@ def __index__(self): return 1 obj = FloatIntIndexClass() - self.assertEquals(py_c(obj), complex(3.1415)) - self.assertEquals(jit_c(obj), complex(3.1415)) - self.assertEquals(py_f(obj), 3.1415) - self.assertEquals(jit_f(obj), 3.1415) - self.assertEquals(py_i(obj), 7) - self.assertEquals(jit_i(obj), 7) + self.assertSame(py_c(obj), complex(3.1415)) + self.assertSame(jit_c(obj), complex(3.1415)) + self.assertSame(py_f(obj), 3.1415) + self.assertSame(jit_f(obj), 3.1415) + self.assertSame(py_i(obj), 7) + self.assertSame(jit_i(obj), 7) + + def test_arithmetic(self): + + @jitclass([("x", types.intp)]) + class IntWrapper: + def __init__(self, value): + self.x = value + + def __lshift__(self, other): + return IntWrapper(self.x << other.x) + + def __rshift__(self, other): + return IntWrapper(self.x >> other.x) + + def __and__(self, other): + return IntWrapper(self.x & other.x) + + def __or__(self, other): + return IntWrapper(self.x | other.x) + + def __xor__(self, other): + return IntWrapper(self.x ^ other.x) + + @jitclass([("x", types.float64)]) + class FloatWrapper: + + def __init__(self, value): + self.x = value + + # def __eq__(self, other): + # print("Eq", self, other, self.x, other.x) + # return self.x == other.x + + # def __ne__(self, other): + # return self.x != other.x + + def __ge__(self, other): + return self.x >= other.x + + def __gt__(self, other): + return self.x > other.x + + def __le__(self, other): + return self.x <= other.x + + def __lt__(self, other): + return self.x < other.x + + def __add__(self, other): + return FloatWrapper(self.x + other.x) + + def __floordiv__(self, other): + return FloatWrapper(self.x // other.x) + + def __mod__(self, other): + return FloatWrapper(self.x % other.x) + + def __mul__(self, other): + return FloatWrapper(self.x * other.x) + + def __neg__(self, other): + return FloatWrapper(-self.x) + + def __pos__(self, other): + return FloatWrapper(+self.x) + + def __pow__(self, other): + return FloatWrapper(self.x ** other.x) + + def __sub__(self, other): + return FloatWrapper(self.x - other.x) + + def __truediv__(self, other): + return FloatWrapper(self.x / other.x) + + float_py_funcs = [ + # lambda x, y: x == y, + # lambda x, y: x != y, + lambda x, y: x >= y, + lambda x, y: x > y, + lambda x, y: x <= y, + lambda x, y: x < y, + lambda x, y: x + y, + lambda x, y: x // y, + lambda x, y: x % y, + lambda x, y: x * y, + lambda x, y: x ** y, + lambda x, y: x - y, + lambda x, y: x / y, + ] + int_py_funcs = [ + lambda x, y: x << y, + lambda x, y: x >> y, + lambda x, y: x & y, + lambda x, y: x | y, + lambda x, y: x ^ y, + ] + + test_values = [ + (0.0, 2.0), + (1.234, 3.1415), + (13.1, 1.01), + ] + + def unwrap(value): + return getattr(value, "x", value) + + for jit_f, (x, y) in itertools.product( + map(njit, float_py_funcs), test_values): + + py_f = jit_f.py_func + + expected = py_f(x, y) + jit_x = FloatWrapper(x) + jit_y = FloatWrapper(y) + + check = ( + self.assertEqual + if type(expected) is not float + else self.assertAlmostEqual + ) + check(expected, jit_f(x, y)) + check(expected, unwrap(py_f(jit_x, jit_y))) + check(expected, unwrap(jit_f(jit_x, jit_y))) + + for jit_f, (x, y) in itertools.product( + map(njit, int_py_funcs), test_values): + + py_f = jit_f.py_func + x, y = int(x), int(y) + + expected = py_f(x, y) + jit_x = IntWrapper(x) + jit_y = IntWrapper(y) + + self.assertEqual(expected, jit_f(x, y)) + self.assertEqual(expected, unwrap(py_f(jit_x, jit_y))) + self.assertEqual(expected, unwrap(jit_f(jit_x, jit_y))) if __name__ == '__main__': From 26a48fa985d7c28fa31dbe95131658f7e496122c Mon Sep 17 00:00:00 2001 From: Ethan Pronovost Date: Sun, 14 Jun 2020 22:37:57 -0700 Subject: [PATCH 04/25] Support inplace operators --- numba/experimental/jitclass/boxing.py | 16 +- numba/experimental/jitclass/overloads.py | 90 +++++--- numba/tests/test_jitclasses.py | 265 ++++++++++++++++------- 3 files changed, 268 insertions(+), 103 deletions(-) diff --git a/numba/experimental/jitclass/boxing.py b/numba/experimental/jitclass/boxing.py index 947bede4925..f5c7d157db8 100644 --- a/numba/experimental/jitclass/boxing.py +++ b/numba/experimental/jitclass/boxing.py @@ -128,13 +128,25 @@ def _specialize_box(typ): "__and__", "__or__", "__xor__", + "__iadd__", + "__ifloordiv__", + "__ilshift__", + "__imod__", + "__imul__", + "__ineg__", + "__ipos__", + "__ipow__", + "__irshift__", + "__isub__", + "__itruediv__", + "__iand__", + "__ior__", + "__ixor__", } for name, func in typ.methods.items(): if (not (name.startswith("__") and name.endswith("__")) or name in supported_dunders): dct[name] = _generate_method(name, func) - # if name == "__hash__": - # dct[name] = func # Inject static methods as class members for name, func in typ.static_methods.items(): diff --git a/numba/experimental/jitclass/overloads.py b/numba/experimental/jitclass/overloads.py index c4823ab5597..bfc0313eb9a 100644 --- a/numba/experimental/jitclass/overloads.py +++ b/numba/experimental/jitclass/overloads.py @@ -9,28 +9,46 @@ from numba.core.types import ClassInstanceType -def register_class_overload(func, attr, nargs=1): - """ - Register overload handler for func calling class attribute attr. - """ - args = list("abcdefg")[:nargs] +def _get_args(nargs=1): + return list("xyzabcdefg")[:nargs] + + +def _simple_template(*attrs, nargs=1): + args = _get_args(nargs=nargs) arg0 = args[0] template = f""" def handler({",".join(args)}): if not isinstance({arg0}, ClassInstanceType): return +""" + for attr in attrs: + assert isinstance(attr, str) + template += f""" if "__{attr}__" in {arg0}.jit_methods: return lambda {",".join(args)}: {arg0}.__{attr}__({",".join(args[1:])}) """ + return template + +def _register_overload(func, template, glbls=None): + """ + Register overload handler for func calling class attribute attr. + """ namespace = dict(ClassInstanceType=ClassInstanceType) - exec(template, namespace) + if glbls: + namespace.update(glbls) + exec(template, namespace) handler = namespace["handler"] overload(func)(handler) +def register_simple_overload(func, *attrs, nargs=1): + template = _simple_template(*attrs, nargs=nargs) + _register_overload(func, template) + + @overload(bool) def class_bool(x): if not isinstance(x, ClassInstanceType): @@ -105,31 +123,47 @@ def class_str(x): return lambda x: repr(x) -register_class_overload(abs, "abs") -register_class_overload(len, "len") +register_simple_overload(abs, "abs") +register_simple_overload(len, "len") # Comparison operators. -# register_class_overload(operator.eq, "eq", 2) -# register_class_overload(operator.ne, "ne", 2) -register_class_overload(operator.ge, "ge", 2) -register_class_overload(operator.gt, "gt", 2) -register_class_overload(operator.le, "le", 2) -register_class_overload(operator.lt, "lt", 2) +# register_simple_overload(operator.eq, "eq", 2) +# register_simple_overload(operator.ne, "ne", 2) +register_simple_overload(operator.ge, "ge", nargs=2) +register_simple_overload(operator.gt, "gt", nargs=2) +register_simple_overload(operator.le, "le", nargs=2) +register_simple_overload(operator.lt, "lt", nargs=2) # Arithmetic operators. -register_class_overload(operator.add, "add", 2) -register_class_overload(operator.floordiv, "floordiv", 2) -register_class_overload(operator.lshift, "lshift", 2) -register_class_overload(operator.mod, "mod", 2) -register_class_overload(operator.mul, "mul", 2) -register_class_overload(operator.neg, "neg") -register_class_overload(operator.pos, "pos") -register_class_overload(operator.pow, "pow", 2) -register_class_overload(operator.rshift, "rshift", 2) -register_class_overload(operator.sub, "sub", 2) -register_class_overload(operator.truediv, "truediv", 2) +register_simple_overload(operator.add, "add", nargs=2) +register_simple_overload(operator.floordiv, "floordiv", nargs=2) +register_simple_overload(operator.lshift, "lshift", nargs=2) +register_simple_overload(operator.mod, "mod", nargs=2) +register_simple_overload(operator.mul, "mul", nargs=2) +register_simple_overload(operator.neg, "neg") +register_simple_overload(operator.pos, "pos") +register_simple_overload(operator.pow, "pow", nargs=2) +register_simple_overload(operator.rshift, "rshift", nargs=2) +register_simple_overload(operator.sub, "sub", nargs=2) +register_simple_overload(operator.truediv, "truediv", nargs=2) + +# Inplace arithmetic operators. +register_simple_overload(operator.iadd, "iadd", "add", nargs=2) +register_simple_overload(operator.ifloordiv, "ifloordiv", "floordiv", nargs=2) +register_simple_overload(operator.ilshift, "ilshift", "lshift", nargs=2) +register_simple_overload(operator.imod, "imod", "mod", nargs=2) +register_simple_overload(operator.imul, "imul", "mul", nargs=2) +register_simple_overload(operator.ipow, "ipow", "pow", nargs=2) +register_simple_overload(operator.irshift, "irshift", "rshift", nargs=2) +register_simple_overload(operator.isub, "isub", "sub", nargs=2) +register_simple_overload(operator.itruediv, "itruediv", "truediv", nargs=2) # Logical operators. -register_class_overload(operator.and_, "and", 2) -register_class_overload(operator.or_, "or", 2) -register_class_overload(operator.xor, "xor", 2) +register_simple_overload(operator.and_, "and", nargs=2) +register_simple_overload(operator.or_, "or", nargs=2) +register_simple_overload(operator.xor, "xor", nargs=2) + +# Inplace logical operators. +register_simple_overload(operator.iand, "iand", "and", nargs=2) +register_simple_overload(operator.ior, "ior", "or", nargs=2) +register_simple_overload(operator.ixor, "ixor", "xor", nargs=2) diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index 99dd4a04641..4a7140c9245 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -1092,6 +1092,86 @@ def __str__(self): else: return "PyList non-empty" + @staticmethod + def get_int_warpper(): + @jitclass([("x", types.intp)]) + class IntWrapper: + def __init__(self, value): + self.x = value + + def __lshift__(self, other): + return IntWrapper(self.x << other.x) + + def __rshift__(self, other): + return IntWrapper(self.x >> other.x) + + def __and__(self, other): + return IntWrapper(self.x & other.x) + + def __or__(self, other): + return IntWrapper(self.x | other.x) + + def __xor__(self, other): + return IntWrapper(self.x ^ other.x) + + return IntWrapper + + @staticmethod + def get_float_wrapper(): + @jitclass([("x", types.float64)]) + class FloatWrapper: + + def __init__(self, value): + self.x = value + + # def __eq__(self, other): + # print("Eq", self, other, self.x, other.x) + # return self.x == other.x + + # def __ne__(self, other): + # return self.x != other.x + + def __ge__(self, other): + return self.x >= other.x + + def __gt__(self, other): + return self.x > other.x + + def __le__(self, other): + return self.x <= other.x + + def __lt__(self, other): + return self.x < other.x + + def __add__(self, other): + return FloatWrapper(self.x + other.x) + + def __floordiv__(self, other): + return FloatWrapper(self.x // other.x) + + def __mod__(self, other): + return FloatWrapper(self.x % other.x) + + def __mul__(self, other): + return FloatWrapper(self.x * other.x) + + def __neg__(self, other): + return FloatWrapper(-self.x) + + def __pos__(self, other): + return FloatWrapper(+self.x) + + def __pow__(self, other): + return FloatWrapper(self.x ** other.x) + + def __sub__(self, other): + return FloatWrapper(self.x - other.x) + + def __truediv__(self, other): + return FloatWrapper(self.x / other.x) + + return FloatWrapper + def assertSame(self, first, second, msg=None): self.assertEqual(type(first), type(second), msg=msg) self.assertEqual(first, second, msg=msg) @@ -1308,79 +1388,9 @@ def __index__(self): self.assertSame(py_i(obj), 7) self.assertSame(jit_i(obj), 7) - def test_arithmetic(self): - - @jitclass([("x", types.intp)]) - class IntWrapper: - def __init__(self, value): - self.x = value - - def __lshift__(self, other): - return IntWrapper(self.x << other.x) - - def __rshift__(self, other): - return IntWrapper(self.x >> other.x) - - def __and__(self, other): - return IntWrapper(self.x & other.x) - - def __or__(self, other): - return IntWrapper(self.x | other.x) - - def __xor__(self, other): - return IntWrapper(self.x ^ other.x) - - @jitclass([("x", types.float64)]) - class FloatWrapper: - - def __init__(self, value): - self.x = value - - # def __eq__(self, other): - # print("Eq", self, other, self.x, other.x) - # return self.x == other.x - - # def __ne__(self, other): - # return self.x != other.x - - def __ge__(self, other): - return self.x >= other.x - - def __gt__(self, other): - return self.x > other.x - - def __le__(self, other): - return self.x <= other.x - - def __lt__(self, other): - return self.x < other.x - - def __add__(self, other): - return FloatWrapper(self.x + other.x) - - def __floordiv__(self, other): - return FloatWrapper(self.x // other.x) - - def __mod__(self, other): - return FloatWrapper(self.x % other.x) - - def __mul__(self, other): - return FloatWrapper(self.x * other.x) - - def __neg__(self, other): - return FloatWrapper(-self.x) - - def __pos__(self, other): - return FloatWrapper(+self.x) - - def __pow__(self, other): - return FloatWrapper(self.x ** other.x) - - def __sub__(self, other): - return FloatWrapper(self.x - other.x) - - def __truediv__(self, other): - return FloatWrapper(self.x / other.x) + def test_arithmetic_logical(self): + IntWrapper = self.get_int_warpper() + FloatWrapper = self.get_float_wrapper() float_py_funcs = [ # lambda x, y: x == y, @@ -1446,6 +1456,115 @@ def unwrap(value): self.assertEqual(expected, unwrap(py_f(jit_x, jit_y))) self.assertEqual(expected, unwrap(jit_f(jit_x, jit_y))) + def test_arithmetic_logical_inplace(self): + + # If __i*__ methods are not defined, should fall back to normal methods. + JitIntWrapper = self.get_int_warpper() + JitFloatWrapper = self.get_float_wrapper() + + PyIntWrapper = JitIntWrapper.mro()[1] + PyFloatWrapper = JitFloatWrapper.mro()[1] + + @jitclass([("x", types.intp)]) + class JitIntUpdateWrapper(PyIntWrapper): + def __init__(self, value): + self.x = value + + def __ilshift__(self, other): + return JitIntUpdateWrapper(self.x << other.x) + + def __irshift__(self, other): + return JitIntUpdateWrapper(self.x >> other.x) + + def __iand__(self, other): + return JitIntUpdateWrapper(self.x & other.x) + + def __ior__(self, other): + return JitIntUpdateWrapper(self.x | other.x) + + def __ixor__(self, other): + return JitIntUpdateWrapper(self.x ^ other.x) + + @jitclass({"x": types.float64}) + class JitFloatUpdateWrapper(PyFloatWrapper): + + def __init__(self, value): + self.x = value + + def __iadd__(self, other): + return JitFloatUpdateWrapper(self.x + 2.718 * other.x) + + def __ifloordiv__(self, other): + return JitFloatUpdateWrapper(self.x * 2.718 // other.x) + + def __imod__(self, other): + return JitFloatUpdateWrapper(self.x % (other.x + 1)) + + def __imul__(self, other): + return JitFloatUpdateWrapper(self.x * other.x + 1) + + def __ipow__(self, other): + return JitFloatUpdateWrapper(self.x ** other.x + 1) + + def __isub__(self, other): + return JitFloatUpdateWrapper(self.x - 3.1415 * other.x) + + def __itruediv__(self, other): + return JitFloatUpdateWrapper((self.x + 1) / other.x) + + PyIntUpdateWrapper = JitIntUpdateWrapper.mro()[1] + PyFloatUpdateWrapper = JitFloatUpdateWrapper.mro()[1] + + def get_update_func(op): + template = f""" +def f(x, y): + x {op}= y + return x +""" + namespace = {} + exec(template, namespace) + return namespace["f"] + + float_py_funcs = [get_update_func(op) for op in [ + "+", "//", "%", "*", "**", "-", "/", + ]] + int_py_funcs = [get_update_func(op) for op in [ + "<<", ">>", "&", "|", "^", + ]] + + test_values = [ + (0.0, 2.0), + (1.234, 3.1415), + (13.1, 1.01), + ] + + for jit_f, (py_cls, jit_cls), (x, y) in itertools.product( + map(njit, float_py_funcs), + [ + (PyFloatWrapper, JitFloatWrapper), + (PyFloatUpdateWrapper, JitFloatUpdateWrapper) + ], + test_values): + py_f = jit_f.py_func + + expected = py_f(py_cls(x), py_cls(y)).x + self.assertAlmostEqual(expected, py_f(jit_cls(x), jit_cls(y)).x) + self.assertAlmostEqual(expected, jit_f(jit_cls(x), jit_cls(y)).x) + + for jit_f, (py_cls, jit_cls), (x, y) in itertools.product( + map(njit, int_py_funcs), + [ + (PyIntWrapper, JitIntWrapper), + (PyIntUpdateWrapper, JitIntUpdateWrapper) + ], + test_values): + x, y = int(x), int(y) + py_f = jit_f.py_func + + expected = py_f(py_cls(x), py_cls(y)).x + self.assertEqual(expected, py_f(jit_cls(x), jit_cls(y)).x) + self.assertEqual(expected, jit_f(jit_cls(x), jit_cls(y)).x) + if __name__ == '__main__': unittest.main() From 3ae39c6d313d54f7985dc49bb54ab54e1c69d1ed Mon Sep 17 00:00:00 2001 From: Ethan Pronovost Date: Fri, 30 Oct 2020 21:21:19 -0700 Subject: [PATCH 05/25] Refactor to work with autogen listings So long as the base function has actual lines of code to reference we're ok. --- numba/experimental/jitclass/overloads.py | 226 ++++++++++++----------- 1 file changed, 123 insertions(+), 103 deletions(-) diff --git a/numba/experimental/jitclass/overloads.py b/numba/experimental/jitclass/overloads.py index bfc0313eb9a..7ebf4373c29 100644 --- a/numba/experimental/jitclass/overloads.py +++ b/numba/experimental/jitclass/overloads.py @@ -2,6 +2,8 @@ Overloads for ClassInstanceType for built-in functions that call dunder methods on an object. """ +from functools import wraps +import inspect import operator import sys @@ -9,118 +11,136 @@ from numba.core.types import ClassInstanceType -def _get_args(nargs=1): - return list("xyzabcdefg")[:nargs] +def _get_args(n_args): + assert n_args in (1, 2) + return list("xy")[:n_args] -def _simple_template(*attrs, nargs=1): - args = _get_args(nargs=nargs) - arg0 = args[0] +def class_instance_overload(target): + """ + Decorator to add an overload for target that applies when the first argument + is a ClassInstanceType. + """ + def decorator(func): + @wraps(func) + def wrapped(*args, **kwargs): + if not isinstance(args[0], ClassInstanceType): + return + return func(*args, **kwargs) - template = f""" -def handler({",".join(args)}): - if not isinstance({arg0}, ClassInstanceType): - return -""" - for attr in attrs: - assert isinstance(attr, str) - template += f""" - if "__{attr}__" in {arg0}.jit_methods: - return lambda {",".join(args)}: {arg0}.__{attr}__({",".join(args[1:])}) -""" - return template + params = list(inspect.signature(wrapped).parameters) + assert params == _get_args(len(params)) + return overload(target)(wrapped) + + return decorator -def _register_overload(func, template, glbls=None): +def extract_template(template, name): """ - Register overload handler for func calling class attribute attr. + Extract a code-generated function from a string template. """ - namespace = dict(ClassInstanceType=ClassInstanceType) - if glbls: - namespace.update(glbls) - + namespace = {} exec(template, namespace) - handler = namespace["handler"] - overload(func)(handler) + return namespace[name] -def register_simple_overload(func, *attrs, nargs=1): - template = _simple_template(*attrs, nargs=nargs) - _register_overload(func, template) +def register_simple_overload(func, *attrs, n_args=1,): + """ + Register an overload for func that checks for methods __attr__ for each + attr in attrs. + """ + # Use a template to set the signature correctly. + arg_names = _get_args(n_args) + template = f""" +def func({','.join(arg_names)}): + pass +""" + @wraps(extract_template(template, "func")) + def overload_func(*args, **kwargs): + options = [ + try_call_method(args[0], f"__{attr}__", n_args) + for attr in attrs + ] + return take_first(*options) -@overload(bool) -def class_bool(x): - if not isinstance(x, ClassInstanceType): - return + return class_instance_overload(func)(overload_func) - if "__bool__" in x.jit_methods: - return lambda x: x.__bool__() - if "__len__" in x.jit_methods: - return lambda x: x.__len__() != 0 +def try_call_method(cls_type, method, n_args=1): + """ + If method is defined for cls_type, return a callable that calls this method. + If not, return None. + """ + if method in cls_type.jit_methods: + arg_names = _get_args(n_args) + template = f""" +def func({','.join(arg_names)}): + return {arg_names[0]}.{method}({','.join(arg_names[1:])}) +""" + return extract_template(template, "func") + - return lambda x: True +def take_first(*options): + """ + Take the first non-None option. + """ + assert all(o is None or inspect.isfunction(o) for o in options), options + for o in options: + if o is not None: + return o -@overload(complex) -def class_complex(x): - if not isinstance(x, ClassInstanceType): - return +@class_instance_overload(bool) +def class_bool(x): + return take_first( + try_call_method(x, "__bool__"), + try_call_method(x, "__len__"), + lambda x: True, + ) - if "__complex__" in x.jit_methods: - return lambda x: x.__complex__() - return lambda x: complex(float(x)) +@class_instance_overload(complex) +def class_complex(x): + return take_first( + try_call_method(x, "__complex__"), + lambda x: complex(float(x)) + ) -@overload(operator.contains) +@class_instance_overload(operator.contains) def class_contains(x, y): # https://docs.python.org/3/reference/expressions.html#membership-test-operations - if not isinstance(x, ClassInstanceType): - return - - if "__contains__" in x.jit_methods: - return lambda x, y: x.__contains__(y) - + return try_call_method(x, "__contains__", 2) # TODO: use __iter__ if defined. -@overload(float) +@class_instance_overload(float) def class_float(x): - if not isinstance(x, ClassInstanceType): - return + options = [try_call_method(x, "__float__")] - if "__float__" in x.jit_methods: - return lambda x: x.__float__() + if (sys.version_info.major, sys.version_info.minor) >= (3, 8): + options.append(try_call_method(x, "__index__")) - if ((sys.version_info.major, sys.version_info.minor) >= (3, 8) and - "__index__" in x.jit_methods): - return lambda x: float(x.__index__()) + return take_first(*options) -@overload(int) +@class_instance_overload(int) def class_int(x): - if not isinstance(x, ClassInstanceType): - return + options = [try_call_method(x, "__int__")] - if "__int__" in x.jit_methods: - return lambda x: x.__int__() + if (sys.version_info.major, sys.version_info.minor) >= (3, 8): + options.append(try_call_method(x, "__index__")) - if ((sys.version_info.major, sys.version_info.minor) >= (3, 8) and - "__index__" in x.jit_methods): - return lambda x: x.__index__() + return take_first(*options) -@overload(str) +@class_instance_overload(str) def class_str(x): - if not isinstance(x, ClassInstanceType): - return - - if "__str__" in x.jit_methods: - return lambda x: x.__str__() - - return lambda x: repr(x) + return take_first( + try_call_method(x, "__str__"), + lambda x: repr(x), + ) register_simple_overload(abs, "abs") @@ -129,41 +149,41 @@ def class_str(x): # Comparison operators. # register_simple_overload(operator.eq, "eq", 2) # register_simple_overload(operator.ne, "ne", 2) -register_simple_overload(operator.ge, "ge", nargs=2) -register_simple_overload(operator.gt, "gt", nargs=2) -register_simple_overload(operator.le, "le", nargs=2) -register_simple_overload(operator.lt, "lt", nargs=2) +register_simple_overload(operator.ge, "ge", n_args=2) +register_simple_overload(operator.gt, "gt", n_args=2) +register_simple_overload(operator.le, "le", n_args=2) +register_simple_overload(operator.lt, "lt", n_args=2) # Arithmetic operators. -register_simple_overload(operator.add, "add", nargs=2) -register_simple_overload(operator.floordiv, "floordiv", nargs=2) -register_simple_overload(operator.lshift, "lshift", nargs=2) -register_simple_overload(operator.mod, "mod", nargs=2) -register_simple_overload(operator.mul, "mul", nargs=2) +register_simple_overload(operator.add, "add", n_args=2) +register_simple_overload(operator.floordiv, "floordiv", n_args=2) +register_simple_overload(operator.lshift, "lshift", n_args=2) +register_simple_overload(operator.mul, "mul", n_args=2) +register_simple_overload(operator.mod, "mod", n_args=2) register_simple_overload(operator.neg, "neg") register_simple_overload(operator.pos, "pos") -register_simple_overload(operator.pow, "pow", nargs=2) -register_simple_overload(operator.rshift, "rshift", nargs=2) -register_simple_overload(operator.sub, "sub", nargs=2) -register_simple_overload(operator.truediv, "truediv", nargs=2) +register_simple_overload(operator.pow, "pow", n_args=2) +register_simple_overload(operator.rshift, "rshift", n_args=2) +register_simple_overload(operator.sub, "sub", n_args=2) +register_simple_overload(operator.truediv, "truediv", n_args=2) # Inplace arithmetic operators. -register_simple_overload(operator.iadd, "iadd", "add", nargs=2) -register_simple_overload(operator.ifloordiv, "ifloordiv", "floordiv", nargs=2) -register_simple_overload(operator.ilshift, "ilshift", "lshift", nargs=2) -register_simple_overload(operator.imod, "imod", "mod", nargs=2) -register_simple_overload(operator.imul, "imul", "mul", nargs=2) -register_simple_overload(operator.ipow, "ipow", "pow", nargs=2) -register_simple_overload(operator.irshift, "irshift", "rshift", nargs=2) -register_simple_overload(operator.isub, "isub", "sub", nargs=2) -register_simple_overload(operator.itruediv, "itruediv", "truediv", nargs=2) +register_simple_overload(operator.iadd, "iadd", "add", n_args=2) +register_simple_overload(operator.ifloordiv, "ifloordiv", "floordiv", n_args=2) +register_simple_overload(operator.ilshift, "ilshift", "lshift", n_args=2) +register_simple_overload(operator.imul, "imul", "mul", n_args=2) +register_simple_overload(operator.imod, "imod", "mod", n_args=2) +register_simple_overload(operator.ipow, "ipow", "pow", n_args=2) +register_simple_overload(operator.irshift, "irshift", "rshift", n_args=2) +register_simple_overload(operator.isub, "isub", "sub", n_args=2) +register_simple_overload(operator.itruediv, "itruediv", "truediv", n_args=2) # Logical operators. -register_simple_overload(operator.and_, "and", nargs=2) -register_simple_overload(operator.or_, "or", nargs=2) -register_simple_overload(operator.xor, "xor", nargs=2) +register_simple_overload(operator.and_, "and", n_args=2) +register_simple_overload(operator.or_, "or", n_args=2) +register_simple_overload(operator.xor, "xor", n_args=2) # Inplace logical operators. -register_simple_overload(operator.iand, "iand", "and", nargs=2) -register_simple_overload(operator.ior, "ior", "or", nargs=2) -register_simple_overload(operator.ixor, "ixor", "xor", nargs=2) +register_simple_overload(operator.iand, "iand", "and", n_args=2) +register_simple_overload(operator.ior, "ior", "or", n_args=2) +register_simple_overload(operator.ixor, "ixor", "xor", n_args=2) From 53dcf92ab940b92b1346b9d59e69e0005398440e Mon Sep 17 00:00:00 2001 From: Ethan Pronovost Date: Fri, 30 Oct 2020 21:35:05 -0700 Subject: [PATCH 06/25] Fix float overload --- numba/experimental/jitclass/overloads.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/numba/experimental/jitclass/overloads.py b/numba/experimental/jitclass/overloads.py index 7ebf4373c29..3383406913f 100644 --- a/numba/experimental/jitclass/overloads.py +++ b/numba/experimental/jitclass/overloads.py @@ -119,8 +119,11 @@ def class_contains(x, y): def class_float(x): options = [try_call_method(x, "__float__")] - if (sys.version_info.major, sys.version_info.minor) >= (3, 8): - options.append(try_call_method(x, "__index__")) + if ( + (sys.version_info.major, sys.version_info.minor) >= (3, 8) + and "__index__" in x.jit_methods + ): + options.append(lambda x: float(x.__index__())) return take_first(*options) From aaa74ae75c9481f583ba61826890740aebfbecc0 Mon Sep 17 00:00:00 2001 From: Ethan Pronovost Date: Sat, 31 Oct 2020 19:56:57 -0700 Subject: [PATCH 07/25] Support for hash, eq, and ne Figured out a way around the hash limitation. --- numba/experimental/jitclass/boxing.py | 32 ++++++++++--- numba/experimental/jitclass/overloads.py | 17 ++++++- numba/tests/test_jitclasses.py | 60 ++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 8 deletions(-) diff --git a/numba/experimental/jitclass/boxing.py b/numba/experimental/jitclass/boxing.py index f5c7d157db8..722d3b373e8 100644 --- a/numba/experimental/jitclass/boxing.py +++ b/numba/experimental/jitclass/boxing.py @@ -8,8 +8,9 @@ from llvmlite import ir from numba.core import types, cgutils +from numba.core.decorators import njit from numba.core.pythonapi import box, unbox, NativeValue -from numba import njit +from numba.core.typing.typeof import typeof_impl from numba.experimental.jitclass import _box @@ -103,13 +104,14 @@ def _specialize_box(typ): "__contains__", "__float__", "__getitem__", + "__hash__", "__index__", "__int__", "__len__", "__setitem__", "__str__", - # "__eq__", - # "__ne__", + "__eq__", + "__ne__", "__ge__", "__gt__", "__le__", @@ -144,9 +146,13 @@ def _specialize_box(typ): "__ixor__", } for name, func in typ.methods.items(): - if (not (name.startswith("__") and name.endswith("__")) or - name in supported_dunders): - dct[name] = _generate_method(name, func) + if ( + name.startswith("__") + and name.endswith("__") + and name not in supported_dunders + ): + continue + dct[name] = _generate_method(name, func) # Inject static methods as class members for name, func in typ.static_methods.items(): @@ -236,3 +242,17 @@ def access_member(member_offset): c.context.nrt.incref(c.builder, typ, ret) return NativeValue(ret, is_error=c.pyapi.c_api_error()) + + +# Add a typeof_impl implementation for boxed jitclasses to short-circut the +# various tests in typeof. This is needed for jitclasses which implement a +# custom hash method. Without this, typeof_impl will return None, and one of the +# later attempts to determine the type of the jitclass (before checking for +# _numba_type_) will look up the object in a dictionary, triggering the hash +# method. This will cause the dispatcher to determine the call signature of the +# jit decorated obj.__hash__ method, which will call typeof(obj), and thus +# infinite loop. +# This implementation is here instead of in typeof.py to avoid circular imports. +@typeof_impl.register(_box.Box) +def _typeof_jitclass_box(val, c): + return getattr(type(val), "_numba_type_") diff --git a/numba/experimental/jitclass/overloads.py b/numba/experimental/jitclass/overloads.py index 3383406913f..c310c0a2bbd 100644 --- a/numba/experimental/jitclass/overloads.py +++ b/numba/experimental/jitclass/overloads.py @@ -146,12 +146,25 @@ def class_str(x): ) +@class_instance_overload(operator.eq) +def class_eq(x, y): + # TODO: Fallback to x is y. + return try_call_method(x, "__eq__", 2) + + +@class_instance_overload(operator.ne) +def class_ne(x, y): + return take_first( + try_call_method(x, "__ne__", 2), + lambda x, y: not (x == y), + ) + + register_simple_overload(abs, "abs") register_simple_overload(len, "len") # Comparison operators. -# register_simple_overload(operator.eq, "eq", 2) -# register_simple_overload(operator.ne, "ne", 2) +register_simple_overload(hash, "hash") register_simple_overload(operator.ge, "ge", n_args=2) register_simple_overload(operator.gt, "gt", n_args=2) register_simple_overload(operator.le, "le", n_args=2) diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index d699222933b..c4e40de3ad7 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -1634,6 +1634,66 @@ def f(x, y): self.assertEqual(expected, py_f(jit_cls(x), jit_cls(y)).x) self.assertEqual(expected, jit_f(jit_cls(x), jit_cls(y)).x) + def test_hash_eq_ne(self): + + class HashEqTest: + x: int + + def __init__(self, x): + self.x = x + + def __hash__(self): + return self.x % 10 + + def __eq__(self, o): + return (self.x - o.x) % 20 == 0 + + class HashEqNeTest(HashEqTest): + def __ne__(self, o): + return (self.x - o.x) % 20 > 1 + + def py_hash(x): + return hash(x) + + def py_eq(x, y): + return x == y + + def py_ne(x, y): + return x != y + + def identity_decorator(f): + return f + + comparisons = [ + (0, 1), # Will give different ne results. + (2, 22), + (7, 10), + (3, 3), + ] + + for base_cls, use_jit in itertools.product( + [HashEqTest, HashEqNeTest], [False, True] + ): + decorator = njit if use_jit else identity_decorator + hash_func = decorator(py_hash) + eq_func = decorator(py_eq) + ne_func = decorator(py_ne) + + jit_cls = jitclass(base_cls) + + for v in [0, 2, 10, 24, -8]: + self.assertEqual(hash_func(jit_cls(v)), v % 10) + + for x, y in comparisons: + self.assertEqual( + eq_func(jit_cls(x), jit_cls(y)), + base_cls(x) == base_cls(y), + ) + self.assertEqual( + ne_func(jit_cls(x), jit_cls(y)), + base_cls(x) != base_cls(y), + ) + if __name__ == "__main__": unittest.main() From ae38aaf091b6c795526a79398b95955af64b6706 Mon Sep 17 00:00:00 2001 From: Ethan Pronovost Date: Sat, 31 Oct 2020 20:06:34 -0700 Subject: [PATCH 08/25] Remove commented code --- numba/tests/test_jitclasses.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index c4e40de3ad7..1b955b01a31 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -1193,13 +1193,6 @@ class FloatWrapper: def __init__(self, value): self.x = value - # def __eq__(self, other): - # print("Eq", self, other, self.x, other.x) - # return self.x == other.x - - # def __ne__(self, other): - # return self.x != other.x - def __ge__(self, other): return self.x >= other.x From 2157ff1f2fedbcb5f296b3d5a8bdf9ac68b05687 Mon Sep 17 00:00:00 2001 From: Ethan Pronovost Date: Sat, 31 Oct 2020 20:20:48 -0700 Subject: [PATCH 09/25] Update docs --- docs/source/user/jitclass.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/user/jitclass.rst b/docs/source/user/jitclass.rst index a29f3349282..47e0a41cade 100644 --- a/docs/source/user/jitclass.rst +++ b/docs/source/user/jitclass.rst @@ -181,6 +181,7 @@ compiled functions: * calling methods (e.g. ``mybag.increment(3)``); * calling static methods as instance attributes (e.g. ``mybag.add(1, 1)``); * calling static methods as class attributes (e.g. ``Bag.add(1, 2)``); +* using select dunder methods (e.g. ``__add__`` with ``mybag + otherbag``); Using jitclasses in Numba compiled function is more efficient. Short methods can be inlined (at the discretion of LLVM inliner). @@ -195,6 +196,8 @@ Calling static methods as class attributes is only supported outside of the class definition (i.e. code cannot call ``Bag.add()`` from within another method of ``Bag``). +See :ghfile:`numba/experimental/jitclass/boxing.py` for the list of supported +dunder methods. Limitations =========== From e010694be4054c34c1354d8f3ece2cd81feaff09 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Wed, 20 Oct 2021 10:46:00 +0100 Subject: [PATCH 10/25] Type errors for Complex should raise a NumbaTypeError --- numba/core/typing/builtins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numba/core/typing/builtins.py b/numba/core/typing/builtins.py index fe7827ba496..9dfb81b51f7 100644 --- a/numba/core/typing/builtins.py +++ b/numba/core/typing/builtins.py @@ -989,7 +989,7 @@ def generic(self, args, kws): if len(args) == 1: [arg] = args if arg not in types.number_domain: - raise TypeError("complex() only support for numbers") + raise errors.NumbaTypeError("complex() only support for numbers") if arg == types.float32: return signature(types.complex64, arg) else: @@ -999,7 +999,7 @@ def generic(self, args, kws): [real, imag] = args if (real not in types.number_domain or imag not in types.number_domain): - raise TypeError("complex() only support for numbers") + raise errors.NumbaTypeError("complex() only support for numbers") if real == imag == types.float32: return signature(types.complex64, real, imag) else: From 80520e4effb1616f11f97ea950ab62d5dd950768 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Wed, 20 Oct 2021 11:34:48 +0100 Subject: [PATCH 11/25] Small fixups for jitclass testsuite - A type in the name of `get_int_wrapper()` - Swap a docstring for a comment, because the docstring hides the test name when running the testsuite with -v. --- numba/tests/test_jitclasses.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index 1fdc33d8d3a..e9fa89fd81b 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -1163,7 +1163,7 @@ def __str__(self): return "PyList non-empty" @staticmethod - def get_int_warpper(): + def get_int_wrapper(): @jitclass([("x", types.intp)]) class IntWrapper: def __init__(self, value): @@ -1240,9 +1240,8 @@ def assertSame(self, first, second, msg=None): self.assertEqual(first, second, msg=msg) def test_overloads(self): - """ - Check that the dunder methods are exposed on ClassInstanceType. - """ + # Check that the dunder methods are exposed on ClassInstanceType. + JitList = jitclass({"x": types.List(types.intp)})(self.PyList) py_funcs = [ @@ -1452,7 +1451,7 @@ def __index__(self): self.assertSame(jit_i(obj), 7) def test_arithmetic_logical(self): - IntWrapper = self.get_int_warpper() + IntWrapper = self.get_int_wrapper() FloatWrapper = self.get_float_wrapper() float_py_funcs = [ @@ -1522,7 +1521,7 @@ def unwrap(value): def test_arithmetic_logical_inplace(self): # If __i*__ methods are not defined, should fall back to normal methods. - JitIntWrapper = self.get_int_warpper() + JitIntWrapper = self.get_int_wrapper() JitFloatWrapper = self.get_float_wrapper() PyIntWrapper = JitIntWrapper.mro()[1] From 987093f424769c7a27492f75a46542fa884bbf6d Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Wed, 20 Oct 2021 11:44:12 +0100 Subject: [PATCH 12/25] Test eq and ne with float and int wrappers --- numba/tests/test_jitclasses.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index e9fa89fd81b..79873593657 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -1169,6 +1169,12 @@ class IntWrapper: def __init__(self, value): self.x = value + def __eq__(self, other): + return self.x == other.x + + def __hash__(self): + return self.x + def __lshift__(self, other): return IntWrapper(self.x << other.x) @@ -1194,6 +1200,12 @@ class FloatWrapper: def __init__(self, value): self.x = value + def __eq__(self, other): + return self.x == other.x + + def __hash__(self): + return self.x + def __ge__(self, other): return self.x >= other.x @@ -1455,8 +1467,8 @@ def test_arithmetic_logical(self): FloatWrapper = self.get_float_wrapper() float_py_funcs = [ - # lambda x, y: x == y, - # lambda x, y: x != y, + lambda x, y: x == y, + lambda x, y: x != y, lambda x, y: x >= y, lambda x, y: x > y, lambda x, y: x <= y, @@ -1470,6 +1482,8 @@ def test_arithmetic_logical(self): lambda x, y: x / y, ] int_py_funcs = [ + lambda x, y: x == y, + lambda x, y: x != y, lambda x, y: x << y, lambda x, y: x >> y, lambda x, y: x & y, From 56a2d09de8560c1ca2108e03cab22e3f672a4105 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Wed, 20 Oct 2021 11:48:50 +0100 Subject: [PATCH 13/25] Remove ineg and ipos from the list of supported jitclass dunders As far as I can tell, these don't exist. --- numba/experimental/jitclass/boxing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/numba/experimental/jitclass/boxing.py b/numba/experimental/jitclass/boxing.py index 722d3b373e8..7d70f561a1b 100644 --- a/numba/experimental/jitclass/boxing.py +++ b/numba/experimental/jitclass/boxing.py @@ -135,8 +135,6 @@ def _specialize_box(typ): "__ilshift__", "__imod__", "__imul__", - "__ineg__", - "__ipos__", "__ipow__", "__irshift__", "__isub__", From 5396621801e9fc87e0378247b4ca180df57e856e Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Wed, 20 Oct 2021 11:56:19 +0100 Subject: [PATCH 14/25] List supported jitclass dunder methods in docs --- docs/source/user/jitclass.rst | 57 +++++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/docs/source/user/jitclass.rst b/docs/source/user/jitclass.rst index e6a0b414277..73cf406d992 100644 --- a/docs/source/user/jitclass.rst +++ b/docs/source/user/jitclass.rst @@ -197,8 +197,61 @@ Calling static methods as class attributes is only supported outside of the class definition (i.e. code cannot call ``Bag.add()`` from within another method of ``Bag``). -See :ghfile:`numba/experimental/jitclass/boxing.py` for the list of supported -dunder methods. + +Supported dunder methods +------------------------ + +The following dunder methods may be defined for jitclasses: + +* ``__abs__`` +* ``__bool__`` +* ``__complex__`` +* ``__contains__`` +* ``__float__`` +* ``__getitem__`` +* ``__hash__`` +* ``__index__`` +* ``__int__`` +* ``__len__`` +* ``__setitem__`` +* ``__str__`` +* ``__eq__`` +* ``__ne__`` +* ``__ge__`` +* ``__gt__`` +* ``__le__`` +* ``__lt__`` +* ``__add__`` +* ``__floordiv__`` +* ``__lshift__`` +* ``__mod__`` +* ``__mul__`` +* ``__neg__`` +* ``__pos__`` +* ``__pow__`` +* ``__rshift__`` +* ``__sub__`` +* ``__truediv__`` +* ``__and__`` +* ``__or__`` +* ``__xor__`` +* ``__iadd__`` +* ``__ifloordiv__`` +* ``__ilshift__`` +* ``__imod__`` +* ``__imul__`` +* ``__ipow__`` +* ``__irshift__`` +* ``__isub__`` +* ``__itruediv__`` +* ``__iand__`` +* ``__ior__`` +* ``__ixor__`` + +Refer to the `Python Data Model documentation +`_ for descriptions of +these methods. + Limitations =========== From 352379a22366793e84dcad2e2d80ca9e1a0524b7 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Mon, 13 Jun 2022 23:21:55 +0100 Subject: [PATCH 15/25] Revisions based on PR #5877 feedback - Mixins should go first in inheritance list - numba.core.utils.PYVERSION should be used to check the Python version --- numba/experimental/jitclass/overloads.py | 5 +++-- numba/tests/test_jitclasses.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/numba/experimental/jitclass/overloads.py b/numba/experimental/jitclass/overloads.py index c310c0a2bbd..4fb526da67c 100644 --- a/numba/experimental/jitclass/overloads.py +++ b/numba/experimental/jitclass/overloads.py @@ -9,6 +9,7 @@ from numba.core.extending import overload from numba.core.types import ClassInstanceType +from numba.core.utils import PYVERSION def _get_args(n_args): @@ -120,7 +121,7 @@ def class_float(x): options = [try_call_method(x, "__float__")] if ( - (sys.version_info.major, sys.version_info.minor) >= (3, 8) + PYVERSION >= (3, 8) and "__index__" in x.jit_methods ): options.append(lambda x: float(x.__index__())) @@ -132,7 +133,7 @@ def class_float(x): def class_int(x): options = [try_call_method(x, "__int__")] - if (sys.version_info.major, sys.version_info.minor) >= (3, 8): + if PYVERSION >= (3, 8): options.append(try_call_method(x, "__index__")) return take_first(*options) diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index 79873593657..0f13f54b47a 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -15,6 +15,7 @@ from numba.core.dispatcher import Dispatcher from numba.core.errors import LoweringError, TypingError from numba.core.runtime.nrt import MemInfo +from numba.core.utils import PYVERSION from numba.experimental import jitclass from numba.experimental.jitclass import _box from numba.experimental.jitclass.base import JitClassType @@ -1114,7 +1115,7 @@ def __init__(self): self.assertDictEqual(JitTest2.class_type.struct, spec) -class TestJitClassOverloads(TestCase, MemoryLeakMixin): +class TestJitClassOverloads(MemoryLeakMixin, TestCase): class PyList: def __init__(self): @@ -1413,7 +1414,7 @@ def __index__(self): obj = IndexClass() - if sys.version[:3] >= "3.8": + if PYVERSION >= (3, 8): self.assertSame(py_c(obj), complex(1)) self.assertSame(jit_c(obj), complex(1)) self.assertSame(py_f(obj), 1.) From 5362fd99e97a13856b84e519107d3997c3c32d51 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Tue, 14 Jun 2022 11:02:28 +0100 Subject: [PATCH 16/25] Add failing test showing issue with jitclass bool Identified in https://github.com/numba/numba/pull/5877/files#r892530803, in which the implementation returns the result of `__len__()` when there is no `__bool__()` implementation. Fails with: ``` ====================================================================== FAIL: test_bool_fallback (numba.tests.test_jitclasses.TestJitClassOverloads) ---------------------------------------------------------------------- Traceback (most recent call last): File "C:\Users\gmarkall\work\numbadev\numba\numba\tests\test_jitclasses.py", line 1742, in test_bool_fallback self.assertEqual(py_class_2_bool, jitted_class_2_bool) AssertionError: True != 2 ``` --- numba/tests/test_jitclasses.py | 43 ++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index 0f13f54b47a..b91e02003c2 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -1702,6 +1702,49 @@ def identity_decorator(f): base_cls(x) != base_cls(y), ) + def test_bool_fallback(self): + # Check that the fallback to using len(obj) to determine truth of an + # object is implemented correctly as per + # https://docs.python.org/3/library/stdtypes.html#truth-value-testing + # + # Relevant points: + # + # "By default, an object is considered true unless its class defines + # either a __bool__() method that returns False or a __len__() method + # that returns zero, when called with the object." + # + # and: + # + # "Operations and built-in functions that have a Boolean result always + # return 0 or False for false and 1 or True for true, unless otherwise + # stated." + + class NoBoolHasLen: + def __init__(self, val): + self.val = val + + def __len__(self): + return self.val + + def get_bool(self): + return bool(self) + + py_class = NoBoolHasLen + jitted_class = jitclass([('val', types.int64)])(py_class) + + py_class_0_bool = py_class(0).get_bool() + py_class_2_bool = py_class(2).get_bool() + jitted_class_0_bool = jitted_class(0).get_bool() + jitted_class_2_bool = jitted_class(2).get_bool() + + # Truth values from bool(obj) should be equal + self.assertEqual(py_class_0_bool, jitted_class_0_bool) + self.assertEqual(py_class_2_bool, jitted_class_2_bool) + + # Truth values from bool(obj) should be the same type + self.assertEqual(type(py_class_0_bool), type(jitted_class_0_bool)) + self.assertEqual(type(py_class_2_bool), type(jitted_class_2_bool)) + if __name__ == "__main__": unittest.main() From acf4fe3d6fa54bff1f68c11e8a64af84473ad312 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Tue, 14 Jun 2022 11:15:15 +0100 Subject: [PATCH 17/25] Fix the jitclass bool() fallback logic --- numba/experimental/jitclass/overloads.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/numba/experimental/jitclass/overloads.py b/numba/experimental/jitclass/overloads.py index 4fb526da67c..431ebad424c 100644 --- a/numba/experimental/jitclass/overloads.py +++ b/numba/experimental/jitclass/overloads.py @@ -94,11 +94,17 @@ def take_first(*options): @class_instance_overload(bool) def class_bool(x): - return take_first( - try_call_method(x, "__bool__"), - try_call_method(x, "__len__"), - lambda x: True, - ) + using_bool_impl = try_call_method(x, "__bool__") + + if '__len__' in x.jit_methods: + def using_len_impl(x): + return bool(len(x)) + else: + using_len_impl = None + + always_true_impl = lambda x: True + + return take_first(using_bool_impl, using_len_impl, always_true_impl) @class_instance_overload(complex) From 15841e23f09690a4d352c50052e839a7bd0e354c Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Tue, 14 Jun 2022 11:24:27 +0100 Subject: [PATCH 18/25] Add a test for the default bool class implementation --- numba/tests/test_jitclasses.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index b91e02003c2..513e6472927 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -1745,6 +1745,29 @@ def get_bool(self): self.assertEqual(type(py_class_0_bool), type(jitted_class_0_bool)) self.assertEqual(type(py_class_2_bool), type(jitted_class_2_bool)) + def test_bool_fallback_default(self): + # Similar to test_bool_fallback, but checks the case where there is no + # __bool__() or __len__() defined, so the object should always be True. + + class NoBoolNoLen: + def __init__(self): + pass + + def get_bool(self): + return bool(self) + + py_class = NoBoolNoLen + jitted_class = jitclass([])(py_class) + + py_class_bool = py_class().get_bool() + jitted_class_bool = jitted_class().get_bool() + + # Truth values from bool(obj) should be equal + self.assertEqual(py_class_bool, jitted_class_bool) + + # Truth values from bool(obj) should be the same type + self.assertEqual(type(py_class_bool), type(jitted_class_bool)) + if __name__ == "__main__": unittest.main() From 85af969b7ca2af742a780652b57a40cc50bc9581 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Tue, 14 Jun 2022 11:32:18 +0100 Subject: [PATCH 19/25] Fix flake8 --- numba/experimental/jitclass/overloads.py | 1 - numba/tests/test_jitclasses.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/numba/experimental/jitclass/overloads.py b/numba/experimental/jitclass/overloads.py index 431ebad424c..40b108344b6 100644 --- a/numba/experimental/jitclass/overloads.py +++ b/numba/experimental/jitclass/overloads.py @@ -5,7 +5,6 @@ from functools import wraps import inspect import operator -import sys from numba.core.extending import overload from numba.core.types import ClassInstanceType diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index 513e6472927..02c8ff1ee78 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -2,7 +2,6 @@ import itertools import pickle import random -import sys import typing as pt import unittest @@ -1702,7 +1701,7 @@ def identity_decorator(f): base_cls(x) != base_cls(y), ) - def test_bool_fallback(self): + def test_bool_fallback_len(self): # Check that the fallback to using len(obj) to determine truth of an # object is implemented correctly as per # https://docs.python.org/3/library/stdtypes.html#truth-value-testing From 33f3885ac6a0a143cda32b70f88db298ab7fdc00 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Wed, 15 Jun 2022 11:29:07 +0100 Subject: [PATCH 20/25] Begin implementation of reflected comparison ops Still requires: - Testing - Proper implementation for __eq__ and __ne__ - Refactoring for consistency with the rest of the file --- numba/experimental/jitclass/overloads.py | 25 +++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/numba/experimental/jitclass/overloads.py b/numba/experimental/jitclass/overloads.py index 40b108344b6..614ea967684 100644 --- a/numba/experimental/jitclass/overloads.py +++ b/numba/experimental/jitclass/overloads.py @@ -165,16 +165,31 @@ def class_ne(x, y): lambda x, y: not (x == y), ) +def register_reflected_overload(func, meth_forward, meth_reflected): + def class_lt(x, y): + normal_impl = try_call_method(x, meth_forward, 2) + + if meth_reflected in y.jit_methods: + def reflected_impl(x, y): + return y > x + else: + reflected_impl = None + + return take_first(normal_impl, reflected_impl) + + class_instance_overload(func)(class_lt) + + register_simple_overload(abs, "abs") register_simple_overload(len, "len") +register_simple_overload(hash, "hash") # Comparison operators. -register_simple_overload(hash, "hash") -register_simple_overload(operator.ge, "ge", n_args=2) -register_simple_overload(operator.gt, "gt", n_args=2) -register_simple_overload(operator.le, "le", n_args=2) -register_simple_overload(operator.lt, "lt", n_args=2) +register_reflected_overload(operator.ge, "__ge__", "__le__") +register_reflected_overload(operator.gt, "__gt__", "__lt__") +register_reflected_overload(operator.le, "__le__", "__ge__") +register_reflected_overload(operator.lt, "__lt__", "__gt__") # Arithmetic operators. register_simple_overload(operator.add, "add", n_args=2) From 64991af4a342fa81fd41bd497f40a86ebc8e73b6 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Fri, 17 Jun 2022 10:43:27 +0100 Subject: [PATCH 21/25] Fix flake8 --- numba/experimental/jitclass/overloads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numba/experimental/jitclass/overloads.py b/numba/experimental/jitclass/overloads.py index 614ea967684..be4d4b2881b 100644 --- a/numba/experimental/jitclass/overloads.py +++ b/numba/experimental/jitclass/overloads.py @@ -165,6 +165,7 @@ def class_ne(x, y): lambda x, y: not (x == y), ) + def register_reflected_overload(func, meth_forward, meth_reflected): def class_lt(x, y): normal_impl = try_call_method(x, meth_forward, 2) @@ -180,7 +181,6 @@ def reflected_impl(x, y): class_instance_overload(func)(class_lt) - register_simple_overload(abs, "abs") register_simple_overload(len, "len") register_simple_overload(hash, "hash") From 210f7b43f7a063ce479bdf0d9fc3de62b9afb34f Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Fri, 17 Jun 2022 10:49:05 +0100 Subject: [PATCH 22/25] Make register_reflected_overload interface consistent with other methods --- numba/experimental/jitclass/overloads.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/numba/experimental/jitclass/overloads.py b/numba/experimental/jitclass/overloads.py index be4d4b2881b..9a80294a453 100644 --- a/numba/experimental/jitclass/overloads.py +++ b/numba/experimental/jitclass/overloads.py @@ -168,9 +168,9 @@ def class_ne(x, y): def register_reflected_overload(func, meth_forward, meth_reflected): def class_lt(x, y): - normal_impl = try_call_method(x, meth_forward, 2) + normal_impl = try_call_method(x, f"__{meth_forward}__", 2) - if meth_reflected in y.jit_methods: + if f"__{meth_reflected}__" in y.jit_methods: def reflected_impl(x, y): return y > x else: @@ -186,10 +186,10 @@ def reflected_impl(x, y): register_simple_overload(hash, "hash") # Comparison operators. -register_reflected_overload(operator.ge, "__ge__", "__le__") -register_reflected_overload(operator.gt, "__gt__", "__lt__") -register_reflected_overload(operator.le, "__le__", "__ge__") -register_reflected_overload(operator.lt, "__lt__", "__gt__") +register_reflected_overload(operator.ge, "ge", "le") +register_reflected_overload(operator.gt, "gt", "lt") +register_reflected_overload(operator.le, "le", "ge") +register_reflected_overload(operator.lt, "lt", "gt") # Arithmetic operators. register_simple_overload(operator.add, "add", n_args=2) From 97ccc544dba623c735865806e7965f4873c68a90 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Fri, 17 Jun 2022 11:16:38 +0100 Subject: [PATCH 23/25] Add jitclass operator reflection tests Also use register_reflected_overload for __eq__. This test exposes a bug - if a class defines __eq__ but not __hash__, compilation fails due to __hash__ being None. This needs resolving. --- numba/experimental/jitclass/overloads.py | 13 +++--- numba/tests/test_jitclasses.py | 53 ++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 6 deletions(-) diff --git a/numba/experimental/jitclass/overloads.py b/numba/experimental/jitclass/overloads.py index 9a80294a453..cf0e0ff3e08 100644 --- a/numba/experimental/jitclass/overloads.py +++ b/numba/experimental/jitclass/overloads.py @@ -152,14 +152,11 @@ def class_str(x): ) -@class_instance_overload(operator.eq) -def class_eq(x, y): - # TODO: Fallback to x is y. - return try_call_method(x, "__eq__", 2) - - @class_instance_overload(operator.ne) def class_ne(x, y): + # This doesn't use register_reflected_overload like the other operators + # because it falls back to inverting __eq__ rather than reflecting its + # arguments (as per the definition of the Python data model). return take_first( try_call_method(x, "__ne__", 2), lambda x, y: not (x == y), @@ -191,6 +188,10 @@ def reflected_impl(x, y): register_reflected_overload(operator.le, "le", "ge") register_reflected_overload(operator.lt, "lt", "gt") +# Note that eq is missing support for fallback to `x is y`, but `is` and +# `operator.is` are presently unsupported in general. +register_reflected_overload(operator.eq, "eq", "eq") + # Arithmetic operators. register_simple_overload(operator.add, "add", n_args=2) register_simple_overload(operator.floordiv, "floordiv", n_args=2) diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index 1ada3da9856..d42caf03155 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -1808,6 +1808,59 @@ def get_bool(self): # Truth values from bool(obj) should be the same type self.assertEqual(type(py_class_bool), type(jitted_class_bool)) + def test_operator_reflection(self): + class OperatorsDefined: + def __init__(self, x): + self.x = x + + def __eq__(self, other): + return self.x == other.x + + def __hash__(self): + # FIXME: This should not be needed fix for #5877! + return None + + def __le__(self, other): + return self.x <= other.x + + def __lt__(self, other): + return self.x < other.x + + def __ge__(self, other): + return self.x >= other.x + + def __gt__(self, other): + return self.x > other.x + + class NoOperatorsDefined: + def __init__(self, x): + self.x = x + + spec = [('x', types.int32)] + JitOperatorsDefined = jitclass(spec)(OperatorsDefined) + JitNoOperatorsDefined = jitclass(spec)(NoOperatorsDefined) + + py_ops_defined = OperatorsDefined(2) + py_ops_not_defined = NoOperatorsDefined(3) + + jit_ops_defined = JitOperatorsDefined(2) + jit_ops_not_defined = JitNoOperatorsDefined(3) + + self.assertEqual(py_ops_not_defined == py_ops_defined, + jit_ops_not_defined == jit_ops_defined) + + self.assertEqual(py_ops_not_defined <= py_ops_defined, + jit_ops_not_defined <= jit_ops_defined) + + self.assertEqual(py_ops_not_defined < py_ops_defined, + jit_ops_not_defined < jit_ops_defined) + + self.assertEqual(py_ops_not_defined >= py_ops_defined, + jit_ops_not_defined >= jit_ops_defined) + + self.assertEqual(py_ops_not_defined > py_ops_defined, + jit_ops_not_defined > jit_ops_defined) + if __name__ == "__main__": unittest.main() From 92f9495bacfdd154d31cfb4e794e0669c8c7f17b Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Fri, 17 Jun 2022 11:49:17 +0100 Subject: [PATCH 24/25] Handle implicit __hash__ member and add test --- numba/experimental/jitclass/base.py | 6 ++++++ numba/tests/test_jitclasses.py | 19 +++++++++++++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/numba/experimental/jitclass/base.py b/numba/experimental/jitclass/base.py index 0153edc5419..d61dc60e3f4 100644 --- a/numba/experimental/jitclass/base.py +++ b/numba/experimental/jitclass/base.py @@ -293,6 +293,12 @@ def _drop_ignored_attrs(dct): elif getattr(v, '__objclass__', None) is object: drop.add(k) + # If a class defines __eq__ but not __hash__, __hash__ is implicitly set to + # None. This is a class member, and class members are not presently + # supported. + if '__hash__' in dct and dct['__hash__'] is None: + drop.add('__hash__') + for k in drop: del dct[k] diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index d42caf03155..091ee02b39b 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -1816,10 +1816,6 @@ def __init__(self, x): def __eq__(self, other): return self.x == other.x - def __hash__(self): - # FIXME: This should not be needed fix for #5877! - return None - def __le__(self, other): return self.x <= other.x @@ -1861,6 +1857,21 @@ def __init__(self, x): self.assertEqual(py_ops_not_defined > py_ops_defined, jit_ops_not_defined > jit_ops_defined) + def test_implicit_hash_compiles(self): + # Ensure that classes with __hash__ implicitly defined as None due to + # the presence of __eq__ are correctly handled by ignoring the __hash__ + # class member. + class ImplicitHash: + def __init__(self): + pass + + def __eq__(self, other): + return False + + jitted = jitclass([])(ImplicitHash) + instance = jitted() + + self.assertFalse(instance == instance) if __name__ == "__main__": unittest.main() From 6bfb69b493b5484518e9c9dbb49e20d451d6d66b Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Fri, 17 Jun 2022 12:17:29 +0100 Subject: [PATCH 25/25] Fix flake8 --- numba/tests/test_jitclasses.py | 1 + 1 file changed, 1 insertion(+) diff --git a/numba/tests/test_jitclasses.py b/numba/tests/test_jitclasses.py index 091ee02b39b..f9dc74dacb8 100644 --- a/numba/tests/test_jitclasses.py +++ b/numba/tests/test_jitclasses.py @@ -1873,5 +1873,6 @@ def __eq__(self, other): self.assertFalse(instance == instance) + if __name__ == "__main__": unittest.main()