Skip to content
Open
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
9c3920d
[Graph] Add qd.checkpoint AST surface (slice 1a)
hughperkins Jun 4, 2026
ec5bdd2
[Graph] Plumb checkpoint_id end-to-end (slice 1b)
hughperkins Jun 4, 2026
1af49f0
[Graph] Add checkpoint IF-gate kernel source (slice 1c, part 1/N)
hughperkins Jun 4, 2026
1e18653
[Graph] Drop sm_110 from checkpoint gate fatbin script (CUDA 12.9 too…
hughperkins Jun 4, 2026
a5c614f
[Graph] GraphManager wires IF nodes per qd.checkpoint (slice 1c, CUDA…
hughperkins Jun 4, 2026
b4612a0
[Graph] Slice 1d (part 1): yield-check + cond-with-yield kernel sourc…
hughperkins Jun 5, 2026
e138cfd
[Graph] Slice 1d (part 2): regenerate fatbins for yield-check + cond-…
Jun 5, 2026
4492b35
[Graph] Slice 1d (part 3): wire yield-check kernel + cond-with-yield …
hughperkins Jun 5, 2026
79ebf57
[Graph] Slice 1d (part 4): tests for yield-check / yield-race / WHILE…
hughperkins Jun 5, 2026
28a2f29
[Graph] Slice 1d (fix): route first launch through launch_cached_grap…
hughperkins Jun 5, 2026
5dedc4e
[Graph] Slice 1d (tests): correct yield-first-wins semantics
hughperkins Jun 5, 2026
977f3f0
[Graph] Slice 1d (docs): document yield mechanism in graph.md + updat…
hughperkins Jun 5, 2026
bf906d7
[Graph] Slice 2 (impl): GraphStatus return value + kernel.resume(from…
hughperkins Jun 5, 2026
dd37b28
[Graph] Slice 2 (fix): move GraphStatus to its own module to avoid ci…
hughperkins Jun 5, 2026
788f607
[Graph] Slice 2 (tests): GraphStatus return + resume(from_checkpoint=…
hughperkins Jun 5, 2026
37ed4e4
[Graph] Slice 2 (docs): GraphStatus + kernel.resume() user guide
hughperkins Jun 5, 2026
4d5bc25
[Graph] Slice 3: port qipc test_resume_offset.cu scenarios A+B + rese…
hughperkins Jun 5, 2026
d7fe3f6
[Graph] Slice 3: regenerate condition kernel fatbin (cond-with-yield …
hughperkins Jun 5, 2026
7123030
[Graph] Slice 3 (fix): rewrite resume-offset tests to use for-loops s…
hughperkins Jun 5, 2026
c685f24
[Graph] Slice 3 (fix): use for _ in range(1) for scalar work inside c…
hughperkins Jun 5, 2026
0c94986
[Graph] Slice 7 (docs polish): backend-support table now lists qd.che…
hughperkins Jun 5, 2026
e35e9cb
[Graph] Slice 6: CPU fallback - host-branch gating for qd.checkpoint …
hughperkins Jun 5, 2026
83a6224
[Graph] Slice 6 (fix): clear resume_from_checkpoint after first WHILE…
hughperkins Jun 5, 2026
b265cdf
[Graph] Slice 6: gate yield bookkeeping on yield-capable kernels + op…
hughperkins Jun 5, 2026
e42c496
[Graph] Slice 7: graph.md final pass for slice 6 CPU coverage
hughperkins Jun 5, 2026
9a369c4
[Graph] Slice 4 (AMDGPU): sub-graph orchestration for qd.checkpoint
hughperkins Jun 5, 2026
22f8d26
[Graph] Slice 4 (AMDGPU): open up yield/resume tests on amdgpu (seque…
hughperkins Jun 5, 2026
22f2069
[Graph] Slice 4 (AMDGPU): WHILE + checkpoint via streaming-launcher h…
hughperkins Jun 5, 2026
a25c241
[Graph] Slice 5 (Vulkan/Metal): GFX-runtime host gating for qd.checkp…
hughperkins Jun 5, 2026
840ac3d
[Docs] Update graph.md for slice 4/5 (AMDGPU + Vulkan + Metal coverage)
hughperkins Jun 5, 2026
68a4d54
[Graph] Slice 4/5/6: arm64 also runs the CPU host-branch gating path
hughperkins Jun 5, 2026
cd2881d
[Graph] Slice 5 rewrite: Vulkan/Metal GPU-side gating via indirect di…
hughperkins Jun 5, 2026
fcc001f
[Graph] Pre-Hopper CUDA: GPU-side qd.checkpoint gating via codegen pr…
hughperkins Jun 5, 2026
559fbbf
amdgpu: WIP GPU-side checkpoint gating via codegen prologue + flat HI…
hughperkins Jun 5, 2026
fded02a
amdgpu graph: pass device ptr (not host vector addr) to yield-check k…
hughperkins Jun 5, 2026
73628e2
amdgpu streaming launcher: GPU-side checkpoint gating for graph_do_wh…
hughperkins Jun 5, 2026
01795ab
docs: AMDGPU qd.checkpoint is now GPU-side (codegen prologue + flat H…
hughperkins Jun 5, 2026
0983393
Merge branch 'main' into hp/graph-checkpoint
hughperkins Jun 11, 2026
5058e34
PR 725 review: tighten graph.md qd.checkpoint section + linter fix
hughperkins Jun 11, 2026
c8bfb4b
PR 725 lint: pre-commit run -a (clang-format, black, trailing whitesp…
hughperkins Jun 11, 2026
7c9b76a
PR 725 review round 2: graph.md tightening + AST detection of the bar…
hughperkins Jun 11, 2026
61ee679
PR 725 review round 3: auto-wrap bare stmts in qd.checkpoint, trim cp…
hughperkins Jun 11, 2026
a0f2b54
PR 725 fixes: pyright tuple, _resume_from_checkpoint compat, clang-ti…
hughperkins Jun 11, 2026
cf4633e
PR 725 CI: feature factorization (3 modules) + restore deleted ration…
hughperkins Jun 11, 2026
b6ca290
PR 725 fix: test_api -- register qd.checkpoint + qd.GraphStatus, hide…
hughperkins Jun 11, 2026
d2afddf
PR 725: rewrap comments to 120c per find_underwrapped.py audit
hughperkins Jun 11, 2026
b6ee9be
PR 725: rewrap comments to 120c per find_underwrapped.py audit (follo…
hughperkins Jun 11, 2026
2a483e7
PR 725 CI: address line wrapping, deleted comments, and SPIR-V test c…
hughperkins Jun 11, 2026
9240ca3
PR 725 CI: additional factorization + restore two more upstream comments
hughperkins Jun 11, 2026
fb799c6
PR 725: rewrap one more comment in checkpoint_yield_check_shader.h
hughperkins Jun 11, 2026
6ea701b
PR 725 fix: offload-pass cp_id propagation through intervening serial…
hughperkins Jun 11, 2026
8b45f47
PR 725 lint: rewrap mixed-width comment paragraphs to a single 120c p…
hughperkins Jun 11, 2026
dd9548b
PR 725 AMDGPU fix: add gfx1010/1011/1012 (RDNA1) to checkpoint yield-…
hughperkins Jun 12, 2026
d741a46
PR 725 AMDGPU fix: clang-format the regenerated HSACO header (and pre…
hughperkins Jun 12, 2026
197a892
qd.checkpoint: fuse adjacent bare statements into one wrapper task
hughperkins Jun 12, 2026
3402601
qd.checkpoint: reject bare top-level statements with a clear compile-…
hughperkins Jun 12, 2026
9e3eae5
PR 725 line-wrap: reflow three 80c comments flagged by the CI agent
hughperkins Jun 12, 2026
10c36a6
PR 725 line-wrap: reflow three more 80c comments flagged by the CI agent
hughperkins Jun 12, 2026
5b3e656
Merge branch 'main' into hp/graph-checkpoint
hughperkins Jun 12, 2026
0e26368
PR 725 line-wrap: reflow the last two 80c comments flagged by the CI …
hughperkins Jun 12, 2026
9aa7c69
Merge branch 'main' into hp/graph-checkpoint
hughperkins Jun 12, 2026
0bb6027
PR 725 line-wrap: reflow three more 80c comments flagged by the CI ag…
hughperkins Jun 12, 2026
935cc37
qd.checkpoint: mark as experimental in user-facing docs and docstrings
hughperkins Jun 13, 2026
7ddee8d
qd.checkpoint experimental note: drop the "pin to a version / file an…
hughperkins Jun 13, 2026
5d1c181
PR 725 line-wrap: bulk reflow of under-wrapped // comment runs across…
hughperkins Jun 13, 2026
7aa1f60
qd.checkpoint v2: auto-wrap top-level for-loops + IntEnum-friendly cp…
hughperkins Jun 17, 2026
b3e726e
qd.checkpoint v2: tests, docs, GraphStatus repr, line-wrapping pass
hughperkins Jun 17, 2026
9698f73
docs: drop converted-from-bare qd.checkpoint() calls from the graph.m…
hughperkins Jun 17, 2026
60290d7
qd.checkpoint: reject bare-Name cp_id (fastcache no-globals safety)
hughperkins Jun 17, 2026
aa408c6
docs: drop "auto-wrap" / "implicit checkpoint" from user-facing surface
hughperkins Jun 17, 2026
9e02885
docs: drop 'and are the expected pattern' from gdw + checkpoint note
hughperkins Jun 17, 2026
a7f198f
qd.checkpoint: re-allow bare-Name cp_id, just don't advertise it
hughperkins Jun 17, 2026
9d91543
docs: drop misleading 'first-yielder-wins-of-several' clause from gra…
hughperkins Jun 17, 2026
24139e0
qd.checkpoint: stop clearing user's yield_on flag in the yield-check
hughperkins Jun 17, 2026
be27aa3
amdgpu: regenerate checkpoint_yield_check HSACO after dropping yield_…
Jun 17, 2026
2db5db8
cuda: regenerate checkpoint_yield_check fatbin after dropping yield_o…
Jun 17, 2026
9a35cb9
docs: drop the skip-past-yielder paragraph (suggested API that doesn'…
hughperkins Jun 17, 2026
b0ae987
docs: spell out that yield_on must be initialised before the first la…
hughperkins Jun 17, 2026
c5fa6c9
docs: drop internal '(skip + yield_on=)' tag from qd.checkpoint row i…
hughperkins Jun 17, 2026
9e2cc90
docs: drop misleading 'kernel pauses at that checkpoint' from yield m…
hughperkins Jun 17, 2026
a3520a0
docs: drop the 'write the flag unconditionally' alternative pattern
hughperkins Jun 17, 2026
14edbf7
docs: GraphStatus returned by kernels with checkpoints=True (not 'at …
hughperkins Jun 17, 2026
5927d31
docs: 'iff a checkpoint' instead of 'iff some checkpoint'
hughperkins Jun 17, 2026
d3559df
docs: 'including from kernel.resume(...)' instead of 'and from'
hughperkins Jun 17, 2026
f21617d
Merge branch 'main' into hp/graph-checkpoint
hughperkins Jun 17, 2026
67da464
address review commetns
hughperkins Jun 17, 2026
a398139
precoommit
hughperkins Jun 17, 2026
e939301
kernel: extract qd.checkpoint plumbing into kernel_checkpoint.py
hughperkins Jun 17, 2026
6544ff6
fix CI failures from kernel_checkpoint extract
hughperkins Jun 17, 2026
0464423
wrap: tighten 3 comment runs flagged by AI bot on 6544ff668
hughperkins Jun 17, 2026
cb4f308
Merge origin/main into hp/graph-checkpoint (nested graph_do_while)
hughperkins Jun 18, 2026
28f630b
address comments
hughperkins Jun 18, 2026
1cf0343
add Resume where setion
hughperkins Jun 18, 2026
35e49d4
precommit
hughperkins Jun 18, 2026
3b84936
factor checkpoint launch helpers out of GfxRuntime::launch_kernel
hughperkins Jun 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 63 additions & 4 deletions docs/source/user_guide/graph.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@ Graphs reduce kernel launch overhead by capturing a sequence of GPU operations i

## Backend support

Both features run on every backend. They are *hardware accelerated* on CUDA (via CUDA graphs) and AMDGPU (via HIP graphs); `graph_do_while` additionally requires CUDA SM 9.0+ / Hopper for its hardware-accelerated path. On other backends, `graph=True` is silently ignored and the kernel runs via the normal launch path, and `graph_do_while` falls back to a host-side do-while loop that copies the condition value GPU → host each iteration (causing a pipeline stall — see [Caveats](#caveats)).
Comment thread
hughperkins marked this conversation as resolved.
Outdated

| Feature | `qd.cuda` SM 9.0+ | `qd.cuda` < SM 9.0 | `qd.amdgpu` | `qd.metal` | `qd.vulkan` | `qd.cpu` |
Comment thread
hughperkins marked this conversation as resolved.
| --- | --- | --- | --- | --- | --- | --- |
| `graph=True` | hardware accelerated | hardware accelerated | hardware accelerated | runs (no acceleration) | runs (no acceleration) | runs (no acceleration) |
| `graph_do_while` | hardware accelerated | host fallback | host fallback | host fallback | host fallback | host fallback |

AMDGPU `graph_do_while` falls back to the host-side loop because HIP does not currently expose conditional / while graph nodes (as of ROCm 7.2).
| `qd.checkpoint` (skip + `yield_on=`) | GPU-side | GPU-side | GPU-side | GPU-side | GPU-side | host-side |
Comment thread
hughperkins marked this conversation as resolved.
Outdated

## Basic usage

Expand Down Expand Up @@ -154,3 +151,65 @@ Note: the basic `graph=True` path (without `graph_do_while`) does **not** stall
Therefore on unsupported platforms, you might consider creating a second implementation, which works differently. e.g.:
- fixed number of loop iterations, so no dependency on gpu data for kernel launch; combined perhaps with:
- make each kernel 'short-circuit', exit quickly, if the task has already been completed; to avoid running the GPU more than necessary

## Checkpoints with `qd.checkpoint` *(experimental)*

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: after the backwards-incompatible disaster taht was the algorithms.md changes for qipc, that kept changing 😅 I think I'd like to mark things as 'experimental' for a few weeks/months, until we are confident the api is stable.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to AI: this is not a request to you, it is an obseration for other human reviewers.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If checkpoint is a graph-specific API, it should be prefixed by 'graph_', as for all the other functions.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Know that I am thinking, we should probably just have some new 'qd.graph.' submodule. That would like everything both simpler and more less confusing.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable. But perhaps not in this PR?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, make sure it is tracking somewhere. Still, it is weird to address this in another PR, since it is literally the PR that introduces this function. But I can understand you want to move faster and qipc is already relying on this one.


`qd.checkpoint()` marks a section of a graph kernel as a *skippable, optionally yieldable stage*. An example use-case is an algorithm implemented as a graph where you might need to allocate additional memory part-way through, the graph operations are in-place, and simply retrying the whole graph from the start is not an option. `qd.checkpoint` lets the kernel break at some point in the graph, surface the reason to the host, let the host fix things up, and resume from that point on the next launch.
Comment thread
hughperkins marked this conversation as resolved.
Outdated
Comment thread
hughperkins marked this conversation as resolved.
Outdated
Comment thread
hughperkins marked this conversation as resolved.
Outdated

```python
@qd.kernel(graph=True)
def step(
arr: qd.types.ndarray(qd.f32, ndim=1),
overflow_flag: qd.types.ndarray(qd.i32, ndim=0),
newton_cond: qd.types.ndarray(qd.i32, ndim=0),
):
while qd.graph_do_while(newton_cond):
with qd.checkpoint(): # cp_id 0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to be able to specify some enum-based ID, to allow cleaner implementation of checkpoint recovery.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean, instead of incremntally auto-assigning an integer id, require the id to be passed in, as an integer, and we could pass in an intenum value, converted to int?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly.

for i in range(arr.shape[0]):
# ...
pass
with qd.checkpoint(yield_on=overflow_flag): # cp_id 1 (can yield)
for i in range(arr.shape[0]):
# ...
pass
with qd.checkpoint(): # cp_id 2
for i in range(arr.shape[0]):
# ...
pass
Comment thread
hughperkins marked this conversation as resolved.
Outdated
```
Comment thread
hughperkins marked this conversation as resolved.

Each `with qd.checkpoint(...)` block gets a `cp_id` assigned. You can use the `cp_id` to identify which checkpoint yielded and which checkpoint to resume from — see [Host-side yield / resume loop](#host-side-yield--resume-loop) below.

### Yield mechanism

If `yield_on=foo` is supplied, the body may write a non-zero value into `foo[()]` (for example, when a pre-allocated buffer is too small) to signal "the host needs to handle something before this checkpoint can complete". When that happens:
Comment thread
duburcqa marked this conversation as resolved.
Outdated

1. The framework records the checkpoint that yielded (first yielder in declaration order wins).
2. Every later checkpoint in the same launch is skipped.
3. `qd.checkpoint` will exit any surrounding `qd.graph_do_while`.
4. `foo[()]` is reset to `0`.
Comment thread
hughperkins marked this conversation as resolved.
Outdated

### Host-side yield / resume loop

Kernels with at least one `yield_on=` checkpoint return a `qd.GraphStatus` from every launch (and from `kernel.resume(...)`). The status carries two fields:

- `status.yielded` — `True` iff some `yield_on=` flag was non-zero during this launch.
- `status.checkpoint` — `cp_id` of the first (in declaration order) checkpoint that fired its flag, or `None` when `yielded` is `False`.

Resume by calling `kernel.resume(..., from_checkpoint=status.checkpoint)`. Every `qd.checkpoint` with `cp_id < from_checkpoint` is skipped on the resume launch; the rest run normally. The canonical host loop:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every qd.checkpoint with cp_id < from_checkpoint is skipped on the resume launch; the rest run normally.

Does this means the checkpoint for 'cp_id == from_checkpoint' will be executed once again entirely?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be clearly stated if not already the case.


```python
status = step(arr, overflow_flag, newton_cond)
Comment thread
duburcqa marked this conversation as resolved.
while status.yielded:
handle_overflow_for(status.checkpoint, ...)
status = step.resume(arr, overflow_flag, newton_cond,
from_checkpoint=status.checkpoint)
```

Kernels with `qd.checkpoint()` but no `yield_on=` keep their previous return contract (typically `None`) — the `GraphStatus` surface is opt-in via `yield_on=`.
Comment thread
hughperkins marked this conversation as resolved.
Outdated

### Restrictions

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain / be more explicit about what happens during resume?

Stating clearly that the entire checkpoint block is re-executed, and that it is user-responsibility to ensure idempotent behaviour when checkpoint is needed? Because if the state is altered during the checkpoint block, resuming is not going to save you I guess? Whatever the answer, it should be very clear in the doc.

Beyond that, how does checkpointing works under the hood? Does it snapshot all the input data by copy before yielding, or it just return like this? If no copy is made, this means that resuming must be done "right away", without further altering the data in between, otherwise it is some kind of undefined behaviour.

Another important point, what is I don't want to resume in such a case and I just want to move on to another kernel and continue like this? Is it supported or resume must happen?

I think it is essentially to clarify all these points in the documentation.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkpoint does NOT require idempotent behavior. This is the entire purpose of checkpoint: to be able to interrupt and resume graphs that are NOT idempotent.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, hte checkpoint block itself. right.

@hughperkins hughperkins Jun 18, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the checkpoint block itself actually does not so much require idempotence, as requiring that it is atomic: it either succeeds completely, or fails without changing anything.

@hughperkins hughperkins Jun 18, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as an example, in the case of allocation issues, the checkpoint block looks like:

  • do we have enough memory availbel?
    • no: exit now
    • yes: ok, lets proceed with running the sort etc

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added 'resume where' section

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the checkpoint block itself actually does not so much require idempotence, as requiring that it is atomic: it either succeeds completely, or fails without changing anything.

Yeah, this is exactly what I meant by « ensure idempotent behaviour when checkpoint is needed »

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"fails without changing anything" I feel is not idempotent? Idempotent means that calling the function multiple times is identical in effect to calling it once. But if it fails the first time, it would only be idempotent if it always failed thereafter I feel?


- Must be used inside `@qd.kernel(graph=True)`.
- `yield_on=` (when supplied) must be a kernel parameter that is a 0-d `qd.types.ndarray(qd.i32, ndim=0)`.
- Checkpoints cannot be nested inside other checkpoints.
30 changes: 30 additions & 0 deletions python/quadrants/lang/_quadrants_callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,32 @@ def __init__(self, fn: Callable, wrapper: Callable) -> None:
def __call__(self, *args, **kwargs):
return self.wrapper.__call__(*args, **kwargs)

def resume(self, *args, from_checkpoint: int, **kwargs):
"""Re-launches the kernel, skipping every ``qd.checkpoint`` with ``cp_id < from_checkpoint``.

Use only on kernels decorated with ``@qd.kernel(graph=True)`` that contain at least
one ``qd.checkpoint(yield_on=...)`` block. The host loop pattern is::

status = step(arr, overflow_flag, newton_cond)
while status.yielded:
handle_overflow_for(status.checkpoint, ...)
status = step.resume(arr, overflow_flag, newton_cond,
from_checkpoint=status.checkpoint)

Returns the same ``GraphStatus`` shape as the plain call.

Raises ``RuntimeError`` if invoked on a kernel without any ``yield_on=`` checkpoint
(there is no resume_point slot to write to, so the call would be a no-op).
"""
if not isinstance(from_checkpoint, int) or from_checkpoint < 0:
raise RuntimeError(
f"from_checkpoint= must be a non-negative integer (typically `status.checkpoint` "
f"from the previous launch's GraphStatus); got {from_checkpoint!r}."
)
# Smuggle the resume cookie past the AST-mapped kwargs path; `Kernel.__call__` pops it
# before anything else looks at kwargs.
return self.wrapper.__call__(*args, _qd_from_checkpoint=from_checkpoint, **kwargs)

def __get__(self, instance, owner):
if instance is None:
return self
Expand Down Expand Up @@ -124,3 +150,7 @@ def __setattr__(self, k: str, v: Any) -> None:
def grad(self, *args, **kwargs) -> "Kernel":
assert self.quadrants_callable._adjoint is not None
return self.quadrants_callable._adjoint(self.instance, *args, **kwargs)

def resume(self, *args, from_checkpoint: int, **kwargs):
"""Bound-method form of `QuadrantsCallable.resume` (see that docstring)."""
return self.quadrants_callable.resume(self.instance, *args, from_checkpoint=from_checkpoint, **kwargs)
135 changes: 134 additions & 1 deletion python/quadrants/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,42 @@ def _is_graph_do_while_call(node: ast.expr) -> str | None:
return node.args[0].id
return None

@staticmethod
def _is_checkpoint_call(node: ast.expr) -> tuple[bool, str | None]:
"""If *node* is a ``qd.checkpoint(...)`` call return ``(True, yield_on_arg_name)``; otherwise
``(False, None)``. ``yield_on_arg_name`` is ``None`` when the user wrote
``qd.checkpoint()`` with no ``yield_on`` kwarg.

Validates the call shape (no positional args, only ``yield_on=`` as a bare ``ast.Name``)
and raises ``QuadrantsSyntaxError`` for misuse so the user gets a clear message at the
``with`` site rather than a vague "not stream_parallel" error later.
"""
if not isinstance(node, ast.Call):
return False, None
func = node.func
is_checkpoint = (isinstance(func, ast.Attribute) and func.attr == "checkpoint") or (
isinstance(func, ast.Name) and func.id == "checkpoint"
)
if not is_checkpoint:
return False, None
if node.args:
raise QuadrantsSyntaxError(
"qd.checkpoint() takes no positional arguments; use qd.checkpoint(yield_on=flag) instead"
)
yield_on_name: str | None = None
for kw in node.keywords:
if kw.arg != "yield_on":
raise QuadrantsSyntaxError(
f"qd.checkpoint() got unexpected keyword argument {kw.arg!r}; only 'yield_on' is supported"
)
if not isinstance(kw.value, ast.Name):
raise QuadrantsSyntaxError(
"qd.checkpoint(yield_on=...) must be the bare name of a kernel parameter "
"(e.g. `yield_on=overflow_flag`); expressions are not supported"
)
yield_on_name = kw.value.id
return True, yield_on_name

@staticmethod
def build_While(ctx: ASTTransformerFuncContext, node: ast.While) -> None:
if node.orelse:
Expand Down Expand Up @@ -1575,15 +1611,112 @@ def build_With(ctx: ASTTransformerFuncContext, node: ast.With) -> None:
raise QuadrantsSyntaxError("'with ... as ...' is not supported in Quadrants kernels")
if not isinstance(item.context_expr, ast.Call):
raise QuadrantsSyntaxError("'with' in Quadrants kernels requires a call expression")

is_checkpoint, yield_on_name = ASTTransformer._is_checkpoint_call(item.context_expr)
if is_checkpoint:
return ASTTransformer._build_checkpoint_with(ctx, node, yield_on_name)

if not FunctionDefTransformer._is_stream_parallel_with(node, ctx.global_vars):
raise QuadrantsSyntaxError("'with' in Quadrants kernels only supports qd.stream_parallel()")
raise QuadrantsSyntaxError(
"'with' in Quadrants kernels only supports qd.stream_parallel() or qd.checkpoint()"
)
if not ctx.is_kernel:
raise QuadrantsSyntaxError("qd.stream_parallel() can only be used inside @qd.kernel, not @qd.func")
ctx.ast_builder.begin_stream_parallel()
build_stmts(ctx, node.body)
ctx.ast_builder.end_stream_parallel()
return None

@staticmethod
def _build_checkpoint_with(
ctx: ASTTransformerFuncContext,
node: ast.With,
yield_on_name: str | None,
) -> None:
"""Handles ``with qd.checkpoint(yield_on=arg):`` blocks.

Slice 1a: validates the use-site (kernel must be graph=True, no nesting, yield_on must be a kernel
parameter) and records the checkpoint's ``yield_on`` arg on the kernel object. Walks the body
transparently -- for-loops inside the ``with`` become normal top-level for-loops in the kernel's
frontend IR. The ``cp_id`` is assigned by declaration order (list index in
``kernel.checkpoint_yield_on_args``).

Later slices wire ``cp_id`` through ForLoopConfig → OffloadedTask so the GraphManager can wrap
each checkpoint's body kernels in an IF conditional node and insert the yield-check kernel.
"""
if not ctx.is_kernel:
raise QuadrantsSyntaxError("qd.checkpoint() can only be used inside @qd.kernel, not @qd.func")
kernel = ctx.global_context.current_kernel
if not kernel.use_graph:
raise QuadrantsSyntaxError("qd.checkpoint() requires @qd.kernel(graph=True)")
if getattr(ctx, "_in_checkpoint", False):
raise QuadrantsSyntaxError(
"qd.checkpoint() cannot be nested inside another qd.checkpoint(); checkpoints in the "
"same kernel must be flat siblings (a checkpoint inside qd.graph_do_while is fine)"
)
if yield_on_name is not None:
arg_names = [m.name for m in kernel.arg_metas]
if yield_on_name not in arg_names:
raise QuadrantsSyntaxError(
f"qd.checkpoint(yield_on={yield_on_name!r}) does not match any parameter of kernel "
f"{kernel.func.__name__!r}. Available parameters: {arg_names}"
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Validate yield_on is a scalar i32 ndarray

This only checks that yield_on names a parameter, but the runtime kernels later cast the resolved pointer to int32_t * and read/write a single scalar. If a user names a f32 flag or an ndarray with ndim != 0, the graph will interpret or clear the first 4 bytes/first element instead of failing as documented, producing bogus yields or corrupting user data. Please validate the matched arg metadata here before recording it.

Useful? React with 👍 / 👎.


# Auto-wrap bare top-level statements in the checkpoint body in a one-iteration
# `for` loop. The offloader's pending-serial bucket loses the surrounding
# `checkpoint_id` and emits such statements as `serial` tasks with `cp_id == -1`,
# meaning they would run unconditionally even when the checkpoint is skipped -- a
# silent correctness bug. The fix is to lower them as `range_for` tasks instead by
# wrapping each bare statement in `for _ in range(1): <stmt>`. We target the specific
# statement kinds known to hit the footgun (Assign / AugAssign / AnnAssign /
# non-docstring Expr) and leave everything else (For, While, If, With, Pass,
# docstring) untouched so they keep working transparently; nested
# `with qd.checkpoint(...)` in particular still falls through to the existing
# nested-checkpoint check at the start of this method.
new_body: list[ast.stmt] = []
for i, stmt in enumerate(node.body):
needs_wrap = isinstance(stmt, (ast.Assign, ast.AugAssign, ast.AnnAssign))
if not needs_wrap and isinstance(stmt, ast.Expr):
is_docstring = i == 0 and isinstance(stmt.value, ast.Constant)
needs_wrap = not is_docstring

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve checkpoint ids for compound serial bodies

When a checkpoint body starts with a top-level if/while (for example, with qd.checkpoint(): if flag[None] != 0: x[None] += 1), this wrapping logic leaves the compound statement as serial IR, but the new checkpoint plumbing only propagates checkpoint_id through frontend for statements into offloaded tasks. That serial task keeps cp_id == -1, so it still runs on resume(..., from_checkpoint > cp) and after an earlier checkpoint yields, violating the skip/resume contract for valid kernel control flow inside a checkpoint.

Useful? React with 👍 / 👎.

if needs_wrap:
wrapped = ast.For(
target=ast.Name(id="_", ctx=ast.Store()),
iter=ast.Call(
func=ast.Name(id="range", ctx=ast.Load()),
args=[ast.Constant(value=1)],
keywords=[],
),
body=[stmt],
orelse=[],
)
ast.copy_location(wrapped, stmt)
ast.fix_missing_locations(wrapped)
new_body.append(wrapped)
else:
new_body.append(stmt)
node.body = new_body

kernel.checkpoint_yield_on_args.append(yield_on_name)
# Hand control to the C++ ASTBuilder so that every for-loop emitted by `build_stmts`
# below is tagged with this checkpoint's `cp_id` on its `ForLoopConfig.checkpoint_id`.
# The C++ counter is the source of truth for cp_id; we cross-check it against the
# Python list index so a future refactor that misaligns the two surfaces immediately.
cpp_cp_id = ctx.ast_builder.begin_checkpoint()
py_cp_id = len(kernel.checkpoint_yield_on_args) - 1
assert cpp_cp_id == py_cp_id, (
f"C++ ASTBuilder.begin_checkpoint() returned cp_id={cpp_cp_id} but Python "
f"kernel.checkpoint_yield_on_args index expected {py_cp_id}; these counters "
f"must stay in lockstep so the GraphManager (slice 1c) can index yield_on by cp_id"
)
ctx._in_checkpoint = True
try:
build_stmts(ctx, node.body)
finally:
ctx._in_checkpoint = False
ctx.ast_builder.end_checkpoint()
return None

@staticmethod
def build_Pass(ctx: ASTTransformerFuncContext, node: ast.Pass) -> None:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,12 @@ def build_FunctionDef(

if ctx.is_kernel:
FunctionDefTransformer._validate_stream_parallel_exclusivity(node.body, ctx.global_vars)
kernel = ctx.global_context.current_kernel
if kernel is not None:
# Reset before walking the body so re-materialisations (e.g. when a templated kernel is compiled
# with a different argument shape) start from an empty list. Mirrors how `graph_do_while_arg`
# gets overwritten unconditionally during AST traversal.
kernel.checkpoint_yield_on_args = []

with ctx.variable_scope_guard():
build_stmts(ctx, node.body)
Expand Down
50 changes: 50 additions & 0 deletions python/quadrants/lang/graph_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Plain-Python container returned from graph kernels that contain ``qd.checkpoint(yield_on=...)``.

Lives in its own module (with no Quadrants-internal imports) so it can be imported safely from
both ``kernel.py`` and ``misc.py`` without re-introducing the circular import chain that
``misc.py -> impl.py -> kernel.py`` would create.

Re-exported via ``qd.lang.misc`` (and therefore as ``qd.GraphStatus``) for the user-facing
canonical import path.
"""

from __future__ import annotations


class GraphStatus:
"""Result returned by a graph kernel that contains ``qd.checkpoint(yield_on=...)`` blocks.

Returned from ``kernel(...)`` and ``kernel.resume(..., from_checkpoint=cp)`` whenever the
kernel was decorated with ``@qd.kernel(graph=True)`` and contains at least one checkpoint
that declared a ``yield_on=`` parameter. Read ``status.yielded`` to decide whether to keep
running the host loop, and ``status.checkpoint`` to find out which checkpoint asked the
host to handle something.

Canonical usage (mirrors the qipc re-entrant pattern; see ``graph.md``)::

status = step(arr, overflow_flag, newton_cond)
while status.yielded:
handle_overflow_for(status.checkpoint, ...)
status = step.resume(arr, overflow_flag, newton_cond,
from_checkpoint=status.checkpoint)

Attributes:
yielded: ``True`` iff one of the kernel's ``yield_on=`` checkpoints fired its flag on
the most recent launch. ``False`` means the kernel completed normally and the host
loop should exit.
checkpoint: ``cp_id`` of the checkpoint whose ``yield_on=`` flag was non-zero (or
``None`` when ``yielded`` is ``False``). Pass it to ``kernel.resume(...,
from_checkpoint=cp)`` to skip every checkpoint with a lower ``cp_id`` on the next
launch.
"""

__slots__ = ("yielded", "checkpoint")

def __init__(self, yielded: bool, checkpoint: int | None):
self.yielded = yielded
self.checkpoint = checkpoint

def __repr__(self) -> str:
if self.yielded:
return f"GraphStatus(yielded=True, checkpoint={self.checkpoint})"
return "GraphStatus(yielded=False)"
Loading
Loading