Skip to content

Commit

Permalink
Merge pull request numba#7815 from gmarkall/cuda-dispatcher-base-2022…
Browse files Browse the repository at this point in the history
…0204

CUDA Dispatcher refactor 2: inherit from `dispatcher.Dispatcher`
  • Loading branch information
sklam authored Feb 18, 2022
2 parents 775b988 + 17f9585 commit 27426b5
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 185 deletions.
2 changes: 1 addition & 1 deletion docs/source/cuda-reference/kernel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ This is similar to launch configuration in CUDA C/C++:
Dispatcher objects also provide several utility methods for inspection and
creating a specialized instance:

.. autoclass:: numba.cuda.dispatcher.Dispatcher
.. autoclass:: numba.cuda.dispatcher.CUDADispatcher
:members: inspect_asm, inspect_llvm, inspect_sass, inspect_types,
get_regs_per_thread, specialize, specialized, extensions, forall

Expand Down
17 changes: 11 additions & 6 deletions numba/core/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,20 +269,25 @@ def dump(self, tab=''):
])


def compile_result(**kws):
keys = set(kws.keys())
def sanitize_compile_result_entries(entries):
keys = set(entries.keys())
fieldset = set(CR_FIELDS)
badnames = keys - fieldset
if badnames:
raise NameError(*badnames)
missing = fieldset - keys
for k in missing:
kws[k] = None
entries[k] = None
# Avoid keeping alive traceback variables
err = kws['typing_error']
err = entries['typing_error']
if err is not None:
kws['typing_error'] = err.with_traceback(None)
return CompileResult(**kws)
entries['typing_error'] = err.with_traceback(None)
return entries


def compile_result(**entries):
entries = sanitize_compile_result_entries(entries)
return CompileResult(**entries)


def compile_isolated(func, args, return_type=None, flags=DEFAULT_FLAGS,
Expand Down
34 changes: 31 additions & 3 deletions numba/cuda/compiler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from numba.core.typing.templates import ConcreteTemplate
from numba.core import types, typing, funcdesc, config, compiler
from numba.core.compiler import (CompilerBase, DefaultPassBuilder,
compile_result, Flags, Option)
from numba.core.compiler import (sanitize_compile_result_entries, CompilerBase,
DefaultPassBuilder, Flags, Option,
CompileResult)
from numba.core.compiler_lock import global_compiler_lock
from numba.core.compiler_machinery import (LoweringPass, AnalysisPass,
PassManager, register_pass)
Expand Down Expand Up @@ -29,6 +30,33 @@ class CUDAFlags(Flags):
)


# The CUDACompileResult (CCR) has a specially-defined entry point equal to its
# id. This is because the entry point is used as a key into a dict of
# overloads by the base dispatcher. The id of the CCR is the only small and
# unique property of a CompileResult in the CUDA target (cf. the CPU target,
# which uses its entry_point, which is a pointer value).
#
# This does feel a little hackish, and there are two ways in which this could
# be improved:
#
# 1. We could change the core of Numba so that each CompileResult has its own
# unique ID that can be used as a key - e.g. a count, similar to the way in
# which types have unique counts.
# 2. At some future time when kernel launch uses a compiled function, the entry
# point will no longer need to be a synthetic value, but will instead be a
# pointer to the compiled function as in the CPU target.

class CUDACompileResult(CompileResult):
@property
def entry_point(self):
return id(self)


def cuda_compile_result(**entries):
entries = sanitize_compile_result_entries(entries)
return CUDACompileResult(**entries)


@register_pass(mutates_CFG=True, analysis_only=False)
class CUDABackend(LoweringPass):

Expand All @@ -44,7 +72,7 @@ def run_pass(self, state):
lowered = state['cr']
signature = typing.signature(state.return_type, *state.args)

state.cr = compile_result(
state.cr = cuda_compile_result(
typing_context=state.typingctx,
target_context=state.targetctx,
typing_error=state.status.fail_reason,
Expand Down
22 changes: 17 additions & 5 deletions numba/cuda/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from numba.core import types, config, sigutils
from numba.core.errors import DeprecationError, NumbaInvalidConfigWarning
from numba.cuda.compiler import declare_device_function
from numba.cuda.dispatcher import Dispatcher
from numba.cuda.dispatcher import CUDADispatcher
from numba.cuda.simulator.kernel import FakeCUDAKernel


Expand Down Expand Up @@ -69,6 +69,7 @@ def jit(func_or_sig=None, device=False, inline=False, link=[], debug=None,

debug = config.CUDA_DEBUGINFO_DEFAULT if debug is None else debug
fastmath = kws.get('fastmath', False)
extensions = kws.get('extensions', [])

if debug and opt:
msg = ("debug=True with opt=True (the default) "
Expand Down Expand Up @@ -97,7 +98,19 @@ def _jit(func):
targetoptions['opt'] = opt
targetoptions['fastmath'] = fastmath
targetoptions['device'] = device
return Dispatcher(func, [func_or_sig], targetoptions=targetoptions)
targetoptions['extensions'] = extensions

disp = CUDADispatcher(func, targetoptions=targetoptions)

if device:
disp.compile_device(argtypes)
else:
disp.compile(argtypes)

disp._specialized = True
disp.disable_compile()

return disp

return _jit
else:
Expand All @@ -124,9 +137,8 @@ def autojitwrapper(func):
targetoptions['link'] = link
targetoptions['fastmath'] = fastmath
targetoptions['device'] = device
sigs = None
return Dispatcher(func_or_sig, sigs,
targetoptions=targetoptions)
targetoptions['extensions'] = extensions
return CUDADispatcher(func_or_sig, targetoptions=targetoptions)


def declare_device(name, sig):
Expand Down
Loading

0 comments on commit 27426b5

Please sign in to comment.