[Graph] Add qd.checkpoint#725
Open
hughperkins wants to merge 95 commits into
Open
GitHub Actions / Coverage Report
succeeded
Jun 18, 2026 in 0s
Diff Coverage Report
Per-line coverage annotations (truncated — see artifacts for full report).
Details
Coverage Report (3b849362e)
| Metric | Value |
|---|---|
| Diff coverage (changed lines only) | 91% |
| Overall project coverage | 72% |
Total: 943 lines, 82 missing, 91% covered
🟢 python/quadrants/lang/_quadrants_callable.py (83%)
🟢 99 def resume(self, *args, from_checkpoint, **kwargs):
100 """Continues a paused graph kernel from the checkpoint labelled ``from_checkpoint``.
101
102 .. warning::
103
104 **Experimental.** ``kernel.resume`` is part of the experimental ``qd.checkpoint`` surface; the signature
105 (in particular the ``from_checkpoint=`` kwarg) and behaviour may change in any future release without a
106 deprecation cycle.
107
108 Use only on ``@qd.kernel(graph=True, checkpoints=True)`` kernels with at least one
109 ``qd.checkpoint(cp_id, yield_on=flag)`` block. ``from_checkpoint`` is a ``cp_id`` label (typically an
110 ``IntEnum`` value, often ``status.checkpoint`` from the previous launch): everything before that label in
111 source order is skipped on this launch, and execution continues from there. The host loop pattern is::
112
113 from enum import IntEnum
114
115 class Stage(IntEnum):
116 SIM = 0
117
118 overflow_flag[()] = 0 # initialise before the first launch
119 status = step(arr, overflow_flag, newton_cond)
120 while status.yielded:
121 handle(status.checkpoint, ...)
122 overflow_flag[()] = 0 # the framework never clears your yield_on flag
123 status = step.resume(arr, overflow_flag, newton_cond,
124 from_checkpoint=status.checkpoint)
125
126 Returns the same ``GraphStatus`` shape as the plain call.
127
128 Raises ``RuntimeError`` if invoked on a kernel without any ``yield_on=`` checkpoint, or if ``from_checkpoint``
129 does not match any declared ``cp_id`` in the kernel.
130 """
🟢 131 if not isinstance(from_checkpoint, int):
🟢 132 raise RuntimeError(
133 f"from_checkpoint= must be an int or IntEnum value matching a `qd.checkpoint(cp_id=...)` label in "
134 f"the kernel (typically `status.checkpoint` from the previous launch's GraphStatus); "
135 f"got {from_checkpoint!r}."
136 )
137 # Smuggle the resume cookie past the AST-mapped kwargs path; `Kernel.__call__` pops it before anything else
138 # looks at kwargs.
🟢 139 return self.wrapper.__call__(*args, _qd_from_checkpoint=from_checkpoint, **kwargs)
140
170
🟢 171 def resume(self, *args, from_checkpoint, **kwargs):
172 """Bound-method form of `QuadrantsCallable.resume` (see that docstring)."""
🔴 173 return self.quadrants_callable.resume(self.instance, *args, from_checkpoint=from_checkpoint, **kwargs)
🟢 python/quadrants/lang/ast/ast_transformer.py (100%)
🟢 29 from quadrants.lang.ast.ast_transformers.checkpoint_transformer import (
30 CheckpointTransformer,
31 )
🟢 1368 @staticmethod
🟢 1369 def _is_checkpoint_call(node: ast.expr, global_vars: dict):
1370 """Thin forwarding wrapper around ``CheckpointTransformer.is_checkpoint_call``; the actual logic lives in module
1371 ``ast_transformers/checkpoint_transformer.py`` to keep this file from growing per-feature. Returns a
1372 ``CheckpointCallInfo`` or ``None``."""
🟢 1373 return CheckpointTransformer.is_checkpoint_call(node, global_vars)
1374
1613
🟢 1614 checkpoint_info = ASTTransformer._is_checkpoint_call(item.context_expr, ctx.global_vars)
🟢 1615 if checkpoint_info is not None:
🟢 1616 return ASTTransformer._build_checkpoint_with(ctx, node, checkpoint_info)
1617
🟢 1619 raise QuadrantsSyntaxError(
1620 "'with' in Quadrants kernels only supports qd.stream_parallel() or qd.checkpoint()"
1621 )
🟢 1629 @staticmethod
🟢 1630 def _build_checkpoint_with(
1631 ctx: ASTTransformerFuncContext,
1632 node: ast.With,
1633 info,
1634 ) -> None:
1635 """Thin forwarding wrapper around ``CheckpointTransformer.build_checkpoint_with``; the actual logic lives in
1636 ``ast_transformers/checkpoint_transformer.py``."""
🟢 1637 return CheckpointTransformer.build_checkpoint_with(ctx, node, info, build_stmts)
1638
🟢 python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py (89%)
1 # type: ignore
2 """AST recognition, validation, and auto-wrap lowering for ``qd.checkpoint(...)`` ``with`` blocks.
3
4 Lives alongside ``call_transformer.py`` / ``function_def_transformer.py`` so that ``ast_transformer.py`` doesn't have to
5 grow per-feature. ``ASTTransformer.build_With`` and ``ASTTransformer._is_checkpoint_call`` forward calls into the static
6 methods here. ``FunctionDefTransformer.build_FunctionDef`` calls ``auto_wrap_for_loops`` on the kernel body when the
7 kernel was decorated with ``@qd.kernel(graph=True, checkpoints=True)``, so that every top-level for-loop not already
8 inside a ``with qd.checkpoint(...)`` becomes its own implicit no-yield checkpoint.
9
10 See ``docs/source/user_guide/graph.md`` for the user-facing surface and ``perso_hugh/doc/qipc/reentrant.md`` for the
11 design.
12 """
13
🟢 14 from __future__ import annotations
15
🟢 16 import ast
🟢 17 from dataclasses import dataclass
18
🟢 19 from quadrants.lang.ast.ast_transformer_utils import ASTTransformerFuncContext
🟢 20 from quadrants.lang.exception import QuadrantsSyntaxError
21
22 # Sentinel name used by `_kernel_coverage.py` (`FIELD_VAR_NAME`) for the probe-tracking field. The checkpoint validator
23 # hard-codes the literal so coverage probes can be exempted from the bare-statement check without taking a runtime dep
24 # on the optional `_kernel_coverage` module (it is only imported when `QD_KERNEL_COVERAGE=1`). The two must stay in
25 # sync; if `FIELD_VAR_NAME` ever changes, update this constant and the corresponding test.
🟢 26 _KERNEL_COVERAGE_FIELD_NAME = "_qd_cov"
27
28 # Attribute name attached to synthetic `qd.checkpoint()` AST Call nodes generated by `auto_wrap_for_loops` for the
29 # implicit-no-yield wrap of top-level for-loops in a `checkpoints=True` kernel. Distinguishes auto-wrap calls from
30 # user-written `qd.checkpoint(cp_id, yield_on)` calls so the validator can apply different rules (implicit calls take no
31 # args; explicit calls require both `cp_id` and `yield_on`).
🟢 32 _IMPLICIT_MARKER_ATTR = "_qd_implicit"
33
34
🟢 35 @dataclass
🟢 36 class CheckpointCallInfo:
37 """Resolved metadata for a `qd.checkpoint(...)` call recognised in the AST.
38
39 - ``cp_id``: the user-supplied label (an ``int`` or ``IntEnum`` value), or ``None`` for an auto-wrap implicit
40 checkpoint.
41 - ``yield_on``: name of the kernel parameter passed as ``yield_on=`` (an ``ast.Name`` is required), or ``None`` for
42 an implicit checkpoint.
43 - ``is_implicit``: ``True`` iff this Call was synthesised by ``auto_wrap_for_loops``.
44 """
45
🟢 46 cp_id: int | None
🟢 47 yield_on: str | None
🟢 48 is_implicit: bool
49
50
🟢 51 class CheckpointTransformer:
🟢 52 @staticmethod
🟢 53 def _is_coverage_probe_assign(stmt: ast.stmt) -> bool:
54 """Return True iff *stmt* is the synthesized ``_qd_cov[<probe_id>] = 1`` assignment inserted by
55 ``_kernel_coverage.py`` when ``QD_KERNEL_COVERAGE=1``. Keeping these out of the bare-statement rejection in
56 ``build_checkpoint_with`` lets coverage CI exercise every checkpoint kernel without the user having to wrap the
57 synthetic probes themselves.
58 """
🟢 59 if not isinstance(stmt, ast.Assign) or len(stmt.targets) != 1:
🔴 60 return False
🟢 61 tgt = stmt.targets[0]
🟢 62 if not isinstance(tgt, ast.Subscript):
🔴 63 return False
🟢 64 return isinstance(tgt.value, ast.Name) and tgt.value.id == _KERNEL_COVERAGE_FIELD_NAME
65
🟢 66 @staticmethod
🟢 67 def _looks_like_checkpoint_call(node: ast.expr) -> bool:
68 """Cheap structural check: is *node* `qd.checkpoint(...)` or bare `checkpoint(...)`? Doesn't validate args or
69 resolve names; used by ``auto_wrap_for_loops`` to skip already-wrapped `with` blocks without paying for the full
70 argument-resolution pass (which needs ``global_vars`` and a real Kernel object)."""
🟢 71 if not isinstance(node, ast.Call):
🔴 72 return False
🟢 73 func = node.func
🟢 74 return (isinstance(func, ast.Attribute) and func.attr == "checkpoint") or (
75 isinstance(func, ast.Name) and func.id == "checkpoint"
76 )
77
🟢 78 @staticmethod
🟢 79 def is_explicit_checkpoint_with(stmt: ast.stmt) -> bool:
80 """Return True iff *stmt* is a single-item `with qd.checkpoint(...):` block (explicit or implicit). Used by
81 ``auto_wrap_for_loops`` to leave already-wrapped for-loops alone."""
🟢 82 if not isinstance(stmt, ast.With) or len(stmt.items) != 1:
🟢 83 return False
🟢 84 return CheckpointTransformer._looks_like_checkpoint_call(stmt.items[0].context_expr)
85
🟢 86 @staticmethod
🟢 87 def _resolve_cp_id(node: ast.expr, global_vars: dict) -> int:
88 """Resolve the first positional arg of `qd.checkpoint(cp_id, ...)` to a Python int (or IntEnum instance).
89
90 Accepts (a) `ast.Constant` int literals (``qd.checkpoint(0, ...)``), (b) `ast.Attribute` references to
91 `IntEnum` values resolved against the kernel's `global_vars` (``qd.checkpoint(Stage.SIM, ...)`` where ``Stage``
92 is an `IntEnum` defined at module scope), and (c) `ast.Name` references to module-level int constants
93 (``qd.checkpoint(CP_LOAD, ...)`` where ``CP_LOAD`` is an int defined at module scope). For (b) and (c) we
94 return the resolved value AS-IS (without re-wrapping through `int(...)`) so an IntEnum member identity is
95 preserved end-to-end -- the user writes `qd.checkpoint(Stage.SIM, ...)`, then reads `status.checkpoint` and
96 gets back `Stage.SIM`, not the raw int. Rejects everything else with a clear error so the user gets a
97 compile-time diagnostic rather than a confusing template-mapper failure later.
98 """
99 # Plain `qd.checkpoint(0, ...)` literal.
🟢 100 if isinstance(node, ast.Constant) and isinstance(node.value, int) and not isinstance(node.value, bool):
🟢 101 return node.value
102 # `qd.checkpoint(Stage.SIM, ...)` -- look up `Stage` in the kernel's module globals, then `.SIM` on it.
🟢 103 if isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
🟢 104 container_name = node.value.id
🟢 105 if container_name in global_vars:
🟢 106 container = global_vars[container_name]
🟢 107 if hasattr(container, node.attr):
🟢 108 val = getattr(container, node.attr)
🟢 109 if isinstance(val, int) and not isinstance(val, bool):
🟢 110 return val
111 # `qd.checkpoint(CP_LOAD, ...)` -- bare name referencing a module-level int constant.
🟢 112 if isinstance(node, ast.Name) and node.id in global_vars:
🔴 113 val = global_vars[node.id]
🔴 114 if isinstance(val, int) and not isinstance(val, bool):
🔴 115 return val
🟢 116 raise QuadrantsSyntaxError(
117 "qd.checkpoint() first argument must be an int literal or an IntEnum value (e.g. `Stage.SIM`). Got an "
118 f"unresolvable expression at line {getattr(node, 'lineno', '?')}; the cp_id must be statically "
119 "determinable at AST-walk time so the framework can build the label -> internal cp_id map and the "
120 "host-side resume API can refer to checkpoints by name."
121 )
122
🟢 123 @staticmethod
🟢 124 def is_checkpoint_call(node: ast.expr, global_vars: dict) -> CheckpointCallInfo | None:
125 """If *node* is a `qd.checkpoint(...)` call return a `CheckpointCallInfo`; otherwise return `None`.
126
127 Validates the call shape and raises `QuadrantsSyntaxError` for misuse so the user gets a clear message at the
128 `with` site rather than a vague "not stream_parallel" error later. Implicit calls (the synthetic
129 `qd.checkpoint()` no-arg calls produced by `auto_wrap_for_loops`) are recognised via the `_qd_implicit`
130 attribute and bypass user-call argument validation entirely.
131 """
🟢 132 if not CheckpointTransformer._looks_like_checkpoint_call(node):
🟢 133 return None
134 # Auto-wrap-synthesised implicit checkpoint: no user-facing args, no validation.
🟢 135 if getattr(node, _IMPLICIT_MARKER_ATTR, False):
🟢 136 return CheckpointCallInfo(cp_id=None, yield_on=None, is_implicit=True)
137 # User-written `qd.checkpoint(cp_id, yield_on)` -- both args are required.
🟢 138 if len(node.args) + len(node.keywords) == 0:
🟢 139 raise QuadrantsSyntaxError("qd.checkpoint() takes two arguments: `qd.checkpoint(cp_id, yield_on=flag)`.")
140 # Collect cp_id (positional 0 or kw `cp_id=`) and yield_on (positional 1 or kw `yield_on=`).
🟢 141 cp_id_arg: ast.expr | None = None
🟢 142 yield_on_arg: ast.expr | None = None
🟢 143 for i, arg in enumerate(node.args):
🟢 144 if i == 0:
🟢 145 cp_id_arg = arg
🔴 146 elif i == 1:
🔴 147 yield_on_arg = arg
148 else:
🔴 149 raise QuadrantsSyntaxError(
150 f"qd.checkpoint() takes at most 2 positional arguments (cp_id, yield_on); got {len(node.args)}"
151 )
🟢 152 for kw in node.keywords:
🟢 153 if kw.arg == "cp_id":
🔴 154 if cp_id_arg is not None:
🔴 155 raise QuadrantsSyntaxError("qd.checkpoint() got `cp_id` both positionally and as a keyword")
🔴 156 cp_id_arg = kw.value
🟢 157 elif kw.arg == "yield_on":
🟢 158 if yield_on_arg is not None:
🔴 159 raise QuadrantsSyntaxError("qd.checkpoint() got `yield_on` both positionally and as a keyword")
🟢 160 yield_on_arg = kw.value
161 else:
🟢 162 raise QuadrantsSyntaxError(
163 f"qd.checkpoint() got unexpected keyword argument {kw.arg!r}; only 'cp_id' and 'yield_on' are "
164 "supported"
165 )
🟢 166 if cp_id_arg is None:
🔴 167 raise QuadrantsSyntaxError(
168 "qd.checkpoint() is missing required argument `cp_id` (e.g. `qd.checkpoint(0, yield_on=flag)` or "
169 "`qd.checkpoint(Stage.SIM, yield_on=flag)`)"
170 )
🟢 171 if yield_on_arg is None:
🟢 172 raise QuadrantsSyntaxError(
173 "qd.checkpoint() is missing required argument `yield_on` (e.g. "
174 "`qd.checkpoint(0, yield_on=overflow_flag)`)"
175 )
🟢 176 if not isinstance(yield_on_arg, ast.Name):
🟢 177 raise QuadrantsSyntaxError(
178 "qd.checkpoint(yield_on=...) must be the bare name of a kernel parameter (e.g. "
179 "`yield_on=overflow_flag`); expressions are not supported"
180 )
🟢 181 cp_id_value = CheckpointTransformer._resolve_cp_id(cp_id_arg, global_vars)
🟢 182 return CheckpointCallInfo(cp_id=cp_id_value, yield_on=yield_on_arg.id, is_implicit=False)
183
🟢 184 @staticmethod
🟢 185 def build_checkpoint_with(
186 ctx: ASTTransformerFuncContext,
187 node: ast.With,
188 info: CheckpointCallInfo,
189 build_stmts,
190 ) -> None:
191 """Handles a `with qd.checkpoint(...):` block (explicit or auto-wrapped implicit).
192
193 Validates the use-site (kernel must be `graph=True, checkpoints=True`, no nesting; for explicit calls, the
194 `yield_on` arg must be a kernel parameter and the `cp_id` must be unique across the kernel) and appends an entry
195 to `kernel.checkpoint_yield_on_args` + `kernel.checkpoint_user_labels_by_cp_id`. Walks the body transparently --
196 for-loops inside the `with` become normal top-level for-loops in the kernel's frontend IR. The internal `cp_id`
197 is assigned by declaration order (the C++ `ASTBuilder.begin_checkpoint()` counter, mirrored as the list index in
198 `checkpoint_yield_on_args`).
199 """
🟢 200 if not ctx.is_kernel:
🔴 201 raise QuadrantsSyntaxError("qd.checkpoint() can only be used inside @qd.kernel, not @qd.func")
🟢 202 kernel = ctx.global_context.current_kernel
🟢 203 if not kernel.use_graph:
🔴 204 raise QuadrantsSyntaxError(
205 "qd.checkpoint() requires @qd.kernel(graph=True); the resume model only applies to graph kernels"
206 )
🟢 207 if not kernel.use_checkpoints:
🟢 208 raise QuadrantsSyntaxError(
209 f"qd.checkpoint() found in kernel {kernel.func.__name__!r}, but this kernel was not decorated with "
210 "`checkpoints=True`. Add `checkpoints=True` to the decorator -- e.g. "
211 "`@qd.kernel(graph=True, checkpoints=True)` -- to opt into the pause / resume model."
212 )
🟢 213 if getattr(ctx, "_in_checkpoint", False):
🟢 214 raise QuadrantsSyntaxError(
215 "qd.checkpoint() cannot be nested inside another qd.checkpoint(); checkpoints in the "
216 "same kernel must be flat siblings (a checkpoint inside qd.graph_do_while is fine)"
217 )
218
🟢 219 if not info.is_implicit:
220 # Validate `yield_on=` names a real kernel parameter.
🟢 221 arg_names = [m.name for m in kernel.arg_metas]
🟢 222 if info.yield_on not in arg_names:
🟢 223 raise QuadrantsSyntaxError(
224 f"qd.checkpoint(yield_on={info.yield_on!r}) does not match any parameter of kernel "
225 f"{kernel.func.__name__!r}. Available parameters: {arg_names}"
226 )
227 # Reject duplicate user-supplied cp_id labels.
🟢 228 existing = [lbl for lbl in kernel.checkpoint_user_labels_by_cp_id if lbl is not None]
🟢 229 if info.cp_id in existing:
🟢 230 raise QuadrantsSyntaxError(
231 f"qd.checkpoint(cp_id={info.cp_id!r}, ...) declares a cp_id that is already used by another "
232 f"checkpoint in kernel {kernel.func.__name__!r}. Each `cp_id` label must be unique so the host "
233 "loop can map `status.checkpoint` and `kernel.resume(from_checkpoint=...)` unambiguously."
234 )
235
236 # Reject bare top-level statements (Assign / AugAssign / AnnAssign / non-docstring Expr) at the top of the
237 # checkpoint body and ask the user to wrap them in their own for-loop. The offloader's pending-serial bucket
238 # loses the surrounding `checkpoint_id` and emits such statements as `serial` tasks with `cp_id == -1`, so they
239 # would run unconditionally even when the checkpoint is skipped -- a silent correctness bug. Rather than
240 # auto-wrapping them transparently (which hides the fact that each bare stmt becomes its own kernel / graph
241 # node and surprises users when they look at the lowered IR or `prog.get_num_offloaded_tasks_on_last_call()`),
242 # we surface a clear compile-time error and have the user write `for _ in range(1): <stmt>` themselves.
243 # Implicit (auto-wrapped) checkpoints can never trip this -- they wrap exactly one ast.For. Coverage probes
244 # (synthesized `_qd_cov[i] = 1` assignments under `QD_KERNEL_COVERAGE=1`) are explicitly exempt; the cp_id
245 # propagation in `quadrants/transforms/offload.cpp` (`assemble_serial_statements`) makes them inherit the
246 # surrounding checkpoint's cp_id so the launcher's "last task in checkpoint" detection stays correct.
🟢 247 for stmt in node.body:
🟢 248 is_bare = isinstance(stmt, (ast.Assign, ast.AugAssign, ast.AnnAssign))
🟢 249 if not is_bare and isinstance(stmt, ast.Expr):
250 # Any `Expr(Constant)` is a no-op (Python's docstring pattern, e.g. a leading triple-quoted string).
251 # We accept it anywhere in the body rather than only at position 0 because the kernel-coverage AST
252 # transformer (see `_kernel_coverage.py`) prepends a `_qd_cov[...] = 1` probe under
253 # `QD_KERNEL_COVERAGE=1`, which would otherwise push the docstring to position 1 and have us flag it as
254 # a bare top-level statement.
🟢 255 is_constant_expr = isinstance(stmt.value, ast.Constant)
🟢 256 is_bare = not is_constant_expr
🟢 257 if not is_bare:
🟢 258 continue
🟢 259 if CheckpointTransformer._is_coverage_probe_assign(stmt):
🟢 260 continue
🟢 261 stmt_kind = type(stmt).__name__
🟢 262 raise QuadrantsSyntaxError(
263 f"qd.checkpoint() body cannot contain a bare top-level {stmt_kind} statement "
264 f"(line {getattr(stmt, 'lineno', '?')}): every top-level statement in a checkpoint must be inside a "
265 f"for-loop (or other control-flow construct), so the compiler can lower it as its own offloaded task "
266 f"with the correct cp_id. Wrap the statement in `for _ in range(1):` to keep the original intent:\n"
267 f"\n"
268 f" with qd.checkpoint(cp_id, yield_on=flag):\n"
269 f" for _ in range(1):\n"
270 f" <your statement here>\n"
271 f" for i in range(arr.shape[0]):\n"
272 f" ...\n"
273 )
274
🟢 275 kernel.checkpoint_yield_on_args.append(info.yield_on)
🟢 276 kernel.checkpoint_user_labels_by_cp_id.append(info.cp_id)
277 # Hand control to the C++ ASTBuilder so that every for-loop emitted by `build_stmts` below is tagged with this
278 # checkpoint's internal `cp_id` on its `ForLoopConfig.checkpoint_id`. The C++ counter is the source of truth for
279 # the dense internal cp_id; we cross-check it against the Python list index so that a future refactor that
280 # misaligns the two surfaces fires immediately.
🟢 281 cpp_cp_id = ctx.ast_builder.begin_checkpoint()
🟢 282 py_cp_id = len(kernel.checkpoint_yield_on_args) - 1
🟢 283 assert cpp_cp_id == py_cp_id, (
284 f"C++ ASTBuilder.begin_checkpoint() returned cp_id={cpp_cp_id} but Python "
285 f"kernel.checkpoint_yield_on_args index expected {py_cp_id}; these counters "
286 f"must stay in lockstep so the GraphManager can index yield_on by cp_id"
287 )
🟢 288 ctx._in_checkpoint = True
🟢 289 try:
🟢 290 build_stmts(ctx, node.body)
291 finally:
🟢 292 ctx._in_checkpoint = False
🟢 293 ctx.ast_builder.end_checkpoint()
🟢 294 return None
295
🟢 296 @staticmethod
🟢 297 def _make_implicit_checkpoint_with(for_stmt: ast.For) -> ast.With:
298 """Construct a synthetic ``with qd.checkpoint(): <for_stmt>`` AST node carrying the `_qd_implicit` marker.
299
300 The `qd` name resolves via the kernel's module globals at AST-walk time; in practice every kernel that imports
301 `quadrants as qd` (the canonical pattern) has it available. If the user imported as `import quadrants as t`,
302 `with qd.checkpoint(...)` would already fail to resolve, so we don't worry about that case here.
303 """
🟢 304 call = ast.Call(
305 func=ast.Attribute(value=ast.Name(id="qd", ctx=ast.Load()), attr="checkpoint", ctx=ast.Load()),
306 args=[],
307 keywords=[],
308 )
🟢 309 setattr(call, _IMPLICIT_MARKER_ATTR, True)
🟢 310 with_stmt = ast.With(
311 items=[ast.withitem(context_expr=call, optional_vars=None)],
312 body=[for_stmt],
313 type_comment=None,
314 )
🟢 315 ast.copy_location(with_stmt, for_stmt)
🟢 316 ast.copy_location(call, for_stmt)
🟢 317 ast.copy_location(call.func, for_stmt)
🟢 318 ast.copy_location(call.func.value, for_stmt)
319 # ast.With requires a complete locations table for downstream passes; fix up missing fields.
🟢 320 ast.fix_missing_locations(with_stmt)
🟢 321 return with_stmt
322
🟢 323 @staticmethod
🟢 324 def auto_wrap_for_loops(stmts: list[ast.stmt]) -> list[ast.stmt]:
325 """Auto-wrap pass for `@qd.kernel(graph=True, checkpoints=True)` kernels.
326
327 Walks *stmts* and returns a new list where every `ast.For` not already inside a `with qd.checkpoint(...)` block
328 is wrapped in a synthetic implicit checkpoint. Recurses into `while qd.graph_do_while(...)` bodies so for-loops
329 nested in the WHILE body get the same treatment. Other compound statements (`ast.If`, `ast.With`,
330 non-graph_do_while `ast.While`, `ast.Try`) are passed through unchanged -- they're not the common pattern in
331 Quadrants kernel bodies and recursing into them risks wrapping nested for-loops that the user intended to be
332 sub-tasks of a larger control-flow block. Bare top-level statements (assignments, expressions, coverage probes)
333 are also passed through unchanged so they remain in the kernel prologue with `cp_id=-1` and run on every launch.
334 """
🟢 335 new_stmts: list[ast.stmt] = []
🟢 336 for stmt in stmts:
🟢 337 if isinstance(stmt, ast.For):
🟢 338 new_stmts.append(CheckpointTransformer._make_implicit_checkpoint_with(stmt))
🟢 339 elif CheckpointTransformer.is_explicit_checkpoint_with(stmt):
340 # User has already wrapped this for-loop (or block of for-loops) -- leave alone.
🟢 341 new_stmts.append(stmt)
🟢 342 elif isinstance(stmt, ast.While):
343 # `while qd.graph_do_while(...):` -- recurse into the loop body. (Plain `while <cond>:` is not a
344 # supported Quadrants kernel construct; the `build_While` transformer will reject it later, so we don't
345 # special-case it here.)
🟢 346 stmt.body = CheckpointTransformer.auto_wrap_for_loops(stmt.body)
🟢 347 new_stmts.append(stmt)
348 else:
🟢 349 new_stmts.append(stmt)
🟢 350 return new_stmts
🔴 python/quadrants/lang/ast/ast_transformers/function_def_transformer.py (78%)
🟢 29 from quadrants.lang.ast.ast_transformers.checkpoint_transformer import (
30 CheckpointTransformer,
31 )
🟢 511 kernel = ctx.global_context.current_kernel
🟢 512 if kernel is not None:
513 # Reset before walking the body so re-materialisations (e.g. when a templated kernel is compiled with a
514 # different argument shape) start from an empty list. Mirrors how `graph_do_while_arg` gets overwritten
515 # unconditionally during AST traversal.
🟢 516 kernel.checkpoint_yield_on_args = []
🟢 517 kernel.checkpoint_user_labels_by_cp_id = []
518 # Auto-wrap pass for `@qd.kernel(graph=True, checkpoints=True)` kernels. Mutates `node.body` in place so
519 # every top-level for-loop (and every for-loop inside a `qd.graph_do_while` body) that the user did not
520 # already wrap in a `with qd.checkpoint(...)` gets wrapped in a synthetic implicit no-yield checkpoint.
521 # Implicit checkpoints share the same dense source-order internal cp_id space as explicit ones, but
522 # carry `None` in `checkpoint_user_labels_by_cp_id` so they never appear in `GraphStatus.checkpoint` /
523 # `kernel.resume(from_checkpoint=...)`. Runs here (after coverage instrumentation has already injected
524 # its top-level `_qd_cov[i] = 1` probes, which are bare assigns that the wrap pass intentionally leaves
525 # alone) so that the regular `build_stmts` walk below sees a uniform stream of `with qd.checkpoint(...)`
526 # blocks and bare prologue stmts.
🟢 527 if kernel.use_checkpoints:
🟢 528 node.body = CheckpointTransformer.auto_wrap_for_loops(node.body)
🟢 573 @staticmethod
🟢 574 def _is_checkpoint_with(stmt: ast.With) -> bool:
575 """Syntactic check matching CheckpointTransformer: a ``with qd.checkpoint(...):`` block. Accepted as a
576 well-formed statement inside a graph_do_while body (the checkpoint itself enforces its own restrictions)."""
🟢 577 if not isinstance(stmt, ast.With) or len(stmt.items) != 1:
🔴 578 return False
🟢 579 ctx = stmt.items[0].context_expr
🟢 580 if not isinstance(ctx, ast.Call):
🔴 581 return False
🟢 582 func = ctx.func
🟢 583 if isinstance(func, ast.Attribute) and func.attr == "checkpoint":
🟢 584 return True
🔴 585 if isinstance(func, ast.Name) and func.id == "checkpoint":
🔴 586 return True
🔴 587 return False
588
🟢 655 if isinstance(stmt, ast.With) and FunctionDefTransformer._is_checkpoint_with(stmt):
656 # `with qd.checkpoint(...)` is a legal placement for graph kernels (the checkpoint may sit at the
657 # kernel top level or inside any graph_do_while level). Recurse so a malformed graph_do_while
658 # nested inside a checkpoint body is still caught; the checkpoint's own body restrictions are
659 # enforced by `CheckpointTransformer.build_checkpoint_with`.
🟢 660 FunctionDefTransformer._validate_graph_do_while_stmt_list(stmt.body, is_kernel_top=is_kernel_top)
🟢 661 continue
🟢 python/quadrants/lang/checkpoint.py (100%)
1 """User-facing ``qd.checkpoint`` context-manager and its no-op Python-runtime stub.
2
3 Mirrors ``graph_status.py``, which holds the other half of the same feature surface (``GraphStatus``). Kept in its own
4 module to keep ``lang/misc.py`` from growing further -- the AST transformer and the C++ runtime are doing all the actual
5 implementation work; this file is just the public API entry point.
6
7 Re-exported via ``qd.lang.misc`` (and therefore as ``qd.checkpoint``) for the user-facing canonical import path.
8 """
9
🟢 10 from __future__ import annotations
11
🟢 12 from contextlib import contextmanager
13
14
🟢 15 @contextmanager
🟢 16 def checkpoint(cp_id, yield_on):
17 """Marks a section of a graph kernel as a pause / resume point.
18
19 .. warning::
20
21 **Experimental.** ``qd.checkpoint`` (together with ``qd.GraphStatus`` and ``kernel.resume(from_checkpoint=...)``)
22 is an experimental API. The signature, the lowering across backends, the error messages, and the host-side
23 yield/resume contract may change in any future release without a deprecation cycle.
24
25 Used as ``with qd.checkpoint(cp_id, yield_on=flag):`` inside a ``@qd.kernel(graph=True, checkpoints=True)`` kernel
26 body. When the body writes a non-zero value into ``flag``, the kernel pauses at this checkpoint and returns a
27 ``GraphStatus`` to the host carrying ``status.checkpoint == cp_id``. The host can then fix things up and call
28 ``kernel.resume(..., from_checkpoint=cp_id)`` to continue from the same point on the next launch.
29
30 Arguments:
31 cp_id: User-facing label identifying this checkpoint to the host. Must be an ``int`` literal or an ``IntEnum``
32 value, and must be unique within the kernel. The value is preserved as-is end-to-end -- if you pass
33 ``Stage.SIM`` (an ``IntEnum`` member), ``status.checkpoint`` round-trips back as ``Stage.SIM`` rather than
34 the raw int.
35 yield_on: Name of a kernel parameter that is a 0-d ``qd.types.ndarray(qd.i32, ndim=0)``. The body may write a
36 non-zero value into it to signal "pause here, host needs to handle something". The framework never writes
37 into this buffer -- the host owns it end-to-end and must initialise it to ``0`` before the first launch
38 (``qd.ndarray`` is not zero-initialised) AND reset it to ``0`` before each ``kernel.resume(...)`` call
39 (otherwise the same checkpoint sees the stale non-zero value and yields again).
40
41 Restrictions (enforced at kernel compile time):
42 - Must be used inside ``@qd.kernel(graph=True, checkpoints=True)``.
43 - ``cp_id`` must be an ``int`` (or ``IntEnum`` value), and must be unique across the kernel.
44 - ``yield_on`` must name a kernel parameter that is a 0-d ``qd.types.ndarray(qd.i32, ndim=0)``.
45 - Checkpoints cannot be nested inside other checkpoints. Checkpoints inside a ``qd.graph_do_while`` body are fine.
46 - Cannot be combined with ``qd.stream_parallel()`` in the same kernel.
47 - The body cannot contain bare top-level statements (assignments, expressions); wrap them in
48 ``for _ in range(1):`` so the lowering surfaces the per-statement task cost.
49
50 This function should not be called directly at runtime; it is recognised and transformed during AST compilation. At
51 Python runtime (outside kernels), this is a no-op context manager so that doctests / type-checking can import the
52 symbol freely.
53
54 See ``docs/source/user_guide/graph.md`` for the host-side yield/resume loop and cross-backend semantics.
55 """
🟢 56 del cp_id, yield_on
🟢 57 yield
🔴 python/quadrants/lang/graph_status.py (70%)
1 """Plain-Python container returned from graph kernels that contain ``qd.checkpoint(yield_on=...)``.
2
3 Lives in its own module (with no Quadrants-internal imports) so it can be imported safely from both ``kernel.py`` and
4 ``misc.py`` without re-introducing the circular import chain that ``misc.py -> impl.py -> kernel.py`` would create.
5
6 Re-exported via ``qd.lang.misc`` (and therefore as ``qd.GraphStatus``) for the user-facing canonical import path.
7 """
8
🟢 9 from __future__ import annotations
10
11
🟢 12 class GraphStatus:
13 """Result returned by a graph kernel that contains ``qd.checkpoint(cp_id, yield_on=...)`` blocks.
14
15 .. warning::
16
17 **Experimental.** ``GraphStatus`` is part of the experimental ``qd.checkpoint`` surface; its attributes and
18 the conditions under which it is returned may change in any future release without a deprecation cycle.
19
20 Returned from ``kernel(...)`` and ``kernel.resume(..., from_checkpoint=label)`` whenever the kernel was decorated
21 with ``@qd.kernel(graph=True, checkpoints=True)``. Read ``status.yielded`` to decide whether to keep running the
22 host loop, and ``status.checkpoint`` to find out which checkpoint asked the host to handle something.
23
24 Canonical usage (see ``graph.md``)::
25
26 from enum import IntEnum
27
28 class Stage(IntEnum):
29 SIM = 0
30
31 overflow_flag[()] = 0 # initialise before the first launch
32 status = step(arr, overflow_flag, newton_cond)
33 while status.yielded:
34 handle_overflow_for(status.checkpoint, ...)
35 overflow_flag[()] = 0 # the framework never clears your yield_on flag
36 status = step.resume(arr, overflow_flag, newton_cond,
37 from_checkpoint=status.checkpoint)
38
39 Attributes:
40 yielded: ``True`` iff one of the kernel's checkpoints paused on the most recent launch (i.e. its ``yield_on=``
41 flag was non-zero). ``False`` means the kernel completed normally and the host loop should exit.
42 checkpoint: The ``cp_id`` label of the checkpoint that paused (an ``int`` or the original ``IntEnum`` instance
43 you passed to ``qd.checkpoint(cp_id, ...)``), or ``None`` when ``yielded`` is ``False``. Pass it to
44 ``kernel.resume(..., from_checkpoint=...)`` to continue from that point on the next launch.
45 """
46
🟢 47 __slots__ = ("yielded", "checkpoint")
48
🟢 49 def __init__(self, yielded: bool, checkpoint: int | None):
🟢 50 self.yielded = yielded
🟢 51 self.checkpoint = checkpoint
52
🟢 53 def __repr__(self) -> str:
🔴 54 if self.yielded:
55 # Use `!r` so an `IntEnum` cp_id is shown as `<Stage.LOAD: 0>` rather than collapsing to its raw int via
56 # `int.__format__` (which Python 3.10's `f"{IntEnum.X}"` does). Plain ints round-trip unchanged.
🔴 57 return f"GraphStatus(yielded=True, checkpoint={self.checkpoint!r})"
🔴 58 return "GraphStatus(yielded=False)"
🟢 python/quadrants/lang/kernel.py (100%)
39
40 # `qd.checkpoint` pause / resume model helpers. See `kernel_checkpoint.py` for the full extracted surface; `Kernel`
41 # delegates the resume-cookie validation, label translation, per-launch yield_on= arg-id table build, and GraphStatus
42 # construction to those free functions so this hot file doesn't accrete checkpoint-feature-specific blocks.
🟢 43 from quadrants.lang import kernel_checkpoint as _checkpoint_helpers
331 # Opt-in flag set by `@qd.kernel(graph=True, checkpoints=True)`. When True, the AST transformer enables
332 # `qd.checkpoint(...)` recognition AND auto-wraps every top-level for-loop that isn't already inside a
333 # `with qd.checkpoint(...)` block in an implicit no-yield checkpoint. When False, any use of
334 # `qd.checkpoint(...)` in the kernel body is rejected at compile time with a fix-it pointing at
335 # `checkpoints=True`.
🟢 336 self.use_checkpoints: bool = False
345 # Per-checkpoint metadata, one entry per `with qd.checkpoint(...)` block (explicit AND auto-injected implicit)
346 # in declaration order. List index is the checkpoint's internal `cp_id` (0, 1, 2, ... dense, flat across the
347 # kernel). Each entry is the name of the `yield_on=` kernel parameter, or `None` for implicit checkpoints
348 # (which never yield). Populated by the AST transformer; empty means the kernel uses no checkpoints.
🟢 349 self.checkpoint_yield_on_args: list[str | None] = []
350 # User-facing labels for explicit checkpoints. Same indexing as `checkpoint_yield_on_args`: entry `i` is the int
351 # (or IntEnum value) the user passed as the first positional arg of `qd.checkpoint(cp_id, yield_on)` for the
352 # checkpoint whose internal cp_id is `i`. Implicit checkpoints (auto-wrapped) get `None` (they have no
353 # user-facing label and can never appear in `GraphStatus.checkpoint`). The label is preserved as-is so an
354 # `IntEnum` round-trips: writing `qd.checkpoint(Stage.SIM, ...)` and then reading `status.checkpoint` returns
355 # `Stage.SIM` rather than the raw int.
🟢 356 self.checkpoint_user_labels_by_cp_id: list[int | None] = []
535 self,
536 key,
537 t_kernel: KernelCxx,
538 compiled_kernel_data: CompiledKernelData | None,
539 *args,
540 qd_stream=None,
541 _resume_from_checkpoint: int | None = None,
🟢 590 _checkpoint_helpers.init_yield_on_arg_id_table(self)
🟢 608 _checkpoint_helpers.maybe_record_yield_on_arg(self, self.arg_metas[i_in].name, i_out - template_num)
🟢 684 _checkpoint_helpers.forward_yield_on_table_to_ctx(self, launch_ctx)
685 # `_resume_from_checkpoint` is `None` for fresh launches (host-side default 0 in `LaunchContextBuilder`,
686 # which means "run every checkpoint"). When `Kernel.resume` plumbs an int through, copy it onto the launch
687 # context so the GraphManager's `launch_cached_graph` memcpys it into the device-side `resume_point` slot
688 # instead of clearing to 0. Slice 2 implementation; pre-CUDA-12.4 / non-CUDA backends ignore the value since
689 # they don't have a resume_point slot today (slices 4-6 will add an indirect-dispatch equivalent).
🟢 690 if _resume_from_checkpoint is not None:
🟢 691 launch_ctx.resume_from_checkpoint = int(_resume_from_checkpoint)
780 # Pop the resume cookie before anything else touches kwargs -- the AST mapper sees user parameter names only, so
781 # a stray `from_checkpoint=` would raise "unexpected kwarg". `_resume_from_checkpoint` is the resolved cp_id to
782 # copy into the device-side `resume_point` slot before launch; `None` means "fresh start, reset to 0". Plumbed
783 # via `Kernel.resume()` only; users do not pass this directly.
🟢 784 _resume_from_checkpoint = kwargs.pop("_qd_from_checkpoint", None)
🟢 785 _checkpoint_helpers.validate_resume_cookie(self, _resume_from_checkpoint)
858 # Translate the user-supplied `from_checkpoint=` label into the dense, source-order internal cp_id the runtime
859 # uses. Translation happens here (after `ensure_compiled`) because `checkpoint_user_labels_by_cp_id` is
860 # populated during AST processing inside `ensure_compiled`.
🟢 861 if _resume_from_checkpoint is not None:
🟢 862 _resume_from_checkpoint = _checkpoint_helpers.translate_user_label_to_internal_cp_id(
863 self, _resume_from_checkpoint
864 )
865 # Only forward `_resume_from_checkpoint` when the caller actually supplied one (i.e. via `Kernel.resume(...)`).
866 # Otherwise omit the kwarg entirely so subclasses / monkeypatches of `launch_kernel` that pre-date this kwarg
867 # keep working unmodified. The host-side default in `LaunchContextBuilder` is 0 ("run every checkpoint"), which
868 # matches the `None` semantics in `launch_kernel`.
🟢 869 if _resume_from_checkpoint is None:
🟢 870 ret = self.launch_kernel(
871 key,
872 kernel_cpp,
873 compiled_kernel_data,
874 *py_args,
875 qd_stream=qd_stream,
876 )
877 else:
🟢 878 ret = self.launch_kernel(
879 key,
880 kernel_cpp,
881 compiled_kernel_data,
882 *py_args,
883 qd_stream=qd_stream,
884 _resume_from_checkpoint=_resume_from_checkpoint,
885 )
889 # Surface a GraphStatus for kernels with `qd.checkpoint(yield_on=...)` so the host can drive the qipc-style
890 # re-entrant loop. Kernels without yield-capable checkpoints get `ret` (typically `None`) passed through.
🟢 891 return _checkpoint_helpers.maybe_build_graph_status(self, ret)
🟢 python/quadrants/lang/kernel_checkpoint.py (97%)
1 """Helpers extracted from ``kernel.py`` for the ``qd.checkpoint(...)`` pause / resume model.
2
3 ``Kernel.__call__`` / ``Kernel.launch_kernel`` delegate the resume-cookie validation, the user-label-to-internal-
4 cp_id translation, the per-launch ``yield_on=`` arg-id table construction, and the ``GraphStatus`` build to the free
5 functions below so the central ``Kernel`` class doesn't accrete checkpoint-feature-specific blocks. See
6 ``qd.checkpoint`` / ``kernel.resume`` / ``docs/source/user_guide/graph.md`` for the user-facing surface.
7 """
8
🟢 9 from __future__ import annotations
10
🟢 11 from typing import Any
12
🟢 13 from quadrants.lang import impl
🟢 14 from quadrants.lang.graph_status import GraphStatus
15
16
🟢 17 def validate_resume_cookie(kernel: Any, resume_from_checkpoint: int | None) -> None:
18 """Raise if ``_qd_from_checkpoint`` was passed to a kernel without any ``qd.checkpoint(yield_on=...)`` block.
19
20 Called from the preamble of ``Kernel.__call__`` so the user gets a clear error before any compile / launch work
21 happens, rather than a confusing "no GraphStatus surface" failure later.
22 """
🟢 23 if resume_from_checkpoint is not None and not kernel.checkpoint_yield_on_args:
🔴 24 raise RuntimeError(
25 "`from_checkpoint=` is only valid for kernels that contain at least one "
26 "qd.checkpoint(yield_on=...) block; this kernel has none."
27 )
28
29
🟢 30 def translate_user_label_to_internal_cp_id(kernel: Any, user_label: int) -> int:
31 """Translate a user-supplied ``from_checkpoint=`` label (int or IntEnum) to the runtime's internal dense cp_id.
32
33 The runtime indexes checkpoints by source-declaration order (0, 1, 2, ...). The user-facing label is whatever int /
34 IntEnum they passed as the first positional arg of ``qd.checkpoint(cp_id, yield_on=...)``; the
35 ``checkpoint_user_labels_by_cp_id`` table maps internal cp_id -> user label. Compared with ``==`` so an IntEnum
36 value matches its underlying int. Implicit (auto-wrapped) checkpoints have ``None`` in the table and are never
37 resume targets. Raises ``RuntimeError`` listing the available labels when the user passed an unknown one.
38 """
🟢 39 for internal_cp_id, label in enumerate(kernel.checkpoint_user_labels_by_cp_id):
🟢 40 if label is not None and label == user_label:
🟢 41 return internal_cp_id
🟢 42 available = [lbl for lbl in kernel.checkpoint_user_labels_by_cp_id if lbl is not None]
🟢 43 raise RuntimeError(
44 f"from_checkpoint={user_label!r} does not match any qd.checkpoint(cp_id=...) in "
45 f"kernel {kernel.func.__name__!r}. Available cp_id labels (source-declaration order): {available}."
46 )
47
48
🟢 49 def init_yield_on_arg_id_table(kernel: Any) -> None:
50 """Allocate / reset the per-launch ``cp_id -> C++ arg-id`` table at the top of ``launch_kernel``'s arg iteration.
51
52 Each entry defaults to ``-1`` ("no yield_on"); the per-arg loop below fills in the C++ arg id when it visits the
53 named parameter. Sized to the kernel's checkpoint count once per launch so any changes to the checkpoint set (only
54 possible via re-AST-walk) reset the table cleanly. No-op for kernels with no ``yield_on=`` checkpoints.
55 """
🟢 56 if kernel.checkpoint_yield_on_args:
🟢 57 kernel._checkpoint_yield_on_cpp_arg_ids = [-1] * len(kernel.checkpoint_yield_on_args)
58
59
🟢 60 def maybe_record_yield_on_arg(kernel: Any, arg_name: str, cpp_arg_id: int) -> None:
61 """Fill the ``cp_id -> C++ arg-id`` slot when the arg iterator visits a named ``yield_on=`` kernel parameter.
62
63 Walked once per kernel arg in ``launch_kernel``; cheap O(checkpoints) match. A single parameter can be the
64 ``yield_on=`` for multiple checkpoints (the inner loop fills every matching slot).
65 """
🟢 66 if not kernel.checkpoint_yield_on_args:
🟢 67 return
🟢 68 for cp_idx, yield_name in enumerate(kernel.checkpoint_yield_on_args):
🟢 69 if yield_name is not None and arg_name == yield_name:
🟢 70 kernel._checkpoint_yield_on_cpp_arg_ids[cp_idx] = cpp_arg_id
71
72
🟢 73 def forward_yield_on_table_to_ctx(kernel: Any, launch_ctx: Any) -> None:
74 """Copy the resolved ``cp_id -> C++ arg-id`` table onto the launch context so the runtime can find each
75 ``yield_on=`` ndarray's device address at launch.
76 """
🟢 77 if kernel.checkpoint_yield_on_args and hasattr(kernel, "_checkpoint_yield_on_cpp_arg_ids"):
🟢 78 launch_ctx.checkpoint_yield_on_arg_ids = tuple(kernel._checkpoint_yield_on_cpp_arg_ids)
79
80
🟢 81 def maybe_build_graph_status(kernel: Any, default_ret: Any) -> Any:
82 """Translate the runtime's internal yielding cp_id back to the user-supplied label and return a ``GraphStatus``.
83
84 Returns ``default_ret`` unchanged for kernels without any yielding checkpoint -- there's no ``yield_on=`` parameter
85 to surface a status from, so the value would always be ``yielded=False, checkpoint=None`` (no information).
86 Implicit (auto-wrapped) checkpoints have ``None`` in ``checkpoint_user_labels_by_cp_id`` but they never have
87 ``yield_on=``, so the runtime can't surface them as the yielding cp -- the lookup is always to an explicit
88 checkpoint.
89 """
🟢 90 if not (kernel.checkpoint_yield_on_args and any(n is not None for n in kernel.checkpoint_yield_on_args)):
🟢 91 return default_ret
🟢 92 cp = impl.get_runtime().prog.get_graph_last_yield_cp_id_on_last_call()
🟢 93 if cp >= 0:
🟢 94 user_label = (
95 kernel.checkpoint_user_labels_by_cp_id[cp] if cp < len(kernel.checkpoint_user_labels_by_cp_id) else None
96 )
🟢 97 return GraphStatus(yielded=True, checkpoint=user_label if user_label is not None else cp)
🟢 98 return GraphStatus(yielded=False, checkpoint=None)
🟢 python/quadrants/lang/kernel_impl.py (100%)
161 checkpoints: bool = False,
🟢 172 primal.use_checkpoints = checkpoints
🟢 173 adjoint.use_checkpoints = checkpoints
215 def kernel(
216 _fn: None = None, *, pure: bool = False, graph: bool = False, checkpoints: bool = False
217 ) -> Callable[[Any], Any]: ...
227 def kernel(_fn: Any, *, pure: bool = False, graph: bool = False, checkpoints: bool = False) -> Any: ...
236 checkpoints: bool = False,
253 checkpoints: If True, opt into the (experimental) ``qd.checkpoint`` pause / resume model.
254 ``with qd.checkpoint(cp_id, yield_on=flag):`` blocks in the kernel body become pause points the host can
255 resume from via ``kernel.resume(from_checkpoint=cp_id)``. Requires ``graph=True``.
256 ``qd.checkpoint(...)`` in the body is rejected unless this flag is set.
🟢 276 if checkpoints and not graph:
🟢 277 raise QuadrantsSyntaxError(
278 f"@qd.kernel({fn.__name__!r}, checkpoints=True) requires graph=True; "
279 "the checkpoint resume model is only meaningful for graph kernels."
280 )
🟢 281 wrapped = _kernel_impl(fn, level_of_class_stackframe=level, graph=graph, checkpoints=checkpoints)
🟢 python/quadrants/lang/misc.py (100%)
🟢 13 from quadrants.lang.checkpoint import checkpoint
🟢 15 from quadrants.lang.graph_status import GraphStatus
890 "GraphStatus",
891 "checkpoint",
🟢 tests/python/test_checkpoint.py (90%)
1 """Tests for ``qd.checkpoint`` -- yield/resume stage primitive for graph kernels.
2
3 These tests cover the auto-checkpoint surface:
4
5 - The user-facing API is ``qd.checkpoint(cp_id, yield_on=flag)``. Both arguments are required; ``cp_id`` is a user
6 label (``int`` or ``IntEnum`` value), and ``yield_on`` is a 0-d ``qd.i32`` ndarray kernel parameter.
7 - The kernel must be decorated with ``@qd.kernel(graph=True, checkpoints=True)`` to use ``qd.checkpoint(...)``. The
8 flag opts the kernel into the resume model and enables the auto-wrap pass.
9 - Auto-wrap: every top-level for-loop in the kernel body (including inside ``while qd.graph_do_while(...):``) that
10 is not inside a ``with qd.checkpoint(...)`` becomes an implicit no-yield checkpoint. Implicit checkpoints carry no
11 user label and never appear in ``GraphStatus.checkpoint``, but they DO consume an internal cp_id slot so a resume
12 launch can skip them along with the explicit checkpoints declared earlier in source order.
13 - ``status.checkpoint`` round-trips the user-supplied label (so ``qd.checkpoint(Stage.SIM, ...)`` surfaces as
14 ``Stage.SIM`` on yield). ``kernel.resume(from_checkpoint=Stage.SIM)`` skips every checkpoint (implicit + explicit)
15 declared before ``Stage.SIM`` in source order.
16
17 The behavioural assertions (yield, resume, kernel completes normally on no-yield) run on every backend that implements
18 the host-side yield/resume contract -- see ``_supports_checkpoint_yield_resume`` below. The CUDA-native-only counters
19 (IF conditional node count) are guarded behind ``_is_checkpoint_if_path_native``.
20 """
21
🟢 22 from enum import IntEnum
23
🟢 24 import numpy as np
🟢 25 import pytest
26
🟢 27 import quadrants as qd
🟢 28 from quadrants.lang import impl
29
🟢 30 from tests import test_utils
31
32
🟢 33 def _on_cuda():
🟢 34 return impl.current_cfg().arch == qd.cuda
35
36
🟢 37 def _is_checkpoint_if_path_native():
38 """The CUDA-native IF-conditional path requires SM 9.0+ / CUDA 12.4+ (slice 1c).
39
40 On other devices and backends the kernel still runs through every checkpoint body, so the behavioural tests pass
41 everywhere, but the GraphManager-introspection assertions only apply on the native path.
42 """
🟢 43 return _on_cuda() and qd.lang.impl.get_cuda_compute_capability() >= 90
44
45
🟢 46 def _supports_checkpoint_yield_resume():
47 """Backends that implement the checkpoint yield/resume host contract.
48
49 Wider than `_is_checkpoint_if_path_native()`: also includes the CPU/x64 path (slice 6) and AMDGPU host-orchestrated
50 sub-graph path (slice 4). Use this predicate for tests of the behavioural yield/resume + `kernel.resume(...)` API;
51 use `_is_checkpoint_if_path_native()` only for graph-introspection counters that exist on CUDA alone.
52 """
🟢 53 if _is_checkpoint_if_path_native():
🔴 54 return True
55 # CPU backend: same `runtime/cpu/kernel_launcher.cpp` host-branch gating runs on both x64 and arm64 (the launcher is
56 # arch-agnostic; only the LLVM codegen target differs). Apple Silicon surfaces as `qd.arm64`; Linux x86 as `qd.x64`.
57 # Both go through the slice 6 path.
🟢 58 if impl.current_cfg().arch in (qd.x64, qd.arm64):
🟢 59 return True
🟢 60 if impl.current_cfg().arch == qd.amdgpu:
🔴 61 return True
62 # GFX backends (Vulkan, Metal): per-task host gating + readback yield-check in `GfxRuntime` (slice 4 cont.); see
63 # `runtime/gfx/runtime.cpp`'s task loop.
🟢 64 if impl.current_cfg().arch in (qd.vulkan, qd.metal):
🔴 65 return True
🟢 66 return False
67
68
🟢 69 def _supports_checkpoint_yield_resume_in_while_loop():
70 """Strict subset of `_supports_checkpoint_yield_resume`: returns true on backends where yield/resume also works
71 inside a `qd.graph_do_while` body. Same predicate today since slice 4 ported the CPU launcher's host-branch gating
72 plus per-iter resume_point reset to the AMDGPU streaming path."""
🟢 73 return _supports_checkpoint_yield_resume()
74
75
🟢 76 def _num_checkpoints_on_last_call():
🔴 77 return impl.get_runtime().prog.get_graph_num_checkpoints_on_last_call()
78
79
🟢 80 def _last_yield_cp_id_on_last_call():
🔴 81 return impl.get_runtime().prog.get_graph_last_yield_cp_id_on_last_call()
82
83
84 # ----------------------------------------------------------------------------------------------------------------------
85 # Python-runtime surface (works outside @qd.kernel).
86 # ----------------------------------------------------------------------------------------------------------------------
87
88
🟢 89 def test_checkpoint_is_no_op_outside_kernels():
90 """At Python runtime (outside kernels) ``qd.checkpoint`` must be a usable no-op context manager.
91
92 Lets downstream consumers import the symbol unconditionally and use it inside helpers that are sometimes called
93 from Python and sometimes from kernels. The new API has two required args; both are accepted unchanged here (the
94 Python runtime stub is just ``del cp_id, yield_on; yield``).
95 """
🟢 96 sentinel = []
🟢 97 with qd.checkpoint(0, None):
🟢 98 sentinel.append("body ran")
🟢 99 with qd.checkpoint(7, None):
🟢 100 sentinel.append("body ran")
🟢 101 assert sentinel == ["body ran", "body ran"]
102
103
104 # ----------------------------------------------------------------------------------------------------------------------
105 # Decorator + opt-in flag.
106 # ----------------------------------------------------------------------------------------------------------------------
107
108
🟢 109 @test_utils.test()
🟢 110 def test_checkpoint_in_non_checkpoints_kernel_raises():
111 """Using ``qd.checkpoint(...)`` in a ``@qd.kernel(graph=True)`` without ``checkpoints=True`` must error at compile
112 time, pointing at the fix-it (add ``checkpoints=True``)."""
113
🟢 114 @qd.kernel(graph=True)
🟢 115 def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0)):
🔴 116 with qd.checkpoint(0, yield_on=flag):
🔴 117 for i in range(x.shape[0]):
🔴 118 x[i] = x[i] + 1
119
🟢 120 x = qd.ndarray(qd.i32, shape=(4,))
🟢 121 flag = qd.ndarray(qd.i32, shape=())
🟢 122 with pytest.raises(qd.QuadrantsSyntaxError, match=r"checkpoints=True"):
🟢 123 k(x, flag)
124
125
🟢 126 def test_kernel_checkpoints_requires_graph_true():
127 """``@qd.kernel(checkpoints=True)`` without ``graph=True`` is rejected at decorator time -- the resume model is only
128 meaningful for graph kernels (the gate / yield-check lowering, the resume_point slot, and the kernel.resume API all
129 depend on the graph-capture path)."""
🟢 130 with pytest.raises(qd.QuadrantsSyntaxError, match=r"checkpoints=True\) requires graph=True"):
131
🟢 132 @qd.kernel(checkpoints=True)
🟢 133 def k(x: qd.types.ndarray(qd.i32, ndim=1)):
🔴 134 for i in range(x.shape[0]):
🔴 135 x[i] = x[i] + 1
136
137
138 # ----------------------------------------------------------------------------------------------------------------------
139 # New API signature: cp_id (int or IntEnum), yield_on (required, must be parameter name).
140 # ----------------------------------------------------------------------------------------------------------------------
141
142
🟢 143 @test_utils.test()
🟢 144 def test_checkpoint_missing_cp_id_raises():
145 """``qd.checkpoint()`` with no args must raise with a message that points at the required signature."""
146
🟢 147 @qd.kernel(graph=True, checkpoints=True)
🟢 148 def k(x: qd.types.ndarray(qd.i32, ndim=1)):
🔴 149 with qd.checkpoint():
🔴 150 for i in range(x.shape[0]):
🔴 151 x[i] = x[i] + 1
152
🟢 153 x = qd.ndarray(qd.i32, shape=(4,))
🟢 154 with pytest.raises(qd.QuadrantsSyntaxError, match=r"qd\.checkpoint\(cp_id, yield_on=flag\)"):
🟢 155 k(x)
156
157
🟢 158 @test_utils.test()
🟢 159 def test_checkpoint_missing_yield_on_raises():
160 """``qd.checkpoint(cp_id)`` without ``yield_on=`` is rejected at compile time."""
161
🟢 162 @qd.kernel(graph=True, checkpoints=True)
🟢 163 def k(x: qd.types.ndarray(qd.i32, ndim=1)):
🔴 164 with qd.checkpoint(0):
🔴 165 for i in range(x.shape[0]):
🔴 166 x[i] = x[i] + 1
167
🟢 168 x = qd.ndarray(qd.i32, shape=(4,))
🟢 169 with pytest.raises(qd.QuadrantsSynta
Loading