Skip to content

[Bug] TileType memory_space presence mismatch on print/parse roundtrip after auto-flatten of Mat tile.load #1278

@lyfne123

Description

@lyfne123

Component

Printer

Description

After PR #1266 canonicalized TileType.tile_view at construction (fixing #1261's tile_view presence mismatch), a related but distinct print/parse roundtrip bug remains: memory_space presence mismatch between Var.type and the TileType produced by re-deducing the right-hand side tile.load Call.

PyPTO has two in-memory TileType encodings that represent the same semantic state (a 2D Mat-resident tile loaded from a rank>2 tensor) but which are NOT print/parse equivalent:

Form A — auto-flattened (flatten_tile_nd_to_2d_pass.cpp:1369-1387 tile.load handler):

auto flat_tile_view = std::make_optional(
    tile_view_semantics::GetImplicitTileView(flat_shape_exprs, result_tile->memory_space_));
auto flat_tile_type = std::make_shared<TileType>(flat_shape_exprs, result_tile->dtype_, std::nullopt,
                                                  flat_tile_view, result_tile->memory_space_);

result_tile->memory_space_ is nullopt (DeduceTileLoadType doesn't set it). flat_tile_view is therefore the implicit form for (shape, nullopt). The constructor's CanonicalizeTileViewInPlace collapses it to nullopt. Final: TileType(shape, dtype, memref=nullopt, tile_view=nullopt, memory_space=nullopt). Printed as bare pl.Tile[shape, dtype].

Form B — Strategy 1 transcribed (flatten_tile_nd_to_2d_pass.cpp:580-627 ExtractBatchPage):

flat_tile_view = TileView(flat_shape_exprs, /*stride=*/{}, /*start_offset=*/nullptr,
                          orig_tv.blayout, orig_tv.slayout, orig_tv.fractal, orig_tv.pad);
auto flat_type = std::make_shared<TileType>(
    flat_shape_exprs, ..., flat_tile_view, batch_load_type ? batch_load_type->memory_space_ : ...);

The transcribed tile_view carries blayout=col_major, slayout=row_major from the original Mat-load deducer's tile_view. Canonicalize does NOT collapse — col_major/row_major is non-implicit when memory_space=nullopt (implicit is row_major/none_box). Final: TileType(shape, dtype, memref=nullopt, tile_view=explicit, memory_space=nullopt). Printed with explicit pl.TileView(blayout=col_major, slayout=row_major) annotation.

When Form A is roundtripped through print → parse, the parser produces a TileType whose memory_space.has_value() differs from the original (presence mismatch on body[0].var). Form B roundtrips cleanly because the explicit TileView annotation carries enough information for the parser to reconstruct the same canonical type.

Steps to Reproduce

The bug is currently dormant because PR #1276 explicitly avoids Form A in FlattenTileNdTo2D::LowerBatchMatmul by gating its single-matmul fast path on batch_count_lhs > 1, leaving the batch_count == 1 case to Strategy 1 (Form B). To exercise the bug, remove that gate so the unified fast path also fires for batch_count_lhs == 1:

  1. In src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp, weaken the new fast-path condition to batch_count_rhs == 1 && !lhs_info.transpose (drop the batch_count_lhs > 1 clause). Also adjust the corresponding WillTakeSingleMatmulFastPath predicate in the pre-scan.
  2. cmake --build build --parallel
    export PYTHONPATH=$(pwd)/python
    python -m pytest tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py::TestFlattenTileNdTo2DBatchMatmul::test_batch_matmul_unrolls_kwargs -x --tb=long
  3. The [single_batch] parametrization (lhs [1,16,128], rhs [1,128,64], both Mat-loaded) fails inside RoundtripInstrument.

Expected Behavior

assert_structural_equal should be reflexive on roundtripped IR. Two semantically-equivalent in-memory TileType encodings must roundtrip to the same canonical form so that RoundtripInstrument accepts both as the contract for any pass output.

Actual Behavior

RuntimeError: [RoundtripInstrument] Structural equality failed after pass 'FlattenTileNdTo2D'.

Error: Structural equality assertion failed at: ['main_incore_0'].body[0].var

Reason: TileType memory_space presence mismatch

The printed text for body[0] is:

lhs_tile: pl.Tile[[16, 128], pl.FP16] = pl.tile.load(lhs, [0, 0, 0], [1, 16, 128], [1, 16, 128], target_memory=pl.Mem.Mat, transpose=False)

The Var-annotation side parses to TileType(memory_space=nullopt); the Call's re-deduced result type ends up with memory_space.has_value() differing — hence the presence mismatch.

Git Commit ID

e6ef4fb

NPU Kind

N/A (not hardware-specific)

Host Platform

macOS (aarch64)

Additional Context

Suggested fixes (need design discussion — not picking one in this issue):

  1. Make DeduceTileLoadType write target_memory directly into memory_space_ when present (src/ir/op/tile_ops/memory.cpp:188). Eliminates the "Mat-style tile_view but memory_space=nullopt" intermediate state at construction. Would need to verify that downstream InferTileMemorySpace is OK with memory_space already being set.
  2. Make the parser backfill target_memory into the result Var's TileType.memory_space_ when it sees pl.tile.load(target_memory=X, ...). Symmetric with option 1 but on the parse side.
  3. Tighten CanonicalizeTileViewInPlace: don't collapse tile_view to nullopt when its layout is Mat-style and memory_space is nullopt — keep it explicit so the printer emits the TileView annotation that Form B produces. This effectively makes Form A converge to Form B at construction time.

Workaround in PR #1276:

const bool can_single_matmul =
    batch_count_lhs > 1 && batch_count_rhs == 1 && !lhs_info.transpose;

The batch_count_lhs > 1 clause exists solely to avoid Form A on Mat-loaded operands. Once this issue is fixed, the workaround can be relaxed.

Severity: medium — currently dormant, but blocks any future pass / test from producing auto-flattened Mat tiles in an assert_structural_equal or RoundtripInstrument context.

Code locations:

  • src/ir/op/tile_ops/memory.cpp:188DeduceTileLoadType (omits memory_space)
  • src/ir/type.cpp:56-62CanonicalizeTileViewInPlace (collapses tile_view to nullopt)
  • src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp:1369-1387 — auto-flatten construction site
  • src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp:580-627 — Strategy 1 (the working form)
  • include/pypto/ir/tile_view_semantics.h:62-89GetImplicitTileView

Related: #1261 (closed — same family of roundtrip bug, but on tile_view presence; fixed by #1266), #1262 (open — Var/Call memory_space asymmetry from InferTileMemorySpace), #1276 (the PR that workaround-ed this).

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    Status
    Ready

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions