Skip to content

[Autotune] Add pipeline, grouped compilation, and multi-GPU benchmark support#2159

Open
Wazrrr wants to merge 8 commits intotile-ai:mainfrom
Wazrrr:pr-clean
Open

[Autotune] Add pipeline, grouped compilation, and multi-GPU benchmark support#2159
Wazrrr wants to merge 8 commits intotile-ai:mainfrom
Wazrrr:pr-clean

Conversation

@Wazrrr
Copy link
Copy Markdown

@Wazrrr Wazrrr commented May 7, 2026

Summary

This PR adds pipeline mode, grouped compilation, and multi-GPU benchmark support to autotuning.

Key Updates

1) Grouped compilation (tvm-ffi only)

  • Added grouped compilation support for autotune execution.
  • Current backend scope is tvm-ffi only (CUDA + tvm-ffi path).
  • Grouped compile helper and integration are included in:
    • tilelang/autotuner/grouped_compile.py
    • tilelang/autotuner/tuner.py
    • related argument/plumbing updates in tilelang/autotuner/param.py and tilelang/engine/lower.py

2) Pipeline mode via tuner runtime refactor

  • Refactored AutoTuner.run() logic in tilelang/autotuner/tuner.py to support pipeline execution flow.
  • The refactor reorganizes compile/benchmark orchestration so pipeline and non-pipeline modes share a clearer execution structure.

3) Multi-GPU benchmark support

  • Added benchmark worker sharding support across multiple GPUs.
  • Supports explicit benchmark device selection.

4) Example entrypoint support

  • Updated examples/gemm/example_gemm_autotune.py main entry to accept and forward:
    • --pipeline / --no-pipeline
    • --grouped-compile / --no-grouped-compile
    • --group-compile-size
    • --benchmark-multi-gpu / --no-benchmark-multi-gpu
    • --benchmark-devices (repeatable)

Summary by CodeRabbit

  • New Features

    • Expanded GEMM autotune controls: choose execution backend, warmup/rep/timeout, skip correctness check, cache inputs, top-k search hint, grouped compile options, pipeline toggle, group size, per-device benchmark list, and multi-GPU benchmarking.
  • Refactor

    • Autotuner reworked for pipelined, concurrent elaborate/compile/benchmark execution with grouped compilation and per-worker caching.
    • Added host/device lowering helper to improve lowering and device codegen flow.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 7, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR extracts host/device lowering into a helper, adds grouped CUDA compilation, rewrites the autotuner for concurrent compile+benchmark pipelines (optional grouped compile and multi‑GPU), forwards execution_backend through compile args, and updates the gemm example CLI to expose the new flags.

Changes

Autotuner grouped/parallel tuning

Layer / File(s) Summary
IR Shape / Lowering
tilelang/engine/lower.py
Add lower_to_host_device_ir to canonicalize targets, run pre-lower checks, lowering/legalize/optimize, and split the lowered IR into host and device modules; lower() now delegates to it.
Compile args wiring
tilelang/autotuner/param.py
CompileArgs.compile_program() forwards execution_backend into tilelang.compile.
Grouped Compilation
tilelang/autotuner/grouped_compile.py
New compile_grouped_unit_tvm_ffi elaborates configs, lowers under a shared PassContext, merges device functions, runs device codegen once, then produces per-config host runtimes and JIT kernels; per-config failures are recorded without aborting the group.
Autotuner Orchestration
tilelang/autotuner/tuner.py
AutoTuner.run rewritten to support pipeline benchmarking, grouped compilation, grouped compile sizing, multi-GPU device resolution, thread-pooled compilation (grouped or per-config), and benchmark worker threads consuming per-device queues; introduces _BenchmarkWorkerState, jit_elaborate, and multiple helper methods.
Example CLI / Usage
examples/gemm/example_gemm_autotune.py
get_best_config, _build_autotuner, and main accept/forward new flags (execution_backend, warmup/rep/timeout, skip_check, cache_input_tensors, topk, pipeline/grouped-compile, group size, multi-GPU, benchmark devices); CLI parsing adds mutually exclusive flags and repeatable device args.

Sequence Diagram(s)

sequenceDiagram
    participant Main as Main Thread
    participant CompExec as Compilation Pool
    participant BenchWorkers as Benchmark Workers
    participant DeviceQueues as Per-Device Queues

    Main->>CompExec: schedule compilation tasks (grouped or per-config)
    CompExec->>CompExec: elaborate configs -> lower (shared PassContext)
    alt grouped
        CompExec->>CompExec: merge device funcs -> device codegen once
    end
    CompExec->>DeviceQueues: enqueue compiled kernels
    Main->>BenchWorkers: launch worker threads (per-device or shared)
    BenchWorkers->>DeviceQueues: consume kernel tasks
    BenchWorkers->>BenchWorkers: execute kernel, cache inputs, measure latency, optional ref-check
    BenchWorkers->>Main: report latency & config
    Main->>Main: track and select best result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • tile-ai/tilelang#861: Both PRs modify tilelang/autotuner/tuner.py and change AutoTuner internals (autotuning orchestration and parameter handling).

Suggested reviewers

  • LeiWang1999

"I'm a rabbit with a tiny shell,
I hop through kernels, weave and dwell.
Grouped compiles hum, workers drum along,
I nibble latencies and sing this song.
Fastest config found — hop, hop, hooray!"

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 12.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the three main features added: pipeline, grouped compilation, and multi-GPU benchmark support for the autotuner.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (6)
tilelang/autotuner/tuner.py (5)

460-476: 💤 Low value

Grouped-compile gate uses str() on already-string execution_backend.

self.compile_args.execution_backend is typed as a Literal[str] and resolved via resolve_execution_backend(...) in set_compile_args, so str(...) here is a no-op. Minor cleanup; the gate logic itself looks correct.

♻️ Suggested cleanup
-        execution_backend = str(self.compile_args.execution_backend)
+        execution_backend = self.compile_args.execution_backend
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 460 - 476, The code in
_resolve_grouped_compile_mode unnecessarily calls str() on
self.compile_args.execution_backend (already a Literal[str]); remove the
redundant str() conversion and use execution_backend =
self.compile_args.execution_backend directly (or keep a local variable named
execution_backend) so the gate logic remains unchanged; update any related
logging or returned tuple to reference the execution_backend variable instead of
the str(...) call.

768-795: ⚡ Quick win

Document and validate the new run() parameters.

run() now takes five new tuning knobs but only group_compile_size is validated (> 0 at line 794). A few small additions would help users:

  • Validate benchmark_devices entries are non-negative ints (currently any negative value is silently filtered out by _resolve_benchmark_devices).
  • Document that timeout is best-effort post-completion (see related comment) for the worker-thread benchmark path.
  • Mention explicitly that use_pipeline=True means benchmarks may begin before all compilations complete (i.e., earlier configs benchmark on a less-loaded GPU than later ones), which can bias measurements.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 768 - 795, The run() signature adds
