Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified .pre-commit-config.yaml
100644 → 100755
Empty file.
3 changes: 3 additions & 0 deletions src/flag_gems/runtime/backend/_cambricon/ops/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@
from .sum import sum, sum_dim, sum_dim_out, sum_out
from .tan import tan, tan_
from .tanh import tanh, tanh_, tanh_backward
from .threshold import threshold, threshold_backward
from .tile import tile
from .to import to_copy
from .topk import topk
Expand Down Expand Up @@ -443,4 +444,6 @@
"zero_",
"zeros",
"zeros_like",
"threshold",
"threshold_backward",
]
55 changes: 49 additions & 6 deletions src/flag_gems/runtime/backend/_cambricon/ops/abs.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,25 +1,68 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils.pointwise_dynamic import pointwise_dynamic
from flag_gems.runtime import torch_device_fn
from flag_gems.utils import libentry, libtuner

from ..utils import TOTAL_CORE_NUM

logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))


@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "COMPLEX_TO_FLOAT")])
@libentry()
@libtuner(
configs=[
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1),
triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1),
],
key=["n_elements"],
)
@triton.jit
def abs_func(x, inplace):
return tl.abs(x)
def abs_kernel(
X_ptr,
OUT_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
num_jobs = tl.num_programs(0)
block_start = pid * BLOCK_SIZE
step = num_jobs * BLOCK_SIZE
block_start = block_start.to(tl.int64)
for off in range(block_start, n_elements, step):
offsets = off + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(X_ptr + offsets, mask=mask)
tl.store(OUT_ptr + offsets, tl.abs(x), mask=mask)


def abs(A):
logger.debug("GEMS_CAMBRICON ABS")
return abs_func(A, False)
A = A.contiguous()
out = torch.empty_like(A)
N = A.numel()
if N == 0:
return out
grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
with torch_device_fn.device(A.device):
abs_kernel[grid_fn](A, out, N)
return out


def abs_(A):
logger.debug("GEMS_CAMBRICON ABS_")
abs_func(A, True, out0=A)
A_contig = A.contiguous()
N = A_contig.numel()
if N == 0:
return A
grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
with torch_device_fn.device(A.device):
abs_kernel[grid_fn](A_contig, A_contig, N)
if not A.is_contiguous():
A.copy_(A_contig)
return A
66 changes: 58 additions & 8 deletions src/flag_gems/runtime/backend/_cambricon/ops/ceil.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,33 +1,83 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils.pointwise_dynamic import pointwise_dynamic
from flag_gems.runtime import torch_device_fn
from flag_gems.utils import libentry, libtuner

from ..utils import TOTAL_CORE_NUM

logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))


@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")])
@libentry()
@libtuner(
configs=[
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1),
triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1),
],
key=["n_elements"],
)
@triton.jit
def ceil_func(x, inplace):
return tl.ceil(x.to(tl.float32)).to(x.dtype)
def ceil_kernel(
X_ptr,
OUT_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
num_jobs = tl.num_programs(0)
block_start = pid * BLOCK_SIZE
step = num_jobs * BLOCK_SIZE
block_start = block_start.to(tl.int64)
for off in range(block_start, n_elements, step):
offsets = off + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(X_ptr + offsets, mask=mask)
result = tl.ceil(x.to(tl.float32)).to(x.dtype)
tl.store(OUT_ptr + offsets, result, mask=mask)


def ceil(A):
logger.debug("GEMS_CAMBRICON CEIL")
return ceil_func(A, False)
A = A.contiguous()
out = torch.empty_like(A)
N = A.numel()
if N == 0:
return out
grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
with torch_device_fn.device(A.device):
ceil_kernel[grid_fn](A, out, N)
return out


def ceil_out(A, *, out=None):
logger.debug("GEMS_CAMBRICON CEIL_OUT")
A = A.contiguous()
N = A.numel()
if out is None:
return ceil_func(A, False)
ceil_func(A, False, out0=out)
out = torch.empty_like(A)
if N == 0:
return out
grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
with torch_device_fn.device(A.device):
ceil_kernel[grid_fn](A, out, N)
return out


def ceil_(A):
logger.debug("GEMS_CAMBRICON CEIL_")
ceil_func(A, True, out0=A)
A_contig = A.contiguous()
N = A_contig.numel()
if N == 0:
return A
grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
with torch_device_fn.device(A.device):
ceil_kernel[grid_fn](A_contig, A_contig, N)
if not A.is_contiguous():
A.copy_(A_contig)
return A
60 changes: 34 additions & 26 deletions src/flag_gems/runtime/backend/_cambricon/ops/dropout.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import triton.language as tl
from triton.language.extra.mlu.libdevice import philox as _philox

from flag_gems import runtime
from flag_gems.runtime import torch_device_fn
from flag_gems.utils import libentry, libtuner
from flag_gems.utils.random_utils import (
philox_backend_seed_offset,
uint_to_uniform_float,
Expand All @@ -17,8 +17,19 @@

logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))

