Skip to content

[Graph] Add qd.checkpoint#725

Open
hughperkins wants to merge 95 commits into
mainfrom
hp/graph-checkpoint
Open

[Graph] Add qd.checkpoint#725
hughperkins wants to merge 95 commits into
mainfrom
hp/graph-checkpoint

factor checkpoint launch helpers out of GfxRuntime::launch_kernel

3b84936
Select commit
Loading
Failed to load commit list.
Sign in for the full log view
GitHub Actions / Coverage Report succeeded Jun 18, 2026 in 0s

Diff Coverage Report

Per-line coverage annotations (truncated — see artifacts for full report).

Details

Coverage Report (3b849362e)

Metric Value
Diff coverage (changed lines only) 91%
Overall project coverage 72%

Total: 943 lines, 82 missing, 91% covered

🟢 python/quadrants/lang/_quadrants_callable.py (83%)
🟢   99      def resume(self, *args, from_checkpoint, **kwargs):
    100          """Continues a paused graph kernel from the checkpoint labelled ``from_checkpoint``.
    101  
    102          .. warning::
    103  
    104              **Experimental.** ``kernel.resume`` is part of the experimental ``qd.checkpoint`` surface; the signature
    105              (in particular the ``from_checkpoint=`` kwarg) and behaviour may change in any future release without a
    106              deprecation cycle.
    107  
    108          Use only on ``@qd.kernel(graph=True, checkpoints=True)`` kernels with at least one
    109          ``qd.checkpoint(cp_id, yield_on=flag)`` block. ``from_checkpoint`` is a ``cp_id`` label (typically an
    110          ``IntEnum`` value, often ``status.checkpoint`` from the previous launch): everything before that label in
    111          source order is skipped on this launch, and execution continues from there. The host loop pattern is::
    112  
    113              from enum import IntEnum
    114  
    115              class Stage(IntEnum):
    116                  SIM = 0
    117  
    118              overflow_flag[()] = 0  # initialise before the first launch
    119              status = step(arr, overflow_flag, newton_cond)
    120              while status.yielded:
    121                  handle(status.checkpoint, ...)
    122                  overflow_flag[()] = 0  # the framework never clears your yield_on flag
    123                  status = step.resume(arr, overflow_flag, newton_cond,
    124                                       from_checkpoint=status.checkpoint)
    125  
    126          Returns the same ``GraphStatus`` shape as the plain call.
    127  
    128          Raises ``RuntimeError`` if invoked on a kernel without any ``yield_on=`` checkpoint, or if ``from_checkpoint``
    129          does not match any declared ``cp_id`` in the kernel.
    130          """
🟢  131          if not isinstance(from_checkpoint, int):
🟢  132              raise RuntimeError(
    133                  f"from_checkpoint= must be an int or IntEnum value matching a `qd.checkpoint(cp_id=...)` label in "
    134                  f"the kernel (typically `status.checkpoint` from the previous launch's GraphStatus); "
    135                  f"got {from_checkpoint!r}."
    136              )
    137          # Smuggle the resume cookie past the AST-mapped kwargs path; `Kernel.__call__` pops it before anything else
    138          # looks at kwargs.
🟢  139          return self.wrapper.__call__(*args, _qd_from_checkpoint=from_checkpoint, **kwargs)
    140  
    170  
🟢  171      def resume(self, *args, from_checkpoint, **kwargs):
    172          """Bound-method form of `QuadrantsCallable.resume` (see that docstring)."""
🔴  173          return self.quadrants_callable.resume(self.instance, *args, from_checkpoint=from_checkpoint, **kwargs)
🟢 python/quadrants/lang/ast/ast_transformer.py (100%)
🟢   29  from quadrants.lang.ast.ast_transformers.checkpoint_transformer import (
     30      CheckpointTransformer,
     31  )
🟢 1368      @staticmethod
🟢 1369      def _is_checkpoint_call(node: ast.expr, global_vars: dict):
   1370          """Thin forwarding wrapper around ``CheckpointTransformer.is_checkpoint_call``; the actual logic lives in module
   1371          ``ast_transformers/checkpoint_transformer.py`` to keep this file from growing per-feature. Returns a
   1372          ``CheckpointCallInfo`` or ``None``."""
🟢 1373          return CheckpointTransformer.is_checkpoint_call(node, global_vars)
   1374  
   1613  
🟢 1614          checkpoint_info = ASTTransformer._is_checkpoint_call(item.context_expr, ctx.global_vars)
🟢 1615          if checkpoint_info is not None:
🟢 1616              return ASTTransformer._build_checkpoint_with(ctx, node, checkpoint_info)
   1617  
🟢 1619              raise QuadrantsSyntaxError(
   1620                  "'with' in Quadrants kernels only supports qd.stream_parallel() or qd.checkpoint()"
   1621              )
🟢 1629      @staticmethod
🟢 1630      def _build_checkpoint_with(
   1631          ctx: ASTTransformerFuncContext,
   1632          node: ast.With,
   1633          info,
   1634      ) -> None:
   1635          """Thin forwarding wrapper around ``CheckpointTransformer.build_checkpoint_with``; the actual logic lives in
   1636          ``ast_transformers/checkpoint_transformer.py``."""
🟢 1637          return CheckpointTransformer.build_checkpoint_with(ctx, node, info, build_stmts)
   1638  
🟢 python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py (89%)
      1  # type: ignore
      2  """AST recognition, validation, and auto-wrap lowering for ``qd.checkpoint(...)`` ``with`` blocks.
      3  
      4  Lives alongside ``call_transformer.py`` / ``function_def_transformer.py`` so that ``ast_transformer.py`` doesn't have to
      5  grow per-feature. ``ASTTransformer.build_With`` and ``ASTTransformer._is_checkpoint_call`` forward calls into the static
      6  methods here. ``FunctionDefTransformer.build_FunctionDef`` calls ``auto_wrap_for_loops`` on the kernel body when the
      7  kernel was decorated with ``@qd.kernel(graph=True, checkpoints=True)``, so that every top-level for-loop not already
      8  inside a ``with qd.checkpoint(...)`` becomes its own implicit no-yield checkpoint.
      9  
     10  See ``docs/source/user_guide/graph.md`` for the user-facing surface and ``perso_hugh/doc/qipc/reentrant.md`` for the
     11  design.
     12  """
     13  
🟢   14  from __future__ import annotations
     15  
🟢   16  import ast
🟢   17  from dataclasses import dataclass
     18  
🟢   19  from quadrants.lang.ast.ast_transformer_utils import ASTTransformerFuncContext
🟢   20  from quadrants.lang.exception import QuadrantsSyntaxError
     21  
     22  # Sentinel name used by `_kernel_coverage.py` (`FIELD_VAR_NAME`) for the probe-tracking field. The checkpoint validator
     23  # hard-codes the literal so coverage probes can be exempted from the bare-statement check without taking a runtime dep
     24  # on the optional `_kernel_coverage` module (it is only imported when `QD_KERNEL_COVERAGE=1`). The two must stay in
     25  # sync; if `FIELD_VAR_NAME` ever changes, update this constant and the corresponding test.
🟢   26  _KERNEL_COVERAGE_FIELD_NAME = "_qd_cov"
     27  
     28  # Attribute name attached to synthetic `qd.checkpoint()` AST Call nodes generated by `auto_wrap_for_loops` for the
     29  # implicit-no-yield wrap of top-level for-loops in a `checkpoints=True` kernel. Distinguishes auto-wrap calls from
     30  # user-written `qd.checkpoint(cp_id, yield_on)` calls so the validator can apply different rules (implicit calls take no
     31  # args; explicit calls require both `cp_id` and `yield_on`).
🟢   32  _IMPLICIT_MARKER_ATTR = "_qd_implicit"
     33  
     34  