new knobs but lacks validation and doc clarity: add validation in run() to
ensure each entry in benchmark_devices (or provided list) is an int >= 0 (or
raise ValueError) and ensure incompatible combinations (e.g.,
benchmark_multi_gpu=True requires benchmark_devices non-empty) are checked;
update the docstring of run() to explicitly state that timeout is best-effort
after worker completion (may not abort already finished runs) and that
use_pipeline=True means benchmarking can start before all compilations finish
(earlier configs may see less-loaded GPUs, introducing bias); reference the
helper _resolve_benchmark_devices in the validation note so the implementation
uses or mirrors that logic.

517-536: ⚡ Quick win

cuda_device_wrapper captures the submitting thread's current device, not the worker's.

get_compile_func/get_elaborate_func are called inside compile_unit on the worker thread, which is fine — but torch.cuda.current_device() is read inside the wrapper on the worker thread the first time it runs, then torch.cuda.set_device(device) is invoked per call. If the user has set a non-default current device on the main thread before calling run(), the workers will inherit the default device (since current_device() is per-thread), not the main thread's selection. If the intent is to compile on the user's chosen device, capture torch.cuda.current_device() on the main thread once (in _prepare_compile_execution) and pass it down to the wrapper.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 517 - 536, The cuda_device_wrapper
currently captures torch.cuda.current_device() on the worker thread; change the
flow so the device is captured once on the main thread in
_prepare_compile_execution and then passed into
get_compile_func/get_elaborate_func so cuda_device_wrapper uses that captured
device instead of calling torch.cuda.current_device() inside the worker;
specifically, update _prepare_compile_execution to read device =
torch.cuda.current_device() (if cuda available) and thread-safely pass that
device into get_compile_func/get_elaborate_func which should call
cuda_device_wrapper(compile_func, device) / cuda_device_wrapper(elaborate_func,
device) so each worker calls torch.cuda.set_device(device) using the
main-thread-chosen device before invoking compile_func/elaborate_func (ensure
compile_unit still calls get_compile_func/get_elaborate_func as before).

597-610: ⚡ Quick win

Benchmark timeout is post-hoc and cannot interrupt a hung kernel.

benchmark_target(...) is invoked inline; time.perf_counter() is read only after it returns, so a kernel that hangs (e.g., infinite GPU-side wait) keeps the worker thread blocked indefinitely and never reaches the elapsed > timeout check. The warning at lines 946-949 acknowledges signal-based timeouts can't be used here, but this means the timeout argument is effectively a post-completion filter, not a deadline. Worth reflecting that in the docstring of run() (and ideally short-circuiting the result_queue.put for already-overrun configs without raising) so users don't expect deadline enforcement when they set benchmark_multi_gpu=True or use_pipeline=True.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 597 - 610, The timeout check in
run() is currently post-hoc because benchmark_target(...) runs inline and cannot
be interrupted; update the run() docstring to explicitly state that the timeout
parameter is a post-completion filter (it cannot abort a hung kernel when
benchmark_multi_gpu or use_pipeline is used) and then change the post-benchmark
logic in run(): after calling benchmark_target(...) compute elapsed and if
timeout>0 and elapsed>timeout, do not treat the result as a successful
measurement — put a timeout result into result_queue.put((idx, config,
jit_kernel, None, None, "timeout", "")) (instead of the success tuple) and avoid
raising TimeoutException there; reference the run() function, benchmark_target,
result_queue.put and TimeoutException so you modify both the docstring and the
post-benchmark result handling accordingly.

1019-1043: 💤 Low value

Compile-loop blind except Exception swallows non-compilation errors.

The except Exception as e at line 1029 wraps future.result(), which can raise both compilation exceptions raised inside compile_unit and genuine programming errors (e.g., a TypeError from _prepare_compile_execution plumbing). Both paths currently log at DEBUG and continue, which can hide real bugs during development. Consider logging at WARNING for non-Exception or unexpected types, or at minimum elevating the log level so the user sees that a unit failed without re-running with a debug logger.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 1019 - 1043, The current broad
except around future.result() in the compile loop swallows real programming
errors from functions like compile_unit and _prepare_compile_execution; change
this to (1) catch the specific compilation exception type if one exists (e.g.,
CompilationError) and handle it as before, and (2) for the general except
Exception as e branch, escalate the log to logger.warning and include
exc_info=True (and the unit_indexes) so the traceback is visible; if you cannot
identify a specific compilation exception type, at minimum log with
logger.warning(..., exc_info=True) and then re-raise unexpected exceptions to
avoid silently continuing on real bugs.
tilelang/autotuner/grouped_compile.py (1)

81-97: ⚖️ Poor tradeoff

Add validation that all device modules have identical attrs before merging.

Currently, merged_attrs is taken from lowered_items[0]["device_mod"].attrs without checking that all subsequent device modules have the same attrs. While all configs are lowered with identical PassContext and target (making identical attrs likely), this assumption is implicit and unenforced. If future changes introduce per-config lowering variations or pass-context differences, attrs could diverge silently without notice. Either assert that all device_mod.attrs match the first, or explicitly document this invariant.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/grouped_compile.py` around lines 81 - 97, Validate that
every device module's attrs equals the first module's attrs before building
merged_device_mod: capture the first attrs into merged_attrs from
lowered_items[0]["device_mod"].attrs, then in the loop over lowered_items
compare each device_mod.attrs to merged_attrs and raise a RuntimeError
(including the config index/item['idx'] and differing attrs) if they differ;
keep merged_attrs, merged_funcs, merged_names logic and only construct
tvm.IRModule(merged_funcs, attrs=merged_attrs) after the validation passes.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tilelang/autotuner/grouped_compile.py`:
- Around line 50-53: The code assumes program.attrs["global_symbol"] exists and
will crash if attrs is None or missing the key; update the block around
elaborate_func(...) so you defensively obtain the symbol (e.g., check if
program.attrs is truthy and "global_symbol" in program.attrs, or wrap the access
in a try/except catching TypeError/KeyError) and fall back to a safe default or
raise a controlled per-config error that the caller can handle; then construct
unique_symbol (used with with_attr("global_symbol", unique_symbol)) from that
safe value (for example a generated "gc_{idx}" when missing) so a single bad
config doesn't abort the whole grouping.

