Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/en/dev/passes/14-flatten_tile_nd_to_2d.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Per-statement handling:
| `tile.store` (2D tensor) | Pass through unchanged |
| `tile.create`/`tile.full` (>2D) | Rebuild with flattened 2D shape directly |
| `tile.sum`/`tile.max`/`tile.min` (>2D) | Remap axis to 1 (last axis of 2D) |
| `tile.batch_matmul` | Expand to per-batch 2D `tile.matmul`, honoring batch broadcast and any operand-side transpose carried in the producer `tile.load(target_memory=Mat, transpose=True)` |
| `tile.batch_matmul` | Expand to per-batch 2D `tile.matmul`, honoring batch broadcast and any operand-side transpose carried in the producer `tile.load(target_memory=Mat, transpose=True)`. When the rhs has no effective batch (`prod(rhs_batch_dims) == 1`), the rhs page extraction is hoisted out of the per-batch unroll so only one rhs.load is emitted (saves B-1 redundant rhs loads) |
| `tile.batch_matmul_acc` | Expand to per-batch 2D `tile.matmul_acc`, slicing the (already-flattened) accumulator per batch index; an explicit `tile.move(target_memory=Acc)` is inserted when the accumulator is in another memory space |
| Other tile ops (>2D) | Substitute vars, re-create with 2D types |
| 1D/2D tile ops | Unchanged |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh-cn/dev/passes/14-flatten_tile_nd_to_2d.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ program_2d = flatten_pass(program)
| `tile.store`(2D 张量) | 直接透传 |
| `tile.create`/`tile.full`(>2D) | 直接使用展平的 2D 形状重建 |
| `tile.sum`/`tile.max`/`tile.min`(>2D) | 将 axis 映射为 1(2D 的最后轴) |
| `tile.batch_matmul` | 展开为逐 batch 的 2D `tile.matmul`,处理 batch broadcast;operand 的 transpose 通过生产侧 `tile.load(target_memory=Mat, transpose=True)` 携带 |
| `tile.batch_matmul` | 展开为逐 batch 的 2D `tile.matmul`,并处理 batch broadcast;operand 的 transpose 通过生产侧 `tile.load(target_memory=Mat, transpose=True)` 携带。当 rhs 实质无 batch(`prod(rhs_batch_dims) == 1`)时,rhs 的 page 提取会被提到 unroll 循环外,整段只发射一次 rhs.load(省掉 B-1 次冗余 rhs 加载) |
| `tile.batch_matmul_acc` | 展开为逐 batch 的 2D `tile.matmul_acc`,按 batch 索引切分(已展平的)累加器;累加器若不在 Acc 内存空间会插入显式 `tile.move(target_memory=Acc)` |
| 其他 Tile 操作(>2D) | 替换变量,使用 2D 类型重新创建 |
| 1D/2D Tile 操作 | 不变 |
Expand Down
2 changes: 1 addition & 1 deletion runtime
Submodule runtime updated 354 files
36 changes: 31 additions & 5 deletions src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,19 @@ BatchMatmulResult LowerBatchMatmul(const AssignStmtPtr& assign, const CallPtr& c
}
}

// Hoist rhs.ExtractBatchPage when rhs has no effective batch (all rhs batch
// dims are 1, or rhs is 2D). In that case BuildOperandFlatBatchIndex would
// return 0 for every iteration, so ExtractBatchPage emits identical IR each
// round — emit it once outside the loop and reuse the resulting Var across
// all per-batch matmuls. Saves (batch_count - 1) redundant rhs loads.
const int64_t batch_count_rhs = MultiplyStaticDims(rhs_batch_dims, "tile.batch_matmul rhs batch size");
std::optional<BatchPageResult> hoisted_rhs_page;
if (batch_count_rhs == 1 && batch_count > 1) {
hoisted_rhs_page = ExtractBatchPage(rhs_info, rhs_dims, rhs_batch_dims, /*batch_index=*/0, "rhs", def_map,
ctx, op_registry, span);
out.stmts.insert(out.stmts.end(), hoisted_rhs_page->stmts.begin(), hoisted_rhs_page->stmts.end());
}

