Skip to content

Merge branch 'main' into hp/tiles-use-unpacked-vector

fbde88e
Select commit
Loading
Failed to load commit list.
Sign in for the full log view
Open

[Tile] Use unpacked vector field for Tile16x16/Tile32x32 register storage #722

Merge branch 'main' into hp/tiles-use-unpacked-vector
fbde88e
Select commit
Loading
Failed to load commit list.
GitHub Actions / Coverage Report succeeded Jun 8, 2026 in 0s

Diff Coverage Report

See details below for per-line coverage annotations.

Details

Coverage Report (fbde88ef6)

Metric Value
Diff coverage (changed lines only) 59%
Overall project coverage 73%

Total: 258 lines, 106 missing, 59% covered

🟢 python/quadrants/__init__.py (100%)
🟢   59          from quadrants.lang.simt._tile import outer  # noqa: I001  # pylint: disable=import-outside-toplevel
🟢 python/quadrants/lang/simt/__init__.py (100%)
      6      from quadrants.lang.simt._tile import Tile16x16Proxy as Tile16x16
      7      from quadrants.lang.simt._tile import Tile32x32Proxy as Tile32x32
🟢   13      if name in ("Tile16x16", "Tile32x32"):
🟢   14          from quadrants.lang.simt._tile import (  # pylint: disable=import-outside-toplevel
🟢   19          proxy = Tile16x16Proxy if name == "Tile16x16" else Tile32x32Proxy
🟢   20          globals()[name] = proxy
🟢   21          return proxy
🔴 python/quadrants/lang/simt/_tile.py (57%)
      1  # pyright: reportInvalidTypeForm=false
      2  
      3  """
      4  Register-resident NxN tile operations.
      5  
      6  Each tile is an NxN matrix distributed across N threads in a subgroup, one row per thread, with each row stored in N
      7  scalar registers held in an unpacked vector field (``self.r``).  Cross-thread communication uses subgroup shuffles --
      8  no shared memory needed.
      9  
     10  A single factory ``_make_tile_class(N, dtype)`` builds the tile dataclass for both supported tile sizes (N == 16 and
     11  N == 32).  The user-facing entry points are the proxies ``qd.simt.Tile16x16`` and ``qd.simt.Tile32x32``, which defer
     12  dtype resolution to kernel compile time (defaulting to the runtime ``default_fp``).
     13  
     14  The thread's lane index (tid) is obtained internally via ``subgroup.invocation_id()``, so callers never need to pass
     15  it.  See docs/source/user_guide/tile.md for usage documentation.
     16  """
     17  
🟢   18  from typing import TYPE_CHECKING as _TYPE_CHECKING
🟢   19  from typing import Any, NoReturn
     20  
🟢   21  import quadrants as qd
     22  
🟢   23  if _TYPE_CHECKING:
     24  
🔴   25      class _TileProto:  # noqa: E303
     26          """Static type stub so pyright sees TileNxN methods correctly (shared by Tile16x16 and Tile32x32)."""
     27  
🔴   28          SIZE: int
     29  
     30          def __init__(self, *args: Any, **kwargs: Any) -> None: ...  # noqa: E704
     31          @classmethod
     32          def zeros(cls) -> "_TileProto": ...  # noqa: E704
     33          @classmethod
     34          def eye(cls) -> "_TileProto": ...  # noqa: E704
     35          def eye_(self) -> None: ...  # noqa: E704
     36          def cholesky_(self, eps: Any) -> None: ...  # noqa: E704
     37          def solve_triangular_(self, B: "_TileProto", lower: bool = True) -> None: ...  # noqa: E704
     38          def _load(self, arr: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any) -> None: ...  # noqa: E704
     39          def _store(
     40              self, arr: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any
     41          ) -> None: ...  # noqa: E704
     42          def _load3d(
     43              self, arr: Any, batch: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any
     44          ) -> None: ...  # noqa: E704
     45          def _store3d(
     46              self, arr: Any, batch: Any, row_start: Any, row_end: Any, col_start: Any, col_end: Any
     47          ) -> None: ...  # noqa: E704
     48          def _ger_sub(self, a: Any, b: Any) -> None: ...  # noqa: E704
     49          def _trsm(self, L: "_TileProto") -> None: ...  # noqa: E704
     50          def __isub__(self, other: Any) -> "_TileProto": ...  # noqa: E704
     51          def __getitem__(self, key: Any) -> Any: ...  # noqa: E704
     52          def __setitem__(self, key: Any, value: Any) -> None: ...  # noqa: E704
     53  
     54  
🟢   55  class _OuterProduct:
     56      """Deferred outer product proxy for use with augmented assignment on a Tile.
     57  
     58      Created by qd.outer(a, b). Not a quadrants expression -- only valid as the RHS of ``tile -= qd.outer(a, b)``.
     59      """
     60  
🟢   61      _qd_is_deferred = True
     62  
🟢   63      def __init__(self, a: Any, b: Any) -> None:
🟢   64          self.a = a
🟢   65          self.b = b
     66  
🟢   67      def __add__(self, other: Any) -> NoReturn:
🟢   68          raise TypeError("OuterProduct does not support composition; apply each update separately")
     69  
🟢   70      def __radd__(self, other: Any) -> NoReturn:
🔴   71          raise TypeError("OuterProduct does not support composition; apply each update separately")
     72  
     73  
🟢   74  def outer(a: Any, b: Any) -> _OuterProduct:
     75      """Create a deferred outer product for use with Tile augmented assignment.
     76  
     77      Usage::
     78  
     79          t -= qd.outer(a, b)   # equivalent to t._ger_sub(a, b)
     80          t -= qd.outer(v, v)   # symmetric case (a == b)
     81      """
🟢   82      return _OuterProduct(a, b)
     83  
     84  
🟢   85  class _DeferredProxyMixin:
     86      """Raises clear errors if a deferred tile proxy is accidentally used as a value."""
     87  
🟢   88      _proxy_description = "Tile proxy"
     89  
🟢   90      def _misuse(self, op: str = "used") -> NoReturn:
🟢   91          raise TypeError(
     92              f"{self._proxy_description} was {op}, but it is only valid in tile operations (tile[:] = ..., ... = tile, qd.outer(...))"
     93          )
     94  
🟢   95      def __add__(self, other: Any) -> NoReturn:
🟢   96          self._misuse("added")
     97  
🟢   98      def __radd__(self, other: Any) -> NoReturn:
🟢   99          self._misuse("added")
    100  
🟢  101      def __sub__(self, other: Any) -> NoReturn:
🟢  102          self._misuse("subtracted")
    103  
🟢  104      def __mul__(self, other: Any) -> NoReturn:
🟢  105          self._misuse("multiplied")
    106  
🟢  107      def __getitem__(self, key: Any) -> NoReturn:
🟢  108          self._misuse("subscripted")
    109  
🟢  110      def __repr__(self) -> str:
🟢  111          return f"<{self._proxy_description} — not a value; use with tile[:] = ... or qd.outer(...)>"
    112  
    113  
🟢  114  class _TileSliceProxy(_DeferredProxyMixin):
    115      """Deferred 2D/3D array slice for tile load/store.
    116  
    117      Created by subscripting a Field or ndarray with 2D slices, e.g. ``arr[row_start:row_stop, col_start:col_stop]``.
    118      Not a quadrants expression -- only valid as the RHS of a tile assignment (load) or as the LHS target (store).
    119      """
    120  
🟢  121      _qd_is_deferred = True
🟢  122      _proxy_description = "Array slice proxy (arr[r0:r1, c0:c1])"
    123  
🟢  124      def __init__(
    125          self, arr: Any, row_start: Any, row_stop: Any, col_start: Any, col_stop: Any, batch_idx: Any = None
    126      ) -> None:
🟢  127          self.arr = arr
🟢  128          self.row_start = row_start
🟢  129          self.row_stop = row_stop
🟢  130          self.col_start = col_start
🟢  131          self.col_stop = col_stop
🟢  132          self.batch_idx = batch_idx
    133  
🟢  134      def _assign(self, tile: Any) -> None:
    135          """Store path: arr[r:r+n_rows, c:c+n_cols] = tile."""
🟢  136          if self.batch_idx is not None:
🟢  137              tile._store3d(self.arr, self.batch_idx, self.row_start, self.row_stop, self.col_start, self.col_stop)
    138          else:
🟢  139              tile._store(self.arr, self.row_start, self.row_stop, self.col_start, self.col_stop)
    140  
    141  
🟢  142  class _VecSliceProxy(_DeferredProxyMixin):
    143      """Deferred column-vector load from a 2D/3D array.
    144  
    145      Created by ``arr[row_start:row_stop, col]`` or ``arr[batch_idx, row_start:row_stop, col]``.
    146      Each subgroup thread loads one element; out-of-range threads get 0.
    147      Only valid as an argument to ``qd.outer()`` in tile augmented assignment.
    148      """
    149  
🟢  150      _qd_is_deferred = True
🟢  151      _proxy_description = "Vec slice proxy (arr[r0:r1, col])"
    152  
🟢  153      def __init__(self, arr: Any, row_start: Any, row_stop: Any, col: Any, batch_idx: Any = None) -> None:
🟢  154          self.arr = arr
🟢  155          self.row_start = row_start
🟢  156          self.row_stop = row_stop
🟢  157          self.col = col
🟢  158          self.batch_idx = batch_idx
    159  
    160  
🟢  161  class _TileRefProxy:
    162      """Proxy returned by tile[:] for the LHS of a load assignment.
    163  
    164      Enables ``tile[:] = arr[r:r+N, c:n]``.  The ``[:]`` is required to distinguish in-place tile loads from
    165      variable rebinding.
    166      """
    167  
🟢  168      _qd_is_deferred = True
    169  
🟢  170      def __init__(self, tile: Any) -> None:
🟢  171          self.tile = tile
    172  
🟢  173      def _assign(self, value: Any) -> None:
    174          """Load path: tile[:] = arr[r:r+n, c:c+n]. Dispatches to _load or _load3d."""
🟢  175          if isinstance(value, _TileSliceProxy):
🟢  176              if value.batch_idx is not None:
🟢  177                  self.tile._load3d(
    178                      value.arr, value.batch_idx, value.row_start, value.row_stop, value.col_start, value.col_stop
    179                  )
    180              else:
🟢  181                  self.tile._load(value.arr, value.row_start, value.row_stop, value.col_start, value.col_stop)
    182          else:
🔴  183              raise TypeError(f"Tile[:] can only be assigned from an array slice, got {type(value)}")
    184  
    185  
🟢  186  _tile_cache: dict = {}
    187  
    188  
🟢  189  def _make_tile(N: int, dtype=None) -> "type[_TileProto]":
    190      """Create a TileNxN dataclass whose registers use the given scalar dtype (qd.f32 or qd.f64).
    191  
    192      This is an internal factory.  Use ``qd.simt.Tile16x16`` / ``qd.simt.Tile32x32`` (the proxies) instead.
    193      """
🟢  194      if dtype is None:
🔴  195          dtype = qd.f32
🟢  196      key = (N, dtype)
🟢  197      if key in _tile_cache:
🟢  198          return _tile_cache[key]  # pyright: ignore[reportReturnType]
🟢  199      cls = _make_tile_class(N, dtype)
🟢  200      _tile_cache[key] = cls
🟢  201      return cls  # pyright: ignore[reportReturnType]
    202  
    203  
🟢  204  def _make_tile_class(N: int, dtype):
🟢  205      name = f"Tile{N}x{N}"
    206  
🟢  207      class _Tile:
    208          """An NxN tile distributed one row per subgroup thread, with each row held in N scalar registers via an
    209          unpacked vector field.  ``TileNxN()`` creates a zero tile."""
    210  
🟢  211          r: qd.types.vector(N, dtype, unpacked=True)
    212  
🟢  213          @qd.func
🟢  214          def _load(self, arr: qd.template(), row_start, row_stop, col_start, col_stop):
    215              """Load from a 2D array within [row_start, row_stop) x [col_start, col_stop).
    216  
    217              Each thread loads arr[row_start + tid, col_start:col_stop].  Threads where row_start + tid >= row_stop
    218              skip the load (tile row unchanged).
    219              """
🔴  220              arr_row_stop = arr.shape[0]
🔴  221              if arr_row_stop < row_stop:
🔴  222                  row_stop = arr_row_stop
🔴  223              row = row_start + qd.simt.subgroup.invocation_id()
🔴  224              if row < row_stop:
🔴  225                  arr_col_stop = arr.shape[1]
🔴  226                  if arr_col_stop < col_stop:
🔴  227                      col_stop = arr_col_stop
🔴  228                  for j in qd.static(range(N)):
🔴  229                      if col_start + j < col_stop:
🔴  230                          self.r[j] = arr[row, col_start + j]
    231  
🟢  232          @qd.func
🟢  233          def _load3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, col_stop):
    234              """Load from a 3D array within [row_start, row_stop) x [col_start, col_stop).
    235  
    236              Each thread loads arr[batch, row_start+tid, col_start:col_stop].  Threads where row_start + tid >=
    237              row_stop skip the load (tile row unchanged).
    238              """