🟢   35  @dataclass
🟢   36  class CheckpointCallInfo:
     37      """Resolved metadata for a `qd.checkpoint(...)` call recognised in the AST.
     38  
     39      - ``cp_id``: the user-supplied label (an ``int`` or ``IntEnum`` value), or ``None`` for an auto-wrap implicit
     40        checkpoint.
     41      - ``yield_on``: name of the kernel parameter passed as ``yield_on=`` (an ``ast.Name`` is required), or ``None`` for
     42        an implicit checkpoint.
     43      - ``is_implicit``: ``True`` iff this Call was synthesised by ``auto_wrap_for_loops``.
     44      """
     45  
🟢   46      cp_id: int | None
🟢   47      yield_on: str | None
🟢   48      is_implicit: bool
     49  
     50  
🟢   51  class CheckpointTransformer:
🟢   52      @staticmethod
🟢   53      def _is_coverage_probe_assign(stmt: ast.stmt) -> bool:
     54          """Return True iff *stmt* is the synthesized ``_qd_cov[<probe_id>] = 1`` assignment inserted by
     55          ``_kernel_coverage.py`` when ``QD_KERNEL_COVERAGE=1``. Keeping these out of the bare-statement rejection in
     56          ``build_checkpoint_with`` lets coverage CI exercise every checkpoint kernel without the user having to wrap the
     57          synthetic probes themselves.
     58          """
🟢   59          if not isinstance(stmt, ast.Assign) or len(stmt.targets) != 1:
🔴   60              return False
🟢   61          tgt = stmt.targets[0]
🟢   62          if not isinstance(tgt, ast.Subscript):
🔴   63              return False
🟢   64          return isinstance(tgt.value, ast.Name) and tgt.value.id == _KERNEL_COVERAGE_FIELD_NAME
     65  
🟢   66      @staticmethod
🟢   67      def _looks_like_checkpoint_call(node: ast.expr) -> bool:
     68          """Cheap structural check: is *node* `qd.checkpoint(...)` or bare `checkpoint(...)`? Doesn't validate args or
     69          resolve names; used by ``auto_wrap_for_loops`` to skip already-wrapped `with` blocks without paying for the full
     70          argument-resolution pass (which needs ``global_vars`` and a real Kernel object)."""
🟢   71          if not isinstance(node, ast.Call):
🔴   72              return False
🟢   73          func = node.func
🟢   74          return (isinstance(func, ast.Attribute) and func.attr == "checkpoint") or (
     75              isinstance(func, ast.Name) and func.id == "checkpoint"
     76          )
     77  
🟢   78      @staticmethod
🟢   79      def is_explicit_checkpoint_with(stmt: ast.stmt) -> bool:
     80          """Return True iff *stmt* is a single-item `with qd.checkpoint(...):` block (explicit or implicit). Used by
     81          ``auto_wrap_for_loops`` to leave already-wrapped for-loops alone."""
🟢   82          if not isinstance(stmt, ast.With) or len(stmt.items) != 1:
🟢   83              return False
🟢   84          return CheckpointTransformer._looks_like_checkpoint_call(stmt.items[0].context_expr)
     85  
🟢   86      @staticmethod
🟢   87      def _resolve_cp_id(node: ast.expr, global_vars: dict) -> int:
     88          """Resolve the first positional arg of `qd.checkpoint(cp_id, ...)` to a Python int (or IntEnum instance).
     89  
     90          Accepts (a) `ast.Constant` int literals (``qd.checkpoint(0, ...)``), (b) `ast.Attribute` references to
     91          `IntEnum` values resolved against the kernel's `global_vars` (``qd.checkpoint(Stage.SIM, ...)`` where ``Stage``
     92          is an `IntEnum` defined at module scope), and (c) `ast.Name` references to module-level int constants
     93          (``qd.checkpoint(CP_LOAD, ...)`` where ``CP_LOAD`` is an int defined at module scope). For (b) and (c) we
     94          return the resolved value AS-IS (without re-wrapping through `int(...)`) so an IntEnum member identity is
     95          preserved end-to-end -- the user writes `qd.checkpoint(Stage.SIM, ...)`, then reads `status.checkpoint` and
     96          gets back `Stage.SIM`, not the raw int. Rejects everything else with a clear error so the user gets a
     97          compile-time diagnostic rather than a confusing template-mapper failure later.
     98          """
     99          # Plain `qd.checkpoint(0, ...)` literal.
🟢  100          if isinstance(node, ast.Constant) and isinstance(node.value, int) and not isinstance(node.value, bool):
🟢  101              return node.value
    102          # `qd.checkpoint(Stage.SIM, ...)` -- look up `Stage` in the kernel's module globals, then `.SIM` on it.
🟢  103          if isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
🟢  104              container_name = node.value.id
🟢  105              if container_name in global_vars:
🟢  106                  container = global_vars[container_name]
🟢  107                  if hasattr(container, node.attr):
🟢  108                      val = getattr(container, node.attr)
🟢  109                      if isinstance(val, int) and not isinstance(val, bool):
🟢  110                          return val
    111          # `qd.checkpoint(CP_LOAD, ...)` -- bare name referencing a module-level int constant.
🟢  112          if isinstance(node, ast.Name) and node.id in global_vars:
🔴  113              val = global_vars[node.id]
🔴  114              if isinstance(val, int) and not isinstance(val, bool):
🔴  115                  return val
🟢  116          raise QuadrantsSyntaxError(
    117              "qd.checkpoint() first argument must be an int literal or an IntEnum value (e.g. `Stage.SIM`). Got an "
    118              f"unresolvable expression at line {getattr(node, 'lineno', '?')}; the cp_id must be statically "
    119              "determinable at AST-walk time so the framework can build the label -> internal cp_id map and the "
    120              "host-side resume API can refer to checkpoints by name."
    121          )
    122  
🟢  123      @staticmethod
🟢  124      def is_checkpoint_call(node: ast.expr, global_vars: dict) -> CheckpointCallInfo | None:
    125          """If *node* is a `qd.checkpoint(...)` call return a `CheckpointCallInfo`; otherwise return `None`.
    126  
    127          Validates the call shape and raises `QuadrantsSyntaxError` for misuse so the user gets a clear message at the
    128          `with` site rather than a vague "not stream_parallel" error later. Implicit calls (the synthetic
    129          `qd.checkpoint()` no-arg calls produced by `auto_wrap_for_loops`) are recognised via the `_qd_implicit`
    130          attribute and bypass user-call argument validation entirely.
    131          """
🟢  132          if not CheckpointTransformer._looks_like_checkpoint_call(node):
🟢  133              return None
    134          # Auto-wrap-synthesised implicit checkpoint: no user-facing args, no validation.
🟢  135          if getattr(node, _IMPLICIT_MARKER_ATTR, False):
🟢  136              return CheckpointCallInfo(cp_id=None, yield_on=None, is_implicit=True)
    137          # User-written `qd.checkpoint(cp_id, yield_on)` -- both args are required.
🟢  138          if len(node.args) + len(node.keywords) == 0:
🟢  139              raise QuadrantsSyntaxError("qd.checkpoint() takes two arguments: `qd.checkpoint(cp_id, yield_on=flag)`.")
    140          # Collect cp_id (positional 0 or kw `cp_id=`) and yield_on (positional 1 or kw `yield_on=`).
🟢  141          cp_id_arg: ast.expr | None = None
🟢  142          yield_on_arg: ast.expr | None = None
🟢  143          for i, arg in enumerate(node.args):
🟢  144              if i == 0:
🟢  145                  cp_id_arg = arg
🔴  146              elif i == 1:
🔴  147                  yield_on_arg = arg
    148              else:
🔴  149                  raise QuadrantsSyntaxError(
    150                      f"qd.checkpoint() takes at most 2 positional arguments (cp_id, yield_on); got {len(node.args)}"
    151                  )
🟢  152          for kw in node.keywords:
🟢  153              if kw.arg == "cp_id":
🔴  154                  if cp_id_arg is not None:
🔴  155                      raise QuadrantsSyntaxError("qd.checkpoint() got `cp_id` both positionally and as a keyword")
🔴  156                  cp_id_arg = kw.value
🟢  157              elif kw.arg == "yield_on":
🟢  158                  if yield_on_arg is not None:
🔴  159                      raise QuadrantsSyntaxError("qd.checkpoint() got `yield_on` both positionally and as a keyword")
🟢  160                  yield_on_arg = kw.value
    161              else:
