Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/backend/cuda/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2246,6 +2246,38 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
ss << "tl::tma_store";
}
print_extern_call_stmt(ss.str(), 0, 2);
} else if (op->op.same_as(tl::tma_load_gather4())) {
std::ostringstream ss;
ICHECK_EQ(op->args.size(), 9u)
<< "tma_load_gather4 expects 9 args (desc, mbar, smem, col, "
"r0..r3, eviction_policy), got "
<< op->args.size();
auto eviction_policy =
this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value];
if (eviction_policy != "EVICT_NORMAL") {
ss << "tl::tma_load_gather4<tl::CacheHintSm90::" << eviction_policy
<< ">";
} else {
ss << "tl::tma_load_gather4";
}
print_extern_call_stmt(ss.str(), 0, 1);
} else if (op->op.same_as(tl::tma_store_scatter4())) {
std::ostringstream ss;
ICHECK_EQ(op->args.size(), 8u)
<< "tma_store_scatter4 expects 8 args (desc, smem, col, r0..r3, "
"eviction_policy), got "
<< op->args.size();
auto eviction_policy =
this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value];
if (eviction_policy != "EVICT_NORMAL") {
ss << "tl::tma_store_scatter4<tl::CacheHintSm90::" << eviction_policy
<< ">";
} else {
ss << "tl::tma_store_scatter4";
}
print_extern_call_stmt(ss.str(), 0, 1);
} else if (op->op.same_as(tl::ptx_ldmatrix())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
Expand Down
160 changes: 160 additions & 0 deletions src/backend/cuda/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ struct Copy {
static Stmt LowerBulk(const CopyNode &op, const LowerArgs &T,
arith::Analyzer *analyzer, CopyInst copy_inst);

static Stmt LowerBulkGather4(const CopyNode &op, const LowerArgs &T,
arith::Analyzer *analyzer, CopyInst copy_inst);

static Stmt LowerBulk1D(const CopyNode &op, const LowerArgs &T,
arith::Analyzer *analyzer, CopyInst copy_inst);
};
Expand Down Expand Up @@ -445,6 +448,11 @@ Stmt Copy::Lower(const CopyNode &op, const LowerArgs &T,
auto bulk_copy = LowerBulk(op, T, analyzer, copy_inst);
ICHECK(bulk_copy.defined()) << "Failed to lower bulk load/store";
return bulk_copy;
} else if (copy_inst == CopyInst::kBulkLoadGather4 ||
copy_inst == CopyInst::kBulkStoreScatter4) {
auto bulk_copy = LowerBulkGather4(op, T, analyzer, copy_inst);
ICHECK(bulk_copy.defined()) << "Failed to lower tma gather4/scatter4";
return bulk_copy;
} else if (copy_inst == CopyInst::kLDSM || copy_inst == CopyInst::kSTSM) {
auto ldsm_copy = LowerLDSM(op, T, analyzer, copy_inst);
ICHECK(ldsm_copy.defined()) << "Failed to lower ptx matrix copy";
Expand Down Expand Up @@ -1267,6 +1275,158 @@ Stmt Copy::LowerBulk(const CopyNode &op, const LowerArgs &T,
return tma_copy;
}

namespace {

Array<PrimExpr> GetGather4Rows(const CopyNode &op) {
if (auto val = op.annotations.Get("gather4_rows")) {
return Downcast<Array<PrimExpr>>(val.value());
}
return {};
}

PrimExpr GetGather4Col(const CopyNode &op) {
if (auto val = op.annotations.Get("gather4_col")) {
return Downcast<PrimExpr>(val.value());
}
return PrimExpr();
}

} // namespace

Stmt Copy::LowerBulkGather4(const CopyNode &op, const LowerArgs &T,
arith::Analyzer *analyzer, CopyInst copy_inst) {
ICHECK(copy_inst == CopyInst::kBulkLoadGather4 ||
copy_inst == CopyInst::kBulkStoreScatter4);
bool is_load = copy_inst == CopyInst::kBulkLoadGather4;

Buffer global_tensor = is_load ? op.src : op.dst;
Buffer shared_tensor = is_load ? op.dst : op.src;
Buffer shared_tensor_unmapped = shared_tensor;

ICHECK_EQ(global_tensor->shape.size(), 2u);
ICHECK_EQ(shared_tensor->shape.size(), 2u);
auto shared_lead = as_const_int(shared_tensor->shape[0]);
ICHECK(shared_lead != nullptr && *shared_lead == 4)
<< "tma_gather4/scatter4 shared tile leading dim must be 4, got "
<< shared_tensor->shape[0];
ICHECK_EQ(global_tensor->dtype, shared_tensor->dtype);

Array<PrimExpr> rows = GetGather4Rows(op);
PrimExpr col = GetGather4Col(op);
ICHECK_EQ(rows.size(), 4u);
ICHECK(col.defined());

TMADesc desc;
desc.rank = 2;
desc.data_type = to_CUtensorMapDataType(global_tensor->dtype);
desc.global_addr = global_tensor->data;
desc.global_shape = ReverseArray(global_tensor->shape);

if (!global_tensor->strides.empty()) {
desc.global_stride = ReverseArray(global_tensor->strides);
} else {
PrimExpr stride = 1;
desc.global_stride.reserve(2);
for (size_t i = 0; i < global_tensor->shape.size(); ++i) {
desc.global_stride.push_back(stride);
stride *= global_tensor->shape[global_tensor->shape.size() - 1 - i];
}
}
ICHECK(is_one(desc.global_stride[0]))
<< "tma_gather4/scatter4 requires unit innermost global stride, got "
<< desc.global_stride;
desc.global_stride = desc.global_stride.Map([&](PrimExpr e) {
return TMABytesFromElements(e, global_tensor->dtype);
});
for (size_t i = 1; i < desc.global_stride.size(); ++i) {
if (auto stride = desc.global_stride[i].as<IntImmNode>()) {
ICHECK(stride->value % 16 == 0 && stride->value < (1LL << 40))
<< "tma_gather4/scatter4 global stride[" << i
<< "] = " << stride->value
<< " bytes must be 16-byte aligned and < 2^40";
}
}

// The descriptor's row box-dim must be 1, not 4. The four-row pack is
// implicit in the cp.async.bulk.tensor.tile::gather4 PTX, which takes 4
// row coordinates and materializes them into 4 logical rows of the shared
// tile. Setting box[1]=4 here would describe a contiguous 4-row strip; the
// gather4 unrolling would then read OOB → CUDA_ERROR_ILLEGAL_INSTRUCTION.
PrimExpr K_box = shared_tensor->shape[1];
if (auto k = as_const_int(K_box)) {
int64_t k_bytes = TMABytesFromElements(*k, shared_tensor->dtype);
ICHECK(k_bytes % 16 == 0)
<< "tma_gather4/scatter4 K_box * dtype.bytes() = " << k_bytes
<< " must be 16-byte aligned";
}
desc.smem_box = {K_box, IntImm(DataType::Int(32), 1)};
desc.smem_stride = {IntImm(DataType::Int(32), 1),
IntImm(DataType::Int(32), 1)};
desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE);
desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);

Layout shared_layout;
if (T.layout_map.count(shared_tensor)) {
shared_layout = T.layout_map.at(shared_tensor);
ICHECK(T.buffer_remap.count(shared_tensor));
shared_tensor = T.buffer_remap.at(shared_tensor);
}
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
if (shared_layout.defined() && shared_layout->InputDim() >= 2) {
SwizzleMode mode = DetectSwizzleMode(shared_layout, shared_tensor_unmapped);
if (mode == SwizzleMode::kQuarter) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_32B);
} else if (mode == SwizzleMode::kHalf) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B);
} else if (mode == SwizzleMode::kFull) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
}
}
if (auto k = as_const_int(K_box)) {
int64_t k_bytes = TMABytesFromElements(*k, shared_tensor->dtype);
int max_bytes = 0;
if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_32B))
max_bytes = 32;
else if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B))
max_bytes = 64;
else if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B))
max_bytes = 128;
if (max_bytes > 0) {
ICHECK(k_bytes <= max_bytes)
<< "tma_gather4/scatter4 K_box * dtype.bytes() = " << k_bytes
<< " exceeds " << max_bytes << "B swizzle limit";
}
}

