Skip to content
Open
229 changes: 183 additions & 46 deletions src/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,41 +132,58 @@ PrimExpr ReduceOpNode::MakeInitValue() const {
}
}

PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &acc,
const PrimExpr &b) const {
PrimExpr rhs = b;
if (acc->dtype != rhs->dtype) {
rhs = Cast(acc->dtype, rhs);
PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &acc, const PrimExpr &b,
int pack_lanes) const {
if (pack_lanes == 1) {
PrimExpr rhs = b;
if (acc->dtype != rhs->dtype) {
rhs = Cast(acc->dtype, rhs);
}
const bool use_nan_op =
nan_propagate &&
(acc.dtype().is_float16() || acc.dtype().is_bfloat16());
if (type->isSum()) {
return acc + rhs;
} else if (type->isAbsSum()) {
return acc + Max(rhs, -rhs);
} else if (type->isMax()) {
if (use_nan_op) {
return Call(acc.dtype(), tl::max_nan(), {acc, rhs});
}
return Max(acc, rhs);
} else if (type->isMin()) {
if (use_nan_op) {
return Call(acc.dtype(), tl::min_nan(), {acc, rhs});
}
return Min(acc, rhs);
} else if (type->isAbsMax()) {
if (use_nan_op) {
return Call(acc.dtype(), tl::max_nan(), {acc, tvm::abs(rhs)});
}
return Max(acc, tvm::abs(rhs));
} else if (type->isBitAnd()) {
return acc & rhs;
} else if (type->isBitOr()) {
return acc | rhs;
} else if (type->isBitXor()) {
return acc ^ rhs;
} else {
LOG(FATAL) << "Unsupported reduce type: " << type->type;
}
}
const bool use_nan_op =
nan_propagate && (acc.dtype().is_float16() || acc.dtype().is_bfloat16());
if (type->isSum()) {
return acc + rhs;
} else if (type->isAbsSum()) {
return acc + Max(rhs, -rhs);

if (type->isSum() || type->isAbsSum()) {
return Call(acc.dtype(), tl::add2(), {acc, b});
} else if (type->isMax()) {
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
if (use_nan_op) {
return Call(acc.dtype(), tl::max_nan(), {acc, rhs});
}
return Max(acc, rhs);
return Call(acc.dtype(), tl::max2(), {acc, b});
} else if (type->isMin()) {
if (use_nan_op) {
return Call(acc.dtype(), tl::min_nan(), {acc, rhs});
}
return Min(acc, rhs);
return Call(acc.dtype(), tl::min2(), {acc, b});
} else if (type->isAbsMax()) {
if (use_nan_op) {
return Call(acc.dtype(), tl::max_nan(), {acc, tvm::abs(rhs)});
}
return Max(acc, tvm::abs(rhs));
} else if (type->isBitAnd()) {
return acc & rhs;
} else if (type->isBitOr()) {
return acc | rhs;
} else if (type->isBitXor()) {
return acc ^ rhs;
return Call(acc.dtype(), tl::max2(),
{acc, Call(acc.dtype(), tl::abs2(), {b})});
} else {
LOG(FATAL) << "Unsupported reduce type: " << type->type;
LOG(FATAL) << "Unsupported packed reduce type: " << type->type;
return PrimExpr();
}
}

Expand Down Expand Up @@ -265,6 +282,28 @@ static Fragment ComputeReducerLayout(const Fragment &src_layout, int dim) {
* normalization.
* @return Stmt Lowered TIR statement implementing the reduction.
*/

struct ReducePackConfig {
int pack_lanes;
DataType vec_dtype;
};

static std::optional<ReducePackConfig>
GetReducePackConfig(DataType dt, Target target) {
if (!TargetIsCuda(target))
return std::nullopt;
int pack = 0;
if (dt.is_bfloat16() || dt.is_float16()) {
pack = 2;
} else if (dt.is_float() && dt.bits() == 32) {
if (TargetHasSMVersionGE(target, 100))
pack = 2;
}
if (pack == 0)
return std::nullopt;
return ReducePackConfig{pack, dt.with_lanes(pack)};
}

Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (nan_propagate && (dst->dtype.is_float16() || dst->dtype.is_bfloat16()) &&
!TargetIsCuda(T.target)) {
Expand Down Expand Up @@ -374,12 +413,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
// For max/min/absmax with clear=false and need_duplicate, we still need to
// initialize the temporary buffer with identity values since the original
// dst values will be combined later via need_update
if (require_init ||
(need_duplicate && (this->type->isMax() || this->type->isMin() ||
this->type->isAbsMax()))) {
stmts.push_back(
BufferStore(clear_buffer, this->MakeInitValue(), red_indices));
}

// make thread-local reduce
Array<PrimExpr> src_indice_compressed;
Expand All @@ -391,19 +424,115 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
src_var_compressed.push_back(var);
}

Stmt reduce_local = BufferStore(
clear_buffer,
this->MakeReduce(BufferLoad(clear_buffer, red_indices),
BufferLoad(src_buffer, src_indice_compressed)),
red_indices);
bool can_pack = false;
bool need_pack_buffer = false;
Buffer clear_buffer_packed;
if (auto cfg =
GetReducePackConfig(clear_buffer->dtype, T.target)) {
if (!src_var_compressed.empty() && !nan_propagate) {
auto *ext =
src_var_compressed.back()->dom->extent.as<IntImmNode>();
if (ext && ext->value >= cfg->pack_lanes &&
ext->value % cfg->pack_lanes == 0) {
can_pack = true;
clear_buffer_packed = decl_buffer(
red_layout->OutputShape(), cfg->vec_dtype,
clear_buffer->name + "_pack",
GetPtrStorageScope(clear_buffer->data));
need_pack_buffer = true;

Array<Stmt> local_body;

if (require_init ||
(need_duplicate &&
(this->type->isMax() || this->type->isMin() ||
this->type->isAbsMax()))) {
auto init = this->MakeInitValue();
local_body.push_back(BufferStore(
clear_buffer_packed,
Broadcast(init, cfg->pack_lanes), red_indices));
}

for (int i = static_cast<int>(src_layout->OutputDim()) - 1; i >= 0; --i) {
reduce_local =
For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
ForKind::kUnrolled, reduce_local, std::nullopt,
const auto *ext_int = as_const_int(
src_var_compressed.back()->dom->extent);
int64_t inner_extent = *ext_int;
PrimExpr halved_extent =
Integer(inner_extent / cfg->pack_lanes);

auto &inner_var = src_var_compressed.back();

PrimExpr ramp_base =
Substitute(src_indice_compressed.back(),
{{inner_var->var, inner_var->var * Integer(2)}});
src_indice_compressed.Set(
src_indice_compressed.size() - 1,
Ramp(ramp_base,
IntImm(DataType::Int(32), 1), cfg->pack_lanes));

auto src_load = BufferLoad(src_buffer, src_indice_compressed);
auto *src_writer = src_load.CopyOnWrite();
src_writer->dtype = cfg->vec_dtype;

Stmt reduce_local = BufferStore(
clear_buffer_packed,
this->MakeReduce(BufferLoad(clear_buffer_packed, red_indices),
src_load, cfg->pack_lanes),
red_indices);

reduce_local = For(
inner_var->var, 0, halved_extent, ForKind::kUnrolled,
reduce_local, std::nullopt,
{{tir::attr::pragma_unroll_explicit, Bool(false)}});

for (int i =
static_cast<int>(src_layout->OutputDim()) - 2;
i >= 0; --i) {
reduce_local = For(
src_var_compressed[i]->var, 0,
src_var_compressed[i]->dom->extent, ForKind::kUnrolled,
reduce_local, std::nullopt,
{{tir::attr::pragma_unroll_explicit, Bool(false)}});
}
local_body.push_back(reduce_local);

auto acc_vec =
BufferLoad(clear_buffer_packed, red_indices);
auto lane0 = Shuffle::ExtractElement(acc_vec, 0);
auto lane1 = Shuffle::ExtractElement(acc_vec, 1);
auto scalar_result =
this->MakeReduce(lane0, lane1, /*pack_lanes=*/1);
local_body.push_back(BufferStore(
clear_buffer, scalar_result, red_indices));

stmts.push_back(SeqStmt(local_body));
}
}
}

if (!can_pack) {
if (require_init ||
(need_duplicate && (this->type->isMax() || this->type->isMin() ||
this->type->isAbsMax()))) {
stmts.push_back(
BufferStore(clear_buffer, this->MakeInitValue(), red_indices));
}

Stmt reduce_local = BufferStore(
clear_buffer,
this->MakeReduce(BufferLoad(clear_buffer, red_indices),
BufferLoad(src_buffer, src_indice_compressed)),
red_indices);

for (int i = static_cast<int>(src_layout->OutputDim()) - 1; i >= 0;
--i) {
reduce_local = For(
src_var_compressed[i]->var, 0,
src_var_compressed[i]->dom->extent, ForKind::kUnrolled,
reduce_local, std::nullopt,
{{tir::attr::pragma_unroll_explicit, Bool(false)}});
}
stmts.push_back(reduce_local);
}
stmts.push_back(reduce_local);

auto src_thread = src_layout->ForwardThread(
src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
Expand Down Expand Up @@ -615,6 +744,10 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
body = Allocate(clear_buffer->data, clear_buffer->dtype,
clear_buffer->shape, const_true(), body);
}
if (need_pack_buffer) {
body = Allocate(clear_buffer_packed->data, clear_buffer_packed->dtype,
clear_buffer_packed->shape, const_true(), body);
}
return body;

} else {
Expand Down Expand Up @@ -722,6 +855,10 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
body = Allocate(clear_buffer->data, clear_buffer->dtype,
clear_buffer->shape, const_true(), body);
}
if (need_pack_buffer) {
body = Allocate(clear_buffer_packed->data, clear_buffer_packed->dtype,
clear_buffer_packed->shape, const_true(), body);
}
return body;
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/op/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ class ReduceOpNode : public TileOperatorNode {
/// Generate initial value for reduction
PrimExpr MakeInitValue() const;
/// Generate reduction expression
PrimExpr MakeReduce(const PrimExpr &acc, const PrimExpr &b) const;
/// pack_lanes = 1 for scalar, 2 for add2/max2/min2, etc.
PrimExpr MakeReduce(const PrimExpr &acc, const PrimExpr &b,
int pack_lanes = 1) const;
/// Generate codegen reducer string
std::string MakeCodegenReducer() const;
};
Expand Down
Loading
Loading