diff --git a/src/kernl/implementations/activation_func.py b/src/kernl/implementations/activation_func.py index 53408a5e..25562b5e 100644 --- a/src/kernl/implementations/activation_func.py +++ b/src/kernl/implementations/activation_func.py @@ -28,7 +28,7 @@ @triton.jit def tanh(x): """Tanh activation function""" - return tl.libdevice.tanh(x) + return tl.math.tanh(x) @triton.jit @@ -46,4 +46,4 @@ def fast_gelu(x): @triton.jit def gelu(x): """Gaussian Error Linear Unit (GELU)""" - return x * 0.5 * (1.0 + tl.libdevice.erf(x / sqrt2)) + return x * 0.5 * (1.0 + tl.math.erf(x / sqrt2))