🟢  162                  raise QuadrantsSyntaxError(
    163                      f"qd.checkpoint() got unexpected keyword argument {kw.arg!r}; only 'cp_id' and 'yield_on' are "
    164                      "supported"
    165                  )
🟢  166          if cp_id_arg is None:
🔴  167              raise QuadrantsSyntaxError(
    168                  "qd.checkpoint() is missing required argument `cp_id` (e.g. `qd.checkpoint(0, yield_on=flag)` or "
    169                  "`qd.checkpoint(Stage.SIM, yield_on=flag)`)"
    170              )
🟢  171          if yield_on_arg is None:
🟢  172              raise QuadrantsSyntaxError(
    173                  "qd.checkpoint() is missing required argument `yield_on` (e.g. "
    174                  "`qd.checkpoint(0, yield_on=overflow_flag)`)"
    175              )
🟢  176          if not isinstance(yield_on_arg, ast.Name):
🟢  177              raise QuadrantsSyntaxError(
    178                  "qd.checkpoint(yield_on=...) must be the bare name of a kernel parameter (e.g. "
    179                  "`yield_on=overflow_flag`); expressions are not supported"
    180              )
🟢  181          cp_id_value = CheckpointTransformer._resolve_cp_id(cp_id_arg, global_vars)
🟢  182          return CheckpointCallInfo(cp_id=cp_id_value, yield_on=yield_on_arg.id, is_implicit=False)
    183  
🟢  184      @staticmethod
🟢  185      def build_checkpoint_with(
    186          ctx: ASTTransformerFuncContext,
    187          node: ast.With,
    188          info: CheckpointCallInfo,
    189          build_stmts,
    190      ) -> None:
    191          """Handles a `with qd.checkpoint(...):` block (explicit or auto-wrapped implicit).
    192  
    193          Validates the use-site (kernel must be `graph=True, checkpoints=True`, no nesting; for explicit calls, the
    194          `yield_on` arg must be a kernel parameter and the `cp_id` must be unique across the kernel) and appends an entry
    195          to `kernel.checkpoint_yield_on_args` + `kernel.checkpoint_user_labels_by_cp_id`. Walks the body transparently --
    196          for-loops inside the `with` become normal top-level for-loops in the kernel's frontend IR. The internal `cp_id`
    197          is assigned by declaration order (the C++ `ASTBuilder.begin_checkpoint()` counter, mirrored as the list index in
    198          `checkpoint_yield_on_args`).
    199          """
🟢  200          if not ctx.is_kernel:
🔴  201              raise QuadrantsSyntaxError("qd.checkpoint() can only be used inside @qd.kernel, not @qd.func")
🟢  202          kernel = ctx.global_context.current_kernel
🟢  203          if not kernel.use_graph:
🔴  204              raise QuadrantsSyntaxError(
    205                  "qd.checkpoint() requires @qd.kernel(graph=True); the resume model only applies to graph kernels"
    206              )
🟢  207          if not kernel.use_checkpoints:
🟢  208              raise QuadrantsSyntaxError(
    209                  f"qd.checkpoint() found in kernel {kernel.func.__name__!r}, but this kernel was not decorated with "
    210                  "`checkpoints=True`. Add `checkpoints=True` to the decorator -- e.g. "
    211                  "`@qd.kernel(graph=True, checkpoints=True)` -- to opt into the pause / resume model."
    212              )
🟢  213          if getattr(ctx, "_in_checkpoint", False):
🟢  214              raise QuadrantsSyntaxError(
    215                  "qd.checkpoint() cannot be nested inside another qd.checkpoint(); checkpoints in the "
    216                  "same kernel must be flat siblings (a checkpoint inside qd.graph_do_while is fine)"
    217              )
    218  
🟢  219          if not info.is_implicit:
    220              # Validate `yield_on=` names a real kernel parameter.
🟢  221              arg_names = [m.name for m in kernel.arg_metas]
🟢  222              if info.yield_on not in arg_names:
🟢  223                  raise QuadrantsSyntaxError(
    224                      f"qd.checkpoint(yield_on={info.yield_on!r}) does not match any parameter of kernel "
    225                      f"{kernel.func.__name__!r}. Available parameters: {arg_names}"
    226                  )
    227              # Reject duplicate user-supplied cp_id labels.
🟢  228              existing = [lbl for lbl in kernel.checkpoint_user_labels_by_cp_id if lbl is not None]
🟢  229              if info.cp_id in existing:
🟢  230                  raise QuadrantsSyntaxError(
    231                      f"qd.checkpoint(cp_id={info.cp_id!r}, ...) declares a cp_id that is already used by another "
    232                      f"checkpoint in kernel {kernel.func.__name__!r}. Each `cp_id` label must be unique so the host "
    233                      "loop can map `status.checkpoint` and `kernel.resume(from_checkpoint=...)` unambiguously."
    234                  )
    235  
    236          # Reject bare top-level statements (Assign / AugAssign / AnnAssign / non-docstring Expr) at the top of the
    237          # checkpoint body and ask the user to wrap them in their own for-loop. The offloader's pending-serial bucket
    238          # loses the surrounding `checkpoint_id` and emits such statements as `serial` tasks with `cp_id == -1`, so they
    239          # would run unconditionally even when the checkpoint is skipped -- a silent correctness bug. Rather than
    240          # auto-wrapping them transparently (which hides the fact that each bare stmt becomes its own kernel / graph
    241          # node and surprises users when they look at the lowered IR or `prog.get_num_offloaded_tasks_on_last_call()`),
    242          # we surface a clear compile-time error and have the user write `for _ in range(1): <stmt>` themselves.
    243          # Implicit (auto-wrapped) checkpoints can never trip this -- they wrap exactly one ast.For. Coverage probes
    244          # (synthesized `_qd_cov[i] = 1` assignments under `QD_KERNEL_COVERAGE=1`) are explicitly exempt; the cp_id
    245          # propagation in `quadrants/transforms/offload.cpp` (`assemble_serial_statements`) makes them inherit the
    246          # surrounding checkpoint's cp_id so the launcher's "last task in checkpoint" detection stays correct.
🟢  247          for stmt in node.body:
🟢  248              is_bare = isinstance(stmt, (ast.Assign, ast.AugAssign, ast.AnnAssign))
🟢  249              if not is_bare and isinstance(stmt, ast.Expr):
    250                  # Any `Expr(Constant)` is a no-op (Python's docstring pattern, e.g. a leading triple-quoted string).
    251                  # We accept it anywhere in the body rather than only at position 0 because the kernel-coverage AST
    252                  # transformer (see `_kernel_coverage.py`) prepends a `_qd_cov[...] = 1` probe under
    253                  # `QD_KERNEL_COVERAGE=1`, which would otherwise push the docstring to position 1 and have us flag it as
    254                  # a bare top-level statement.
🟢  255                  is_constant_expr = isinstance(stmt.value, ast.Constant)
🟢  256                  is_bare = not is_constant_expr
🟢  257              if not is_bare:
🟢  258                  continue
🟢  259              if CheckpointTransformer._is_coverage_probe_assign(stmt):
🟢  260                  continue
🟢  261              stmt_kind = type(stmt).__name__
🟢  262              raise QuadrantsSyntaxError(
    263                  f"qd.checkpoint() body cannot contain a bare top-level {stmt_kind} statement "
    264                  f"(line {getattr(stmt, 'lineno', '?')}): every top-level statement in a checkpoint must be inside a "
    265                  f"for-loop (or other control-flow construct), so the compiler can lower it as its own offloaded task "
    266                  f"with the correct cp_id. Wrap the statement in `for _ in range(1):` to keep the original intent:\n"
    267                  f"\n"
    268                  f"    with qd.checkpoint(cp_id, yield_on=flag):\n"
    269                  f"        for _ in range(1):\n"
    270                  f"            <your statement here>\n"
    271                  f"        for i in range(arr.shape[0]):\n"
    272                  f"            ...\n"
    273              )
    274  
