Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ GemmInst GemmNode::getGemmInst(int block_size, Target target) const {
return GemmInst::kMMA;
} else if (TargetIsCPU(target)) {
return GemmInst::kScalar;
} else if (TargetIsMetal(target)) {
return GemmInst::kScalar;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
return GemmInst::kMMA;
Expand Down
3 changes: 2 additions & 1 deletion src/op/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ TVM_REGISTER_OP("tl.infinity")
Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "infinity")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", infinity_op)
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", infinity_op);
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", infinity_op)
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", infinity_op);

} // namespace tl
} // namespace tvm
52 changes: 49 additions & 3 deletions src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,29 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
Map<Buffer, Layout> layout_map_;
};

bool IndicesAreLoopInvariant(const Array<PrimExpr> &indices,
const Array<IterVar> &loop_vars) {
bool depends_on_loop_var = false;
auto visitor = [&](const ObjectRef &obj) {
if (depends_on_loop_var)
return;
if (const auto *var = obj.as<VarNode>()) {
for (const auto &iv : loop_vars) {
if (var == iv->var.get()) {
depends_on_loop_var = true;
return;
}
}
}
};
for (const auto &index : indices) {
PostOrderVisit(index, visitor);
if (depends_on_loop_var)
return false;
}
return true;
}

} // anonymous namespace