UNROLL = 4


@triton.heuristics(runtime.get_heuristic_config("dropout"))
@libentry()
@libtuner(
configs=[
triton.Config(kwargs={"BLOCK": 1024}, num_stages=3, num_warps=1),
triton.Config(kwargs={"BLOCK": 4096}, num_stages=3, num_warps=1),
triton.Config(kwargs={"BLOCK": 16384}, num_stages=3, num_warps=1),
triton.Config(kwargs={"BLOCK": 32768}, num_stages=3, num_warps=1),
],
key=["N"],
)
@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
def dropout_forward_kernel(
X,
Expand All @@ -30,7 +41,7 @@ def dropout_forward_kernel(
philox_offset,
BLOCK: tl.constexpr,
):
UNROLL: tl.constexpr = 4 # philox generate 128 random bits at a time
UNROLL: tl.constexpr = 4
philox_seed = philox_seed.to(tl.int64)
philox_offset = philox_offset.to(tl.int64)

Expand All @@ -50,19 +61,27 @@ def dropout_forward_kernel(
r = uint_to_uniform_float(r)

mask = r > p
mask_reshaped = tl.reshape(mask, [UNROLL * BLOCK], can_reorder=True)

off = block_offset + tl.arange(0, UNROLL * BLOCK)
x = tl.load(X + off, mask=off < N, other=0.0)
y = (
x * mp * tl.reshape(mask, [UNROLL * BLOCK], can_reorder=True)
) # tl.where(mask0, x0 * p, 0.0)
mask_reshaped = tl.reshape(mask, [UNROLL * BLOCK], can_reorder=True)
tl.store(dropout_mask + off, mask_reshaped, mask=off < N)
tl.store(Y + off, y, mask=off < N)
valid = off < N
x = tl.load(X + off, mask=valid, other=0.0)
y = tl.where(mask_reshaped, x * mp, 0.0)
tl.store(dropout_mask + off, mask_reshaped, mask=valid)
tl.store(Y + off, y, mask=valid)
i4_start += num_jobs * BLOCK


@triton.heuristics(runtime.get_heuristic_config("dropout"))
@libentry()
@libtuner(
configs=[
triton.Config(kwargs={"BLOCK": 1024}, num_stages=3, num_warps=1),
triton.Config(kwargs={"BLOCK": 4096}, num_stages=3, num_warps=1),
triton.Config(kwargs={"BLOCK": 16384}, num_stages=3, num_warps=1),
triton.Config(kwargs={"BLOCK": 32768}, num_stages=3, num_warps=1),
],
key=["N"],
)
@triton.jit(do_not_specialize=["scale"])
def dropout_backward_kernel(
DY,
Expand All @@ -73,23 +92,19 @@ def dropout_backward_kernel(
BLOCK: tl.constexpr,
):
UNROLL: tl.constexpr = 4

pid = tl.program_id(0)
num_programs = tl.num_programs(0)
block_start = pid * UNROLL * BLOCK
step = num_programs * UNROLL * BLOCK
for block_offset in range(block_start, N, step):
off = block_offset + tl.arange(0, UNROLL * BLOCK)
valid = off < N
mask = tl.load(
dropout_mask + off, mask=off < N, other=0, eviction_policy="evict_first"
dropout_mask + off, mask=valid, other=0, eviction_policy="evict_first"
)
dy = tl.load(DY + off, mask=off < N, other=0.0, eviction_policy="evict_first")
dy = tl.load(DY + off, mask=valid, other=0.0, eviction_policy="evict_first")
dx = dy * mask * scale

tl.store(DX + off, dx, mask=off < N, eviction_policy="evict_first")


UNROLL = 4
tl.store(DX + off, dx, mask=valid, eviction_policy="evict_first")


def dropout(input, p, train=True):
Expand All @@ -104,16 +119,13 @@ def dropout(input, p, train=True):
return out, mask
assert p > 0.0 and p < 1.0, "p must be in (0, 1)"
device = input.device
# TODO: remove contiguous enforcement
input = input.contiguous()
out = torch.empty_like(input)
mask = torch.empty_like(input, dtype=torch.bool)
N = input.numel()
grid_fn = lambda meta: (
min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM),
)
# (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
# hence we cannot obtain the per thread offset as in Pytorch.
increment = triton.cdiv(N, UNROLL)
with torch_device_fn.device(device):
philox_seed, philox_offset = philox_backend_seed_offset(increment)
Expand All @@ -125,8 +137,6 @@ def dropout(input, p, train=True):
p,
philox_seed,
philox_offset,
num_warps=1,
num_stages=3,
)
return out, mask

Expand All @@ -146,7 +156,5 @@ def dropout_backward(grad_output, mask, scale):
mask,
N,
scale,
num_stages=3,
num_warps=1,
)
return grad_input
Loading
Loading