Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion lib/PTO/IR/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ static LogicalResult verifyMatTileOperandsA2A3(Operation *op, Type lhsTy,
Type rhsTy, Type dstTy);
static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy,
Type rhsTy, Type dstTy);
static LogicalResult verifyMadTileLayoutsA5(Operation *op, Type lhsTy,
Type rhsTy, Type dstTy);
static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsTy,
Type dstTy);
static LogicalResult verifyGemvTileOperandsA2A3(Operation *op, Type lhsTy,
Expand Down Expand Up @@ -3131,6 +3133,11 @@ static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy,
if (failed(verifyMatTileOperandsA2A3(op, lhsTy, rhsTy, dstTy)))
return failure();

return verifyMadTileLayoutsA5(op, lhsTy, rhsTy, dstTy);
}

static LogicalResult verifyMadTileLayoutsA5(Operation *op, Type lhsTy,
Type rhsTy, Type dstTy) {
Comment on lines +3139 to +3140
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of verifyMadTileLayoutsA5 (refactored from verifyMatTileOperandsA5) uses an "all-or-nothing" check for TileBufType operands. If any of the operands is not a TileBufType (e.g., a MemRefType during intermediate lowering stages), all layout checks are skipped for the remaining operands. It would be more robust to check each operand independently if it is a TileBufType.

auto lhsTb = mlir::dyn_cast<pto::TileBufType>(lhsTy);
auto rhsTb = mlir::dyn_cast<pto::TileBufType>(rhsTy);
auto dstTb = mlir::dyn_cast<pto::TileBufType>(dstTy);
Expand Down Expand Up @@ -3204,7 +3211,7 @@ static LogicalResult verifyGemvTileOperandsA5(Operation *op, Type lhsTy,
Type rhsTy, Type dstTy) {
if (failed(verifyGemvTileOperandsA2A3(op, lhsTy, rhsTy, dstTy)))
return failure();
return verifyMatTileOperandsA5(op, lhsTy, rhsTy, dstTy);
return verifyMadTileLayoutsA5(op, lhsTy, rhsTy, dstTy);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reinstate GEMV K/N bound checks on A5

Changing A5 GEMV verification to call only verifyMadTileLayoutsA5 drops the [1, 4095] valid-size guard that previously came from verifyMatTileOperandsA2A3 via verifyMatTileOperandsA5. As a result, pto.tgemv* now accepts static valid_shape values such as K=5000 or N=5000 on A5, even though matmul-family verification still treats those sizes as out of range; this can allow invalid TGEMV dimensions through verification and fail later in lowering/codegen.

Useful? React with 👍 / 👎.

}

static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsTy,
Expand Down
24 changes: 24 additions & 0 deletions test/lit/pto/tgemv_a5_aligned_acc_rows.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: ptoas --pto-arch=a5 %s | FileCheck %s

module {
func.func @tgemv_a5_aligned_acc_rows() {
%lhs = pto.alloc_tile : !pto.tile_buf<loc=left, dtype=f16, rows=1, cols=64, v_row=1, v_col=64, blayout=col_major, slayout=row_major, fractal=512, pad=0>
%rhs = pto.alloc_tile : !pto.tile_buf<loc=right, dtype=f16, rows=64, cols=80, v_row=64, v_col=80, blayout=row_major, slayout=col_major, fractal=512, pad=0>
%bias = pto.alloc_tile : !pto.tile_buf<loc=bias, dtype=f32, rows=1, cols=80, v_row=1, v_col=80, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%acc_in = pto.alloc_tile : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>
%dst0 = pto.alloc_tile : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>
%dst1 = pto.alloc_tile : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>
%dst2 = pto.alloc_tile : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>

pto.tgemv ins(%lhs, %rhs : !pto.tile_buf<loc=left, dtype=f16, rows=1, cols=64, v_row=1, v_col=64, blayout=col_major, slayout=row_major, fractal=512, pad=0>, !pto.tile_buf<loc=right, dtype=f16, rows=64, cols=80, v_row=64, v_col=80, blayout=row_major, slayout=col_major, fractal=512, pad=0>) outs(%dst0 : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>)
pto.tgemv.acc ins(%acc_in, %lhs, %rhs : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>, !pto.tile_buf<loc=left, dtype=f16, rows=1, cols=64, v_row=1, v_col=64, blayout=col_major, slayout=row_major, fractal=512, pad=0>, !pto.tile_buf<loc=right, dtype=f16, rows=64, cols=80, v_row=64, v_col=80, blayout=row_major, slayout=col_major, fractal=512, pad=0>) outs(%dst1 : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>)
pto.tgemv.bias ins(%lhs, %rhs, %bias : !pto.tile_buf<loc=left, dtype=f16, rows=1, cols=64, v_row=1, v_col=64, blayout=col_major, slayout=row_major, fractal=512, pad=0>, !pto.tile_buf<loc=right, dtype=f16, rows=64, cols=80, v_row=64, v_col=80, blayout=row_major, slayout=col_major, fractal=512, pad=0>, !pto.tile_buf<loc=bias, dtype=f32, rows=1, cols=80, v_row=1, v_col=80, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%dst2 : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>)

return
}
}

// CHECK-LABEL: __global__ AICORE void tgemv_a5_aligned_acc_rows()
// CHECK: TGEMV(
// CHECK: TGEMV_ACC(
// CHECK: TGEMV_BIAS(
Loading