🔴  239              arr_row_stop = arr.shape[1]
🔴  240              if arr_row_stop < row_stop:
🔴  241                  row_stop = arr_row_stop
🔴  242              row = row_start + qd.simt.subgroup.invocation_id()
🔴  243              if row < row_stop:
🔴  244                  arr_col_stop = arr.shape[2]
🔴  245                  if arr_col_stop < col_stop:
🔴  246                      col_stop = arr_col_stop
🔴  247                  for j in qd.static(range(N)):
🔴  248                      if col_start + j < col_stop:
🔴  249                          self.r[j] = arr[batch, row, col_start + j]
    250  
🟢  251          @qd.func
🟢  252          def _store(self, arr: qd.template(), row_start, row_stop, col_start, col_stop):
    253              """Store to a 2D array within [row_start, row_stop) x [col_start, col_stop).
    254  
    255              Each thread stores to arr[row_start + tid, col_start:col_stop].  Threads where row_start + tid >=
    256              row_stop skip the store.
    257              """
🔴  258              arr_row_stop = arr.shape[0]
🔴  259              if arr_row_stop < row_stop:
🔴  260                  row_stop = arr_row_stop
🔴  261              row = row_start + qd.simt.subgroup.invocation_id()
🔴  262              if row < row_stop:
🔴  263                  arr_col_stop = arr.shape[1]
🔴  264                  if arr_col_stop < col_stop:
🔴  265                      col_stop = arr_col_stop
🔴  266                  for j in qd.static(range(N)):
🔴  267                      if col_start + j < col_stop:
🔴  268                          arr[row, col_start + j] = self.r[j]
    269  
