Skip to content

Commit

Permalink
Replace numba.generated_jit with njit (pygae#430)
Browse files Browse the repository at this point in the history
Done by avoid the numpy.result_type (not supported by numba). Instead
do the array multiplication at-once, then use the result dtype.
This requires some numpy array-indexing tricks.
  • Loading branch information
trundev committed Nov 19, 2024
1 parent 1480e16 commit 9b301f1
Showing 1 changed file with 16 additions and 29 deletions.
45 changes: 16 additions & 29 deletions clifford/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,14 @@ def _get_mult_function(mt: sparse.COO):
k_list, l_list, m_list = mt.coords
mult_table_vals = mt.data

@_numba_utils.generated_jit(nopython=True)
@_numba_utils.njit
def mv_mult(value, other_value):
# this casting will be done at jit-time
ret_dtype = _get_mult_function_result_type(value, other_value, mult_table_vals.dtype)
mult_table_vals_t = mult_table_vals.astype(ret_dtype)

def mult_inner(value, other_value):
output = np.zeros(dims, dtype=ret_dtype)
for k, l, m, val in zip(k_list, l_list, m_list, mult_table_vals_t):
output[l] += value[k] * val * other_value[m]
return output

return mult_inner
res = value[k_list] * mult_table_vals * other_value[m_list]
output = np.zeros(dims, dtype=res.dtype)
# Can not use "np.add.at(output, l_list, res)", as ufunc.at is not supported by numba
for l, val in zip(l_list, res):
output[l] += val
return output

return mv_mult

Expand All @@ -203,24 +198,16 @@ def _get_mult_function_runtime_sparse(mt: sparse.COO):
k_list, l_list, m_list = mt.coords
mult_table_vals = mt.data

@_numba_utils.generated_jit(nopython=True)
@_numba_utils.njit
def mv_mult(value, other_value):
# this casting will be done at jit-time
ret_dtype = _get_mult_function_result_type(value, other_value, mult_table_vals.dtype)
mult_table_vals_t = mult_table_vals.astype(ret_dtype)

def mult_inner(value, other_value):
output = np.zeros(dims, dtype=ret_dtype)
for ind, k in enumerate(k_list):
v_val = value[k]
if v_val != 0.0:
m = m_list[ind]
ov_val = other_value[m]
if ov_val != 0.0:
l = l_list[ind]
output[l] += v_val * mult_table_vals_t[ind] * ov_val
return output
return mult_inner
# Use mask where both operands are non-zero, to avoid zero-multiplications
nz_mask = (value != 0.0)[k_list] & (other_value != 0.0)[m_list]
res = value[k_list[nz_mask]] * mult_table_vals[nz_mask] * other_value[m_list[nz_mask]]
output = np.zeros(dims, dtype=res.dtype)
# Can not use "np.add.at(output, l_list[nz_mask], res)", as ufunc.at is not supported by numba
for l, val in zip(l_list[nz_mask], res):
output[l] += val
return output

return mv_mult

Expand Down

0 comments on commit 9b301f1

Please sign in to comment.