Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions src/backend/common/op/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,9 +569,16 @@ template <typename Impl> struct ReduceLowerer {
std::string reducer =
reduce::MakeCodegenReducer(op, can_batch_pack ? vsize : 1)
.value();
std::string allreduce = Impl::MakeBatchAllReduce(
reducer, reducing_threads, *scale, thread_offset,
T.thread_bounds->extent, eff_batch, reducing_threads, T.target);
std::string allreduce =
can_batch_pack
? Impl::MakeBatchAllReduce(reducer, reducing_threads, *scale,
thread_offset,
T.thread_bounds->extent, eff_batch,
reducing_threads, T.target)
: Impl::MakeBatchAllReduceOffset(
reducer, reducing_threads, *scale, thread_offset,
T.thread_bounds->extent, batch, reducing_threads,
T.target);

DataType ws_dtype = can_batch_pack
? clear_buffer->dtype.with_lanes(vsize)
Expand Down Expand Up @@ -674,16 +681,8 @@ template <typename Impl> struct ReduceLowerer {
} else {
for (int chunk = 0; chunk < num_chunks; chunk++) {
int64_t flat_offset = static_cast<int64_t>(chunk) * batch;
Array<PrimExpr> chunk_indices;
for (int d = 0; d < buf_ndim; d++) {
int64_t idx =
(flat_offset / buf_strides[d]) % buf_shape_vals[d];
chunk_indices.push_back(Integer(idx));
}
PrimExpr ptr = Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(clear_buffer, chunk_indices)});

Array<PrimExpr> args = {StringImm(allreduce), ptr};
Array<PrimExpr> args = {StringImm(allreduce), clear_buffer->data,
Integer(flat_offset)};
if (need_workspace) {
args.push_back(workspace);
}
Expand Down
16 changes: 16 additions & 0 deletions src/backend/cuda/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ struct Reduce : backend::ReduceLowerer<Reduce> {
return ss.str();
}

static std::string
MakeBatchAllReduceOffset(std::string reducer, int reducing_threads, int scale,
PrimExpr thread_offset, PrimExpr all_threads,
int batch, int workspace_stride, Target target) {
std::stringstream ss;
ss << "tl::AllReduce<" << reducer << ", " << reducing_threads << ", "
<< scale << ", " << thread_offset;
if (TargetHasSMVersionGE(target, 90)) {
ss << ", tl::NamedBarrier<" << all_threads << ">";
} else {
ss << ", tl::SyncThreadsBarrier";
}
ss << ", " << batch << ", " << workspace_stride << ">::run_batch_offset";
return ss.str();
}

static std::string MakeScalarAllReduce(std::string reducer,
int reducing_threads, int scale,
PrimExpr thread_offset,
Expand Down
48 changes: 48 additions & 0 deletions src/backend/maca/codegen/codegen_maca.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1578,6 +1578,39 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) {
<< "maca_barrier_arrive_and_wait expects 1 argument (bar)";
std::string dummyRet = this->PrintExpr(op->args[0]);
this->stream << "barrier_arrive_and_wait(" << dummyRet << ");\n";
} else if (op->op.same_as(tl::maca_ldg_b128_bsm_predicator()) ||
op->op.same_as(tl::maca_ldg_b64_bsm_predicator())) {
ICHECK_EQ(op->args.size(), 10U)
<< "maca_ldg_b{64,128}_bsm_predicator expects 10 arguments "
"<dst, src, offset, cache_global, cache_shared, evict, wait, "
"cmp_lhs, cmp_rhs, cmp_type>";
std::string builtin_name =
op->op.same_as(tl::maca_ldg_b128_bsm_predicator())
? "__builtin_mxc_ldg_b128_bsm_predicator"
: "__builtin_mxc_ldg_b64_bsm_predicator";
this->PrintIndent();
this->stream << builtin_name << "(";
for (size_t i = 0; i < 9; ++i) {
if (i > 0) {
this->stream << ", ";
}
this->stream << this->PrintExpr(op->args[i]);
}
this->stream << ", " << Downcast<StringImm>(op->args[9])->value << ");\n";
} else if (op->op.same_as(tl::maca_arrive_gvmcnt())) {
ICHECK_EQ(op->args.size(), 1U);
this->PrintIndent();
this->stream << "__builtin_mxc_arrive_gvmcnt("
<< this->PrintExpr(op->args[0]) << ");\n";
} else if (op->op.same_as(tl::maca_arrive_bsmcnt())) {
ICHECK_EQ(op->args.size(), 1U);
this->PrintIndent();
this->stream << "__builtin_mxc_arrive_bsmcnt("
<< this->PrintExpr(op->args[0]) << ");\n";
} else if (op->op.same_as(tl::maca_barrier_inst())) {
ICHECK_EQ(op->args.size(), 0U);
this->PrintIndent();
this->stream << "__builtin_mxc_barrier_inst();\n";
} else if (op->op.same_as(builtin::create_barriers())) {
this->PrintIndent();
int barrier_count = Downcast<IntImm>(op->args[0])->value;
Expand Down Expand Up @@ -2016,6 +2049,21 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) {
os << ")";
} else if (op->op.same_as(builtin::thread_return())) {
os << "return";
} else if (op->op.same_as(tl::tl_gemm())) {
ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments <op_instance, "
"A_ptr, B_ptr, C_ptr>, but got "
<< op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
op_instance->value, op->args, true, os);
} else if (op->op.same_as(tl::tl_gemm_wsm())) {
ICHECK(op->args.size() == 7)
<< "tl_gemm_wsm expects 7 arguments <op_instance, A_ptr, B_ptr, "
"C_ptr, WSM_ptr, A_source_ptr, B_source_ptr>, but got "
<< op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
op_instance->value, op->args, true, os);
} else if (op->op.same_as(tl::tl_gemm_sp())) {
ICHECK(op->args.size() == 5)
<< "tl_gemm_sp expects 5 arguments <op_instance, A_ptr, B_ptr, C_ptr, "
Expand Down
11 changes: 11 additions & 0 deletions src/backend/maca/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ struct Reduce : backend::ReduceLowerer<Reduce> {
return ss.str();
}

static std::string
MakeBatchAllReduceOffset(std::string reducer, int reducing_threads, int scale,
PrimExpr thread_offset, PrimExpr all_threads,
int batch, int workspace_stride, Target target) {
std::stringstream ss;
ss << "tl::AllReduce<" << reducer << ", " << reducing_threads << ", "
<< scale << ", " << thread_offset << ", tl::SyncThreadsBarrier"
<< ", " << batch << ", " << workspace_stride << ">::run_batch_offset";
return ss.str();
}

static std::string MakeScalarAllReduce(std::string reducer,
int reducing_threads, int scale,
PrimExpr thread_offset,
Expand Down
12 changes: 12 additions & 0 deletions src/backend/rocm/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ struct Reduce : backend::ReduceLowerer<Reduce> {
return ss.str();
}

static std::string MakeBatchAllReduceOffset(std::string reducer,
int reducing_threads, int scale,
PrimExpr thread_offset, PrimExpr,
int batch, int workspace_stride,
Target) {
std::stringstream ss;
ss << "tl::AllReduce<" << reducer << ", " << reducing_threads << ", "
<< scale << ", " << thread_offset << ", " << batch << ", "
<< workspace_stride << ">::run_batch_offset";
return ss.str();
}

static std::string MakeScalarAllReduce(std::string reducer,
int reducing_threads, int scale,
PrimExpr thread_offset, PrimExpr,
Expand Down
31 changes: 29 additions & 2 deletions src/layout/gemm_layouts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,12 @@ Fragment makeGemmFragmentCMACA(const int block_m, const int block_n,
PrimExpr forward_thread = 16 * FloorDiv(j->var, 4) + i;
PrimExpr index = FloorMod(j->var, 4);
auto base_layout = Fragment({i, j}, {index}, forward_thread, rep);
auto warp_layout =
// Match CUTE partition_shape_C accumulator storage: the thread-group
// repeat must be applied before the per-thread 16x16 tile repeats.
auto thread_layout =
base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
auto block_layout =
warp_layout->Repeat({warp_m / 16, warp_n / 16}, false, false);
thread_layout->Repeat({warp_m / 16, warp_n / 16}, false, false);
return block_layout;
}

Expand Down Expand Up @@ -1034,6 +1036,31 @@ Layout makeGemmABLayoutMACA(int mat_stride, int mat_continuous, int continuity,
}
}

Layout makeMacaGemmABLayout(const Buffer &buffer, int kfactor) {
auto info = GetSwizzleShapeInfoChecked(buffer);
auto mat_stride = static_cast<int>(info.stride);
auto mat_continuous = static_cast<int>(info.continuous);
Layout base;
if (info.element_size == 16 && kfactor == 1 && mat_continuous % 64 == 0 &&
mat_stride % 64 != 32) {
Var i = InputPlaceholder(0);
Var j = InputPlaceholder(1);
constexpr int vector_size = 4;
PrimExpr ts = FloorDiv(i, 64);
PrimExpr s = FloorMod(FloorDiv(i, vector_size), 16);
PrimExpr tc = FloorDiv(j, 16);
PrimExpr c = FloorMod(j, 16);
PrimExpr vec = FloorMod(i, vector_size);
PrimExpr s_swizzle = xor16x16(s, c);
PrimExpr index = vec + (s_swizzle + c * 16) * vector_size;
base = Layout(Array<PrimExpr>{mat_stride, mat_continuous}, {tc, ts, index});
} else {
base = makeGemmABLayoutMACA(mat_stride, mat_continuous, mat_continuous,
info.element_size, kfactor);
}
return ExpandLayout2D(base, buffer);
}

Layout makeSwizzledLayout(const Buffer &buffer, bool k_inner, bool allow_pad) {
auto info = GetSwizzleShapeInfoChecked(buffer);
Layout base;
Expand Down
10 changes: 10 additions & 0 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,16 @@ TVM_FFI_STATIC_INIT_BLOCK() {
})
.def("tl.make_linear_layout",
[](Array<PrimExpr> shape) { return makeLinearLayout(shape); })
.def("tl.make_maca_gemm_ab_layout",
[](const Buffer &buffer, int kfactor) {
return makeMacaGemmABLayout(buffer, kfactor);
})
.def("tl.make_maca_gemm_fragment_c",
[](int block_m, int block_n, int warp_m, int warp_n,
int element_size) {
return makeGemmFragmentCMACA(block_m, block_n, warp_m, warp_n,
element_size);
})
.def("tl.make_gemm_fragment_8x8", []() { return makeGemmFragment8x8(); })
.def("tl.make_gemm_fragment_8x8_transposed",
[]() { return makeGemmFragment8x8Transposed(); })
Expand Down
1 change: 1 addition & 0 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
int kPack);
Layout makeGemmABLayoutMACA(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor);
Layout makeMacaGemmABLayout(const Buffer &buffer, int kfactor);

Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n,
const int warp_m, const int warp_n,
Expand Down
30 changes: 30 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,31 @@ TIR_DEFINE_TL_BUILTIN(ptx_cp_async)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(maca_ldg_b128_bsm_predicator)
.set_num_inputs(10)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(maca_ldg_b64_bsm_predicator)
.set_num_inputs(10)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(maca_arrive_gvmcnt)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(maca_arrive_bsmcnt)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(maca_barrier_inst)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(fence_proxy_async)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down Expand Up @@ -543,6 +568,11 @@ TIR_DEFINE_TL_BUILTIN(loop_break)
TIR_DEFINE_TL_BUILTIN(tl_gemm).set_num_inputs(4).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tl_gemm_wsm)
.set_num_inputs(7)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tl_gemm_sp)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down
29 changes: 29 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,23 @@ TVM_DLL const Op &ptx_cp_async_barrier_noinc();
*/
TVM_DLL const Op &ptx_cp_async();