🟢  270          @qd.func
🟢  271          def _store3d(self, arr: qd.template(), batch, row_start, row_stop, col_start, col_stop):
    272              """Store to a 3D array within [row_start, row_stop) x [col_start, col_stop).
    273  
    274              Each thread stores to arr[batch, row_start+tid, col_start:col_stop].  Threads where row_start + tid >=
    275              row_stop skip the store.
    276              """
🔴  277              arr_row_stop = arr.shape[1]
🔴  278              if arr_row_stop < row_stop:
🔴  279                  row_stop = arr_row_stop
🔴  280              row = row_start + qd.simt.subgroup.invocation_id()
🔴  281              if row < row_stop:
🔴  282                  arr_col_stop = arr.shape[2]
🔴  283                  if arr_col_stop < col_stop:
🔴  284                      col_stop = arr_col_stop
🔴  285                  for j in qd.static(range(N)):
🔴  286                      if col_start + j < col_stop:
🔴  287                          arr[batch, row, col_start + j] = self.r[j]
    288  
🟢  289          @qd.func
🟢  290          def eye_(self):
    291              """Set this tile to the NxN identity matrix.  Each thread sets its diagonal element to 1.0 and all
    292              others to 0.0."""
🔴  293              tid = qd.simt.subgroup.invocation_id()
🔴  294              for j in qd.static(range(N)):
🔴  295                  self.r[j] = 1.0 if tid == j else 0.0
    296  
