Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion lib/PTO/IR/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ static LogicalResult verifyMatTileOperandsA2A3(Operation *op, Type lhsTy,
Type rhsTy, Type dstTy);
static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy,
Type rhsTy, Type dstTy);
static LogicalResult verifyMadTileLayoutsA5(Operation *op, Type lhsTy,
Type rhsTy, Type dstTy);
static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsTy,
Type dstTy);
static LogicalResult verifyGemvTileOperandsA2A3(Operation *op, Type lhsTy,
Expand Down Expand Up @@ -3131,6 +3133,11 @@ static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy,
if (failed(verifyMatTileOperandsA2A3(op, lhsTy, rhsTy, dstTy)))
return failure();

return verifyMadTileLayoutsA5(op, lhsTy, rhsTy, dstTy);
}

static LogicalResult verifyMadTileLayoutsA5(Operation *op, Type lhsTy,
Type rhsTy, Type dstTy) {
Comment on lines +3139 to +3140
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of verifyMadTileLayoutsA5 (refactored from verifyMatTileOperandsA5) uses an "all-or-nothing" check for TileBufType operands. If any of the operands is not a TileBufType (e.g., a MemRefType during intermediate lowering stages), all layout checks are skipped for the remaining operands. It would be more robust to check each operand independently if it is a TileBufType.

auto lhsTb = mlir::dyn_cast<pto::TileBufType>(lhsTy);
auto rhsTb = mlir::dyn_cast<pto::TileBufType>(rhsTy);
auto dstTb = mlir::dyn_cast<pto::TileBufType>(dstTy);
Expand Down Expand Up @@ -3204,7 +3211,7 @@ static LogicalResult verifyGemvTileOperandsA5(Operation *op, Type lhsTy,
Type rhsTy, Type dstTy) {
if (failed(verifyGemvTileOperandsA2A3(op, lhsTy, rhsTy, dstTy)))
return failure();
return verifyMatTileOperandsA5(op, lhsTy, rhsTy, dstTy);
return verifyMadTileLayoutsA5(op, lhsTy, rhsTy, dstTy);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reinstate GEMV K/N bound checks on A5

Changing A5 GEMV verification to call only verifyMadTileLayoutsA5 drops the [1, 4095] valid-size guard that previously came from verifyMatTileOperandsA2A3 via verifyMatTileOperandsA5. As a result, pto.tgemv* now accepts static valid_shape values such as K=5000 or N=5000 on A5, even though matmul-family verification still treats those sizes as out of range; this can allow invalid TGEMV dimensions through verification and fail later in lowering/codegen.

Useful? React with 👍 / 👎.

}

static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsTy,
Expand Down
12 changes: 6 additions & 6 deletions test/lit/pto/issue226_remove_redundant_pipe_pair.pto
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@

module {
func.func @remove_redundant_pipe_pair(
%arg0: memref<64x1xf16, strided<[1, 1]>, #pto.address_space<vec>>) {
%arg0: !pto.ptr<f16>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c64 = arith.constant 64 : index
%vbuf0 = pto.bind_tile %arg0, %c64, %c1
{config = #pto.tile_buf_config<blayout=1 : i32, slayout=2 : i32, s_fractal_size=512, pad=0 : i32>}
: memref<64x1xf16, strided<[1, 1]>, #pto.address_space<vec>>
-> memref<64x1xf16, strided<[1, 1], offset: ?>, #pto.address_space<vec>>
%vview = pto.make_tensor_view %arg0, shape = [%c64, %c1], strides = [%c1, %c1]
: !pto.tensor_view<64x1xf16>
%vbuf0 = pto.partition_view %vview, offsets = [%c0, %c0], sizes = [%c64, %c1]
: !pto.tensor_view<64x1xf16> -> !pto.partition_tensor_view<64x1xf16>

pto.section.cube {
%mat_a = pto.alloc_tile : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>
Expand All @@ -33,7 +33,7 @@ module {
%acc = pto.alloc_tile : !pto.tile_buf<loc=acc, dtype=f32, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=1024, pad=0>

scf.for %i = %c0 to %c2 step %c1 {
pto.tload ins(%vbuf0 : memref<64x1xf16, strided<[1, 1], offset: ?>, #pto.address_space<vec>>)
pto.tload ins(%vbuf0 : !pto.partition_tensor_view<64x1xf16>)
outs(%mat_a : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>)
pto.tmov ins(%mat_a : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>)
outs(%left : !pto.tile_buf<loc=left, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=row_major, slayout=row_major, fractal=512, pad=0>)
Expand Down
24 changes: 24 additions & 0 deletions test/lit/pto/tgemv_a5_aligned_acc_rows.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: ptoas --pto-arch=a5 %s | FileCheck %s

module {
func.func @tgemv_a5_aligned_acc_rows() {
%lhs = pto.alloc_tile : !pto.tile_buf<loc=left, dtype=f16, rows=1, cols=64, v_row=1, v_col=64, blayout=col_major, slayout=row_major, fractal=512, pad=0>
%rhs = pto.alloc_tile : !pto.tile_buf<loc=right, dtype=f16, rows=64, cols=80, v_row=64, v_col=80, blayout=row_major, slayout=col_major, fractal=512, pad=0>
%bias = pto.alloc_tile : !pto.tile_buf<loc=bias, dtype=f32, rows=1, cols=80, v_row=1, v_col=80, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%acc_in = pto.alloc_tile : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>
%dst0 = pto.alloc_tile : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>
%dst1 = pto.alloc_tile : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>
%dst2 = pto.alloc_tile : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>

pto.tgemv ins(%lhs, %rhs : !pto.tile_buf<loc=left, dtype=f16, rows=1, cols=64, v_row=1, v_col=64, blayout=col_major, slayout=row_major, fractal=512, pad=0>, !pto.tile_buf<loc=right, dtype=f16, rows=64, cols=80, v_row=64, v_col=80, blayout=row_major, slayout=col_major, fractal=512, pad=0>) outs(%dst0 : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>)
pto.tgemv.acc ins(%acc_in, %lhs, %rhs : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>, !pto.tile_buf<loc=left, dtype=f16, rows=1, cols=64, v_row=1, v_col=64, blayout=col_major, slayout=row_major, fractal=512, pad=0>, !pto.tile_buf<loc=right, dtype=f16, rows=64, cols=80, v_row=64, v_col=80, blayout=row_major, slayout=col_major, fractal=512, pad=0>) outs(%dst1 : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>)
pto.tgemv.bias ins(%lhs, %rhs, %bias : !pto.tile_buf<loc=left, dtype=f16, rows=1, cols=64, v_row=1, v_col=64, blayout=col_major, slayout=row_major, fractal=512, pad=0>, !pto.tile_buf<loc=right, dtype=f16, rows=64, cols=80, v_row=64, v_col=80, blayout=row_major, slayout=col_major, fractal=512, pad=0>, !pto.tile_buf<loc=bias, dtype=f32, rows=1, cols=80, v_row=1, v_col=80, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%dst2 : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>)

return
}
}

// CHECK-LABEL: __global__ AICORE void tgemv_a5_aligned_acc_rows()
// CHECK: TGEMV(
// CHECK: TGEMV_ACC(
// CHECK: TGEMV_BIAS(
Loading