In `@tilelang/autotuner/tuner.py`:
- Around line 977-980: In multi-GPU mode ref_latency is recomputed per worker
because each worker creates a fresh _BenchmarkWorkerState and hits the
ref_latency_cache is None branch in _benchmark_target, leading
_process_benchmark_result to overwrite ref_latency nondeterministically; to fix,
compute the reference latency once (on the coordinator or the first worker) and
either (a) pre-seed each worker's _BenchmarkWorkerState.ref_latency_cache with
that single value before workers start so _benchmark_target skips re-running
ref_prog, or (b) make _process_benchmark_result only set the global ref_latency
if it is not already set (or if the result comes from the chosen coordinator
device) so the recorded ref_latency is pinned to a single chosen device when
benchmark_multi_gpu=True; modify the coordinator/worker setup code that
instantiates _BenchmarkWorkerState and the logic in
_benchmark_target/_process_benchmark_result to implement one of these
approaches.
- Around line 919-923: The current writeback of main_thread_benchmark_state into
self.jit_input_tensors/ref_input_tensors/ref_latency_cache is a no-op in
multi-GPU mode because workers use fresh _BenchmarkWorkerState instances; update
the code to either (A) maintain a per-device cache on self (e.g.,
self._per_device_jit_input_tensors, self._per_device_ref_input_tensors,
self._per_device_ref_latency_cache keyed by device id) and have each worker
merge its _BenchmarkWorkerState into the per-device entry, or (B) if
benchmark_multi_gpu is True skip the global writeback of
main_thread_benchmark_state into self.* and document that cross-run caching is
per-device only; make the change where main_thread_benchmark_state is created
and where it is currently written back (references: main_thread_benchmark_state,
_BenchmarkWorkerState, benchmark_multi_gpu, self.jit_input_tensors,
ref_input_tensors, ref_latency_cache).

---

Nitpick comments:
In `@tilelang/autotuner/grouped_compile.py`:
- Around line 81-97: Validate that every device module's attrs equals the first
module's attrs before building merged_device_mod: capture the first attrs into
merged_attrs from lowered_items[0]["device_mod"].attrs, then in the loop over
lowered_items compare each device_mod.attrs to merged_attrs and raise a
RuntimeError (including the config index/item['idx'] and differing attrs) if
they differ; keep merged_attrs, merged_funcs, merged_names logic and only
construct tvm.IRModule(merged_funcs, attrs=merged_attrs) after the validation
passes.

In `@tilelang/autotuner/tuner.py`:
- Around line 460-476: The code in _resolve_grouped_compile_mode unnecessarily
calls str() on self.compile_args.execution_backend (already a Literal[str]);
remove the redundant str() conversion and use execution_backend =
self.compile_args.execution_backend directly (or keep a local variable named
execution_backend) so the gate logic remains unchanged; update any related
logging or returned tuple to reference the execution_backend variable instead of
the str(...) call.
- Around line 768-795: The run() signature adds new knobs but lacks validation
and doc clarity: add validation in run() to ensure each entry in
benchmark_devices (or provided list) is an int >= 0 (or raise ValueError) and
ensure incompatible combinations (e.g., benchmark_multi_gpu=True requires
benchmark_devices non-empty) are checked; update the docstring of run() to
explicitly state that timeout is best-effort after worker completion (may not
abort already finished runs) and that use_pipeline=True means benchmarking can
start before all compilations finish (earlier configs may see less-loaded GPUs,
introducing bias); reference the helper _resolve_benchmark_devices in the
validation note so the implementation uses or mirrors that logic.
- Around line 517-536: The cuda_device_wrapper currently captures
torch.cuda.current_device() on the worker thread; change the flow so the device
is captured once on the main thread in _prepare_compile_execution and then
passed into get_compile_func/get_elaborate_func so cuda_device_wrapper uses that
captured device instead of calling torch.cuda.current_device() inside the
worker; specifically, update _prepare_compile_execution to read device =
torch.cuda.current_device() (if cuda available) and thread-safely pass that
device into get_compile_func/get_elaborate_func which should call
cuda_device_wrapper(compile_func, device) / cuda_device_wrapper(elaborate_func,
device) so each worker calls torch.cuda.set_device(device) using the
main-thread-chosen device before invoking compile_func/elaborate_func (ensure
compile_unit still calls get_compile_func/get_elaborate_func as before).
- Around line 597-610: The timeout check in run() is currently post-hoc because
benchmark_target(...) runs inline and cannot be interrupted; update the run()
docstring to explicitly state that the timeout parameter is a post-completion
filter (it cannot abort a hung kernel when benchmark_multi_gpu or use_pipeline
is used) and then change the post-benchmark logic in run(): after calling
benchmark_target(...) compute elapsed and if timeout>0 and elapsed>timeout, do
not treat the result as a successful measurement — put a timeout result into
result_queue.put((idx, config, jit_kernel, None, None, "timeout", "")) (instead
of the success tuple) and avoid raising TimeoutException there; reference the
run() function, benchmark_target, result_queue.put and TimeoutException so you
modify both the docstring and the post-benchmark result handling accordingly.
- Around line 1019-1043: The current broad except around future.result() in the
compile loop swallows real programming errors from functions like compile_unit
and _prepare_compile_execution; change this to (1) catch the specific
compilation exception type if one exists (e.g., CompilationError) and handle it
as before, and (2) for the general except Exception as e branch, escalate the
log to logger.warning and include exc_info=True (and the unit_indexes) so the
traceback is visible; if you cannot identify a specific compilation exception
type, at minimum log with logger.warning(..., exc_info=True) and then re-raise
unexpected exceptions to avoid silently continuing on real bugs.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8054724d-b0b9-413b-9946-39efbc1f76e3

📥 Commits

Reviewing files that changed from the base of the PR and between c2cabd9 and cb34fb0.

📒 Files selected for processing (5)
  • examples/gemm/example_gemm_autotune.py
  • tilelang/autotuner/grouped_compile.py
  • tilelang/autotuner/param.py
  • tilelang/autotuner/tuner.py
  • tilelang/engine/lower.py

Comment on lines +50 to +53
program = elaborate_func(**config_arg)
original_symbol = str(program.attrs["global_symbol"])
unique_symbol = f"{original_symbol}_gc_{idx}"
program = program.with_attr("global_symbol", unique_symbol)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Defensive access to global_symbol attribute.

program.attrs["global_symbol"] will raise if attrs is None or missing the key. Although tilelang elaborated PrimFuncs typically carry it, the failure here aborts the whole group rather than just this config — guarding it ensures the per-config error path captures it cleanly.