/*!
* \brief MACA BSM direct global-to-shared load intrinsics.
*
* These ops expose the low-level C500 BSM load builtins used by handwritten
* MACA HGEMM kernels. They are intentionally target-specific and lower only in
* the MACA code generator.
*/
TVM_DLL const Op &maca_ldg_b128_bsm_predicator();
TVM_DLL const Op &maca_ldg_b64_bsm_predicator();

/*!
* \brief MACA BSM/global-memory wait and barrier intrinsics.
*/
TVM_DLL const Op &maca_arrive_gvmcnt();
TVM_DLL const Op &maca_arrive_bsmcnt();
TVM_DLL const Op &maca_barrier_inst();

/*!
* \brief Pack two b16 value into a b32 value
*
Expand Down Expand Up @@ -988,6 +1005,18 @@ TVM_DLL const Op &tvm_rdna_wmma_store();
*/
TVM_DLL const Op &tl_gemm();

/*!
* \brief tilelang intrinsic for MACA GEMM variants that need a wider generated
* source ABI than tl_gemm.
*
* The first use is the HGEMM WSM-aware compiler-path probe. It preserves the
* template-call style of tl_gemm while adding pointer arguments for generated
* shared/WSM storage plus the original global A/B source operands:
* T.call_intrin("handle", "tl.tl_gemm_wsm", op_instance_str,
* A_ptr, B_ptr, C_ptr, WSM_ptr, A_source_ptr, B_source_ptr)
*/
TVM_DLL const Op &tl_gemm_wsm();

/*!
* \brief tilelang intrinsic for sparse matrix multiplication (GEMM with
* sparsity).
Expand Down
Loading
Loading