Skip to content

Commit

Permalink
Merge pull request numba#8605 from testhound/testhound/fp16_canonical…
Browse files Browse the repository at this point in the history
…_math_functions

Support for CUDA fp16 math functions (part 1)
  • Loading branch information
sklam authored Nov 23, 2022
2 parents a5d7dee + b500268 commit 7282635
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 15 deletions.
2 changes: 1 addition & 1 deletion numba/cuda/cudaimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def fp16_div_impl(context, builder, sig, args):
def fp16_div(x, y):
return cuda.fp16.hdiv(x, y)

return context.compile_internal(builder, fp16_div, sig, args, )
return context.compile_internal(builder, fp16_div, sig, args)


_fp16_cmp = """{{
Expand Down
34 changes: 22 additions & 12 deletions numba/cuda/cudamath.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,18 @@
@infer_global(math.asinh)
@infer_global(math.atan)
@infer_global(math.atanh)
@infer_global(math.ceil)
@infer_global(math.cos)
@infer_global(math.cosh)
@infer_global(math.degrees)
@infer_global(math.erf)
@infer_global(math.erfc)
@infer_global(math.exp)
@infer_global(math.expm1)
@infer_global(math.fabs)
@infer_global(math.floor)
@infer_global(math.gamma)
@infer_global(math.lgamma)
@infer_global(math.log)
@infer_global(math.log2)
@infer_global(math.log10)
@infer_global(math.log1p)
@infer_global(math.radians)
@infer_global(math.sin)
@infer_global(math.sinh)
@infer_global(math.sqrt)
@infer_global(math.tan)
@infer_global(math.tanh)
@infer_global(math.trunc)
@infer_global(math.tan)
class Math_unary(ConcreteTemplate):
cases = [
signature(types.float64, types.int64),
Expand All @@ -45,6 +34,27 @@ class Math_unary(ConcreteTemplate):
]


@infer_global(math.sin)
@infer_global(math.cos)
@infer_global(math.ceil)
@infer_global(math.floor)
@infer_global(math.sqrt)
@infer_global(math.log)
@infer_global(math.log2)
@infer_global(math.log10)
@infer_global(math.exp)
@infer_global(math.fabs)
@infer_global(math.trunc)
class Math_unary_with_fp16(ConcreteTemplate):
cases = [
signature(types.float64, types.int64),
signature(types.float64, types.uint64),
signature(types.float32, types.float32),
signature(types.float64, types.float64),
signature(types.float16, types.float16),
]


@infer_global(math.atan2)
class Math_atan2(ConcreteTemplate):
key = math.atan2
Expand Down
90 changes: 90 additions & 0 deletions numba/cuda/mathimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from numba.core.imputils import Registry
from numba.types import float32, float64, int64, uint64
from numba.cuda import libdevice
from numba import cuda

registry = Registry()
lower = registry.lower
Expand Down Expand Up @@ -42,6 +43,7 @@
unarys += [('atanh', 'atanhf', math.atanh)]
unarys += [('tan', 'tanf', math.tan)]
unarys += [('tanh', 'tanhf', math.tanh)]
unarys += [('trunc', 'truncf', math.trunc)]

unarys_fastmath = {}
unarys_fastmath['cosf'] = 'fast_cosf'
Expand Down Expand Up @@ -88,6 +90,94 @@ def math_isfinite_int(context, builder, sig, args):
return context.get_constant(types.boolean, 1)


@lower(math.sin, types.float16)
def fp16_sin_impl(context, builder, sig, args):
def fp16_sin(x):
return cuda.fp16.hsin(x)

return context.compile_internal(builder, fp16_sin, sig, args)


@lower(math.cos, types.float16)
def fp16_cos_impl(context, builder, sig, args):
def fp16_cos(x):
return cuda.fp16.hcos(x)

return context.compile_internal(builder, fp16_cos, sig, args)


@lower(math.log, types.float16)
def fp16_log_impl(context, builder, sig, args):
def fp16_log(x):
return cuda.fp16.hlog(x)

return context.compile_internal(builder, fp16_log, sig, args)


@lower(math.log10, types.float16)
def fp16_log10_impl(context, builder, sig, args):
def fp16_log10(x):
return cuda.fp16.hlog10(x)

return context.compile_internal(builder, fp16_log10, sig, args)


@lower(math.log2, types.float16)
def fp16_log2_impl(context, builder, sig, args):
def fp16_log2(x):
return cuda.fp16.hlog2(x)

return context.compile_internal(builder, fp16_log2, sig, args)


@lower(math.exp, types.float16)
def fp16_exp_impl(context, builder, sig, args):
def fp16_exp(x):
return cuda.fp16.hexp(x)

return context.compile_internal(builder, fp16_exp, sig, args)


@lower(math.floor, types.float16)
def fp16_floor_impl(context, builder, sig, args):
def fp16_floor(x):
return cuda.fp16.hfloor(x)

return context.compile_internal(builder, fp16_floor, sig, args)


@lower(math.ceil, types.float16)
def fp16_ceil_impl(context, builder, sig, args):
def fp16_ceil(x):
return cuda.fp16.hceil(x)

return context.compile_internal(builder, fp16_ceil, sig, args)


@lower(math.sqrt, types.float16)
def fp16_sqrt_impl(context, builder, sig, args):
def fp16_sqrt(x):
return cuda.fp16.hsqrt(x)

return context.compile_internal(builder, fp16_sqrt, sig, args)


@lower(math.fabs, types.float16)
def fp16_fabs_impl(context, builder, sig, args):
def fp16_fabs(x):
return cuda.fp16.habs(x)

return context.compile_internal(builder, fp16_fabs, sig, args)


@lower(math.trunc, types.float16)
def fp16_trunc_impl(context, builder, sig, args):
def fp16_trunc(x):
return cuda.fp16.htrunc(x)

return context.compile_internal(builder, fp16_trunc, sig, args)


def impl_boolean(key, ty, libfunc):
def lower_boolean_impl(context, builder, sig, args):
libfunc_impl = context.get_function(libfunc,
Expand Down
42 changes: 40 additions & 2 deletions numba/cuda/tests/cudapy/test_math.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import numpy as np
from numba.cuda.testing import unittest, CUDATestCase, skip_on_cudasim
from numba.cuda.testing import (skip_unless_cc_53,
skip_unless_cuda_python,
unittest,
CUDATestCase,
skip_on_cudasim)
from numba.np import numpy_support
from numba import cuda, float32, float64, int32, vectorize, void, int64
import math
Expand Down Expand Up @@ -195,6 +199,11 @@ def math_radians(A, B):
B[i] = math.radians(A[i])


def math_trunc(A, B):
i = cuda.grid(1)
B[i] = math.trunc(A[i])


def math_pow_binop(A, B, C):
i = cuda.grid(1)
C[i] = A[i] ** B[i]
Expand All @@ -206,6 +215,9 @@ def math_mod_binop(A, B, C):


class TestCudaMath(CUDATestCase):
def unary_template_float16(self, func, npfunc, start=0, stop=1):
self.unary_template(func, npfunc, np.float16, np.float16, start, stop)

def unary_template_float32(self, func, npfunc, start=0, stop=1):
self.unary_template(func, npfunc, np.float32, np.float32, start, stop)

Expand Down Expand Up @@ -233,8 +245,10 @@ def unary_template(self, func, npfunc, npdtype, nprestype, start, stop):
# the tightest under which the tests will pass.
if npdtype == np.float64:
rtol = 1e-13
else:
elif npdtype == np.float32:
rtol = 1e-6
else:
rtol = 1e-3
np.testing.assert_allclose(npfunc(A), B, rtol=rtol)

def unary_bool_special_values(self, func, npfunc, npdtype, npmtype):
Expand Down Expand Up @@ -369,6 +383,21 @@ def test_math_cos(self):
self.unary_template_int64(math_cos, np.cos)
self.unary_template_uint64(math_cos, np.cos)

@skip_unless_cc_53
@skip_unless_cuda_python('NVIDIA Binding needed for NVRTC')
def test_math_fp16(self):
self.unary_template_float16(math_sin, np.sin)
self.unary_template_float16(math_cos, np.cos)
self.unary_template_float16(math_exp, np.exp)
self.unary_template_float16(math_log, np.log, start=1)
self.unary_template_float16(math_log2, np.log2, start=1)
self.unary_template_float16(math_log10, np.log10, start=1)
self.unary_template_float16(math_fabs, np.fabs, start=-1)
self.unary_template_float16(math_sqrt, np.sqrt)
self.unary_template_float16(math_ceil, np.ceil)
self.unary_template_float16(math_floor, np.floor)
self.unary_template_float16(math_trunc, np.trunc)

#---------------------------------------------------------------------------
# test_math_sin

Expand Down Expand Up @@ -622,6 +651,15 @@ def test_math_floor(self):
self.unary_template_int64(math_floor, np.floor)
self.unary_template_uint64(math_floor, np.floor)

#---------------------------------------------------------------------------
# test_math_trunc

def test_math_trunc(self):
self.unary_template_float32(math_trunc, np.trunc)
self.unary_template_float64(math_trunc, np.trunc)
self.unary_template_int64(math_trunc, np.trunc)
self.unary_template_uint64(math_trunc, np.trunc)

#---------------------------------------------------------------------------
# test_math_copysign

Expand Down

0 comments on commit 7282635

Please sign in to comment.