🟢  275          kernel.checkpoint_yield_on_args.append(info.yield_on)
🟢  276          kernel.checkpoint_user_labels_by_cp_id.append(info.cp_id)
    277          # Hand control to the C++ ASTBuilder so that every for-loop emitted by `build_stmts` below is tagged with this
    278          # checkpoint's internal `cp_id` on its `ForLoopConfig.checkpoint_id`. The C++ counter is the source of truth for
    279          # the dense internal cp_id; we cross-check it against the Python list index so that a future refactor that
    280          # misaligns the two surfaces fires immediately.
🟢  281          cpp_cp_id = ctx.ast_builder.begin_checkpoint()
🟢  282          py_cp_id = len(kernel.checkpoint_yield_on_args) - 1
🟢  283          assert cpp_cp_id == py_cp_id, (
    284              f"C++ ASTBuilder.begin_checkpoint() returned cp_id={cpp_cp_id} but Python "
    285              f"kernel.checkpoint_yield_on_args index expected {py_cp_id}; these counters "
    286              f"must stay in lockstep so the GraphManager can index yield_on by cp_id"
    287          )
🟢  288          ctx._in_checkpoint = True
🟢  289          try:
🟢  290              build_stmts(ctx, node.body)
    291          finally:
🟢  292              ctx._in_checkpoint = False
🟢  293              ctx.ast_builder.end_checkpoint()
🟢  294          return None
    295  
🟢  296      @staticmethod
🟢  297      def _make_implicit_checkpoint_with(for_stmt: ast.For) -> ast.With:
    298          """Construct a synthetic ``with qd.checkpoint(): <for_stmt>`` AST node carrying the `_qd_implicit` marker.
    299  
    300          The `qd` name resolves via the kernel's module globals at AST-walk time; in practice every kernel that imports
    301          `quadrants as qd` (the canonical pattern) has it available. If the user imported as `import quadrants as t`,
    302          `with qd.checkpoint(...)` would already fail to resolve, so we don't worry about that case here.
    303          """
🟢  304          call = ast.Call(
    305              func=ast.Attribute(value=ast.Name(id="qd", ctx=ast.Load()), attr="checkpoint", ctx=ast.Load()),
    306              args=[],
    307              keywords=[],
    308          )
🟢  309          setattr(call, _IMPLICIT_MARKER_ATTR, True)
🟢  310          with_stmt = ast.With(
    311              items=[ast.withitem(context_expr=call, optional_vars=None)],
    312              body=[for_stmt],
    313              type_comment=None,
    314          )
🟢  315          ast.copy_location(with_stmt, for_stmt)
🟢  316          ast.copy_location(call, for_stmt)
🟢  317          ast.copy_location(call.func, for_stmt)
🟢  318          ast.copy_location(call.func.value, for_stmt)
    319          # ast.With requires a complete locations table for downstream passes; fix up missing fields.
🟢  320          ast.fix_missing_locations(with_stmt)
🟢  321          return with_stmt
    322  
🟢  323      @staticmethod
🟢  324      def auto_wrap_for_loops(stmts: list[ast.stmt]) -> list[ast.stmt]:
    325          """Auto-wrap pass for `@qd.kernel(graph=True, checkpoints=True)` kernels.
    326  
    327          Walks *stmts* and returns a new list where every `ast.For` not already inside a `with qd.checkpoint(...)` block
    328          is wrapped in a synthetic implicit checkpoint. Recurses into `while qd.graph_do_while(...)` bodies so for-loops
    329          nested in the WHILE body get the same treatment. Other compound statements (`ast.If`, `ast.With`,
    330          non-graph_do_while `ast.While`, `ast.Try`) are passed through unchanged -- they're not the common pattern in
    331          Quadrants kernel bodies and recursing into them risks wrapping nested for-loops that the user intended to be
    332          sub-tasks of a larger control-flow block. Bare top-level statements (assignments, expressions, coverage probes)
    333          are also passed through unchanged so they remain in the kernel prologue with `cp_id=-1` and run on every launch.
    334          """
🟢  335          new_stmts: list[ast.stmt] = []
🟢  336          for stmt in stmts:
🟢  337              if isinstance(stmt, ast.For):
🟢  338                  new_stmts.append(CheckpointTransformer._make_implicit_checkpoint_with(stmt))
🟢  339              elif CheckpointTransformer.is_explicit_checkpoint_with(stmt):
    340                  # User has already wrapped this for-loop (or block of for-loops) -- leave alone.
🟢  341                  new_stmts.append(stmt)
🟢  342              elif isinstance(stmt, ast.While):
    343                  # `while qd.graph_do_while(...):` -- recurse into the loop body. (Plain `while <cond>:` is not a
    344                  # supported Quadrants kernel construct; the `build_While` transformer will reject it later, so we don't
    345                  # special-case it here.)
🟢  346                  stmt.body = CheckpointTransformer.auto_wrap_for_loops(stmt.body)
🟢  347                  new_stmts.append(stmt)
    348              else:
🟢  349                  new_stmts.append(stmt)
🟢  350          return new_stmts
🔴 python/quadrants/lang/ast/ast_transformers/function_def_transformer.py (78%)
🟢   29  from quadrants.lang.ast.ast_transformers.checkpoint_transformer import (
     30      CheckpointTransformer,
     31  )
🟢  511              kernel = ctx.global_context.current_kernel
🟢  512              if kernel is not None:
    513                  # Reset before walking the body so re-materialisations (e.g. when a templated kernel is compiled with a
    514                  # different argument shape) start from an empty list. Mirrors how `graph_do_while_arg` gets overwritten
    515                  # unconditionally during AST traversal.
🟢  516                  kernel.checkpoint_yield_on_args = []
🟢  517                  kernel.checkpoint_user_labels_by_cp_id = []
    518                  # Auto-wrap pass for `@qd.kernel(graph=True, checkpoints=True)` kernels. Mutates `node.body` in place so
    519                  # every top-level for-loop (and every for-loop inside a `qd.graph_do_while` body) that the user did not
    520                  # already wrap in a `with qd.checkpoint(...)` gets wrapped in a synthetic implicit no-yield checkpoint.
    521                  # Implicit checkpoints share the same dense source-order internal cp_id space as explicit ones, but
    522                  # carry `None` in `checkpoint_user_labels_by_cp_id` so they never appear in `GraphStatus.checkpoint` /
    523                  # `kernel.resume(from_checkpoint=...)`. Runs here (after coverage instrumentation has already injected
    524                  # its top-level `_qd_cov[i] = 1` probes, which are bare assigns that the wrap pass intentionally leaves
    525                  # alone) so that the regular `build_stmts` walk below sees a uniform stream of `with qd.checkpoint(...)`
    526                  # blocks and bare prologue stmts.
🟢  527                  if kernel.use_checkpoints:
🟢  528                      node.body = CheckpointTransformer.auto_wrap_for_loops(node.body)
🟢  573      @staticmethod
🟢  574      def _is_checkpoint_with(stmt: ast.With) -> bool:
    575          """Syntactic check matching CheckpointTransformer: a ``with qd.checkpoint(...):`` block. Accepted as a
    576          well-formed statement inside a graph_do_while body (the checkpoint itself enforces its own restrictions)."""
🟢  577          if not isinstance(stmt, ast.With) or len(stmt.items) != 1:
🔴  578              return False
🟢  579          ctx = stmt.items[0].context_expr
🟢  580          if not isinstance(ctx, ast.Call):
🔴  581              return False
🟢  582          func = ctx.func
🟢  583          if isinstance(func, ast.Attribute) and func.attr == "checkpoint":
🟢  584              return True
🔴  585          if isinstance(func, ast.Name) and func.id == "checkpoint":
🔴  586              return True
🔴  587          return False
    588  
