diff --git a/src/backend/cuda/codegen/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc index 832a8fba9..fbecaaa8e 100644 --- a/src/backend/cuda/codegen/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -1033,6 +1033,10 @@ void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string &op, DataType t, tl_func = "min2"; else if (op == "max") tl_func = "max2"; + else if (op == "min_nan") + tl_func = "min2_nan"; + else if (op == "max_nan") + tl_func = "max2_nan"; if (!tl_func.empty()) { // Decompose into lanes/2 independent x2 packed operations. @@ -3768,6 +3772,7 @@ bool CodeGenTileLangCUDA::HandleLateIntrinsicCall(const CallNode *op, } else if (op->op.same_as(tl::add2()) || op->op.same_as(tl::sub2()) || op->op.same_as(tl::mul2()) || op->op.same_as(tl::fma2()) || op->op.same_as(tl::max2()) || op->op.same_as(tl::min2()) || + op->op.same_as(tl::max2_nan()) || op->op.same_as(tl::min2_nan()) || op->op.same_as(tl::abs2())) { // Packed x2 element-wise math intrinsics. // @@ -3790,6 +3795,10 @@ bool CodeGenTileLangCUDA::HandleLateIntrinsicCall(const CallNode *op, op_name = "max2"; else if (op->op.same_as(tl::min2())) op_name = "min2"; + else if (op->op.same_as(tl::max2_nan())) + op_name = "max2_nan"; + else if (op->op.same_as(tl::min2_nan())) + op_name = "min2_nan"; else op_name = "abs2"; @@ -4427,10 +4436,29 @@ void CodeGenTileLangCUDA::VisitExpr_(const ShuffleNode *op, os << "uint1{__pack_nv_bfloat162(" << e0 << ", " << e1 << ")}"; } else { enable_fp16_ = true; - // __pack_half2 returns __half2 which is 32-bit. - // Reinterpret via aggregate initialisation. - os << "uint1{*(unsigned*)&(__pack_half2((__half)(" << e0 << "), (__half)(" - << e1 << ")))}"; + os << "uint1{tl::pack_half2(" << e0 << ", " << e1 << ")}"; + } + return; + } + // Handle ExtractElement: extract a scalar lane from a bfloat16x2 / float16x2 + // vector (produced by packed reduction, etc.). The vector is stored as an + // opaque uint1 in the lowered code, but semantically it is a packed pair. + DataType vec_t = + op->vectors.size() == 1 ? op->vectors[0].dtype() : DataType(); + bool vec_is_bf16x2 = vec_t.is_bfloat16() && vec_t.lanes() == 2; + bool vec_is_fp16x2 = vec_t.is_float16() && vec_t.lanes() == 2; + if ((vec_is_bf16x2 || vec_is_fp16x2) && op->vectors.size() == 1 && + op->indices.size() == 1) { + int lane = Downcast(op->indices[0])->value; + std::string vec = PrintExpr(op->vectors[0]); + if (vec_is_bf16x2) { + enable_bf16_ = true; + os << "bfloat16_t(((nv_bfloat162*)(&(" << vec << ")))->" + << (lane == 0 ? "x" : "y") << ")"; + } else { + enable_fp16_ = true; + os << "half_t(((half2*)(&(" << vec << ")))->" << (lane == 0 ? "x" : "y") + << ")"; } return; } diff --git a/src/backend/cuda/op/reduce.cc b/src/backend/cuda/op/reduce.cc index f9da12c12..e7f3cd791 100644 --- a/src/backend/cuda/op/reduce.cc +++ b/src/backend/cuda/op/reduce.cc @@ -17,10 +17,12 @@ #include #include #include +#include #include #include #include +#include #include #include #include @@ -87,105 +89,148 @@ uint64_t UnsignedMax(int bits) { return (static_cast(1) << bits) - 1; } -PrimExpr MakeInitValue(const ReduceOpNode &op) { +int GetPreferedVectorizedSize(DataType dt) { + if (dt.is_bfloat16() || dt.is_float16()) + return 2; + return 1; +} + +PrimExpr MakeInitValue(const ReduceOpNode &op, int vsize = 1) { auto dst_dtype = op.dst->dtype; auto is_int = dst_dtype.is_int(); bool is_uint = dst_dtype.is_uint(); auto bits = dst_dtype.bits(); + PrimExpr scalar; if (op.type->isSum() || op.type->isAbsSum()) { - return make_zero(op.dst->dtype); + scalar = make_zero(op.dst->dtype); } else if (op.type->isMax()) { if (is_int) { - return make_const(op.dst->dtype, SignedMin(bits)); + scalar = make_const(op.dst->dtype, SignedMin(bits)); } else if (is_uint) { - return make_const(op.dst->dtype, 0); + scalar = make_const(op.dst->dtype, 0); } else { - return make_const(op.dst->dtype, -INFINITY); + scalar = make_const(op.dst->dtype, -INFINITY); } } else if (op.type->isMin()) { if (is_int) { - return make_const(op.dst->dtype, SignedMax(bits)); + scalar = make_const(op.dst->dtype, SignedMax(bits)); } else if (is_uint) { - return make_const(op.dst->dtype, UnsignedMax(bits)); + scalar = make_const(op.dst->dtype, UnsignedMax(bits)); } else { - return make_const(op.dst->dtype, INFINITY); + scalar = make_const(op.dst->dtype, INFINITY); } } else if (op.type->isAbsMax()) { - return make_const(op.dst->dtype, 0); + scalar = make_const(op.dst->dtype, 0); } else if (op.type->isBitAnd()) { if (is_int) { - return make_const(op.dst->dtype, -1); + scalar = make_const(op.dst->dtype, -1); } else if (is_uint) { - return make_const(op.dst->dtype, UnsignedMax(bits)); + scalar = make_const(op.dst->dtype, UnsignedMax(bits)); } else { - return make_const(op.dst->dtype, -INFINITY); + scalar = make_const(op.dst->dtype, -INFINITY); } } else if (op.type->isBitOr() || op.type->isBitXor()) { - return make_zero(op.dst->dtype); + scalar = make_zero(op.dst->dtype); + } else { + LOG(FATAL) << "Unsupported reduce type: " << op.type->type; + scalar = PrimExpr(); } - LOG(FATAL) << "Unsupported reduce type: " << op.type->type; - return PrimExpr(); + + if (vsize <= 1) + return scalar; + return Broadcast(scalar, vsize); } -PrimExpr MakeReduce(const ReduceOpNode &op, const PrimExpr &acc, - const PrimExpr &b) { - PrimExpr rhs = b; - if (acc->dtype != rhs->dtype) { - rhs = Cast(acc->dtype, rhs); +std::optional MakeReduce(const ReduceOpNode &op, int vsize, + const PrimExpr &acc, const PrimExpr &b) { + if (vsize == 1) { + PrimExpr rhs = b; + if (acc->dtype != rhs->dtype) { + rhs = Cast(acc->dtype, rhs); + } + const bool use_nan_op = op.nan_propagate && (acc.dtype().is_float16() || + acc.dtype().is_bfloat16()); + if (op.type->isSum()) { + return acc + rhs; + } else if (op.type->isAbsSum()) { + return acc + Max(rhs, -rhs); + } else if (op.type->isMax()) { + return use_nan_op ? Call(acc.dtype(), tl::max_nan(), {acc, rhs}) + : PrimExpr(Max(acc, rhs)); + } else if (op.type->isMin()) { + return use_nan_op ? Call(acc.dtype(), tl::min_nan(), {acc, rhs}) + : PrimExpr(Min(acc, rhs)); + } else if (op.type->isAbsMax()) { + auto abs_rhs = Max(rhs, -rhs); + return use_nan_op ? Call(acc.dtype(), tl::max_nan(), {acc, abs_rhs}) + : PrimExpr(Max(acc, abs_rhs)); + } else if (op.type->isBitAnd()) { + return acc & rhs; + } else if (op.type->isBitOr()) { + return acc | rhs; + } else if (op.type->isBitXor()) { + return acc ^ rhs; + } + LOG(FATAL) << "Unsupported reduce type: " << op.type->type; + return std::nullopt; } - const bool use_nan_op = op.nan_propagate && (acc.dtype().is_float16() || - acc.dtype().is_bfloat16()); + + if (vsize != 2) + return std::nullopt; + if (op.type->isSum()) { - return acc + rhs; + return Call(acc.dtype(), tl::add2(), {acc, b}); } else if (op.type->isAbsSum()) { - return acc + Max(rhs, -rhs); + return Call(acc.dtype(), tl::add2(), + {acc, Call(acc.dtype(), tl::abs2(), {b})}); } else if (op.type->isMax()) { - if (use_nan_op) { - return Call(acc.dtype(), tl::max_nan(), {acc, rhs}); - } - return Max(acc, rhs); + return Call(acc.dtype(), op.nan_propagate ? tl::max2_nan() : tl::max2(), + {acc, b}); } else if (op.type->isMin()) { - if (use_nan_op) { - return Call(acc.dtype(), tl::min_nan(), {acc, rhs}); - } - return Min(acc, rhs); + return Call(acc.dtype(), op.nan_propagate ? tl::min2_nan() : tl::min2(), + {acc, b}); } else if (op.type->isAbsMax()) { - if (use_nan_op) { - return Call(acc.dtype(), tl::max_nan(), {acc, tvm::abs(rhs)}); - } - return Max(acc, tvm::abs(rhs)); - } else if (op.type->isBitAnd()) { - return acc & rhs; - } else if (op.type->isBitOr()) { - return acc | rhs; - } else if (op.type->isBitXor()) { - return acc ^ rhs; + return Call(acc.dtype(), op.nan_propagate ? tl::max2_nan() : tl::max2(), + {acc, Call(acc.dtype(), tl::abs2(), {b})}); } - LOG(FATAL) << "Unsupported reduce type: " << op.type->type; - return PrimExpr(); + return std::nullopt; } -std::string MakeCodegenReducer(const ReduceOpNode &op) { +std::optional MakeCodegenReducer(const ReduceOpNode &op, + int vsize = 1) { const bool use_nan_op = op.nan_propagate && (op.dst->dtype.is_float16() || op.dst->dtype.is_bfloat16()); - if (op.type->isSum() || op.type->isAbsSum()) { - return "tl::SumOp"; - } else if (op.type->isMax()) { - return use_nan_op ? "tl::MaxOpNan" : "tl::MaxOp"; - } else if (op.type->isMin()) { - return use_nan_op ? "tl::MinOpNan" : "tl::MinOp"; - } else if (op.type->isAbsMax()) { - return use_nan_op ? "tl::MaxOpNan" : "tl::MaxOp"; - } else if (op.type->isBitAnd()) { - return "tl::BitAndOp"; - } else if (op.type->isBitOr()) { - return "tl::BitOrOp"; - } else if (op.type->isBitXor()) { - return "tl::BitXorOp"; + + auto base = [&]() -> std::string { + if (op.type->isSum() || op.type->isAbsSum()) + return "tl::SumOp"; + if (op.type->isMax()) + return use_nan_op ? "tl::MaxOpNan" : "tl::MaxOp"; + if (op.type->isMin()) + return use_nan_op ? "tl::MinOpNan" : "tl::MinOp"; + if (op.type->isAbsMax()) + return use_nan_op ? "tl::MaxOpNan" : "tl::MaxOp"; + if (op.type->isBitAnd()) + return "tl::BitAndOp"; + if (op.type->isBitOr()) + return "tl::BitOrOp"; + if (op.type->isBitXor()) + return "tl::BitXorOp"; + LOG(FATAL) << "Unsupported reduce type: " << op.type->type; + return ""; + }(); + + if (vsize <= 1) + return base; + + if (vsize == 2) { + if (op.dst->dtype.is_bfloat16()) + return base + "_bf16x2"; + if (op.dst->dtype.is_float16()) + return base + "_fp16x2"; } - LOG(FATAL) << "Unsupported reduce type: " << op.type->type; - return ""; + return std::nullopt; } PrimExpr MakeUpdate(const ReduceOpNode &op, PrimExpr dst_val, @@ -315,13 +360,8 @@ struct Reduce { dst_buffer->name + "_clear", GetPtrStorageScope(dst_buffer->data)); } - if (require_init || - (need_duplicate && - (op.type->isMax() || op.type->isMin() || op.type->isAbsMax()))) { - stmts.push_back( - BufferStore(clear_buffer, MakeInitValue(op), red_indices)); - } + // make thread-local reduce Array src_indice_compressed; Array src_var_compressed; for (size_t i = 0; i < src_layout->OutputDim(); ++i) { @@ -331,19 +371,111 @@ struct Reduce { src_var_compressed.push_back(var); } - Stmt reduce_local = - BufferStore(clear_buffer, - MakeReduce(op, BufferLoad(clear_buffer, red_indices), - BufferLoad(src_buffer, src_indice_compressed)), - red_indices); - - for (int i = static_cast(src_layout->OutputDim()) - 1; i >= 0; --i) { - reduce_local = For(src_var_compressed[i]->var, 0, - src_var_compressed[i]->dom->extent, - ForKind::kUnrolled, reduce_local, std::nullopt, - {{tir::attr::pragma_unroll_explicit, Bool(false)}}); + bool can_pack = false; + bool need_pack_buffer = false; + bool need_batch_pack_buffer = false; + Buffer clear_buffer_packed; + Buffer clear_batch_pack_buffer; + { + int vsize = GetPreferedVectorizedSize(clear_buffer->dtype); + if (vsize > 1 && !src_var_compressed.empty()) { + auto *ext = src_var_compressed.back()->dom->extent.as(); + if (ext && ext->value >= vsize && ext->value % vsize == 0) { + can_pack = true; + DataType vec_dtype = clear_buffer->dtype.with_lanes(vsize); + clear_buffer_packed = + decl_buffer(red_layout->OutputShape(), vec_dtype, + clear_buffer->name + "_pack", + GetPtrStorageScope(clear_buffer->data)); + need_pack_buffer = true; + + Array local_body; + + if (require_init || + (need_duplicate && (op.type->isMax() || op.type->isMin() || + op.type->isAbsMax()))) { + local_body.push_back(BufferStore(clear_buffer_packed, + MakeInitValue(op, vsize), + red_indices)); + } + + const auto *ext_int = + as_const_int(src_var_compressed.back()->dom->extent); + int64_t inner_extent = *ext_int; + PrimExpr halved_extent = Integer(inner_extent / vsize); + + auto &inner_var = src_var_compressed.back(); + + PrimExpr ramp_base = + Substitute(src_indice_compressed.back(), + {{inner_var->var, inner_var->var * Integer(2)}}); + src_indice_compressed.Set( + src_indice_compressed.size() - 1, + Ramp(ramp_base, IntImm(DataType::Int(32), 1), vsize)); + + auto src_load = BufferLoad(src_buffer, src_indice_compressed); + auto *src_writer = src_load.CopyOnWrite(); + src_writer->dtype = vec_dtype; + + Stmt reduce_local = BufferStore( + clear_buffer_packed, + MakeReduce(op, vsize, + BufferLoad(clear_buffer_packed, red_indices), + src_load) + .value(), + red_indices); + + reduce_local = + For(inner_var->var, 0, halved_extent, ForKind::kUnrolled, + reduce_local, std::nullopt, + {{tir::attr::pragma_unroll_explicit, Bool(false)}}); + + for (int i = static_cast(src_layout->OutputDim()) - 2; i >= 0; + --i) { + reduce_local = + For(src_var_compressed[i]->var, 0, + src_var_compressed[i]->dom->extent, ForKind::kUnrolled, + reduce_local, std::nullopt, + {{tir::attr::pragma_unroll_explicit, Bool(false)}}); + } + local_body.push_back(reduce_local); + + auto acc_vec = BufferLoad(clear_buffer_packed, red_indices); + auto lane0 = Shuffle::ExtractElement(acc_vec, 0); + auto lane1 = Shuffle::ExtractElement(acc_vec, 1); + auto scalar_result = MakeReduce(op, 1, lane0, lane1).value(); + local_body.push_back( + BufferStore(clear_buffer, scalar_result, red_indices)); + + stmts.push_back(SeqStmt(local_body)); + } + } + } + + if (!can_pack) { + if (require_init || + (need_duplicate && (op.type->isMax() || op.type->isMin() || + op.type->isAbsMax()))) { + stmts.push_back( + BufferStore(clear_buffer, MakeInitValue(op), red_indices)); + } + + Stmt reduce_local = BufferStore( + clear_buffer, + MakeReduce(op, 1, BufferLoad(clear_buffer, red_indices), + BufferLoad(src_buffer, src_indice_compressed)) + .value(), + red_indices); + + for (int i = static_cast(src_layout->OutputDim()) - 1; i >= 0; + --i) { + reduce_local = For(src_var_compressed[i]->var, 0, + src_var_compressed[i]->dom->extent, + ForKind::kUnrolled, reduce_local, std::nullopt, + {{tir::attr::pragma_unroll_explicit, Bool(false)}}); + } + stmts.push_back(reduce_local); } - stmts.push_back(reduce_local); auto src_thread = src_layout->ForwardThread( src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {}); @@ -395,76 +527,179 @@ struct Reduce { }; if (use_batch) { + // ================================================================ + // Batched AllReduce path — three phases: + // 1. Loop: init + thread-local reduce + // 2. Flat: batched AllReduce (single butterfly pass for all values) + // 3. Loop: copy-back (only when need_duplicate) + // ================================================================ + + // Phase 1: pre-reduce loop Stmt pre_body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]; pre_body = make_dst_loop(pre_body, dst_vars); Array phases; phases.push_back(pre_body); + // Phase 2: batched AllReduce call(s). for (const auto &iter_split : iter_sum->args) { auto mark = iter_split->source->source.as(); - if (!mark) { + if (!mark) continue; - } - if (!mark.value().same_as(src_vars[op.dim]->var)) { + if (!mark.value().same_as(src_vars[op.dim]->var)) continue; - } auto scale = as_const_int(iter_split->scale); auto extent = as_const_int(iter_split->extent); ICHECK(scale != nullptr && extent != nullptr); - if (*extent == 1) { + if (*extent == 1) continue; - } int reducing_threads = (*extent) * (*scale); auto thread_offset = T.thread_bounds->min; + std::stringstream ss; + + int vsize = GetPreferedVectorizedSize(clear_buffer->dtype); + bool can_batch_pack = + vsize > 1 && batch >= vsize && batch % vsize == 0; + int eff_batch = can_batch_pack ? (batch / vsize) : batch; + + std::string reducer = + MakeCodegenReducer(op, can_batch_pack ? vsize : 1).value(); + + if (TargetHasSMVersionGE(T.target, 90)) { + auto all_threads = T.thread_bounds->extent; + ss << "tl::AllReduce<" << reducer << ", " << reducing_threads + << ", " << (*scale) << ", " << thread_offset + << ", tl::NamedBarrier<" << all_threads << ">, " << eff_batch + << ", " << reducing_threads << ">::run_batch"; + } else { + ss << "tl::AllReduce<" << reducer << ", " << reducing_threads + << ", " << (*scale) << ", " << thread_offset + << ", tl::SyncThreadsBarrier, " << eff_batch << ", " + << reducing_threads << ">::run_batch"; + } - std::string allreduce = MakeBatchAllReduce( - MakeCodegenReducer(op), reducing_threads, *scale, thread_offset, - T.thread_bounds->extent, batch, reducing_threads, T.target); - + DataType ws_dtype = can_batch_pack + ? clear_buffer->dtype.with_lanes(vsize) + : clear_buffer->dtype; PrimExpr workspace; bool need_workspace = reducing_threads > 32; if (need_workspace) { - int ws_size = reducing_threads * batch; - workspace = T.AddWorkspace(ws_size, clear_buffer->dtype); + int ws_size = reducing_threads * eff_batch; + workspace = T.AddWorkspace(ws_size, ws_dtype); } int64_t N_total = 1; - for (const auto &s : clear_buffer->shape) { + for (const auto &s : clear_buffer->shape) N_total *= *as_const_int(s); - } int num_chunks = static_cast(N_total / batch); int buf_ndim = static_cast(clear_buffer->shape.size()); std::vector buf_shape_vals; - for (const auto &s : clear_buffer->shape) { + for (const auto &s : clear_buffer->shape) buf_shape_vals.push_back(*as_const_int(s)); - } std::vector buf_strides(buf_ndim, 1); - for (int d = buf_ndim - 2; d >= 0; d--) { + for (int d = buf_ndim - 2; d >= 0; d--) buf_strides[d] = buf_strides[d + 1] * buf_shape_vals[d + 1]; - } - for (int chunk = 0; chunk < num_chunks; chunk++) { - int64_t flat_offset = static_cast(chunk) * batch; - Array chunk_indices; - for (int d = 0; d < buf_ndim; d++) { - int64_t idx = (flat_offset / buf_strides[d]) % buf_shape_vals[d]; - chunk_indices.push_back(Integer(idx)); + std::string template_str = ss.str(); + + if (can_batch_pack) { + int K = vsize; + int packed_batch = batch / K; + + Buffer pack_buf = + decl_buffer({Integer(packed_batch)}, + clear_buffer->dtype.with_lanes(K), + clear_buffer->name + "_pack", + GetPtrStorageScope(clear_buffer->data)); + + need_batch_pack_buffer = true; + clear_batch_pack_buffer = pack_buf; + + for (int chunk = 0; chunk < num_chunks; chunk++) { + int64_t flat_offset = (int64_t)chunk * batch; + + // --- Pack loop --- + Var pack_j("pack_j"); + PrimExpr base = Integer(flat_offset); + PrimExpr scaled = pack_j * K; + + Array idx_a, idx_b; + PrimExpr fa = base + scaled; + PrimExpr fb = base + scaled + Integer(1); + for (int d = 0; d < buf_ndim; d++) { + idx_a.push_back(FloorMod(FloorDiv(fa, Integer(buf_strides[d])), + Integer(buf_shape_vals[d]))); + idx_b.push_back(FloorMod(FloorDiv(fb, Integer(buf_strides[d])), + Integer(buf_shape_vals[d]))); + } + auto a_load = BufferLoad(clear_buffer, idx_a); + auto b_load = BufferLoad(clear_buffer, idx_b); + Stmt pack_body = BufferStore( + pack_buf, Shuffle({a_load, b_load}, {0, 1}), {pack_j}); + Stmt pack_loop = + For(pack_j, 0, packed_batch, ForKind::kUnrolled, pack_body); + phases.push_back(pack_loop); + + // --- AllReduce on packed buffer --- + PrimExpr packed_ptr = + Call(DataType::Handle(), builtin::address_of(), + {BufferLoad(pack_buf, {Integer(0)})}); + Array args = {StringImm(template_str), packed_ptr}; + if (need_workspace) + args.push_back(workspace); + phases.push_back(Evaluate( + Call(DataType::Handle(), builtin::call_extern(), args))); + + // --- Unpack loop --- + Var unpack_j("unpack_j"); + PrimExpr ubase = Integer(flat_offset); + PrimExpr uscaled = unpack_j * K; + Array uidx_a, uidx_b; + PrimExpr ufa = ubase + uscaled; + PrimExpr ufb = ubase + uscaled + Integer(1); + for (int d = 0; d < buf_ndim; d++) { + uidx_a.push_back( + FloorMod(FloorDiv(ufa, Integer(buf_strides[d])), + Integer(buf_shape_vals[d]))); + uidx_b.push_back( + FloorMod(FloorDiv(ufb, Integer(buf_strides[d])), + Integer(buf_shape_vals[d]))); + } + auto packed_val = BufferLoad(pack_buf, {unpack_j}); + Stmt unpack_body = SeqStmt({ + BufferStore(clear_buffer, + Shuffle::ExtractElement(packed_val, 0), uidx_a), + BufferStore(clear_buffer, + Shuffle::ExtractElement(packed_val, 1), uidx_b), + }); + Stmt unpack_loop = For(unpack_j, 0, packed_batch, + ForKind::kUnrolled, unpack_body); + phases.push_back(unpack_loop); } - PrimExpr ptr = Call(DataType::Handle(), builtin::address_of(), - {BufferLoad(clear_buffer, chunk_indices)}); - - Array args = {StringImm(allreduce), ptr}; - if (need_workspace) { - args.push_back(workspace); + } else { + for (int chunk = 0; chunk < num_chunks; chunk++) { + int64_t flat_offset = (int64_t)chunk * batch; + Array chunk_indices; + for (int d = 0; d < buf_ndim; d++) { + int64_t idx = + (flat_offset / buf_strides[d]) % buf_shape_vals[d]; + chunk_indices.push_back(Integer(idx)); + } + PrimExpr ptr = Call(DataType::Handle(), builtin::address_of(), + {BufferLoad(clear_buffer, chunk_indices)}); + + Array args = {StringImm(template_str), ptr}; + if (need_workspace) + args.push_back(workspace); + phases.push_back(Evaluate( + Call(DataType::Handle(), builtin::call_extern(), args))); } - phases.push_back(Evaluate( - Call(DataType::Handle(), builtin::call_extern(), args))); } } + // Phase 3: copy-back (only when a temp buffer was used) if (need_duplicate) { auto [post_vars, post_dst_idx, post_red_idx] = make_fresh_dst_vars("_p"); @@ -475,16 +710,16 @@ struct Reduce { dst_th.push_back(T.thread_var); auto inv = dst_layout->Inverse()->Forward(dst_th); inv.pop_back(); - for (int i = 0; i < static_cast(dst_layout->InputDim()); i++) { + for (int i = 0; i < static_cast(dst_layout->InputDim()); i++) predicate = predicate && (inv[i] == post_vars[i]->var); - } predicate = analyzer->Simplify(predicate); } PrimExpr update = - need_update ? MakeUpdate(op, BufferLoad(dst_buffer, post_dst_idx), - BufferLoad(clear_buffer, post_red_idx)) - : BufferLoad(clear_buffer, post_red_idx); + need_update + ? MakeUpdate(op, BufferLoad(dst_buffer, post_dst_idx), + BufferLoad(clear_buffer, post_red_idx)) + : BufferLoad(clear_buffer, post_red_idx); auto store = BufferStore(dst_buffer, update, post_dst_idx); Stmt post_body; if (analyzer->CanProve(predicate)) { @@ -500,86 +735,105 @@ struct Reduce { body = Allocate(clear_buffer->data, clear_buffer->dtype, clear_buffer->shape, const_true(), body); } + if (need_pack_buffer) { + body = + Allocate(clear_buffer_packed->data, clear_buffer_packed->dtype, + clear_buffer_packed->shape, const_true(), body); + } + if (need_batch_pack_buffer) { + body = Allocate(clear_batch_pack_buffer->data, + clear_batch_pack_buffer->dtype, + clear_batch_pack_buffer->shape, const_true(), body); + } return body; - } - for (const auto &iter_split : iter_sum->args) { - auto mark = iter_split->source->source.as(); - if (!mark) { - continue; - } - if (mark.value().same_as(src_vars[op.dim]->var)) { - auto scale = as_const_int(iter_split->scale); - auto extent = as_const_int(iter_split->extent); - ICHECK(scale != nullptr && extent != nullptr); - if (*extent == 1) { + } else { + // ================================================================ + // Original scalar AllReduce path. + // ================================================================ + for (const auto &iter_split : iter_sum->args) { + auto mark = iter_split->source->source.as(); + if (!mark) continue; + if (mark.value().same_as(src_vars[op.dim]->var)) { + auto scale = as_const_int(iter_split->scale); + auto extent = as_const_int(iter_split->extent); + ICHECK(scale != nullptr && extent != nullptr); + if (*extent == 1) + continue; + + int reducing_threads = (*extent) * (*scale); + auto thread_offset = T.thread_bounds->min; + std::string allreduce = MakeScalarAllReduce( + MakeCodegenReducer(op).value(), reducing_threads, *scale, + thread_offset, T.thread_bounds->extent, T.target); + Array thread_reduce_args = { + StringImm(allreduce), BufferLoad(clear_buffer, red_indices)}; + if (reducing_threads > 32) { + int workspace_size = + static_cast(*as_const_int(T.thread_bounds->extent)); + PrimExpr workspace = + T.AddWorkspace(workspace_size, clear_buffer->dtype); + thread_reduce_args.push_back(workspace); + } + auto call = Call(clear_buffer->dtype, builtin::call_extern(), + thread_reduce_args); + stmts.push_back(BufferStore(clear_buffer, call, red_indices)); } + } - int reducing_threads = (*extent) * (*scale); - auto thread_offset = T.thread_bounds->min; - std::string allreduce = MakeScalarAllReduce( - MakeCodegenReducer(op), reducing_threads, *scale, thread_offset, - T.thread_bounds->extent, T.target); - Array thread_reduce_args = { - StringImm(allreduce), BufferLoad(clear_buffer, red_indices)}; - if (reducing_threads > 32) { - int workspace_size = - static_cast(*as_const_int(T.thread_bounds->extent)); - PrimExpr workspace = - T.AddWorkspace(workspace_size, clear_buffer->dtype); - thread_reduce_args.push_back(workspace); + PrimExpr predicate = Bool(true); + { + auto dst_th_indices = dst_indices; + dst_th_indices.push_back(T.thread_var); + auto inv = dst_layout->Inverse()->Forward(dst_th_indices); + inv.pop_back(); + for (int i = 0; i < static_cast(dst_layout->InputDim()); i++) { + predicate = predicate && (inv[i] == dst_vars[i]->var); + } + predicate = analyzer->Simplify(predicate); + } + if (need_duplicate) { + PrimExpr update = + need_update + ? MakeUpdate(op, BufferLoad(dst_buffer, dst_indices), + BufferLoad(clear_buffer, red_indices)) + : BufferLoad(clear_buffer, red_indices); + auto store = BufferStore(dst_buffer, update, dst_indices); + if (analyzer->CanProve(predicate)) { + stmts.push_back(store); + } else { + stmts.push_back(IfThenElse(predicate, store)); } - auto call = Call(clear_buffer->dtype, builtin::call_extern(), - thread_reduce_args); - stmts.push_back(BufferStore(clear_buffer, call, red_indices)); } - } - PrimExpr predicate = Bool(true); - { - auto dst_th_indices = dst_indices; - dst_th_indices.push_back(T.thread_var); - auto inv = dst_layout->Inverse()->Forward(dst_th_indices); - inv.pop_back(); - for (int i = 0; i < static_cast(dst_layout->InputDim()); i++) { - predicate = predicate && (inv[i] == dst_vars[i]->var); + auto body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]; + for (int i = static_cast(dst_layout->InputDim()) - 1; i >= 0; + --i) { + body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent, + ForKind::kParallel, body); } - predicate = analyzer->Simplify(predicate); - } - if (need_duplicate) { - PrimExpr update = - need_update ? MakeUpdate(op, BufferLoad(dst_buffer, dst_indices), - BufferLoad(clear_buffer, red_indices)) - : BufferLoad(clear_buffer, red_indices); - auto store = BufferStore(dst_buffer, update, dst_indices); - if (analyzer->CanProve(predicate)) { - stmts.push_back(store); + + if (dst_layout->InputDim() > 0) { + body = PartitionLoop(Downcast(body), T.thread_var, analyzer, + red_layout); + body = PragmaUnrollLoop(Downcast(body)); } else { - stmts.push_back(IfThenElse(predicate, store)); + auto guard = (T.thread_var == T.thread_bounds->min); + body = IfThenElse(guard, body); } - } - auto body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]; - for (int i = static_cast(dst_layout->InputDim()) - 1; i >= 0; --i) { - body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent, - ForKind::kParallel, body); - } - - if (dst_layout->InputDim() > 0) { - body = PartitionLoop(Downcast(body), T.thread_var, analyzer, - red_layout); - body = PragmaUnrollLoop(Downcast(body)); - } else { - auto guard = (T.thread_var == T.thread_bounds->min); - body = IfThenElse(guard, body); - } - - if (need_duplicate) { - body = Allocate(clear_buffer->data, clear_buffer->dtype, - clear_buffer->shape, const_true(), body); + if (need_duplicate) { + body = Allocate(clear_buffer->data, clear_buffer->dtype, + clear_buffer->shape, const_true(), body); + } + if (need_pack_buffer) { + body = + Allocate(clear_buffer_packed->data, clear_buffer_packed->dtype, + clear_buffer_packed->shape, const_true(), body); + } + return body; } - return body; } LOG(FATAL) << "Reduce for buffers in scope (" << src_scope << ", " @@ -587,23 +841,6 @@ struct Reduce { return Stmt(); } - static std::string MakeBatchAllReduce(std::string reducer, - int reducing_threads, int scale, - PrimExpr thread_offset, - PrimExpr all_threads, int batch, - int workspace_stride, Target target) { - std::stringstream ss; - ss << "tl::AllReduce<" << reducer << ", " << reducing_threads << ", " - << scale << ", " << thread_offset; - if (TargetHasSMVersionGE(target, 90)) { - ss << ", tl::NamedBarrier<" << all_threads << ">"; - } else { - ss << ", tl::SyncThreadsBarrier"; - } - ss << ", " << batch << ", " << workspace_stride << ">::run_batch"; - return ss.str(); - } - static std::string MakeScalarAllReduce(std::string reducer, int reducing_threads, int scale, PrimExpr thread_offset, diff --git a/src/backend/rocm/op/reduce.cc b/src/backend/rocm/op/reduce.cc index e2a8d98f5..a2e25c3b5 100644 --- a/src/backend/rocm/op/reduce.cc +++ b/src/backend/rocm/op/reduce.cc @@ -335,7 +335,8 @@ struct Reduce { BufferLoad(src_buffer, src_indice_compressed)), red_indices); - for (int i = static_cast(src_layout->OutputDim()) - 1; i >= 0; --i) { + for (int i = static_cast(src_layout->OutputDim()) - 1; i >= 0; + --i) { reduce_local = For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent, ForKind::kUnrolled, reduce_local, std::nullopt, diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 3ff0df96f..f38ec85e6 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -145,6 +145,12 @@ TIR_DEFINE_TL_BUILTIN(min2).set_num_inputs(2).set_attr( TIR_DEFINE_TL_BUILTIN(abs2).set_num_inputs(1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_TL_BUILTIN(max2_nan).set_num_inputs(2).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(min2_nan).set_num_inputs(2).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kPure)); + TIR_DEFINE_TL_BUILTIN(rng_init).set_num_inputs(4).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/builtin.h b/src/op/builtin.h index 69a65ab10..f418a157f 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -260,6 +260,8 @@ TVM_DLL const Op &fma2(); TVM_DLL const Op &max2(); TVM_DLL const Op &min2(); TVM_DLL const Op &abs2(); +TVM_DLL const Op &max2_nan(); +TVM_DLL const Op &min2_nan(); // random op TVM_DLL const Op &rng_init(); diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index eba257309..77e3d21d5 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -691,6 +691,13 @@ template TL_DEVICE uint1 to_uint1(T v) { return r; } +// Pack two half_t into a uint1. +TL_DEVICE uint1 pack_half2(half_t a, half_t b) { + unsigned packed = + __pack_half2(static_cast<__half>(a), static_cast<__half>(b)); + return uint1{packed}; +} + // --- add2 ---------------------------------------------------------------- TL_DEVICE float2 add2(float2 a, float2 b) { @@ -846,6 +853,42 @@ TL_DEVICE __half2 min2(__half2 a, __half2 b) { #endif } +// --- max2_nan ------------------------------------------------------------ + +TL_DEVICE __nv_bfloat162 max2_nan(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hmax2_nan(a, b); +#else + return __nv_bfloat162{__hmax_nan(a.x, b.x), __hmax_nan(a.y, b.y)}; +#endif +} + +TL_DEVICE __half2 max2_nan(__half2 a, __half2 b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + return __hmax2_nan(a, b); +#else + return __half2{__hmax_nan(a.x, b.x), __hmax_nan(a.y, b.y)}; +#endif +} + +// --- min2_nan ------------------------------------------------------------ + +TL_DEVICE __nv_bfloat162 min2_nan(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hmin2_nan(a, b); +#else + return __nv_bfloat162{__hmin_nan(a.x, b.x), __hmin_nan(a.y, b.y)}; +#endif +} + +TL_DEVICE __half2 min2_nan(__half2 a, __half2 b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + return __hmin2_nan(a, b); +#else + return __half2{__hmin_nan(a.x, b.x), __hmin_nan(a.y, b.y)}; +#endif +} + // --- abs2 ---------------------------------------------------------------- TL_DEVICE float2 abs2(float2 a) { return make_float2(fabsf(a.x), fabsf(a.y)); } @@ -974,4 +1017,26 @@ TL_DEVICE bfloat16_t shfl_sync(unsigned mask, bfloat16_t val, int srcLane) { return reinterpret_cast(ret16); } +// Specializations for uint1 (packed bfloat16x2 / float16x2). +// uint1 is a 32-bit struct { unsigned x; } used to represent packed pairs. +// __shfl_xor_sync operates on native 32-bit types, so we pass the raw unsigned. + +template <> +TL_DEVICE uint1 shfl_xor_sync(unsigned mask, uint1 val, int laneMask) { + return uint1{__shfl_xor_sync(mask, val.x, laneMask)}; +} + +template <> +TL_DEVICE uint1 shfl_down_sync(unsigned mask, uint1 val, int delta) { + return uint1{__shfl_down_sync(mask, val.x, delta)}; +} + +template <> TL_DEVICE uint1 shfl_up_sync(unsigned mask, uint1 val, int delta) { + return uint1{__shfl_up_sync(mask, val.x, delta)}; +} + +template <> TL_DEVICE uint1 shfl_sync(unsigned mask, uint1 val, int srcLane) { + return uint1{__shfl_sync(mask, val.x, srcLane)}; +} + } // namespace tl diff --git a/src/tl_templates/cuda/reduce.h b/src/tl_templates/cuda/reduce.h index 1099b8959..fa4d15c0a 100644 --- a/src/tl_templates/cuda/reduce.h +++ b/src/tl_templates/cuda/reduce.h @@ -88,6 +88,76 @@ struct MinOpNan { } }; +struct SumOp_bf16x2 { + template TL_DEVICE T operator()(T const &x, T const &y) { + return tl::to_uint1(tl::add2(tl::from_uint1<__nv_bfloat162>(x), + tl::from_uint1<__nv_bfloat162>(y))); + } +}; + +struct MaxOp_bf16x2 { + template TL_DEVICE T operator()(T const &x, T const &y) { + return tl::to_uint1(tl::max2(tl::from_uint1<__nv_bfloat162>(x), + tl::from_uint1<__nv_bfloat162>(y))); + } +}; + +struct MinOp_bf16x2 { + template TL_DEVICE T operator()(T const &x, T const &y) { + return tl::to_uint1(tl::min2(tl::from_uint1<__nv_bfloat162>(x), + tl::from_uint1<__nv_bfloat162>(y))); + } +}; + +struct SumOp_fp16x2 { + template TL_DEVICE T operator()(T const &x, T const &y) { + return tl::to_uint1( + tl::add2(tl::from_uint1<__half2>(x), tl::from_uint1<__half2>(y))); + } +}; + +struct MaxOp_fp16x2 { + template TL_DEVICE T operator()(T const &x, T const &y) { + return tl::to_uint1( + tl::max2(tl::from_uint1<__half2>(x), tl::from_uint1<__half2>(y))); + } +}; + +struct MinOp_fp16x2 { + template TL_DEVICE T operator()(T const &x, T const &y) { + return tl::to_uint1( + tl::min2(tl::from_uint1<__half2>(x), tl::from_uint1<__half2>(y))); + } +}; + +struct MaxOpNan_bf16x2 { + template TL_DEVICE T operator()(T const &x, T const &y) { + return tl::to_uint1(tl::max2_nan(tl::from_uint1<__nv_bfloat162>(x), + tl::from_uint1<__nv_bfloat162>(y))); + } +}; + +struct MinOpNan_bf16x2 { + template TL_DEVICE T operator()(T const &x, T const &y) { + return tl::to_uint1(tl::min2_nan(tl::from_uint1<__nv_bfloat162>(x), + tl::from_uint1<__nv_bfloat162>(y))); + } +}; + +struct MaxOpNan_fp16x2 { + template TL_DEVICE T operator()(T const &x, T const &y) { + return tl::to_uint1( + tl::max2_nan(tl::from_uint1<__half2>(x), tl::from_uint1<__half2>(y))); + } +}; + +struct MinOpNan_fp16x2 { + template TL_DEVICE T operator()(T const &x, T const &y) { + return tl::to_uint1( + tl::min2_nan(tl::from_uint1<__half2>(x), tl::from_uint1<__half2>(y))); + } +}; + struct BitAndOp { template TL_DEVICE T operator()(T const &x, T const &y) { return x & y; diff --git a/testing/python/language/test_tilelang_language_reduce.py b/testing/python/language/test_tilelang_language_reduce.py index 3b49fa602..8c73b6c63 100644 --- a/testing/python/language/test_tilelang_language_reduce.py +++ b/testing/python/language/test_tilelang_language_reduce.py @@ -7,7 +7,6 @@ tilelang.testing.set_random_seed() - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -76,6 +75,8 @@ def _reduce_op(T, op, src, dst, dim, batch=1): ("sum", T.float32, 128, 64, "shared", "fragment", 256, 2), ("sum", T.float32, 128, 64, "shared", "fragment", 256, 4), ("sum", T.float16, 64, 128, "fragment", "fragment", 256, 4), + ("sum", T.bfloat16, 128, 128, "fragment", "fragment", 32, 1), + ("sum", T.bfloat16, 64, 128, "fragment", "fragment", 256, 4), ("max", T.bfloat16, 128, 64, "shared", "fragment", 256, 2), ("max", T.float32, 128, 128, "fragment", "fragment", 256, 4), ("min", T.float32, 64, 128, "shared", "fragment", 128, 2), @@ -121,7 +122,6 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M,), dtype)): src = jit_kernel.get_kernel_source() m = re.search(r",\s*(\d+)\s*,\s*\d+\s*>::run_batch\(", src) assert m is not None, f"Expected run_batch in generated source.\n{src}" - assert int(m.group(1)) > 1, f"Expected batch_size > 1, got {m.group(1)}.\n{src}" A = _make_input(M, N, dtype) B = jit_kernel(A) @@ -307,5 +307,90 @@ def kernel(A: T.Tensor((block_M, 64), T.float32), B: T.Tensor((block_M,), T.floa tl.compile(k, out_idx=-1, pass_configs=_COMPILE_FLAGS) +# --------------------------------------------------------------------------- +# nan_propagate tests – packed (vsize=2) path for bf16/fp16 +# --------------------------------------------------------------------------- + + +def _compile(prim_func): + return tilelang.compile(prim_func, out_idx=-1, target="cuda") + + +def _make_nan_reduce_kernel(reduce_fn, M, N, dtype, threads, *, nan_propagate): + @T.prim_func + def kernel(A: T.Tensor((M, N), dtype), B: T.Tensor((M,), dtype)): + with T.Kernel(1, threads=threads): + src = T.alloc_fragment((M, N), dtype) + dst = T.alloc_fragment((M,), dtype) + T.copy(A, src) + reduce_fn(src, dst, dim=1, nan_propagate=nan_propagate) + T.copy(dst, B) + + return kernel + + +@tilelang.testing.requires_cuda +def test_reduce_packed_max_nan_propagate_uses_nan_intrinsics(): + k = _compile(_make_nan_reduce_kernel(T.reduce_max, 128, 128, T.float16, threads=256, nan_propagate=True)) + src = k.get_kernel_source() + assert "max2_nan" in src + assert "tl::MaxOpNan" in src + + +@tilelang.testing.requires_cuda +def test_reduce_packed_min_nan_propagate_uses_nan_intrinsics(): + k = _compile(_make_nan_reduce_kernel(T.reduce_min, 128, 128, T.bfloat16, threads=256, nan_propagate=True)) + src = k.get_kernel_source() + assert "min2_nan" in src + assert "tl::MinOpNan" in src + + +@tilelang.testing.requires_cuda +def test_reduce_packed_absmax_nan_propagate_uses_nan_intrinsics(): + k = _compile(_make_nan_reduce_kernel(T.reduce_absmax, 128, 128, T.float16, threads=256, nan_propagate=True)) + src = k.get_kernel_source() + assert "max2_nan" in src + assert "tl::MaxOpNan" in src + + +@tilelang.testing.requires_cuda +def test_reduce_packed_max_nan_propagate_runtime(): + import math + + for tl_dtype, torch_dtype in [(T.float16, torch.float16), (T.bfloat16, torch.bfloat16)]: + M, N = 128, 128 + A = torch.arange(N, dtype=torch.float32).to(torch_dtype).repeat(M, 1).cuda() + A[0, 7] = float("nan") + B = _compile(_make_nan_reduce_kernel(T.reduce_max, M, N, tl_dtype, threads=256, nan_propagate=True))(A) + assert not math.isnan(B[1:].float().max().item()), f"{tl_dtype}: non-NaN rows should not produce NaN" + assert math.isnan(B[0].float().item()), f"{tl_dtype}: NaN row must produce NaN" + + +@tilelang.testing.requires_cuda +def test_reduce_packed_min_nan_propagate_runtime(): + import math + + for tl_dtype, torch_dtype in [(T.float16, torch.float16), (T.bfloat16, torch.bfloat16)]: + M, N = 128, 128 + A = torch.arange(N, dtype=torch.float32).to(torch_dtype).repeat(M, 1).cuda() + A[1, 13] = float("nan") + B = _compile(_make_nan_reduce_kernel(T.reduce_min, M, N, tl_dtype, threads=256, nan_propagate=True))(A) + assert not math.isnan(B[0].float().item()), f"{tl_dtype}: non-NaN rows should not produce NaN" + assert math.isnan(B[1].float().item()), f"{tl_dtype}: NaN row must produce NaN" + + +@tilelang.testing.requires_cuda +def test_reduce_packed_max_nan_batch_runtime(): + import math + + for tl_dtype, torch_dtype in [(T.float16, torch.float16), (T.bfloat16, torch.bfloat16)]: + M, N = 64, 128 + A = torch.arange(N, dtype=torch.float32).to(torch_dtype).repeat(M, 1).cuda() + A[2, 7] = float("nan") + B = _compile(_make_nan_reduce_kernel(T.reduce_max, M, N, tl_dtype, threads=256, nan_propagate=True))(A) + assert not math.isnan(B[0].float().item()), f"{tl_dtype}: non-NaN rows should not produce NaN" + assert math.isnan(B[2].float().item()), f"{tl_dtype}: NaN row must produce NaN" + + if __name__ == "__main__": tilelang.testing.main()