🛡️ Suggested defensive lookup
-            original_symbol = str(program.attrs["global_symbol"])
+            global_symbol_attr = program.attrs.get("global_symbol") if program.attrs else None
+            if global_symbol_attr is None:
+                raise RuntimeError(
+                    "Grouped compilation requires PrimFunc to carry the 'global_symbol' attribute"
+                )
+            original_symbol = str(global_symbol_attr)
             unique_symbol = f"{original_symbol}_gc_{idx}"
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/grouped_compile.py` around lines 50 - 53, The code assumes
program.attrs["global_symbol"] exists and will crash if attrs is None or missing
the key; update the block around elaborate_func(...) so you defensively obtain
the symbol (e.g., check if program.attrs is truthy and "global_symbol" in
program.attrs, or wrap the access in a try/except catching TypeError/KeyError)
and fall back to a safe default or raise a controlled per-config error that the
caller can handle; then construct unique_symbol (used with
with_attr("global_symbol", unique_symbol)) from that safe value (for example a
generated "gc_{idx}" when missing) so a single bad config doesn't abort the
whole grouping.

Comment on lines +919 to +923
main_thread_benchmark_state = _BenchmarkWorkerState(
jit_input_tensors=self.jit_input_tensors,
ref_input_tensors=self.ref_input_tensors,
ref_latency_cache=self.ref_latency_cache,
)
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Input-tensor cache propagation back to self is a no-op in multi-GPU mode.

main_thread_benchmark_state is initialized from self.jit_input_tensors etc. and then read back into self at lines 1065-1067, but in multi-GPU mode (line 995) workers use fresh _BenchmarkWorkerState() instances, so main_thread_benchmark_state is never mutated. The result is that self.jit_input_tensors/ref_input_tensors/ref_latency_cache retain stale or empty values across repeated run() calls when benchmark_multi_gpu=True, defeating the cross-run caching that single-device mode provides. Either keep a per-device map on self or skip the writeback when multi-GPU is active and document the trade-off.

Also applies to: 1065-1067

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 919 - 923, The current writeback of
main_thread_benchmark_state into
self.jit_input_tensors/ref_input_tensors/ref_latency_cache is a no-op in
multi-GPU mode because workers use fresh _BenchmarkWorkerState instances; update
the code to either (A) maintain a per-device cache on self (e.g.,
self._per_device_jit_input_tensors, self._per_device_ref_input_tensors,
self._per_device_ref_latency_cache keyed by device id) and have each worker
merge its _BenchmarkWorkerState into the per-device entry, or (B) if
benchmark_multi_gpu is True skip the global writeback of
main_thread_benchmark_state into self.* and document that cross-run caching is
per-device only; make the change where main_thread_benchmark_state is created
and where it is currently written back (references: main_thread_benchmark_state,
_BenchmarkWorkerState, benchmark_multi_gpu, self.jit_input_tensors,
ref_input_tensors, ref_latency_cache).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It serves for code simplicity, and there is no significant overhead in adding this line.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wazrrr, understood — since the no-op writeback in multi-GPU mode doesn't affect correctness or introduce measurable overhead, keeping it for simplicity is a reasonable trade-off. Thanks for the clarification!

Comment on lines +977 to +980
if worker_ref_latency is not None:
ref_latency = worker_ref_latency
assert latency is not None
_record_benchmark_result(latency=latency, config=config, jit_kernel=jit_kernel, idx=idx, progress_bar=progress_bar)
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Multi-GPU benchmark: ref_latency is recomputed per worker, and the recorded value is non-deterministic.

In multi-GPU mode each worker thread is given a fresh _BenchmarkWorkerState() (line 995), so every worker hits the ref_latency_cache is None branch in _benchmark_target and re-runs ref_prog on its own device. Then in _process_benchmark_result (lines 977-978) ref_latency is overwritten with whichever worker's value arrives last, so the final reported ref_latency depends on completion ordering across heterogeneous-perf GPUs. Two improvements worth considering:

  1. Compute ref_latency once on the coordinator (or first worker) and pre-seed each _BenchmarkWorkerState.ref_latency_cache so subsequent workers skip the redundant ref benchmark.
  2. Pin the reported ref_latency to a single chosen device, or document explicitly that the reported ref latency is from an arbitrary GPU when benchmark_multi_gpu=True.

Also applies to: 994-1011

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 977 - 980, In multi-GPU mode
ref_latency is recomputed per worker because each worker creates a fresh
_BenchmarkWorkerState and hits the ref_latency_cache is None branch in
_benchmark_target, leading _process_benchmark_result to overwrite ref_latency
nondeterministically; to fix, compute the reference latency once (on the
coordinator or the first worker) and either (a) pre-seed each worker's
_BenchmarkWorkerState.ref_latency_cache with that single value before workers
start so _benchmark_target skips re-running ref_prog, or (b) make
_process_benchmark_result only set the global ref_latency if it is not already
set (or if the result comes from the chosen coordinator device) so the recorded
ref_latency is pinned to a single chosen device when benchmark_multi_gpu=True;
modify the coordinator/worker setup code that instantiates _BenchmarkWorkerState
and the logic in _benchmark_target/_process_benchmark_result to implement one of
these approaches.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multi-GPU benchmark only works for homogeneous GPUs, it's not a problem.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wazrrr, understood — since multi-GPU benchmarking is scoped to homogeneous GPU setups, the ref_latency values computed independently on each worker will be equivalent, making the overwrite in _process_benchmark_result benign. Thanks for the clarification!


✏️ Learnings added
Learnt from: Wazrrr
Repo: tile-ai/tilelang PR: 2159
File: tilelang/autotuner/tuner.py:977-980
Timestamp: 2026-05-07T06:18:06.278Z
Learning: In `tilelang/autotuner/tuner.py`, the multi-GPU benchmark feature (`benchmark_multi_gpu`) is intentionally designed for homogeneous GPU setups only. Each benchmark worker independently computes `ref_latency` via a fresh `_BenchmarkWorkerState`, and the final `ref_latency` is overwritten by whichever worker result arrives last. This is acceptable because all GPUs are homogeneous, so ref_latency values across workers are equivalent. Do not flag this as a non-determinism issue.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

@SiriusNEO SiriusNEO self-requested a review May 7, 2026 04:22
@Wazrrr
Copy link
Copy Markdown
Author

Wazrrr commented May 7, 2026

Experiments

Common setup for individual feature validation

All individual feature tests (pipeline-only, grouped-compilation-only, multi-GPU-benchmark-only) use the same common settings:

--shape 4096x4096x4096
--cpu-count 2
--runs 1
--with-roller
--topk 20
--warmup 3
--rep 20
--timeout 180
--skip-check
--disable-cache
--no-cache-input-tensors

These settings are intentionally lightweight for fast functional validation, and are directly extensible to larger --cpu-count, larger config spaces, and more shapes/runs.

1) Pipeline only

  • Figure: pipeline_only_end_to_end_s
  • NoPipe: 79.258s
  • Pipe: 69.742s
  • End-to-end speedup: ~1.14x (79.258 / 69.742)
  • Note: --cpu-count 2 is small, so this run is mainly to demonstrate pipeline functionality and runtime behavior. With tuned worker counts and larger search
    spaces, gains are expected to be higher.

2) Multi-GPU benchmark

  • Figure: multi_gpu_benchmark_benchmark_s
  • 1-GPU benchmark time: 10.655s
  • MultiGPU benchmark time: 5.368s
  • Benchmark-stage speedup: ~1.98x (10.655 / 5.368)
  • Since this setup uses 2 GPUs for benchmarking, the benchmark stage is close to the expected ~2x reduction.

3) Grouped compilation

  • Figures: grouped_compilation_compilation_sgrouped_compilation_benchmark_sgrouped_compilation_end_to_end_s
  • G* means grouped compilation size:
    • G1: no grouping baseline
    • G2: group size 2
    • G4: group size 4

Observed behavior:

  • Compilation time drops strongly with larger group size:
    • G1: 69.720s
    • G2: 37.768s
    • G4: 26.259s
  • Benchmark time increases moderately with larger group size:
    • G1: 10.649s
    • G2: 12.954s
    • G4: 17.290s
  • This is expected: grouped compilation reduces compile work, while benchmark-side dispatch/coordination overhead grows somewhat.
  • End-to-end time still improves significantly:
    • G1: 80.373s
    • G2: 50.726s
    • G4: 43.553s
    • Speedups vs G1: ~1.58x (G2), ~1.85x (G4)

4) All-features combined

  • Figure: all_features_end_to_end_s
  • Baseline: 222.513s
  • AllFeat (pipeline + grouped compile(size=4) + 2-GPU benchmark): 138.225s
  • End-to-end speedup: ~1.61x

Final common experiment setting:

 --shape 4096x4096x4096
 --cpu-count 32
 --runs 1
 --without-roller
 --warmup 3
 --rep 20
 --timeout 180
 --skip-check
 --disable-cache
 --no-cache-input-tensors
 --execution-backend auto

Overall, the combined result is already strong, and further parameter tuning (CPU workers, config scale, shape mix, and benchmark device strategy) should provide additional gains.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (6)
tilelang/autotuner/tuner.py (5)

855-857: 💤 Low value

Reassigning self.jit_compile / self.jit_elaborate mutates instance state on every run().

_ensure_jit_functions returns either user-provided callables or the defaults, then both are written back to self. On a second run() of the same tuner where the user later cleared self.jit_compile, the prior default is now sticky. Minor but possibly surprising for re-use; consider keeping the resolved callables purely local to run().

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 855 - 857, The current run()
mutates instance state by assigning the resolved callables from
_ensure_jit_functions() back to self.jit_compile and self.jit_elaborate, making
defaults “sticky”; instead, call _ensure_jit_functions() and use the returned
compile_func and elaborate_func as local variables within run() (do not assign
them back to self), so subsequent runs respect changes to
self.jit_compile/self.jit_elaborate made by the caller.

959-963: ⚖️ Poor tradeoff

Static modulo assignment is not load-balanced across benchmark workers.

Tasks are distributed by idx % len(benchmark_task_queues), so worker 0 always handles configs 0, N, 2N, … regardless of how long each config takes. Heterogeneous compile completion order combined with variable benchmark duration will leave faster workers idle while a slower one is still draining its private queue.

A single shared SimpleQueue consumed by all workers (work-stealing) would let any free worker pick up the next ready kernel and would also simplify shutdown sentinels. Worth doing if benchmark wall-time matters more than per-device locality.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 959 - 963, The current
_enqueue_benchmark_task uses static assignment idx % len(benchmark_task_queues)
which causes poor load balancing; replace the per-worker queues with a single
shared queue (e.g., a multiprocessing.SimpleQueue or queue.SimpleQueue depending
on context) and have all benchmark workers consume from that shared queue;
update producers to put (jit_kernel, config, idx) into the shared queue instead
of benchmark_task_queues[...], remove or adapt any per-queue shutdown sentinels
to a single global sentinel, and keep benchmark_expected_results increments the
same so result accounting remains correct.

612-703: 💤 Low value

_benchmark_target correctness check uses possibly-stale jit_input_tensors_cache after recompute.

When cache_input_tensors=True and a shape/dtype mismatch is detected (lines 664-676), jit_input_tensors_cache is regenerated and the loop breaks — good. However, on the first iteration where jit_input_tensors_cache is None (line 651-652), params was already fetched but never compared against the freshly-generated tensors; that's fine because there's nothing to compare to. The flow looks correct.

One nit: at line 700, benchmark_state.ref_input_tensors = ref_input_tensors_cache always writes back, even when ref_prog is None (in which case ref_input_tensors_cache is whatever it was on entry). Harmless but slightly wasteful. Not blocking.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 612 - 703, The function
_benchmark_target currently always writes back benchmark_state.ref_input_tensors
and benchmark_state.ref_latency_cache even when ref_prog is None; change the
final write-back so that benchmark_state.ref_input_tensors and
benchmark_state.ref_latency_cache are only updated when ref_prog is not None
(leave benchmark_state.jit_input_tensors update as-is), locating these symbols
in the _benchmark_target function and gating the assignments to
ref_input_tensors_cache/ref_latency_cache behind a check of ref_prog.

770-797: 💤 Low value

Validate other public-API parameters too.

group_compile_size correctly raises on <= 0 (line 796-797), but warmup, rep, timeout accept any int and benchmark_devices accepts arbitrary ints (parsed in _resolve_benchmark_devices). Not a blocker — invalid values fail later with less obvious messages — but a small early validation block here would improve UX. partial(autotuner.run, self.warmup, self.rep, self.timeout) at line 1160 still resolves correctly with the new keyword-only additions.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 770 - 797, The run method currently
only validates group_compile_size; add early validation in run to check warmup
is >= 0, rep and timeout are > 0 and all three are integers, and validate
benchmark_devices (if not None) is an iterable of non-negative ints (no negative
ordinals) before calling _resolve_benchmark_devices; keep existing
group_compile_size check and don’t change the function signature so
partial(autotuner.run, self.warmup, self.rep, self.timeout) still works—raise
ValueError with clear messages for each invalid parameter.

460-476: 💤 Low value

Silent fallback when grouped compilation isn't supported.

grouped_compile_reason is logged at INFO and then the third return value (grouped_compile_reason) is unused at the call site (line 874 destructures _, _, grouped_compile_active, _). For a feature the user explicitly opted into via --grouped-compile, surfacing the rejection at WARNING (or echoing it in the progress description / final summary) makes the fallback discoverable when users wonder why their --group-compile-size 4 had no effect.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 460 - 476, The grouped-compile
fallback message is only logged at INFO and its return value is being ignored by
the caller, so users don't see why their --grouped-compile was disabled; change
the logger call in _resolve_grouped_compile_mode to logger.warning(...) when
grouped_compile_requested and not active, and update the caller that currently
destructures as "_, _, grouped_compile_active, _" to capture the fourth return
(grouped_compile_reason) and surface it into the progress description or final
summary so the warning is visible to users.
tilelang/autotuner/grouped_compile.py (1)

80-96: ⚖️ Poor tradeoff

Group-wide failure on merge errors is coarse-grained.

A single duplicate name_hint (line 91-93) or any failure in merged_device_mod construction or shared device_codegen (line 100) is caught by the broad except at line 151, which then attributes the error to every successfully lowered config in the group. That makes diagnostics misleading (one bad config taints siblings) and prevents partial progress for the group.

If feasible, detect duplicate name_hint early enough to mark only the colliding idx as failed, and only fall back to broadcasting the error if codegen of the merged module truly cannot proceed without the offending function.

Also applies to: 151-153

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/grouped_compile.py` around lines 80 - 96, The merge
currently treats any duplicate name_hint or any failure during merged IRModule
construction / device_codegen as a group-wide error; instead, detect and isolate
failing configs: build a map of name_hint -> list of item['idx'] from
device_mod.functions and if any name_hint maps to multiple idxs mark only those
colliding idxs as failed (record their indices and skip adding their funcs to
merged_funcs), then attempt to construct merged_device_mod from the remaining
funcs; wrap tvm.IRModule(...) in a try/except and if it fails, iterate remaining
items and try constructing per-item IRModule or constructing merged IRModule
while removing one item at a time to identify and mark specific failing
item['idx'] values, then rebuild merged_device_mod without those; similarly, if
device_codegen(merged_device_mod) raises, run device_codegen on each surviving
single-item module to localize and mark only the failing configs instead of
blaming the whole group (refer to merged_funcs, merged_names, merged_device_mod,
device_mod.functions, lowered_items, and item['idx'] to locate code).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tilelang/autotuner/tuner.py`:
- Around line 596-610: Worker threads can hang when benchmark_target blocks and
then deadlock on worker_thread.join; set the worker threads to daemon=True and
make the worker loop enforce a hard timeout around benchmark_target so a blocked
benchmark is recorded as a timeout: in the worker creation code that spawns the
thread (the place that later calls worker_thread.join), mark the Thread as
daemon=True; inside the worker loop that currently calls
benchmark_target(jit_kernel=..., benchmark_state=worker_state) wrap that call in
a short-lived sub-thread or use a timed wait (e.g. start a Thread for
benchmark_target and join it with the configured timeout) and on timeout put
(idx, config, jit_kernel, None, None, "timeout", "") into result_queue; ensure
any exception path still puts the error tuple into result_queue; also bound
queue.get/result_queue.get waits in _drain_benchmark_results to avoid blocking
forever so the orchestrator can continue.

---

Nitpick comments:
In `@tilelang/autotuner/grouped_compile.py`:
- Around line 80-96: The merge currently treats any duplicate name_hint or any
failure during merged IRModule construction / device_codegen as a group-wide
error; instead, detect and isolate failing configs: build a map of name_hint ->
list of item['idx'] from device_mod.functions and if any name_hint maps to
multiple idxs mark only those colliding idxs as failed (record their indices and
skip adding their funcs to merged_funcs), then attempt to construct
merged_device_mod from the remaining funcs; wrap tvm.IRModule(...) in a
try/except and if it fails, iterate remaining items and try constructing
per-item IRModule or constructing merged IRModule while removing one item at a
time to identify and mark specific failing item['idx'] values, then rebuild
merged_device_mod without those; similarly, if device_codegen(merged_device_mod)
raises, run device_codegen on each surviving single-item module to localize and
mark only the failing configs instead of blaming the whole group (refer to
merged_funcs, merged_names, merged_device_mod, device_mod.functions,
lowered_items, and item['idx'] to locate code).

In `@tilelang/autotuner/tuner.py`:
- Around line 855-857: The current run() mutates instance state by assigning the
resolved callables from _ensure_jit_functions() back to self.jit_compile and
self.jit_elaborate, making defaults “sticky”; instead, call
_ensure_jit_functions() and use the returned compile_func and elaborate_func as
local variables within run() (do not assign them back to self), so subsequent
runs respect changes to self.jit_compile/self.jit_elaborate made by the caller.
- Around line 959-963: The current _enqueue_benchmark_task uses static
assignment idx % len(benchmark_task_queues) which causes poor load balancing;
replace the per-worker queues with a single shared queue (e.g., a
multiprocessing.SimpleQueue or queue.SimpleQueue depending on context) and have
all benchmark workers consume from that shared queue; update producers to put
(jit_kernel, config, idx) into the shared queue instead of
benchmark_task_queues[...], remove or adapt any per-queue shutdown sentinels to
a single global sentinel, and keep benchmark_expected_results increments the
same so result accounting remains correct.
- Around line 612-703: The function _benchmark_target currently always writes
back benchmark_state.ref_input_tensors and benchmark_state.ref_latency_cache
even when ref_prog is None; change the final write-back so that
benchmark_state.ref_input_tensors and benchmark_state.ref_latency_cache are only
updated when ref_prog is not None (leave benchmark_state.jit_input_tensors
update as-is), locating these symbols in the _benchmark_target function and
gating the assignments to ref_input_tensors_cache/ref_latency_cache behind a
check of ref_prog.
- Around line 770-797: The run method currently only validates
group_compile_size; add early validation in run to check warmup is >= 0, rep and
timeout are > 0 and all three are integers, and validate benchmark_devices (if
not None) is an iterable of non-negative ints (no negative ordinals) before
calling _resolve_benchmark_devices; keep existing group_compile_size check and
don’t change the function signature so partial(autotuner.run, self.warmup,
self.rep, self.timeout) still works—raise ValueError with clear messages for
each invalid parameter.
- Around line 460-476: The grouped-compile fallback message is only logged at
INFO and its return value is being ignored by the caller, so users don't see why
their --grouped-compile was disabled; change the logger call in
_resolve_grouped_compile_mode to logger.warning(...) when
grouped_compile_requested and not active, and update the caller that currently
destructures as "_, _, grouped_compile_active, _" to capture the fourth return
(grouped_compile_reason) and surface it into the progress description or final
summary so the warning is visible to users.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f01d791b-c952-42c4-bfc7-3d8590f5dd27

📥 Commits

Reviewing files that changed from the base of the PR and between cb34fb0 and f3bda82.

📒 Files selected for processing (4)
  • examples/gemm/example_gemm_autotune.py
  • tilelang/autotuner/grouped_compile.py
  • tilelang/autotuner/tuner.py
  • tilelang/engine/lower.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/gemm/example_gemm_autotune.py

Comment thread tilelang/autotuner/tuner.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tilelang/autotuner/tuner.py`:
- Around line 600-637: The nested function _run_benchmark_target closes over
loop-local jit_kernel and call_result_queue by reference, so if a previous
daemon benchmark_call_thread times out and the loop rebinds those names the
still-running thread will use the new values and corrupt results; fix by binding
the current values into the thread callback so they are captured by value (e.g.,
create local vars like _jit_kernel = jit_kernel and _call_result_queue =
call_result_queue or pass them as default arguments to _run_benchmark_target)
and then use those bound names when calling benchmark_target and putting
results, ensuring the thread always uses the intended kernel and queue for that
iteration before starting benchmark_call_thread and joining; keep the existing
TimeoutException and result_queue handling (idx, config) unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 12b8aa86-c52f-46c4-a707-5ba48dca429e

