diff --git a/numba/cuda/cudamath.py b/numba/cuda/cudamath.py index 2ae56df3058..12d9715b62e 100644 --- a/numba/cuda/cudamath.py +++ b/numba/cuda/cudamath.py @@ -24,6 +24,7 @@ @infer_global(math.radians) @infer_global(math.sinh) @infer_global(math.tanh) +@infer_global(math.tan) class Math_unary(ConcreteTemplate): cases = [ signature(types.float64, types.int64), @@ -35,7 +36,6 @@ class Math_unary(ConcreteTemplate): @infer_global(math.sin) @infer_global(math.cos) -@infer_global(math.tan) @infer_global(math.ceil) @infer_global(math.floor) @infer_global(math.sqrt) diff --git a/numba/cuda/mathimpl.py b/numba/cuda/mathimpl.py index eec7cfec1c4..9dcd6dbefa0 100644 --- a/numba/cuda/mathimpl.py +++ b/numba/cuda/mathimpl.py @@ -106,14 +106,6 @@ def fp16_cos(x): return context.compile_internal(builder, fp16_cos, sig, args) -@lower(math.tan, types.float16) -def fp16_tan_impl(context, builder, sig, args): - def fp16_tan(x): - return cuda.fp16.hdiv(cuda.fp16.hsin(x), cuda.fp16.hcos(x)) - - return context.compile_internal(builder, fp16_tan, sig, args) - - @lower(math.log, types.float16) def fp16_log_impl(context, builder, sig, args): def fp16_log(x): diff --git a/numba/cuda/tests/cudapy/test_math.py b/numba/cuda/tests/cudapy/test_math.py index 00de78066bb..2ee39ab8e2c 100644 --- a/numba/cuda/tests/cudapy/test_math.py +++ b/numba/cuda/tests/cudapy/test_math.py @@ -388,7 +388,6 @@ def test_math_cos(self): def test_math_fp16(self): self.unary_template_float16(math_sin, np.sin) self.unary_template_float16(math_cos, np.cos) - self.unary_template_float16(math_tan, np.tan) self.unary_template_float16(math_exp, np.exp) self.unary_template_float16(math_log, np.log, start=1) self.unary_template_float16(math_log2, np.log2, start=1)