// Unroll batch dimensions.
for (int64_t i = 0; i < batch_count; ++i) {
auto output_batch_indices = BuildBatchIndices(i, output_batch_dims);
Expand All @@ -858,16 +871,24 @@ BatchMatmulResult LowerBatchMatmul(const AssignStmtPtr& assign, const CallPtr& c
int64_t rhs_batch_idx =
BuildOperandFlatBatchIndex(rhs_batch_dims, output_batch_dims, output_batch_indices);

// Extract 2D pages.
// Extract 2D pages. lhs always extracted per batch; rhs reuses the hoisted
// var when applicable.
auto lhs_page = ExtractBatchPage(lhs_info, lhs_dims, lhs_batch_dims, lhs_batch_idx, "lhs", def_map, ctx,
op_registry, span);
auto rhs_page = ExtractBatchPage(rhs_info, rhs_dims, rhs_batch_dims, rhs_batch_idx, "rhs", def_map, ctx,
op_registry, span);
out.stmts.insert(out.stmts.end(), lhs_page.stmts.begin(), lhs_page.stmts.end());
out.stmts.insert(out.stmts.end(), rhs_page.stmts.begin(), rhs_page.stmts.end());

VarPtr rhs_var;
if (hoisted_rhs_page) {
rhs_var = hoisted_rhs_page->var;
} else {
auto rhs_page = ExtractBatchPage(rhs_info, rhs_dims, rhs_batch_dims, rhs_batch_idx, "rhs", def_map, ctx,
op_registry, span);
out.stmts.insert(out.stmts.end(), rhs_page.stmts.begin(), rhs_page.stmts.end());
rhs_var = rhs_page.var;
}

// Emit tile.matmul.
auto matmul = op_registry.Create("tile.matmul", {lhs_page.var, rhs_page.var}, span);
auto matmul = op_registry.Create("tile.matmul", {lhs_page.var, rhs_var}, span);
auto matmul_var = std::make_shared<Var>("matmul_" + std::to_string(i), matmul->GetType(), span);
out.stmts.push_back(std::make_shared<AssignStmt>(matmul_var, matmul, span));

Expand Down Expand Up @@ -1147,6 +1168,11 @@ std::vector<StmtPtr> TransformBody(const std::vector<StmtPtr>& stmts, FlattenCon
// from the original tensor, the full-batch load becomes dead code. Skip emitting
// it to avoid wasted memory and potential hardware pipeline interference.
//
// Exception: when LowerBatchMatmul will take the unified single-matmul fast path,
// the operand loads are NOT dead — the fast path uses tile.reshape on the original
// load result, not Strategy-1 per-batch re-emission. Operands of such batch_matmuls
// must stay in the IR.
//
// Safety: we count ALL Var references across every statement type (Return, Yield,
// If conditions, For/While bounds, etc.), not just Call arguments. A Var used
// anywhere outside a tile.batch_matmul Call prevents it from being skipped.
Expand Down
77 changes: 77 additions & 0 deletions tests/st/runtime/test_batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,68 @@ def compute_expected(self, tensors, params=None):
tensors["c"][:] = torch.bmm(tensors["a"].transpose(-2, -1), tensors["b"])


class TestBatchMatmulLhsNdRhs2d(PTOTestCase):
"""Tile-level batch matmul with rhs as a true 2D (non-batched) operand.

Exercises FlattenTileNdTo2D's rhs-hoist optimization: lhs is `[B, M, K]`,
rhs is shared `[K, N]` across all batches, so the lowering hoists the
rhs.load out of the per-batch unroll loop (saves B-1 redundant loads).
"""

__test__ = False

def __init__(self, batch: int = 2, m: int = 64, k: int = 64, n: int = 64, config=None):
super().__init__(config)
self.batch = batch
self.M = m
self.K = k
self.N = n

def get_name(self) -> str:
return f"batch_matmul_lhs_nd_rhs_2d_{self.batch}x{self.M}x{self.K}x{self.N}"

def define_tensors(self) -> list[TensorSpec]:
return [
TensorSpec("a", [self.batch, self.M, self.K], DataType.FP32, init_value=torch.randn),
TensorSpec("b", [self.K, self.N], DataType.FP32, init_value=torch.randn),
TensorSpec("c", [self.batch, self.M, self.N], DataType.FP32, is_output=True),
]

def get_program(self) -> Any:
B, M, K, N = self.batch, self.M, self.K, self.N

@pl.program
class BatchMatmulLhsNdRhs2dProgram:
@pl.function(type=pl.FunctionType.InCore)
def batch_matmul_lhs_nd_rhs_2d(
self,
a: pl.Tensor[[B, M, K], pl.FP32],
b: pl.Tensor[[K, N], pl.FP32],
c: pl.Out[pl.Tensor[[B, M, N], pl.FP32]],
) -> pl.Tensor[[B, M, N], pl.FP32]:
tile_a = pl.load(a, offsets=[0, 0, 0], shapes=[B, M, K], target_memory=pl.MemorySpace.Mat)
tile_b = pl.load(b, offsets=[0, 0], shapes=[K, N], target_memory=pl.MemorySpace.Mat)
tile_c = pl.batch_matmul(tile_a, tile_b)
out_c = pl.store(tile_c, offsets=[0, 0, 0], output_tensor=c)
return out_c

@pl.function(type=pl.FunctionType.Orchestration)
def orchestrator(
self,
a: pl.Tensor[[B, M, K], pl.FP32],
b: pl.Tensor[[K, N], pl.FP32],
c: pl.Out[pl.Tensor[[B, M, N], pl.FP32]],
) -> pl.Tensor[[B, M, N], pl.FP32]:
out_c = self.batch_matmul_lhs_nd_rhs_2d(a, b, c)
return out_c

return BatchMatmulLhsNdRhs2dProgram

def compute_expected(self, tensors, params=None):
# torch.matmul broadcasts 2D rhs across all batches of 3D lhs.
tensors["c"][:] = torch.matmul(tensors["a"], tensors["b"])


class TestBatchMatmulOperations:
"""Test suite for tile-level batch matrix multiplication.

Expand Down Expand Up @@ -260,6 +322,21 @@ def test_batch_matmul_a_transpose(self, test_runner, batch, m, k, n):
result = test_runner.run(test_case)
assert result.passed, f"Test failed (A-trans): {result.error}"

@pytest.mark.parametrize(
"batch,m,k,n",
[
(2, 64, 64, 64),
(4, 32, 64, 32),
# batch=1 falls through the existing batch_count==1 fast path.
(1, 64, 64, 64),
],
)
def test_batch_matmul_lhs_nd_rhs_2d(self, test_runner, batch, m, k, n):
"""lhs ND × rhs 2D triggers FlattenTileNdTo2D's rhs-hoist optimization."""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Use plain x instead of × in test text to avoid Ruff ambiguous-character warnings.

Suggested edit
-        """lhs ND × rhs 2D triggers FlattenTileNdTo2D's rhs-hoist optimization."""
+        """lhs ND x rhs 2D triggers FlattenTileNdTo2D's rhs-hoist optimization."""
...
-        assert result.passed, f"Test failed (lhs ND × rhs 2D): {result.error}"
+        assert result.passed, f"Test failed (lhs ND x rhs 2D): {result.error}"

Also applies to: 338-338

🧰 Tools
🪛 Ruff (0.15.12)

[warning] 335-335: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF002)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/st/runtime/test_batch_matmul.py` at line 335, Replace the Unicode
multiplication sign '×' with a plain ASCII 'x' in the docstrings of the batch
matmul tests to avoid Ruff ambiguous-character warnings; specifically update the
docstring that currently reads "lhs ND × rhs 2D triggers FlattenTileNdTo2D's
rhs-hoist optimization." (and the similar docstring at the nearby line) to use
"lhs ND x rhs 2D ..." so the test function/docstring text containing that phrase
in tests/st/runtime/test_batch_matmul.py is changed accordingly.

test_case = TestBatchMatmulLhsNdRhs2d(batch=batch, m=m, k=k, n=n)
result = test_runner.run(test_case)
assert result.passed, f"Test failed (lhs ND × rhs 2D): {result.error}"


if __name__ == "__main__":
pytest.main([__file__, "-v", "--forked"])
163 changes: 163 additions & 0 deletions tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,6 +1507,169 @@ def load_tile_shape(shape: list[int], transpose: bool) -> list[int]:
"expected_store_shapes"
]

def test_batch_matmul_rhs_2d_hoists_rhs_load(self):
"""lhs [B, M, K] x rhs [K, N] (B>1): rhs.load is hoisted out of the per-batch loop.

Saves (B-1) redundant rhs tile.load calls. Per-batch loop still emits B copies of
(lhs.load, matmul, store), but only ONE rhs.load is emitted for the whole sequence.
"""

@pl.program
class Before:
@pl.function(type=pl.FunctionType.InCore)
def main_incore_0(
self,
lhs: pl.Tensor[[2, 16, 128], pl.FP16],
rhs: pl.Tensor[[128, 64], pl.FP16],
out_0: pl.Out[pl.Tensor[[2, 16, 64], pl.FP16]],
) -> pl.Tensor[[2, 16, 64], pl.FP16]:
lhs_tile: pl.Tile[[2, 16, 128], pl.FP16] = pl.load(
lhs, [0, 0, 0], [2, 16, 128], target_memory=pl.MemorySpace.Mat
)
rhs_tile: pl.Tile[[128, 64], pl.FP16] = pl.load(
rhs, [0, 0], [128, 64], target_memory=pl.MemorySpace.Mat
)
out_tile: pl.Tile[[2, 16, 64], pl.FP32] = pl.tile.batch_matmul(lhs_tile, rhs_tile)
out_0 = pl.store(out_tile, [0, 0, 0], out_0)
return out_0

@pl.function
def main(
self,
lhs: pl.Tensor[[2, 16, 128], pl.FP16],
rhs: pl.Tensor[[128, 64], pl.FP16],
) -> pl.Tensor[[2, 16, 64], pl.FP16]:
out_0 = pl.create_tensor([2, 16, 64], dtype=pl.FP16)
y = self.main_incore_0(lhs, rhs, out_0)
return y

func = self._flattened_incore(Before)
calls = self._top_level_calls(func)
# Expect: 1 rhs.load (hoisted), then per-batch (lhs.load, matmul, store) × 2.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Replace ambiguous × characters in comments to satisfy Ruff.

These flagged comment lines should use plain x to avoid RUF003 warnings.

Suggested edit
-        # Expect: 1 rhs.load (hoisted), then per-batch (lhs.load, matmul, store) × 2.
+        # Expect: 1 rhs.load (hoisted), then per-batch (lhs.load, matmul, store) x 2.
...
-        # 1 hoisted rhs.load + 3 × (lhs.load, matmul, store).
+        # 1 hoisted rhs.load + 3 x (lhs.load, matmul, store).
...
-        # Both batched -> no hoist: 2 × (lhs.load, rhs.load, matmul, store).
+        # Both batched -> no hoist: 2 x (lhs.load, rhs.load, matmul, store).

Also applies to: 1609-1609, 1661-1661

🧰 Tools
🪛 Ruff (0.15.12)

[warning] 1548-1548: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py` at line 1548, In
test_flatten_tile_nd_to_2d.py replace the Unicode multiplication character '×'
used in comment lines with a plain ASCII 'x' to satisfy Ruff RUF003; update the
three comment occurrences (the one shown and the two other occurrences flagged)
so comments like "× 2" become "x 2" — no behavior change, only ASCII comment
character replacement.

assert [call.op.name for call in calls] == [
"tile.load", # hoisted rhs
"tile.load", # batch 0 lhs
"tile.matmul",
"tile.store",
"tile.load", # batch 1 lhs
"tile.matmul",
"tile.store",
]
load_calls = [call for call in calls if call.op.name == "tile.load"]
# First load is rhs (2D, no batch shift). Subsequent loads are per-batch lhs.
assert [self._tuple_const_values(call.args[1]) for call in load_calls] == [
[0, 0], # hoisted rhs offset
[0, 0, 0], # batch 0 lhs
[1, 0, 0], # batch 1 lhs
]
assert [self._tuple_const_values(call.args[2]) for call in load_calls] == [
[128, 64], # rhs shape
[1, 16, 128], # per-batch lhs shape (single batch slice)
[1, 16, 128],
]

def test_batch_matmul_rhs_size1_batch_hoists_rhs_load(self):
"""lhs [B, M, K] x rhs [1, K, N] (B>1): rhs hoist also fires for size-1-batch rhs.

rhs has a leading batch dim of 1 (broadcasted across all output batches), so
ExtractBatchPage with batch_idx=0 gives the same result every iteration. Hoist.
"""

@pl.program
class Before:
@pl.function(type=pl.FunctionType.InCore)
def main_incore_0(
self,
lhs: pl.Tensor[[3, 16, 128], pl.FP16],
rhs: pl.Tensor[[1, 128, 64], pl.FP16],
out_0: pl.Out[pl.Tensor[[3, 16, 64], pl.FP16]],
) -> pl.Tensor[[3, 16, 64], pl.FP16]:
lhs_tile: pl.Tile[[3, 16, 128], pl.FP16] = pl.load(
lhs, [0, 0, 0], [3, 16, 128], target_memory=pl.MemorySpace.Mat
)
rhs_tile: pl.Tile[[1, 128, 64], pl.FP16] = pl.load(
rhs, [0, 0, 0], [1, 128, 64], target_memory=pl.MemorySpace.Mat
)
out_tile: pl.Tile[[3, 16, 64], pl.FP32] = pl.tile.batch_matmul(lhs_tile, rhs_tile)
out_0 = pl.store(out_tile, [0, 0, 0], out_0)
return out_0

@pl.function
def main(
self,
lhs: pl.Tensor[[3, 16, 128], pl.FP16],
rhs: pl.Tensor[[1, 128, 64], pl.FP16],
) -> pl.Tensor[[3, 16, 64], pl.FP16]:
out_0 = pl.create_tensor([3, 16, 64], dtype=pl.FP16)
y = self.main_incore_0(lhs, rhs, out_0)
return y

func = self._flattened_incore(Before)
op_names = [call.op.name for call in self._top_level_calls(func)]
# 1 hoisted rhs.load + 3 × (lhs.load, matmul, store).
assert op_names == [
"tile.load", # hoisted rhs
"tile.load",
"tile.matmul",
"tile.store",
"tile.load",
"tile.matmul",
"tile.store",
"tile.load",
"tile.matmul",
"tile.store",
]
# Exactly one load whose shape is the rhs shape (rest are lhs).
load_calls = [call for call in self._top_level_calls(func) if call.op.name == "tile.load"]
rhs_loads = [call for call in load_calls if self._tuple_const_values(call.args[2]) == [1, 128, 64]]
assert len(rhs_loads) == 1, f"expected 1 hoisted rhs load, got {len(rhs_loads)}"

def test_batch_matmul_both_batched_no_hoist(self):
"""lhs [B, M, K] x rhs [B, K, N] (B>1): rhs varies per batch, hoist must NOT fire."""

@pl.program
class Before:
@pl.function(type=pl.FunctionType.InCore)
def main_incore_0(
self,
lhs: pl.Tensor[[2, 16, 128], pl.FP16],
rhs: pl.Tensor[[2, 128, 64], pl.FP16],
out_0: pl.Out[pl.Tensor[[2, 16, 64], pl.FP16]],
) -> pl.Tensor[[2, 16, 64], pl.FP16]:
lhs_tile: pl.Tile[[2, 16, 128], pl.FP16] = pl.load(
lhs, [0, 0, 0], [2, 16, 128], target_memory=pl.MemorySpace.Mat
)
rhs_tile: pl.Tile[[2, 128, 64], pl.FP16] = pl.load(
rhs, [0, 0, 0], [2, 128, 64], target_memory=pl.MemorySpace.Mat
)
out_tile: pl.Tile[[2, 16, 64], pl.FP32] = pl.tile.batch_matmul(lhs_tile, rhs_tile)
out_0 = pl.store(out_tile, [0, 0, 0], out_0)
return out_0

@pl.function
def main(
self,
lhs: pl.Tensor[[2, 16, 128], pl.FP16],
rhs: pl.Tensor[[2, 128, 64], pl.FP16],
) -> pl.Tensor[[2, 16, 64], pl.FP16]:
out_0 = pl.create_tensor([2, 16, 64], dtype=pl.FP16)
y = self.main_incore_0(lhs, rhs, out_0)
return y

func = self._flattened_incore(Before)
op_names = [call.op.name for call in self._top_level_calls(func)]
# Both batched -> no hoist: 2 × (lhs.load, rhs.load, matmul, store).
assert op_names == [
"tile.load",
"tile.load",
"tile.matmul",
"tile.store",
"tile.load",
"tile.load",
"tile.matmul",
"tile.store",
]


# ----------------------------------------------------------------------------
# tile.batch_matmul_acc lowering
Expand Down