📥 Commits

Reviewing files that changed from the base of the PR and between f3bda82 and 8d4c0fc.

📒 Files selected for processing (1)
  • tilelang/autotuner/tuner.py

Comment thread tilelang/autotuner/tuner.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tilelang/autotuner/tuner.py (1)

1114-1116: ⚡ Quick win

pool.shutdown() does not cancel pending compile futures on early exit.

On the normal completion path this is fine (all futures already drained). But on an exception or KeyboardInterrupt, any not-yet-started compilation tasks remain queued and pool.shutdown() blocks until each runs to completion. Since this codebase supports Python 3.9+, cancel_futures=True is available and should be passed to shed the pending work immediately.

♻️ Proposed fix
-            pool.shutdown()
+            pool.shutdown(cancel_futures=True)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 1114 - 1116, The current shutdown
sequence closes compile_progress and progress_bar then calls pool.shutdown()
which will block waiting for queued compile futures to run on early exit; update
the shutdown call in tuner.py to call pool.shutdown(cancel_futures=True) so
pending compile futures are cancelled immediately (use this argument in the
exception/KeyboardInterrupt/early-exit cleanup path where compile_progress,
progress_bar are closed) to avoid blocking; reference the existing variables
compile_progress, progress_bar and the ThreadPool/ProcessPool instance pool when
making this change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tilelang/autotuner/tuner.py`:
- Around line 707-712: The current shape_equal helper silently accepts
different-rank tensors because it zips shapes; change shape_equal (used where p
and c are compared before do_bench) to first require equal ranks (e.g., if
len(a.shape) != len(b.shape) return False) and only then compare each dimension
allowing Vars (keep the existing isinstance(..., Var) logic). Update the
function used in the if condition (shape_equal) so mismatched ranks fail fast
and prevent stale tensors from reaching do_bench.
- Around line 735-744: The call to profiler.do_bench at the top of the block is
passing iteration-count variables warmup and rep using the wrong parameter names
(warmup=, rep=) causing them to be interpreted as millisecond targets; update
the profiler.do_bench invocation that assigns latency (the one using
jit_input_tensors_cache and backend) to pass these as n_warmup=warmup and
n_repeat=rep to match the later ref_prog call and the intended semantics.

---

Nitpick comments:
In `@tilelang/autotuner/tuner.py`:
- Around line 1114-1116: The current shutdown sequence closes compile_progress
and progress_bar then calls pool.shutdown() which will block waiting for queued
compile futures to run on early exit; update the shutdown call in tuner.py to
call pool.shutdown(cancel_futures=True) so pending compile futures are cancelled
immediately (use this argument in the exception/KeyboardInterrupt/early-exit
cleanup path where compile_progress, progress_bar are closed) to avoid blocking;
reference the existing variables compile_progress, progress_bar and the
ThreadPool/ProcessPool instance pool when making this change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c74dd60d-b354-4112-be79-52e31158806f

📥 Commits

Reviewing files that changed from the base of the PR and between 8d4c0fc and cd35151.

📒 Files selected for processing (1)
  • tilelang/autotuner/tuner.py

Comment thread tilelang/autotuner/tuner.py
Comment thread tilelang/autotuner/tuner.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
tilelang/autotuner/tuner.py (1)

738-748: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Use the same do_bench() parameter mode for kernel and reference timings.

At line 738, warmup and rep (iteration counts) are passed to warmup= and rep= parameters, which are treated as millisecond targets with auto-calculated iteration counts. At lines 742–748, the same variables are passed to n_warmup= and n_repeat= parameters, which bypass auto-calculation and use them as explicit iteration counts. This inconsistency causes the kernel and reference timings to use different benchmarking semantics.

Expected fix
-        latency = profiler.do_bench(warmup=warmup, rep=rep, input_tensors=jit_input_tensors_cache, backend=backend)
+        latency = profiler.do_bench(
+            n_warmup=warmup,
+            n_repeat=rep,
+            input_tensors=jit_input_tensors_cache,
+            backend=backend,
+        )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 738 - 748, The kernel and reference
benchmarking calls use different do_bench parameter modes: the first call passes
warmup= and rep= (time-target mode) while the reference call passes n_warmup=
and n_repeat= (explicit-iteration mode), causing inconsistent semantics; update
the reference call inside the ref_latency_cache block (the profiler.do_bench
invocation that currently uses ref_prog, n_warmup, n_repeat,
input_tensors=ref_input_tensors_cache, backend=backend) to use the same
parameter names as the kernel timing (warmup= and rep=) so both measurements use
the same benchmarking mode and inputs supplied by ref_input_tensors_supply.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tilelang/autotuner/tuner.py`:
- Around line 372-377: The current supply_prog closure freezes a single list of
captured tensors (frozen_inputs) and returns the same objects to every worker,
causing cross-device CUDA tensor reuse; instead, change supply_prog (the lambda
assigned where get_autotune_inputs() is used) to produce a per-device cache:
when first called for a given device id, materialize/move/clone the frozen
inputs onto that device and store them in a dict keyed by device, then return
the device-specific list on subsequent calls. Use the existing
get_autotune_inputs(), frozen_inputs, and supply_prog symbols to implement a
small per-device memoization (e.g., device_cache = {}) and ensure the
supply_prog signature reads the device argument and returns
device_cache[device].
- Around line 785-794: The current logic sets requested_devices to single_device
when benchmark_devices is None and CUDA_VISIBLE_DEVICES is unset, which prevents
multi-GPU benchmarking (benchmark_multi_gpu) from using all GPUs; change the
fallback to detect the total available device count and set requested_devices to
list(range(device_count)) (using the framework's GPU count API, e.g.,
torch.cuda.device_count() or equivalent) instead of single_device, while
preserving the existing behavior when CUDA_VISIBLE_DEVICES is present or
benchmark_devices is provided; adjust references in this block that assign
requested_devices, and ensure single_device remains available as a last-resort
fallback if device_count == 0.
- Around line 607-628: The per-call thread currently reuses the shared
_BenchmarkWorkerState (worker_state) causing races when a thread times out; to
fix, create a per-call deep copy of worker_state (e.g.,
copy.deepcopy(worker_state)) and pass that copy into _run_benchmark_target
(rename param to _worker_state_copy) so the daemon thread only mutates its
private state; when you read the call_result_queue and see a successful ("ok",
...) result, merge the relevant mutable pieces (cached tensors,
ref_latency_cache, or other worker-state caches) from the per-call copy back
into the original worker_state in a well-defined merge step; ensure
timeout/error paths keep the original worker_state unchanged and only the "ok"
branch performs the merge, leaving variable names _run_benchmark_target,
benchmark_target, _BenchmarkWorkerState, call_result_queue, and result_queue
intact for easy location.

