Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ jobs:
SKIP_CASES: ${{ github.event.inputs.skip_cases || '' }}
RUN_ONLY_CASES: ${{ github.event.inputs.run_only_cases || '' }}
PTO_ISA_REPO: ${{ github.event.inputs.pto_isa_repo || 'https://gitcode.com/cann/pto-isa.git' }}
PTO_ISA_COMMIT: ${{ github.event.inputs.pto_isa_commit || '662d7f2a916d6bbde3109ce4a16ed5c28f5d900a' }}
PTO_ISA_COMMIT: ${{ github.event.inputs.pto_isa_commit || '5dbf1b2f6b8ea934f03e62367c2f540ece21134e' }}
REMOTE_HOST: ${{ github.event.inputs.remote_host || '101.245.68.6' }}
REMOTE_USER: ${{ github.event.inputs.remote_user || 'zhongxuan' }}
REMOTE_PORT: ${{ github.event.inputs.remote_port || '22' }}
Expand Down
22 changes: 14 additions & 8 deletions docs/PTO_IR_manual.md
Original file line number Diff line number Diff line change
Expand Up @@ -6846,7 +6846,8 @@ pto.tscatter ins(%src, %idx : !pto.tile_buf<...>, !pto.tile_buf<...>)
**Semantics:**

```
dst[i, j] = mem[idx[i, j]]
row mode (default): dst[r, j] = mem[idx[r], j]
elem mode: dst[i, j] = mem[idx[i, j]]
```

**Arguments:**
Expand All @@ -6867,13 +6868,15 @@ dst[i, j] = mem[idx[i, j]]
- `idx` element type must be signless `i32`.

- **Tile / memory roles**
- `dst` and `idx` must be `loc=vec`, `blayout=row_major`, `slayout=none_box`.
- `dst` must be `loc=vec`, `blayout=row_major`, `slayout=none_box`.
- `idx` must be `loc=vec`, `slayout=none_box`. `row_major` and `col_major` are both accepted for row mode.
- `mem` must denote a GlobalTensor in GM memory.
- `mem` must use `ND` layout when layout can be inferred.

- **Shape**
- `dst row == idx row`.
- `idx column == 1` or `idx column == dst column`.
- Element mode: `idx valid_shape == dst valid_shape`.
- Row mode: `idx valid_shape` may be `[1, dst.valid_row]` or `[dst.valid_row, 1]`.
- The `[1, R]` row-mode variant uses `row_major`; the `[R, 1]` row-mode variant uses `col_major`.
- If `mem` is a rank-5 static GM memref, it must satisfy `<1, 1, 1, Rows, RowWidth>`.

- **Out-of-bounds mode**
Expand Down Expand Up @@ -6904,7 +6907,8 @@ pto.mgather ins(%mem, %idx : memref<...>, !pto.tile_buf<...>)
**Semantics:**

```
mem[idx[i, j]] = src[i, j]
row mode (default): mem[idx[r], j] = src[r, j]
elem mode: mem[idx[i, j]] = src[i, j]
```

**Arguments:**
Expand All @@ -6926,13 +6930,15 @@ mem[idx[i, j]] = src[i, j]
- `idx` element type must be signless `i32`.

- **Tile / memory roles**
- `src` and `idx` must be `loc=vec`, `blayout=row_major`, `slayout=none_box`.
- `src` must be `loc=vec`, `blayout=row_major`, `slayout=none_box`.
- `idx` must be `loc=vec`, `slayout=none_box`. `row_major` and `col_major` are both accepted for row mode.
- `mem` must denote a GlobalTensor in GM memory.
- `mem` must use `ND` layout when layout can be inferred.

- **Shape**
- `src row == idx row`.
- `idx column == 1` or `idx column == src column`.
- Element mode: `idx valid_shape == src valid_shape`.
- Row mode: `idx valid_shape` may be `[1, src.valid_row]` or `[src.valid_row, 1]`.
- The `[1, R]` row-mode variant uses `row_major`; the `[R, 1]` row-mode variant uses `col_major`.
- If `mem` is a rank-5 static GM memref, it must satisfy `<1, 1, 1, Rows, RowWidth>`.

