Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion lib/PTO/Transforms/ExpandTileOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,14 @@ static StringRef getPrecisionModeString(pto::PrecisionMode mode) {
// HIGH_PRECISION code path would silence the warning while preserving DEFAULT
// behavior.
static const llvm::StringSet<> &highPrecisionImplementedOps() {
static const llvm::StringSet<> kImplementedOps{"pto.tlog"};
static const llvm::StringSet<> kImplementedOps{
"pto.tlog",
"pto.tdiv",
"pto.tdivs",
"pto.trecip",
"pto.trowexpanddiv",
"pto.tcolexpanddiv",
};
return kImplementedOps;
}

Expand Down
455 changes: 455 additions & 0 deletions lib/TileOps/div_hp.py

Large diffs are not rendered by default.

44 changes: 33 additions & 11 deletions lib/TileOps/tcolexpanddiv_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,50 @@
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
"""TileLang DSL template for pto.tcolexpanddiv"""

"""TileLang DSL template for pto.tcolexpanddiv with IEEE 754 high-precision support

Divide each column of src0 by a per-column scalar from src1[0, col].
Semantics: dst[row, col] = src0[row, col] / src1[0, col]
"""

import sys
from pathlib import Path
import tilelang_dsl as pto

# Import shared high-precision division algorithms
from div_hp import _div_ieee754_f32_impl, _div_ieee754_f16_impl


@pto.vkernel(
target="a5",
op="pto.tcolexpanddiv"
)
def template_tcolexpanddiv(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile):
"""Template for pto.tcolexpanddiv with optional high-precision mode."""
dtype = dst.element_type
valid_rows, valid_cols = dst.valid_shape

for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, pto.get_lanes(dtype)):
mask, remained = pto.make_mask(dtype, remained)
lhs = pto.vlds(src0[row, col:])
rhs = pto.vlds(src1[0, col:])
# TODO: 当前使用普通精度版本,后续需要添加高精度版本(vdivh)
result = pto.vdiv(lhs, rhs, mask)
pto.vsts(result, dst[row, col:], mask)
return
precision_mode = pto.get_op_attr("precision_mode", "DEFAULT")
if pto.constexpr(precision_mode == "HIGH_PRECISION"):
for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, pto.get_lanes(dtype)):
mask, remained = pto.make_mask(dtype, remained)
lhs = pto.vlds(src0[row, col:])
rhs = pto.vlds(src1[0, col:])
if pto.constexpr(dtype == pto.f32):
result = _div_ieee754_f32_impl(lhs, rhs, mask)
else: # dtype == pto.f16 (guaranteed by MLIR validation)
result = _div_ieee754_f16_impl(lhs, rhs, mask)
pto.vsts(result, dst[row, col:], mask)
else:
for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, pto.get_lanes(dtype)):
mask, remained = pto.make_mask(dtype, remained)
lhs = pto.vlds(src0[row, col:])
rhs = pto.vlds(src1[0, col:])
result = pto.vdiv(lhs, rhs, mask)
pto.vsts(result, dst[row, col:], mask)
return
36 changes: 27 additions & 9 deletions lib/TileOps/tdiv_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,45 @@
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

"""TileLang DSL template for pto.tdiv"""
"""TileLang DSL template for pto.tdiv with IEEE 754 high-precision support"""

import sys
from pathlib import Path
import tilelang_dsl as pto

# Import shared high-precision division algorithms
from div_hp import _div_ieee754_f32_impl, _div_ieee754_f16_impl


@pto.vkernel(
target="a5",
op="pto.tdiv"
)
def template_tdiv(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile):
"""Element-wise division with optional high-precision mode"""
dtype = dst.element_type
valid_rows, valid_cols = dst.valid_shape