---

Duplicate comments:
In `@tilelang/autotuner/tuner.py`:
- Around line 738-748: The kernel and reference benchmarking calls use different
do_bench parameter modes: the first call passes warmup= and rep= (time-target
mode) while the reference call passes n_warmup= and n_repeat=
(explicit-iteration mode), causing inconsistent semantics; update the reference
call inside the ref_latency_cache block (the profiler.do_bench invocation that
currently uses ref_prog, n_warmup, n_repeat,
input_tensors=ref_input_tensors_cache, backend=backend) to use the same
parameter names as the kernel timing (warmup= and rep=) so both measurements use
the same benchmarking mode and inputs supplied by ref_input_tensors_supply.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: be1be661-be78-477b-9049-e1fd6cb007f2

📥 Commits

Reviewing files that changed from the base of the PR and between cd35151 and a2dc696.

📒 Files selected for processing (1)
  • tilelang/autotuner/tuner.py

Comment on lines +372 to +377
captured_inputs = get_autotune_inputs()
if captured_inputs is not None:
if supply_prog is not None:
logger.warning("`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context.")
supply_prog = lambda _: get_autotune_inputs() # noqa: E731
frozen_inputs = list(captured_inputs)
supply_prog = lambda _, _frozen_inputs=frozen_inputs: _frozen_inputs # noqa: E731
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Freeze autotune inputs per device, not once globally.