- **Atomic modes**
Expand Down
52 changes: 50 additions & 2 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2504,7 +2504,31 @@ def MGatherOp : PTO_TOp<"mgather", [

let extraClassDeclaration = [{
static StringRef getIntrinsicName() { return "MGATHER"; }
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_MTE2; }
::mlir::pto::PIPE getPipe() {
auto isA5Target = [&]() -> bool {
auto moduleOp = getOperation()->getParentOfType<::mlir::ModuleOp>();
if (!moduleOp)
return false;

if (auto arch =
moduleOp->getAttrOfType<::mlir::StringAttr>("pto.target_arch")) {
if (arch.getValue().equals_insensitive("a5"))
return true;
}

if (auto spec =
moduleOp->getAttrOfType<::mlir::StringAttr>("pto.device-spec")) {
auto s = spec.getValue();
if (s.starts_with("Ascend950") || s.starts_with("Ascend910_95"))
return true;
}

return false;
};

return isA5Target() ? ::mlir::pto::PIPE::PIPE_V
: ::mlir::pto::PIPE::PIPE_MTE2;
}
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];
}
Expand Down Expand Up @@ -2660,7 +2684,31 @@ def MScatterOp : PTO_TOp<"mscatter", [

let extraClassDeclaration = [{
static StringRef getIntrinsicName() { return "MSCATTER"; }
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; }
::mlir::pto::PIPE getPipe() {
auto isA5Target = [&]() -> bool {
auto moduleOp = getOperation()->getParentOfType<::mlir::ModuleOp>();
if (!moduleOp)
return false;

if (auto arch =
moduleOp->getAttrOfType<::mlir::StringAttr>("pto.target_arch")) {
if (arch.getValue().equals_insensitive("a5"))
return true;
}

if (auto spec =
moduleOp->getAttrOfType<::mlir::StringAttr>("pto.device-spec")) {
auto s = spec.getValue();
if (s.starts_with("Ascend950") || s.starts_with("Ascend910_95"))
return true;
}

return false;
};

return isA5Target() ? ::mlir::pto::PIPE::PIPE_V
: ::mlir::pto::PIPE::PIPE_MTE3;
}
::mlir::MutableOperandRange getDpsInitsMutable() { return getMemMutable(); }
}];
}
Expand Down
74 changes: 57 additions & 17 deletions lib/PTO/IR/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3019,27 +3019,67 @@ static LogicalResult verifyMGatherMScatterMemOperand(Operation *op,
"expects mem to be !pto.partition_tensor_view or a GM/ZERO memref");
}

static bool hasCompatibleKnownExtent(int64_t lhs, int64_t rhs);
static bool isKnownUnitExtent(int64_t value);

static LogicalResult verifyMGatherMScatterTileShape(Operation *op, Type dataTy,
Type idxTy,
StringRef dataName) {
auto dataShape = getShapeVec(dataTy);
auto idxShape = getShapeVec(idxTy);
if (dataShape.size() != 2 || idxShape.size() != 2)
return op->emitOpError() << "expects " << dataName
<< " and idx to be rank-2";

if (dataShape[0] != ShapedType::kDynamic &&
idxShape[0] != ShapedType::kDynamic && dataShape[0] != idxShape[0])
auto dataValid = getValidShapeVec(dataTy);
auto idxValid = getValidShapeVec(idxTy);
if (dataValid.size() != 2 || idxValid.size() != 2)
return op->emitOpError() << "expects " << dataName
<< " and idx static row dimensions to match";
<< " and idx to have rank-2 valid_shape";

auto idxTile = dyn_cast<pto::TileBufType>(idxTy);
if (!idxTile)
return op->emitOpError("expects idx to be a tile_buf type");

const bool idxRowMajor =
idxTile.getBLayoutValueI32() ==
static_cast<int32_t>(pto::BLayout::RowMajor);
const bool idxColMajor =
idxTile.getBLayoutValueI32() ==
static_cast<int32_t>(pto::BLayout::ColMajor);

const bool rowCoalesce1xR =
idxRowMajor && isKnownUnitExtent(idxValid[0]) &&
hasCompatibleKnownExtent(idxValid[1], dataValid[0]);
const bool rowCoalesceRx1 =
idxColMajor && hasCompatibleKnownExtent(idxValid[0], dataValid[0]) &&
isKnownUnitExtent(idxValid[1]);
const bool elemCoalesce =
hasCompatibleKnownExtent(idxValid[0], dataValid[0]) &&
hasCompatibleKnownExtent(idxValid[1], dataValid[1]);

if (!(rowCoalesce1xR || rowCoalesceRx1 || elemCoalesce))
return op->emitOpError()
<< "expects idx valid_shape to be [1, " << dataName
<< ".valid_row], [" << dataName
<< ".valid_row, 1], or match " << dataName << " valid_shape";

int64_t dataCols = dataShape[1];
int64_t idxCols = idxShape[1];
if (idxCols != ShapedType::kDynamic && dataCols != ShapedType::kDynamic &&
idxCols != 1 && idxCols != dataCols)
return op->emitOpError() << "expects idx cols to be 1 or equal to "
<< dataName << " cols";
return success();
}

