Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/pypto/codegen/pto/pto_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ class PTOCodegen : public CodegenBase {
std::map<const ir::Var*, std::string> memref_to_var_name; ///< keyed by base_ Ptr
std::vector<std::pair<ir::VarPtr, std::shared_ptr<const ir::TileType>>> tile_var_allocs;
std::set<const ir::Var*> emitted_tile_alloc_vars;
std::set<std::string> emitted_tile_alloc_ssas;
std::map<const ir::Var*, TpopResultInfo> tpop_result_vars;

ir::FunctionPtr current_function;
Expand Down Expand Up @@ -535,6 +536,7 @@ class PTOCodegen : public CodegenBase {
memref_to_var_name.clear();
tile_var_allocs.clear();
emitted_tile_alloc_vars.clear();
emitted_tile_alloc_ssas.clear();
tpop_result_vars.clear();

current_function.reset();
Expand Down
15 changes: 9 additions & 6 deletions src/backend/common/pto_ops_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2365,22 +2365,25 @@ void RegisterPTOOps(Backend& backend, const std::unordered_set<std::string>& exc
});

// In-place accumulation ops (matmul_acc, gemv_acc): ptoas expects the
// accumulator in ins() to be the same SSA value as outs(). InitMemRef
// guarantees that the output shares the MemRef of the accumulator input
// (via set_output_reuses_input), so we use the result buffer (dst) as the
// accumulator operand instead of the IR-level input arg.
// accumulator in ins() to be the same SSA value as outs(). Multiple IR tile
// vars may share one MemRef but still have distinct per-var alloc_tile SSA
// names, so use the accumulator input's SSA value as both src and dst and
// bind the assignment result to that same SSA value.
auto make_acc_codegen = [](const std::string& pto_op) {
return [pto_op](const ir::CallPtr& op, codegen::CodegenBase& codegen_base) -> std::string {
auto& codegen = dynamic_cast<codegen::PTOCodegen&>(codegen_base);
CHECK(op->args_.size() == 3) << pto_op << " requires 3 arguments: acc, lhs, rhs";

std::string dst = codegen.GetCurrentResultTarget();
std::string dst = codegen.GetExprAsCode(op->args_[0]);
std::string lhs = codegen.GetExprAsCode(op->args_[1]);
std::string rhs = codegen.GetExprAsCode(op->args_[2]);
std::string dst_type = codegen.GetCurrentResultTileBufTypeString();
std::string dst_type = codegen.GetExprTypeAnnotation(op->args_[0]);
std::string lhs_type = codegen.GetExprTypeAnnotation(op->args_[1]);
std::string rhs_type = codegen.GetExprTypeAnnotation(op->args_[2]);

INTERNAL_CHECK_SPAN(!dst.empty(), op->span_) << pto_op << " accumulator operand has no tile buffer";
codegen.SetCurrentResultBuf(dst);

std::ostringstream acc_inst;
acc_inst << pto_op << " ins(" << dst << ", " << lhs << ", " << rhs;
std::vector<std::string> ins_type_parts;
Expand Down
79 changes: 71 additions & 8 deletions src/codegen/pto/pto_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
#include <set>
#include <sstream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

#include "pypto/backend/common/backend.h"
#include "pypto/backend/common/backend_config.h"
#include "pypto/backend/common/backend_handler.h"
#include "pypto/codegen/pto/pto_type_utils.h"
#include "pypto/codegen/pto/tile_buf_signature.h"
#include "pypto/core/dtype.h"
#include "pypto/core/logging.h"
#include "pypto/ir/expr.h"
Expand Down Expand Up @@ -145,6 +147,42 @@ bool ShouldAliasScatterUpdateResultToInput(const AssignStmtPtr& stmt) {

const auto& FlattenBody = transform_utils::FlattenToStmts;

bool IsInPlaceAccumulatorCall(const CallPtr& call) {
if (!call || !call->op_) return false;
return call->op_->name_ == "tile.matmul_acc" || call->op_->name_ == "tile.gemv_acc";
}

bool HasStaticAllocTileShape(const std::shared_ptr<const TileType>& tile_type) {
if (!tile_type) return false;
if (tile_type->shape_.size() == 1) {
return As<ir::ConstInt>(tile_type->shape_[0]) != nullptr;
}
if (tile_type->shape_.size() >= 2) {
return As<ir::ConstInt>(tile_type->shape_[0]) != nullptr &&
As<ir::ConstInt>(tile_type->shape_[1]) != nullptr;
}
return false;
}

bool IsMatmulOperandBuffer(const std::shared_ptr<const TileType>& tile_type) {
if (!tile_type) return false;
return tile_type->memory_space_ == ir::MemorySpace::Left ||
tile_type->memory_space_ == ir::MemorySpace::Right;
}

std::optional<std::string> GetStaticMatmulOperandReuseKey(const std::shared_ptr<const TileType>& tile_type,
const std::string& type_str) {
if (!IsMatmulOperandBuffer(tile_type)) return std::nullopt;
if (!HasStaticAllocTileShape(tile_type)) return std::nullopt;

const auto sig = TileBufSignature::FromTileType(*tile_type);
if (sig.v_row_dynamic || sig.v_col_dynamic) return std::nullopt;

std::ostringstream key;
key << type_str << "|v_row=" << sig.v_row << "|v_col=" << sig.v_col;
return key.str();
}

} // namespace

// Visitor to collect all MemRef objects from TileType variables
Expand Down Expand Up @@ -333,20 +371,42 @@ void PTOCodegen::GenerateFunction(const FunctionPtr& func) {
// Still collect fs_.memref_to_tile_type for GetTileBufTypeString fallback paths
fs_.memref_to_tile_type = collector.GetMemRefTileTypes();

// Per-var SSA binding: each tile variable gets its own SSA name
// Tile-buffer SSA binding. A PTO tile_buf SSA denotes a mutable tile handle,
// not an immutable value. Reuse handles only for static Left/Right matmul
// operand buffers that MemoryReuse placed in the same physical L0A/L0B
// MemRef; PTOAS otherwise cannot see the WAR/WAW dependency between a tmov
// writing the reused buffer and the following tmatmul consuming it. Keep the
// per-variable model for Vec/Mat/Acc, dynamic byte offsets, and
// shape/view-distinct signatures so unrelated ST kernels preserve their
// existing scheduling surface.
std::map<std::tuple<const ir::Var*, int64_t, std::string>, std::string> matmul_operand_reuse;
for (const auto& [tile_var, tile_type] : fs_.tile_var_allocs) {
std::string ssa_name = NewNamedTemp(tile_var->name_hint_);
BindVarToMlir(tile_var, ssa_name);

// Pre-populate type so body visitors (e.g., tile.reshape no-op check)
// can query it before per-variable alloc_tile emission runs.
std::string type_str = GetTileBufTypeStringFromTileType(tile_type);
fs_.ssa_to_tile_buf_type[ssa_name] = type_str;

auto memref = ir::GetDefinedMemRef(tile_type);
const ir::Var* base_ptr = memref->base_.get();

std::string ssa_name;
auto reuse_key = GetStaticMatmulOperandReuseKey(tile_type, type_str);
auto const_offset = As<ir::ConstInt>(memref->byte_offset_);
if (reuse_key.has_value() && const_offset && fs_.tpop_result_vars.count(tile_var.get()) == 0) {
auto key = std::make_tuple(base_ptr, const_offset->value_, *reuse_key);
auto reuse_it = matmul_operand_reuse.find(key);
if (reuse_it != matmul_operand_reuse.end()) {
ssa_name = reuse_it->second;
} else {
ssa_name = NewNamedTemp(tile_var->name_hint_);
matmul_operand_reuse.emplace(std::move(key), ssa_name);
}
} else {
ssa_name = NewNamedTemp(tile_var->name_hint_);
}

BindVarToMlir(tile_var, ssa_name);
fs_.ssa_to_tile_buf_type[ssa_name] = type_str;

// Also maintain fs_.memref_to_mlir for compatibility (first var per allocation)
const ir::Var* base_ptr = memref->base_.get();
if (fs_.memref_to_mlir.find(base_ptr) == fs_.memref_to_mlir.end()) {
fs_.memref_to_mlir[base_ptr] = ssa_name;
}
Expand Down Expand Up @@ -822,6 +882,9 @@ void PTOCodegen::EmitAllocTileForVar(const ir::VarPtr& tile_var,
INTERNAL_CHECK_SPAN(mlir_it != fs_.var_to_mlir.end(), tile_var->span_)
<< "Tile var " << tile_var->name_hint_ << " not found in fs_.var_to_mlir";
std::string tile_buf = mlir_it->second;
if (!fs_.emitted_tile_alloc_ssas.insert(tile_buf).second) {
return;
}

AllocTileFields fields = ComputeAllocTileFields(tile_type);

Expand Down Expand Up @@ -1059,7 +1122,7 @@ void PTOCodegen::VisitStmt_(const AssignStmtPtr& op) {

if (auto tile_type = ir::GetTileTypeWithMemRef(op->var_->GetType())) {
if (!is_set_validshape && fs_.tpop_result_vars.count(op->var_.get()) == 0 &&
!alias_scatter_result_to_input) {
!alias_scatter_result_to_input && !IsInPlaceAccumulatorCall(call)) {
EmitAllocTileForVar(op->var_, tile_type);
}
}
Expand Down
74 changes: 74 additions & 0 deletions tests/st/runtime/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class can run on multiple platforms via ``@pytest.mark.parametrize``.
import torch
from examples.kernels.matmul import matmul_acc_64
from harness.core.harness import PLATFORMS, DataType, PTOTestCase, TensorSpec
from pypto.runtime.runner import RunConfig


class TestMatmul(PTOTestCase):
Expand Down Expand Up @@ -463,6 +464,70 @@ def compute_expected(self, tensors, params=None):
tensors["c"][:] = torch.matmul(tensors["a"].to(torch.float32), tensors["b"].to(torch.float32))


class TestMatmulAccBTransposeNopad(PTOTestCase):
"""Issue #1213: C = X @ W^T with b_trans=True and K split across matmul_acc."""

__test__ = False

def __init__(
self,
dtype: DataType = DataType.FP32,
*,
platform: str | None = None,
config=None,
):
super().__init__(config or RunConfig(rtol=4e-3, atol=4e-3), platform=platform)
self.M = 16
self.K = 1024
self.N = 32
self.K_CHUNK = 512
self.dtype = dtype

def get_name(self) -> str:
return f"matmulacc_btranspose_nopad_{self.dtype.value}_{self.M}x{self.K}x{self.N}"

def define_tensors(self) -> list[TensorSpec]:
return [
TensorSpec("x", [self.M, self.K], self.dtype, init_value=torch.randn),
TensorSpec("w", [self.N, self.K], self.dtype, init_value=torch.randn),
TensorSpec("out", [self.M, self.N], DataType.FP32, is_output=True),
]

def get_program(self) -> Any:
M, K, N, K_CHUNK = self.M, self.K, self.N, self.K_CHUNK
K_BLOCKS = K // K_CHUNK
elem_dtype = pl.FP32 if self.dtype is DataType.FP32 else pl.BF16

@pl.program
class MatmulAccBTransposeNopadProgram:
@pl.function(type=pl.FunctionType.Opaque)
def main(
self,
x: pl.Tensor[[M, K], elem_dtype],
w: pl.Tensor[[N, K], elem_dtype],
out: pl.Out[pl.Tensor[[M, N], pl.FP32]],
):
with pl.at(
level=pl.Level.CORE_GROUP,
optimization=pl.chunked_loop_optimizer,
name_hint="linear",
):
x0 = pl.slice(x, [M, K_CHUNK], [0, 0])
w0 = pl.slice(w, [N, K_CHUNK], [0, 0])
acc = pl.matmul(x0, w0, b_trans=True, out_dtype=pl.FP32)
for kb in pl.range(1, K_BLOCKS):
k0 = kb * K_CHUNK
x_chunk = pl.slice(x, [M, K_CHUNK], [0, k0])
w_chunk = pl.slice(w, [N, K_CHUNK], [0, k0])
acc = pl.matmul_acc(acc, x_chunk, w_chunk, b_trans=True)
out = pl.assemble(out, acc, [0, 0])

return MatmulAccBTransposeNopadProgram

def compute_expected(self, tensors, params=None):
tensors["out"][:] = tensors["x"].float() @ tensors["w"].float().T


# =============================================================================
# pytest test functions
# =============================================================================
Expand All @@ -484,6 +549,8 @@ def compute_expected(self, tensors, params=None):
# (BATCH=16, K_CHUNK=128, OUT_CHUNK=256). Same 2-iter K-loop, BF16 inputs +
# FP32 accumulator.
_AUTOL0_BF16_SHAPES = [(16, 128, 256)]
_ISSUE1213_DTYPES = [pytest.param(DataType.FP32, id="fp32"), pytest.param(DataType.BF16, id="bf16")]
_A2A3_ONLY = [pytest.param("a2a3", id="a2a3")]


class TestMatmulOperations:
Expand Down Expand Up @@ -559,6 +626,13 @@ def test_matmul_outer_pipelined_bf16(self, test_runner, platform):
result = test_runner.run(TestMatmulOuterPipelinedBF16(platform=platform))
assert result.passed, f"Test failed: {result.error}"

@pytest.mark.parametrize("platform", _A2A3_ONLY)
@pytest.mark.parametrize("dtype", _ISSUE1213_DTYPES)
def test_matmulacc_btranspose_nopad_issue1213(self, test_runner, platform, dtype):
"""Regression for b_trans=True matmul_acc over reused L0A/L0B buffers."""
result = test_runner.run(TestMatmulAccBTransposeNopad(dtype=dtype, platform=platform))
assert result.passed, f"Test failed: {result.error}"


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading
Loading