Skip to content

Commit cc2590c

Browse files
authored
Merge pull request #439 from trundev/drop-generated_jit
Replace the deprecated @generated_jit by @njit
2 parents 9e47ec2 + fc8657b commit cc2590c

11 files changed

+54
-116
lines changed

.github/workflows/python-package.yml

+18-18
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ jobs:
1414
strategy:
1515
fail-fast: false
1616
matrix:
17-
python-version: [3.8]
17+
python-version: [3.11]
1818
name: "lint | Python ${{ matrix.python-version }}"
1919
steps:
20-
- uses: actions/checkout@v2
20+
- uses: actions/checkout@v4
2121
- name: Set up Python ${{ matrix.python-version }}
22-
uses: actions/setup-python@v2
22+
uses: actions/setup-python@v5
2323
with:
2424
python-version: ${{ matrix.python-version }}
2525
- name: Install dependencies
@@ -45,37 +45,37 @@ jobs:
4545
matrix:
4646
include:
4747
# fastest jobs first
48-
- python-version: 3.8
48+
- python-version: 3.11
4949
name: without JIT
5050
disable_jit: 1
51-
- python-version: 3.8
51+
- python-version: 3.11
5252
name: doctests
5353
mode: doctests
5454
# really slow job next, so it runs in parallel with the others
55-
- python-version: 3.8
55+
- python-version: 3.11
5656
name: slow tests
5757
mode: very_slow
58-
- python-version: 3.5
59-
name: default
6058
- python-version: 3.8
6159
name: default
62-
- python-version: 3.9
60+
- python-version: 3.11
6361
name: default
64-
- python-version: 3.8
62+
- python-version: 3.12
63+
name: default
64+
- python-version: 3.11
6565
name: conda
6666
conda: true
67-
- python-version: 3.8
67+
- python-version: 3.11
6868
name: benchmarks
6969
mode: bench
7070

7171
name: "build | ${{ matrix.name }} | Python ${{matrix.python-version}}"
7272
steps:
73-
- uses: actions/checkout@v2
73+
- uses: actions/checkout@v4
7474

7575
# python / pip
7676
- name: Set up Python ${{ matrix.python-version }}
7777
if: "!matrix.conda"
78-
uses: actions/setup-python@v2
78+
uses: actions/setup-python@v5
7979
with:
8080
python-version: ${{ matrix.python-version }}
8181
- name: Install dependencies
@@ -89,7 +89,7 @@ jobs:
8989
# conda
9090
- name: Set up Python ${{ matrix.python-version }} (conda)
9191
if: matrix.conda
92-
uses: conda-incubator/setup-miniconda@v2
92+
uses: conda-incubator/setup-miniconda@v3
9393
with:
9494
auto-update-conda: true
9595
python-version: ${{ matrix.python-version }}
@@ -146,24 +146,24 @@ jobs:
146146
# if: ${{ always() }}
147147
# with:
148148
# report_paths: 'junit/test-results.xml'
149-
- uses: codecov/codecov-action@v1
149+
- uses: codecov/codecov-action@v5
150150

151151
deploy:
152152
needs: test
153153
runs-on: ubuntu-latest
154154
name: deploy
155155
steps:
156-
- uses: actions/checkout@v2
156+
- uses: actions/checkout@v4
157157
- name: Set up Python ${{ matrix.python-version }}
158-
uses: actions/setup-python@v2
158+
uses: actions/setup-python@v5
159159
with:
160160
python-version: ${{ matrix.python-version }}
161161
- name: "Install"
162162
run: |
163163
python -m pip install --upgrade pip;
164164
python -m pip install build
165165
python -m build --sdist --wheel --outdir dist/
166-
- uses: actions/upload-artifact@v2
166+
- uses: actions/upload-artifact@v4
167167
with:
168168
name: dist
169169
path: dist

clifford/__init__.py

+16-40
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,7 @@
8181

8282
# Major library imports.
8383
import numpy as np
84-
import numba as _numba # to avoid clashing with clifford.numba
8584
import sparse
86-
try:
87-
from numba.np import numpy_support as _numpy_support
88-
except ImportError:
89-
import numba.numpy_support as _numpy_support
9085

9186