🟢  297          @qd.func
🟢  298          def _ger_sub(self, a, b):
    299              """General rank-1 subtract in-place: self -= a @ b^T."""
🔴  300              for j in qd.static(range(N)):
🔴  301                  bc = qd.simt.subgroup.shuffle(b, qd.u32(j))
🔴  302                  self.r[j] = self.r[j] - a * bc
    303  
🟢  304          @qd.func
🟢  305          def cholesky_(self, eps):
    306              """In-place NxN Cholesky factorization via subgroup shuffles.
    307  
    308              On return, the lower triangle holds L such that A = L @ L^T.  Diagonal clamped to
    309              sqrt(max(value, eps)) for numerical stability.
    310              """
    311              # ``k`` and ``j`` are wrapped in qd.static so the ``if k > j`` predicate folds at compile time and the
    312              # ``self.r[k]`` / ``self.r[j]`` accesses resolve to a single unpacked-register slot per use (no runtime
    313              # cascade).  The per-lane row-norm used for the diagonal update is carried in ``my_norm_sq``, so each
    314              # diagonal step is O(1) rather than O(k).  The off-diagonal ``dot`` is split into two interleaved partial
    315              # sums (``dot0`` / ``dot1``) so the back-to-back FMA dependency chain is cut in half, exposing more
    316              # instruction-level parallelism.
