diff --git a/src/backend/cuda/codegen/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc index d50b29183..287f9e4f3 100644 --- a/src/backend/cuda/codegen/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -2246,6 +2246,38 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ss << "tl::tma_store"; } print_extern_call_stmt(ss.str(), 0, 2); + } else if (op->op.same_as(tl::tma_load_gather4())) { + std::ostringstream ss; + ICHECK_EQ(op->args.size(), 9u) + << "tma_load_gather4 expects 9 args (desc, mbar, smem, col, " + "r0..r3, eviction_policy), got " + << op->args.size(); + auto eviction_policy = + this->eviction_policy_names_ + [op->args[op->args.size() - 1].as()->value]; + if (eviction_policy != "EVICT_NORMAL") { + ss << "tl::tma_load_gather4"; + } else { + ss << "tl::tma_load_gather4"; + } + print_extern_call_stmt(ss.str(), 0, 1); + } else if (op->op.same_as(tl::tma_store_scatter4())) { + std::ostringstream ss; + ICHECK_EQ(op->args.size(), 8u) + << "tma_store_scatter4 expects 8 args (desc, smem, col, r0..r3, " + "eviction_policy), got " + << op->args.size(); + auto eviction_policy = + this->eviction_policy_names_ + [op->args[op->args.size() - 1].as()->value]; + if (eviction_policy != "EVICT_NORMAL") { + ss << "tl::tma_store_scatter4"; + } else { + ss << "tl::tma_store_scatter4"; + } + print_extern_call_stmt(ss.str(), 0, 1); } else if (op->op.same_as(tl::ptx_ldmatrix())) { int trans = Downcast(op->args[0])->value; int num = Downcast(op->args[1])->value; diff --git a/src/backend/cuda/op/copy.cc b/src/backend/cuda/op/copy.cc index 14cbd6348..0cc1771e9 100644 --- a/src/backend/cuda/op/copy.cc +++ b/src/backend/cuda/op/copy.cc @@ -196,6 +196,9 @@ struct Copy { static Stmt LowerBulk(const CopyNode &op, const LowerArgs &T, arith::Analyzer *analyzer, CopyInst copy_inst); + static Stmt LowerBulkGather4(const CopyNode &op, const LowerArgs &T, + arith::Analyzer *analyzer, CopyInst copy_inst); + static Stmt LowerBulk1D(const CopyNode &op, const LowerArgs &T, arith::Analyzer *analyzer, CopyInst copy_inst); }; @@ -445,6 +448,11 @@ Stmt Copy::Lower(const CopyNode &op, const LowerArgs &T, auto bulk_copy = LowerBulk(op, T, analyzer, copy_inst); ICHECK(bulk_copy.defined()) << "Failed to lower bulk load/store"; return bulk_copy; + } else if (copy_inst == CopyInst::kBulkLoadGather4 || + copy_inst == CopyInst::kBulkStoreScatter4) { + auto bulk_copy = LowerBulkGather4(op, T, analyzer, copy_inst); + ICHECK(bulk_copy.defined()) << "Failed to lower tma gather4/scatter4"; + return bulk_copy; } else if (copy_inst == CopyInst::kLDSM || copy_inst == CopyInst::kSTSM) { auto ldsm_copy = LowerLDSM(op, T, analyzer, copy_inst); ICHECK(ldsm_copy.defined()) << "Failed to lower ptx matrix copy"; @@ -1267,6 +1275,158 @@ Stmt Copy::LowerBulk(const CopyNode &op, const LowerArgs &T, return tma_copy; } +namespace { + +Array GetGather4Rows(const CopyNode &op) { + if (auto val = op.annotations.Get("gather4_rows")) { + return Downcast>(val.value()); + } + return {}; +} + +PrimExpr GetGather4Col(const CopyNode &op) { + if (auto val = op.annotations.Get("gather4_col")) { + return Downcast(val.value()); + } + return PrimExpr(); +} + +} // namespace + +Stmt Copy::LowerBulkGather4(const CopyNode &op, const LowerArgs &T, + arith::Analyzer *analyzer, CopyInst copy_inst) { + ICHECK(copy_inst == CopyInst::kBulkLoadGather4 || + copy_inst == CopyInst::kBulkStoreScatter4); + bool is_load = copy_inst == CopyInst::kBulkLoadGather4; + + Buffer global_tensor = is_load ? op.src : op.dst; + Buffer shared_tensor = is_load ? op.dst : op.src; + Buffer shared_tensor_unmapped = shared_tensor; + + ICHECK_EQ(global_tensor->shape.size(), 2u); + ICHECK_EQ(shared_tensor->shape.size(), 2u); + auto shared_lead = as_const_int(shared_tensor->shape[0]); + ICHECK(shared_lead != nullptr && *shared_lead == 4) + << "tma_gather4/scatter4 shared tile leading dim must be 4, got " + << shared_tensor->shape[0]; + ICHECK_EQ(global_tensor->dtype, shared_tensor->dtype); + + Array rows = GetGather4Rows(op); + PrimExpr col = GetGather4Col(op); + ICHECK_EQ(rows.size(), 4u); + ICHECK(col.defined()); + + TMADesc desc; + desc.rank = 2; + desc.data_type = to_CUtensorMapDataType(global_tensor->dtype); + desc.global_addr = global_tensor->data; + desc.global_shape = ReverseArray(global_tensor->shape); + + if (!global_tensor->strides.empty()) { + desc.global_stride = ReverseArray(global_tensor->strides); + } else { + PrimExpr stride = 1; + desc.global_stride.reserve(2); + for (size_t i = 0; i < global_tensor->shape.size(); ++i) { + desc.global_stride.push_back(stride); + stride *= global_tensor->shape[global_tensor->shape.size() - 1 - i]; + } + } + ICHECK(is_one(desc.global_stride[0])) + << "tma_gather4/scatter4 requires unit innermost global stride, got " + << desc.global_stride; + desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { + return TMABytesFromElements(e, global_tensor->dtype); + }); + for (size_t i = 1; i < desc.global_stride.size(); ++i) { + if (auto stride = desc.global_stride[i].as()) { + ICHECK(stride->value % 16 == 0 && stride->value < (1LL << 40)) + << "tma_gather4/scatter4 global stride[" << i + << "] = " << stride->value + << " bytes must be 16-byte aligned and < 2^40"; + } + } + + // The descriptor's row box-dim must be 1, not 4. The four-row pack is + // implicit in the cp.async.bulk.tensor.tile::gather4 PTX, which takes 4 + // row coordinates and materializes them into 4 logical rows of the shared + // tile. Setting box[1]=4 here would describe a contiguous 4-row strip; the + // gather4 unrolling would then read OOB → CUDA_ERROR_ILLEGAL_INSTRUCTION. + PrimExpr K_box = shared_tensor->shape[1]; + if (auto k = as_const_int(K_box)) { + int64_t k_bytes = TMABytesFromElements(*k, shared_tensor->dtype); + ICHECK(k_bytes % 16 == 0) + << "tma_gather4/scatter4 K_box * dtype.bytes() = " << k_bytes + << " must be 16-byte aligned"; + } + desc.smem_box = {K_box, IntImm(DataType::Int(32), 1)}; + desc.smem_stride = {IntImm(DataType::Int(32), 1), + IntImm(DataType::Int(32), 1)}; + desc.interleave = static_cast(CU_TENSOR_MAP_INTERLEAVE_NONE); + desc.l2_promotion = static_cast(CU_TENSOR_MAP_L2_PROMOTION_L2_128B); + desc.oob_fill = static_cast(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + Layout shared_layout; + if (T.layout_map.count(shared_tensor)) { + shared_layout = T.layout_map.at(shared_tensor); + ICHECK(T.buffer_remap.count(shared_tensor)); + shared_tensor = T.buffer_remap.at(shared_tensor); + } + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); + if (shared_layout.defined() && shared_layout->InputDim() >= 2) { + SwizzleMode mode = DetectSwizzleMode(shared_layout, shared_tensor_unmapped); + if (mode == SwizzleMode::kQuarter) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_32B); + } else if (mode == SwizzleMode::kHalf) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_64B); + } else if (mode == SwizzleMode::kFull) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_128B); + } + } + if (auto k = as_const_int(K_box)) { + int64_t k_bytes = TMABytesFromElements(*k, shared_tensor->dtype); + int max_bytes = 0; + if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_32B)) + max_bytes = 32; + else if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_64B)) + max_bytes = 64; + else if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_128B)) + max_bytes = 128; + if (max_bytes > 0) { + ICHECK(k_bytes <= max_bytes) + << "tma_gather4/scatter4 K_box * dtype.bytes() = " << k_bytes + << " exceeds " << max_bytes << "B swizzle limit"; + } + } + + Call create_descriptor = + Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs()); + + PrimExpr total_elements = 4 * K_box; + PrimExpr smem_addr = + shared_tensor.access_ptr(is_load ? 2 : 1, DataType::Handle(), 1, + IntImm(DataType::Int(32), 0), total_elements); + + Array args; + args.push_back(create_descriptor); + if (is_load) { + auto user_barrier = op.annotations.Get("barrier"); + ICHECK(user_barrier.has_value()) + << "tma_gather4 requires a 'barrier' annotation"; + args.push_back(Downcast(user_barrier.value())); + } + args.push_back(smem_addr); + args.push_back(col); + for (auto r : rows) + args.push_back(r); + args.push_back(IntImm(DataType::Int(32), GetEvictionPolicy(op))); + + // Fire-and-forget: caller manages mbarrier_expect_tx / wait (loads) and + // tma_store_arrive / wait (stores), and the leader-thread guard. + auto tl_op = is_load ? tma_load_gather4() : tma_store_scatter4(); + return Evaluate(Call(DataType::Handle(), tl_op, args)); +} + Stmt Copy::LowerBulk1D(const CopyNode &op, const LowerArgs &T, arith::Analyzer *analyzer, CopyInst copy_inst) { const Buffer &src = op.src; diff --git a/src/backend/cuda/op/copy.h b/src/backend/cuda/op/copy.h index 7b0d3a6c3..a95c096b0 100644 --- a/src/backend/cuda/op/copy.h +++ b/src/backend/cuda/op/copy.h @@ -29,6 +29,8 @@ enum class CopyInst : uint8_t { kBulkStore1D = 7, kTMemLoad = 8, kTMemStore = 9, + kBulkLoadGather4 = 10, // tma cp.async.bulk.tensor.tile::gather4 (sm_100a) + kBulkStoreScatter4 = 11, // tma cp.async.bulk.tensor.tile::scatter4 (sm_100a) kInvalid = 255, }; diff --git a/src/backend/cuda/op/copy_analysis.cc b/src/backend/cuda/op/copy_analysis.cc index d43aa5539..cb532a1ac 100644 --- a/src/backend/cuda/op/copy_analysis.cc +++ b/src/backend/cuda/op/copy_analysis.cc @@ -319,6 +319,10 @@ const char *CopyInstToString(CopyInst inst) { return "TMemLoad"; case CopyInst::kTMemStore: return "TMemStore"; + case CopyInst::kBulkLoadGather4: + return "BulkLoadGather4"; + case CopyInst::kBulkStoreScatter4: + return "BulkStoreScatter4"; case CopyInst::kInvalid: return "Invalid"; default: @@ -520,6 +524,16 @@ CopyFacts AnalyzeCopyFacts(const CopyNode &op, const CopyAnalysisContext &ctx) { CopyInstSelection SelectCopyInstForLowering(const CopyNode &op, const CopyAnalysisContext &ctx) { + // tile::gather4 / scatter4 markers take precedence over generic TMA paths. + // The IR carries explicit row indices via annotations and must always be + // lowered through LowerBulkCopyGather4 (no fallback path makes sense). + if (GetBoolAnnotation(op, "is_gather4")) { + return Supported(CopyInst::kBulkLoadGather4); + } + if (GetBoolAnnotation(op, "is_scatter4")) { + return Supported(CopyInst::kBulkStoreScatter4); + } + CopyFacts facts = AnalyzeCopyFacts(op, ctx); if (facts.explicit_tma) { CopyInst inst = diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index f46c47715..2b75244d2 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -431,7 +431,10 @@ static Layout MakeQuarterBankSwizzleLayout2D(int stride, int continuous, Var i = InputPlaceholder(0); Var j = InputPlaceholder(1); int vector_size = 128 / element_size; - ICHECK(stride % 8 == 0) << "stride=" << stride; + // stride==4 is a truncated 4-row period used by tile::gather4/scatter4 + // (s=i%8 ∈ [0,4) is a valid subset of the 8-row XOR pattern, matching + // what TMA applies per-row in hardware). + ICHECK(stride == 4 || stride % 8 == 0) << "stride=" << stride; ICHECK(continuous % (vector_size * 2) == 0) << "continuous=" << continuous << ", vector_size=" << vector_size; PrimExpr ts = FloorDiv(i, 8); @@ -459,7 +462,8 @@ static Layout MakeHalfBankSwizzleLayout2D(int stride, int continuous, Var i = InputPlaceholder(0); Var j = InputPlaceholder(1); int vector_size = 128 / element_size; - ICHECK(stride % 8 == 0) << "stride=" << stride; + // See MakeQuarterBankSwizzleLayout2D for stride==4 rationale. + ICHECK(stride == 4 || stride % 8 == 0) << "stride=" << stride; ICHECK(continuous % (vector_size * 4) == 0) << "continuous=" << continuous << ", vector_size=" << vector_size; PrimExpr ts = FloorDiv(i, 8); @@ -487,7 +491,8 @@ static Layout MakeFullBankSwizzleLayout2D(int stride, int continuous, Var i = InputPlaceholder(0); Var j = InputPlaceholder(1); int vector_size = 128 / element_size; - ICHECK(stride % 8 == 0) << "stride=" << stride; + // See MakeQuarterBankSwizzleLayout2D for stride==4 rationale. + ICHECK(stride == 4 || stride % 8 == 0) << "stride=" << stride; ICHECK(continuous % (vector_size * 8) == 0) << "continuous=" << continuous << ", vector_size=" << vector_size; PrimExpr ts = FloorDiv(i, 8); @@ -949,20 +954,23 @@ SwizzleMode DetectSwizzleMode(const Layout &layout, const Buffer &buffer) { int vector_size = 128 / info.element_size; // Check from smallest to largest granularity - // Need to verify stride and continuous constraints before comparing - if (info.stride % 8 == 0 && + // Need to verify stride and continuous constraints before comparing. + // stride==4 is the truncated 4-row period used by tile::gather4/scatter4 + // (see Make{Quarter,Half,Full}BankSwizzleLayout2D). + bool stride_ok = info.stride == 4 || info.stride % 8 == 0; + if (stride_ok && info.continuous % (static_cast(vector_size) * 2) == 0) { if (StructuralEqual()(layout, makeQuarterBankSwizzleLayout(buffer))) { return SwizzleMode::kQuarter; } } - if (info.stride % 8 == 0 && + if (stride_ok && info.continuous % (static_cast(vector_size) * 4) == 0) { if (StructuralEqual()(layout, makeHalfBankSwizzleLayout(buffer))) { return SwizzleMode::kHalf; } } - if (info.stride % 8 == 0 && + if (stride_ok && info.continuous % (static_cast(vector_size) * 8) == 0) { if (StructuralEqual()(layout, makeFullBankSwizzleLayout(buffer))) { return SwizzleMode::kFull; diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 3ff0df96f..55e53c7d1 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -177,6 +177,16 @@ TIR_DEFINE_TL_BUILTIN(tma_load_im2col) TIR_DEFINE_TL_BUILTIN(tma_store).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(tma_load_gather4) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tma_store_scatter4) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(ptx_fence_barrier_init) .set_num_inputs(-1) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 69a65ab10..be70bbd25 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -315,6 +315,31 @@ TVM_DLL const Op &tma_load_im2col(); */ TVM_DLL const Op &tma_store(); +/*! + * \brief tvm intrinsics for tile::gather4 TMA load (sm_90+). + * + * Loads four rows from a 2D global tensor (described by a tiled CUtensorMap) + * into a shared memory tile. The four rows can be at arbitrary indices. + * + * tma_load_gather4(descriptor, mbarrier, smem_data, col, + * row0, row1, row2, row3, eviction_policy) + * + * The descriptor must be encoded with rank=2 and box dim along axis 1 = 1 + * (the four-row pack is implicit in the gather4 PTX mode). + */ +TVM_DLL const Op &tma_load_gather4(); + +/*! + * \brief tvm intrinsics for tile::scatter4 TMA store (sm_90+). + * + * Stores four shared-memory rows back to four arbitrary rows of a 2D global + * tensor (described by a tiled CUtensorMap). + * + * tma_store_scatter4(descriptor, smem_data, col, + * row0, row1, row2, row3, eviction_policy) + */ +TVM_DLL const Op &tma_store_scatter4(); + /*! * \brief tvm intrinsics for barrier initialization fence * diff --git a/src/tl_templates/cuda/copy_sm100.h b/src/tl_templates/cuda/copy_sm100.h index a871bb1ca..921e79f11 100644 --- a/src/tl_templates/cuda/copy_sm100.h +++ b/src/tl_templates/cuda/copy_sm100.h @@ -429,4 +429,50 @@ TL_DEVICE void tma_load_2sm(const CUtensorMap &descriptor, : "memory"); } +// cp.async.bulk.tensor.2d.tile::{gather4,scatter4} (PTX 8.6, sm_100a). +// Five coordinate operands: {col, r0, r1, r2, r3}; the 4-row pack is implicit. +// CacheHintSm90 reused from copy_sm90.h (always included alongside this file). +#if (__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) +template +TL_DEVICE void tma_load_gather4(const CUtensorMap &descriptor, + BarrierType &smem_mbar, + void const *const smem_ptr, int32_t const &col, + int32_t const &r0, int32_t const &r1, + int32_t const &r2, int32_t const &r3) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4." + "mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(col), "r"(r0), "r"(r1), "r"(r2), "r"(r3), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void +tma_store_scatter4(const CUtensorMap &descriptor, void const *const smem_ptr, + int32_t const &col, int32_t const &r0, int32_t const &r1, + int32_t const &r2, int32_t const &r3) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.tile::scatter4.bulk_group" + ".L2::cache_hint [%0, {%2, %3, %4, %5, %6}], [%1], %7;" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(col), "r"(r0), "r"(r1), + "r"(r2), "r"(r3), "l"(cache_hint) + : "memory"); +} +#endif // CUDA 12.8+ + } // namespace tl diff --git a/testing/python/language/test_tilelang_language_tma_gather_scatter.py b/testing/python/language/test_tilelang_language_tma_gather_scatter.py new file mode 100644 index 000000000..4b2807c14 --- /dev/null +++ b/testing/python/language/test_tilelang_language_tma_gather_scatter.py @@ -0,0 +1,192 @@ +"""Round-trip test for TMA tile::gather4 / tile::scatter4 (sm_100a, Blackwell).""" + +import pytest + +from tilelang import tvm as tvm +import tilelang.testing +import tilelang.language as T +import tilelang + + +def _has_sm100(): + try: + import torch + except ImportError: + return False + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability(0) + return major >= 10 + + +requires_sm100 = pytest.mark.skipif(not _has_sm100(), reason="tile::gather4/scatter4 require sm_100a (Blackwell)") + + +def gather_scatter_program(N: int, K: int, K_box: int, in_dtype: str = "float16"): + + @T.prim_func + def main( + Src: T.Tensor((N, K), in_dtype), + Idx: T.Tensor((4,), "int32"), + Dst: T.Tensor((N, K), in_dtype), + ): + with T.Kernel(1, 1, threads=128) as (bx, by): + T.reads(Src[0:N, 0:K], Idx[0:4]) + T.writes(Dst[0:N, 0:K]) + + smem = T.alloc_shared((4, K_box), in_dtype) + mbar = T.alloc_barrier(1) + + r0 = Idx[0] + r1 = Idx[1] + r2 = Idx[2] + r3 = Idx[3] + + if T.shuffle_elect(128): + T.mbarrier_expect_tx(mbar, T.tma_gather4_bytes(K_box, in_dtype)) + T.tma_gather4(Src, smem, 0, [r0, r1, r2, r3], barrier=mbar) + T.barrier_arrive(mbar) + T.mbarrier_wait_parity(mbar, 0) + + if T.shuffle_elect(128): + T.tma_scatter4(smem, Dst, 0, [r0, r1, r2, r3]) + T.tma_store_arrive() + T.tma_store_wait(0) + + return main + + +def run_gather_scatter(N=64, K=64, K_box=64): + program = gather_scatter_program(N=N, K=K, K_box=K_box) + kernel = tilelang.compile( + program, + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + src = kernel.get_kernel_source() + assert "tma_load_gather4" in src, "tma_load_gather4 missing from emitted CUDA" + assert "tma_store_scatter4" in src, "tma_store_scatter4 missing from emitted CUDA" + assert "CUtensorMap" in src, "CUtensorMap descriptor missing from kernel signature" + + import torch + + Src = torch.randn(N, K, dtype=torch.float16, device="cuda") + Idx = torch.tensor([5, 17, 42, 9], dtype=torch.int32, device="cuda") + Dst = torch.zeros_like(Src) + + kernel(Src, Idx, Dst) + torch.cuda.synchronize() + + expected = torch.zeros_like(Src) + rows = Idx.tolist() + for r in rows: + expected[r] = Src[r] + + torch.testing.assert_close(Dst, expected) + + +@requires_sm100 +def test_gather_scatter_basic(): + run_gather_scatter(N=64, K=64, K_box=64) + + +# Swizzled round-trip: LowerBulkGather4 infers desc.swizzle from the annotated +# shared layout via DetectSwizzleMode. K_box * 2 bytes must match the swizzle +# period: 64→128B, 32→64B, 16→32B fp16. + + +def gather_scatter_swizzled_program(N: int, K: int, K_box: int, swizzle_kind: str, in_dtype: str = "float16"): + from tilelang.layout import ( + make_full_bank_swizzled_layout, + make_half_bank_swizzled_layout, + make_quarter_bank_swizzled_layout, + ) + + swizzle_factories = { + "128B": make_full_bank_swizzled_layout, + "64B": make_half_bank_swizzled_layout, + "32B": make_quarter_bank_swizzled_layout, + } + make_layout = swizzle_factories[swizzle_kind] + + @T.prim_func + def main( + Src: T.Tensor((N, K), in_dtype), + Idx: T.Tensor((4,), "int32"), + Dst: T.Tensor((N, K), in_dtype), + ): + with T.Kernel(1, 1, threads=128) as (bx, by): + T.reads(Src[0:N, 0:K], Idx[0:4]) + T.writes(Dst[0:N, 0:K]) + + smem = T.alloc_shared((4, K_box), in_dtype) + T.annotate_layout({smem: make_layout(smem)}) + + mbar = T.alloc_barrier(1) + + r0 = Idx[0] + r1 = Idx[1] + r2 = Idx[2] + r3 = Idx[3] + + if T.shuffle_elect(128): + T.mbarrier_expect_tx(mbar, T.tma_gather4_bytes(K_box, in_dtype)) + T.tma_gather4(Src, smem, 0, [r0, r1, r2, r3], barrier=mbar) + T.barrier_arrive(mbar) + T.mbarrier_wait_parity(mbar, 0) + + if T.shuffle_elect(128): + T.tma_scatter4(smem, Dst, 0, [r0, r1, r2, r3]) + T.tma_store_arrive() + T.tma_store_wait(0) + + return main + + +def run_gather_scatter_swizzled(N, K, K_box, swizzle_kind): + program = gather_scatter_swizzled_program(N=N, K=K, K_box=K_box, swizzle_kind=swizzle_kind) + kernel = tilelang.compile( + program, + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + src = kernel.get_kernel_source() + assert "tma_load_gather4" in src + assert "tma_store_scatter4" in src + + import torch + + Src = torch.randn(N, K, dtype=torch.float16, device="cuda") + Idx = torch.tensor([5, 17, 42, 9], dtype=torch.int32, device="cuda") + Dst = torch.zeros_like(Src) + + kernel(Src, Idx, Dst) + torch.cuda.synchronize() + + expected = torch.zeros_like(Src) + rows = Idx.tolist() + for r in rows: + expected[r] = Src[r] + + torch.testing.assert_close(Dst, expected) + + +@requires_sm100 +@pytest.mark.parametrize( + "K_box, swizzle_kind", + [ + (64, "128B"), # row = 128 bytes fp16 -> full-bank swizzle + (32, "64B"), # row = 64 bytes fp16 -> half-bank swizzle + (16, "32B"), # row = 32 bytes fp16 -> quarter-bank swizzle + ], +) +def test_gather_scatter_swizzled(K_box, swizzle_kind): + run_gather_scatter_swizzled(N=64, K=K_box, K_box=K_box, swizzle_kind=swizzle_kind) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 43c70563a..60490ae91 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -55,7 +55,7 @@ empty, # noqa: F401 ) from tvm.script.parser.tir import allocate as allocate # noqa: F401 -from .copy_op import copy, async_copy, tma_copy, transpose, c2d_im2col # noqa: F401 +from .copy_op import copy, async_copy, tma_copy, tma_gather4, tma_gather4_bytes, tma_scatter4, transpose, c2d_im2col # noqa: F401 from tilelang.tileop.base import GemmWarpPolicy # noqa: F401 from .gemm_op import ( # noqa: F401 gemm, diff --git a/tilelang/language/copy_op.py b/tilelang/language/copy_op.py index 7e9cdddb0..d0fd7499a 100644 --- a/tilelang/language/copy_op.py +++ b/tilelang/language/copy_op.py @@ -8,6 +8,7 @@ legalize_pairwise_extents, ) from tilelang.language.utils import get_extent +import tvm from tvm import ir, tir @@ -229,6 +230,186 @@ def tma_copy( return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.tma_copy"), src, dst, annotations=ann if ann else None) +_TMA_SUPPORTED_DTYPES = frozenset( + { + "uint8", + "uint16", + "uint32", + "int32", + "uint64", + "int64", + "float16", + "float32", + "float64", + "bfloat16", + } +) + + +def tma_gather4( + src: tir.Buffer, + dst: tir.Buffer, + col: tir.PrimExpr, + rows, + *, + barrier, + swizzle=None, + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, +): + """Issue a TMA tile::gather4 load (sm_100a, Blackwell). + + Loads four arbitrary rows of a 2D global tensor ``src`` into a 2D shared + tile ``dst`` of shape ``(4, K_box)``. The CUtensorMap descriptor (dtype + + swizzle) is built by the compiler from buffer + layout info. + + Caller must wrap this with ``T.shuffle_elect`` and pair it with + ``T.mbarrier_expect_tx`` (use :func:`tma_gather4_bytes`) before, and + ``barrier_arrive`` / ``mbarrier_wait_parity`` after. + + The ``swizzle`` kwarg is deprecated; mark the shared tile via + ``T.annotate_layout`` for non-default swizzle. + """ + if not isinstance(src, tir.Buffer): + raise TypeError("tma_gather4 src must be a tir.Buffer (global)") + if not isinstance(dst, tir.Buffer): + raise TypeError("tma_gather4 dst must be a tir.Buffer (shared)") + if src.scope() != "global": + raise ValueError(f"tma_gather4 src must be a global buffer, got scope={src.scope()}") + if dst.scope() not in ("shared", "shared.dyn"): + raise ValueError(f"tma_gather4 dst must be a shared buffer, got scope={dst.scope()}") + if len(src.shape) != 2: + raise ValueError(f"tma_gather4 expects rank-2 global buffer, got {len(src.shape)}") + if len(dst.shape) != 2: + raise ValueError(f"tma_gather4 expects rank-2 shared buffer (4 x K_box), got {len(dst.shape)}") + if src.dtype != dst.dtype: + raise ValueError(f"tma_gather4 dtype mismatch: src={src.dtype}, dst={dst.dtype}") + if not (isinstance(dst.shape[0], int) and dst.shape[0] == 4) and not (hasattr(dst.shape[0], "value") and int(dst.shape[0].value) == 4): + raise ValueError(f"tma_gather4 shared tile leading dim must be 4, got {dst.shape[0]}") + if src.strides: + inner = src.strides[1] + if not ((isinstance(inner, int) and inner == 1) or (hasattr(inner, "value") and int(inner.value) == 1)): + raise ValueError(f"tma_gather4 requires unit innermost global stride, got {inner}") + rows = list(rows) + if len(rows) != 4: + raise ValueError(f"tma_gather4 expects exactly 4 row indices, got {len(rows)}") + if swizzle not in (None, "none", 0): + import warnings + + warnings.warn( + f"tma_gather4 swizzle={swizzle!r} is deprecated; use T.annotate_layout.", + DeprecationWarning, + stacklevel=2, + ) + + from .builtin import _mbar_to_buffer_load + + bar_load = _mbar_to_buffer_load(barrier) + + eviction_policy_map = {"evict_normal": 0, "evict_first": 1, "evict_last": 2} + ep = 0 if eviction_policy is None else eviction_policy_map[eviction_policy] + + # Matching (4, K_box) extents satisfy CopyNode's shape check; the actual + # access pattern lives in the gather4_rows / gather4_col annotations. + K_box = dst.shape[1] + src_region = to_buffer_region(src, access_type="r", extents=[4, K_box]) + dst_region = to_buffer_region(dst, access_type="w", extents=[4, K_box]) + + ann = { + "is_gather4": True, + "gather4_rows": rows, + "gather4_col": col, + "barrier": bar_load, + "eviction_policy": ep, + } + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.tileop.copy"), + src_region, + dst_region, + annotations=ann, + ) + + +def tma_gather4_bytes(K_box, dtype: str) -> int: + """Transaction byte count for a 4-row gather4 of width ``K_box``. Pass + to ``T.mbarrier_expect_tx`` immediately before ``T.tma_gather4``. + """ + if dtype not in _TMA_SUPPORTED_DTYPES: + raise ValueError(f"Unsupported dtype: {dtype}") + elem_bytes = tvm.DataType(dtype).bits // 8 + return 4 * K_box * elem_bytes + + +def tma_scatter4( + src: tir.Buffer, + dst: tir.Buffer, + col: tir.PrimExpr, + rows, + *, + swizzle=None, + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, +): + """Issue a TMA tile::scatter4 store (sm_100a, Blackwell). + + Stores a 2D shared tile of shape ``(4, K_box)`` to four arbitrary rows of + a 2D global tensor ``dst``. Caller is responsible for ``tma_store_arrive`` + / ``tma_store_wait`` and the ``T.shuffle_elect`` guard. See + :func:`tma_gather4` for descriptor / swizzle inference details. + """ + if not isinstance(src, tir.Buffer): + raise TypeError("tma_scatter4 src must be a tir.Buffer (shared)") + if not isinstance(dst, tir.Buffer): + raise TypeError("tma_scatter4 dst must be a tir.Buffer (global)") + if src.scope() not in ("shared", "shared.dyn"): + raise ValueError(f"tma_scatter4 src must be a shared buffer, got scope={src.scope()}") + if dst.scope() != "global": + raise ValueError(f"tma_scatter4 dst must be a global buffer, got scope={dst.scope()}") + if len(src.shape) != 2: + raise ValueError(f"tma_scatter4 expects rank-2 shared buffer (4 x K_box), got {len(src.shape)}") + if len(dst.shape) != 2: + raise ValueError(f"tma_scatter4 expects rank-2 global buffer, got {len(dst.shape)}") + if src.dtype != dst.dtype: + raise ValueError(f"tma_scatter4 dtype mismatch: src={src.dtype}, dst={dst.dtype}") + if not (isinstance(src.shape[0], int) and src.shape[0] == 4) and not (hasattr(src.shape[0], "value") and int(src.shape[0].value) == 4): + raise ValueError(f"tma_scatter4 shared tile leading dim must be 4, got {src.shape[0]}") + if dst.strides: + inner = dst.strides[1] + if not ((isinstance(inner, int) and inner == 1) or (hasattr(inner, "value") and int(inner.value) == 1)): + raise ValueError(f"tma_scatter4 requires unit innermost global stride, got {inner}") + rows = list(rows) + if len(rows) != 4: + raise ValueError(f"tma_scatter4 expects exactly 4 row indices, got {len(rows)}") + if swizzle not in (None, "none", 0): + import warnings + + warnings.warn( + f"tma_scatter4 swizzle={swizzle!r} is deprecated; use T.annotate_layout.", + DeprecationWarning, + stacklevel=2, + ) + + eviction_policy_map = {"evict_normal": 0, "evict_first": 1, "evict_last": 2} + ep = 0 if eviction_policy is None else eviction_policy_map[eviction_policy] + + K_box = src.shape[1] + src_region = to_buffer_region(src, access_type="r", extents=[4, K_box]) + dst_region = to_buffer_region(dst, access_type="w", extents=[4, K_box]) + + ann = { + "is_scatter4": True, + "gather4_rows": rows, + "gather4_col": col, + "eviction_policy": ep, + } + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.tileop.copy"), + src_region, + dst_region, + annotations=ann, + ) + + def transpose( src: BufferLikeType, dst: BufferLikeType,