Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,51 @@ TVM_DLL const Op& simdgroup_store();
*/
TVM_DLL const Op& simdgroup_multiply_accumulate();

// Metal cooperative_tensor intrinsics (MetalPerformancePrimitives / Metal 4)

/*!
* \brief Fill a cooperative_tensor with a given value.
*
* void cooperative_tensor_fill(Var d, PrimExpr index, PrimExpr value,
* int rows, int cols);
*/
TVM_DLL const Op& cooperative_tensor_fill();

/*!
* \brief Load data from device or threadgroup memory into a cooperative_tensor.
*
* void cooperative_tensor_load(Var d, PrimExpr index, PrimExpr ptr,
* PrimExpr stride, int rows, int cols,
* bool transpose_matrix,
* int mma_M, int mma_N, int mma_K,
* int operand_role);
* operand_role: 0=left(A), 1=right(B), 2=destination(C)
*/
TVM_DLL const Op& cooperative_tensor_load();

/*!
* \brief Store data from a cooperative_tensor to device or threadgroup memory.
*
* void cooperative_tensor_store(Var d, PrimExpr index, PrimExpr ptr,
* PrimExpr stride, int rows, int cols,
* bool transpose_matrix,
* int mma_M, int mma_N, int mma_K,
* int operand_role);
* operand_role: 0=left(A), 1=right(B), 2=destination(C)
*/
TVM_DLL const Op& cooperative_tensor_store();

/*!
* \brief Multiply and accumulate two matrices using cooperative_tensor
* (MetalPerformancePrimitives matmul2d).
*
* void cooperative_tensor_multiply_accumulate(
* Var d, PrimExpr index_d, Var a, PrimExpr index_a,
* Var b, PrimExpr index_b, Var c, PrimExpr index_c,
* int M, int N, int K, bool transpose_a, bool transpose_b);
*/
TVM_DLL const Op& cooperative_tensor_multiply_accumulate();

// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,7 @@ def func(

return func


if TYPE_CHECKING:
class int8: ...
class int16: ...
Expand Down Expand Up @@ -2232,6 +2233,10 @@ def wrapped(*args, **kwargs):
simdgroup_load = _op_wrapper(_tir_op.simdgroup_load)
simdgroup_store = _op_wrapper(_tir_op.simdgroup_store)
simdgroup_multiply_accumulate = _op_wrapper(_tir_op.simdgroup_multiply_accumulate)
cooperative_tensor_fill = _op_wrapper(_tir_op.cooperative_tensor_fill)
cooperative_tensor_load = _op_wrapper(_tir_op.cooperative_tensor_load)
cooperative_tensor_store = _op_wrapper(_tir_op.cooperative_tensor_store)
cooperative_tensor_multiply_accumulate = _op_wrapper(_tir_op.cooperative_tensor_multiply_accumulate)
create_barriers = _op_wrapper(_tir_op.create_barriers)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
Expand Down Expand Up @@ -2538,6 +2543,10 @@ def wrapped(*args, **kwargs):
"simdgroup_load",
"simdgroup_store",
"simdgroup_multiply_accumulate",
"cooperative_tensor_fill",
"cooperative_tensor_load",
"cooperative_tensor_store",
"cooperative_tensor_multiply_accumulate",
"create_barriers",
"mma_store",
"mma_fill",
Expand Down
111 changes: 109 additions & 2 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=redefined-builtin, invalid-name, too-many-arguments
"""Operators used in TIR expression."""

from typing import Any, Optional, Union