🔴  317              tid = qd.i32(qd.simt.subgroup.invocation_id())
🔴  318              my_norm_sq = qd.cast(0.0, dtype)
🔴  319              for k in qd.static(range(N)):
🔴  320                  diag_val = qd.cast(0.0, dtype)
🔴  321                  if tid == k:
🔴  322                      diag_val = qd.sqrt(qd.max(self.r[k] - my_norm_sq, eps))
🔴  323                      self.r[k] = diag_val
    324  
🔴  325                  diag_k = qd.simt.subgroup.shuffle(diag_val, qd.u32(k))
    326  
🔴  327                  dot0 = qd.cast(0.0, dtype)
🔴  328                  dot1 = qd.cast(0.0, dtype)
🔴  329                  for j in qd.static(range(N)):
🔴  330                      if k > j:
🔴  331                          my_col = self.r[j]
🔴  332                          Lkj = qd.simt.subgroup.shuffle(my_col, qd.u32(k))
🔴  333                          if j % 2 == 0:
🔴  334                              dot0 += Lkj * my_col  # type: ignore[reportOperatorIssue]
    335                          else:
🔴  336                              dot1 += Lkj * my_col  # type: ignore[reportOperatorIssue]
🔴  337                  dot = dot0 + dot1
    338  
🔴  339                  new_val = qd.cast(0.0, dtype)
🔴  340                  if tid > k:  # type: ignore[reportOperatorIssue]
🔴  341                      new_val = (self.r[k] - dot) / diag_k  # type: ignore[reportOperatorIssue]
🔴  342                      self.r[k] = new_val
🔴  343                  if tid > k:  # type: ignore[reportOperatorIssue]
🔴  344                      my_norm_sq += new_val * new_val
    345  
🟢  346          @qd.func
🟢  347          def _trsm(self, L):
    348              """In-place triangular solve: solve self @ L^T = B (original self).
    349  
    350              L is a TileNxN holding the lower-triangular Cholesky factor (from cholesky_).  On return, self holds the
    351              solution X.
    352              """