9287
from clifford.io import write_ga_file, read_ga_file # noqa: F401
@@ -152,12 +147,6 @@ def get_mult_function(mt: sparse.COO, gradeList,
152147
return _get_mult_function_runtime_sparse(mt)
153148

154149

155-
def _get_mult_function_result_type(a: _numba.types.Type, b: _numba.types.Type, mt: np.dtype):
156-
a_dt = _numpy_support.as_dtype(getattr(a, 'dtype', a))
157-
b_dt = _numpy_support.as_dtype(getattr(b, 'dtype', b))
158-
return np.result_type(a_dt, mt, b_dt)
159-
160-
161150
def _get_mult_function(mt: sparse.COO):
162151
"""
163152
Get a function similar to `` lambda a, b: np.einsum('i,ijk,k->j', a, mt, b)``
@@ -172,19 +161,14 @@ def _get_mult_function(mt: sparse.COO):
172161
k_list, l_list, m_list = mt.coords
173162
mult_table_vals = mt.data
174163

175-
@_numba_utils.generated_jit(nopython=True)
164+
@_numba_utils.njit
176165
def mv_mult(value, other_value):
177-
# this casting will be done at jit-time
178-
ret_dtype = _get_mult_function_result_type(value, other_value, mult_table_vals.dtype)
179-
mult_table_vals_t = mult_table_vals.astype(ret_dtype)
180-
181-
def mult_inner(value, other_value):
182-
output = np.zeros(dims, dtype=ret_dtype)
183-
for k, l, m, val in zip(k_list, l_list, m_list, mult_table_vals_t):
184-
output[l] += value[k] * val * other_value[m]
185-
return output
186-
187-
return mult_inner
166+
res = value[k_list] * mult_table_vals * other_value[m_list]
167+
output = np.zeros(dims, dtype=res.dtype)
168+
# Can not use "np.add.at(output, l_list, res)", as ufunc.at is not supported by numba
169+
for l, val in zip(l_list, res):
170+
output[l] += val
171+
return output
188172

189173
return mv_mult
190174

@@ -203,24 +187,16 @@ def _get_mult_function_runtime_sparse(mt: sparse.COO):
203187
k_list, l_list, m_list = mt.coords
204188
mult_table_vals = mt.data
205189

206-
@_numba_utils.generated_jit(nopython=True)
190+
@_numba_utils.njit
207191
def mv_mult(value, other_value):
208-
# this casting will be done at jit-time
209-
ret_dtype = _get_mult_function_result_type(value, other_value, mult_table_vals.dtype)
210-
mult_table_vals_t = mult_table_vals.astype(ret_dtype)
211-
212-
def mult_inner(value, other_value):
213-
output = np.zeros(dims, dtype=ret_dtype)
214-
for ind, k in enumerate(k_list):
215-
v_val = value[k]
216-
if v_val != 0.0:
217-
m = m_list[ind]
218-
ov_val = other_value[m]
219-
if ov_val != 0.0:
220-
l = l_list[ind]
221-
output[l] += v_val * mult_table_vals_t[ind] * ov_val
222-
return output
223-
return mult_inner
192+
# Use mask where both operands are non-zero, to avoid zero-multiplications
193+
nz_mask = (value != 0.0)[k_list] & (other_value != 0.0)[m_list]
194+
res = value[k_list[nz_mask]] * mult_table_vals[nz_mask] * other_value[m_list[nz_mask]]
195+
output = np.zeros(dims, dtype=res.dtype)
196+
# Can not use "np.add.at(output, l_list[nz_mask], res)", as ufunc.at is not supported by numba
197+
for l, val in zip(l_list[nz_mask], res):
198+
output[l] += val
199+
return output
224200

225201
return mv_mult
226202

clifford/_conformal_layout.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def up(self, x: MultiVector) -> MultiVector:
6666
new_val = np.zeros(self.gaDims)
6767
new_val[:len(old_val)] = old_val
6868
x = self.MultiVector(value=new_val)
69-
except(AttributeError):
69+
except AttributeError:
7070
# if x is a scalar it does not have layout but following
7171
# will still work
7272
pass

clifford/_layout.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def vee(aval, bval):
479479
def __repr__(self):
480480
return "{}({!r}, ids={!r}, order={!r}, names={!r})".format(
481481
type(self).__name__,
482-
list(self.sig), self._basis_vector_ids, self._basis_blade_order, self.names
482+
self.sig.tolist(), self._basis_vector_ids, self._basis_blade_order, self.names
483483
)
484484

485485
def _repr_pretty_(self, p, cycle):
@@ -489,7 +489,7 @@ def _repr_pretty_(self, p, cycle):
489489
prefix = '{}('.format(type(self).__name__)
490490

491491
with p.group(len(prefix), prefix, ')'):
492-
p.text('{},'.format(list(self.sig)))
492+
p.text('{},'.format(self.sig.tolist()))
493493
p.breakable()
494494
p.text('ids=')
495495
p.pretty(self._basis_vector_ids)

clifford/_multivector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def _repr_pretty_(self, p, cycle):
579579
p.pretty(self.layout)
580580
p.text(",")
581581
p.breakable()
582-
p.text(repr(list(self.value)))
582+
p.text(repr(self.value.tolist()))
583583
if self.value.dtype != np.float64:
584584
p.text(",")
585585
p.breakable()

clifford/_numba_utils.py

-28
Original file line numberDiff line numberDiff line change
@@ -59,39 +59,11 @@ def __repr__(self):
5959
return "_pickleable_function({!r})".format(self.__func)
6060

6161

62-
class _fake_generated_jit:
63-
def __init__(self, f):
64-
self.__cache = {}
65-
self.__func = pickleable_function(f)
66-
functools.update_wrapper(self, self.__func)
67-
68-
def __getnewargs_ex__(self):
69-
return (self.__func,), {}
70-
71-
def __getstate__(self):
72-
return {}
73-
74-
def __call__(self, *args):
75-
arg_type = tuple(numba.typeof(arg) for arg in args)
76-
try:
77-
func = self.__cache[arg_type]
78-
except KeyError:
79-
func = self.__cache[arg_type] = self.__func(*arg_type)
80-
return func(*args)
81-
82-
8362
if not DISABLE_JIT:
8463
njit = numba.njit
85-
generated_jit = numba.generated_jit
8664
else:
8765
def njit(f=None, **kwargs):
8866
if f is None:
8967
return pickleable_function
9068
else:
9169
return pickleable_function(f)
92-
93-
def generated_jit(f=None, **kwargs):
94-
if f is None:
95-
return _fake_generated_jit
96-
else:
97-
return _fake_generated_jit(f)

clifford/cga.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def __init__(self, cga, *args) -> None:
399399
arg = float(arg)
400400

401401
if arg < 0:
402-
raise(ValueError('dilation should be positive'))
402+
raise ValueError('dilation should be positive')
403403

404404
mv = e**((-log(arg)/2.)*(self.cga.E0))
405405

clifford/test/test_clifford.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def test_2d_mv_array(self, g3, rng): # noqa: F811
220220
# check properties of the array are preserved (no need to check both a and b)
221221
np.testing.assert_array_equal(mv_array_a.value, value_array_a)
222222
assert mv_array_a.value.dtype == value_array_a.dtype
223-
assert type(mv_array_a.value) == type(value_array_a)
223+
assert type(mv_array_a.value) is type(value_array_a)
224224

225225
# Check addition
226226
mv_array_sum = mv_array_a + mv_array_b
@@ -806,9 +806,9 @@ def check_inv(self, A):
806806
for m, a in enumerate(A):
807807
for n, b in enumerate(A.inv):
808808
if m == n:
809-
assert(a | b == 1)
809+
assert (a | b == 1)
810810
else:
811-
assert(a | b == 0)
811+
assert (a | b == 0)
812812

813813
@pytest.mark.parametrize(('p', 'q'), [
814814
(2, 0), (3, 0), (4, 0)

clifford/test/test_function_cache.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
11
import numpy as np
2-
from clifford._numba_utils import generated_jit
2+
from clifford._numba_utils import njit
33
import pytest
44

55

6-
@generated_jit(cache=True)
7-
def foo(x):
8-
from clifford.g3 import e3
9-
10-
def impl(x):
11-
return (x * e3).value
12-
return impl
6+
@njit(cache=True)
7+
def foo(x, y):
8+
return (x * y).value
139

1410

1511
# Make the test fail on a failed cache warning
1612
@pytest.mark.filterwarnings("error")
1713
def test_function_cache():
1814
from clifford.g3 import e3
19-
np.testing.assert_array_equal((1.0*e3).value, foo(1.0))
15+
np.testing.assert_array_equal((1.0*e3).value, foo(1.0, e3))

clifford/test/test_g3c_tools.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
1-
import random
21
from functools import reduce
32
import time
43
import functools
54

65

76
import numpy as np
87
import numpy.testing as npt
9-
from numpy import exp
108
import pytest
11-
import numba
129

13-
from clifford import Cl
1410
from clifford.g3c import *
1511
from clifford.tools.g3c import *
16-
from clifford.tools.g3c.rotor_parameterisation import ga_log, ga_exp, general_logarithm, \
17-
interpolate_rotors
12+
from clifford.tools.g3c.rotor_parameterisation import ga_log, general_logarithm
1813
from clifford.tools.g3c.rotor_estimation import *
1914
from clifford.tools.g3c.object_clustering import *
2015
from clifford.tools.g3c.scene_simplification import *
@@ -175,7 +170,7 @@ def test_general_logarithm_TRS(self, rng): # noqa: F811
175170
V = (T * R * S).normal()
176171
biv = general_logarithm(V)
177172
V_rebuilt = biv.exp().normal()
178-
biv2 = general_logarithm(V)
173+
_ = general_logarithm(V)
179174

180175
C1 = random_point_pair(rng=rng)
181176
C2 = (V * C1 * ~V).normal()
@@ -381,8 +376,8 @@ def test_closest_furthest_circle_points(self, rng): # noqa: F811
381376
for _ in range(100):
382377
C1 = random_circle(rng=rng)
383378
C2 = random_circle(rng=rng)
384-
pclose = iterative_closest_points_on_circles(C1, C2)
385-
pfar = iterative_furthest_points_on_circles(C1, C2)
379+
_ = iterative_closest_points_on_circles(C1, C2)
380+
_ = iterative_furthest_points_on_circles(C1, C2)
386381

387382
def test_closest_points_circle_line(self, rng): # noqa: F811
388383
"""
@@ -740,7 +735,7 @@ def test_rotor_between_non_overlapping_spheres(self, rng): # noqa: F811
740735
rad = get_radius_from_sphere(C1)
741736
t_r = generate_translation_rotor(2.5*rad*e1)
742737
C2 = (t_r * C1 * ~t_r)(4).normal()
743-
rad2 = get_radius_from_sphere(C2)
738+
_ = get_radius_from_sphere(C2)
744739
R = rotor_between_objects(C1, C2)
745740
C3 = (R * C1 * ~R).normal()
746741
if sum(np.abs((C2 + C3).value)) < 0.0001:
@@ -1021,7 +1016,7 @@ def test_assign_objects_to_objects(self, obj_gen, rng): # noqa: F811
10211016

10221017
n_repeats = 5
10231018
for i in range(n_repeats):
1024-
r = random_rotation_translation_rotor(0.001, np.pi / 32, rng=rng)
1019+
_ = random_rotation_translation_rotor(0.001, np.pi / 32, rng=rng)
10251020
object_set_a = [obj_gen(rng=rng) for i in range(20)]
10261021
object_set_b = [l for l in object_set_a]
10271022
label_a, costs_a = assign_measurements_to_objects_matrix(object_set_a, object_set_b)

clifford/tools/g3c/rotor_estimation.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import random
2-
from scipy import e
32
import numpy as np
43
import multiprocessing
54

@@ -379,7 +378,7 @@ def cartans_lines(obj_list_a, obj_list_b):
379378
""" Performs the extended cartans algorithm as suggested by Alex Arsenovic """
380379
V_found, rs = cartan(A=obj_list_a, B=obj_list_b)
381380
theta = ((V_found*~V_found)*e1234)(0)
382-
V_found = e**(-theta/2*e123inf)*V_found
381+
V_found = np.e**(-theta/2*e123inf)*V_found
383382
return V_found
384383

385384

0 commit comments

Comments
 (0)