Skip to content

Commit

Permalink
Merge pull request numba#5877 from EPronovost/epronovost/jitclass-bui…
Browse files Browse the repository at this point in the history
…ltins

Jitclass builtin methods
  • Loading branch information
sklam authored Jun 21, 2022
2 parents 63d0c47 + 6bfb69b commit 7dab9b9
Show file tree
Hide file tree
Showing 7 changed files with 1,084 additions and 11 deletions.
56 changes: 56 additions & 0 deletions docs/source/user/jitclass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ compiled functions:
* calling methods (e.g. ``mybag.increment(3)``);
* calling static methods as instance attributes (e.g. ``mybag.add(1, 1)``);
* calling static methods as class attributes (e.g. ``Bag.add(1, 2)``);
* using select dunder methods (e.g. ``__add__`` with ``mybag + otherbag``);

Using jitclasses in Numba compiled function is more efficient.
Short methods can be inlined (at the discretion of LLVM inliner).
Expand All @@ -197,6 +198,61 @@ class definition (i.e. code cannot call ``Bag.add()`` from within another method
of ``Bag``).


Supported dunder methods
------------------------

The following dunder methods may be defined for jitclasses:

* ``__abs__``
* ``__bool__``
* ``__complex__``
* ``__contains__``
* ``__float__``
* ``__getitem__``
* ``__hash__``
* ``__index__``
* ``__int__``
* ``__len__``
* ``__setitem__``
* ``__str__``
* ``__eq__``
* ``__ne__``
* ``__ge__``
* ``__gt__``
* ``__le__``
* ``__lt__``
* ``__add__``
* ``__floordiv__``
* ``__lshift__``
* ``__mod__``
* ``__mul__``
* ``__neg__``
* ``__pos__``
* ``__pow__``
* ``__rshift__``
* ``__sub__``
* ``__truediv__``
* ``__and__``
* ``__or__``
* ``__xor__``
* ``__iadd__``
* ``__ifloordiv__``
* ``__ilshift__``
* ``__imod__``
* ``__imul__``
* ``__ipow__``
* ``__irshift__``
* ``__isub__``
* ``__itruediv__``
* ``__iand__``
* ``__ior__``
* ``__ixor__``

Refer to the `Python Data Model documentation
<https://docs.python.org/3/reference/datamodel.html>`_ for descriptions of
these methods.


Limitations
===========

Expand Down
4 changes: 2 additions & 2 deletions numba/core/typing/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ def generic(self, args, kws):
if len(args) == 1:
[arg] = args
if arg not in types.number_domain:
raise TypeError("complex() only support for numbers")
raise errors.NumbaTypeError("complex() only support for numbers")
if arg == types.float32:
return signature(types.complex64, arg)
else:
Expand All @@ -1002,7 +1002,7 @@ def generic(self, args, kws):
[real, imag] = args
if (real not in types.number_domain or
imag not in types.number_domain):
raise TypeError("complex() only support for numbers")
raise errors.NumbaTypeError("complex() only support for numbers")
if real == imag == types.float32:
return signature(types.complex64, real, imag)
else:
Expand Down
1 change: 1 addition & 0 deletions numba/experimental/jitclass/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from numba.experimental.jitclass.decorators import jitclass
from numba.experimental.jitclass import boxing # Has import-time side effect
from numba.experimental.jitclass import overloads # Has import-time side effect
6 changes: 6 additions & 0 deletions numba/experimental/jitclass/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,12 @@ def _drop_ignored_attrs(dct):
elif getattr(v, '__objclass__', None) is object:
drop.add(k)

# If a class defines __eq__ but not __hash__, __hash__ is implicitly set to
# None. This is a class member, and class members are not presently
# supported.
if '__hash__' in dct and dct['__hash__'] is None:
drop.add('__hash__')

for k in drop:
del dct[k]

Expand Down
78 changes: 70 additions & 8 deletions numba/experimental/jitclass/boxing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from llvmlite import ir

from numba.core import types, cgutils
from numba.core.decorators import njit
from numba.core.pythonapi import box, unbox, NativeValue
from numba import njit
from numba.core.typing.typeof import typeof_impl
from numba.experimental.jitclass import _box


_getter_code_template = """
Expand Down Expand Up @@ -95,18 +97,66 @@ def _specialize_box(typ):
doc = getattr(imp, '__doc__', None)
dct[field] = property(getter, setter, doc=doc)
# Inject methods as class members
supported_dunders = {
"__abs__",
"__bool__",
"__complex__",
"__contains__",
"__float__",
"__getitem__",
"__hash__",
"__index__",
"__int__",
"__len__",
"__setitem__",
"__str__",
"__eq__",
"__ne__",
"__ge__",
"__gt__",
"__le__",
"__lt__",
"__add__",
"__floordiv__",
"__lshift__",
"__mod__",
"__mul__",
"__neg__",
"__pos__",
"__pow__",
"__rshift__",
"__sub__",
"__truediv__",
"__and__",
"__or__",
"__xor__",
"__iadd__",
"__ifloordiv__",
"__ilshift__",
"__imod__",
"__imul__",
"__ipow__",
"__irshift__",
"__isub__",
"__itruediv__",
"__iand__",
"__ior__",
"__ixor__",
}
for name, func in typ.methods.items():
if (name == "__getitem__" or name == "__setitem__") or \
(not (name.startswith('__') and name.endswith('__'))):

dct[name] = _generate_method(name, func)
if (
name.startswith("__")
and name.endswith("__")
and name not in supported_dunders
):
continue
dct[name] = _generate_method(name, func)

# Inject static methods as class members
for name, func in typ.static_methods.items():
dct[name] = _generate_method(name, func)

# Create subclass
from numba.experimental.jitclass import _box
subcls = type(typ.classname, (_box.Box,), dct)
# Store to cache
_cache_specialized_box[typ] = subcls
Expand Down Expand Up @@ -159,7 +209,6 @@ def set_member(member_offset, value):
casted = c.builder.bitcast(ptr, llvoidptr.as_pointer())
c.builder.store(value, casted)

from numba.experimental.jitclass import _box
set_member(_box.box_meminfoptr_offset, addr_meminfo)
set_member(_box.box_dataptr_offset, addr_data)
return box
Expand All @@ -179,7 +228,6 @@ def access_member(member_offset):
inst = struct_cls(c.context, c.builder)

# load from Python object
from numba.experimental.jitclass import _box
ptr_meminfo = access_member(_box.box_meminfoptr_offset)
ptr_dataptr = access_member(_box.box_dataptr_offset)

Expand All @@ -192,3 +240,17 @@ def access_member(member_offset):
c.context.nrt.incref(c.builder, typ, ret)

return NativeValue(ret, is_error=c.pyapi.c_api_error())


# Add a typeof_impl implementation for boxed jitclasses to short-circut the
# various tests in typeof. This is needed for jitclasses which implement a
# custom hash method. Without this, typeof_impl will return None, and one of the
# later attempts to determine the type of the jitclass (before checking for
# _numba_type_) will look up the object in a dictionary, triggering the hash
# method. This will cause the dispatcher to determine the call signature of the
# jit decorated obj.__hash__ method, which will call typeof(obj), and thus
# infinite loop.
# This implementation is here instead of in typeof.py to avoid circular imports.
@typeof_impl.register(_box.Box)
def _typeof_jitclass_box(val, c):
return getattr(type(val), "_numba_type_")
Loading

0 comments on commit 7dab9b9

Please sign in to comment.