import tvm_ffi
Expand Down Expand Up @@ -186,10 +187,12 @@ def call_intrin(dtype, func_name, *args, annotations=None, span=None):
call : PrimExpr
The call expression.
"""

# Convert to TVM Map
if annotations is not None:
annotations = {k: tir.const(v) if isinstance(v, (int, bool)) else v for k, v in annotations.items()}
annotations = {
k: tir.const(v) if isinstance(v, (int, bool)) else v for k, v in annotations.items()
}
return Call(dtype, func_name, args, annotations=annotations, span=span)


Expand Down Expand Up @@ -1790,6 +1793,110 @@ def simdgroup_multiply_accumulate(
)


def cooperative_tensor_fill(
d: Var,
index: PrimExpr,
value: PrimExpr,
rows: int,
cols: int,
):
return call_intrin("handle", "tir.cooperative_tensor_fill", d, index, value, rows, cols)


def cooperative_tensor_load(
d: Var,
index: PrimExpr,
ptr: PrimExpr,
stride: PrimExpr,
rows: int,
cols: int,
transpose_matrix: bool = False,
mma_M: int = 0,
mma_N: int = 0,
mma_K: int = 0,
operand_role: int = 0,
):
return call_intrin(
"handle",
"tir.cooperative_tensor_load",
d,
index,
ptr,
stride,
rows,
cols,
transpose_matrix,
mma_M,
mma_N,
mma_K,
operand_role,
)


def cooperative_tensor_store(
d: PrimExpr,
index: PrimExpr,
ptr: PrimExpr,
stride: PrimExpr,
rows: int,
cols: int,
transpose_matrix: bool = False,
mma_M: int = 0,
mma_N: int = 0,
mma_K: int = 0,
operand_role: int = 0,
):
return call_intrin(
"handle",
"tir.cooperative_tensor_store",
d,
index,
ptr,
stride,
rows,
cols,
transpose_matrix,
mma_M,
mma_N,
mma_K,
operand_role,
)


def cooperative_tensor_multiply_accumulate(
d: Var,
index_d: PrimExpr,
a: Var,
index_a: PrimExpr,
b: Var,
index_b: PrimExpr,
c: Var,
index_c: PrimExpr,
M: int,
N: int,
K: int,
transpose_a: bool = False,
transpose_b: bool = False,
):
return call_intrin(
"handle",
"tir.cooperative_tensor_multiply_accumulate",
d,
index_d,
a,
index_a,
b,
index_b,
c,
index_c,
M,
N,
K,
transpose_a,
transpose_b,
)


def vectorlow(dtype, vec):
"""Get the low level half of the vector

Expand Down
8 changes: 7 additions & 1 deletion src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,13 @@ void WriteToFile(const ffi::String& file_name, const ffi::String& format) const

if (fmt_ == "metal") {
MTLCompileOptions* opts = [MTLCompileOptions alloc];
opts.languageVersion = MTLLanguageVersion2_3;
#if defined(__MAC_OS_X_VERSION_MAX_ALLOWED) && __MAC_OS_X_VERSION_MAX_ALLOWED >= 260000
opts.languageVersion = MTLLanguageVersion4_0;
#elif defined(__MAC_OS_X_VERSION_MAX_ALLOWED) && __MAC_OS_X_VERSION_MAX_ALLOWED >= 140000
opts.languageVersion = MTLLanguageVersion3_1;
#else
opts.languageVersion = MTLLanguageVersion3_0;
#endif
opts.fastMathEnabled = YES;
// opts = nil;
lib =
Expand Down
7 changes: 7 additions & 0 deletions src/runtime/thread_storage_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ enum class StorageRank {
kMMAMatrixC = 11,
/*! \brief Metal SIMD group memory */
kMetalSimdGroup = 12,
/*! \brief Metal cooperative_tensor memory (MetalPerformancePrimitives) */
kMetalCooperativeTensor = 13,
};

/*!
Expand Down Expand Up @@ -129,6 +131,8 @@ struct StorageScope {
return "m16n8k8.matrixC" + tag;
case StorageRank::kMetalSimdGroup:
return "metal.simdgroup" + tag;
case StorageRank::kMetalCooperativeTensor:
return "metal.cooperative_tensor" + tag;
default:
LOG(FATAL) << "unknown storage scope";
}
Expand Down Expand Up @@ -181,6 +185,9 @@ struct StorageScope {
} else if (s.compare(0, 15, "metal.simdgroup") == 0) {
r.rank = StorageRank::kMetalSimdGroup;
r.tag = s.substr(15, std::string::npos);
} else if (s.compare(0, 24, "metal.cooperative_tensor") == 0) {
r.rank = StorageRank::kMetalCooperativeTensor;
r.tag = s.substr(24, std::string::npos);
} else {
LOG(FATAL) << "unknown storage scope " << s;
}
Expand Down
12 changes: 12 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,18 @@ TIR_DEFINE_BUILTIN_FUNC(simdgroup_store)
TIR_DEFINE_BUILTIN_FUNC(simdgroup_multiply_accumulate)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_fill)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_load)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_store)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_multiply_accumulate)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Expand Down