🔴  353              for c in qd.static(range(N)):
🔴  354                  dot = qd.cast(0.0, dtype)
🔴  355                  for j in qd.static(range(N)):
🔴  356                      if c > j:
🔴  357                          Lkj = qd.simt.subgroup.shuffle(L.r[j], qd.u32(c))
🔴  358                          dot += self.r[j] * Lkj  # type: ignore[reportOperatorIssue]
    359  
🔴  360                  diag_c = qd.simt.subgroup.shuffle(L.r[c], qd.u32(c))
🔴  361                  self.r[c] = (self.r[c] - dot) / diag_c  # type: ignore[reportOperatorIssue]
    362  
🟢  363          def solve_triangular_(self, B: Any, lower: bool = True) -> None:
    364              """Triangular solve: X @ self^T = B, storing result X in B in-place.
    365  
    366              self must be lower-triangular and non-singular (all diagonal elements non-zero).  Passing a singular
    367              matrix causes division by zero, producing inf/NaN without warning.  Only lower=True is supported.
    368              """
🟢  369              if not lower:
🟢  370                  raise TypeError(f"{name}.solve_triangular_: only lower=True is supported")
🟢  371              B._trsm(self)
    372  
🟢  373          @qd.func
🟢  374          def _resolve_vec2d(self, arr: qd.template(), row_start, row_stop, col):
    375              """Load one scalar per thread from a 2D array column, clamped to array bounds."""
🔴  376              tid = qd.i32(qd.simt.subgroup.invocation_id())
🔴  377              arr_row_stop = arr.shape[0]
🔴  378              if arr_row_stop < row_stop:
🔴  379                  row_stop = arr_row_stop
🔴  380              v = dtype(0.0)
🔴  381              if row_start + tid < row_stop:
🔴  382                  v = arr[row_start + tid, col]
🔴  383              return v
    384  
🟢  385          @qd.func
🟢  386          def _resolve_vec3d(self, arr: qd.template(), batch, row_start, row_stop, col):
    387              """Load one scalar per thread from a 3D array column, clamped to array bounds."""
🔴  388              tid = qd.i32(qd.simt.subgroup.invocation_id())
🔴  389              arr_row_stop = arr.shape[1]
🔴  390              if arr_row_stop < row_stop:
🔴  391                  row_stop = arr_row_stop
🔴  392              v = dtype(0.0)
🔴  393              if row_start + tid < row_stop:
🔴  394                  v = arr[batch, row_start + tid, col]
🔴  395              return v
    396  
🟢  397          def _resolve_vec_proxy(self, proxy: _VecSliceProxy) -> Any:
    398              """Materialize a _VecSliceProxy into a scalar by dispatching to _resolve_vec2d or _resolve_vec3d."""
🟢  399              if proxy.batch_idx is not None:
🟢  400                  return self._resolve_vec3d(proxy.arr, proxy.batch_idx, proxy.row_start, proxy.row_stop, proxy.col)
🟢  401              return self._resolve_vec2d(proxy.arr, proxy.row_start, proxy.row_stop, proxy.col)
    402  
🟢  403          def _augassign(self, other: Any, op: str) -> None:
    404              """Handle augmented assignment (e.g. tile -= qd.outer(a, b)).
    405  
    406              Resolves _VecSliceProxy arguments and dispatches to _ger_sub.  Only 'Sub' is supported.
    407              """
🟢  408              if isinstance(other, _OuterProduct):
🟢  409                  if op == "Sub":
🟢  410                      a_orig = other.a
🟢  411                      b_orig = other.b
🟢  412                      a = self._resolve_vec_proxy(a_orig) if isinstance(a_orig, _VecSliceProxy) else a_orig
🟢  413                      b = (
    414                          a
    415                          if (b_orig is a_orig)
    416                          else (self._resolve_vec_proxy(b_orig) if isinstance(b_orig, _VecSliceProxy) else b_orig)
    417                      )
🟢  418                      self._ger_sub(a, b)
    419                  else:
