Skip to content

Commit

Permalink
Merge pull request numba#8803 from KyanCheung/main
Browse files Browse the repository at this point in the history
Attempted fix to numba#8789 by changing `compile_ptx` to accept a signature instead of argument tuple
  • Loading branch information
sklam authored Mar 15, 2023
2 parents 26bc501 + 5a8bc2f commit 1311388
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 19 deletions.
21 changes: 12 additions & 9 deletions numba/cuda/compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from numba.core.typing.templates import ConcreteTemplate
from numba.core import types, typing, funcdesc, config, compiler
from numba.core import types, typing, funcdesc, config, compiler, sigutils
from numba.core.compiler import (sanitize_compile_result_entries, CompilerBase,
DefaultPassBuilder, Flags, Option,
CompileResult)
Expand Down Expand Up @@ -243,12 +243,13 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,


@global_compiler_lock
def compile_ptx(pyfunc, args, debug=False, lineinfo=False, device=False,
def compile_ptx(pyfunc, sig, debug=False, lineinfo=False, device=False,
fastmath=False, cc=None, opt=True):
"""Compile a Python function to PTX for a given set of argument types.
:param pyfunc: The Python function to compile.
:param args: A tuple of argument types to compile for.
:param sig: The signature representing the function's input and output
types.
:param debug: Whether to include debug info in the generated PTX.
:type debug: bool
:param lineinfo: Whether to include a line mapping from the generated PTX
Expand All @@ -263,8 +264,8 @@ def compile_ptx(pyfunc, args, debug=False, lineinfo=False, device=False,
:param fastmath: Whether to enable fast math flags (ftz=1, prec_sqrt=0,
prec_div=, and fma=1)
:type fastmath: bool
:param cc: Compute capability to compile for, as a tuple ``(MAJOR, MINOR)``.
Defaults to ``(5, 3)``.
:param cc: Compute capability to compile for, as a tuple
``(MAJOR, MINOR)``. Defaults to ``(5, 3)``.
:type cc: tuple
:param opt: Enable optimizations. Defaults to ``True``.
:type opt: bool
Expand All @@ -282,9 +283,11 @@ def compile_ptx(pyfunc, args, debug=False, lineinfo=False, device=False,
'opt': 3 if opt else 0
}

args, return_type = sigutils.normalize_signature(sig)

cc = cc or config.CUDA_DEFAULT_PTX_CC
cres = compile_cuda(pyfunc, None, args, debug=debug, lineinfo=lineinfo,
fastmath=fastmath,
cres = compile_cuda(pyfunc, return_type, args, debug=debug,
lineinfo=lineinfo, fastmath=fastmath,
nvvm_options=nvvm_options, cc=cc)
resty = cres.signature.return_type

Expand All @@ -307,13 +310,13 @@ def compile_ptx(pyfunc, args, debug=False, lineinfo=False, device=False,
return ptx, resty


def compile_ptx_for_current_device(pyfunc, args, debug=False, lineinfo=False,
def compile_ptx_for_current_device(pyfunc, sig, debug=False, lineinfo=False,
device=False, fastmath=False, opt=True):
"""Compile a Python function to PTX for a given set of argument types for
the current device's compute capabilility. This calls :func:`compile_ptx`
with an appropriate ``cc`` value for the current device."""
cc = get_current_device().compute_capability
return compile_ptx(pyfunc, args, debug=debug, lineinfo=lineinfo,
return compile_ptx(pyfunc, sig, debug=debug, lineinfo=lineinfo,
device=device, fastmath=fastmath, cc=cc, opt=True)


Expand Down
10 changes: 5 additions & 5 deletions numba/cuda/tests/cudapy/test_casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_float16_to_int_ptx(self):
sizes = (8, 16, 32, 64)

for pyfunc, size in zip(pyfuncs, sizes):
ptx, _ = compile_ptx(pyfunc, [f2], device=True)
ptx, _ = compile_ptx(pyfunc, (f2,), device=True)
self.assertIn(f"cvt.rni.s{size}.f16", ptx)

@skip_unless_cc_53
Expand All @@ -156,7 +156,7 @@ def test_float16_to_uint_ptx(self):
sizes = (8, 16, 32, 64)

for pyfunc, size in zip(pyfuncs, sizes):
ptx, _ = compile_ptx(pyfunc, [f2], device=True)
ptx, _ = compile_ptx(pyfunc, (f2,), device=True)
self.assertIn(f"cvt.rni.u{size}.f16", ptx)

@skip_unless_cc_53
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_int_to_float16_ptx(self):
sizes = (8, 16, 32, 64)

for ty, size in zip(fromtys, sizes):
ptx, _ = compile_ptx(to_float16, [ty], device=True)
ptx, _ = compile_ptx(to_float16, (ty,), device=True)
self.assertIn(f"cvt.rn.f16.s{size}", ptx)

@skip_on_cudasim('Compilation unsupported in the simulator')
Expand All @@ -196,7 +196,7 @@ def test_uint_to_float16_ptx(self):
sizes = (8, 16, 32, 64)

for ty, size in zip(fromtys, sizes):
ptx, _ = compile_ptx(to_float16, [ty], device=True)
ptx, _ = compile_ptx(to_float16, (ty,), device=True)
self.assertIn(f"cvt.rn.f16.u{size}", ptx)

@skip_unless_cc_53
Expand All @@ -222,7 +222,7 @@ def test_float16_to_float_ptx(self):
postfixes = ("f32", "f64")

for pyfunc, postfix in zip(pyfuncs, postfixes):
ptx, _ = compile_ptx(pyfunc, [f2], device=True)
ptx, _ = compile_ptx(pyfunc, (f2,), device=True)
self.assertIn(f"cvt.{postfix}.f16", ptx)

@skip_unless_cc_53
Expand Down
23 changes: 18 additions & 5 deletions numba/cuda/tests/cudapy/test_compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from math import sqrt
from numba import cuda, float32, uint32, void
from numba import cuda, float32, int16, int32, uint32, void
from numba.cuda import compile_ptx, compile_ptx_for_current_device
from numba.cuda.cudadrv.nvvm import NVVM

Expand Down Expand Up @@ -43,6 +43,19 @@ def add(x, y):
# Inferred return type as expected?
self.assertEqual(resty, float32)

# Check that function's output matches signature
sig_int32 = int32(int32, int32)
ptx, resty = compile_ptx(add, sig_int32, device=True)
self.assertEqual(resty, int32)

sig_int16 = int16(int16, int16)
ptx, resty = compile_ptx(add, sig_int16, device=True)
self.assertEqual(resty, int16)
# Using string as signature
sig_string = "uint32(uint32, uint32)"
ptx, resty = compile_ptx(add, sig_string, device=True)
self.assertEqual(resty, uint32)

def test_fastmath(self):
def f(x, y, z, d):
return sqrt((x * y + z) / d)
Expand Down Expand Up @@ -85,15 +98,15 @@ def test_device_function_with_debug(self):
def f():
pass

ptx, resty = compile_ptx(f, [], device=True, debug=True)
ptx, resty = compile_ptx(f, (), device=True, debug=True)
self.check_debug_info(ptx)

def test_kernel_with_debug(self):
# Inspired by (but not originally affected by) Issue #6719
def f():
pass

ptx, resty = compile_ptx(f, [], debug=True)
ptx, resty = compile_ptx(f, (), debug=True)
self.check_debug_info(ptx)

def check_line_info(self, ptx):
Expand All @@ -109,7 +122,7 @@ def test_device_function_with_line_info(self):
def f():
pass

ptx, resty = compile_ptx(f, [], device=True, lineinfo=True)
ptx, resty = compile_ptx(f, (), device=True, lineinfo=True)
self.check_line_info(ptx)

def test_kernel_with_line_info(self):
Expand All @@ -119,7 +132,7 @@ def test_kernel_with_line_info(self):
def f():
pass

ptx, resty = compile_ptx(f, [], lineinfo=True)
ptx, resty = compile_ptx(f, (), lineinfo=True)
self.check_line_info(ptx)

def test_non_void_return_type(self):
Expand Down
1 change: 1 addition & 0 deletions numba/cuda/tests/cudapy/test_libdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def _test_call_functions(self):
else:
pyargs.insert(0, sig.return_type[::1])

pyargs = tuple(pyargs)
ptx, resty = compile_ptx(pyfunc, pyargs)

# If the function body was discarded by optimization (therefore making
Expand Down

0 comments on commit 1311388

Please sign in to comment.