Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "PTO/Transforms/InsertSync/InsertSyncAnalysis.h"
#include "PTO/Transforms/InsertSync/SyncCommon.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "llvm/Support/Casting.h"
Expand All @@ -36,6 +37,35 @@ static bool isValidPipeIndex(PipelineType pipe) {
return static_cast<unsigned>(pipe) < kPipeStateSize;
}

static std::optional<int64_t> getConstantIndexValue(Value value) {
llvm::APInt apIntValue;
if (!matchPattern(value, m_ConstantInt(&apIntValue)))
return std::nullopt;
return apIntValue.getSExtValue();
}

static bool hasAtMostOneIteration(const LoopInstanceElement *loopElement) {
if (loopElement == nullptr || loopElement->elementOp == nullptr)
return false;

auto forOp = dyn_cast<scf::ForOp>(loopElement->elementOp);
if (!forOp)
return false;

auto lowerBound = getConstantIndexValue(forOp.getLowerBound());
auto upperBound = getConstantIndexValue(forOp.getUpperBound());
auto step = getConstantIndexValue(forOp.getStep());
if (!lowerBound || !upperBound || !step || *step <= 0)
return false;

if (*upperBound <= *lowerBound)
return true;

int64_t span = *upperBound - *lowerBound;
int64_t tripCount = (span + *step - 1) / *step;
return tripCount <= 1;
Comment on lines +64 to +66
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The trip count calculation (span + *step - 1) / *step can be simplified to (*upperBound - *lowerBound) <= *step. This is more concise and avoids potential overflow in the intermediate addition span + *step - 1 when span is very large. Since the previous check ensures *upperBound > *lowerBound and *step > 0, this logic correctly identifies loops with exactly one iteration.

Suggested change
int64_t span = *upperBound - *lowerBound;
int64_t tripCount = (span + *step - 1) / *step;
return tripCount <= 1;
return (*upperBound - *lowerBound) <= *step;

}

// ==============================================================================
// 1. Entry Point
// ==============================================================================
Expand Down Expand Up @@ -80,6 +110,12 @@ void InsertSyncAnalysis::DealWithLoopSync(LoopInstanceElement *nowElement) {
return;
}

// No loop-carried dependence exists when a loop executes at most once.
// In that case, skip backward-edge sync analysis entirely.
if (hasAtMostOneIteration(nowElement)) {
return;
}

SyncIRs backSyncIr;
assert(syncIR_.size() >= nowElement->endId);
for (unsigned i = nowElement->beginId; i < nowElement->endId; i++) {
Expand Down
37 changes: 37 additions & 0 deletions test/basic/issue233_single_iter_loop_skip_backedge_sync.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: ptoas --pto-arch=a3 --enable-insert-sync %s | FileCheck %s
//
// Regression guard:
// - A constant-trip scf.for with at most one iteration must not trigger
// loop-carried backward sync seed/drain.
// - Keep ordinary intra-iteration sync insertion behavior unchanged.
//
// CHECK-LABEL: __global__ AICORE void single_iter_loop_skip_backedge_sync()
// CHECK: for (size_t
// CHECK-NOT: set_flag(PIPE_M, PIPE_MTE1
// CHECK-NOT: wait_flag(PIPE_M, PIPE_MTE1
// CHECK: ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll);

module {
func.func @single_iter_loop_skip_backedge_sync() {
pto.section.cube {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index

%m0 = pto.alloc_tile : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>
%m1 = pto.alloc_tile : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>
%l0 = pto.alloc_tile : !pto.tile_buf<loc=left, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=row_major, slayout=row_major, fractal=512, pad=0>
%r0 = pto.alloc_tile : !pto.tile_buf<loc=right, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=row_major, slayout=col_major, fractal=512, pad=0>
%acc = pto.alloc_tile : !pto.tile_buf<loc=acc, dtype=f32, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=1024, pad=0>

pto.tmov ins(%m0 : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>) outs(%l0 : !pto.tile_buf<loc=left, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=row_major, slayout=row_major, fractal=512, pad=0>)
pto.tmov ins(%m1 : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>) outs(%r0 : !pto.tile_buf<loc=right, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=row_major, slayout=col_major, fractal=512, pad=0>)

scf.for %i = %c0 to %c1 step %c1 {
pto.tmov ins(%m0 : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>) outs(%l0 : !pto.tile_buf<loc=left, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=row_major, slayout=row_major, fractal=512, pad=0>)
pto.tmatmul ins(%l0, %r0 : !pto.tile_buf<loc=left, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=row_major, slayout=row_major, fractal=512, pad=0>, !pto.tile_buf<loc=right, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=row_major, slayout=col_major, fractal=512, pad=0>) outs(%acc : !pto.tile_buf<loc=acc, dtype=f32, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=1024, pad=0>)
pto.tmov ins(%acc : !pto.tile_buf<loc=acc, dtype=f32, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=1024, pad=0>) outs(%m0 : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>)
}
}
return
}
}
Loading