diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2b020193b..4064a2c92 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -295,7 +295,7 @@ jobs: SKIP_CASES: ${{ github.event.inputs.skip_cases || '' }} RUN_ONLY_CASES: ${{ github.event.inputs.run_only_cases || '' }} PTO_ISA_REPO: ${{ github.event.inputs.pto_isa_repo || 'https://gitcode.com/cann/pto-isa.git' }} - PTO_ISA_COMMIT: ${{ github.event.inputs.pto_isa_commit || '662d7f2a916d6bbde3109ce4a16ed5c28f5d900a' }} + PTO_ISA_COMMIT: ${{ github.event.inputs.pto_isa_commit || '5dbf1b2f6b8ea934f03e62367c2f540ece21134e' }} REMOTE_HOST: ${{ github.event.inputs.remote_host || '101.245.68.6' }} REMOTE_USER: ${{ github.event.inputs.remote_user || 'zhongxuan' }} REMOTE_PORT: ${{ github.event.inputs.remote_port || '22' }} diff --git a/docs/PTO_IR_manual.md b/docs/PTO_IR_manual.md index be8148d90..2a1b97709 100644 --- a/docs/PTO_IR_manual.md +++ b/docs/PTO_IR_manual.md @@ -6846,7 +6846,8 @@ pto.tscatter ins(%src, %idx : !pto.tile_buf<...>, !pto.tile_buf<...>) **Semantics:** ``` -dst[i, j] = mem[idx[i, j]] +row mode (default): dst[r, j] = mem[idx[r], j] +elem mode: dst[i, j] = mem[idx[i, j]] ``` **Arguments:** @@ -6867,13 +6868,15 @@ dst[i, j] = mem[idx[i, j]] - `idx` element type must be signless `i32`. - **Tile / memory roles** - - `dst` and `idx` must be `loc=vec`, `blayout=row_major`, `slayout=none_box`. + - `dst` must be `loc=vec`, `blayout=row_major`, `slayout=none_box`. + - `idx` must be `loc=vec`, `slayout=none_box`. `row_major` and `col_major` are both accepted for row mode. - `mem` must denote a GlobalTensor in GM memory. - `mem` must use `ND` layout when layout can be inferred. - **Shape** - - `dst row == idx row`. - - `idx column == 1` or `idx column == dst column`. + - Element mode: `idx valid_shape == dst valid_shape`. + - Row mode: `idx valid_shape` may be `[1, dst.valid_row]` or `[dst.valid_row, 1]`. + - The `[1, R]` row-mode variant uses `row_major`; the `[R, 1]` row-mode variant uses `col_major`. - If `mem` is a rank-5 static GM memref, it must satisfy `<1, 1, 1, Rows, RowWidth>`. - **Out-of-bounds mode** @@ -6904,7 +6907,8 @@ pto.mgather ins(%mem, %idx : memref<...>, !pto.tile_buf<...>) **Semantics:** ``` -mem[idx[i, j]] = src[i, j] +row mode (default): mem[idx[r], j] = src[r, j] +elem mode: mem[idx[i, j]] = src[i, j] ``` **Arguments:** @@ -6926,13 +6930,15 @@ mem[idx[i, j]] = src[i, j] - `idx` element type must be signless `i32`. - **Tile / memory roles** - - `src` and `idx` must be `loc=vec`, `blayout=row_major`, `slayout=none_box`. + - `src` must be `loc=vec`, `blayout=row_major`, `slayout=none_box`. + - `idx` must be `loc=vec`, `slayout=none_box`. `row_major` and `col_major` are both accepted for row mode. - `mem` must denote a GlobalTensor in GM memory. - `mem` must use `ND` layout when layout can be inferred. - **Shape** - - `src row == idx row`. - - `idx column == 1` or `idx column == src column`. + - Element mode: `idx valid_shape == src valid_shape`. + - Row mode: `idx valid_shape` may be `[1, src.valid_row]` or `[src.valid_row, 1]`. + - The `[1, R]` row-mode variant uses `row_major`; the `[R, 1]` row-mode variant uses `col_major`. - If `mem` is a rank-5 static GM memref, it must satisfy `<1, 1, 1, Rows, RowWidth>`. - **Atomic modes** diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index fd42d8736..0cea9ab48 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -2504,7 +2504,31 @@ def MGatherOp : PTO_TOp<"mgather", [ let extraClassDeclaration = [{ static StringRef getIntrinsicName() { return "MGATHER"; } - ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_MTE2; } + ::mlir::pto::PIPE getPipe() { + auto isA5Target = [&]() -> bool { + auto moduleOp = getOperation()->getParentOfType<::mlir::ModuleOp>(); + if (!moduleOp) + return false; + + if (auto arch = + moduleOp->getAttrOfType<::mlir::StringAttr>("pto.target_arch")) { + if (arch.getValue().equals_insensitive("a5")) + return true; + } + + if (auto spec = + moduleOp->getAttrOfType<::mlir::StringAttr>("pto.device-spec")) { + auto s = spec.getValue(); + if (s.starts_with("Ascend950") || s.starts_with("Ascend910_95")) + return true; + } + + return false; + }; + + return isA5Target() ? ::mlir::pto::PIPE::PIPE_V + : ::mlir::pto::PIPE::PIPE_MTE2; + } ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } }]; } @@ -2660,7 +2684,31 @@ def MScatterOp : PTO_TOp<"mscatter", [ let extraClassDeclaration = [{ static StringRef getIntrinsicName() { return "MSCATTER"; } - ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::pto::PIPE getPipe() { + auto isA5Target = [&]() -> bool { + auto moduleOp = getOperation()->getParentOfType<::mlir::ModuleOp>(); + if (!moduleOp) + return false; + + if (auto arch = + moduleOp->getAttrOfType<::mlir::StringAttr>("pto.target_arch")) { + if (arch.getValue().equals_insensitive("a5")) + return true; + } + + if (auto spec = + moduleOp->getAttrOfType<::mlir::StringAttr>("pto.device-spec")) { + auto s = spec.getValue(); + if (s.starts_with("Ascend950") || s.starts_with("Ascend910_95")) + return true; + } + + return false; + }; + + return isA5Target() ? ::mlir::pto::PIPE::PIPE_V + : ::mlir::pto::PIPE::PIPE_MTE3; + } ::mlir::MutableOperandRange getDpsInitsMutable() { return getMemMutable(); } }]; } diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 0b5f6cbcd..1c12ef331 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -3019,27 +3019,67 @@ static LogicalResult verifyMGatherMScatterMemOperand(Operation *op, "expects mem to be !pto.partition_tensor_view or a GM/ZERO memref"); } +static bool hasCompatibleKnownExtent(int64_t lhs, int64_t rhs); +static bool isKnownUnitExtent(int64_t value); + static LogicalResult verifyMGatherMScatterTileShape(Operation *op, Type dataTy, Type idxTy, StringRef dataName) { - auto dataShape = getShapeVec(dataTy); - auto idxShape = getShapeVec(idxTy); - if (dataShape.size() != 2 || idxShape.size() != 2) - return op->emitOpError() << "expects " << dataName - << " and idx to be rank-2"; - - if (dataShape[0] != ShapedType::kDynamic && - idxShape[0] != ShapedType::kDynamic && dataShape[0] != idxShape[0]) + auto dataValid = getValidShapeVec(dataTy); + auto idxValid = getValidShapeVec(idxTy); + if (dataValid.size() != 2 || idxValid.size() != 2) return op->emitOpError() << "expects " << dataName - << " and idx static row dimensions to match"; + << " and idx to have rank-2 valid_shape"; + + auto idxTile = dyn_cast(idxTy); + if (!idxTile) + return op->emitOpError("expects idx to be a tile_buf type"); + + const bool idxRowMajor = + idxTile.getBLayoutValueI32() == + static_cast(pto::BLayout::RowMajor); + const bool idxColMajor = + idxTile.getBLayoutValueI32() == + static_cast(pto::BLayout::ColMajor); + + const bool rowCoalesce1xR = + idxRowMajor && isKnownUnitExtent(idxValid[0]) && + hasCompatibleKnownExtent(idxValid[1], dataValid[0]); + const bool rowCoalesceRx1 = + idxColMajor && hasCompatibleKnownExtent(idxValid[0], dataValid[0]) && + isKnownUnitExtent(idxValid[1]); + const bool elemCoalesce = + hasCompatibleKnownExtent(idxValid[0], dataValid[0]) && + hasCompatibleKnownExtent(idxValid[1], dataValid[1]); + + if (!(rowCoalesce1xR || rowCoalesceRx1 || elemCoalesce)) + return op->emitOpError() + << "expects idx valid_shape to be [1, " << dataName + << ".valid_row], [" << dataName + << ".valid_row, 1], or match " << dataName << " valid_shape"; - int64_t dataCols = dataShape[1]; - int64_t idxCols = idxShape[1]; - if (idxCols != ShapedType::kDynamic && dataCols != ShapedType::kDynamic && - idxCols != 1 && idxCols != dataCols) - return op->emitOpError() << "expects idx cols to be 1 or equal to " - << dataName << " cols"; + return success(); +} +static LogicalResult verifyMGatherMScatterIdxTile(Operation *op, Type ty, + StringRef name) { + if (failed(verifyTileBufCommon(op, ty, name))) + return failure(); + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name + << " to be in the vec address space"; + auto tb = dyn_cast(ty); + if (!tb) + return op->emitOpError() << "expects " << name << " to be a tile_buf type"; + int32_t blayout = tb.getBLayoutValueI32(); + if (blayout != static_cast(pto::BLayout::RowMajor) && + blayout != static_cast(pto::BLayout::ColMajor)) + return op->emitOpError() << "expects " << name + << " to use row_major or col_major blayout"; + if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return op->emitOpError() << "expects " << name + << " to use the none_box slayout"; return success(); } @@ -6428,7 +6468,7 @@ LogicalResult MScatterOp::verify() { return emitOpError("expects src, idx, and mem to use supported PTO shapes"); if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || - failed(verifyNDStyleVecTile(*this, idxTy, "idx"))) + failed(verifyMGatherMScatterIdxTile(getOperation(), idxTy, "idx"))) return failure(); Type srcElem = getElemTy(srcTy); @@ -6483,7 +6523,7 @@ LogicalResult MGatherOp::verify() { return emitOpError("expects mem, idx, and dst to use supported PTO shapes"); if (failed(verifyNDStyleVecTile(*this, dstTy, "dst")) || - failed(verifyNDStyleVecTile(*this, idxTy, "idx"))) + failed(verifyMGatherMScatterIdxTile(getOperation(), idxTy, "idx"))) return failure(); Type dstElem = getElemTy(dstTy); diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 2a8f2c122..632e66760 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -140,6 +140,69 @@ static Value maybeWrapGlobalMemrefAsGlobalTensor( ConversionPatternRewriter &rewriter, Location loc, Value loweredValue, Type originalType, Operation *anchor); +static bool hasCompatibleKnownExtentForMGather(int64_t lhs, int64_t rhs) { + return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || + lhs == rhs; +} + +static bool isKnownUnitExtentForMGather(int64_t value) { + return value == ShapedType::kDynamic || value == 1; +} + +struct GatherScatterShapeLayoutInfo { + SmallVector shape; + bool rowMajor = false; + bool colMajor = false; +}; + +static std::optional +getGatherScatterShapeLayoutInfo(Type ty) { + if (auto tileTy = dyn_cast(ty)) { + ArrayRef validShape = tileTy.getValidShape(); + if (validShape.size() != 2) + return std::nullopt; + + GatherScatterShapeLayoutInfo info; + info.shape.assign(validShape.begin(), validShape.end()); + int32_t blayout = tileTy.getBLayoutValueI32(); + info.rowMajor = blayout == static_cast(pto::BLayout::RowMajor); + info.colMajor = blayout == static_cast(pto::BLayout::ColMajor); + return info; + } + + auto memRefTy = dyn_cast(ty); + if (!memRefTy || memRefTy.getRank() != 2) + return std::nullopt; + + SmallVector strides; + int64_t offset = ShapedType::kDynamic; + if (failed(getStridesAndOffset(memRefTy, strides, offset)) || + strides.size() != 2) + return std::nullopt; + + GatherScatterShapeLayoutInfo info; + info.shape.assign(memRefTy.getShape().begin(), memRefTy.getShape().end()); + info.rowMajor = strides[1] == 1; + info.colMajor = strides[0] == 1; + return info; +} + +static bool isRowCoalescedMGatherIndexType(Type dataTy, Type idxTy) { + auto dataInfo = getGatherScatterShapeLayoutInfo(dataTy); + auto idxInfo = getGatherScatterShapeLayoutInfo(idxTy); + if (!dataInfo || !idxInfo) + return false; + + const bool rowCoalesce1xR = + idxInfo->rowMajor && isKnownUnitExtentForMGather(idxInfo->shape[0]) && + hasCompatibleKnownExtentForMGather(idxInfo->shape[1], dataInfo->shape[0]); + const bool rowCoalesceRx1 = + idxInfo->colMajor && + hasCompatibleKnownExtentForMGather(idxInfo->shape[0], dataInfo->shape[0]) && + isKnownUnitExtentForMGather(idxInfo->shape[1]); + return rowCoalesce1xR || rowCoalesceRx1; +} + static std::optional getLayoutAttrFromOp(Operation *op) { if (!op) return std::nullopt; @@ -2577,24 +2640,30 @@ struct PTOMGatherToMGATHER : public OpConversionPattern { Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); - ArrayAttr templateArgs; + auto gatherOobTok = [&](pto::GatherOOB mode) -> StringRef { + switch (mode) { + case pto::GatherOOB::Undefined: + return "pto::GatherOOB::Undefined"; + case pto::GatherOOB::Clamp: + return "pto::GatherOOB::Clamp"; + case pto::GatherOOB::Wrap: + return "pto::GatherOOB::Wrap"; + case pto::GatherOOB::Zero: + return "pto::GatherOOB::Zero"; + } + llvm_unreachable("unknown GatherOOB"); + }; + + SmallVector templateArgVec; + const bool rowCoalesce = + isRowCoalescedMGatherIndexType(op.getDst().getType(), op.getIdx().getType()); + templateArgVec.push_back(emitc::OpaqueAttr::get( + ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); if (op.getGatherOob() != pto::GatherOOB::Undefined) { - auto gatherOobTok = [&](pto::GatherOOB mode) -> StringRef { - switch (mode) { - case pto::GatherOOB::Undefined: - return "pto::GatherOOB::Undefined"; - case pto::GatherOOB::Clamp: - return "pto::GatherOOB::Clamp"; - case pto::GatherOOB::Wrap: - return "pto::GatherOOB::Wrap"; - case pto::GatherOOB::Zero: - return "pto::GatherOOB::Zero"; - } - llvm_unreachable("unknown GatherOOB"); - }; - templateArgs = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, gatherOobTok(op.getGatherOob()))}); + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, gatherOobTok(op.getGatherOob()))); } + ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); rewriter.create( op.getLoc(), TypeRange{}, "MGATHER", @@ -5233,17 +5302,20 @@ struct PTOMScatterToMSCATTER : public OpConversionPattern { llvm_unreachable("unknown ScatterOOB"); }; - SmallVector templateArgVec; + SmallVector templateArgVec; + const bool rowCoalesce = + isRowCoalescedMGatherIndexType(op.getSrc().getType(), op.getIdx().getType()); + templateArgVec.push_back(emitc::OpaqueAttr::get( + ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); if (op.getScatterAtomicOp() != pto::ScatterAtomicOp::None || op.getScatterOob() != pto::ScatterOOB::Undefined) { - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, scatterAtomicTok(op.getScatterAtomicOp()))); + templateArgVec.push_back(emitc::OpaqueAttr::get( + ctx, scatterAtomicTok(op.getScatterAtomicOp()))); if (op.getScatterOob() != pto::ScatterOOB::Undefined) templateArgVec.push_back( emitc::OpaqueAttr::get(ctx, scatterOobTok(op.getScatterOob()))); } - ArrayAttr templateArgs = - templateArgVec.empty() ? ArrayAttr{} : rewriter.getArrayAttr(templateArgVec); + ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); rewriter.create( op.getLoc(), TypeRange{}, "MSCATTER", diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index f9f8c052f..1df9008c1 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -3332,7 +3332,8 @@ struct PTOViewToMemrefPass TypeRange{}, mem, idx, - dst); + dst, + op.getGatherOobAttr()); } SmallVector mascatterops; @@ -3360,7 +3361,9 @@ struct PTOViewToMemrefPass TypeRange{}, src, idx, - mem); + mem, + op.getScatterAtomicOpAttr(), + op.getScatterOobAttr()); } SmallVector printops; func.walk([&](mlir::pto::TPrintOp op) { printops.push_back(op); }); diff --git a/test/lit/pto/issue664_mscatter_pipe_selection.pto b/test/lit/pto/issue664_mscatter_pipe_selection.pto index 5df05fb25..1aa95e740 100644 --- a/test/lit/pto/issue664_mscatter_pipe_selection.pto +++ b/test/lit/pto/issue664_mscatter_pipe_selection.pto @@ -60,4 +60,4 @@ module attributes {pto.target_arch = "a5"} { // CHECK-NOT: set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); // CHECK: wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); // CHECK-NOT: wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); -// CHECK: MSCATTER( +// CHECK: MSCATTER{{(<.*>)?}}( diff --git a/test/lit/pto/issue664_mscatter_pipe_selection_gss.pto b/test/lit/pto/issue664_mscatter_pipe_selection_gss.pto index 16a69269f..94ba43f80 100644 --- a/test/lit/pto/issue664_mscatter_pipe_selection_gss.pto +++ b/test/lit/pto/issue664_mscatter_pipe_selection_gss.pto @@ -60,4 +60,4 @@ module attributes {pto.target_arch = "a5"} { // CHECK-NOT: set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID{{[0-9]+}}); // CHECK: wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID{{[0-9]+}}); // CHECK-NOT: wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID{{[0-9]+}}); -// CHECK: MSCATTER( +// CHECK: MSCATTER{{(<.*>)?}}( diff --git a/test/npu_validation/scripts/generate_testcase.py b/test/npu_validation/scripts/generate_testcase.py index 1a164898c..644e567f9 100644 --- a/test/npu_validation/scripts/generate_testcase.py +++ b/test/npu_validation/scripts/generate_testcase.py @@ -1861,6 +1861,8 @@ def generate_testcase( kernel_has_tscatter = "TSCATTER" in raw_kernel kernel_has_tgather = "TGATHER" in raw_kernel kernel_has_tgatherb = "TGATHERB" in raw_kernel + kernel_has_mscatter = "MSCATTER" in raw_kernel + kernel_has_mgather = "MGATHER" in raw_kernel # Some kernels use an integer tensor as "indices". The safe in-range domain # depends on the op semantics: # - TSCATTER: use a deterministic, collision-free permutation so NPU-vs-NPU @@ -1872,6 +1874,27 @@ def generate_testcase( index_mod = max(elem_count, 1) elif kernel_has_tgather and not kernel_has_tgatherb: index_mod = max(elem_count, 1) + mgather_table_input = None + if kernel_has_mgather: + for p in init_ptrs: + if p.get("role") == "input": + mgather_table_input = p + break + mscatter_indices_input = None + mscatter_output = output_ptrs[0] if kernel_has_mscatter and output_ptrs else None + if kernel_has_mscatter: + for p in reversed(init_ptrs): + p_dtype = _np_dtype_for_cpp(p["cpp_type"]) + if p.get("role") == "input" and ( + p_dtype.startswith("np.int") or p_dtype.startswith("np.uint") + ): + mscatter_indices_input = p + break + if mscatter_output is not None: + index_mod = max( + int(ptr_elem_counts.get(mscatter_output["name"], logical_elem_count)), + 1, + ) mrgsort_packed = "TMRGSORT" in raw_kernel for p in init_ptrs: np_dtype = _np_dtype_for_cpp(p["cpp_type"]) @@ -1880,6 +1903,18 @@ def generate_testcase( is_output = p.get("role") == "output" is_integer = np_dtype.startswith("np.int") or np_dtype.startswith("np.uint") is_tscatter_indices = kernel_has_tscatter and p.get("role") == "input" and is_integer and size == elem_count + is_mscatter_indices = ( + kernel_has_mscatter + and mscatter_indices_input is not None + and name == mscatter_indices_input["name"] + ) + is_mgather_indices = ( + kernel_has_mgather + and mgather_table_input is not None + and p.get("role") == "input" + and is_integer + and name != mgather_table_input["name"] + ) is_tgatherb_offset = kernel_has_tgatherb and p.get("role") == "input" and is_integer and size < elem_count is_tgatherb_src = kernel_has_tgatherb and p.get("role") == "input" and not is_tgatherb_offset # If the kernel has both inputs and outputs, default to zero-init for @@ -1954,6 +1989,26 @@ def generate_testcase( f" {name} = ({name}__row_perm * {cols} + {name}__cols).astype({np_dtype}).reshape(-1)" ) input_generate.append(f" {name}.tofile(\"{name}.bin\")") + elif is_mscatter_indices: + out_count = ( + int(ptr_elem_counts.get(mscatter_output["name"], logical_elem_count)) + if mscatter_output is not None + else max(size, 1) + ) + input_generate.append( + f" {name} = (np.arange({size}, dtype=np.int64) % {out_count}).astype({np_dtype}, copy=False)" + ) + input_generate.append(f" {name}.tofile(\"{name}.bin\")") + elif is_mgather_indices: + table_count = ( + int(ptr_elem_counts.get(mgather_table_input['name'], logical_elem_count)) + if mgather_table_input is not None + else max(size, 1) + ) + input_generate.append( + f" {name} = (np.arange({size}, dtype=np.int64) % {table_count}).astype({np_dtype}, copy=False)" + ) + input_generate.append(f" {name}.tofile(\"{name}.bin\")") elif is_tgatherb_offset: input_generate.append(f" {name} = (np.arange({size}, dtype=np.uint32) * 32).astype({np_dtype})") input_generate.append(f" {name}.tofile(\"{name}.bin\")") @@ -2205,13 +2260,15 @@ def generate_testcase( compare_template = (templates_root / "compare_template.py").read_text(encoding="utf-8") compare_lines = [" ok = True"] compare_prefix_counts = {} - tscatter_indices_input = None + scatter_indices_input = None if kernel_has_tscatter: for p in init_ptrs: p_dtype = _np_dtype_for_cpp(p["cpp_type"]) if p.get("role") == "input" and (p_dtype.startswith("np.int") or p_dtype.startswith("np.uint")): - tscatter_indices_input = p + scatter_indices_input = p break + elif kernel_has_mscatter and mscatter_indices_input is not None: + scatter_indices_input = mscatter_indices_input for p in output_ptrs: name = p["name"] req = inferred_counts.get(name) @@ -2242,16 +2299,16 @@ def generate_testcase( eps = _default_eps_for_cpp_type(p["cpp_type"]) is_bf16_output = _is_bf16_cpp_type(p["cpp_type"]) bf16_max_ulp = _default_bf16_max_ulp_for_cpp_type(p["cpp_type"]) - if kernel_has_tscatter and tscatter_indices_input is not None: + if (kernel_has_tscatter or kernel_has_mscatter) and scatter_indices_input is not None: if is_bf16_output: compare_lines.append( f" ok = compare_bf16_bin_at_indices(\"golden_{name}.bin\", \"{name}.bin\", {bf16_max_ulp}, " - f"\"{tscatter_indices_input['name']}.bin\", {_np_dtype_for_cpp(tscatter_indices_input['cpp_type'])}) and ok" + f"\"{scatter_indices_input['name']}.bin\", {_np_dtype_for_cpp(scatter_indices_input['cpp_type'])}) and ok" ) else: compare_lines.append( f" ok = compare_bin_at_indices(\"golden_{name}.bin\", \"{name}.bin\", {np_dtype}, {eps}, " - f"\"{tscatter_indices_input['name']}.bin\", {_np_dtype_for_cpp(tscatter_indices_input['cpp_type'])}) and ok" + f"\"{scatter_indices_input['name']}.bin\", {_np_dtype_for_cpp(scatter_indices_input['cpp_type'])}) and ok" ) elif has_packed_pred_mask and p["cpp_type"] in {"uint8_t", "int8_t"}: compare_lines.append( diff --git a/test/npu_validation/scripts/run_remote_npu_validation.sh b/test/npu_validation/scripts/run_remote_npu_validation.sh index 771e7aa1b..974c3d18e 100644 --- a/test/npu_validation/scripts/run_remote_npu_validation.sh +++ b/test/npu_validation/scripts/run_remote_npu_validation.sh @@ -231,6 +231,15 @@ else fi fi +pto_isa_has_symbol() { + local symbol="$1" + [[ -n "${symbol}" ]] || return 1 + find "${PTO_ISA_ROOT}/include" "${PTO_ISA_ROOT}/tests" \ + -type f \( -name '*.h' -o -name '*.hpp' -o -name '*.cpp' -o -name '*.cc' \) \ + -print0 2>/dev/null \ + | xargs -0 grep -F -q "${symbol}" +} + status=0 ok_count=0 fail_count=0 @@ -267,6 +276,12 @@ while IFS= read -r -d '' cpp; do log "SKIP: ${testcase} (SKIP_CASES)" continue fi + if [[ "${testcase}" == "partarg" ]] && ! pto_isa_has_symbol "TPARTARGMAX("; then + skip_count=$((skip_count + 1)) + printf "%s\tSKIP\t%s\tpto-isa missing TPARTARGMAX/TPARTARGMIN\n" "${testcase}" "${STAGE}" >> "${RESULTS_TSV}" + log "SKIP: ${testcase} (pto-isa missing TPARTARG intrinsics)" + continue + fi if [[ "${testcase}" == "gemvmx" ]]; then soc_lc="$(printf '%s' "${SOC_VERSION:-}" | tr '[:upper:]' '[:lower:]')" if [[ "$soc_lc" != *"a5"* && "$soc_lc" != *"950"* ]]; then diff --git a/test/samples/Abs/abs.py b/test/samples/Abs/abs.py index 759d93e70..7c16c1dce 100644 --- a/test/samples/Abs/abs.py +++ b/test/samples/Abs/abs.py @@ -46,10 +46,7 @@ def build(): tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c32, c32], [c32, c1]).result - # Test pto.get_tensor_view_dim: get dim sizes from tensor_view and use as partition sizes - dim0 = pto.GetTensorViewDimOp(tv0, c0).result - dim1 = pto.GetTensorViewDimOp(tv0, c1).result - sv0 = pto.PartitionViewOp(tile_view_32, tv0, offsets=[c0, c0], sizes=[dim0, dim1]).result + sv0 = pto.PartitionViewOp(tile_view_32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result tb0 = pto.AllocTileOp(tile_buf_32).result tb1 = pto.AllocTileOp(tile_buf_32).result diff --git a/test/samples/Mgather/mgather.py b/test/samples/Mgather/mgather.py index d63a2008c..a9b5e0cb6 100644 --- a/test/samples/Mgather/mgather.py +++ b/test/samples/Mgather/mgather.py @@ -21,9 +21,9 @@ def build(): i32 = IntegerType.get_signless(32, ctx) ptr_i32 = pto.PtrType.get(i32, ctx) - tv2_i32 = pto.TensorViewType.get(2, i32, ctx) tile_view_32 = pto.PartitionTensorViewType.get([32, 32], i32, ctx) + tile_view_1x32 = pto.PartitionTensorViewType.get([1, 32], i32, ctx) vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) @@ -31,7 +31,8 @@ def build(): fractal_ab_size = pto.TileConfig.fractalABSize cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) - tile_buf_i32 = pto.TileBufType.get([32, 32], i32, vec, [32, 32], cfg, ctx) + tile_buf_data_i32 = pto.TileBufType.get([32, 32], i32, vec, [32, 32], cfg, ctx) + tile_buf_idx_i32 = pto.TileBufType.get([1, 32], i32, vec, [1, 32], cfg, ctx) fn_ty = func.FunctionType.get([ptr_i32, ptr_i32, ptr_i32], []) with InsertionPoint(m.body): @@ -46,26 +47,20 @@ def build(): arg0, arg1, arg2 = entry.arguments - # %0/%1/%2 = pto.make_tensor_view %arg?, shape=[%c32,%c32] strides=[%c32,%c1] tv0 = pto.MakeTensorViewOp(tv2_i32, arg0, [c32, c32], [c32, c1]).result - tv1 = pto.MakeTensorViewOp(tv2_i32, arg1, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_i32, arg1, [c1, c32], [c32, c1]).result tv2 = pto.MakeTensorViewOp(tv2_i32, arg2, [c32, c32], [c32, c1]).result sv0 = pto.PartitionViewOp(tile_view_32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result - sv1 = pto.PartitionViewOp(tile_view_32, tv1, offsets=[c0, c0], sizes=[c32, c32]).result - - tb1 = pto.AllocTileOp(tile_buf_i32).result - tb2 = pto.AllocTileOp(tile_buf_i32).result + sv1 = pto.PartitionViewOp(tile_view_1x32, tv1, offsets=[c0, c0], sizes=[c1, c32]).result - # pto.load_dps_tb ins(%sv) outs(%tb) - pto.TLoadOp(None, sv1, tb1) # result=None + tb1 = pto.AllocTileOp(tile_buf_idx_i32).result + tb2 = pto.AllocTileOp(tile_buf_data_i32).result + pto.TLoadOp(None, sv1, tb1) pto.MGatherOp(sv0, tb1, tb2) - # %8 = subview on output tensor_view sv2 = pto.PartitionViewOp(tile_view_32, tv2, offsets=[c0, c0], sizes=[c32, c32]).result - - # pto.store_dps_tb ins(%tb2) outs(%sv2) pto.TStoreOp(None, tb2, sv2) func.ReturnOp([]) diff --git a/test/samples/Mscatter/mscatter.py b/test/samples/Mscatter/mscatter.py index d3a3652df..7700a1279 100644 --- a/test/samples/Mscatter/mscatter.py +++ b/test/samples/Mscatter/mscatter.py @@ -21,9 +21,9 @@ def build(): i32 = IntegerType.get_signless(32, ctx) ptr_i32 = pto.PtrType.get(i32, ctx) - tv2_i32 = pto.TensorViewType.get(2, i32, ctx) tile_view_32 = pto.PartitionTensorViewType.get([32, 32], i32, ctx) + tile_view_1x32 = pto.PartitionTensorViewType.get([1, 32], i32, ctx) vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) @@ -31,7 +31,8 @@ def build(): fractal_ab_size = pto.TileConfig.fractalABSize cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) - tile_buf_i32 = pto.TileBufType.get([32, 32], i32, vec, [32, 32], cfg, ctx) + tile_buf_data_i32 = pto.TileBufType.get([32, 32], i32, vec, [32, 32], cfg, ctx) + tile_buf_idx_i32 = pto.TileBufType.get([1, 32], i32, vec, [1, 32], cfg, ctx) fn_ty = func.FunctionType.get([ptr_i32, ptr_i32, ptr_i32], []) with InsertionPoint(m.body): @@ -46,21 +47,19 @@ def build(): arg0, arg1, arg2 = entry.arguments - # %0/%1/%2 = pto.make_tensor_view %arg?, shape=[%c32,%c32] strides=[%c32,%c1] tv0 = pto.MakeTensorViewOp(tv2_i32, arg0, [c32, c32], [c32, c1]).result - tv1 = pto.MakeTensorViewOp(tv2_i32, arg1, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_i32, arg1, [c1, c32], [c32, c1]).result tv2 = pto.MakeTensorViewOp(tv2_i32, arg2, [c32, c32], [c32, c1]).result sv0 = pto.PartitionViewOp(tile_view_32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result - sv1 = pto.PartitionViewOp(tile_view_32, tv1, offsets=[c0, c0], sizes=[c32, c32]).result + sv1 = pto.PartitionViewOp(tile_view_1x32, tv1, offsets=[c0, c0], sizes=[c1, c32]).result sv2 = pto.PartitionViewOp(tile_view_32, tv2, offsets=[c0, c0], sizes=[c32, c32]).result - tb0 = pto.AllocTileOp(tile_buf_i32).result - tb1 = pto.AllocTileOp(tile_buf_i32).result + tb0 = pto.AllocTileOp(tile_buf_data_i32).result + tb1 = pto.AllocTileOp(tile_buf_idx_i32).result - # pto.load_dps_tb ins(%sv) outs(%tb) pto.TLoadOp(None, sv0, tb0) - pto.TLoadOp(None, sv1, tb1) # result=None + pto.TLoadOp(None, sv1, tb1) pto.MScatterOp(tb0, tb1, sv2) diff --git a/test/samples/Quant/quant.py b/test/samples/Quant/quant.py index 78e68d332..7e802bbc6 100644 --- a/test/samples/Quant/quant.py +++ b/test/samples/Quant/quant.py @@ -8,9 +8,9 @@ """TQuant INT8_SYM kernel sample. - tquant(src_f32, fp_f32) -> dst_i8 + tquant(src_f32, scale_f32[row]) -> dst_i8 -Loads a 32x32 f32 tile (src) and a 32x32 f32 scaling-factor tile (fp), +Loads a 32x32 f32 tile (src) and a 32x1 per-row scaling tile (scale), performs symmetric INT8 quantization, and stores the int8 result tile. Note: int8 tiles require Cols*sizeof(T) to be a multiple of 32 bytes @@ -49,18 +49,27 @@ def _make_common_types(ctx): tv2_i8 = pto.TensorViewType.get(2, i8, ctx) ptv_f32 = pto.PartitionTensorViewType.get(_SHAPE, f32, ctx) + ptv_scale = pto.PartitionTensorViewType.get([_SHAPE[0], 1], f32, ctx) ptv_i8 = pto.PartitionTensorViewType.get(_SHAPE, i8, ctx) vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + bl_col = pto.BLayoutAttr.get(pto.BLayout.ColMajor, ctx) sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) cfg = pto.TileBufConfigAttr.get(bl, sl, pto.TileConfig.fractalABSize, pd, ctx) + cfg_col = pto.TileBufConfigAttr.get( + bl_col, sl, pto.TileConfig.fractalABSize, pd, ctx + ) tb_f32 = pto.TileBufType.get(_SHAPE, f32, vec, _SHAPE, cfg, ctx) + tb_scale = pto.TileBufType.get( + [_SHAPE[0], 1], f32, vec, [_SHAPE[0], 1], cfg_col, ctx + ) tb_i8 = pto.TileBufType.get(_SHAPE, i8, vec, _SHAPE, cfg, ctx) quant_sym = pto.QuantTypeAttr.get(pto.QuantType.INT8_SYM, ctx) + layout_dn = pto.LayoutAttr.get(pto.Layout.DN, ctx) class NS: pass @@ -74,10 +83,13 @@ class NS: ns.tv2_f32 = tv2_f32 ns.tv2_i8 = tv2_i8 ns.ptv_f32 = ptv_f32 + ns.ptv_scale = ptv_scale ns.ptv_i8 = ptv_i8 ns.tb_f32 = tb_f32 + ns.tb_scale = tb_scale ns.tb_i8 = tb_i8 ns.quant_sym = quant_sym + ns.layout_dn = layout_dn return ns @@ -91,7 +103,7 @@ def build(): # ------------------------------------------------------------------ # @tquant_sym_kernel(src_ptr: !pto.ptr, - # fp_ptr: !pto.ptr, + # scale_ptr: !pto.ptr, # dst_ptr: !pto.ptr) # ------------------------------------------------------------------ fn_sym_ty = func.FunctionType.get([t.ptr_f32, t.ptr_f32, t.ptr_i8], []) @@ -109,14 +121,18 @@ def build(): c1 = arith.ConstantOp(idx, 1).result c32 = arith.ConstantOp(idx, 32).result - src_ptr, fp_ptr, dst_ptr = entry_sym.arguments + src_ptr, scale_ptr, dst_ptr = entry_sym.arguments # Make tensor views over the flat global-memory pointers. tv_src = pto.MakeTensorViewOp( t.tv2_f32, src_ptr, [c32, c32], [c32, c1] ).result - tv_fp = pto.MakeTensorViewOp( - t.tv2_f32, fp_ptr, [c32, c32], [c32, c1] + tv_scale = pto.MakeTensorViewOp( + t.tv2_f32, + scale_ptr, + [c32, c1], + [c1, c1], + layout=t.layout_dn, ).result tv_dst = pto.MakeTensorViewOp( t.tv2_i8, dst_ptr, [c32, c32], [c32, c1] @@ -126,8 +142,8 @@ def build(): sv_src = pto.PartitionViewOp( t.ptv_f32, tv_src, offsets=[c0, c0], sizes=[c32, c32] ).result - sv_fp = pto.PartitionViewOp( - t.ptv_f32, tv_fp, offsets=[c0, c0], sizes=[c32, c32] + sv_scale = pto.PartitionViewOp( + t.ptv_scale, tv_scale, offsets=[c0, c0], sizes=[c32, c1] ).result sv_dst = pto.PartitionViewOp( t.ptv_i8, tv_dst, offsets=[c0, c0], sizes=[c32, c32] @@ -135,15 +151,15 @@ def build(): # Allocate on-chip tile buffers. tb_src = pto.AllocTileOp(t.tb_f32).result - tb_fp = pto.AllocTileOp(t.tb_f32).result + tb_scale = pto.AllocTileOp(t.tb_scale).result tb_dst = pto.AllocTileOp(t.tb_i8).result - # Load src and fp tiles from global memory. + # Load src and per-row scale tiles from global memory. pto.TLoadOp(None, sv_src, tb_src) - pto.TLoadOp(None, sv_fp, tb_fp) + pto.TLoadOp(None, sv_scale, tb_scale) # INT8_SYM quantization (no offset operand). - pto.TQuantOp(tb_src, tb_fp, tb_dst, quant_type=t.quant_sym) + pto.TQuantOp(tb_src, tb_scale, tb_dst, quant_type=t.quant_sym) # Store result back to global memory. pto.TStoreOp(None, tb_dst, sv_dst) diff --git a/test/samples/Quant/quant_asym.py b/test/samples/Quant/quant_asym.py index 4a641a3a8..67490c96e 100644 --- a/test/samples/Quant/quant_asym.py +++ b/test/samples/Quant/quant_asym.py @@ -8,10 +8,10 @@ """TQuant INT8_ASYM kernel sample. - tquant(src_f32, fp_f32, offset_f32) -> dst_ui8 + tquant(src_f32, scale_f32[row], offset_f32[row]) -> dst_ui8 -Loads a 32x32 f32 tile (src), a 32x32 f32 scaling-factor tile (fp), and a -32x32 f32 offset tile, performs asymmetric UINT8 quantization, and stores the +Loads a 32x32 f32 tile (src), a 32x1 per-row scaling tile (scale), and a +32x1 per-row offset tile, performs asymmetric UINT8 quantization, and stores the uint8 result tile. Note: uint8 tiles require Cols*sizeof(T) to be a multiple of 32 bytes @@ -50,18 +50,27 @@ def _make_common_types(ctx): tv2_ui8 = pto.TensorViewType.get(2, ui8, ctx) ptv_f32 = pto.PartitionTensorViewType.get(_SHAPE, f32, ctx) + ptv_param = pto.PartitionTensorViewType.get([_SHAPE[0], 1], f32, ctx) ptv_ui8 = pto.PartitionTensorViewType.get(_SHAPE, ui8, ctx) vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + bl_col = pto.BLayoutAttr.get(pto.BLayout.ColMajor, ctx) sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) cfg = pto.TileBufConfigAttr.get(bl, sl, pto.TileConfig.fractalABSize, pd, ctx) + cfg_col = pto.TileBufConfigAttr.get( + bl_col, sl, pto.TileConfig.fractalABSize, pd, ctx + ) tb_f32 = pto.TileBufType.get(_SHAPE, f32, vec, _SHAPE, cfg, ctx) + tb_param = pto.TileBufType.get( + [_SHAPE[0], 1], f32, vec, [_SHAPE[0], 1], cfg_col, ctx + ) tb_ui8 = pto.TileBufType.get(_SHAPE, ui8, vec, _SHAPE, cfg, ctx) quant_asym = pto.QuantTypeAttr.get(pto.QuantType.INT8_ASYM, ctx) + layout_dn = pto.LayoutAttr.get(pto.Layout.DN, ctx) class NS: pass @@ -75,10 +84,13 @@ class NS: ns.tv2_f32 = tv2_f32 ns.tv2_ui8 = tv2_ui8 ns.ptv_f32 = ptv_f32 + ns.ptv_param = ptv_param ns.ptv_ui8 = ptv_ui8 ns.tb_f32 = tb_f32 + ns.tb_param = tb_param ns.tb_ui8 = tb_ui8 ns.quant_asym = quant_asym + ns.layout_dn = layout_dn return ns @@ -92,7 +104,7 @@ def build(): # ------------------------------------------------------------------ # @tquant_asym_kernel(src_ptr: !pto.ptr, - # fp_ptr: !pto.ptr, + # scale_ptr: !pto.ptr, # offset_ptr: !pto.ptr, # dst_ptr: !pto.ptr) # ------------------------------------------------------------------ @@ -113,17 +125,25 @@ def build(): c1 = arith.ConstantOp(idx, 1).result c32 = arith.ConstantOp(idx, 32).result - src_ptr, fp_ptr, off_ptr, dst_ptr = entry_asym.arguments + src_ptr, scale_ptr, off_ptr, dst_ptr = entry_asym.arguments # Make tensor views over the flat global-memory pointers. tv_src = pto.MakeTensorViewOp( t.tv2_f32, src_ptr, [c32, c32], [c32, c1] ).result - tv_fp = pto.MakeTensorViewOp( - t.tv2_f32, fp_ptr, [c32, c32], [c32, c1] + tv_scale = pto.MakeTensorViewOp( + t.tv2_f32, + scale_ptr, + [c32, c1], + [c1, c1], + layout=t.layout_dn, ).result tv_off = pto.MakeTensorViewOp( - t.tv2_f32, off_ptr, [c32, c32], [c32, c1] + t.tv2_f32, + off_ptr, + [c32, c1], + [c1, c1], + layout=t.layout_dn, ).result tv_dst = pto.MakeTensorViewOp( t.tv2_ui8, dst_ptr, [c32, c32], [c32, c1] @@ -133,11 +153,11 @@ def build(): sv_src = pto.PartitionViewOp( t.ptv_f32, tv_src, offsets=[c0, c0], sizes=[c32, c32] ).result - sv_fp = pto.PartitionViewOp( - t.ptv_f32, tv_fp, offsets=[c0, c0], sizes=[c32, c32] + sv_scale = pto.PartitionViewOp( + t.ptv_param, tv_scale, offsets=[c0, c0], sizes=[c32, c1] ).result sv_off = pto.PartitionViewOp( - t.ptv_f32, tv_off, offsets=[c0, c0], sizes=[c32, c32] + t.ptv_param, tv_off, offsets=[c0, c0], sizes=[c32, c1] ).result sv_dst = pto.PartitionViewOp( t.ptv_ui8, tv_dst, offsets=[c0, c0], sizes=[c32, c32] @@ -145,18 +165,18 @@ def build(): # Allocate on-chip tile buffers. tb_src = pto.AllocTileOp(t.tb_f32).result - tb_fp = pto.AllocTileOp(t.tb_f32).result - tb_off = pto.AllocTileOp(t.tb_f32).result + tb_scale = pto.AllocTileOp(t.tb_param).result + tb_off = pto.AllocTileOp(t.tb_param).result tb_dst = pto.AllocTileOp(t.tb_ui8).result # Load tiles from global memory. pto.TLoadOp(None, sv_src, tb_src) - pto.TLoadOp(None, sv_fp, tb_fp) + pto.TLoadOp(None, sv_scale, tb_scale) pto.TLoadOp(None, sv_off, tb_off) # INT8_ASYM quantization (offset operand required). pto.TQuantOp( - tb_src, tb_fp, tb_dst, quant_type=t.quant_asym, offset=tb_off + tb_src, tb_scale, tb_dst, quant_type=t.quant_asym, offset=tb_off ) # Store result back to global memory. diff --git a/test/samples/Quant/quant_asym_golden.py b/test/samples/Quant/quant_asym_golden.py index 4774c23f2..c48b42c22 100644 --- a/test/samples/Quant/quant_asym_golden.py +++ b/test/samples/Quant/quant_asym_golden.py @@ -7,12 +7,7 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. -"""Golden reference for the TQuant INT8_ASYM kernel. - -Formula: dst[i] = clip(round(src[i] * fp[i] + offset[i]), 0, 255) (→ uint8) - -Note: fp is a scale multiplier; offset is the zero-point added before rounding. -""" +"""Golden reference for the TQuant INT8_ASYM kernel.""" import numpy as np from pathlib import Path @@ -20,14 +15,6 @@ _ROWS = 32 _COLS = 32 -_r = np.arange(_ROWS)[:, None] -_c = np.arange(_COLS)[None, :] - - -def _effective_scale(fp_flat): - """TRowExpandMul block-broadcast: element [r][c] uses fp_flat[r*8 + c%8].""" - return fp_flat[_r * 8 + _c % 8] - for search_root in ( Path(__file__).resolve().parent, Path(__file__).resolve().parents[1], @@ -49,20 +36,22 @@ def _effective_scale(fp_flat): def main(): meta = load_case_meta() - src_name, fp_name, off_name = meta.inputs + src_name, scale_name, off_name = meta.inputs generator = rng() src = float_values(generator, meta.elem_counts[src_name], style="signed") - fp = float_values(generator, meta.elem_counts[fp_name], style="positive") + scale = float_values(generator, meta.elem_counts[scale_name], style="positive") off = float_values(generator, meta.elem_counts[off_name], style="signed_small") buffers = default_buffers(meta) buffers[src_name] = src - buffers[fp_name] = fp + buffers[scale_name] = scale buffers[off_name] = off write_buffers(meta, buffers) + scale_2d = scale.reshape(_ROWS, 1) + off_2d = off.reshape(_ROWS, 1) out = np.clip( np.round( - src.reshape(_ROWS, _COLS) * _effective_scale(fp) - + _effective_scale(off) + src.reshape(_ROWS, _COLS) * scale_2d + + off_2d ), 0, 255, ).astype(np.uint8) diff --git a/test/samples/Quant/quant_golden.py b/test/samples/Quant/quant_golden.py index 83f5f8a40..61fc0b613 100644 --- a/test/samples/Quant/quant_golden.py +++ b/test/samples/Quant/quant_golden.py @@ -7,12 +7,7 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. -"""Golden reference for the TQuant INT8_SYM kernel. - -Formula: dst[i] = clip(round(src[i] * fp[i]), -128, 127) (→ int8) - -Note: fp is a scale multiplier (e.g. fp = 1/step_size), not a divisor. -""" +"""Golden reference for the TQuant INT8_SYM kernel.""" import numpy as np from pathlib import Path @@ -20,14 +15,6 @@ _ROWS = 32 _COLS = 32 -_r = np.arange(_ROWS)[:, None] -_c = np.arange(_COLS)[None, :] - - -def _effective_scale(fp_flat): - """TRowExpandMul block-broadcast: element [r][c] uses fp_flat[r*8 + c%8].""" - return fp_flat[_r * 8 + _c % 8] - for search_root in ( Path(__file__).resolve().parent, Path(__file__).resolve().parents[1], @@ -49,16 +36,17 @@ def _effective_scale(fp_flat): def main(): meta = load_case_meta() - src_name, fp_name = meta.inputs + src_name, scale_name = meta.inputs generator = rng() src = float_values(generator, meta.elem_counts[src_name], style="signed") - fp = float_values(generator, meta.elem_counts[fp_name], style="positive") + scale = float_values(generator, meta.elem_counts[scale_name], style="positive") buffers = default_buffers(meta) buffers[src_name] = src - buffers[fp_name] = fp + buffers[scale_name] = scale write_buffers(meta, buffers) + scale_2d = scale.reshape(_ROWS, 1) out = np.clip( - np.round(src.reshape(_ROWS, _COLS) * _effective_scale(fp)), + np.round(src.reshape(_ROWS, _COLS) * scale_2d), -128, 127, ).astype(np.int8) write_golden(meta, {single_output(meta): out})