diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml old mode 100644 new mode 100755 diff --git a/src/flag_gems/runtime/backend/_cambricon/ops/__init__.py b/src/flag_gems/runtime/backend/_cambricon/ops/__init__.py old mode 100644 new mode 100755 index 32691e6e15..55a771f928 --- a/src/flag_gems/runtime/backend/_cambricon/ops/__init__.py +++ b/src/flag_gems/runtime/backend/_cambricon/ops/__init__.py @@ -166,6 +166,7 @@ from .sum import sum, sum_dim, sum_dim_out, sum_out from .tan import tan, tan_ from .tanh import tanh, tanh_, tanh_backward +from .threshold import threshold, threshold_backward from .tile import tile from .to import to_copy from .topk import topk @@ -443,4 +444,6 @@ "zero_", "zeros", "zeros_like", + "threshold", + "threshold_backward", ] diff --git a/src/flag_gems/runtime/backend/_cambricon/ops/abs.py b/src/flag_gems/runtime/backend/_cambricon/ops/abs.py old mode 100644 new mode 100755 index 6a49721571..2b991668e7 --- a/src/flag_gems/runtime/backend/_cambricon/ops/abs.py +++ b/src/flag_gems/runtime/backend/_cambricon/ops/abs.py @@ -1,25 +1,68 @@ import logging +import torch import triton import triton.language as tl -from ..utils.pointwise_dynamic import pointwise_dynamic +from flag_gems.runtime import torch_device_fn +from flag_gems.utils import libentry, libtuner + +from ..utils import TOTAL_CORE_NUM logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "COMPLEX_TO_FLOAT")]) +@libentry() +@libtuner( + configs=[ + triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1), + ], + key=["n_elements"], +) @triton.jit -def abs_func(x, inplace): - return tl.abs(x) +def abs_kernel( + X_ptr, + OUT_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + num_jobs = tl.num_programs(0) + block_start = pid * BLOCK_SIZE + step = num_jobs * BLOCK_SIZE + block_start = block_start.to(tl.int64) + for off in range(block_start, n_elements, step): + offsets = off + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(X_ptr + offsets, mask=mask) + tl.store(OUT_ptr + offsets, tl.abs(x), mask=mask) def abs(A): logger.debug("GEMS_CAMBRICON ABS") - return abs_func(A, False) + A = A.contiguous() + out = torch.empty_like(A) + N = A.numel() + if N == 0: + return out + grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) + with torch_device_fn.device(A.device): + abs_kernel[grid_fn](A, out, N) + return out def abs_(A): logger.debug("GEMS_CAMBRICON ABS_") - abs_func(A, True, out0=A) + A_contig = A.contiguous() + N = A_contig.numel() + if N == 0: + return A + grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) + with torch_device_fn.device(A.device): + abs_kernel[grid_fn](A_contig, A_contig, N) + if not A.is_contiguous(): + A.copy_(A_contig) return A diff --git a/src/flag_gems/runtime/backend/_cambricon/ops/ceil.py b/src/flag_gems/runtime/backend/_cambricon/ops/ceil.py old mode 100644 new mode 100755 index df349b38f4..1828af8438 --- a/src/flag_gems/runtime/backend/_cambricon/ops/ceil.py +++ b/src/flag_gems/runtime/backend/_cambricon/ops/ceil.py @@ -1,33 +1,83 @@ import logging +import torch import triton import triton.language as tl -from ..utils.pointwise_dynamic import pointwise_dynamic +from flag_gems.runtime import torch_device_fn +from flag_gems.utils import libentry, libtuner + +from ..utils import TOTAL_CORE_NUM logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")]) +@libentry() +@libtuner( + configs=[ + triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1), + ], + key=["n_elements"], +) @triton.jit -def ceil_func(x, inplace): - return tl.ceil(x.to(tl.float32)).to(x.dtype) +def ceil_kernel( + X_ptr, + OUT_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + num_jobs = tl.num_programs(0) + block_start = pid * BLOCK_SIZE + step = num_jobs * BLOCK_SIZE + block_start = block_start.to(tl.int64) + for off in range(block_start, n_elements, step): + offsets = off + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(X_ptr + offsets, mask=mask) + result = tl.ceil(x.to(tl.float32)).to(x.dtype) + tl.store(OUT_ptr + offsets, result, mask=mask) def ceil(A): logger.debug("GEMS_CAMBRICON CEIL") - return ceil_func(A, False) + A = A.contiguous() + out = torch.empty_like(A) + N = A.numel() + if N == 0: + return out + grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) + with torch_device_fn.device(A.device): + ceil_kernel[grid_fn](A, out, N) + return out def ceil_out(A, *, out=None): logger.debug("GEMS_CAMBRICON CEIL_OUT") + A = A.contiguous() + N = A.numel() if out is None: - return ceil_func(A, False) - ceil_func(A, False, out0=out) + out = torch.empty_like(A) + if N == 0: + return out + grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) + with torch_device_fn.device(A.device): + ceil_kernel[grid_fn](A, out, N) return out def ceil_(A): logger.debug("GEMS_CAMBRICON CEIL_") - ceil_func(A, True, out0=A) + A_contig = A.contiguous() + N = A_contig.numel() + if N == 0: + return A + grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) + with torch_device_fn.device(A.device): + ceil_kernel[grid_fn](A_contig, A_contig, N) + if not A.is_contiguous(): + A.copy_(A_contig) return A diff --git a/src/flag_gems/runtime/backend/_cambricon/ops/dropout.py b/src/flag_gems/runtime/backend/_cambricon/ops/dropout.py old mode 100644 new mode 100755 index 48cb27b1e5..ca89278e8d --- a/src/flag_gems/runtime/backend/_cambricon/ops/dropout.py +++ b/src/flag_gems/runtime/backend/_cambricon/ops/dropout.py @@ -6,8 +6,8 @@ import triton.language as tl from triton.language.extra.mlu.libdevice import philox as _philox -from flag_gems import runtime from flag_gems.runtime import torch_device_fn +from flag_gems.utils import libentry, libtuner from flag_gems.utils.random_utils import ( philox_backend_seed_offset, uint_to_uniform_float, @@ -17,8 +17,19 @@ logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) +UNROLL = 4 + -@triton.heuristics(runtime.get_heuristic_config("dropout")) +@libentry() +@libtuner( + configs=[ + triton.Config(kwargs={"BLOCK": 1024}, num_stages=3, num_warps=1), + triton.Config(kwargs={"BLOCK": 4096}, num_stages=3, num_warps=1), + triton.Config(kwargs={"BLOCK": 16384}, num_stages=3, num_warps=1), + triton.Config(kwargs={"BLOCK": 32768}, num_stages=3, num_warps=1), + ], + key=["N"], +) @triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"]) def dropout_forward_kernel( X, @@ -30,7 +41,7 @@ def dropout_forward_kernel( philox_offset, BLOCK: tl.constexpr, ): - UNROLL: tl.constexpr = 4 # philox generate 128 random bits at a time + UNROLL: tl.constexpr = 4 philox_seed = philox_seed.to(tl.int64) philox_offset = philox_offset.to(tl.int64) @@ -50,19 +61,27 @@ def dropout_forward_kernel( r = uint_to_uniform_float(r) mask = r > p + mask_reshaped = tl.reshape(mask, [UNROLL * BLOCK], can_reorder=True) off = block_offset + tl.arange(0, UNROLL * BLOCK) - x = tl.load(X + off, mask=off < N, other=0.0) - y = ( - x * mp * tl.reshape(mask, [UNROLL * BLOCK], can_reorder=True) - ) # tl.where(mask0, x0 * p, 0.0) - mask_reshaped = tl.reshape(mask, [UNROLL * BLOCK], can_reorder=True) - tl.store(dropout_mask + off, mask_reshaped, mask=off < N) - tl.store(Y + off, y, mask=off < N) + valid = off < N + x = tl.load(X + off, mask=valid, other=0.0) + y = tl.where(mask_reshaped, x * mp, 0.0) + tl.store(dropout_mask + off, mask_reshaped, mask=valid) + tl.store(Y + off, y, mask=valid) i4_start += num_jobs * BLOCK -@triton.heuristics(runtime.get_heuristic_config("dropout")) +@libentry() +@libtuner( + configs=[ + triton.Config(kwargs={"BLOCK": 1024}, num_stages=3, num_warps=1), + triton.Config(kwargs={"BLOCK": 4096}, num_stages=3, num_warps=1), + triton.Config(kwargs={"BLOCK": 16384}, num_stages=3, num_warps=1), + triton.Config(kwargs={"BLOCK": 32768}, num_stages=3, num_warps=1), + ], + key=["N"], +) @triton.jit(do_not_specialize=["scale"]) def dropout_backward_kernel( DY, @@ -73,23 +92,19 @@ def dropout_backward_kernel( BLOCK: tl.constexpr, ): UNROLL: tl.constexpr = 4 - pid = tl.program_id(0) num_programs = tl.num_programs(0) block_start = pid * UNROLL * BLOCK step = num_programs * UNROLL * BLOCK for block_offset in range(block_start, N, step): off = block_offset + tl.arange(0, UNROLL * BLOCK) + valid = off < N mask = tl.load( - dropout_mask + off, mask=off < N, other=0, eviction_policy="evict_first" + dropout_mask + off, mask=valid, other=0, eviction_policy="evict_first" ) - dy = tl.load(DY + off, mask=off < N, other=0.0, eviction_policy="evict_first") + dy = tl.load(DY + off, mask=valid, other=0.0, eviction_policy="evict_first") dx = dy * mask * scale - - tl.store(DX + off, dx, mask=off < N, eviction_policy="evict_first") - - -UNROLL = 4 + tl.store(DX + off, dx, mask=valid, eviction_policy="evict_first") def dropout(input, p, train=True): @@ -104,7 +119,6 @@ def dropout(input, p, train=True): return out, mask assert p > 0.0 and p < 1.0, "p must be in (0, 1)" device = input.device - # TODO: remove contiguous enforcement input = input.contiguous() out = torch.empty_like(input) mask = torch.empty_like(input, dtype=torch.bool) @@ -112,8 +126,6 @@ def dropout(input, p, train=True): grid_fn = lambda meta: ( min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM), ) - # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, - # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) with torch_device_fn.device(device): philox_seed, philox_offset = philox_backend_seed_offset(increment) @@ -125,8 +137,6 @@ def dropout(input, p, train=True): p, philox_seed, philox_offset, - num_warps=1, - num_stages=3, ) return out, mask @@ -146,7 +156,5 @@ def dropout_backward(grad_output, mask, scale): mask, N, scale, - num_stages=3, - num_warps=1, ) return grad_input diff --git a/src/flag_gems/runtime/backend/_cambricon/ops/logical_and.py b/src/flag_gems/runtime/backend/_cambricon/ops/logical_and.py old mode 100644 new mode 100755 index 2699bf4a15..9ca5ff8458 --- a/src/flag_gems/runtime/backend/_cambricon/ops/logical_and.py +++ b/src/flag_gems/runtime/backend/_cambricon/ops/logical_and.py @@ -1,33 +1,73 @@ import logging +import torch import triton import triton.language as tl -from ..utils.pointwise_dynamic import pointwise_dynamic +from flag_gems.runtime import torch_device_fn +from flag_gems.utils import libentry, libtuner + +from ..utils import TOTAL_CORE_NUM logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) -@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) +@libentry() +@libtuner( + configs=[ + triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=3, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=3, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=3, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=3, num_warps=1), + ], + key=["n_elements"], +) @triton.jit -def logical_and_func(x, y): - return x.to(tl.int1).logical_and(y.to(tl.int1)) +def logical_and_kernel( + X_ptr, + Y_ptr, + OUT_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + num_jobs = tl.num_programs(0) + block_start = pid * BLOCK_SIZE + step = num_jobs * BLOCK_SIZE + block_start = block_start.to(tl.int64) + for off in range(block_start, n_elements, step): + offsets = off + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(X_ptr + offsets, mask=mask) + y = tl.load(Y_ptr + offsets, mask=mask) + result = (x != 0) & (y != 0) + tl.store(OUT_ptr + offsets, result, mask=mask) def logical_and(A, B): logger.debug("GEMS_CAMBRICON LOGICAL_AND") - return logical_and_func(A, B) - - -@pointwise_dynamic( - is_tensor=[True, True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")] -) -@triton.jit -def logical_and_func_(x, y, inplace): - return tl.where((x != 0) & (y != 0), 1, 0) + A = A.contiguous() + B = B.contiguous() + out = torch.empty(A.shape, dtype=torch.bool, device=A.device) + N = A.numel() + if N == 0: + return out + grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) + with torch_device_fn.device(A.device): + logical_and_kernel[grid_fn](A, B, out, N) + return out def logical_and_(A, B): logger.debug("GEMS_CAMBRICON LOGICAL_AND_") - logical_and_func_(A, B, True, out0=A) + A_contig = A.contiguous() + B = B.contiguous() + N = A_contig.numel() + if N == 0: + return A + grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) + with torch_device_fn.device(A.device): + logical_and_kernel[grid_fn](A_contig, B, A_contig, N) + if not A.is_contiguous(): + A.copy_(A_contig) return A diff --git a/src/flag_gems/runtime/backend/_cambricon/ops/logical_or.py b/src/flag_gems/runtime/backend/_cambricon/ops/logical_or.py old mode 100644 new mode 100755 index c9672c4bf4..04e271fc54 --- a/src/flag_gems/runtime/backend/_cambricon/ops/logical_or.py +++ b/src/flag_gems/runtime/backend/_cambricon/ops/logical_or.py @@ -1,27 +1,73 @@ import logging +import torch import triton import triton.language as tl -from ..utils.pointwise_dynamic import pointwise_dynamic +from flag_gems.runtime import torch_device_fn +from flag_gems.utils import libentry, libtuner + +from ..utils import TOTAL_CORE_NUM logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) -@pointwise_dynamic( - is_tensor=[True, True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")] +@libentry() +@libtuner( + configs=[ + triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=3, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=3, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=3, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=3, num_warps=1), + ], + key=["n_elements"], ) @triton.jit -def logical_or_func(x, y, inplace): - return x.to(tl.int1).logical_or(y.to(tl.int1)) +def logical_or_kernel( + X_ptr, + Y_ptr, + OUT_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + num_jobs = tl.num_programs(0) + block_start = pid * BLOCK_SIZE + step = num_jobs * BLOCK_SIZE + block_start = block_start.to(tl.int64) + for off in range(block_start, n_elements, step): + offsets = off + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(X_ptr + offsets, mask=mask) + y = tl.load(Y_ptr + offsets, mask=mask) + result = (x != 0) | (y != 0) + tl.store(OUT_ptr + offsets, result, mask=mask) def logical_or(A, B): logger.debug("GEMS_CAMBRICON LOGICAL_OR") - return logical_or_func(A, B, False) + A = A.contiguous() + B = B.contiguous() + out = torch.empty(A.shape, dtype=torch.bool, device=A.device) + N = A.numel() + if N == 0: + return out + grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) + with torch_device_fn.device(A.device): + logical_or_kernel[grid_fn](A, B, out, N) + return out def logical_or_(A, B): logger.debug("GEMS_CAMBRICON LOGICAL_OR_") - logical_or_func(A, B, True, out0=A) + A_contig = A.contiguous() + B = B.contiguous() + N = A_contig.numel() + if N == 0: + return A + grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) + with torch_device_fn.device(A.device): + logical_or_kernel[grid_fn](A_contig, B, A_contig, N) + if not A.is_contiguous(): + A.copy_(A_contig) return A diff --git a/src/flag_gems/runtime/backend/_cambricon/ops/neg.py b/src/flag_gems/runtime/backend/_cambricon/ops/neg.py old mode 100644 new mode 100755 index cb50338b67..a2b5476834 --- a/src/flag_gems/runtime/backend/_cambricon/ops/neg.py +++ b/src/flag_gems/runtime/backend/_cambricon/ops/neg.py @@ -1,23 +1,68 @@ import logging +import torch import triton +import triton.language as tl -from ..utils.pointwise_dynamic import pointwise_dynamic +from flag_gems.runtime import torch_device_fn +from flag_gems.utils import libentry, libtuner + +from ..utils import TOTAL_CORE_NUM logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")]) +@libentry() +@libtuner( + configs=[ + triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1), + ], + key=["n_elements"], +) @triton.jit -def neg_func(x, inplace): - return -x +def neg_kernel( + X_ptr, + OUT_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + num_jobs = tl.num_programs(0) + block_start = pid * BLOCK_SIZE + step = num_jobs * BLOCK_SIZE + block_start = block_start.to(tl.int64) + for off in range(block_start, n_elements, step): + offsets = off + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(X_ptr + offsets, mask=mask) + tl.store(OUT_ptr + offsets, -x, mask=mask) def neg(A): logger.debug("GEMS_CAMBRICON NEG") - return neg_func(A, False) + A = A.contiguous() + out = torch.empty_like(A) + N = A.numel() + if N == 0: + return out + grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) + with torch_device_fn.device(A.device): + neg_kernel[grid_fn](A, out, N) + return out def neg_(A): logger.debug("GEMS_CAMBRICON NEG_") - return neg_func(A, True, out0=A) + A_contig = A.contiguous() + N = A_contig.numel() + if N == 0: + return A + grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) + with torch_device_fn.device(A.device): + neg_kernel[grid_fn](A_contig, A_contig, N) + if not A.is_contiguous(): + A.copy_(A_contig) + return A diff --git a/src/flag_gems/runtime/backend/_cambricon/ops/relu.py b/src/flag_gems/runtime/backend/_cambricon/ops/relu.py old mode 100644 new mode 100755 index a0547fbdd4..5cf5ea1790 --- a/src/flag_gems/runtime/backend/_cambricon/ops/relu.py +++ b/src/flag_gems/runtime/backend/_cambricon/ops/relu.py @@ -1,19 +1,48 @@ import logging +import torch import triton import triton.language as tl +from flag_gems.runtime import torch_device_fn +from flag_gems.utils import libentry, libtuner + +from ..utils import TOTAL_CORE_NUM from ..utils.pointwise_dynamic import pointwise_dynamic logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")]) +@libentry() +@libtuner( + configs=[ + triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1), + ], + key=["n_elements"], +) @triton.jit -def relu_forward(x, inplace): - return tl.where(x > 0, x, 0) +def relu_kernel( + X_ptr, + OUT_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + num_jobs = tl.num_programs(0) + block_start = pid * BLOCK_SIZE + step = num_jobs * BLOCK_SIZE + block_start = block_start.to(tl.int64) + for off in range(block_start, n_elements, step): + offsets = off + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(X_ptr + offsets, mask=mask) + tl.store(OUT_ptr + offsets, tl.where(x > 0, x, 0), mask=mask) +# backward 保留 pointwise_dynamic 不动 @pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def relu_backward(x, dy): @@ -22,11 +51,26 @@ def relu_backward(x, dy): def relu(self): logger.debug("GEMS_CAMBRICON RELU FORWARD") - output = relu_forward(self, False) - return output + A = self.contiguous() + out = torch.empty_like(A) + N = A.numel() + if N == 0: + return out + grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) + with torch_device_fn.device(A.device): + relu_kernel[grid_fn](A, out, N) + return out def relu_(A): logger.debug("GEMS_CAMBRICON RELU_ FORWARD") - out = relu_forward(A, True, out0=A) - return out + A_contig = A.contiguous() + N = A_contig.numel() + if N == 0: + return A + grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) + with torch_device_fn.device(A.device): + relu_kernel[grid_fn](A_contig, A_contig, N) + if not A.is_contiguous(): + A.copy_(A_contig) + return A diff --git a/src/flag_gems/runtime/backend/_cambricon/ops/threshold.py b/src/flag_gems/runtime/backend/_cambricon/ops/threshold.py new file mode 100644 index 0000000000..25d01c920b --- /dev/null +++ b/src/flag_gems/runtime/backend/_cambricon/ops/threshold.py @@ -0,0 +1,70 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.runtime import torch_device_fn +from flag_gems.utils import libentry, libtuner + +from ..utils import TOTAL_CORE_NUM +from ..utils.pointwise_dynamic import pointwise_dynamic + +logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) + + +@libentry() +@libtuner( + configs=[ + triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), + triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1), + ], + key=["n_elements"], +) +@triton.jit(do_not_specialize=["threshold_val", "value_val"]) +def threshold_kernel( + X_ptr, + OUT_ptr, + n_elements, + threshold_val, + value_val, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + num_jobs = tl.num_programs(0) + block_start = pid * BLOCK_SIZE + step = num_jobs * BLOCK_SIZE + block_start = block_start.to(tl.int64) + for off in range(block_start, n_elements, step): + offsets = off + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(X_ptr + offsets, mask=mask) + result = tl.where(x > threshold_val, x, value_val) + tl.store(OUT_ptr + offsets, result, mask=mask) + + +# backward 保留 pointwise_dynamic +@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def threshold_backward_kernel(grad_output, self, threshold): + return tl.where(self > threshold, grad_output, 0) + + +def threshold(self, threshold_val, value_val): + logger.debug("GEMS_CAMBRICON THRESHOLD FORWARD") + A = self.contiguous() + out = torch.empty_like(A) + N = A.numel() + if N == 0: + return out + grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) + with torch_device_fn.device(A.device): + threshold_kernel[grid_fn](A, out, N, threshold_val, value_val) + return out + + +def threshold_backward(grad_output, self, threshold_val): + logger.debug("GEMS_CAMBRICON THRESHOLD BACKWARD") + return threshold_backward_kernel(grad_output, self, threshold_val)