for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, pto.get_lanes(dtype)):
mask, remained = pto.make_mask(dtype, remained)
lhs = pto.vlds(src0[row, col:])
rhs = pto.vlds(src1[row, col:])
divided = pto.vdiv(lhs, rhs, mask)
pto.vsts(divided, dst[row, col:], mask)
precision_mode = pto.get_op_attr("precision_mode", "DEFAULT")
if pto.constexpr(precision_mode == "HIGH_PRECISION"):
for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, pto.get_lanes(dtype)):
mask, remained = pto.make_mask(dtype, remained)
lhs = pto.vlds(src0[row, col:])
rhs = pto.vlds(src1[row, col:])
if pto.constexpr(dtype == pto.f32):
divided = _div_ieee754_f32_impl(lhs, rhs, mask)
else: # dtype == pto.f16 (guaranteed by MLIR validation)
divided = _div_ieee754_f16_impl(lhs, rhs, mask)
pto.vsts(divided, dst[row, col:], mask)
else:
for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, pto.get_lanes(dtype)):
mask, remained = pto.make_mask(dtype, remained)
lhs = pto.vlds(src0[row, col:])
rhs = pto.vlds(src1[row, col:])
divided = pto.vdiv(lhs, rhs, mask)
pto.vsts(divided, dst[row, col:], mask)
return
75 changes: 53 additions & 22 deletions lib/TileOps/tdivs_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,55 @@
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

"""TileLang DSL template for pto.tdivs
"""TileLang DSL template for pto.tdivs with IEEE 754 high-precision support

Supports two operand orders:
1. tdivs(src_tile, scalar, dst) -> src / scalar
2. tdivs(scalar, src_tile, dst) -> scalar / src

TODO: Add support for high-precision division (e.g., f64 or extended precision)
High-precision mode uses IEEE 754 compliant division algorithms from div_hp module
for improved accuracy with precision-sensitive, subnormal, and overflow boundary cases.
"""

import sys
from pathlib import Path
import tilelang_dsl as pto

# Import shared high-precision division algorithms
from div_hp import _div_ieee754_f32_impl, _div_ieee754_f16_impl


@pto.vkernel(
target="a5",
op="pto.tdivs",
)
def template_tdivs_tile_scalar(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile):
"""src / scalar"""
"""src / scalar with optional high-precision mode"""
dtype = src.element_type
valid_rows, valid_cols = src.valid_shape

for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, pto.get_lanes(dtype)):
mask, remained = pto.make_mask(dtype, remained)
vec = pto.vlds(src[row, col:])
scalar_vec = pto.vbr(scalar)
result = pto.vdiv(vec, scalar_vec, mask)
pto.vsts(result, dst[row, col:], mask)
precision_mode = pto.get_op_attr("precision_mode", "DEFAULT")
if pto.constexpr(precision_mode == "HIGH_PRECISION"):
for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, pto.get_lanes(dtype)):
mask, remained = pto.make_mask(dtype, remained)
vec = pto.vlds(src[row, col:])
scalar_vec = pto.vbr(scalar)
if pto.constexpr(dtype == pto.f32):
result = _div_ieee754_f32_impl(vec, scalar_vec, mask)
else: # dtype == pto.f16 (guaranteed by MLIR validation)
result = _div_ieee754_f16_impl(vec, scalar_vec, mask)
pto.vsts(result, dst[row, col:], mask)
else:
for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, pto.get_lanes(dtype)):
mask, remained = pto.make_mask(dtype, remained)
vec = pto.vlds(src[row, col:])
scalar_vec = pto.vbr(scalar)
result = pto.vdiv(vec, scalar_vec, mask)
pto.vsts(result, dst[row, col:], mask)
return


Expand All @@ -45,17 +63,30 @@ def template_tdivs_tile_scalar(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile
op="pto.tdivs",
)
def template_tdivs_scalar_tile(scalar: pto.AnyType, src: pto.Tile, dst: pto.Tile):
"""scalar / src"""
"""scalar / src with optional high-precision mode"""
dtype = src.element_type
valid_rows, valid_cols = src.valid_shape

for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, pto.get_lanes(dtype)):
mask, remained = pto.make_mask(dtype, remained)
vec = pto.vlds(src[row, col:])
scalar_vec = pto.vbr(scalar)
result = pto.vdiv(scalar_vec, vec, mask)
# TO DO: support high precision division
pto.vsts(result, dst[row, col:], mask)
return
precision_mode = pto.get_op_attr("precision_mode", "DEFAULT")
if pto.constexpr(precision_mode == "HIGH_PRECISION"):
for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, pto.get_lanes(dtype)):
mask, remained = pto.make_mask(dtype, remained)
vec = pto.vlds(src[row, col:])
scalar_vec = pto.vbr(scalar)
if pto.constexpr(dtype == pto.f32):
result = _div_ieee754_f32_impl(scalar_vec, vec, mask)
else: # dtype == pto.f16 (guaranteed by MLIR validation)
result = _div_ieee754_f16_impl(scalar_vec, vec, mask)
pto.vsts(result, dst[row, col:], mask)
else:
for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, pto.get_lanes(dtype)):
mask, remained = pto.make_mask(dtype, remained)
vec = pto.vlds(src[row, col:])
scalar_vec = pto.vbr(scalar)
result = pto.vdiv(scalar_vec, vec, mask)
pto.vsts(result, dst[row, col:], mask)
return
55 changes: 40 additions & 15 deletions lib/TileOps/trecip_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,56 @@
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

"""TileLang DSL template for pto.trecip"""
"""TileLang DSL template for pto.trecip with IEEE 754 high-precision support