static LogicalResult verifyMGatherMScatterIdxTile(Operation *op, Type ty,
StringRef name) {
if (failed(verifyTileBufCommon(op, ty, name)))
return failure();
auto as = getPTOMemorySpaceEnum(ty);
if (!as || *as != pto::AddressSpace::VEC)
return op->emitOpError() << "expects " << name
<< " to be in the vec address space";
auto tb = dyn_cast<pto::TileBufType>(ty);
if (!tb)
return op->emitOpError() << "expects " << name << " to be a tile_buf type";
int32_t blayout = tb.getBLayoutValueI32();
if (blayout != static_cast<int32_t>(pto::BLayout::RowMajor) &&
blayout != static_cast<int32_t>(pto::BLayout::ColMajor))
return op->emitOpError() << "expects " << name
<< " to use row_major or col_major blayout";
if (tb.getSLayoutValueI32() != static_cast<int32_t>(pto::SLayout::NoneBox))
return op->emitOpError() << "expects " << name
<< " to use the none_box slayout";
return success();
}

Expand Down Expand Up @@ -6428,7 +6468,7 @@ LogicalResult MScatterOp::verify() {
return emitOpError("expects src, idx, and mem to use supported PTO shapes");

if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) ||
failed(verifyNDStyleVecTile(*this, idxTy, "idx")))
failed(verifyMGatherMScatterIdxTile(getOperation(), idxTy, "idx")))
return failure();

Type srcElem = getElemTy(srcTy);
Expand Down Expand Up @@ -6483,7 +6523,7 @@ LogicalResult MGatherOp::verify() {
return emitOpError("expects mem, idx, and dst to use supported PTO shapes");

if (failed(verifyNDStyleVecTile(*this, dstTy, "dst")) ||
failed(verifyNDStyleVecTile(*this, idxTy, "idx")))
failed(verifyMGatherMScatterIdxTile(getOperation(), idxTy, "idx")))
return failure();

Type dstElem = getElemTy(dstTy);
Expand Down
114 changes: 93 additions & 21 deletions lib/PTO/Transforms/PTOToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,69 @@ static Value maybeWrapGlobalMemrefAsGlobalTensor(
ConversionPatternRewriter &rewriter, Location loc, Value loweredValue,
Type originalType, Operation *anchor);

static bool hasCompatibleKnownExtentForMGather(int64_t lhs, int64_t rhs) {
return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic ||
lhs == rhs;
}

static bool isKnownUnitExtentForMGather(int64_t value) {
return value == ShapedType::kDynamic || value == 1;
}

struct GatherScatterShapeLayoutInfo {
SmallVector<int64_t, 2> shape;
bool rowMajor = false;
bool colMajor = false;
};

static std::optional<GatherScatterShapeLayoutInfo>
getGatherScatterShapeLayoutInfo(Type ty) {
if (auto tileTy = dyn_cast<pto::TileBufType>(ty)) {
ArrayRef<int64_t> validShape = tileTy.getValidShape();
if (validShape.size() != 2)
return std::nullopt;

GatherScatterShapeLayoutInfo info;
info.shape.assign(validShape.begin(), validShape.end());
int32_t blayout = tileTy.getBLayoutValueI32();
info.rowMajor = blayout == static_cast<int32_t>(pto::BLayout::RowMajor);
info.colMajor = blayout == static_cast<int32_t>(pto::BLayout::ColMajor);
return info;
}

auto memRefTy = dyn_cast<MemRefType>(ty);
if (!memRefTy || memRefTy.getRank() != 2)
return std::nullopt;

SmallVector<int64_t, 4> strides;
int64_t offset = ShapedType::kDynamic;
if (failed(getStridesAndOffset(memRefTy, strides, offset)) ||
strides.size() != 2)
return std::nullopt;

GatherScatterShapeLayoutInfo info;
info.shape.assign(memRefTy.getShape().begin(), memRefTy.getShape().end());
info.rowMajor = strides[1] == 1;
info.colMajor = strides[0] == 1;
return info;
}

static bool isRowCoalescedMGatherIndexType(Type dataTy, Type idxTy) {
auto dataInfo = getGatherScatterShapeLayoutInfo(dataTy);
auto idxInfo = getGatherScatterShapeLayoutInfo(idxTy);
if (!dataInfo || !idxInfo)
return false;

const bool rowCoalesce1xR =
idxInfo->rowMajor && isKnownUnitExtentForMGather(idxInfo->shape[0]) &&
hasCompatibleKnownExtentForMGather(idxInfo->shape[1], dataInfo->shape[0]);
const bool rowCoalesceRx1 =
idxInfo->colMajor &&
hasCompatibleKnownExtentForMGather(idxInfo->shape[0], dataInfo->shape[0]) &&
isKnownUnitExtentForMGather(idxInfo->shape[1]);
return rowCoalesce1xR || rowCoalesceRx1;
}