/**
Expand Down Expand Up @@ -360,12 +383,24 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
continue;

auto frag = T.layout_map[buffer].as<Fragment>().value();
bool is_completed_replicated = frag->IsCompletedReplicated();
bool is_fully_replicated =
IsBufferCompletelyReplicated(buffer, T.layout_map);

if (access.is_write) {
source_buffer = buffer;
if (is_fully_replicated) {
if (!replicated_write_buffer.defined()) {
replicated_write_buffer = buffer;
}
} else {
source_buffer = buffer;
}
} else {
// Copy-out loops should not inherit a fully replicated source layout:
// that would guard the shared/global store to one replicate thread.
if (is_completed_replicated && !store_shared_global_buffers_.empty()) {
continue;
}
// Keep the buffer with largest number of indices
// (which means the inference based on that buffer is more accurate)
// as read_source_buffer to get more accurate layout
Expand All @@ -385,6 +420,12 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
}
}
}
if (!source_buffer.defined() && read_source_buffer.defined()) {
source_buffer = read_source_buffer;
}
if (!source_buffer.defined() && replicated_write_buffer.defined()) {
source_buffer = replicated_write_buffer;
}
// moved to ComputeLoopLayoutFromBuffer

// Try to infer loop layout from buffers in order of preference only if we
Expand Down Expand Up @@ -571,7 +612,12 @@ bool ParallelOpNode::ValidateCandidateAgainstFragments(
auto fragment = T.layout_map[buffer].as<Fragment>().value();
std::ostringstream oss;
bool success = true;
if (access.is_read &&
bool replicated_read = access.is_read && fragment->IsCompletedReplicated();
bool replicated_local_write =
access.is_write && fragment->IsCompletedReplicated() &&
store_shared_global_buffers_.empty() &&
IndicesAreLoopInvariant(access.indices, loop_vars_);
if (access.is_read && !replicated_read &&
!ProveFragmentContains(candidate, fragment, vars, access.indices,
analyzer_, check_forward_index)) {
if (throw_on_error) {
Expand All @@ -582,7 +628,7 @@ bool ParallelOpNode::ValidateCandidateAgainstFragments(
}
success = false;
}
if (access.is_write &&
if (access.is_write && !replicated_local_write &&
!ProveFragmentContains(fragment, candidate, access.indices, vars,
analyzer_, check_forward_index)) {
if (throw_on_error) {
Expand Down
118 changes: 90 additions & 28 deletions src/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,30 @@ static Fragment ComputeReducerLayout(const Fragment &src_layout, int dim) {
return reducer_layout;
}

static Array<PrimExpr> ComputeReducerShape(const Fragment &src_layout,
int dim) {
auto reducer_shape = src_layout->InputShape();
reducer_shape.erase(reducer_shape.begin() + dim);
if (reducer_shape.empty()) {
reducer_shape.push_back(1);
}
return reducer_shape;
}

static Fragment ComputeMetalReplicatedReducerLayout(const Fragment &src_layout,
int dim,
Range thread_bounds) {
auto reducer_shape = ComputeReducerShape(src_layout, dim);
auto forward_index = InputPlaceholders(reducer_shape.size());
auto thread_range = src_layout->ThreadRange();
if (!thread_range.defined()) {
thread_range = thread_bounds;
}
return Fragment(reducer_shape, forward_index, ReplicationPlaceholder(),
src_layout->ReplicateExtent(), std::nullopt)
->BindThreadRange(thread_range);
}

/**
* @brief Lower the Reduce operator to a TIR statement.
*
Expand Down Expand Up @@ -391,6 +415,36 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
src_var_compressed.push_back(var);
}

bool can_use_metal_replicated_path =
TargetIsMetal(T.target) && src_layout->IsCompletedReplicated() &&
dst_layout->IsCompletedReplicated() &&
!src_buffer->data.same_as(dst_buffer->data);
if (can_use_metal_replicated_path) {
PrimExpr init_value = this->clear ? this->MakeInitValue()
: BufferLoad(dst_buffer, dst_indices);

Stmt init = BufferStore(dst_buffer, init_value, dst_indices);
Stmt reduce_local = BufferStore(
dst_buffer,
this->MakeReduce(BufferLoad(dst_buffer, dst_indices),
BufferLoad(src_buffer, src_indice_compressed)),
dst_indices);
for (int i = static_cast<int>(src_layout->OutputDim()) - 1; i >= 0; --i) {
reduce_local = For(src_var_compressed[i]->var, 0,
src_var_compressed[i]->dom->extent, ForKind::kSerial,
reduce_local, std::nullopt,
{{tir::attr::pragma_unroll_explicit, Bool(false)}});
}

Stmt body = SeqStmt({init, reduce_local});
for (int i = static_cast<int>(dst_vars.size()) - 1; i >= 0; --i) {
body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
ForKind::kSerial, body, std::nullopt,
{{tir::attr::pragma_unroll_explicit, Bool(false)}});
}
return body;
}

Stmt reduce_local = BufferStore(
clear_buffer,
this->MakeReduce(BufferLoad(clear_buffer, red_indices),
Expand Down Expand Up @@ -733,39 +787,47 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {

LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (level >= InferLevel::kStrict)
if (!(IsFragmentBuffer(src) && IsFragmentBuffer(dst) &&
T.layout_map.count(src))) {
return {};
}

if (IsFragmentBuffer(src) && IsFragmentBuffer(dst) &&
T.layout_map.count(src)) {
auto src_layout = T.layout_map[src].as<Fragment>().value();
auto reducer_layout = ComputeReducerLayout(src_layout, this->dim);
auto src_layout = T.layout_map[src].as<Fragment>().value();
bool use_metal_replicated_layout =
TargetIsMetal(T.target) && src_layout->IsCompletedReplicated();
if (level >= InferLevel::kStrict && !use_metal_replicated_layout) {
return {};
}

if (!T.layout_map.count(dst)) {
return {{dst, reducer_layout}};
}
Fragment reducer_layout = use_metal_replicated_layout
? ComputeMetalReplicatedReducerLayout(
src_layout, this->dim, T.thread_bounds)
: ComputeReducerLayout(src_layout, this->dim);

auto orig_dst_layout = T.layout_map.Get(dst).value().as<Fragment>().value();
ICHECK(reducer_layout->InputDim() == orig_dst_layout->InputDim());
if (!T.layout_map.count(dst)) {
return {{dst, reducer_layout}};
}

auto indices = InputPlaceholders(reducer_layout->InputDim());
arith::Analyzer analyzer;
for (size_t i = 0; i < indices.size(); i++) {
analyzer.Bind(Downcast<Var>(indices[i]),
Range(0, reducer_layout->InputShape()[i]));
}
if (!ProveFragmentContains(orig_dst_layout, reducer_layout, indices,
indices, analyzer)) {
std::ostringstream oss;
oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. "
<< src << "\n"
<< "src_layout = " << src_layout << "\n"
<< "reducer_layout = " << reducer_layout << "\n"
<< "orig_dst_layout = " << orig_dst_layout << "\n"
<< "You may need to use a shared memory to transform the "
"layout";
throw LayoutConflictException(oss.str());
}
auto orig_dst_layout = T.layout_map.Get(dst).value().as<Fragment>().value();
ICHECK(reducer_layout->InputDim() == orig_dst_layout->InputDim());

auto indices = InputPlaceholders(reducer_layout->InputDim());
arith::Analyzer analyzer;
for (size_t i = 0; i < indices.size(); i++) {
analyzer.Bind(Downcast<Var>(indices[i]),
Range(0, reducer_layout->InputShape()[i]));
}
if (!ProveFragmentContains(orig_dst_layout, reducer_layout, indices, indices,
analyzer)) {
std::ostringstream oss;
oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. "
<< src << "\n"
<< "src_layout = " << src_layout << "\n"
<< "reducer_layout = " << reducer_layout << "\n"
<< "orig_dst_layout = " << orig_dst_layout << "\n"
<< "You may need to use a shared memory to transform the "
"layout";
throw LayoutConflictException(oss.str());
}
return {};
}
Expand Down
24 changes: 16 additions & 8 deletions src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1284,14 +1284,22 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
tvm::transform::Pass LayoutInference() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
ThreadBindingCollector collector;
collector(f->body);
bool has_thread_binding = !collector.thread_binding_.empty();
bool skip_thread_partition = !has_thread_binding;
f = LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
// Validate parallel loop layout annotations
ParallelLoopLayoutValidator::Validate(f->body);
return f;
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
auto run = [&]() {
ThreadBindingCollector collector;
collector(f->body);
bool has_thread_binding = !collector.thread_binding_.empty();
bool skip_thread_partition = !has_thread_binding;
f = LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
// Validate parallel loop layout annotations
ParallelLoopLayoutValidator::Validate(f->body);
return f;
};
if (target.defined()) {
With<Target> target_scope(target.value());
return run();
}
return run();
};
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}
Expand Down
6 changes: 6 additions & 0 deletions src/transform/legalize_vectorized_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/

#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
Expand Down Expand Up @@ -82,6 +83,11 @@ tvm::transform::Pass LegalizeVectorizedLoop() {
using namespace tir::transform;
// Define the transformation function to be applied
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
if (target.defined()) {
With<Target> target_scope(target.value());
return LoopVectorizedLegalizer::Substitute(std::move(f));
}
return LoopVectorizedLegalizer::Substitute(std::move(f));
};
// Create and return a PrimFunc pass with the transformation function
Expand Down
7 changes: 6 additions & 1 deletion src/transform/lower_tile_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
DataType from_ty = cast->value.dtype();
DataType target_ty = cast->dtype;
if (IsCudaVectorizableCast(from_ty, target_ty) &&
TargetIsCuda(Target::Current())) {
TargetIsCuda(target_)) {
has_cast_operations = true;
}
}
Expand Down Expand Up @@ -1388,6 +1388,11 @@ using namespace tir::transform;

tvm::transform::Pass LowerTileOp() {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
if (target.defined()) {
With<Target> target_scope(target.value());
return LowerTileOpPass::Substitute(std::move(f));
}
return LowerTileOpPass::Substitute(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {});
Expand Down
Loading
Loading