This lambda returns the exact captured tensor objects to every benchmark worker. In multi-GPU mode, CUDA tensors captured on one device get reused on the other workers, which can either fail with device mismatches or benchmark the wrong GPU. Materialize/cache per-device inputs instead of closing over one shared list.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 372 - 377, The current supply_prog
closure freezes a single list of captured tensors (frozen_inputs) and returns
the same objects to every worker, causing cross-device CUDA tensor reuse;
instead, change supply_prog (the lambda assigned where get_autotune_inputs() is
used) to produce a per-device cache: when first called for a given device id,
materialize/move/clone the frozen inputs onto that device and store them in a
dict keyed by device, then return the device-specific list on subsequent calls.
Use the existing get_autotune_inputs(), frozen_inputs, and supply_prog symbols
to implement a small per-device memoization (e.g., device_cache = {}) and ensure
the supply_prog signature reads the device argument and returns
device_cache[device].

Comment on lines +607 to +628
def _run_benchmark_target(
_jit_kernel: tilelang.JITKernel = jit_kernel,
_worker_state: _BenchmarkWorkerState = worker_state,
_call_result_queue: queue.SimpleQueue = call_result_queue,
):
try:
latency, worker_ref_latency = benchmark_target(
jit_kernel=_jit_kernel,
benchmark_state=_worker_state,
)
_call_result_queue.put(("ok", latency, worker_ref_latency, ""))
except TimeoutException:
_call_result_queue.put(("timeout", None, None, ""))
except Exception:
_call_result_queue.put(("error", None, None, traceback.format_exc()))

