Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/backend/common/pto_ops_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,19 @@ static codegen::TileTypeComponents InferSubviewTileTypeComponents(const ir::Tile
*out_dynamic = false;
return;
}
// offset=0 and slice shape matches the full source valid extent
if (offset_const && offset_const->value_ == 0 && dim_idx < source_valid.size() &&
ExprsEquivalentForSubview(shape_tuple.elements_[dim_idx], source_valid[dim_idx])) {
*out_value = size;
*out_dynamic = false;
return;
}
// Any valid offset leaves exactly `size` valid elements in this dimension,
// so v_row/v_col is statically known regardless of the offset value.
if (valid_const && valid_const->value_ >= size) {
*out_value = size;
*out_dynamic = false;
return;
}
};

Expand Down
35 changes: 35 additions & 0 deletions tests/ut/codegen/test_pto_codegen_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,41 @@ def kernel(
f"subview result type should be rows=4,cols=64 (slice shape), got:\n{line}"
)

def test_tile_slice_static_valid_inferred_with_dynamic_offset(self):
"""tile.slice with dynamic offset should infer static v_row/v_col when source valid >= slice size.

Regression test for: InferSubviewTileTypeComponents was missing the case where
the source valid extent is a static constant >= the requested slice size but the
offset is dynamic. Before the fix, v_row/v_col were left as dynamic ('?') even
though any offset within bounds leaves exactly `size` valid elements.
"""

@pl.program
class Prog:
@pl.function(type=pl.FunctionType.InCore)
def kernel(
self,
src: pl.Tensor[[32, 32], pl.FP32],
row_off: pl.Scalar[pl.INDEX],
col_off: pl.Scalar[pl.INDEX],
dst: pl.Tensor[[16, 16], pl.FP32],
) -> pl.Tensor[[16, 16], pl.FP32]:
src_tile: pl.Tile[[32, 32], pl.FP32] = pl.load(src, [0, 0], [32, 32])
sliced: pl.Tile[[16, 16], pl.FP32] = pl.tile.slice(src_tile, [16, 16], [row_off, col_off])
return pl.store(sliced, [0, 0], dst)

mlir = self._generate_mlir(Prog)
subview_lines = [line for line in mlir.splitlines() if "pto.subview" in line]
assert subview_lines, "no pto.subview line emitted"
result_type = subview_lines[0].split("->", 1)[-1]
assert "v_row=16" in result_type and "v_col=16" in result_type, (
f"Source valid [32,32] >= slice size [16,16] with dynamic offset must yield "
f"static v_row=16, v_col=16 in result tile_buf type, got:\n{subview_lines[0]}"
)
assert "v_row=?" not in result_type and "v_col=?" not in result_type, (
f"v_row/v_col must not be dynamic ('?') when source valid >= slice size, got:\n{subview_lines[0]}"
)


class TestTileAssembleCodegen:
"""Tests for tile.assemble PTO code generation (pto.subview + pto.tmov).
Expand Down
Loading