🟢  655              if isinstance(stmt, ast.With) and FunctionDefTransformer._is_checkpoint_with(stmt):
    656                  # `with qd.checkpoint(...)` is a legal placement for graph kernels (the checkpoint may sit at the
    657                  # kernel top level or inside any graph_do_while level). Recurse so a malformed graph_do_while
    658                  # nested inside a checkpoint body is still caught; the checkpoint's own body restrictions are
    659                  # enforced by `CheckpointTransformer.build_checkpoint_with`.
🟢  660                  FunctionDefTransformer._validate_graph_do_while_stmt_list(stmt.body, is_kernel_top=is_kernel_top)
🟢  661                  continue
🟢 python/quadrants/lang/checkpoint.py (100%)
      1  """User-facing ``qd.checkpoint`` context-manager and its no-op Python-runtime stub.
      2  
      3  Mirrors ``graph_status.py``, which holds the other half of the same feature surface (``GraphStatus``). Kept in its own
      4  module to keep ``lang/misc.py`` from growing further -- the AST transformer and the C++ runtime are doing all the actual
      5  implementation work; this file is just the public API entry point.
      6  
      7  Re-exported via ``qd.lang.misc`` (and therefore as ``qd.checkpoint``) for the user-facing canonical import path.
      8  """
      9  
🟢   10  from __future__ import annotations
     11  
🟢   12  from contextlib import contextmanager
     13  
     14  
🟢   15  @contextmanager
🟢   16  def checkpoint(cp_id, yield_on):
     17      """Marks a section of a graph kernel as a pause / resume point.
     18  
     19      .. warning::
     20  
     21          **Experimental.** ``qd.checkpoint`` (together with ``qd.GraphStatus`` and ``kernel.resume(from_checkpoint=...)``)
     22          is an experimental API. The signature, the lowering across backends, the error messages, and the host-side
     23          yield/resume contract may change in any future release without a deprecation cycle.
     24  
     25      Used as ``with qd.checkpoint(cp_id, yield_on=flag):`` inside a ``@qd.kernel(graph=True, checkpoints=True)`` kernel
     26      body. When the body writes a non-zero value into ``flag``, the kernel pauses at this checkpoint and returns a
     27      ``GraphStatus`` to the host carrying ``status.checkpoint == cp_id``. The host can then fix things up and call
     28      ``kernel.resume(..., from_checkpoint=cp_id)`` to continue from the same point on the next launch.
     29  
     30      Arguments:
     31          cp_id: User-facing label identifying this checkpoint to the host. Must be an ``int`` literal or an ``IntEnum``
     32              value, and must be unique within the kernel. The value is preserved as-is end-to-end -- if you pass
     33              ``Stage.SIM`` (an ``IntEnum`` member), ``status.checkpoint`` round-trips back as ``Stage.SIM`` rather than
     34              the raw int.
     35          yield_on: Name of a kernel parameter that is a 0-d ``qd.types.ndarray(qd.i32, ndim=0)``. The body may write a
     36              non-zero value into it to signal "pause here, host needs to handle something". The framework never writes
     37              into this buffer -- the host owns it end-to-end and must initialise it to ``0`` before the first launch
     38              (``qd.ndarray`` is not zero-initialised) AND reset it to ``0`` before each ``kernel.resume(...)`` call
     39              (otherwise the same checkpoint sees the stale non-zero value and yields again).
     40  
     41      Restrictions (enforced at kernel compile time):
     42        - Must be used inside ``@qd.kernel(graph=True, checkpoints=True)``.
     43        - ``cp_id`` must be an ``int`` (or ``IntEnum`` value), and must be unique across the kernel.
     44        - ``yield_on`` must name a kernel parameter that is a 0-d ``qd.types.ndarray(qd.i32, ndim=0)``.
     45        - Checkpoints cannot be nested inside other checkpoints. Checkpoints inside a ``qd.graph_do_while`` body are fine.
     46        - Cannot be combined with ``qd.stream_parallel()`` in the same kernel.
     47        - The body cannot contain bare top-level statements (assignments, expressions); wrap them in
     48          ``for _ in range(1):`` so the lowering surfaces the per-statement task cost.
     49  
     50      This function should not be called directly at runtime; it is recognised and transformed during AST compilation. At
     51      Python runtime (outside kernels), this is a no-op context manager so that doctests / type-checking can import the
     52      symbol freely.
     53  
     54      See ``docs/source/user_guide/graph.md`` for the host-side yield/resume loop and cross-backend semantics.
     55      """
🟢   56      del cp_id, yield_on
🟢   57      yield
🔴 python/quadrants/lang/graph_status.py (70%)
      1  """Plain-Python container returned from graph kernels that contain ``qd.checkpoint(yield_on=...)``.
      2  
      3  Lives in its own module (with no Quadrants-internal imports) so it can be imported safely from both ``kernel.py`` and
      4  ``misc.py`` without re-introducing the circular import chain that ``misc.py -> impl.py -> kernel.py`` would create.
      5  
      6  Re-exported via ``qd.lang.misc`` (and therefore as ``qd.GraphStatus``) for the user-facing canonical import path.
      7  """
      8  
🟢    9  from __future__ import annotations
     10  
     11  
🟢   12  class GraphStatus:
     13      """Result returned by a graph kernel that contains ``qd.checkpoint(cp_id, yield_on=...)`` blocks.
     14  
     15      .. warning::
     16  
     17          **Experimental.** ``GraphStatus`` is part of the experimental ``qd.checkpoint`` surface; its attributes and
     18          the conditions under which it is returned may change in any future release without a deprecation cycle.
     19  
     20      Returned from ``kernel(...)`` and ``kernel.resume(..., from_checkpoint=label)`` whenever the kernel was decorated
     21      with ``@qd.kernel(graph=True, checkpoints=True)``. Read ``status.yielded`` to decide whether to keep running the
     22      host loop, and ``status.checkpoint`` to find out which checkpoint asked the host to handle something.
     23  
     24      Canonical usage (see ``graph.md``)::
     25  
     26          from enum import IntEnum
     27  
     28          class Stage(IntEnum):
     29              SIM = 0
     30  
     31          overflow_flag[()] = 0  # initialise before the first launch
     32          status = step(arr, overflow_flag, newton_cond)
     33          while status.yielded:
     34              handle_overflow_for(status.checkpoint, ...)
     35              overflow_flag[()] = 0  # the framework never clears your yield_on flag
     36              status = step.resume(arr, overflow_flag, newton_cond,
     37                                   from_checkpoint=status.checkpoint)
     38  
     39      Attributes:
     40          yielded: ``True`` iff one of the kernel's checkpoints paused on the most recent launch (i.e. its ``yield_on=``
     41              flag was non-zero). ``False`` means the kernel completed normally and the host loop should exit.
     42          checkpoint: The ``cp_id`` label of the checkpoint that paused (an ``int`` or the original ``IntEnum`` instance
     43              you passed to ``qd.checkpoint(cp_id, ...)``), or ``None`` when ``yielded`` is ``False``. Pass it to
     44              ``kernel.resume(..., from_checkpoint=...)`` to continue from that point on the next launch.
     45      """
     46  
🟢   47      __slots__ = ("yielded", "checkpoint")
     48  
🟢   49      def __init__(self, yielded: bool, checkpoint: int | None):
🟢   50          self.yielded = yielded
🟢   51          self.checkpoint = checkpoint
     52  
🟢   53      def __repr__(self) -> str:
🔴   54          if self.yielded:
     55              # Use `!r` so an `IntEnum` cp_id is shown as `<Stage.LOAD: 0>` rather than collapsing to its raw int via
     56              # `int.__format__` (which Python 3.10's `f"{IntEnum.X}"` does). Plain ints round-trip unchanged.
🔴   57              return f"GraphStatus(yielded=True, checkpoint={self.checkpoint!r})"
🔴   58          return "GraphStatus(yielded=False)"
🟢 python/quadrants/lang/kernel.py (100%)
     39  
     40  # `qd.checkpoint` pause / resume model helpers. See `kernel_checkpoint.py` for the full extracted surface; `Kernel`
     41  # delegates the resume-cookie validation, label translation, per-launch yield_on= arg-id table build, and GraphStatus
     42  # construction to those free functions so this hot file doesn't accrete checkpoint-feature-specific blocks.