Call create_descriptor =
Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs());

PrimExpr total_elements = 4 * K_box;
PrimExpr smem_addr =
shared_tensor.access_ptr(is_load ? 2 : 1, DataType::Handle(), 1,
IntImm(DataType::Int(32), 0), total_elements);

Array<PrimExpr> args;
args.push_back(create_descriptor);
if (is_load) {
auto user_barrier = op.annotations.Get("barrier");
ICHECK(user_barrier.has_value())
<< "tma_gather4 requires a 'barrier' annotation";
args.push_back(Downcast<PrimExpr>(user_barrier.value()));
}
args.push_back(smem_addr);
args.push_back(col);
for (auto r : rows)
args.push_back(r);
args.push_back(IntImm(DataType::Int(32), GetEvictionPolicy(op)));

// Fire-and-forget: caller manages mbarrier_expect_tx / wait (loads) and
// tma_store_arrive / wait (stores), and the leader-thread guard.
auto tl_op = is_load ? tma_load_gather4() : tma_store_scatter4();
return Evaluate(Call(DataType::Handle(), tl_op, args));
}

Stmt Copy::LowerBulk1D(const CopyNode &op, const LowerArgs &T,
arith::Analyzer *analyzer, CopyInst copy_inst) {
const Buffer &src = op.src;
Expand Down
2 changes: 2 additions & 0 deletions src/backend/cuda/op/copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ enum class CopyInst : uint8_t {
kBulkStore1D = 7,
kTMemLoad = 8,
kTMemStore = 9,
kBulkLoadGather4 = 10, // tma cp.async.bulk.tensor.tile::gather4 (sm_100a)
kBulkStoreScatter4 = 11, // tma cp.async.bulk.tensor.tile::scatter4 (sm_100a)
kInvalid = 255,
};

Expand Down
14 changes: 14 additions & 0 deletions src/backend/cuda/op/copy_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ const char *CopyInstToString(CopyInst inst) {
return "TMemLoad";
case CopyInst::kTMemStore:
return "TMemStore";
case CopyInst::kBulkLoadGather4:
return "BulkLoadGather4";
case CopyInst::kBulkStoreScatter4:
return "BulkStoreScatter4";
case CopyInst::kInvalid:
return "Invalid";
default:
Expand Down Expand Up @@ -520,6 +524,16 @@ CopyFacts AnalyzeCopyFacts(const CopyNode &op, const CopyAnalysisContext &ctx) {

CopyInstSelection SelectCopyInstForLowering(const CopyNode &op,
const CopyAnalysisContext &ctx) {
// tile::gather4 / scatter4 markers take precedence over generic TMA paths.
// The IR carries explicit row indices via annotations and must always be
// lowered through LowerBulkCopyGather4 (no fallback path makes sense).
if (GetBoolAnnotation(op, "is_gather4")) {
return Supported(CopyInst::kBulkLoadGather4);
}
if (GetBoolAnnotation(op, "is_scatter4")) {
return Supported(CopyInst::kBulkStoreScatter4);
}

CopyFacts facts = AnalyzeCopyFacts(op, ctx);
if (facts.explicit_tma) {
CopyInst inst =
Expand Down
22 changes: 15 additions & 7 deletions src/layout/gemm_layouts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,10 @@ static Layout MakeQuarterBankSwizzleLayout2D(int stride, int continuous,
Var i = InputPlaceholder(0);
Var j = InputPlaceholder(1);
int vector_size = 128 / element_size;
ICHECK(stride % 8 == 0) << "stride=" << stride;
// stride==4 is a truncated 4-row period used by tile::gather4/scatter4
// (s=i%8 ∈ [0,4) is a valid subset of the 8-row XOR pattern, matching
// what TMA applies per-row in hardware).
ICHECK(stride == 4 || stride % 8 == 0) << "stride=" << stride;
Comment thread
coderabbitai[bot] marked this conversation as resolved.
ICHECK(continuous % (vector_size * 2) == 0)
<< "continuous=" << continuous << ", vector_size=" << vector_size;
PrimExpr ts = FloorDiv(i, 8);
Expand Down Expand Up @@ -459,7 +462,8 @@ static Layout MakeHalfBankSwizzleLayout2D(int stride, int continuous,
Var i = InputPlaceholder(0);
Var j = InputPlaceholder(1);
int vector_size = 128 / element_size;
ICHECK(stride % 8 == 0) << "stride=" << stride;
// See MakeQuarterBankSwizzleLayout2D for stride==4 rationale.
ICHECK(stride == 4 || stride % 8 == 0) << "stride=" << stride;
ICHECK(continuous % (vector_size * 4) == 0)
<< "continuous=" << continuous << ", vector_size=" << vector_size;
PrimExpr ts = FloorDiv(i, 8);
Expand Down Expand Up @@ -487,7 +491,8 @@ static Layout MakeFullBankSwizzleLayout2D(int stride, int continuous,
Var i = InputPlaceholder(0);
Var j = InputPlaceholder(1);
int vector_size = 128 / element_size;
ICHECK(stride % 8 == 0) << "stride=" << stride;
// See MakeQuarterBankSwizzleLayout2D for stride==4 rationale.
ICHECK(stride == 4 || stride % 8 == 0) << "stride=" << stride;
ICHECK(continuous % (vector_size * 8) == 0)
<< "continuous=" << continuous << ", vector_size=" << vector_size;
PrimExpr ts = FloorDiv(i, 8);
Expand Down Expand Up @@ -949,20 +954,23 @@ SwizzleMode DetectSwizzleMode(const Layout &layout, const Buffer &buffer) {
int vector_size = 128 / info.element_size;

// Check from smallest to largest granularity
// Need to verify stride and continuous constraints before comparing
if (info.stride % 8 == 0 &&
// Need to verify stride and continuous constraints before comparing.
// stride==4 is the truncated 4-row period used by tile::gather4/scatter4
// (see Make{Quarter,Half,Full}BankSwizzleLayout2D).
bool stride_ok = info.stride == 4 || info.stride % 8 == 0;
if (stride_ok &&
info.continuous % (static_cast<int64_t>(vector_size) * 2) == 0) {
if (StructuralEqual()(layout, makeQuarterBankSwizzleLayout(buffer))) {
return SwizzleMode::kQuarter;
}
}
if (info.stride % 8 == 0 &&
if (stride_ok &&
info.continuous % (static_cast<int64_t>(vector_size) * 4) == 0) {
if (StructuralEqual()(layout, makeHalfBankSwizzleLayout(buffer))) {
return SwizzleMode::kHalf;
}
}
if (info.stride % 8 == 0 &&
if (stride_ok &&
info.continuous % (static_cast<int64_t>(vector_size) * 8) == 0) {
if (StructuralEqual()(layout, makeFullBankSwizzleLayout(buffer))) {
return SwizzleMode::kFull;
Expand Down
10 changes: 10 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,16 @@ TIR_DEFINE_TL_BUILTIN(tma_load_im2col)
TIR_DEFINE_TL_BUILTIN(tma_store).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tma_load_gather4)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tma_store_scatter4)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_fence_barrier_init)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down
25 changes: 25 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,31 @@ TVM_DLL const Op &tma_load_im2col();
*/
TVM_DLL const Op &tma_store();

/*!
* \brief tvm intrinsics for tile::gather4 TMA load (sm_90+).
*
* Loads four rows from a 2D global tensor (described by a tiled CUtensorMap)
* into a shared memory tile. The four rows can be at arbitrary indices.
*
* tma_load_gather4(descriptor, mbarrier, smem_data, col,
* row0, row1, row2, row3, eviction_policy)
*
* The descriptor must be encoded with rank=2 and box dim along axis 1 = 1
* (the four-row pack is implicit in the gather4 PTX mode).
*/
TVM_DLL const Op &tma_load_gather4();

/*!
* \brief tvm intrinsics for tile::scatter4 TMA store (sm_90+).
*
* Stores four shared-memory rows back to four arbitrary rows of a 2D global
* tensor (described by a tiled CUtensorMap).
*
* tma_store_scatter4(descriptor, smem_data, col,
* row0, row1, row2, row3, eviction_policy)
*/
TVM_DLL const Op &tma_store_scatter4();

/*!
* \brief tvm intrinsics for barrier initialization fence
*
Expand Down
Loading
Loading