🟢  420                      raise TypeError(f"{name}: unsupported augmented assignment op '{op}' with outer product")
    421              else:
🟢  422                  raise TypeError(f"{name}: unsupported augmented assignment with {type(other)}")
    423  
🟢  424      _Tile.__name__ = f"_{name}"
🟢  425      _Tile.__qualname__ = f"_make_tile_class.<locals>._{name}"
    426  
    427      # StructType.__call__ already defaults missing args to 0, so Tile() produces a zero-initialized tile
    428      # without needing default values in the class definition (which @qd.dataclass doesn't support).
🟢  429      result = qd.dataclass(_Tile)
🟢  430      result.SIZE = N  # type: ignore[reportAttributeAccessIssue]
🟢  431      result.zeros = result  # type: ignore[reportAttributeAccessIssue]
    432  
🟢  433      @qd.func
🟢  434      def _eye():
🔴  435          t = result()
🔴  436          t.eye_()  # type: ignore[reportAttributeAccessIssue]
🔴  437          return t
    438  
🟢  439      result.eye = _eye  # type: ignore[reportAttributeAccessIssue]
🟢  440      return result
    441  
    442  
🟢  443  class _TileProxy:
    444      """Proxy for dtype-at-point-of-use tile creation.
    445  
    446      Use as ``qd.simt.Tile16x16.zeros(dtype=qd.f32)`` or ``qd.simt.Tile32x32.zeros(dtype=qd.f32)`` inside a kernel.
    447      The dtype is resolved at kernel compilation time, defaulting to the compile config's ``default_fp`` if omitted.
    448      """
    449  
🟢  450      def __init__(self, N: int) -> None:
🟢  451          self._N = N
🟢  452          self.SIZE = N
    453  
🟢  454      def _resolve(self, dtype):
🟢  455          from quadrants.lang import impl  # pylint: disable=import-outside-toplevel
🟢  456          from quadrants.lang.exception import (  # pylint: disable=import-outside-toplevel
    457              QuadrantsSyntaxError,
    458          )
    459  
🟢  460          arch = impl.current_cfg().arch
🟢  461          if arch in (qd.cpu, qd.x64, getattr(qd, "arm64", None)):
🟢  462              raise QuadrantsSyntaxError(
    463                  f"Tile{self._N}x{self._N} requires a GPU backend (cuda, metal, vulkan, amdgpu). "
    464                  f"Current arch is {arch}."
    465              )
🟢  466          if dtype is None:
🟢  467              dtype = impl.get_runtime().default_fp
🟢  468          return _make_tile(self._N, dtype)
    469  
🟢  470      def zeros(self, *, dtype=None):
    471          """Zero-initialized tile."""
🟢  472          return self._resolve(dtype)()
    473  
🟢  474      def eye(self, *, dtype=None):
    475          """Identity tile (diagonal = 1, rest = 0)."""
🟢  476          return self._resolve(dtype).eye()
    477  
    478  
🟢  479  Tile16x16Proxy = _TileProxy(16)
🟢  480  Tile32x32Proxy = _TileProxy(32)
🟢 python/quadrants/lang/simt/tile_slicing.py (100%)
🟢    9  from quadrants.lang.simt._tile import (
     10      _tile_cache,
🟢   20      return any(isinstance(value, t) for t in _tile_cache.values())
🟢   25      return bool(_tile_cache)
🟢 tests/python/test_tile.py (100%)
🟢   11  from quadrants.lang.simt._tile import (
     12      _make_tile,
🟢   27  import functools as _functools  # noqa: E402
     32          _types.SimpleNamespace(
     33              proxy=qd.simt.Tile16x16, make=_functools.partial(_make_tile, 16), size=16, m_size=40, name="tile16"
     34          ),
     38          _types.SimpleNamespace(
     39              proxy=qd.simt.Tile32x32, make=_functools.partial(_make_tile, 32), size=32, m_size=80, name="tile32"
     40          ),
    428      """_make_tile must return the same object for the same (N, dtype)."""