diff --git a/news/callbacks.rst b/news/callbacks.rst new file mode 100644 index 0000000..e5efbfa --- /dev/null +++ b/news/callbacks.rst @@ -0,0 +1,13 @@ +**Added:** + +Callbacks can now be supplied to the ray caster. This includes callbacks written in Cython, subclassing from the RayCollisionCallback cython class, and callbacks generated in Python that instantiate a PythonCallback object. + +**Changed:** None + +**Deprecated:** None + +**Removed:** None + +**Fixed:** None + +**Security:** None diff --git a/pyembree/callback_handler.pxd b/pyembree/callback_handler.pxd new file mode 100644 index 0000000..a79405e --- /dev/null +++ b/pyembree/callback_handler.pxd @@ -0,0 +1,15 @@ +from rtcore_ray cimport RTCRay + +cdef enum: + _CALLBACK_TERMINATE = 0 + _CALLBACK_CONTINUE = 1 + +cdef class RayCollisionCallback: + # The function callback needs to return either _CALLBACK_TERMINATE or + # _CALLBACK_CONTINUE. _CALLBACK_CONTINUE will keep it running, but + # assumes that you have done something to the ray. Otherwise it will + # enter into an endless loop. + cdef int callback(self, RTCRay &ray) + +cdef class RayCollisionNull(RayCollisionCallback): + pass diff --git a/pyembree/callback_handler.pyx b/pyembree/callback_handler.pyx new file mode 100644 index 0000000..19f9901 --- /dev/null +++ b/pyembree/callback_handler.pyx @@ -0,0 +1,49 @@ +from rtcore_ray cimport RTCRay + +# This is to make them accessible from Python +CALLBACK_TERMINATE = _CALLBACK_TERMINATE +CALLBACK_CONTINUE = _CALLBACK_CONTINUE + +cdef class RayCollisionCallback: + cdef int callback(self, RTCRay &ray): + return CALLBACK_TERMINATE + +cdef class PythonCallback(RayCollisionCallback): + # This class lets you specify a python function that can modify in situ the + # rays that are arriving. Changes will be reflected. + cdef public object callback_function + def __init__(self, callback_function): + self.callback_function = callback_function + + cdef int callback(self, RTCRay &ray): + ray_info = dict( + org = (ray.org[0], ray.org[1], ray.org[2]), + dir = (ray.dir[0], ray.dir[1], ray.dir[2]), + tnear = ray.tnear, + tfar = ray.tfar, + time = ray.time, + mask = ray.mask, + Ng = (ray.Ng[0], ray.Ng[1], ray.Ng[2]), + u = ray.u, + v = ray.v, + geomID = ray.geomID, + primID = ray.primID, + instID = ray.instID + ) + rv = self.callback_function(ray_info) + # We now update the ray contents from the dictionary + for i in range(3): + ray.org[i] = ray_info['org'][i] + ray.dir[i] = ray_info['dir'][i] + ray.Ng[i] = ray_info['Ng'][i] + ray.tnear = ray_info['tnear'] + ray.tfar = ray_info['tfar'] + ray.mask = ray_info['mask'] + ray.u = ray_info['u'] + ray.v = ray_info['v'] + ray.geomID = ray_info['geomID'] + ray.primID = ray_info['primID'] + ray.instID = ray_info['instID'] + if rv == _CALLBACK_CONTINUE: + return _CALLBACK_CONTINUE + return _CALLBACK_TERMINATE diff --git a/pyembree/rtcore_scene.pyx b/pyembree/rtcore_scene.pyx index be5e58e..90ac323 100644 --- a/pyembree/rtcore_scene.pyx +++ b/pyembree/rtcore_scene.pyx @@ -6,6 +6,8 @@ import numbers cimport rtcore as rtc cimport rtcore_ray as rtcr cimport rtcore_geometry as rtcg +from .callback_handler cimport \ + RayCollisionCallback, RayCollisionNull, _CALLBACK_TERMINATE, _CALLBACK_CONTINUE log = logging.getLogger('pyembree') @@ -35,12 +37,16 @@ cdef class EmbreeScene: def run(self, np.ndarray[np.float32_t, ndim=2] vec_origins, np.ndarray[np.float32_t, ndim=2] vec_directions, - dists=None,query='INTERSECT',output=None): + dists=None,query='INTERSECT',output=None, + RayCollisionCallback callback_handler=None): if self.is_committed == 0: rtcCommit(self.scene_i) self.is_committed = 1 + if callback_handler is None: + callback_handler = RayCollisionNull() + cdef int nv = vec_origins.shape[0] cdef int vo_i, vd_i, vd_step cdef np.ndarray[np.int32_t, ndim=1] intersect_ids @@ -77,6 +83,7 @@ cdef class EmbreeScene: intersect_ids = np.empty(nv, dtype="int32") cdef rtcr.RTCRay ray + cdef int do_continue vd_i = 0 vd_step = 1 # If vec_directions is 1 long, we won't be updating it. @@ -96,7 +103,10 @@ cdef class EmbreeScene: vd_i += vd_step if query_type == intersect or query_type == distance: - rtcIntersect(self.scene_i, ray) + do_continue = _CALLBACK_CONTINUE + while do_continue == _CALLBACK_CONTINUE: + rtcIntersect(self.scene_i, ray) + do_continue = callback_handler.callback(ray) if not output: if query_type == intersect: intersect_ids[i] = ray.primID