From 0aa895276a0902c49d009310e885cff8a2f47602 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 5 Jun 2026 09:27:57 -0700 Subject: [PATCH 1/2] [Tile] Use unpacked vector field for Tile16x16/Tile32x32 register storage Replace hand-rolled ``r0..rN-1: dtype`` field declarations and their matching ``if k == 0: self.r0 = val; ...`` cascades with a single ``r: qd.types.vector(_TILE, dtype, unpacked=True)`` field accessed via ``self.r[k]``. This shrinks the surface area significantly (net -870 lines) without changing the generated PTX/LLVM IR: with python-int / qd.static-resolved indices the unpacked field still maps to one register slot per use, matching what the explicit cascade produced. Also removes the now-redundant private helpers ``_get_col``, ``_set_col``, ``_r`` and the ``_REGS`` field-name table. --- python/quadrants/lang/simt/_tile16.py | 358 ++------------ python/quadrants/lang/simt/_tile32.py | 642 ++------------------------ 2 files changed, 65 insertions(+), 935 deletions(-) diff --git a/python/quadrants/lang/simt/_tile16.py b/python/quadrants/lang/simt/_tile16.py index 325e5bf61d..d0422a23f6 100644 --- a/python/quadrants/lang/simt/_tile16.py +++ b/python/quadrants/lang/simt/_tile16.py @@ -4,7 +4,8 @@ Register-resident 16x16 tile operations. Each tile is a 16x16 matrix distributed across 16 threads in a subgroup, one row per thread, with each row stored -in 16 scalar registers (r0-r15). Cross-thread communication uses subgroup shuffles -- no shared memory needed. +in 16 scalar registers held in an unpacked vector field (``self.r``). Cross-thread communication uses subgroup +shuffles -- no shared memory needed. The thread's lane index (tid) is obtained internally via ``subgroup.invocation_id()``, so callers never need to pass it. See docs/source/user_guide/tile.md for usage documentation. @@ -40,8 +41,6 @@ def _load3d( def _store3d( self, arr: Any, batch: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any ) -> None: ... # noqa: E704 - def _get_col(self, k: Any) -> Any: ... # noqa: E704 - def _set_col(self, k: Any, val: Any) -> None: ... # noqa: E704 def _ger_sub(self, a: Any, b: Any) -> None: ... # noqa: E704 def _trsm(self, L: "_Tile16x16Proto") -> None: ... # noqa: E704 def __isub__(self, other: Any) -> "_Tile16x16Proto": ... # noqa: E704 @@ -51,14 +50,6 @@ def __setitem__(self, key: Any, value: Any) -> None: ... # noqa: E704 _TILE = 16 -# Field-name lookup table for direct register access in qd.static-unrolled loops. Used via ``self._r(k)`` (defined -# below) which is just ``getattr(self, _REGS[k])``. With a python-int ``k`` (which is what ``qd.static(range(_TILE))`` -# binds inside its body) this collapses to a single field-reference AST node, vs. the _TILE-way ``if k == 0: val = -# self.r0; if k == 1: ...`` cascade emitted by a dynamic ``_get_col(k)`` call. Empirically this cuts cold-compile time -# significantly on hot Cholesky paths because every such call site avoids re-emitting (and later folding) _TILE -# conditional nodes per use. -_REGS = tuple(f"r{i}" for i in range(_TILE)) - class _OuterProduct: """Deferred outer product proxy for use with augmented assignment on Tile16x16. @@ -210,25 +201,10 @@ def _make_tile16x16(dtype=None) -> "type[_Tile16x16Proto]": def _make_tile16x16_class(dtype): class _Tile16x16: - """A 16x16 tile distributed one row per subgroup thread, held in 16 scalar registers. All fields default to - 0.0 when omitted: ``Tile16x16()`` creates a zero tile.""" - - r0: dtype - r1: dtype - r2: dtype - r3: dtype - r4: dtype - r5: dtype - r6: dtype - r7: dtype - r8: dtype - r9: dtype - r10: dtype - r11: dtype - r12: dtype - r13: dtype - r14: dtype - r15: dtype + """A 16x16 tile distributed one row per subgroup thread, with each row held in 16 scalar registers via an + unpacked vector field. ``Tile16x16()`` creates a zero tile.""" + + r: qd.types.vector(_TILE, dtype, unpacked=True) @qd.func def _load(self, arr: qd.template(), row_start, row_stop, col_start, col_stop): @@ -245,43 +221,9 @@ def _load(self, arr: qd.template(), row_start, row_stop, col_start, col_stop): arr_col_stop = arr.shape[1] if arr_col_stop < col_stop: col_stop = arr_col_stop - # Inline cascade: with j a python-int from qd.static, only the matching branch is emitted into the AST. - # Avoids the _TILE-way duplication that calling _set_col(j) through the @qd.func boundary would force. for j in qd.static(range(_TILE)): if col_start + j < col_stop: - val = arr[row, col_start + j] - if j == 0: - self.r0 = val - if j == 1: - self.r1 = val - if j == 2: - self.r2 = val - if j == 3: - self.r3 = val - if j == 4: - self.r4 = val - if j == 5: - self.r5 = val - if j == 6: - self.r6 = val - if j == 7: - self.r7 = val - if j == 8: - self.r8 = val - if j == 9: - self.r9 = val - if j == 10: - self.r10 = val - if j == 11: - self.r11 = val - if j == 12: - self.r12 = val - if j == 13: - self.r13 = val - if j == 14: - self.r14 = val - if j == 15: - self.r15 = val + self.r[j] = arr[row, col_start + j] @qd.func def _load3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, col_stop): @@ -300,39 +242,7 @@ def _load3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, col col_stop = arr_col_stop for j in qd.static(range(_TILE)): if col_start + j < col_stop: - val = arr[batch, row, col_start + j] - if j == 0: - self.r0 = val - if j == 1: - self.r1 = val - if j == 2: - self.r2 = val - if j == 3: - self.r3 = val - if j == 4: - self.r4 = val - if j == 5: - self.r5 = val - if j == 6: - self.r6 = val - if j == 7: - self.r7 = val - if j == 8: - self.r8 = val - if j == 9: - self.r9 = val - if j == 10: - self.r10 = val - if j == 11: - self.r11 = val - if j == 12: - self.r12 = val - if j == 13: - self.r13 = val - if j == 14: - self.r14 = val - if j == 15: - self.r15 = val + self.r[j] = arr[batch, row, col_start + j] @qd.func def _store(self, arr: qd.template(), row_start, row_stop, col_start, col_stop): @@ -351,7 +261,7 @@ def _store(self, arr: qd.template(), row_start, row_stop, col_start, col_stop): col_stop = arr_col_stop for j in qd.static(range(_TILE)): if col_start + j < col_stop: - arr[row, col_start + j] = self._r(j) + arr[row, col_start + j] = self.r[j] @qd.func def _store3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, col_stop): @@ -370,7 +280,7 @@ def _store3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, co col_stop = arr_col_stop for j in qd.static(range(_TILE)): if col_start + j < col_stop: - arr[batch, row, col_start + j] = self._r(j) + arr[batch, row, col_start + j] = self.r[j] @qd.func def eye_(self): @@ -378,152 +288,14 @@ def eye_(self): others to 0.0.""" tid = qd.simt.subgroup.invocation_id() for j in qd.static(range(_TILE)): - val = 1.0 if tid == j else 0.0 - if j == 0: - self.r0 = val - if j == 1: - self.r1 = val - if j == 2: - self.r2 = val - if j == 3: - self.r3 = val - if j == 4: - self.r4 = val - if j == 5: - self.r5 = val - if j == 6: - self.r6 = val - if j == 7: - self.r7 = val - if j == 8: - self.r8 = val - if j == 9: - self.r9 = val - if j == 10: - self.r10 = val - if j == 11: - self.r11 = val - if j == 12: - self.r12 = val - if j == 13: - self.r13 = val - if j == 14: - self.r14 = val - if j == 15: - self.r15 = val - - @qd.func - def _get_col(self, k): - """Return the value of register (column) k.""" - val = qd.cast(0.0, dtype) - if k == 0: - val = self.r0 - if k == 1: - val = self.r1 - if k == 2: - val = self.r2 - if k == 3: - val = self.r3 - if k == 4: - val = self.r4 - if k == 5: - val = self.r5 - if k == 6: - val = self.r6 - if k == 7: - val = self.r7 - if k == 8: - val = self.r8 - if k == 9: - val = self.r9 - if k == 10: - val = self.r10 - if k == 11: - val = self.r11 - if k == 12: - val = self.r12 - if k == 13: - val = self.r13 - if k == 14: - val = self.r14 - if k == 15: - val = self.r15 - return val - - @qd.func - def _set_col(self, k, val): - """Set register (column) k to val.""" - if k == 0: - self.r0 = val - if k == 1: - self.r1 = val - if k == 2: - self.r2 = val - if k == 3: - self.r3 = val - if k == 4: - self.r4 = val - if k == 5: - self.r5 = val - if k == 6: - self.r6 = val - if k == 7: - self.r7 = val - if k == 8: - self.r8 = val - if k == 9: - self.r9 = val - if k == 10: - self.r10 = val - if k == 11: - self.r11 = val - if k == 12: - self.r12 = val - if k == 13: - self.r13 = val - if k == 14: - self.r14 = val - if k == 15: - self.r15 = val + self.r[j] = 1.0 if tid == j else 0.0 @qd.func def _ger_sub(self, a, b): """General rank-1 subtract in-place: self -= a @ b^T.""" for j in qd.static(range(_TILE)): bc = qd.simt.subgroup.shuffle(b, qd.u32(j)) - val = self._r(j) - a * bc - if j == 0: - self.r0 = val - if j == 1: - self.r1 = val - if j == 2: - self.r2 = val - if j == 3: - self.r3 = val - if j == 4: - self.r4 = val - if j == 5: - self.r5 = val - if j == 6: - self.r6 = val - if j == 7: - self.r7 = val - if j == 8: - self.r8 = val - if j == 9: - self.r9 = val - if j == 10: - self.r10 = val - if j == 11: - self.r11 = val - if j == 12: - self.r12 = val - if j == 13: - self.r13 = val - if j == 14: - self.r14 = val - if j == 15: - self.r15 = val + self.r[j] = self.r[j] - a * bc @qd.func def cholesky_(self, eps): @@ -532,52 +304,19 @@ def cholesky_(self, eps): On return, the lower triangle holds L such that A = L @ L^T. Diagonal clamped to sqrt(max(value, eps)) for numerical stability. """ - # ``k`` and ``j`` are wrapped in qd.static so the ``if k > j`` predicates fold at compile time and register - # access on the outer ``k`` and inner ``j`` collapses to a single field reference via ``self._r()`` - # (a thin getattr wrapper) rather than a _TILE-deep register-indexing cascade. Writes use an inline - # ``if k == N: self.rN = ...`` chain (setattr is rejected by the quadrants AST builder) which the AST - # transformer folds at build time when ``k`` is a python int. The per-lane row-norm used for the diagonal - # update is carried in ``my_norm_sq``, so each diagonal step is O(1) rather than O(k). The off-diagonal - # ``dot`` is split into two interleaved partial sums (``dot0`` / ``dot1``) so the back-to-back FMA - # dependency chain is cut in half, exposing more instruction-level parallelism. + # ``k`` and ``j`` are wrapped in qd.static so the ``if k > j`` predicate folds at compile time and the + # ``self.r[k]`` / ``self.r[j]`` accesses resolve to a single unpacked-register slot per use (no runtime + # cascade). The per-lane row-norm used for the diagonal update is carried in ``my_norm_sq``, so each + # diagonal step is O(1) rather than O(k). The off-diagonal ``dot`` is split into two interleaved partial + # sums (``dot0`` / ``dot1``) so the back-to-back FMA dependency chain is cut in half, exposing more + # instruction-level parallelism. tid = qd.i32(qd.simt.subgroup.invocation_id()) my_norm_sq = qd.cast(0.0, dtype) for k in qd.static(range(_TILE)): diag_val = qd.cast(0.0, dtype) if tid == k: - diag_val = qd.sqrt(qd.max(self._r(k) - my_norm_sq, eps)) - if k == 0: - self.r0 = diag_val - if k == 1: - self.r1 = diag_val - if k == 2: - self.r2 = diag_val - if k == 3: - self.r3 = diag_val - if k == 4: - self.r4 = diag_val - if k == 5: - self.r5 = diag_val - if k == 6: - self.r6 = diag_val - if k == 7: - self.r7 = diag_val - if k == 8: - self.r8 = diag_val - if k == 9: - self.r9 = diag_val - if k == 10: - self.r10 = diag_val - if k == 11: - self.r11 = diag_val - if k == 12: - self.r12 = diag_val - if k == 13: - self.r13 = diag_val - if k == 14: - self.r14 = diag_val - if k == 15: - self.r15 = diag_val + diag_val = qd.sqrt(qd.max(self.r[k] - my_norm_sq, eps)) + self.r[k] = diag_val diag_k = qd.simt.subgroup.shuffle(diag_val, qd.u32(k)) @@ -585,7 +324,7 @@ def cholesky_(self, eps): dot1 = qd.cast(0.0, dtype) for j in qd.static(range(_TILE)): if k > j: - my_col = self._r(j) + my_col = self.r[j] Lkj = qd.simt.subgroup.shuffle(my_col, qd.u32(k)) if j % 2 == 0: dot0 += Lkj * my_col # type: ignore[reportOperatorIssue] @@ -595,39 +334,8 @@ def cholesky_(self, eps): new_val = qd.cast(0.0, dtype) if tid > k: # type: ignore[reportOperatorIssue] - new_val = (self._r(k) - dot) / diag_k # type: ignore[reportOperatorIssue] - if k == 0: - self.r0 = new_val - if k == 1: - self.r1 = new_val - if k == 2: - self.r2 = new_val - if k == 3: - self.r3 = new_val - if k == 4: - self.r4 = new_val - if k == 5: - self.r5 = new_val - if k == 6: - self.r6 = new_val - if k == 7: - self.r7 = new_val - if k == 8: - self.r8 = new_val - if k == 9: - self.r9 = new_val - if k == 10: - self.r10 = new_val - if k == 11: - self.r11 = new_val - if k == 12: - self.r12 = new_val - if k == 13: - self.r13 = new_val - if k == 14: - self.r14 = new_val - if k == 15: - self.r15 = new_val + new_val = (self.r[k] - dot) / diag_k # type: ignore[reportOperatorIssue] + self.r[k] = new_val if tid > k: # type: ignore[reportOperatorIssue] my_norm_sq += new_val * new_val @@ -638,16 +346,15 @@ def _trsm(self, L): L is a Tile16x16 holding the lower-triangular Cholesky factor (from cholesky_). On return, self holds the solution X. """ - for c in range(_TILE): + for c in qd.static(range(_TILE)): dot = qd.cast(0.0, dtype) - for j in range(_TILE): + for j in qd.static(range(_TILE)): if c > j: - Lkj = qd.simt.subgroup.shuffle(L._get_col(j), qd.u32(c)) - dot += self._get_col(j) * Lkj # type: ignore[reportOperatorIssue] + Lkj = qd.simt.subgroup.shuffle(L.r[j], qd.u32(c)) + dot += self.r[j] * Lkj # type: ignore[reportOperatorIssue] - diag_c = qd.simt.subgroup.shuffle(L._get_col(c), qd.u32(c)) - new_val = (self._get_col(c) - dot) / diag_c # type: ignore[reportOperatorIssue] - self._set_col(c, new_val) + diag_c = qd.simt.subgroup.shuffle(L.r[c], qd.u32(c)) + self.r[c] = (self.r[c] - dot) / diag_c # type: ignore[reportOperatorIssue] def solve_triangular_(self, B: Any, lower: bool = True) -> None: """Triangular solve: X @ self^T = B, storing result X in B in-place. @@ -659,13 +366,6 @@ def solve_triangular_(self, B: Any, lower: bool = True) -> None: raise TypeError("Tile16x16.solve_triangular_: only lower=True is supported") B._trsm(self) - def _r(self, k): - """Direct field read by python-int index. Used at qd.static-unrolled call sites to bypass the _TILE-way - ``_get_col(k)`` cascade: with ``k`` a python int (from ``qd.static(range(_TILE))``), - ``getattr(self, _REGS[k])`` is evaluated by the AST transformer at build time and returns a single - field-reference expression.""" - return getattr(self, _REGS[k]) - @qd.func def _resolve_vec2d(self, arr: qd.template(), row_start, row_stop, col): """Load one scalar per thread from a 2D array column, clamped to array bounds.""" diff --git a/python/quadrants/lang/simt/_tile32.py b/python/quadrants/lang/simt/_tile32.py index 9808f462ac..2d03f63f87 100644 --- a/python/quadrants/lang/simt/_tile32.py +++ b/python/quadrants/lang/simt/_tile32.py @@ -4,7 +4,8 @@ Register-resident 32x32 tile operations. Each tile is a 32x32 matrix distributed across 32 threads in a subgroup, one row per thread, with each row stored -in 32 scalar registers (r0-r31). Cross-thread communication uses subgroup shuffles -- no shared memory needed. +in 32 scalar registers held in an unpacked vector field (``self.r``). Cross-thread communication uses subgroup +shuffles -- no shared memory needed. Surface mirrors :mod:`quadrants.lang.simt._tile16`: use ``qd.simt.Tile32x32.zeros(dtype=...)`` / ``qd.simt.Tile32x32.eye(dtype=...)`` inside a kernel, then load / store via slice syntax @@ -48,8 +49,6 @@ def _load3d( def _store3d( self, arr: Any, batch: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any ) -> None: ... # noqa: E704 - def _get_col(self, k: Any) -> Any: ... # noqa: E704 - def _set_col(self, k: Any, val: Any) -> None: ... # noqa: E704 def _ger_sub(self, a: Any, b: Any) -> None: ... # noqa: E704 def _trsm(self, L: "_Tile32x32Proto") -> None: ... # noqa: E704 def __isub__(self, other: Any) -> "_Tile32x32Proto": ... # noqa: E704 @@ -59,11 +58,6 @@ def __setitem__(self, key: Any, value: Any) -> None: ... # noqa: E704 _TILE = 32 -# Field-name lookup table for direct register access in qd.static-unrolled loops. See ``_tile16._REGS`` for the -# rationale; the same trick applies here, just with 32 fields instead of 16. -_REGS = tuple(f"r{i}" for i in range(_TILE)) - - _tile32_cache = {} @@ -83,41 +77,10 @@ def _make_tile32x32(dtype=None) -> "type[_Tile32x32Proto]": def _make_tile32x32_class(dtype): class _Tile32x32: - """A 32x32 tile distributed one row per subgroup thread, held in 32 scalar registers. All fields default to - 0.0 when omitted: ``Tile32x32()`` creates a zero tile.""" - - r0: dtype - r1: dtype - r2: dtype - r3: dtype - r4: dtype - r5: dtype - r6: dtype - r7: dtype - r8: dtype - r9: dtype - r10: dtype - r11: dtype - r12: dtype - r13: dtype - r14: dtype - r15: dtype - r16: dtype - r17: dtype - r18: dtype - r19: dtype - r20: dtype - r21: dtype - r22: dtype - r23: dtype - r24: dtype - r25: dtype - r26: dtype - r27: dtype - r28: dtype - r29: dtype - r30: dtype - r31: dtype + """A 32x32 tile distributed one row per subgroup thread, with each row held in 32 scalar registers via an + unpacked vector field. ``Tile32x32()`` creates a zero tile.""" + + r: qd.types.vector(_TILE, dtype, unpacked=True) @qd.func def _load(self, arr: qd.template(), row_start, row_stop, col_start, col_stop): @@ -134,75 +97,9 @@ def _load(self, arr: qd.template(), row_start, row_stop, col_start, col_stop): arr_col_stop = arr.shape[1] if arr_col_stop < col_stop: col_stop = arr_col_stop - # Inline cascade: with j a python-int from qd.static, only the matching branch is emitted into the AST. - # Avoids the 32x duplication that calling _set_col(j) through the @qd.func boundary would force. - for j in qd.static(range(32)): + for j in qd.static(range(_TILE)): if col_start + j < col_stop: - val = arr[row, col_start + j] - if j == 0: - self.r0 = val - if j == 1: - self.r1 = val - if j == 2: - self.r2 = val - if j == 3: - self.r3 = val - if j == 4: - self.r4 = val - if j == 5: - self.r5 = val - if j == 6: - self.r6 = val - if j == 7: - self.r7 = val - if j == 8: - self.r8 = val - if j == 9: - self.r9 = val - if j == 10: - self.r10 = val - if j == 11: - self.r11 = val - if j == 12: - self.r12 = val - if j == 13: - self.r13 = val - if j == 14: - self.r14 = val - if j == 15: - self.r15 = val - if j == 16: - self.r16 = val - if j == 17: - self.r17 = val - if j == 18: - self.r18 = val - if j == 19: - self.r19 = val - if j == 20: - self.r20 = val - if j == 21: - self.r21 = val - if j == 22: - self.r22 = val - if j == 23: - self.r23 = val - if j == 24: - self.r24 = val - if j == 25: - self.r25 = val - if j == 26: - self.r26 = val - if j == 27: - self.r27 = val - if j == 28: - self.r28 = val - if j == 29: - self.r29 = val - if j == 30: - self.r30 = val - if j == 31: - self.r31 = val + self.r[j] = arr[row, col_start + j] @qd.func def _load3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, col_stop): @@ -219,73 +116,9 @@ def _load3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, col arr_col_stop = arr.shape[2] if arr_col_stop < col_stop: col_stop = arr_col_stop - for j in qd.static(range(32)): + for j in qd.static(range(_TILE)): if col_start + j < col_stop: - val = arr[batch, row, col_start + j] - if j == 0: - self.r0 = val - if j == 1: - self.r1 = val - if j == 2: - self.r2 = val - if j == 3: - self.r3 = val - if j == 4: - self.r4 = val - if j == 5: - self.r5 = val - if j == 6: - self.r6 = val - if j == 7: - self.r7 = val - if j == 8: - self.r8 = val - if j == 9: - self.r9 = val - if j == 10: - self.r10 = val - if j == 11: - self.r11 = val - if j == 12: - self.r12 = val - if j == 13: - self.r13 = val - if j == 14: - self.r14 = val - if j == 15: - self.r15 = val - if j == 16: - self.r16 = val - if j == 17: - self.r17 = val - if j == 18: - self.r18 = val - if j == 19: - self.r19 = val - if j == 20: - self.r20 = val - if j == 21: - self.r21 = val - if j == 22: - self.r22 = val - if j == 23: - self.r23 = val - if j == 24: - self.r24 = val - if j == 25: - self.r25 = val - if j == 26: - self.r26 = val - if j == 27: - self.r27 = val - if j == 28: - self.r28 = val - if j == 29: - self.r29 = val - if j == 30: - self.r30 = val - if j == 31: - self.r31 = val + self.r[j] = arr[batch, row, col_start + j] @qd.func def _store(self, arr: qd.template(), row_start, row_stop, col_start, col_stop): @@ -302,9 +135,9 @@ def _store(self, arr: qd.template(), row_start, row_stop, col_start, col_stop): arr_col_stop = arr.shape[1] if arr_col_stop < col_stop: col_stop = arr_col_stop - for j in qd.static(range(32)): + for j in qd.static(range(_TILE)): if col_start + j < col_stop: - arr[row, col_start + j] = self._r(j) + arr[row, col_start + j] = self.r[j] @qd.func def _store3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, col_stop): @@ -321,290 +154,24 @@ def _store3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, co arr_col_stop = arr.shape[2] if arr_col_stop < col_stop: col_stop = arr_col_stop - for j in qd.static(range(32)): + for j in qd.static(range(_TILE)): if col_start + j < col_stop: - arr[batch, row, col_start + j] = self._r(j) + arr[batch, row, col_start + j] = self.r[j] @qd.func def eye_(self): """Set this tile to the 32x32 identity matrix. Each thread sets its diagonal element to 1.0 and all others to 0.0.""" tid = qd.simt.subgroup.invocation_id() - for j in qd.static(range(32)): - val = 1.0 if tid == j else 0.0 - if j == 0: - self.r0 = val - if j == 1: - self.r1 = val - if j == 2: - self.r2 = val - if j == 3: - self.r3 = val - if j == 4: - self.r4 = val - if j == 5: - self.r5 = val - if j == 6: - self.r6 = val - if j == 7: - self.r7 = val - if j == 8: - self.r8 = val - if j == 9: - self.r9 = val - if j == 10: - self.r10 = val - if j == 11: - self.r11 = val - if j == 12: - self.r12 = val - if j == 13: - self.r13 = val - if j == 14: - self.r14 = val - if j == 15: - self.r15 = val - if j == 16: - self.r16 = val - if j == 17: - self.r17 = val - if j == 18: - self.r18 = val - if j == 19: - self.r19 = val - if j == 20: - self.r20 = val - if j == 21: - self.r21 = val - if j == 22: - self.r22 = val - if j == 23: - self.r23 = val - if j == 24: - self.r24 = val - if j == 25: - self.r25 = val - if j == 26: - self.r26 = val - if j == 27: - self.r27 = val - if j == 28: - self.r28 = val - if j == 29: - self.r29 = val - if j == 30: - self.r30 = val - if j == 31: - self.r31 = val - - @qd.func - def _get_col(self, k): - """Return the value of register (column) k.""" - val = qd.cast(0.0, dtype) - if k == 0: - val = self.r0 - if k == 1: - val = self.r1 - if k == 2: - val = self.r2 - if k == 3: - val = self.r3 - if k == 4: - val = self.r4 - if k == 5: - val = self.r5 - if k == 6: - val = self.r6 - if k == 7: - val = self.r7 - if k == 8: - val = self.r8 - if k == 9: - val = self.r9 - if k == 10: - val = self.r10 - if k == 11: - val = self.r11 - if k == 12: - val = self.r12 - if k == 13: - val = self.r13 - if k == 14: - val = self.r14 - if k == 15: - val = self.r15 - if k == 16: - val = self.r16 - if k == 17: - val = self.r17 - if k == 18: - val = self.r18 - if k == 19: - val = self.r19 - if k == 20: - val = self.r20 - if k == 21: - val = self.r21 - if k == 22: - val = self.r22 - if k == 23: - val = self.r23 - if k == 24: - val = self.r24 - if k == 25: - val = self.r25 - if k == 26: - val = self.r26 - if k == 27: - val = self.r27 - if k == 28: - val = self.r28 - if k == 29: - val = self.r29 - if k == 30: - val = self.r30 - if k == 31: - val = self.r31 - return val - - @qd.func - def _set_col(self, k, val): - """Set register (column) k to val.""" - if k == 0: - self.r0 = val - if k == 1: - self.r1 = val - if k == 2: - self.r2 = val - if k == 3: - self.r3 = val - if k == 4: - self.r4 = val - if k == 5: - self.r5 = val - if k == 6: - self.r6 = val - if k == 7: - self.r7 = val - if k == 8: - self.r8 = val - if k == 9: - self.r9 = val - if k == 10: - self.r10 = val - if k == 11: - self.r11 = val - if k == 12: - self.r12 = val - if k == 13: - self.r13 = val - if k == 14: - self.r14 = val - if k == 15: - self.r15 = val - if k == 16: - self.r16 = val - if k == 17: - self.r17 = val - if k == 18: - self.r18 = val - if k == 19: - self.r19 = val - if k == 20: - self.r20 = val - if k == 21: - self.r21 = val - if k == 22: - self.r22 = val - if k == 23: - self.r23 = val - if k == 24: - self.r24 = val - if k == 25: - self.r25 = val - if k == 26: - self.r26 = val - if k == 27: - self.r27 = val - if k == 28: - self.r28 = val - if k == 29: - self.r29 = val - if k == 30: - self.r30 = val - if k == 31: - self.r31 = val + for j in qd.static(range(_TILE)): + self.r[j] = 1.0 if tid == j else 0.0 @qd.func def _ger_sub(self, a, b): """General rank-1 subtract in-place: self -= a @ b^T.""" - for j in qd.static(range(32)): + for j in qd.static(range(_TILE)): bc = qd.simt.subgroup.shuffle(b, qd.u32(j)) - val = self._r(j) - a * bc - if j == 0: - self.r0 = val - if j == 1: - self.r1 = val - if j == 2: - self.r2 = val - if j == 3: - self.r3 = val - if j == 4: - self.r4 = val - if j == 5: - self.r5 = val - if j == 6: - self.r6 = val - if j == 7: - self.r7 = val - if j == 8: - self.r8 = val - if j == 9: - self.r9 = val - if j == 10: - self.r10 = val - if j == 11: - self.r11 = val - if j == 12: - self.r12 = val - if j == 13: - self.r13 = val - if j == 14: - self.r14 = val - if j == 15: - self.r15 = val - if j == 16: - self.r16 = val - if j == 17: - self.r17 = val - if j == 18: - self.r18 = val - if j == 19: - self.r19 = val - if j == 20: - self.r20 = val - if j == 21: - self.r21 = val - if j == 22: - self.r22 = val - if j == 23: - self.r23 = val - if j == 24: - self.r24 = val - if j == 25: - self.r25 = val - if j == 26: - self.r26 = val - if j == 27: - self.r27 = val - if j == 28: - self.r28 = val - if j == 29: - self.r29 = val - if j == 30: - self.r30 = val - if j == 31: - self.r31 = val + self.r[j] = self.r[j] - a * bc @qd.func def cholesky_(self, eps): @@ -613,92 +180,26 @@ def cholesky_(self, eps): On return, the lower triangle holds L such that A = L @ L^T. Diagonal clamped to sqrt(max(value, eps)) for numerical stability. """ - # `k` and `j` are wrapped in qd.static so the `if k > j` predicates fold at compile time and register access - # on the outer `k` and inner `j` collapses to a single field reference via `self._r()` (a thin - # getattr wrapper) rather than a 32-deep register-indexing cascade. Writes use an inline `if k == N: - # self.rN = ...` chain (setattr is rejected by the quadrants AST builder) which the AST transformer folds at - # build time when `k` is a python int. The per-lane row-norm used for the diagonal update is carried in - # `my_norm_sq`, so each diagonal step is O(1) rather than O(k). The off-diagonal `dot` is split into two - # interleaved partial sums (`dot0`/`dot1`) so the back-to-back FMA dependency chain is cut in half, - # exposing more instruction-level parallelism. + # `k` and `j` are wrapped in qd.static so the `if k > j` predicate folds at compile time and the + # `self.r[k]` / `self.r[j]` accesses resolve to a single unpacked-register slot per use (no runtime cascade). + # The per-lane row-norm used for the diagonal update is carried in `my_norm_sq`, so each diagonal step is + # O(1) rather than O(k). The off-diagonal `dot` is split into two interleaved partial sums (`dot0`/`dot1`) + # so the back-to-back FMA dependency chain is cut in half, exposing more instruction-level parallelism. tid = qd.i32(qd.simt.subgroup.invocation_id()) my_norm_sq = qd.cast(0.0, dtype) - for k in qd.static(range(32)): + for k in qd.static(range(_TILE)): diag_val = qd.cast(0.0, dtype) if tid == k: - diag_val = qd.sqrt(qd.max(self._r(k) - my_norm_sq, eps)) - if k == 0: - self.r0 = diag_val - if k == 1: - self.r1 = diag_val - if k == 2: - self.r2 = diag_val - if k == 3: - self.r3 = diag_val - if k == 4: - self.r4 = diag_val - if k == 5: - self.r5 = diag_val - if k == 6: - self.r6 = diag_val - if k == 7: - self.r7 = diag_val - if k == 8: - self.r8 = diag_val - if k == 9: - self.r9 = diag_val - if k == 10: - self.r10 = diag_val - if k == 11: - self.r11 = diag_val - if k == 12: - self.r12 = diag_val - if k == 13: - self.r13 = diag_val - if k == 14: - self.r14 = diag_val - if k == 15: - self.r15 = diag_val - if k == 16: - self.r16 = diag_val - if k == 17: - self.r17 = diag_val - if k == 18: - self.r18 = diag_val - if k == 19: - self.r19 = diag_val - if k == 20: - self.r20 = diag_val - if k == 21: - self.r21 = diag_val - if k == 22: - self.r22 = diag_val - if k == 23: - self.r23 = diag_val - if k == 24: - self.r24 = diag_val - if k == 25: - self.r25 = diag_val - if k == 26: - self.r26 = diag_val - if k == 27: - self.r27 = diag_val - if k == 28: - self.r28 = diag_val - if k == 29: - self.r29 = diag_val - if k == 30: - self.r30 = diag_val - if k == 31: - self.r31 = diag_val + diag_val = qd.sqrt(qd.max(self.r[k] - my_norm_sq, eps)) + self.r[k] = diag_val diag_k = qd.simt.subgroup.shuffle(diag_val, qd.u32(k)) dot0 = qd.cast(0.0, dtype) dot1 = qd.cast(0.0, dtype) - for j in qd.static(range(32)): + for j in qd.static(range(_TILE)): if k > j: - my_col = self._r(j) + my_col = self.r[j] Lkj = qd.simt.subgroup.shuffle(my_col, qd.u32(k)) if j % 2 == 0: dot0 += Lkj * my_col # type: ignore[reportOperatorIssue] @@ -708,71 +209,8 @@ def cholesky_(self, eps): new_val = qd.cast(0.0, dtype) if tid > k: # type: ignore[reportOperatorIssue] - new_val = (self._r(k) - dot) / diag_k # type: ignore[reportOperatorIssue] - if k == 0: - self.r0 = new_val - if k == 1: - self.r1 = new_val - if k == 2: - self.r2 = new_val - if k == 3: - self.r3 = new_val - if k == 4: - self.r4 = new_val - if k == 5: - self.r5 = new_val - if k == 6: - self.r6 = new_val - if k == 7: - self.r7 = new_val - if k == 8: - self.r8 = new_val - if k == 9: - self.r9 = new_val - if k == 10: - self.r10 = new_val - if k == 11: - self.r11 = new_val - if k == 12: - self.r12 = new_val - if k == 13: - self.r13 = new_val - if k == 14: - self.r14 = new_val - if k == 15: - self.r15 = new_val - if k == 16: - self.r16 = new_val - if k == 17: - self.r17 = new_val - if k == 18: - self.r18 = new_val - if k == 19: - self.r19 = new_val - if k == 20: - self.r20 = new_val - if k == 21: - self.r21 = new_val - if k == 22: - self.r22 = new_val - if k == 23: - self.r23 = new_val - if k == 24: - self.r24 = new_val - if k == 25: - self.r25 = new_val - if k == 26: - self.r26 = new_val - if k == 27: - self.r27 = new_val - if k == 28: - self.r28 = new_val - if k == 29: - self.r29 = new_val - if k == 30: - self.r30 = new_val - if k == 31: - self.r31 = new_val + new_val = (self.r[k] - dot) / diag_k # type: ignore[reportOperatorIssue] + self.r[k] = new_val if tid > k: # type: ignore[reportOperatorIssue] my_norm_sq += new_val * new_val @@ -783,16 +221,15 @@ def _trsm(self, L): L is a Tile32x32 holding the lower-triangular Cholesky factor (from cholesky_). On return, self holds the solution X. """ - for c in range(32): + for c in qd.static(range(_TILE)): dot = qd.cast(0.0, dtype) - for j in range(32): + for j in qd.static(range(_TILE)): if c > j: - Lkj = qd.simt.subgroup.shuffle(L._get_col(j), qd.u32(c)) - dot += self._get_col(j) * Lkj # type: ignore[reportOperatorIssue] + Lkj = qd.simt.subgroup.shuffle(L.r[j], qd.u32(c)) + dot += self.r[j] * Lkj # type: ignore[reportOperatorIssue] - diag_c = qd.simt.subgroup.shuffle(L._get_col(c), qd.u32(c)) - new_val = (self._get_col(c) - dot) / diag_c # type: ignore[reportOperatorIssue] - self._set_col(c, new_val) + diag_c = qd.simt.subgroup.shuffle(L.r[c], qd.u32(c)) + self.r[c] = (self.r[c] - dot) / diag_c # type: ignore[reportOperatorIssue] def solve_triangular_(self, B: Any, lower: bool = True) -> None: """Triangular solve: X @ self^T = B, storing result X in B in-place. @@ -804,13 +241,6 @@ def solve_triangular_(self, B: Any, lower: bool = True) -> None: raise TypeError("Tile32x32.solve_triangular_: only lower=True is supported") B._trsm(self) - def _r(self, k): - """Direct field read by python-int index. Used at qd.static-unrolled call sites to bypass the _TILE-way - ``_get_col(k)`` cascade: with ``k`` a python int (from ``qd.static(range(_TILE))``), - ``getattr(self, _REGS[k])`` is evaluated by the AST transformer at build time and returns a single - field-reference expression.""" - return getattr(self, _REGS[k]) - @qd.func def _resolve_vec2d(self, arr: qd.template(), row_start, row_stop, col): """Load one scalar per thread from a 2D array column, clamped to array bounds.""" From 28fa995c2fa626ab73271445cf9307f16ea73345 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 5 Jun 2026 13:10:07 -0700 Subject: [PATCH 2/2] [Tile] Fuse _tile16.py / _tile32.py into a single _tile.py parametrized on N The two factory bodies were structurally identical except for ``_TILE = 16`` vs ``_TILE = 32``. Replace them with a single ``_make_tile_class(N, dtype)`` factory and a single ``_TileProxy(N)`` proxy class, then instantiate ``Tile16x16Proxy = _TileProxy(16)`` and ``Tile32x32Proxy = _TileProxy(32)``. Net diff for this commit: -343 lines. Same generated IR. Updates the few internal consumers (``simt/__init__.py``, ``tile_slicing.py``, ``quadrants/__init__.py``, ``tests/python/test_tile.py``) and a couple of stale ``test_tile16`` references in the docs. --- docs/source/user_guide/contributing.md | 2 +- docs/source/user_guide/unit_testing.md | 6 +- python/quadrants/__init__.py | 2 +- python/quadrants/lang/simt/__init__.py | 19 +- .../lang/simt/{_tile16.py => _tile.py} | 134 +++---- python/quadrants/lang/simt/_tile32.py | 347 ------------------ python/quadrants/lang/simt/tile_slicing.py | 11 +- tests/python/test_tile.py | 16 +- 8 files changed, 97 insertions(+), 440 deletions(-) rename python/quadrants/lang/simt/{_tile16.py => _tile.py} (79%) delete mode 100644 python/quadrants/lang/simt/_tile32.py diff --git a/docs/source/user_guide/contributing.md b/docs/source/user_guide/contributing.md index 3573179e84..f2f22e35b2 100644 --- a/docs/source/user_guide/contributing.md +++ b/docs/source/user_guide/contributing.md @@ -11,7 +11,7 @@ Run the test suite with `python tests/run_tests.py`. CLI arguments are forwarded to pytest. For example, to run only Metal tests matching a keyword: ``` -python tests/run_tests.py --arch metal -k "test_tile16_cholesky" +python tests/run_tests.py --arch metal -k "test_cholesky" ``` The target architecture can also be set via the `QD_WANTED_ARCHS` environment variable (comma-separated, e.g. `QD_WANTED_ARCHS=metal,vulkan`). diff --git a/docs/source/user_guide/unit_testing.md b/docs/source/user_guide/unit_testing.md index 08453a9912..02266bc010 100644 --- a/docs/source/user_guide/unit_testing.md +++ b/docs/source/user_guide/unit_testing.md @@ -16,14 +16,14 @@ Common one-liners: ``` # run one file -python tests/run_tests.py test_tile16 +python tests/run_tests.py test_tile # run one test (any pytest -k expression) -python tests/run_tests.py -k test_tile16_cholesky +python tests/run_tests.py -k test_cholesky # run on a specific backend (or comma-separated list) python tests/run_tests.py --arch cuda -python tests/run_tests.py --arch metal -k tile16 +python tests/run_tests.py --arch metal -k tile # same, via env var (handy for CI) QD_WANTED_ARCHS=metal,vulkan python tests/run_tests.py diff --git a/python/quadrants/__init__.py b/python/quadrants/__init__.py index 3bbadb160d..7c38336a76 100644 --- a/python/quadrants/__init__.py +++ b/python/quadrants/__init__.py @@ -56,7 +56,7 @@ def __getattr__(attr): if attr == "cfg": return None if lang.impl.get_runtime()._prog is None else lang.impl.current_cfg() if attr == "outer": - from quadrants.lang.simt._tile16 import outer # noqa: I001 # pylint: disable=import-outside-toplevel + from quadrants.lang.simt._tile import outer # noqa: I001 # pylint: disable=import-outside-toplevel return outer raise AttributeError(f"module '{__name__}' has no attribute '{attr}'") diff --git a/python/quadrants/lang/simt/__init__.py b/python/quadrants/lang/simt/__init__.py index 498d0424d4..549bbd348c 100644 --- a/python/quadrants/lang/simt/__init__.py +++ b/python/quadrants/lang/simt/__init__.py @@ -3,25 +3,20 @@ from quadrants.lang.simt import block, grid, subgroup, warp if TYPE_CHECKING: - from quadrants.lang.simt._tile16 import Tile16x16Proxy as Tile16x16 - from quadrants.lang.simt._tile32 import Tile32x32Proxy as Tile32x32 + from quadrants.lang.simt._tile import Tile16x16Proxy as Tile16x16 + from quadrants.lang.simt._tile import Tile32x32Proxy as Tile32x32 __all__ = ["warp", "subgroup", "block", "grid", "Tile16x16", "Tile32x32"] def __getattr__(name): - if name == "Tile16x16": - from quadrants.lang.simt._tile16 import ( # pylint: disable=import-outside-toplevel + if name in ("Tile16x16", "Tile32x32"): + from quadrants.lang.simt._tile import ( # pylint: disable=import-outside-toplevel Tile16x16Proxy, - ) - - globals()["Tile16x16"] = Tile16x16Proxy - return Tile16x16Proxy - if name == "Tile32x32": - from quadrants.lang.simt._tile32 import ( # pylint: disable=import-outside-toplevel Tile32x32Proxy, ) - globals()["Tile32x32"] = Tile32x32Proxy - return Tile32x32Proxy + proxy = Tile16x16Proxy if name == "Tile16x16" else Tile32x32Proxy + globals()[name] = proxy + return proxy raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/python/quadrants/lang/simt/_tile16.py b/python/quadrants/lang/simt/_tile.py similarity index 79% rename from python/quadrants/lang/simt/_tile16.py rename to python/quadrants/lang/simt/_tile.py index d0422a23f6..d3361cddec 100644 --- a/python/quadrants/lang/simt/_tile16.py +++ b/python/quadrants/lang/simt/_tile.py @@ -1,14 +1,18 @@ # pyright: reportInvalidTypeForm=false """ -Register-resident 16x16 tile operations. +Register-resident NxN tile operations. -Each tile is a 16x16 matrix distributed across 16 threads in a subgroup, one row per thread, with each row stored -in 16 scalar registers held in an unpacked vector field (``self.r``). Cross-thread communication uses subgroup -shuffles -- no shared memory needed. +Each tile is an NxN matrix distributed across N threads in a subgroup, one row per thread, with each row stored in N +scalar registers held in an unpacked vector field (``self.r``). Cross-thread communication uses subgroup shuffles -- +no shared memory needed. -The thread's lane index (tid) is obtained internally via ``subgroup.invocation_id()``, so callers never need to -pass it. See docs/source/user_guide/tile.md for usage documentation. +A single factory ``_make_tile_class(N, dtype)`` builds the tile dataclass for both supported tile sizes (N == 16 and +N == 32). The user-facing entry points are the proxies ``qd.simt.Tile16x16`` and ``qd.simt.Tile32x32``, which defer +dtype resolution to kernel compile time (defaulting to the runtime ``default_fp``). + +The thread's lane index (tid) is obtained internally via ``subgroup.invocation_id()``, so callers never need to pass +it. See docs/source/user_guide/tile.md for usage documentation. """ from typing import TYPE_CHECKING as _TYPE_CHECKING @@ -18,19 +22,19 @@ if _TYPE_CHECKING: - class _Tile16x16Proto: # noqa: E303 - """Static type stub so pyright sees Tile16x16 methods correctly.""" + class _TileProto: # noqa: E303 + """Static type stub so pyright sees TileNxN methods correctly (shared by Tile16x16 and Tile32x32).""" SIZE: int def __init__(self, *args: Any, **kwargs: Any) -> None: ... # noqa: E704 @classmethod - def zeros(cls) -> "_Tile16x16Proto": ... # noqa: E704 + def zeros(cls) -> "_TileProto": ... # noqa: E704 @classmethod - def eye(cls) -> "_Tile16x16Proto": ... # noqa: E704 + def eye(cls) -> "_TileProto": ... # noqa: E704 def eye_(self) -> None: ... # noqa: E704 def cholesky_(self, eps: Any) -> None: ... # noqa: E704 - def solve_triangular_(self, B: "_Tile16x16Proto", lower: bool = True) -> None: ... # noqa: E704 + def solve_triangular_(self, B: "_TileProto", lower: bool = True) -> None: ... # noqa: E704 def _load(self, arr: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any) -> None: ... # noqa: E704 def _store( self, arr: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any @@ -42,17 +46,14 @@ def _store3d( self, arr: Any, batch: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any ) -> None: ... # noqa: E704 def _ger_sub(self, a: Any, b: Any) -> None: ... # noqa: E704 - def _trsm(self, L: "_Tile16x16Proto") -> None: ... # noqa: E704 - def __isub__(self, other: Any) -> "_Tile16x16Proto": ... # noqa: E704 + def _trsm(self, L: "_TileProto") -> None: ... # noqa: E704 + def __isub__(self, other: Any) -> "_TileProto": ... # noqa: E704 def __getitem__(self, key: Any) -> Any: ... # noqa: E704 def __setitem__(self, key: Any, value: Any) -> None: ... # noqa: E704 -_TILE = 16 - - class _OuterProduct: - """Deferred outer product proxy for use with augmented assignment on Tile16x16. + """Deferred outer product proxy for use with augmented assignment on a Tile. Created by qd.outer(a, b). Not a quadrants expression -- only valid as the RHS of ``tile -= qd.outer(a, b)``. """ @@ -71,7 +72,7 @@ def __radd__(self, other: Any) -> NoReturn: def outer(a: Any, b: Any) -> _OuterProduct: - """Create a deferred outer product for use with Tile16x16 augmented assignment. + """Create a deferred outer product for use with Tile augmented assignment. Usage:: @@ -160,7 +161,7 @@ def __init__(self, arr: Any, row_start: Any, row_stop: Any, col: Any, batch_idx: class _TileRefProxy: """Proxy returned by tile[:] for the LHS of a load assignment. - Enables ``tile[:] = arr[r:r+16, c:n]``. The ``[:]`` is required to distinguish in-place tile loads from + Enables ``tile[:] = arr[r:r+N, c:n]``. The ``[:]`` is required to distinguish in-place tile loads from variable rebinding. """ @@ -179,32 +180,35 @@ def _assign(self, value: Any) -> None: else: self.tile._load(value.arr, value.row_start, value.row_stop, value.col_start, value.col_stop) else: - raise TypeError(f"Tile16x16[:] can only be assigned from an array slice, got {type(value)}") + raise TypeError(f"Tile[:] can only be assigned from an array slice, got {type(value)}") -_tile16_cache = {} +_tile_cache: dict = {} -def _make_tile16x16(dtype=None) -> "type[_Tile16x16Proto]": - """Create a Tile16x16 dataclass whose registers use the given scalar dtype (qd.f32 or qd.f64). +def _make_tile(N: int, dtype=None) -> "type[_TileProto]": + """Create a TileNxN dataclass whose registers use the given scalar dtype (qd.f32 or qd.f64). - This is an internal factory. Use ``qd.simt.Tile16x16`` (the proxy) instead. + This is an internal factory. Use ``qd.simt.Tile16x16`` / ``qd.simt.Tile32x32`` (the proxies) instead. """ if dtype is None: dtype = qd.f32 - if dtype in _tile16_cache: - return _tile16_cache[dtype] # pyright: ignore[reportReturnType] - cls = _make_tile16x16_class(dtype) - _tile16_cache[dtype] = cls + key = (N, dtype) + if key in _tile_cache: + return _tile_cache[key] # pyright: ignore[reportReturnType] + cls = _make_tile_class(N, dtype) + _tile_cache[key] = cls return cls # pyright: ignore[reportReturnType] -def _make_tile16x16_class(dtype): - class _Tile16x16: - """A 16x16 tile distributed one row per subgroup thread, with each row held in 16 scalar registers via an - unpacked vector field. ``Tile16x16()`` creates a zero tile.""" +def _make_tile_class(N: int, dtype): + name = f"Tile{N}x{N}" - r: qd.types.vector(_TILE, dtype, unpacked=True) + class _Tile: + """An NxN tile distributed one row per subgroup thread, with each row held in N scalar registers via an + unpacked vector field. ``TileNxN()`` creates a zero tile.""" + + r: qd.types.vector(N, dtype, unpacked=True) @qd.func def _load(self, arr: qd.template(), row_start, row_stop, col_start, col_stop): @@ -221,7 +225,7 @@ def _load(self, arr: qd.template(), row_start, row_stop, col_start, col_stop): arr_col_stop = arr.shape[1] if arr_col_stop < col_stop: col_stop = arr_col_stop - for j in qd.static(range(_TILE)): + for j in qd.static(range(N)): if col_start + j < col_stop: self.r[j] = arr[row, col_start + j] @@ -240,7 +244,7 @@ def _load3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, col arr_col_stop = arr.shape[2] if arr_col_stop < col_stop: col_stop = arr_col_stop - for j in qd.static(range(_TILE)): + for j in qd.static(range(N)): if col_start + j < col_stop: self.r[j] = arr[batch, row, col_start + j] @@ -259,7 +263,7 @@ def _store(self, arr: qd.template(), row_start, row_stop, col_start, col_stop): arr_col_stop = arr.shape[1] if arr_col_stop < col_stop: col_stop = arr_col_stop - for j in qd.static(range(_TILE)): + for j in qd.static(range(N)): if col_start + j < col_stop: arr[row, col_start + j] = self.r[j] @@ -278,28 +282,28 @@ def _store3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, co arr_col_stop = arr.shape[2] if arr_col_stop < col_stop: col_stop = arr_col_stop - for j in qd.static(range(_TILE)): + for j in qd.static(range(N)): if col_start + j < col_stop: arr[batch, row, col_start + j] = self.r[j] @qd.func def eye_(self): - """Set this tile to the 16x16 identity matrix. Each thread sets its diagonal element to 1.0 and all + """Set this tile to the NxN identity matrix. Each thread sets its diagonal element to 1.0 and all others to 0.0.""" tid = qd.simt.subgroup.invocation_id() - for j in qd.static(range(_TILE)): + for j in qd.static(range(N)): self.r[j] = 1.0 if tid == j else 0.0 @qd.func def _ger_sub(self, a, b): """General rank-1 subtract in-place: self -= a @ b^T.""" - for j in qd.static(range(_TILE)): + for j in qd.static(range(N)): bc = qd.simt.subgroup.shuffle(b, qd.u32(j)) self.r[j] = self.r[j] - a * bc @qd.func def cholesky_(self, eps): - """In-place 16x16 Cholesky factorization via subgroup shuffles. + """In-place NxN Cholesky factorization via subgroup shuffles. On return, the lower triangle holds L such that A = L @ L^T. Diagonal clamped to sqrt(max(value, eps)) for numerical stability. @@ -312,7 +316,7 @@ def cholesky_(self, eps): # instruction-level parallelism. tid = qd.i32(qd.simt.subgroup.invocation_id()) my_norm_sq = qd.cast(0.0, dtype) - for k in qd.static(range(_TILE)): + for k in qd.static(range(N)): diag_val = qd.cast(0.0, dtype) if tid == k: diag_val = qd.sqrt(qd.max(self.r[k] - my_norm_sq, eps)) @@ -322,7 +326,7 @@ def cholesky_(self, eps): dot0 = qd.cast(0.0, dtype) dot1 = qd.cast(0.0, dtype) - for j in qd.static(range(_TILE)): + for j in qd.static(range(N)): if k > j: my_col = self.r[j] Lkj = qd.simt.subgroup.shuffle(my_col, qd.u32(k)) @@ -343,12 +347,12 @@ def cholesky_(self, eps): def _trsm(self, L): """In-place triangular solve: solve self @ L^T = B (original self). - L is a Tile16x16 holding the lower-triangular Cholesky factor (from cholesky_). On return, self holds - the solution X. + L is a TileNxN holding the lower-triangular Cholesky factor (from cholesky_). On return, self holds the + solution X. """ - for c in qd.static(range(_TILE)): + for c in qd.static(range(N)): dot = qd.cast(0.0, dtype) - for j in qd.static(range(_TILE)): + for j in qd.static(range(N)): if c > j: Lkj = qd.simt.subgroup.shuffle(L.r[j], qd.u32(c)) dot += self.r[j] * Lkj # type: ignore[reportOperatorIssue] @@ -363,7 +367,7 @@ def solve_triangular_(self, B: Any, lower: bool = True) -> None: matrix causes division by zero, producing inf/NaN without warning. Only lower=True is supported. """ if not lower: - raise TypeError("Tile16x16.solve_triangular_: only lower=True is supported") + raise TypeError(f"{name}.solve_triangular_: only lower=True is supported") B._trsm(self) @qd.func @@ -413,14 +417,17 @@ def _augassign(self, other: Any, op: str) -> None: ) self._ger_sub(a, b) else: - raise TypeError(f"Tile16x16: unsupported augmented assignment op '{op}' with outer product") + raise TypeError(f"{name}: unsupported augmented assignment op '{op}' with outer product") else: - raise TypeError(f"Tile16x16: unsupported augmented assignment with {type(other)}") + raise TypeError(f"{name}: unsupported augmented assignment with {type(other)}") + + _Tile.__name__ = f"_{name}" + _Tile.__qualname__ = f"_make_tile_class.._{name}" # StructType.__call__ already defaults missing args to 0, so Tile() produces a zero-initialized tile # without needing default values in the class definition (which @qd.dataclass doesn't support). - result = qd.dataclass(_Tile16x16) - result.SIZE = _TILE # type: ignore[reportAttributeAccessIssue] + result = qd.dataclass(_Tile) + result.SIZE = N # type: ignore[reportAttributeAccessIssue] result.zeros = result # type: ignore[reportAttributeAccessIssue] @qd.func @@ -433,17 +440,18 @@ def _eye(): return result -class _Tile16x16Proxy: +class _TileProxy: """Proxy for dtype-at-point-of-use tile creation. - Use as ``qd.simt.Tile16x16.zeros(dtype=qd.f32)`` inside a kernel. The dtype is resolved at kernel compilation - time, defaulting to the compile config's ``default_fp`` if omitted. + Use as ``qd.simt.Tile16x16.zeros(dtype=qd.f32)`` or ``qd.simt.Tile32x32.zeros(dtype=qd.f32)`` inside a kernel. + The dtype is resolved at kernel compilation time, defaulting to the compile config's ``default_fp`` if omitted. """ - SIZE = _TILE + def __init__(self, N: int) -> None: + self._N = N + self.SIZE = N - @staticmethod - def _resolve(dtype): + def _resolve(self, dtype): from quadrants.lang import impl # pylint: disable=import-outside-toplevel from quadrants.lang.exception import ( # pylint: disable=import-outside-toplevel QuadrantsSyntaxError, @@ -452,13 +460,12 @@ def _resolve(dtype): arch = impl.current_cfg().arch if arch in (qd.cpu, qd.x64, getattr(qd, "arm64", None)): raise QuadrantsSyntaxError( - "Tile16x16 requires a GPU backend (cuda, metal, vulkan, amdgpu). " f"Current arch is {arch}." + f"Tile{self._N}x{self._N} requires a GPU backend (cuda, metal, vulkan, amdgpu). " + f"Current arch is {arch}." ) if dtype is None: dtype = impl.get_runtime().default_fp - if dtype in _tile16_cache: - return _tile16_cache[dtype] - return _make_tile16x16(dtype) + return _make_tile(self._N, dtype) def zeros(self, *, dtype=None): """Zero-initialized tile.""" @@ -469,4 +476,5 @@ def eye(self, *, dtype=None): return self._resolve(dtype).eye() -Tile16x16Proxy = _Tile16x16Proxy() +Tile16x16Proxy = _TileProxy(16) +Tile32x32Proxy = _TileProxy(32) diff --git a/python/quadrants/lang/simt/_tile32.py b/python/quadrants/lang/simt/_tile32.py deleted file mode 100644 index 2d03f63f87..0000000000 --- a/python/quadrants/lang/simt/_tile32.py +++ /dev/null @@ -1,347 +0,0 @@ -# pyright: reportInvalidTypeForm=false - -""" -Register-resident 32x32 tile operations. - -Each tile is a 32x32 matrix distributed across 32 threads in a subgroup, one row per thread, with each row stored -in 32 scalar registers held in an unpacked vector field (``self.r``). Cross-thread communication uses subgroup -shuffles -- no shared memory needed. - -Surface mirrors :mod:`quadrants.lang.simt._tile16`: use ``qd.simt.Tile32x32.zeros(dtype=...)`` / -``qd.simt.Tile32x32.eye(dtype=...)`` inside a kernel, then load / store via slice syntax -(``tile[:] = arr[r0:r1, c0:c1]`` / ``arr[r0:r1, c0:c1] = tile``) and update via ``tile -= qd.outer(a, b)``. - -The slice / outer-product proxy classes are imported from :mod:`._tile16` so a single set of dispatch helpers serves -both tile sizes (a tile instance is duck-typed against ``_load`` / ``_store`` / ``_ger_sub``). -""" - -from typing import TYPE_CHECKING as _TYPE_CHECKING -from typing import Any - -import quadrants as qd -from quadrants.lang.simt._tile16 import ( - _OuterProduct, - _VecSliceProxy, -) - -if _TYPE_CHECKING: - - class _Tile32x32Proto: # noqa: E303 - """Static type stub so pyright sees Tile32x32 methods correctly.""" - - SIZE: int - - def __init__(self, *args: Any, **kwargs: Any) -> None: ... # noqa: E704 - @classmethod - def zeros(cls) -> "_Tile32x32Proto": ... # noqa: E704 - @classmethod - def eye(cls) -> "_Tile32x32Proto": ... # noqa: E704 - def eye_(self) -> None: ... # noqa: E704 - def cholesky_(self, eps: Any) -> None: ... # noqa: E704 - def solve_triangular_(self, B: "_Tile32x32Proto", lower: bool = True) -> None: ... # noqa: E704 - def _load(self, arr: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any) -> None: ... # noqa: E704 - def _store( - self, arr: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any - ) -> None: ... # noqa: E704 - def _load3d( - self, arr: Any, batch: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any - ) -> None: ... # noqa: E704 - def _store3d( - self, arr: Any, batch: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any - ) -> None: ... # noqa: E704 - def _ger_sub(self, a: Any, b: Any) -> None: ... # noqa: E704 - def _trsm(self, L: "_Tile32x32Proto") -> None: ... # noqa: E704 - def __isub__(self, other: Any) -> "_Tile32x32Proto": ... # noqa: E704 - def __getitem__(self, key: Any) -> Any: ... # noqa: E704 - def __setitem__(self, key: Any, value: Any) -> None: ... # noqa: E704 - - -_TILE = 32 - -_tile32_cache = {} - - -def _make_tile32x32(dtype=None) -> "type[_Tile32x32Proto]": - """Create a Tile32x32 dataclass whose registers use the given scalar dtype (qd.f32 or qd.f64). - - This is an internal factory. Use ``qd.simt.Tile32x32`` (the proxy) instead. - """ - if dtype is None: - dtype = qd.f32 - if dtype in _tile32_cache: - return _tile32_cache[dtype] # pyright: ignore[reportReturnType] - cls = _make_tile32x32_class(dtype) - _tile32_cache[dtype] = cls - return cls # pyright: ignore[reportReturnType] - - -def _make_tile32x32_class(dtype): - class _Tile32x32: - """A 32x32 tile distributed one row per subgroup thread, with each row held in 32 scalar registers via an - unpacked vector field. ``Tile32x32()`` creates a zero tile.""" - - r: qd.types.vector(_TILE, dtype, unpacked=True) - - @qd.func - def _load(self, arr: qd.template(), row_start, row_stop, col_start, col_stop): - """Load from a 2D array within [row_start, row_stop) x [col_start, col_stop). - - Each thread loads arr[row_start + tid, col_start:col_stop]. Threads where row_start + tid >= row_stop - skip the load (tile row unchanged). - """ - arr_row_stop = arr.shape[0] - if arr_row_stop < row_stop: - row_stop = arr_row_stop - row = row_start + qd.simt.subgroup.invocation_id() - if row < row_stop: - arr_col_stop = arr.shape[1] - if arr_col_stop < col_stop: - col_stop = arr_col_stop - for j in qd.static(range(_TILE)): - if col_start + j < col_stop: - self.r[j] = arr[row, col_start + j] - - @qd.func - def _load3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, col_stop): - """Load from a 3D array within [row_start, row_stop) x [col_start, col_stop). - - Each thread loads arr[batch, row_start+tid, col_start:col_stop]. Threads where row_start + tid >= - row_stop skip the load (tile row unchanged). - """ - arr_row_stop = arr.shape[1] - if arr_row_stop < row_stop: - row_stop = arr_row_stop - row = row_start + qd.simt.subgroup.invocation_id() - if row < row_stop: - arr_col_stop = arr.shape[2] - if arr_col_stop < col_stop: - col_stop = arr_col_stop - for j in qd.static(range(_TILE)): - if col_start + j < col_stop: - self.r[j] = arr[batch, row, col_start + j] - - @qd.func - def _store(self, arr: qd.template(), row_start, row_stop, col_start, col_stop): - """Store to a 2D array within [row_start, row_stop) x [col_start, col_stop). - - Each thread stores to arr[row_start + tid, col_start:col_stop]. Threads where row_start + tid >= - row_stop skip the store. - """ - arr_row_stop = arr.shape[0] - if arr_row_stop < row_stop: - row_stop = arr_row_stop - row = row_start + qd.simt.subgroup.invocation_id() - if row < row_stop: - arr_col_stop = arr.shape[1] - if arr_col_stop < col_stop: - col_stop = arr_col_stop - for j in qd.static(range(_TILE)): - if col_start + j < col_stop: - arr[row, col_start + j] = self.r[j] - - @qd.func - def _store3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, col_stop): - """Store to a 3D array within [row_start, row_stop) x [col_start, col_stop). - - Each thread stores to arr[batch, row_start+tid, col_start:col_stop]. Threads where row_start + tid >= - row_stop skip the store. - """ - arr_row_stop = arr.shape[1] - if arr_row_stop < row_stop: - row_stop = arr_row_stop - row = row_start + qd.simt.subgroup.invocation_id() - if row < row_stop: - arr_col_stop = arr.shape[2] - if arr_col_stop < col_stop: - col_stop = arr_col_stop - for j in qd.static(range(_TILE)): - if col_start + j < col_stop: - arr[batch, row, col_start + j] = self.r[j] - - @qd.func - def eye_(self): - """Set this tile to the 32x32 identity matrix. Each thread sets its diagonal element to 1.0 and all - others to 0.0.""" - tid = qd.simt.subgroup.invocation_id() - for j in qd.static(range(_TILE)): - self.r[j] = 1.0 if tid == j else 0.0 - - @qd.func - def _ger_sub(self, a, b): - """General rank-1 subtract in-place: self -= a @ b^T.""" - for j in qd.static(range(_TILE)): - bc = qd.simt.subgroup.shuffle(b, qd.u32(j)) - self.r[j] = self.r[j] - a * bc - - @qd.func - def cholesky_(self, eps): - """In-place 32x32 Cholesky factorization via subgroup shuffles. - - On return, the lower triangle holds L such that A = L @ L^T. Diagonal clamped to sqrt(max(value, eps)) - for numerical stability. - """ - # `k` and `j` are wrapped in qd.static so the `if k > j` predicate folds at compile time and the - # `self.r[k]` / `self.r[j]` accesses resolve to a single unpacked-register slot per use (no runtime cascade). - # The per-lane row-norm used for the diagonal update is carried in `my_norm_sq`, so each diagonal step is - # O(1) rather than O(k). The off-diagonal `dot` is split into two interleaved partial sums (`dot0`/`dot1`) - # so the back-to-back FMA dependency chain is cut in half, exposing more instruction-level parallelism. - tid = qd.i32(qd.simt.subgroup.invocation_id()) - my_norm_sq = qd.cast(0.0, dtype) - for k in qd.static(range(_TILE)): - diag_val = qd.cast(0.0, dtype) - if tid == k: - diag_val = qd.sqrt(qd.max(self.r[k] - my_norm_sq, eps)) - self.r[k] = diag_val - - diag_k = qd.simt.subgroup.shuffle(diag_val, qd.u32(k)) - - dot0 = qd.cast(0.0, dtype) - dot1 = qd.cast(0.0, dtype) - for j in qd.static(range(_TILE)): - if k > j: - my_col = self.r[j] - Lkj = qd.simt.subgroup.shuffle(my_col, qd.u32(k)) - if j % 2 == 0: - dot0 += Lkj * my_col # type: ignore[reportOperatorIssue] - else: - dot1 += Lkj * my_col # type: ignore[reportOperatorIssue] - dot = dot0 + dot1 - - new_val = qd.cast(0.0, dtype) - if tid > k: # type: ignore[reportOperatorIssue] - new_val = (self.r[k] - dot) / diag_k # type: ignore[reportOperatorIssue] - self.r[k] = new_val - if tid > k: # type: ignore[reportOperatorIssue] - my_norm_sq += new_val * new_val - - @qd.func - def _trsm(self, L): - """In-place triangular solve: solve self @ L^T = B (original self). - - L is a Tile32x32 holding the lower-triangular Cholesky factor (from cholesky_). On return, self holds - the solution X. - """ - for c in qd.static(range(_TILE)): - dot = qd.cast(0.0, dtype) - for j in qd.static(range(_TILE)): - if c > j: - Lkj = qd.simt.subgroup.shuffle(L.r[j], qd.u32(c)) - dot += self.r[j] * Lkj # type: ignore[reportOperatorIssue] - - diag_c = qd.simt.subgroup.shuffle(L.r[c], qd.u32(c)) - self.r[c] = (self.r[c] - dot) / diag_c # type: ignore[reportOperatorIssue] - - def solve_triangular_(self, B: Any, lower: bool = True) -> None: - """Triangular solve: X @ self^T = B, storing result X in B in-place. - - self must be lower-triangular and non-singular (all diagonal elements non-zero). Passing a singular - matrix causes division by zero, producing inf/NaN without warning. Only lower=True is supported. - """ - if not lower: - raise TypeError("Tile32x32.solve_triangular_: only lower=True is supported") - B._trsm(self) - - @qd.func - def _resolve_vec2d(self, arr: qd.template(), row_start, row_stop, col): - """Load one scalar per thread from a 2D array column, clamped to array bounds.""" - tid = qd.i32(qd.simt.subgroup.invocation_id()) - arr_row_stop = arr.shape[0] - if arr_row_stop < row_stop: - row_stop = arr_row_stop - v = dtype(0.0) - if row_start + tid < row_stop: - v = arr[row_start + tid, col] - return v - - @qd.func - def _resolve_vec3d(self, arr: qd.template(), batch, row_start, row_stop, col): - """Load one scalar per thread from a 3D array column, clamped to array bounds.""" - tid = qd.i32(qd.simt.subgroup.invocation_id()) - arr_row_stop = arr.shape[1] - if arr_row_stop < row_stop: - row_stop = arr_row_stop - v = dtype(0.0) - if row_start + tid < row_stop: - v = arr[batch, row_start + tid, col] - return v - - def _resolve_vec_proxy(self, proxy: _VecSliceProxy) -> Any: - """Materialize a _VecSliceProxy into a scalar by dispatching to _resolve_vec2d or _resolve_vec3d.""" - if proxy.batch_idx is not None: - return self._resolve_vec3d(proxy.arr, proxy.batch_idx, proxy.row_start, proxy.row_stop, proxy.col) - return self._resolve_vec2d(proxy.arr, proxy.row_start, proxy.row_stop, proxy.col) - - def _augassign(self, other: Any, op: str) -> None: - """Handle augmented assignment (e.g. tile -= qd.outer(a, b)). - - Resolves _VecSliceProxy arguments and dispatches to _ger_sub. Only 'Sub' is supported. - """ - if isinstance(other, _OuterProduct): - if op == "Sub": - a_orig = other.a - b_orig = other.b - a = self._resolve_vec_proxy(a_orig) if isinstance(a_orig, _VecSliceProxy) else a_orig - b = ( - a - if (b_orig is a_orig) - else (self._resolve_vec_proxy(b_orig) if isinstance(b_orig, _VecSliceProxy) else b_orig) - ) - self._ger_sub(a, b) - else: - raise TypeError(f"Tile32x32: unsupported augmented assignment op '{op}' with outer product") - else: - raise TypeError(f"Tile32x32: unsupported augmented assignment with {type(other)}") - - # StructType.__call__ already defaults missing args to 0, so Tile() produces a zero-initialized tile - # without needing default values in the class definition (which @qd.dataclass doesn't support). - result = qd.dataclass(_Tile32x32) - result.SIZE = _TILE # type: ignore[reportAttributeAccessIssue] - result.zeros = result # type: ignore[reportAttributeAccessIssue] - - @qd.func - def _eye(): - t = result() - t.eye_() # type: ignore[reportAttributeAccessIssue] - return t - - result.eye = _eye # type: ignore[reportAttributeAccessIssue] - return result - - -class _Tile32x32Proxy: - """Proxy for dtype-at-point-of-use tile creation. - - Use as ``qd.simt.Tile32x32.zeros(dtype=qd.f32)`` inside a kernel. The dtype is resolved at kernel compilation - time, defaulting to the compile config's ``default_fp`` if omitted. - """ - - SIZE = _TILE - - @staticmethod - def _resolve(dtype): - from quadrants.lang import impl # pylint: disable=import-outside-toplevel - from quadrants.lang.exception import ( # pylint: disable=import-outside-toplevel - QuadrantsSyntaxError, - ) - - arch = impl.current_cfg().arch - if arch in (qd.cpu, qd.x64, getattr(qd, "arm64", None)): - raise QuadrantsSyntaxError( - "Tile32x32 requires a GPU backend (cuda, metal, vulkan, amdgpu). " f"Current arch is {arch}." - ) - if dtype is None: - dtype = impl.get_runtime().default_fp - if dtype in _tile32_cache: - return _tile32_cache[dtype] - return _make_tile32x32(dtype) - - def zeros(self, *, dtype=None): - """Zero-initialized tile.""" - return self._resolve(dtype)() - - def eye(self, *, dtype=None): - """Identity tile (diagonal = 1, rest = 0).""" - return self._resolve(dtype).eye() - - -Tile32x32Proxy = _Tile32x32Proxy() diff --git a/python/quadrants/lang/simt/tile_slicing.py b/python/quadrants/lang/simt/tile_slicing.py index bbee6b333f..c7436f7316 100644 --- a/python/quadrants/lang/simt/tile_slicing.py +++ b/python/quadrants/lang/simt/tile_slicing.py @@ -6,26 +6,23 @@ """ from quadrants.lang.exception import QuadrantsSyntaxError -from quadrants.lang.simt._tile16 import ( - _tile16_cache, +from quadrants.lang.simt._tile import ( + _tile_cache, _TileRefProxy, _TileSliceProxy, _VecSliceProxy, ) -from quadrants.lang.simt._tile32 import _tile32_cache from quadrants.lang.struct import Struct def _is_tile(value): """Return True if ``value`` is an instance of any registered tile dataclass.""" - return any(isinstance(value, t) for t in _tile16_cache.values()) or any( - isinstance(value, t) for t in _tile32_cache.values() - ) + return any(isinstance(value, t) for t in _tile_cache.values()) def _any_tile_built(): """Return True if any tile dataclass has been built (i.e. at least one tile is in use).""" - return bool(_tile16_cache) or bool(_tile32_cache) + return bool(_tile_cache) def try_tile_ref(value, _indices): diff --git a/tests/python/test_tile.py b/tests/python/test_tile.py index d71b514f92..0f0278a476 100644 --- a/tests/python/test_tile.py +++ b/tests/python/test_tile.py @@ -8,12 +8,11 @@ import quadrants as qd from quadrants.lang.exception import QuadrantsSyntaxError -from quadrants.lang.simt._tile16 import ( - _make_tile16x16, +from quadrants.lang.simt._tile import ( + _make_tile, _TileSliceProxy, _VecSliceProxy, ) -from quadrants.lang.simt._tile32 import _make_tile32x32 from tests import test_utils @@ -25,15 +24,20 @@ # --- Parametrize over tile size ---------------------------------------------------------------- +import functools as _functools # noqa: E402 import types as _types # noqa: E402 _TILE_PARAMS = [ pytest.param( - _types.SimpleNamespace(proxy=qd.simt.Tile16x16, make=_make_tile16x16, size=16, m_size=40, name="tile16"), + _types.SimpleNamespace( + proxy=qd.simt.Tile16x16, make=_functools.partial(_make_tile, 16), size=16, m_size=40, name="tile16" + ), id="tile16", ), pytest.param( - _types.SimpleNamespace(proxy=qd.simt.Tile32x32, make=_make_tile32x32, size=32, m_size=80, name="tile32"), + _types.SimpleNamespace( + proxy=qd.simt.Tile32x32, make=_functools.partial(_make_tile, 32), size=32, m_size=80, name="tile32" + ), id="tile32", ), ] @@ -421,7 +425,7 @@ def k1(src_arr: Ann, dst_arr: Ann, NCOLS: qd.i32, N: qd.Template): def test_make_caching(TILE, make_tile, tdim, m_size): - """_make_tile16x16 must return the same object for the same dtype.""" + """_make_tile must return the same object for the same (N, dtype).""" a = make_tile(qd.f32) b = make_tile(qd.f32) assert a is b