🟢   43  from quadrants.lang import kernel_checkpoint as _checkpoint_helpers
    331          # Opt-in flag set by `@qd.kernel(graph=True, checkpoints=True)`. When True, the AST transformer enables
    332          # `qd.checkpoint(...)` recognition AND auto-wraps every top-level for-loop that isn't already inside a
    333          # `with qd.checkpoint(...)` block in an implicit no-yield checkpoint. When False, any use of
    334          # `qd.checkpoint(...)` in the kernel body is rejected at compile time with a fix-it pointing at
    335          # `checkpoints=True`.
🟢  336          self.use_checkpoints: bool = False
    345          # Per-checkpoint metadata, one entry per `with qd.checkpoint(...)` block (explicit AND auto-injected implicit)
    346          # in declaration order. List index is the checkpoint's internal `cp_id` (0, 1, 2, ... dense, flat across the
    347          # kernel). Each entry is the name of the `yield_on=` kernel parameter, or `None` for implicit checkpoints
    348          # (which never yield). Populated by the AST transformer; empty means the kernel uses no checkpoints.
🟢  349          self.checkpoint_yield_on_args: list[str | None] = []
    350          # User-facing labels for explicit checkpoints. Same indexing as `checkpoint_yield_on_args`: entry `i` is the int
    351          # (or IntEnum value) the user passed as the first positional arg of `qd.checkpoint(cp_id, yield_on)` for the
    352          # checkpoint whose internal cp_id is `i`. Implicit checkpoints (auto-wrapped) get `None` (they have no
    353          # user-facing label and can never appear in `GraphStatus.checkpoint`). The label is preserved as-is so an
    354          # `IntEnum` round-trips: writing `qd.checkpoint(Stage.SIM, ...)` and then reading `status.checkpoint` returns
    355          # `Stage.SIM` rather than the raw int.
🟢  356          self.checkpoint_user_labels_by_cp_id: list[int | None] = []
    535          self,
    536          key,
    537          t_kernel: KernelCxx,
    538          compiled_kernel_data: CompiledKernelData | None,
    539          *args,
    540          qd_stream=None,
    541          _resume_from_checkpoint: int | None = None,
🟢  590              _checkpoint_helpers.init_yield_on_arg_id_table(self)
🟢  608                  _checkpoint_helpers.maybe_record_yield_on_arg(self, self.arg_metas[i_in].name, i_out - template_num)
🟢  684              _checkpoint_helpers.forward_yield_on_table_to_ctx(self, launch_ctx)
    685              # `_resume_from_checkpoint` is `None` for fresh launches (host-side default 0 in `LaunchContextBuilder`,
    686              # which means "run every checkpoint"). When `Kernel.resume` plumbs an int through, copy it onto the launch
    687              # context so the GraphManager's `launch_cached_graph` memcpys it into the device-side `resume_point` slot
    688              # instead of clearing to 0. Slice 2 implementation; pre-CUDA-12.4 / non-CUDA backends ignore the value since
    689              # they don't have a resume_point slot today (slices 4-6 will add an indirect-dispatch equivalent).
🟢  690              if _resume_from_checkpoint is not None:
🟢  691                  launch_ctx.resume_from_checkpoint = int(_resume_from_checkpoint)
    780          # Pop the resume cookie before anything else touches kwargs -- the AST mapper sees user parameter names only, so
    781          # a stray `from_checkpoint=` would raise "unexpected kwarg". `_resume_from_checkpoint` is the resolved cp_id to
    782          # copy into the device-side `resume_point` slot before launch; `None` means "fresh start, reset to 0". Plumbed
    783          # via `Kernel.resume()` only; users do not pass this directly.
🟢  784          _resume_from_checkpoint = kwargs.pop("_qd_from_checkpoint", None)
🟢  785          _checkpoint_helpers.validate_resume_cookie(self, _resume_from_checkpoint)
    858          # Translate the user-supplied `from_checkpoint=` label into the dense, source-order internal cp_id the runtime
    859          # uses. Translation happens here (after `ensure_compiled`) because `checkpoint_user_labels_by_cp_id` is
    860          # populated during AST processing inside `ensure_compiled`.
🟢  861          if _resume_from_checkpoint is not None:
🟢  862              _resume_from_checkpoint = _checkpoint_helpers.translate_user_label_to_internal_cp_id(
    863                  self, _resume_from_checkpoint
    864              )
    865          # Only forward `_resume_from_checkpoint` when the caller actually supplied one (i.e. via `Kernel.resume(...)`).
    866          # Otherwise omit the kwarg entirely so subclasses / monkeypatches of `launch_kernel` that pre-date this kwarg
    867          # keep working unmodified. The host-side default in `LaunchContextBuilder` is 0 ("run every checkpoint"), which
    868          # matches the `None` semantics in `launch_kernel`.
🟢  869          if _resume_from_checkpoint is None:
🟢  870              ret = self.launch_kernel(
    871                  key,
    872                  kernel_cpp,
    873                  compiled_kernel_data,
    874                  *py_args,
    875                  qd_stream=qd_stream,
    876              )
    877          else:
🟢  878              ret = self.launch_kernel(
    879                  key,
    880                  kernel_cpp,
    881                  compiled_kernel_data,
    882                  *py_args,
    883                  qd_stream=qd_stream,
    884                  _resume_from_checkpoint=_resume_from_checkpoint,
    885              )
    889          # Surface a GraphStatus for kernels with `qd.checkpoint(yield_on=...)` so the host can drive the qipc-style
    890          # re-entrant loop. Kernels without yield-capable checkpoints get `ret` (typically `None`) passed through.
🟢  891          return _checkpoint_helpers.maybe_build_graph_status(self, ret)
🟢 python/quadrants/lang/kernel_checkpoint.py (97%)
      1  """Helpers extracted from ``kernel.py`` for the ``qd.checkpoint(...)`` pause / resume model.
      2  
      3  ``Kernel.__call__`` / ``Kernel.launch_kernel`` delegate the resume-cookie validation, the user-label-to-internal-
      4  cp_id translation, the per-launch ``yield_on=`` arg-id table construction, and the ``GraphStatus`` build to the free
      5  functions below so the central ``Kernel`` class doesn't accrete checkpoint-feature-specific blocks. See
      6  ``qd.checkpoint`` / ``kernel.resume`` / ``docs/source/user_guide/graph.md`` for the user-facing surface.
      7  """
      8  
🟢    9  from __future__ import annotations
     10  
🟢   11  from typing import Any
     12  
🟢   13  from quadrants.lang import impl
🟢   14  from quadrants.lang.graph_status import GraphStatus
     15  
     16  
🟢   17  def validate_resume_cookie(kernel: Any, resume_from_checkpoint: int | None) -> None:
     18      """Raise if ``_qd_from_checkpoint`` was passed to a kernel without any ``qd.checkpoint(yield_on=...)`` block.
     19  
     20      Called from the preamble of ``Kernel.__call__`` so the user gets a clear error before any compile / launch work
     21      happens, rather than a confusing "no GraphStatus surface" failure later.
     22      """
🟢   23      if resume_from_checkpoint is not None and not kernel.checkpoint_yield_on_args:
🔴   24          raise RuntimeError(
     25              "`from_checkpoint=` is only valid for kernels that contain at least one "
     26              "qd.checkpoint(yield_on=...) block; this kernel has none."
     27          )
     28  
     29  