Computes reciprocal: dst = 1 / src
High-precision mode uses IEEE 754 compliant division algorithms.
"""

import tilelang_dsl as pto

# TODO: Add implementation for HIGH_PRECISION type
# Import shared high-precision division algorithms
from div_hp import _div_ieee754_f32_impl, _div_ieee754_f16_impl


@pto.vkernel(
target="a5",
op="pto.trecip",
dtypes=[(pto.f16, pto.f16), (pto.f32, pto.f32)]
)
def template_trecip(src: pto.Tile, dst: pto.Tile):
"""Reciprocal with optional high-precision mode: dst = 1 / src"""
dtype = dst.element_type
valid_rows, valid_cols = dst.valid_shape

for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, pto.get_lanes(dtype)):
mask, remained = pto.make_mask(dtype, remained)
vinput = pto.vlds(src[row, col:])
if pto.constexpr(dtype == pto.f16):
one_scalar = pto.f16(1.0)
else:
one_scalar = pto.f32(1.0)
one = pto.vbr(one_scalar)
# one = pto.vbr(dtype(1.0))
result = pto.vdiv(one, vinput, mask)
pto.vsts(result, dst[row, col:], mask)
precision_mode = pto.get_op_attr("precision_mode", "DEFAULT")
if pto.constexpr(precision_mode == "HIGH_PRECISION"):
for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, pto.get_lanes(dtype)):
mask, remained = pto.make_mask(dtype, remained)
vinput = pto.vlds(src[row, col:])
if pto.constexpr(dtype == pto.f16):
one_scalar = pto.f16(1.0)
else:
one_scalar = pto.f32(1.0)
one = pto.vbr(one_scalar)
if pto.constexpr(dtype == pto.f32):
result = _div_ieee754_f32_impl(one, vinput, mask)
else: # dtype == pto.f16 (guaranteed by MLIR validation)
result = _div_ieee754_f16_impl(one, vinput, mask)
pto.vsts(result, dst[row, col:], mask)
else:
for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, pto.get_lanes(dtype)):
mask, remained = pto.make_mask(dtype, remained)
vinput = pto.vlds(src[row, col:])
if pto.constexpr(dtype == pto.f16):
one_scalar = pto.f16(1.0)
else:
one_scalar = pto.f32(1.0)
one = pto.vbr(one_scalar)
result = pto.vdiv(one, vinput, mask)
pto.vsts(result, dst[row, col:], mask)
return
Loading
Loading