81
81
82
82
# Major library imports.
83
83
import numpy as np
84
- import numba as _numba # to avoid clashing with clifford.numba
85
84
import sparse
86
- try :
87
- from numba .np import numpy_support as _numpy_support
88
- except ImportError :
89
- import numba .numpy_support as _numpy_support
90
85
91
86
92
87
from clifford .io import write_ga_file , read_ga_file # noqa: F401
@@ -152,12 +147,6 @@ def get_mult_function(mt: sparse.COO, gradeList,
152
147
return _get_mult_function_runtime_sparse (mt )
153
148
154
149
155
- def _get_mult_function_result_type (a : _numba .types .Type , b : _numba .types .Type , mt : np .dtype ):
156
- a_dt = _numpy_support .as_dtype (getattr (a , 'dtype' , a ))
157
- b_dt = _numpy_support .as_dtype (getattr (b , 'dtype' , b ))
158
- return np .result_type (a_dt , mt , b_dt )
159
-
160
-
161
150
def _get_mult_function (mt : sparse .COO ):
162
151
"""
163
152
Get a function similar to `` lambda a, b: np.einsum('i,ijk,k->j', a, mt, b)``
@@ -172,19 +161,14 @@ def _get_mult_function(mt: sparse.COO):
172
161
k_list , l_list , m_list = mt .coords
173
162
mult_table_vals = mt .data
174
163
175
- @_numba_utils .generated_jit ( nopython = True )
164
+ @_numba_utils .njit
176
165
def mv_mult (value , other_value ):
177
- # this casting will be done at jit-time
178
- ret_dtype = _get_mult_function_result_type (value , other_value , mult_table_vals .dtype )
179
- mult_table_vals_t = mult_table_vals .astype (ret_dtype )
180
-
181
- def mult_inner (value , other_value ):
182
- output = np .zeros (dims , dtype = ret_dtype )
183
- for k , l , m , val in zip (k_list , l_list , m_list , mult_table_vals_t ):
184
- output [l ] += value [k ] * val * other_value [m ]
185
- return output
186
-
187
- return mult_inner
166
+ res = value [k_list ] * mult_table_vals * other_value [m_list ]
167
+ output = np .zeros (dims , dtype = res .dtype )
168
+ # Can not use "np.add.at(output, l_list, res)", as ufunc.at is not supported by numba
169
+ for l , val in zip (l_list , res ):
170
+ output [l ] += val
171
+ return output
188
172
189
173
return mv_mult
190
174
@@ -203,24 +187,16 @@ def _get_mult_function_runtime_sparse(mt: sparse.COO):
203
187
k_list , l_list , m_list = mt .coords
204
188
mult_table_vals = mt .data
205
189
206
- @_numba_utils .generated_jit ( nopython = True )
190
+ @_numba_utils .njit
207
191
def mv_mult (value , other_value ):
208
- # this casting will be done at jit-time
209
- ret_dtype = _get_mult_function_result_type (value , other_value , mult_table_vals .dtype )
210
- mult_table_vals_t = mult_table_vals .astype (ret_dtype )
211
-
212
- def mult_inner (value , other_value ):
213
- output = np .zeros (dims , dtype = ret_dtype )
214
- for ind , k in enumerate (k_list ):
215
- v_val = value [k ]
216
- if v_val != 0.0 :
217
- m = m_list [ind ]
218
- ov_val = other_value [m ]
219
- if ov_val != 0.0 :
220
- l = l_list [ind ]
221
- output [l ] += v_val * mult_table_vals_t [ind ] * ov_val
222
- return output
223
- return mult_inner
192
+ # Use mask where both operands are non-zero, to avoid zero-multiplications
193
+ nz_mask = (value != 0.0 )[k_list ] & (other_value != 0.0 )[m_list ]
194
+ res = value [k_list [nz_mask ]] * mult_table_vals [nz_mask ] * other_value [m_list [nz_mask ]]
195
+ output = np .zeros (dims , dtype = res .dtype )
196
+ # Can not use "np.add.at(output, l_list[nz_mask], res)", as ufunc.at is not supported by numba
197
+ for l , val in zip (l_list [nz_mask ], res ):
198
+ output [l ] += val
199
+ return output
224
200
225
201
return mv_mult
226
202
0 commit comments