benchmark_call_thread = threading.Thread(target=_run_benchmark_target, daemon=True)
benchmark_call_thread.start()
benchmark_call_thread.join(timeout=timeout)
if benchmark_call_thread.is_alive():
result_queue.put((idx, config, jit_kernel, None, None, "timeout", ""))
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

A timed-out benchmark can corrupt the next task on the same worker.

When join(timeout=...) expires, the daemon sub-thread keeps running with the same worker_state object that the next config will reuse. If that late call eventually updates cached tensors or ref_latency_cache, it races with the next benchmark and can skew later results. Pass a per-call copy of _BenchmarkWorkerState into _run_benchmark_target and merge it back only after a successful completion.

Minimal containment fix
                     call_result_queue: queue.SimpleQueue = queue.SimpleQueue()
+                    call_state = _BenchmarkWorkerState(
+                        jit_input_tensors=worker_state.jit_input_tensors,
+                        ref_input_tensors=worker_state.ref_input_tensors,
+                        ref_latency_cache=worker_state.ref_latency_cache,
+                    )

                     def _run_benchmark_target(
                         _jit_kernel: tilelang.JITKernel = jit_kernel,
-                        _worker_state: _BenchmarkWorkerState = worker_state,
+                        _worker_state: _BenchmarkWorkerState = call_state,
                         _call_result_queue: queue.SimpleQueue = call_result_queue,
                     ):
                         try:
@@
                     if status == "ok":
+                        worker_state.jit_input_tensors = call_state.jit_input_tensors
+                        worker_state.ref_input_tensors = call_state.ref_input_tensors
+                        worker_state.ref_latency_cache = call_state.ref_latency_cache
                         result_queue.put((idx, config, jit_kernel, latency, worker_ref_latency, None, ""))
🧰 Tools
🪛 Ruff (0.15.12)

[warning] 620-620: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 607 - 628, The per-call thread
currently reuses the shared _BenchmarkWorkerState (worker_state) causing races
when a thread times out; to fix, create a per-call deep copy of worker_state
(e.g., copy.deepcopy(worker_state)) and pass that copy into
_run_benchmark_target (rename param to _worker_state_copy) so the daemon thread
only mutates its private state; when you read the call_result_queue and see a
successful ("ok", ...) result, merge the relevant mutable pieces (cached
tensors, ref_latency_cache, or other worker-state caches) from the per-call copy
back into the original worker_state in a well-defined merge step; ensure
timeout/error paths keep the original worker_state unchanged and only the "ok"
branch performs the merge, leaving variable names _run_benchmark_target,
benchmark_target, _BenchmarkWorkerState, call_result_queue, and result_queue
intact for easy location.

Comment on lines +785 to +794
requested_devices: list[int] = []
if benchmark_devices:
requested_devices = list(dict.fromkeys(int(device) for device in benchmark_devices))
else:
raw_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
parsed_visible_devices = [token.strip() for token in raw_visible_devices.split(",") if token.strip()]
if parsed_visible_devices:
requested_devices = list(range(len(parsed_visible_devices)))
else:
requested_devices = single_device
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

benchmark_multi_gpu=True still resolves to one GPU on the common path.

If callers do not pass benchmark_devices and CUDA_VISIBLE_DEVICES is unset, this branch falls back to single_device, so a multi-GPU host never enables sharding by default. Using all visible devices here would make the flag behave as advertised.

Suggested change
         else:
             raw_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
             parsed_visible_devices = [token.strip() for token in raw_visible_devices.split(",") if token.strip()]
             if parsed_visible_devices:
                 requested_devices = list(range(len(parsed_visible_devices)))
             else:
-                requested_devices = single_device
+                requested_devices = list(range(visible_device_count))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/tuner.py` around lines 785 - 794, The current logic sets
requested_devices to single_device when benchmark_devices is None and
CUDA_VISIBLE_DEVICES is unset, which prevents multi-GPU benchmarking
(benchmark_multi_gpu) from using all GPUs; change the fallback to detect the
total available device count and set requested_devices to
list(range(device_count)) (using the framework's GPU count API, e.g.,
torch.cuda.device_count() or equivalent) instead of single_device, while
preserving the existing behavior when CUDA_VISIBLE_DEVICES is present or
benchmark_devices is provided; adjust references in this block that assign
requested_devices, and ensure single_device remains available as a last-resort
fallback if device_count == 0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant