-
Notifications
You must be signed in to change notification settings - Fork 50
Fix A5 GEMV verifier for aligned ACC rows #651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -137,6 +137,8 @@ static LogicalResult verifyMatTileOperandsA2A3(Operation *op, Type lhsTy, | |
| Type rhsTy, Type dstTy); | ||
| static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy, | ||
| Type rhsTy, Type dstTy); | ||
| static LogicalResult verifyMadTileLayoutsA5(Operation *op, Type lhsTy, | ||
| Type rhsTy, Type dstTy); | ||
| static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsTy, | ||
| Type dstTy); | ||
| static LogicalResult verifyGemvTileOperandsA2A3(Operation *op, Type lhsTy, | ||
|
|
@@ -3131,6 +3133,11 @@ static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy, | |
| if (failed(verifyMatTileOperandsA2A3(op, lhsTy, rhsTy, dstTy))) | ||
| return failure(); | ||
|
|
||
| return verifyMadTileLayoutsA5(op, lhsTy, rhsTy, dstTy); | ||
| } | ||
|
|
||
| static LogicalResult verifyMadTileLayoutsA5(Operation *op, Type lhsTy, | ||
| Type rhsTy, Type dstTy) { | ||
| auto lhsTb = mlir::dyn_cast<pto::TileBufType>(lhsTy); | ||
| auto rhsTb = mlir::dyn_cast<pto::TileBufType>(rhsTy); | ||
| auto dstTb = mlir::dyn_cast<pto::TileBufType>(dstTy); | ||
|
|
@@ -3204,7 +3211,7 @@ static LogicalResult verifyGemvTileOperandsA5(Operation *op, Type lhsTy, | |
| Type rhsTy, Type dstTy) { | ||
| if (failed(verifyGemvTileOperandsA2A3(op, lhsTy, rhsTy, dstTy))) | ||
| return failure(); | ||
| return verifyMatTileOperandsA5(op, lhsTy, rhsTy, dstTy); | ||
| return verifyMadTileLayoutsA5(op, lhsTy, rhsTy, dstTy); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Changing A5 GEMV verification to call only Useful? React with 👍 / 👎. |
||
| } | ||
|
|
||
| static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsTy, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| // RUN: ptoas --pto-arch=a5 %s | FileCheck %s | ||
|
|
||
| module { | ||
| func.func @tgemv_a5_aligned_acc_rows() { | ||
| %lhs = pto.alloc_tile : !pto.tile_buf<loc=left, dtype=f16, rows=1, cols=64, v_row=1, v_col=64, blayout=col_major, slayout=row_major, fractal=512, pad=0> | ||
| %rhs = pto.alloc_tile : !pto.tile_buf<loc=right, dtype=f16, rows=64, cols=80, v_row=64, v_col=80, blayout=row_major, slayout=col_major, fractal=512, pad=0> | ||
| %bias = pto.alloc_tile : !pto.tile_buf<loc=bias, dtype=f32, rows=1, cols=80, v_row=1, v_col=80, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %acc_in = pto.alloc_tile : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0> | ||
| %dst0 = pto.alloc_tile : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0> | ||
| %dst1 = pto.alloc_tile : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0> | ||
| %dst2 = pto.alloc_tile : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0> | ||
|
|
||
| pto.tgemv ins(%lhs, %rhs : !pto.tile_buf<loc=left, dtype=f16, rows=1, cols=64, v_row=1, v_col=64, blayout=col_major, slayout=row_major, fractal=512, pad=0>, !pto.tile_buf<loc=right, dtype=f16, rows=64, cols=80, v_row=64, v_col=80, blayout=row_major, slayout=col_major, fractal=512, pad=0>) outs(%dst0 : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>) | ||
| pto.tgemv.acc ins(%acc_in, %lhs, %rhs : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>, !pto.tile_buf<loc=left, dtype=f16, rows=1, cols=64, v_row=1, v_col=64, blayout=col_major, slayout=row_major, fractal=512, pad=0>, !pto.tile_buf<loc=right, dtype=f16, rows=64, cols=80, v_row=64, v_col=80, blayout=row_major, slayout=col_major, fractal=512, pad=0>) outs(%dst1 : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>) | ||
| pto.tgemv.bias ins(%lhs, %rhs, %bias : !pto.tile_buf<loc=left, dtype=f16, rows=1, cols=64, v_row=1, v_col=64, blayout=col_major, slayout=row_major, fractal=512, pad=0>, !pto.tile_buf<loc=right, dtype=f16, rows=64, cols=80, v_row=64, v_col=80, blayout=row_major, slayout=col_major, fractal=512, pad=0>, !pto.tile_buf<loc=bias, dtype=f32, rows=1, cols=80, v_row=1, v_col=80, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%dst2 : !pto.tile_buf<loc=acc, dtype=f32, rows=80, cols=80, v_row=1, v_col=80, blayout=col_major, slayout=row_major, fractal=1024, pad=0>) | ||
|
|
||
| return | ||
| } | ||
| } | ||
|
|
||
| // CHECK-LABEL: __global__ AICORE void tgemv_a5_aligned_acc_rows() | ||
| // CHECK: TGEMV( | ||
| // CHECK: TGEMV_ACC( | ||
| // CHECK: TGEMV_BIAS( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of
verifyMadTileLayoutsA5(refactored fromverifyMatTileOperandsA5) uses an "all-or-nothing" check forTileBufTypeoperands. If any of the operands is not aTileBufType(e.g., aMemRefTypeduring intermediate lowering stages), all layout checks are skipped for the remaining operands. It would be more robust to check each operand independently if it is aTileBufType.