Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions lib/PTO/Transforms/PTOToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2864,9 +2864,11 @@ struct SubviewToEmitCPattern : public OpConversionPattern<memref::SubViewOp> {
// 2. 生成 Shape 模板参数,之后会右对齐有效维度并补齐到 5 维(高维填 1)
SmallVector<std::string> shapeParamsVec;
SmallVector<Value> sizeValues; // 每个维度对应的运行时 size(统一为 unsigned)
SmallVector<bool> staticSingletonDims; // 可静态证明为 1 的维度
auto resShape = resTy.getShape();
auto mixedSizes = op.getMixedSizes();
sizeValues.reserve(rank);
staticSingletonDims.reserve(rank);
for (int i = 0; i < resTy.getRank(); ++i) {
if (resShape[i] == ShapedType::kDynamic) {
shapeParamsVec.push_back("-1");
Expand All @@ -2879,6 +2881,13 @@ struct SubviewToEmitCPattern : public OpConversionPattern<memref::SubViewOp> {
else
sizeValues.push_back(
mkU32(resShape[i] == ShapedType::kDynamic ? 1 : resShape[i]));

std::optional<int64_t> staticDim;
if (i < (int)mixedSizes.size())
staticDim = extractStaticInt(mixedSizes[i]);
if (!staticDim && resShape[i] != ShapedType::kDynamic)
staticDim = resShape[i];
staticSingletonDims.push_back(staticDim && *staticDim == 1);
}

// 3. 生成 Stride 模板参数 + 运行时 stride 值(考虑 subview step)
Expand Down Expand Up @@ -2917,6 +2926,57 @@ struct SubviewToEmitCPattern : public OpConversionPattern<memref::SubViewOp> {
rewriter.create<emitc::MulOp>(loc, u32Ty, srcV, stepV));
}

// 3.0 对可证明安全的分区视图做 singleton 轴前移:
// 仅允许 size==1 的轴跨越其它轴,保证地址集合不变;动态轴视为非 singleton,不重排。
if (rank > 2) {
int nonSingletonCount = 0;
for (bool isSingleton : staticSingletonDims) {
if (!isSingleton)
++nonSingletonCount;
}
if (nonSingletonCount <= 2) {
SmallVector<unsigned, 8> permutation;
permutation.reserve(rank);
for (int i = 0; i < rank; ++i) {
if (staticSingletonDims[i])
permutation.push_back(i);
}
for (int i = 0; i < rank; ++i) {
if (!staticSingletonDims[i])
permutation.push_back(i);
}

bool changed = false;
for (int i = 0; i < rank; ++i) {
if (permutation[i] != static_cast<unsigned>(i)) {
changed = true;
break;
}
}

if (changed) {
SmallVector<std::string> reorderedShape;
SmallVector<Value> reorderedSizes;
SmallVector<std::string> reorderedStrides;
SmallVector<Value> reorderedStrideValues;
reorderedShape.reserve(rank);
reorderedSizes.reserve(rank);
reorderedStrides.reserve(rank);
reorderedStrideValues.reserve(rank);
for (unsigned idx : permutation) {
reorderedShape.push_back(shapeParamsVec[idx]);
reorderedSizes.push_back(sizeValues[idx]);
reorderedStrides.push_back(dummyStrideVec[idx]);
reorderedStrideValues.push_back(strideValues[idx]);
}
shapeParamsVec = std::move(reorderedShape);
sizeValues = std::move(reorderedSizes);
dummyStrideVec = std::move(reorderedStrides);
strideValues = std::move(reorderedStrideValues);
}
}
}

// 3.1 右对齐到 5 维:shape 补 1;已有维度继承原 stride;
// 被补出来的高维按“紧密升维”规则连续推导:stride[i] = shape[i+1] * stride[i+1]
SmallVector<std::string, 5> finalShape(5, "1");
Expand Down
45 changes: 45 additions & 0 deletions test/basic/issue453_partition_view_singleton_axis_reorder.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5

module {
// Safe case: only singleton axes are moved. This should canonicalize
// [1, 16, 1, 16] -> [1, 1, 16, 16] before 5D right-align, so emitted
// GlobalTensor shape becomes <1, 1, 1, 16, 16>.
func.func @issue453_singleton_axis_reorder_positive(%src: !pto.ptr<f16>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c256 = arith.constant 256 : index

%tv = pto.make_tensor_view %src, shape = [%c1, %c16, %c1, %c16], strides = [%c256, %c16, %c16, %c1] {layout = #pto.layout<nd>} : !pto.tensor_view<?x?x?x?xf16>
%sv = pto.partition_view %tv, offsets = [%c0, %c0, %c0, %c0], sizes = [%c1, %c16, %c1, %c16] : !pto.tensor_view<?x?x?x?xf16> -> !pto.partition_tensor_view<1x16x1x16xf16>
%tile = pto.alloc_tile : !pto.tile_buf<loc=mat, dtype=f16, rows=16, cols=16, v_row=16, v_col=16, blayout=col_major, slayout=row_major, fractal=512, pad=0>
pto.tload ins(%sv : !pto.partition_tensor_view<1x16x1x16xf16>)
outs(%tile : !pto.tile_buf<loc=mat, dtype=f16, rows=16, cols=16, v_row=16, v_col=16, blayout=col_major, slayout=row_major, fractal=512, pad=0>)
return
}

// Unsafe case: three non-singleton axes (2,16,16). Reordering is not legal
// and must not happen.
func.func @issue453_singleton_axis_reorder_negative(%src: !pto.ptr<f16>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c16 = arith.constant 16 : index
%c256 = arith.constant 256 : index

%tv = pto.make_tensor_view %src, shape = [%c2, %c16, %c1, %c16], strides = [%c256, %c16, %c16, %c1] {layout = #pto.layout<nd>} : !pto.tensor_view<?x?x?x?xf16>
%sv = pto.partition_view %tv, offsets = [%c0, %c0, %c0, %c0], sizes = [%c2, %c16, %c1, %c16] : !pto.tensor_view<?x?x?x?xf16> -> !pto.partition_tensor_view<2x16x1x16xf16>
%tile = pto.alloc_tile : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=16, v_row=32, v_col=16, blayout=col_major, slayout=row_major, fractal=512, pad=0>
pto.tload ins(%sv : !pto.partition_tensor_view<2x16x1x16xf16>)
outs(%tile : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=16, v_row=32, v_col=16, blayout=col_major, slayout=row_major, fractal=512, pad=0>)
return
}
}

// A5-LABEL: AICORE void issue453_singleton_axis_reorder_positive(
// A5: GlobalTensor<half, pto::Shape<1, 1, 1, 16, 16>
// A5-NOT: GlobalTensor<half, pto::Shape<1, 1, 16, 1, 16>

// A5-LABEL: AICORE void issue453_singleton_axis_reorder_negative(
// A5: GlobalTensor<half, pto::Shape<1, 2, 16, 1, 16>
// A5-NOT: GlobalTensor<half, pto::Shape<1, 1, 2, 16, 16>
Loading