🟢   30  def translate_user_label_to_internal_cp_id(kernel: Any, user_label: int) -> int:
     31      """Translate a user-supplied ``from_checkpoint=`` label (int or IntEnum) to the runtime's internal dense cp_id.
     32  
     33      The runtime indexes checkpoints by source-declaration order (0, 1, 2, ...). The user-facing label is whatever int /
     34      IntEnum they passed as the first positional arg of ``qd.checkpoint(cp_id, yield_on=...)``; the
     35      ``checkpoint_user_labels_by_cp_id`` table maps internal cp_id -> user label. Compared with ``==`` so an IntEnum
     36      value matches its underlying int. Implicit (auto-wrapped) checkpoints have ``None`` in the table and are never
     37      resume targets. Raises ``RuntimeError`` listing the available labels when the user passed an unknown one.
     38      """
🟢   39      for internal_cp_id, label in enumerate(kernel.checkpoint_user_labels_by_cp_id):
🟢   40          if label is not None and label == user_label:
🟢   41              return internal_cp_id
🟢   42      available = [lbl for lbl in kernel.checkpoint_user_labels_by_cp_id if lbl is not None]
🟢   43      raise RuntimeError(
     44          f"from_checkpoint={user_label!r} does not match any qd.checkpoint(cp_id=...) in "
     45          f"kernel {kernel.func.__name__!r}. Available cp_id labels (source-declaration order): {available}."
     46      )
     47  
     48  
🟢   49  def init_yield_on_arg_id_table(kernel: Any) -> None:
     50      """Allocate / reset the per-launch ``cp_id -> C++ arg-id`` table at the top of ``launch_kernel``'s arg iteration.
     51  
     52      Each entry defaults to ``-1`` ("no yield_on"); the per-arg loop below fills in the C++ arg id when it visits the
     53      named parameter. Sized to the kernel's checkpoint count once per launch so any changes to the checkpoint set (only
     54      possible via re-AST-walk) reset the table cleanly. No-op for kernels with no ``yield_on=`` checkpoints.
     55      """
🟢   56      if kernel.checkpoint_yield_on_args:
🟢   57          kernel._checkpoint_yield_on_cpp_arg_ids = [-1] * len(kernel.checkpoint_yield_on_args)
     58  
     59  
🟢   60  def maybe_record_yield_on_arg(kernel: Any, arg_name: str, cpp_arg_id: int) -> None:
     61      """Fill the ``cp_id -> C++ arg-id`` slot when the arg iterator visits a named ``yield_on=`` kernel parameter.
     62  
     63      Walked once per kernel arg in ``launch_kernel``; cheap O(checkpoints) match. A single parameter can be the
     64      ``yield_on=`` for multiple checkpoints (the inner loop fills every matching slot).
     65      """
🟢   66      if not kernel.checkpoint_yield_on_args:
🟢   67          return
🟢   68      for cp_idx, yield_name in enumerate(kernel.checkpoint_yield_on_args):
🟢   69          if yield_name is not None and arg_name == yield_name:
🟢   70              kernel._checkpoint_yield_on_cpp_arg_ids[cp_idx] = cpp_arg_id
     71  
     72  
🟢   73  def forward_yield_on_table_to_ctx(kernel: Any, launch_ctx: Any) -> None:
     74      """Copy the resolved ``cp_id -> C++ arg-id`` table onto the launch context so the runtime can find each
     75      ``yield_on=`` ndarray's device address at launch.
     76      """
🟢   77      if kernel.checkpoint_yield_on_args and hasattr(kernel, "_checkpoint_yield_on_cpp_arg_ids"):
🟢   78          launch_ctx.checkpoint_yield_on_arg_ids = tuple(kernel._checkpoint_yield_on_cpp_arg_ids)
     79  
     80  
🟢   81  def maybe_build_graph_status(kernel: Any, default_ret: Any) -> Any:
     82      """Translate the runtime's internal yielding cp_id back to the user-supplied label and return a ``GraphStatus``.
     83  
     84      Returns ``default_ret`` unchanged for kernels without any yielding checkpoint -- there's no ``yield_on=`` parameter
     85      to surface a status from, so the value would always be ``yielded=False, checkpoint=None`` (no information).
     86      Implicit (auto-wrapped) checkpoints have ``None`` in ``checkpoint_user_labels_by_cp_id`` but they never have
     87      ``yield_on=``, so the runtime can't surface them as the yielding cp -- the lookup is always to an explicit
     88      checkpoint.
     89      """
🟢   90      if not (kernel.checkpoint_yield_on_args and any(n is not None for n in kernel.checkpoint_yield_on_args)):
🟢   91          return default_ret
🟢   92      cp = impl.get_runtime().prog.get_graph_last_yield_cp_id_on_last_call()
🟢   93      if cp >= 0:
🟢   94          user_label = (
     95              kernel.checkpoint_user_labels_by_cp_id[cp] if cp < len(kernel.checkpoint_user_labels_by_cp_id) else None
     96          )
🟢   97          return GraphStatus(yielded=True, checkpoint=user_label if user_label is not None else cp)
🟢   98      return GraphStatus(yielded=False, checkpoint=None)
🟢 python/quadrants/lang/kernel_impl.py (100%)
    161      checkpoints: bool = False,
🟢  172      primal.use_checkpoints = checkpoints
🟢  173      adjoint.use_checkpoints = checkpoints
    215  def kernel(
    216      _fn: None = None, *, pure: bool = False, graph: bool = False, checkpoints: bool = False
    217  ) -> Callable[[Any], Any]: ...
    227  def kernel(_fn: Any, *, pure: bool = False, graph: bool = False, checkpoints: bool = False) -> Any: ...
    236      checkpoints: bool = False,
    253          checkpoints: If True, opt into the (experimental) ``qd.checkpoint`` pause / resume model.
    254              ``with qd.checkpoint(cp_id, yield_on=flag):`` blocks in the kernel body become pause points the host can
    255              resume from via ``kernel.resume(from_checkpoint=cp_id)``. Requires ``graph=True``.
    256              ``qd.checkpoint(...)`` in the body is rejected unless this flag is set.
🟢  276          if checkpoints and not graph:
🟢  277              raise QuadrantsSyntaxError(
    278                  f"@qd.kernel({fn.__name__!r}, checkpoints=True) requires graph=True; "
    279                  "the checkpoint resume model is only meaningful for graph kernels."
    280              )
🟢  281          wrapped = _kernel_impl(fn, level_of_class_stackframe=level, graph=graph, checkpoints=checkpoints)
🟢 python/quadrants/lang/misc.py (100%)
🟢   13  from quadrants.lang.checkpoint import checkpoint
🟢   15  from quadrants.lang.graph_status import GraphStatus
    890      "GraphStatus",
    891      "checkpoint",
🟢 tests/python/test_checkpoint.py (90%)
      1  """Tests for ``qd.checkpoint`` -- yield/resume stage primitive for graph kernels.
      2  
      3  These tests cover the auto-checkpoint surface:
      4  
      5    - The user-facing API is ``qd.checkpoint(cp_id, yield_on=flag)``. Both arguments are required; ``cp_id`` is a user
      6      label (``int`` or ``IntEnum`` value), and ``yield_on`` is a 0-d ``qd.i32`` ndarray kernel parameter.
      7    - The kernel must be decorated with ``@qd.kernel(graph=True, checkpoints=True)`` to use ``qd.checkpoint(...)``. The
      8      flag opts the kernel into the resume model and enables the auto-wrap pass.
      9    - Auto-wrap: every top-level for-loop in the kernel body (including inside ``while qd.graph_do_while(...):``) that
     10      is not inside a ``with qd.checkpoint(...)`` becomes an implicit no-yield checkpoint. Implicit checkpoints carry no
     11      user label and never appear in ``GraphStatus.checkpoint``, but they DO consume an internal cp_id slot so a resume
     12      launch can skip them along with the explicit checkpoints declared earlier in source order.
     13    - ``status.checkpoint`` round-trips the user-supplied label (so ``qd.checkpoint(Stage.SIM, ...)`` surfaces as
     14      ``Stage.SIM`` on yield). ``kernel.resume(from_checkpoint=Stage.SIM)`` skips every checkpoint (implicit + explicit)
     15      declared before ``Stage.SIM`` in source order.
     16  
     17  The behavioural assertions (yield, resume, kernel completes normally on no-yield) run on every backend that implements
     18  the host-side yield/resume contract -- see ``_supports_checkpoint_yield_resume`` below. The CUDA-native-only counters
     19  (IF conditional node count) are guarded behind ``_is_checkpoint_if_path_native``.
     20  """
     21  
🟢   22  from enum import IntEnum
     23  
🟢   24  import numpy as np
🟢   25  import pytest
     26  
🟢   27  import quadrants as qd
🟢   28  from quadrants.lang import impl
     29  
🟢   30  from tests import test_utils
     31  
     32  
🟢   33  def _on_cuda():
🟢   34      return impl.current_cfg().arch == qd.cuda
     35  
     36  
🟢   37  def _is_checkpoint_if_path_native():
     38      """The CUDA-native IF-conditional path requires SM 9.0+ / CUDA 12.4+ (slice 1c).
     39  
     40      On other devices and backends the kernel still runs through every checkpoint body, so the behavioural tests pass
     41      everywhere, but the GraphManager-introspection assertions only apply on the native path.
     42      """
🟢   43      return _on_cuda() and qd.lang.impl.get_cuda_compute_capability() >= 90
     44  
     45  
🟢   46  def _supports_checkpoint_yield_resume():
     47      """Backends that implement the checkpoint yield/resume host contract.
     48  
     49      Wider than `_is_checkpoint_if_path_native()`: also includes the CPU/x64 path (slice 6) and AMDGPU host-orchestrated
     50      sub-graph path (slice 4). Use this predicate for tests of the behavioural yield/resume + `kernel.resume(...)` API;
     51      use `_is_checkpoint_if_path_native()` only for graph-introspection counters that exist on CUDA alone.
     52      """
🟢   53      if _is_checkpoint_if_path_native():
🔴   54          return True
     55      # CPU backend: same `runtime/cpu/kernel_launcher.cpp` host-branch gating runs on both x64 and arm64 (the launcher is
     56      # arch-agnostic; only the LLVM codegen target differs). Apple Silicon surfaces as `qd.arm64`; Linux x86 as `qd.x64`.
     57      # Both go through the slice 6 path.
🟢   58      if impl.current_cfg().arch in (qd.x64, qd.arm64):
🟢   59          return True
🟢   60      if impl.current_cfg().arch == qd.amdgpu:
🔴   61          return True
     62      # GFX backends (Vulkan, Metal): per-task host gating + readback yield-check in `GfxRuntime` (slice 4 cont.); see
     63      # `runtime/gfx/runtime.cpp`'s task loop.
🟢   64      if impl.current_cfg().arch in (qd.vulkan, qd.metal):
🔴   65          return True
🟢   66      return False
     67  
     68  
🟢   69  def _supports_checkpoint_yield_resume_in_while_loop():
     70      """Strict subset of `_supports_checkpoint_yield_resume`: returns true on backends where yield/resume also works
     71      inside a `qd.graph_do_while` body. Same predicate today since slice 4 ported the CPU launcher's host-branch gating
     72      plus per-iter resume_point reset to the AMDGPU streaming path."""
🟢   73      return _supports_checkpoint_yield_resume()
     74  
     75  
🟢   76  def _num_checkpoints_on_last_call():
🔴   77      return impl.get_runtime().prog.get_graph_num_checkpoints_on_last_call()
     78  
     79  
🟢   80  def _last_yield_cp_id_on_last_call():
🔴   81      return impl.get_runtime().prog.get_graph_last_yield_cp_id_on_last_call()
     82  
     83  
     84  # ----------------------------------------------------------------------------------------------------------------------
     85  # Python-runtime surface (works outside @qd.kernel).
     86  # ----------------------------------------------------------------------------------------------------------------------
     87  
     88  
🟢   89  def test_checkpoint_is_no_op_outside_kernels():
     90      """At Python runtime (outside kernels) ``qd.checkpoint`` must be a usable no-op context manager.
     91  
     92      Lets downstream consumers import the symbol unconditionally and use it inside helpers that are sometimes called
     93      from Python and sometimes from kernels. The new API has two required args; both are accepted unchanged here (the
     94      Python runtime stub is just ``del cp_id, yield_on; yield``).
     95      """
🟢   96      sentinel = []
🟢   97      with qd.checkpoint(0, None):
🟢   98          sentinel.append("body ran")
🟢   99      with qd.checkpoint(7, None):
🟢  100          sentinel.append("body ran")
🟢  101      assert sentinel == ["body ran", "body ran"]
    102  
    103  
    104  # ----------------------------------------------------------------------------------------------------------------------
    105  # Decorator + opt-in flag.
    106  # ----------------------------------------------------------------------------------------------------------------------
    107  
    108  
🟢  109  @test_utils.test()
🟢  110  def test_checkpoint_in_non_checkpoints_kernel_raises():
    111      """Using ``qd.checkpoint(...)`` in a ``@qd.kernel(graph=True)`` without ``checkpoints=True`` must error at compile
    112      time, pointing at the fix-it (add ``checkpoints=True``)."""
    113  
🟢  114      @qd.kernel(graph=True)
🟢  115      def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0)):
🔴  116          with qd.checkpoint(0, yield_on=flag):
🔴  117              for i in range(x.shape[0]):
🔴  118                  x[i] = x[i] + 1
    119  
🟢  120      x = qd.ndarray(qd.i32, shape=(4,))
🟢  121      flag = qd.ndarray(qd.i32, shape=())
🟢  122      with pytest.raises(qd.QuadrantsSyntaxError, match=r"checkpoints=True"):
🟢  123          k(x, flag)
    124  
    125  
🟢  126  def test_kernel_checkpoints_requires_graph_true():
    127      """``@qd.kernel(checkpoints=True)`` without ``graph=True`` is rejected at decorator time -- the resume model is only
    128      meaningful for graph kernels (the gate / yield-check lowering, the resume_point slot, and the kernel.resume API all
    129      depend on the graph-capture path)."""
🟢  130      with pytest.raises(qd.QuadrantsSyntaxError, match=r"checkpoints=True\) requires graph=True"):
    131  
🟢  132          @qd.kernel(checkpoints=True)
🟢  133          def k(x: qd.types.ndarray(qd.i32, ndim=1)):
🔴  134              for i in range(x.shape[0]):
🔴  135                  x[i] = x[i] + 1
    136  
    137  
    138  # ----------------------------------------------------------------------------------------------------------------------
    139  # New API signature: cp_id (int or IntEnum), yield_on (required, must be parameter name).
    140  # ----------------------------------------------------------------------------------------------------------------------
    141  
    142  
🟢  143  @test_utils.test()
🟢  144  def test_checkpoint_missing_cp_id_raises():
    145      """``qd.checkpoint()`` with no args must raise with a message that points at the required signature."""
    146  
🟢  147      @qd.kernel(graph=True, checkpoints=True)
🟢  148      def k(x: qd.types.ndarray(qd.i32, ndim=1)):
🔴  149          with qd.checkpoint():
🔴  150              for i in range(x.shape[0]):
🔴  151                  x[i] = x[i] + 1
    152  
🟢  153      x = qd.ndarray(qd.i32, shape=(4,))
🟢  154      with pytest.raises(qd.QuadrantsSyntaxError, match=r"qd\.checkpoint\(cp_id, yield_on=flag\)"):
🟢  155          k(x)
    156  
    157  
🟢  158  @test_utils.test()
🟢  159  def test_checkpoint_missing_yield_on_raises():
    160      """``qd.checkpoint(cp_id)`` without ``yield_on=`` is rejected at compile time."""
    161  
🟢  162      @qd.kernel(graph=True, checkpoints=True)
🟢  163      def k(x: qd.types.ndarray(qd.i32, ndim=1)):
🔴  164          with qd.checkpoint(0):
🔴  165              for i in range(x.shape[0]):
🔴  166                  x[i] = x[i] + 1
    167  
🟢  168      x = qd.ndarray(qd.i32, shape=(4,))
🟢  169      with pytest.raises(qd.QuadrantsSynta