Skip to content

Commit

Permalink
Merge pull request numba#7619 from gmarkall/issue-7607
Browse files Browse the repository at this point in the history
CUDA: Fix linking with PTX when compiling lazily
  • Loading branch information
sklam authored Jan 21, 2022
2 parents ef1ba4c + 04ac137 commit 6043825
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
3 changes: 2 additions & 1 deletion numba/cuda/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def autojitwrapper(func):
fastmath=fastmath)
else:
def autojitwrapper(func):
return jit(func, device=device, debug=debug, opt=opt, **kws)
return jit(func, device=device, debug=debug, opt=opt,
link=link, **kws)

return autojitwrapper
# func_or_sig is a function
Expand Down
21 changes: 13 additions & 8 deletions numba/cuda/tests/cudadrv/test_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,18 @@ def test_linker_basic(self):
linker = Linker.new()
del linker

@require_context
def test_linking(self):
def _test_linking(self, eager):
global bar # must be a global; other it is recognized as a freevar
bar = cuda.declare_device('bar', 'int32(int32)')

link = os.path.join(os.path.dirname(__file__), 'data', 'jitlink.ptx')

@cuda.jit('void(int32[:], int32[:])', link=[link])
if eager:
args = ['void(int32[:], int32[:])']
else:
args = []

@cuda.jit(*args, link=[link])
def foo(x, y):
i = cuda.grid(1)
x[i] += bar(y[i])
Expand All @@ -85,15 +89,19 @@ def foo(x, y):

self.assertTrue(A[0] == 123 + 2 * 321)

@require_context
def test_linking_lazy_compile(self):
self._test_linking(eager=False)

def test_linking_eager_compile(self):
self._test_linking(eager=True)

def test_try_to_link_nonexistent(self):
with self.assertRaises(LinkerError) as e:
@cuda.jit('void(int32[::1])', link=['nonexistent.a'])
def f(x):
x[0] = 0
self.assertIn('nonexistent.a not found', e.exception.args)

@require_context
def test_set_registers_no_max(self):
"""Ensure that the jitted kernel used in the test_set_registers_* tests
uses more than 57 registers - this ensures that test_set_registers_*
Expand All @@ -103,19 +111,16 @@ def test_set_registers_no_max(self):
compiled = compiled.specialize(np.empty(32), *range(6))
self.assertGreater(compiled.get_regs_per_thread(), 57)

@require_context
def test_set_registers_57(self):
compiled = cuda.jit(max_registers=57)(func_with_lots_of_registers)
compiled = compiled.specialize(np.empty(32), *range(6))
self.assertLessEqual(compiled.get_regs_per_thread(), 57)

@require_context
def test_set_registers_38(self):
compiled = cuda.jit(max_registers=38)(func_with_lots_of_registers)
compiled = compiled.specialize(np.empty(32), *range(6))
self.assertLessEqual(compiled.get_regs_per_thread(), 38)

@require_context
def test_set_registers_eager(self):
sig = void(float64[::1], int64, int64, int64, int64, int64, int64)
compiled = cuda.jit(sig, max_registers=38)(func_with_lots_of_registers)
Expand Down

0 comments on commit 6043825

Please sign in to comment.