static std::optional<mlir::pto::Layout> getLayoutAttrFromOp(Operation *op) {
if (!op)
return std::nullopt;
Expand Down Expand Up @@ -2577,24 +2640,30 @@ struct PTOMGatherToMGATHER : public OpConversionPattern<pto::MGatherOp> {
Value memArg = maybeWrapGlobalMemrefAsGlobalTensor(
rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation());

ArrayAttr templateArgs;
auto gatherOobTok = [&](pto::GatherOOB mode) -> StringRef {
switch (mode) {
case pto::GatherOOB::Undefined:
return "pto::GatherOOB::Undefined";
case pto::GatherOOB::Clamp:
return "pto::GatherOOB::Clamp";
case pto::GatherOOB::Wrap:
return "pto::GatherOOB::Wrap";
case pto::GatherOOB::Zero:
return "pto::GatherOOB::Zero";
}
llvm_unreachable("unknown GatherOOB");
};

SmallVector<Attribute, 2> templateArgVec;
const bool rowCoalesce =
isRowCoalescedMGatherIndexType(op.getDst().getType(), op.getIdx().getType());
templateArgVec.push_back(emitc::OpaqueAttr::get(
ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem"));
if (op.getGatherOob() != pto::GatherOOB::Undefined) {
auto gatherOobTok = [&](pto::GatherOOB mode) -> StringRef {
switch (mode) {
case pto::GatherOOB::Undefined:
return "pto::GatherOOB::Undefined";
case pto::GatherOOB::Clamp:
return "pto::GatherOOB::Clamp";
case pto::GatherOOB::Wrap:
return "pto::GatherOOB::Wrap";
case pto::GatherOOB::Zero:
return "pto::GatherOOB::Zero";
}
llvm_unreachable("unknown GatherOOB");
};
templateArgs = rewriter.getArrayAttr(
{emitc::OpaqueAttr::get(ctx, gatherOobTok(op.getGatherOob()))});
templateArgVec.push_back(
emitc::OpaqueAttr::get(ctx, gatherOobTok(op.getGatherOob())));
}
ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec);

rewriter.create<emitc::CallOpaqueOp>(
op.getLoc(), TypeRange{}, "MGATHER",
Expand Down Expand Up @@ -5233,17 +5302,20 @@ struct PTOMScatterToMSCATTER : public OpConversionPattern<pto::MScatterOp> {
llvm_unreachable("unknown ScatterOOB");
};

SmallVector<Attribute, 2> templateArgVec;
SmallVector<Attribute, 3> templateArgVec;
const bool rowCoalesce =
isRowCoalescedMGatherIndexType(op.getSrc().getType(), op.getIdx().getType());
templateArgVec.push_back(emitc::OpaqueAttr::get(
ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem"));
if (op.getScatterAtomicOp() != pto::ScatterAtomicOp::None ||
op.getScatterOob() != pto::ScatterOOB::Undefined) {
templateArgVec.push_back(
emitc::OpaqueAttr::get(ctx, scatterAtomicTok(op.getScatterAtomicOp())));
templateArgVec.push_back(emitc::OpaqueAttr::get(
ctx, scatterAtomicTok(op.getScatterAtomicOp())));
if (op.getScatterOob() != pto::ScatterOOB::Undefined)
templateArgVec.push_back(
emitc::OpaqueAttr::get(ctx, scatterOobTok(op.getScatterOob())));
}
ArrayAttr templateArgs =
templateArgVec.empty() ? ArrayAttr{} : rewriter.getArrayAttr(templateArgVec);
ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec);

rewriter.create<emitc::CallOpaqueOp>(
op.getLoc(), TypeRange{}, "MSCATTER",
Expand Down
7 changes: 5 additions & 2 deletions lib/PTO/Transforms/PTOViewToMemref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3332,7 +3332,8 @@ struct PTOViewToMemrefPass
TypeRange{},
mem,
idx,
dst);
dst,
op.getGatherOobAttr());
}

SmallVector<mlir::pto::MScatterOp, 8> mascatterops;
Expand Down Expand Up @@ -3360,7 +3361,9 @@ struct PTOViewToMemrefPass
TypeRange{},
src,
idx,
mem);
mem,
op.getScatterAtomicOpAttr(),
op.getScatterOobAttr());
}
SmallVector<mlir::pto::TPrintOp, 8> printops;
func.walk([&](mlir::pto::TPrintOp op) { printops.push_back(op); });
Expand Down
Loading
Loading