diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index e7b8cac9be15..c0a8f6b3ab1d 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -806,6 +806,51 @@ TVM_DLL const Op& simdgroup_store(); */ TVM_DLL const Op& simdgroup_multiply_accumulate(); +// Metal cooperative_tensor intrinsics (MetalPerformancePrimitives / Metal 4) + +/*! + * \brief Fill a cooperative_tensor with a given value. + * + * void cooperative_tensor_fill(Var d, PrimExpr index, PrimExpr value, + * int rows, int cols); + */ +TVM_DLL const Op& cooperative_tensor_fill(); + +/*! + * \brief Load data from device or threadgroup memory into a cooperative_tensor. + * + * void cooperative_tensor_load(Var d, PrimExpr index, PrimExpr ptr, + * PrimExpr stride, int rows, int cols, + * bool transpose_matrix, + * int mma_M, int mma_N, int mma_K, + * int operand_role); + * operand_role: 0=left(A), 1=right(B), 2=destination(C) + */ +TVM_DLL const Op& cooperative_tensor_load(); + +/*! + * \brief Store data from a cooperative_tensor to device or threadgroup memory. + * + * void cooperative_tensor_store(Var d, PrimExpr index, PrimExpr ptr, + * PrimExpr stride, int rows, int cols, + * bool transpose_matrix, + * int mma_M, int mma_N, int mma_K, + * int operand_role); + * operand_role: 0=left(A), 1=right(B), 2=destination(C) + */ +TVM_DLL const Op& cooperative_tensor_store(); + +/*! + * \brief Multiply and accumulate two matrices using cooperative_tensor + * (MetalPerformancePrimitives matmul2d). + * + * void cooperative_tensor_multiply_accumulate( + * Var d, PrimExpr index_d, Var a, PrimExpr index_a, + * Var b, PrimExpr index_b, Var c, PrimExpr index_c, + * int M, int N, int K, bool transpose_a, bool transpose_b); + */ +TVM_DLL const Op& cooperative_tensor_multiply_accumulate(); + // TODO(tvm-team) replace the usage of the vector operations by Shuffle. /*! * \brief Get the high level half of the vector diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 5255c85a3ede..1c9f78fc780c 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1430,6 +1430,7 @@ def func( return func + if TYPE_CHECKING: class int8: ... class int16: ... @@ -2232,6 +2233,10 @@ def wrapped(*args, **kwargs): simdgroup_load = _op_wrapper(_tir_op.simdgroup_load) simdgroup_store = _op_wrapper(_tir_op.simdgroup_store) simdgroup_multiply_accumulate = _op_wrapper(_tir_op.simdgroup_multiply_accumulate) +cooperative_tensor_fill = _op_wrapper(_tir_op.cooperative_tensor_fill) +cooperative_tensor_load = _op_wrapper(_tir_op.cooperative_tensor_load) +cooperative_tensor_store = _op_wrapper(_tir_op.cooperative_tensor_store) +cooperative_tensor_multiply_accumulate = _op_wrapper(_tir_op.cooperative_tensor_multiply_accumulate) create_barriers = _op_wrapper(_tir_op.create_barriers) assume = _op_wrapper(_tir_op.assume) undef = _op_wrapper(_tir_op.undef) @@ -2538,6 +2543,10 @@ def wrapped(*args, **kwargs): "simdgroup_load", "simdgroup_store", "simdgroup_multiply_accumulate", + "cooperative_tensor_fill", + "cooperative_tensor_load", + "cooperative_tensor_store", + "cooperative_tensor_multiply_accumulate", "create_barriers", "mma_store", "mma_fill", diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 2e96d98489a8..96471ce9fb1e 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=redefined-builtin, invalid-name, too-many-arguments """Operators used in TIR expression.""" + from typing import Any, Optional, Union import tvm_ffi @@ -186,10 +187,12 @@ def call_intrin(dtype, func_name, *args, annotations=None, span=None): call : PrimExpr The call expression. """ - + # Convert to TVM Map if annotations is not None: - annotations = {k: tir.const(v) if isinstance(v, (int, bool)) else v for k, v in annotations.items()} + annotations = { + k: tir.const(v) if isinstance(v, (int, bool)) else v for k, v in annotations.items() + } return Call(dtype, func_name, args, annotations=annotations, span=span) @@ -1790,6 +1793,110 @@ def simdgroup_multiply_accumulate( ) +def cooperative_tensor_fill( + d: Var, + index: PrimExpr, + value: PrimExpr, + rows: int, + cols: int, +): + return call_intrin("handle", "tir.cooperative_tensor_fill", d, index, value, rows, cols) + + +def cooperative_tensor_load( + d: Var, + index: PrimExpr, + ptr: PrimExpr, + stride: PrimExpr, + rows: int, + cols: int, + transpose_matrix: bool = False, + mma_M: int = 0, + mma_N: int = 0, + mma_K: int = 0, + operand_role: int = 0, +): + return call_intrin( + "handle", + "tir.cooperative_tensor_load", + d, + index, + ptr, + stride, + rows, + cols, + transpose_matrix, + mma_M, + mma_N, + mma_K, + operand_role, + ) + + +def cooperative_tensor_store( + d: PrimExpr, + index: PrimExpr, + ptr: PrimExpr, + stride: PrimExpr, + rows: int, + cols: int, + transpose_matrix: bool = False, + mma_M: int = 0, + mma_N: int = 0, + mma_K: int = 0, + operand_role: int = 0, +): + return call_intrin( + "handle", + "tir.cooperative_tensor_store", + d, + index, + ptr, + stride, + rows, + cols, + transpose_matrix, + mma_M, + mma_N, + mma_K, + operand_role, + ) + + +def cooperative_tensor_multiply_accumulate( + d: Var, + index_d: PrimExpr, + a: Var, + index_a: PrimExpr, + b: Var, + index_b: PrimExpr, + c: Var, + index_c: PrimExpr, + M: int, + N: int, + K: int, + transpose_a: bool = False, + transpose_b: bool = False, +): + return call_intrin( + "handle", + "tir.cooperative_tensor_multiply_accumulate", + d, + index_d, + a, + index_a, + b, + index_b, + c, + index_c, + M, + N, + K, + transpose_a, + transpose_b, + ) + + def vectorlow(dtype, vec): """Get the low level half of the vector diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index ff0101ac9a92..a9404a8f0645 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -103,7 +103,13 @@ void WriteToFile(const ffi::String& file_name, const ffi::String& format) const if (fmt_ == "metal") { MTLCompileOptions* opts = [MTLCompileOptions alloc]; - opts.languageVersion = MTLLanguageVersion2_3; +#if defined(__MAC_OS_X_VERSION_MAX_ALLOWED) && __MAC_OS_X_VERSION_MAX_ALLOWED >= 260000 + opts.languageVersion = MTLLanguageVersion4_0; +#elif defined(__MAC_OS_X_VERSION_MAX_ALLOWED) && __MAC_OS_X_VERSION_MAX_ALLOWED >= 140000 + opts.languageVersion = MTLLanguageVersion3_1; +#else + opts.languageVersion = MTLLanguageVersion3_0; +#endif opts.fastMathEnabled = YES; // opts = nil; lib = diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index d085ed40613f..c528abe99bf2 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -71,6 +71,8 @@ enum class StorageRank { kMMAMatrixC = 11, /*! \brief Metal SIMD group memory */ kMetalSimdGroup = 12, + /*! \brief Metal cooperative_tensor memory (MetalPerformancePrimitives) */ + kMetalCooperativeTensor = 13, }; /*! @@ -129,6 +131,8 @@ struct StorageScope { return "m16n8k8.matrixC" + tag; case StorageRank::kMetalSimdGroup: return "metal.simdgroup" + tag; + case StorageRank::kMetalCooperativeTensor: + return "metal.cooperative_tensor" + tag; default: LOG(FATAL) << "unknown storage scope"; } @@ -181,6 +185,9 @@ struct StorageScope { } else if (s.compare(0, 15, "metal.simdgroup") == 0) { r.rank = StorageRank::kMetalSimdGroup; r.tag = s.substr(15, std::string::npos); + } else if (s.compare(0, 24, "metal.cooperative_tensor") == 0) { + r.rank = StorageRank::kMetalCooperativeTensor; + r.tag = s.substr(24, std::string::npos); } else { LOG(FATAL) << "unknown storage scope " << s; } diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 6ce2ae09e2da..8130268f70d7 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -355,6 +355,18 @@ TIR_DEFINE_BUILTIN_FUNC(simdgroup_store) TIR_DEFINE_BUILTIN_FUNC(simdgroup_multiply_accumulate) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_fill) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_load) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_store) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_multiply_accumulate) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(vectorhigh) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TScriptDtypePrintLocation",