[Autotune] Add pipeline, grouped compilation, and multi-GPU benchmark support#2159
[Autotune] Add pipeline, grouped compilation, and multi-GPU benchmark support#2159Wazrrr wants to merge 8 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR extracts host/device lowering into a helper, adds grouped CUDA compilation, rewrites the autotuner for concurrent compile+benchmark pipelines (optional grouped compile and multi‑GPU), forwards execution_backend through compile args, and updates the gemm example CLI to expose the new flags. ChangesAutotuner grouped/parallel tuning
Sequence Diagram(s)sequenceDiagram
participant Main as Main Thread
participant CompExec as Compilation Pool
participant BenchWorkers as Benchmark Workers
participant DeviceQueues as Per-Device Queues
Main->>CompExec: schedule compilation tasks (grouped or per-config)
CompExec->>CompExec: elaborate configs -> lower (shared PassContext)
alt grouped
CompExec->>CompExec: merge device funcs -> device codegen once
end
CompExec->>DeviceQueues: enqueue compiled kernels
Main->>BenchWorkers: launch worker threads (per-device or shared)
BenchWorkers->>DeviceQueues: consume kernel tasks
BenchWorkers->>BenchWorkers: execute kernel, cache inputs, measure latency, optional ref-check
BenchWorkers->>Main: report latency & config
Main->>Main: track and select best result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (6)
tilelang/autotuner/tuner.py (5)
460-476: 💤 Low valueGrouped-compile gate uses
str()on already-stringexecution_backend.
self.compile_args.execution_backendis typed as aLiteral[str]and resolved viaresolve_execution_backend(...)inset_compile_args, sostr(...)here is a no-op. Minor cleanup; the gate logic itself looks correct.♻️ Suggested cleanup
- execution_backend = str(self.compile_args.execution_backend) + execution_backend = self.compile_args.execution_backend🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/autotuner/tuner.py` around lines 460 - 476, The code in _resolve_grouped_compile_mode unnecessarily calls str() on self.compile_args.execution_backend (already a Literal[str]); remove the redundant str() conversion and use execution_backend = self.compile_args.execution_backend directly (or keep a local variable named execution_backend) so the gate logic remains unchanged; update any related logging or returned tuple to reference the execution_backend variable instead of the str(...) call.
768-795: ⚡ Quick winDocument and validate the new
run()parameters.
run()now takes five new tuning knobs but onlygroup_compile_sizeis validated (> 0at line 794). A few small additions would help users:
- Validate
benchmark_devicesentries are non-negative ints (currently any negative value is silently filtered out by_resolve_benchmark_devices).- Document that
timeoutis best-effort post-completion (see related comment) for the worker-thread benchmark path.- Mention explicitly that
use_pipeline=Truemeans benchmarks may begin before all compilations complete (i.e., earlier configs benchmark on a less-loaded GPU than later ones), which can bias measurements.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/autotuner/tuner.py` around lines 768 - 795, The run() signature adds new knobs but lacks validation and doc clarity: add validation in run() to ensure each entry in benchmark_devices (or provided list) is an int >= 0 (or raise ValueError) and ensure incompatible combinations (e.g., benchmark_multi_gpu=True requires benchmark_devices non-empty) are checked; update the docstring of run() to explicitly state that timeout is best-effort after worker completion (may not abort already finished runs) and that use_pipeline=True means benchmarking can start before all compilations finish (earlier configs may see less-loaded GPUs, introducing bias); reference the helper _resolve_benchmark_devices in the validation note so the implementation uses or mirrors that logic.
517-536: ⚡ Quick win
cuda_device_wrappercaptures the submitting thread's current device, not the worker's.
get_compile_func/get_elaborate_funcare called insidecompile_uniton the worker thread, which is fine — buttorch.cuda.current_device()is read inside the wrapper on the worker thread the first time it runs, thentorch.cuda.set_device(device)is invoked per call. If the user has set a non-default current device on the main thread before callingrun(), the workers will inherit the default device (sincecurrent_device()is per-thread), not the main thread's selection. If the intent is to compile on the user's chosen device, capturetorch.cuda.current_device()on the main thread once (in_prepare_compile_execution) and pass it down to the wrapper.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/autotuner/tuner.py` around lines 517 - 536, The cuda_device_wrapper currently captures torch.cuda.current_device() on the worker thread; change the flow so the device is captured once on the main thread in _prepare_compile_execution and then passed into get_compile_func/get_elaborate_func so cuda_device_wrapper uses that captured device instead of calling torch.cuda.current_device() inside the worker; specifically, update _prepare_compile_execution to read device = torch.cuda.current_device() (if cuda available) and thread-safely pass that device into get_compile_func/get_elaborate_func which should call cuda_device_wrapper(compile_func, device) / cuda_device_wrapper(elaborate_func, device) so each worker calls torch.cuda.set_device(device) using the main-thread-chosen device before invoking compile_func/elaborate_func (ensure compile_unit still calls get_compile_func/get_elaborate_func as before).
597-610: ⚡ Quick winBenchmark timeout is post-hoc and cannot interrupt a hung kernel.
benchmark_target(...)is invoked inline;time.perf_counter()is read only after it returns, so a kernel that hangs (e.g., infinite GPU-side wait) keeps the worker thread blocked indefinitely and never reaches theelapsed > timeoutcheck. The warning at lines 946-949 acknowledges signal-based timeouts can't be used here, but this means thetimeoutargument is effectively a post-completion filter, not a deadline. Worth reflecting that in the docstring ofrun()(and ideally short-circuiting theresult_queue.putfor already-overrun configs without raising) so users don't expect deadline enforcement when they setbenchmark_multi_gpu=Trueoruse_pipeline=True.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/autotuner/tuner.py` around lines 597 - 610, The timeout check in run() is currently post-hoc because benchmark_target(...) runs inline and cannot be interrupted; update the run() docstring to explicitly state that the timeout parameter is a post-completion filter (it cannot abort a hung kernel when benchmark_multi_gpu or use_pipeline is used) and then change the post-benchmark logic in run(): after calling benchmark_target(...) compute elapsed and if timeout>0 and elapsed>timeout, do not treat the result as a successful measurement — put a timeout result into result_queue.put((idx, config, jit_kernel, None, None, "timeout", "")) (instead of the success tuple) and avoid raising TimeoutException there; reference the run() function, benchmark_target, result_queue.put and TimeoutException so you modify both the docstring and the post-benchmark result handling accordingly.
1019-1043: 💤 Low valueCompile-loop blind
except Exceptionswallows non-compilation errors.The
except Exception as eat line 1029 wrapsfuture.result(), which can raise both compilation exceptions raised insidecompile_unitand genuine programming errors (e.g., aTypeErrorfrom_prepare_compile_executionplumbing). Both paths currently log atDEBUGand continue, which can hide real bugs during development. Consider logging atWARNINGfor non-Exceptionor unexpected types, or at minimum elevating the log level so the user sees that a unit failed without re-running with a debug logger.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/autotuner/tuner.py` around lines 1019 - 1043, The current broad except around future.result() in the compile loop swallows real programming errors from functions like compile_unit and _prepare_compile_execution; change this to (1) catch the specific compilation exception type if one exists (e.g., CompilationError) and handle it as before, and (2) for the general except Exception as e branch, escalate the log to logger.warning and include exc_info=True (and the unit_indexes) so the traceback is visible; if you cannot identify a specific compilation exception type, at minimum log with logger.warning(..., exc_info=True) and then re-raise unexpected exceptions to avoid silently continuing on real bugs.tilelang/autotuner/grouped_compile.py (1)
81-97: ⚖️ Poor tradeoffAdd validation that all device modules have identical attrs before merging.
Currently,
merged_attrsis taken fromlowered_items[0]["device_mod"].attrswithout checking that all subsequent device modules have the same attrs. While all configs are lowered with identical PassContext and target (making identical attrs likely), this assumption is implicit and unenforced. If future changes introduce per-config lowering variations or pass-context differences, attrs could diverge silently without notice. Either assert that all device_mod.attrs match the first, or explicitly document this invariant.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/autotuner/grouped_compile.py` around lines 81 - 97, Validate that every device module's attrs equals the first module's attrs before building merged_device_mod: capture the first attrs into merged_attrs from lowered_items[0]["device_mod"].attrs, then in the loop over lowered_items compare each device_mod.attrs to merged_attrs and raise a RuntimeError (including the config index/item['idx'] and differing attrs) if they differ; keep merged_attrs, merged_funcs, merged_names logic and only construct tvm.IRModule(merged_funcs, attrs=merged_attrs) after the validation passes.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tilelang/autotuner/grouped_compile.py`:
- Around line 50-53: The code assumes program.attrs["global_symbol"] exists and
will crash if attrs is None or missing the key; update the block around
elaborate_func(...) so you defensively obtain the symbol (e.g., check if
program.attrs is truthy and "global_symbol" in program.attrs, or wrap the access
in a try/except catching TypeError/KeyError) and fall back to a safe default or
raise a controlled per-config error that the caller can handle; then construct
unique_symbol (used with with_attr("global_symbol", unique_symbol)) from that
safe value (for example a generated "gc_{idx}" when missing) so a single bad
config doesn't abort the whole grouping.
In `@tilelang/autotuner/tuner.py`:
- Around line 977-980: In multi-GPU mode ref_latency is recomputed per worker
because each worker creates a fresh _BenchmarkWorkerState and hits the
ref_latency_cache is None branch in _benchmark_target, leading
_process_benchmark_result to overwrite ref_latency nondeterministically; to fix,
compute the reference latency once (on the coordinator or the first worker) and
either (a) pre-seed each worker's _BenchmarkWorkerState.ref_latency_cache with
that single value before workers start so _benchmark_target skips re-running
ref_prog, or (b) make _process_benchmark_result only set the global ref_latency
if it is not already set (or if the result comes from the chosen coordinator
device) so the recorded ref_latency is pinned to a single chosen device when
benchmark_multi_gpu=True; modify the coordinator/worker setup code that
instantiates _BenchmarkWorkerState and the logic in
_benchmark_target/_process_benchmark_result to implement one of these
approaches.
- Around line 919-923: The current writeback of main_thread_benchmark_state into
self.jit_input_tensors/ref_input_tensors/ref_latency_cache is a no-op in
multi-GPU mode because workers use fresh _BenchmarkWorkerState instances; update
the code to either (A) maintain a per-device cache on self (e.g.,
self._per_device_jit_input_tensors, self._per_device_ref_input_tensors,
self._per_device_ref_latency_cache keyed by device id) and have each worker
merge its _BenchmarkWorkerState into the per-device entry, or (B) if
benchmark_multi_gpu is True skip the global writeback of
main_thread_benchmark_state into self.* and document that cross-run caching is
per-device only; make the change where main_thread_benchmark_state is created
and where it is currently written back (references: main_thread_benchmark_state,
_BenchmarkWorkerState, benchmark_multi_gpu, self.jit_input_tensors,
ref_input_tensors, ref_latency_cache).
---
Nitpick comments:
In `@tilelang/autotuner/grouped_compile.py`:
- Around line 81-97: Validate that every device module's attrs equals the first
module's attrs before building merged_device_mod: capture the first attrs into
merged_attrs from lowered_items[0]["device_mod"].attrs, then in the loop over
lowered_items compare each device_mod.attrs to merged_attrs and raise a
RuntimeError (including the config index/item['idx'] and differing attrs) if
they differ; keep merged_attrs, merged_funcs, merged_names logic and only
construct tvm.IRModule(merged_funcs, attrs=merged_attrs) after the validation
passes.
In `@tilelang/autotuner/tuner.py`:
- Around line 460-476: The code in _resolve_grouped_compile_mode unnecessarily
calls str() on self.compile_args.execution_backend (already a Literal[str]);
remove the redundant str() conversion and use execution_backend =
self.compile_args.execution_backend directly (or keep a local variable named
execution_backend) so the gate logic remains unchanged; update any related
logging or returned tuple to reference the execution_backend variable instead of
the str(...) call.
- Around line 768-795: The run() signature adds new knobs but lacks validation
and doc clarity: add validation in run() to ensure each entry in
benchmark_devices (or provided list) is an int >= 0 (or raise ValueError) and
ensure incompatible combinations (e.g., benchmark_multi_gpu=True requires
benchmark_devices non-empty) are checked; update the docstring of run() to
explicitly state that timeout is best-effort after worker completion (may not
abort already finished runs) and that use_pipeline=True means benchmarking can
start before all compilations finish (earlier configs may see less-loaded GPUs,
introducing bias); reference the helper _resolve_benchmark_devices in the
validation note so the implementation uses or mirrors that logic.
- Around line 517-536: The cuda_device_wrapper currently captures
torch.cuda.current_device() on the worker thread; change the flow so the device
is captured once on the main thread in _prepare_compile_execution and then
passed into get_compile_func/get_elaborate_func so cuda_device_wrapper uses that
captured device instead of calling torch.cuda.current_device() inside the
worker; specifically, update _prepare_compile_execution to read device =
torch.cuda.current_device() (if cuda available) and thread-safely pass that
device into get_compile_func/get_elaborate_func which should call
cuda_device_wrapper(compile_func, device) / cuda_device_wrapper(elaborate_func,
device) so each worker calls torch.cuda.set_device(device) using the
main-thread-chosen device before invoking compile_func/elaborate_func (ensure
compile_unit still calls get_compile_func/get_elaborate_func as before).
- Around line 597-610: The timeout check in run() is currently post-hoc because
benchmark_target(...) runs inline and cannot be interrupted; update the run()
docstring to explicitly state that the timeout parameter is a post-completion
filter (it cannot abort a hung kernel when benchmark_multi_gpu or use_pipeline
is used) and then change the post-benchmark logic in run(): after calling
benchmark_target(...) compute elapsed and if timeout>0 and elapsed>timeout, do
not treat the result as a successful measurement — put a timeout result into
result_queue.put((idx, config, jit_kernel, None, None, "timeout", "")) (instead
of the success tuple) and avoid raising TimeoutException there; reference the
run() function, benchmark_target, result_queue.put and TimeoutException so you
modify both the docstring and the post-benchmark result handling accordingly.
- Around line 1019-1043: The current broad except around future.result() in the
compile loop swallows real programming errors from functions like compile_unit
and _prepare_compile_execution; change this to (1) catch the specific
compilation exception type if one exists (e.g., CompilationError) and handle it
as before, and (2) for the general except Exception as e branch, escalate the
log to logger.warning and include exc_info=True (and the unit_indexes) so the
traceback is visible; if you cannot identify a specific compilation exception
type, at minimum log with logger.warning(..., exc_info=True) and then re-raise
unexpected exceptions to avoid silently continuing on real bugs.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8054724d-b0b9-413b-9946-39efbc1f76e3
📒 Files selected for processing (5)
examples/gemm/example_gemm_autotune.pytilelang/autotuner/grouped_compile.pytilelang/autotuner/param.pytilelang/autotuner/tuner.pytilelang/engine/lower.py
| program = elaborate_func(**config_arg) | ||
| original_symbol = str(program.attrs["global_symbol"]) | ||
| unique_symbol = f"{original_symbol}_gc_{idx}" | ||
| program = program.with_attr("global_symbol", unique_symbol) |
There was a problem hiding this comment.
Defensive access to global_symbol attribute.
program.attrs["global_symbol"] will raise if attrs is None or missing the key. Although tilelang elaborated PrimFuncs typically carry it, the failure here aborts the whole group rather than just this config — guarding it ensures the per-config error path captures it cleanly.
🛡️ Suggested defensive lookup
- original_symbol = str(program.attrs["global_symbol"])
+ global_symbol_attr = program.attrs.get("global_symbol") if program.attrs else None
+ if global_symbol_attr is None:
+ raise RuntimeError(
+ "Grouped compilation requires PrimFunc to carry the 'global_symbol' attribute"
+ )
+ original_symbol = str(global_symbol_attr)
unique_symbol = f"{original_symbol}_gc_{idx}"🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/autotuner/grouped_compile.py` around lines 50 - 53, The code assumes
program.attrs["global_symbol"] exists and will crash if attrs is None or missing
the key; update the block around elaborate_func(...) so you defensively obtain
the symbol (e.g., check if program.attrs is truthy and "global_symbol" in
program.attrs, or wrap the access in a try/except catching TypeError/KeyError)
and fall back to a safe default or raise a controlled per-config error that the
caller can handle; then construct unique_symbol (used with
with_attr("global_symbol", unique_symbol)) from that safe value (for example a
generated "gc_{idx}" when missing) so a single bad config doesn't abort the
whole grouping.
| main_thread_benchmark_state = _BenchmarkWorkerState( | ||
| jit_input_tensors=self.jit_input_tensors, | ||
| ref_input_tensors=self.ref_input_tensors, | ||
| ref_latency_cache=self.ref_latency_cache, | ||
| ) |
There was a problem hiding this comment.
Input-tensor cache propagation back to self is a no-op in multi-GPU mode.
main_thread_benchmark_state is initialized from self.jit_input_tensors etc. and then read back into self at lines 1065-1067, but in multi-GPU mode (line 995) workers use fresh _BenchmarkWorkerState() instances, so main_thread_benchmark_state is never mutated. The result is that self.jit_input_tensors/ref_input_tensors/ref_latency_cache retain stale or empty values across repeated run() calls when benchmark_multi_gpu=True, defeating the cross-run caching that single-device mode provides. Either keep a per-device map on self or skip the writeback when multi-GPU is active and document the trade-off.
Also applies to: 1065-1067
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/autotuner/tuner.py` around lines 919 - 923, The current writeback of
main_thread_benchmark_state into
self.jit_input_tensors/ref_input_tensors/ref_latency_cache is a no-op in
multi-GPU mode because workers use fresh _BenchmarkWorkerState instances; update
the code to either (A) maintain a per-device cache on self (e.g.,
self._per_device_jit_input_tensors, self._per_device_ref_input_tensors,
self._per_device_ref_latency_cache keyed by device id) and have each worker
merge its _BenchmarkWorkerState into the per-device entry, or (B) if
benchmark_multi_gpu is True skip the global writeback of
main_thread_benchmark_state into self.* and document that cross-run caching is
per-device only; make the change where main_thread_benchmark_state is created
and where it is currently written back (references: main_thread_benchmark_state,
_BenchmarkWorkerState, benchmark_multi_gpu, self.jit_input_tensors,
ref_input_tensors, ref_latency_cache).
There was a problem hiding this comment.
It serves for code simplicity, and there is no significant overhead in adding this line.
There was a problem hiding this comment.
@Wazrrr, understood — since the no-op writeback in multi-GPU mode doesn't affect correctness or introduce measurable overhead, keeping it for simplicity is a reasonable trade-off. Thanks for the clarification!
| if worker_ref_latency is not None: | ||
| ref_latency = worker_ref_latency | ||
| assert latency is not None | ||
| _record_benchmark_result(latency=latency, config=config, jit_kernel=jit_kernel, idx=idx, progress_bar=progress_bar) |
There was a problem hiding this comment.
Multi-GPU benchmark: ref_latency is recomputed per worker, and the recorded value is non-deterministic.
In multi-GPU mode each worker thread is given a fresh _BenchmarkWorkerState() (line 995), so every worker hits the ref_latency_cache is None branch in _benchmark_target and re-runs ref_prog on its own device. Then in _process_benchmark_result (lines 977-978) ref_latency is overwritten with whichever worker's value arrives last, so the final reported ref_latency depends on completion ordering across heterogeneous-perf GPUs. Two improvements worth considering:
- Compute
ref_latencyonce on the coordinator (or first worker) and pre-seed each_BenchmarkWorkerState.ref_latency_cacheso subsequent workers skip the redundant ref benchmark. - Pin the reported
ref_latencyto a single chosen device, or document explicitly that the reported ref latency is from an arbitrary GPU whenbenchmark_multi_gpu=True.
Also applies to: 994-1011
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/autotuner/tuner.py` around lines 977 - 980, In multi-GPU mode
ref_latency is recomputed per worker because each worker creates a fresh
_BenchmarkWorkerState and hits the ref_latency_cache is None branch in
_benchmark_target, leading _process_benchmark_result to overwrite ref_latency
nondeterministically; to fix, compute the reference latency once (on the
coordinator or the first worker) and either (a) pre-seed each worker's
_BenchmarkWorkerState.ref_latency_cache with that single value before workers
start so _benchmark_target skips re-running ref_prog, or (b) make
_process_benchmark_result only set the global ref_latency if it is not already
set (or if the result comes from the chosen coordinator device) so the recorded
ref_latency is pinned to a single chosen device when benchmark_multi_gpu=True;
modify the coordinator/worker setup code that instantiates _BenchmarkWorkerState
and the logic in _benchmark_target/_process_benchmark_result to implement one of
these approaches.
There was a problem hiding this comment.
Multi-GPU benchmark only works for homogeneous GPUs, it's not a problem.
There was a problem hiding this comment.
@Wazrrr, understood — since multi-GPU benchmarking is scoped to homogeneous GPU setups, the ref_latency values computed independently on each worker will be equivalent, making the overwrite in _process_benchmark_result benign. Thanks for the clarification!
✏️ Learnings added
Learnt from: Wazrrr
Repo: tile-ai/tilelang PR: 2159
File: tilelang/autotuner/tuner.py:977-980
Timestamp: 2026-05-07T06:18:06.278Z
Learning: In `tilelang/autotuner/tuner.py`, the multi-GPU benchmark feature (`benchmark_multi_gpu`) is intentionally designed for homogeneous GPU setups only. Each benchmark worker independently computes `ref_latency` via a fresh `_BenchmarkWorkerState`, and the final `ref_latency` is overwritten by whichever worker result arrives last. This is acceptable because all GPUs are homogeneous, so ref_latency values across workers are equivalent. Do not flag this as a non-determinism issue.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (6)
tilelang/autotuner/tuner.py (5)
855-857: 💤 Low valueReassigning
self.jit_compile/self.jit_elaboratemutates instance state on everyrun().
_ensure_jit_functionsreturns either user-provided callables or the defaults, then both are written back toself. On a secondrun()of the same tuner where the user later clearedself.jit_compile, the prior default is now sticky. Minor but possibly surprising for re-use; consider keeping the resolved callables purely local torun().🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/autotuner/tuner.py` around lines 855 - 857, The current run() mutates instance state by assigning the resolved callables from _ensure_jit_functions() back to self.jit_compile and self.jit_elaborate, making defaults “sticky”; instead, call _ensure_jit_functions() and use the returned compile_func and elaborate_func as local variables within run() (do not assign them back to self), so subsequent runs respect changes to self.jit_compile/self.jit_elaborate made by the caller.
959-963: ⚖️ Poor tradeoffStatic modulo assignment is not load-balanced across benchmark workers.
Tasks are distributed by
idx % len(benchmark_task_queues), so worker 0 always handles configs 0, N, 2N, … regardless of how long each config takes. Heterogeneous compile completion order combined with variable benchmark duration will leave faster workers idle while a slower one is still draining its private queue.A single shared
SimpleQueueconsumed by all workers (work-stealing) would let any free worker pick up the next ready kernel and would also simplify shutdown sentinels. Worth doing if benchmark wall-time matters more than per-device locality.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/autotuner/tuner.py` around lines 959 - 963, The current _enqueue_benchmark_task uses static assignment idx % len(benchmark_task_queues) which causes poor load balancing; replace the per-worker queues with a single shared queue (e.g., a multiprocessing.SimpleQueue or queue.SimpleQueue depending on context) and have all benchmark workers consume from that shared queue; update producers to put (jit_kernel, config, idx) into the shared queue instead of benchmark_task_queues[...], remove or adapt any per-queue shutdown sentinels to a single global sentinel, and keep benchmark_expected_results increments the same so result accounting remains correct.
612-703: 💤 Low value
_benchmark_targetcorrectness check uses possibly-stalejit_input_tensors_cacheafter recompute.When
cache_input_tensors=Trueand a shape/dtype mismatch is detected (lines 664-676),jit_input_tensors_cacheis regenerated and the loopbreaks — good. However, on the first iteration wherejit_input_tensors_cache is None(line 651-652),paramswas already fetched but never compared against the freshly-generated tensors; that's fine because there's nothing to compare to. The flow looks correct.One nit: at line 700,
benchmark_state.ref_input_tensors = ref_input_tensors_cachealways writes back, even whenref_prog is None(in which caseref_input_tensors_cacheis whatever it was on entry). Harmless but slightly wasteful. Not blocking.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/autotuner/tuner.py` around lines 612 - 703, The function _benchmark_target currently always writes back benchmark_state.ref_input_tensors and benchmark_state.ref_latency_cache even when ref_prog is None; change the final write-back so that benchmark_state.ref_input_tensors and benchmark_state.ref_latency_cache are only updated when ref_prog is not None (leave benchmark_state.jit_input_tensors update as-is), locating these symbols in the _benchmark_target function and gating the assignments to ref_input_tensors_cache/ref_latency_cache behind a check of ref_prog.
770-797: 💤 Low valueValidate other public-API parameters too.
group_compile_sizecorrectly raises on<= 0(line 796-797), butwarmup,rep,timeoutaccept any int andbenchmark_devicesaccepts arbitrary ints (parsed in_resolve_benchmark_devices). Not a blocker — invalid values fail later with less obvious messages — but a small early validation block here would improve UX.partial(autotuner.run, self.warmup, self.rep, self.timeout)at line 1160 still resolves correctly with the new keyword-only additions.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/autotuner/tuner.py` around lines 770 - 797, The run method currently only validates group_compile_size; add early validation in run to check warmup is >= 0, rep and timeout are > 0 and all three are integers, and validate benchmark_devices (if not None) is an iterable of non-negative ints (no negative ordinals) before calling _resolve_benchmark_devices; keep existing group_compile_size check and don’t change the function signature so partial(autotuner.run, self.warmup, self.rep, self.timeout) still works—raise ValueError with clear messages for each invalid parameter.
460-476: 💤 Low valueSilent fallback when grouped compilation isn't supported.
grouped_compile_reasonis logged at INFO and then the third return value (grouped_compile_reason) is unused at the call site (line 874 destructures_, _, grouped_compile_active, _). For a feature the user explicitly opted into via--grouped-compile, surfacing the rejection at WARNING (or echoing it in the progress description / final summary) makes the fallback discoverable when users wonder why their--group-compile-size 4had no effect.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/autotuner/tuner.py` around lines 460 - 476, The grouped-compile fallback message is only logged at INFO and its return value is being ignored by the caller, so users don't see why their --grouped-compile was disabled; change the logger call in _resolve_grouped_compile_mode to logger.warning(...) when grouped_compile_requested and not active, and update the caller that currently destructures as "_, _, grouped_compile_active, _" to capture the fourth return (grouped_compile_reason) and surface it into the progress description or final summary so the warning is visible to users.tilelang/autotuner/grouped_compile.py (1)
80-96: ⚖️ Poor tradeoffGroup-wide failure on merge errors is coarse-grained.
A single duplicate
name_hint(line 91-93) or any failure inmerged_device_modconstruction or shareddevice_codegen(line 100) is caught by the broadexceptat line 151, which then attributes the error to every successfully lowered config in the group. That makes diagnostics misleading (one bad config taints siblings) and prevents partial progress for the group.If feasible, detect duplicate
name_hintearly enough to mark only the collidingidxas failed, and only fall back to broadcasting the error if codegen of the merged module truly cannot proceed without the offending function.Also applies to: 151-153
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/autotuner/grouped_compile.py` around lines 80 - 96, The merge currently treats any duplicate name_hint or any failure during merged IRModule construction / device_codegen as a group-wide error; instead, detect and isolate failing configs: build a map of name_hint -> list of item['idx'] from device_mod.functions and if any name_hint maps to multiple idxs mark only those colliding idxs as failed (record their indices and skip adding their funcs to merged_funcs), then attempt to construct merged_device_mod from the remaining funcs; wrap tvm.IRModule(...) in a try/except and if it fails, iterate remaining items and try constructing per-item IRModule or constructing merged IRModule while removing one item at a time to identify and mark specific failing item['idx'] values, then rebuild merged_device_mod without those; similarly, if device_codegen(merged_device_mod) raises, run device_codegen on each surviving single-item module to localize and mark only the failing configs instead of blaming the whole group (refer to merged_funcs, merged_names, merged_device_mod, device_mod.functions, lowered_items, and item['idx'] to locate code).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tilelang/autotuner/tuner.py`:
- Around line 596-610: Worker threads can hang when benchmark_target blocks and
then deadlock on worker_thread.join; set the worker threads to daemon=True and
make the worker loop enforce a hard timeout around benchmark_target so a blocked
benchmark is recorded as a timeout: in the worker creation code that spawns the
thread (the place that later calls worker_thread.join), mark the Thread as
daemon=True; inside the worker loop that currently calls
benchmark_target(jit_kernel=..., benchmark_state=worker_state) wrap that call in
a short-lived sub-thread or use a timed wait (e.g. start a Thread for
benchmark_target and join it with the configured timeout) and on timeout put
(idx, config, jit_kernel, None, None, "timeout", "") into result_queue; ensure
any exception path still puts the error tuple into result_queue; also bound
queue.get/result_queue.get waits in _drain_benchmark_results to avoid blocking
forever so the orchestrator can continue.
---
Nitpick comments:
In `@tilelang/autotuner/grouped_compile.py`:
- Around line 80-96: The merge currently treats any duplicate name_hint or any
failure during merged IRModule construction / device_codegen as a group-wide
error; instead, detect and isolate failing configs: build a map of name_hint ->
list of item['idx'] from device_mod.functions and if any name_hint maps to
multiple idxs mark only those colliding idxs as failed (record their indices and
skip adding their funcs to merged_funcs), then attempt to construct
merged_device_mod from the remaining funcs; wrap tvm.IRModule(...) in a
try/except and if it fails, iterate remaining items and try constructing
per-item IRModule or constructing merged IRModule while removing one item at a
time to identify and mark specific failing item['idx'] values, then rebuild
merged_device_mod without those; similarly, if device_codegen(merged_device_mod)
raises, run device_codegen on each surviving single-item module to localize and
mark only the failing configs instead of blaming the whole group (refer to
merged_funcs, merged_names, merged_device_mod, device_mod.functions,
lowered_items, and item['idx'] to locate code).
In `@tilelang/autotuner/tuner.py`:
- Around line 855-857: The current run() mutates instance state by assigning the
resolved callables from _ensure_jit_functions() back to self.jit_compile and
self.jit_elaborate, making defaults “sticky”; instead, call
_ensure_jit_functions() and use the returned compile_func and elaborate_func as
local variables within run() (do not assign them back to self), so subsequent
runs respect changes to self.jit_compile/self.jit_elaborate made by the caller.
- Around line 959-963: The current _enqueue_benchmark_task uses static
assignment idx % len(benchmark_task_queues) which causes poor load balancing;
replace the per-worker queues with a single shared queue (e.g., a
multiprocessing.SimpleQueue or queue.SimpleQueue depending on context) and have
all benchmark workers consume from that shared queue; update producers to put
(jit_kernel, config, idx) into the shared queue instead of
benchmark_task_queues[...], remove or adapt any per-queue shutdown sentinels to
a single global sentinel, and keep benchmark_expected_results increments the
same so result accounting remains correct.
- Around line 612-703: The function _benchmark_target currently always writes
back benchmark_state.ref_input_tensors and benchmark_state.ref_latency_cache
even when ref_prog is None; change the final write-back so that
benchmark_state.ref_input_tensors and benchmark_state.ref_latency_cache are only
updated when ref_prog is not None (leave benchmark_state.jit_input_tensors
update as-is), locating these symbols in the _benchmark_target function and
gating the assignments to ref_input_tensors_cache/ref_latency_cache behind a
check of ref_prog.
- Around line 770-797: The run method currently only validates
group_compile_size; add early validation in run to check warmup is >= 0, rep and
timeout are > 0 and all three are integers, and validate benchmark_devices (if
not None) is an iterable of non-negative ints (no negative ordinals) before
calling _resolve_benchmark_devices; keep existing group_compile_size check and
don’t change the function signature so partial(autotuner.run, self.warmup,
self.rep, self.timeout) still works—raise ValueError with clear messages for
each invalid parameter.
- Around line 460-476: The grouped-compile fallback message is only logged at
INFO and its return value is being ignored by the caller, so users don't see why
their --grouped-compile was disabled; change the logger call in
_resolve_grouped_compile_mode to logger.warning(...) when
grouped_compile_requested and not active, and update the caller that currently
destructures as "_, _, grouped_compile_active, _" to capture the fourth return
(grouped_compile_reason) and surface it into the progress description or final
summary so the warning is visible to users.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f01d791b-c952-42c4-bfc7-3d8590f5dd27
📒 Files selected for processing (4)
examples/gemm/example_gemm_autotune.pytilelang/autotuner/grouped_compile.pytilelang/autotuner/tuner.pytilelang/engine/lower.py
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/gemm/example_gemm_autotune.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tilelang/autotuner/tuner.py`:
- Around line 600-637: The nested function _run_benchmark_target closes over
loop-local jit_kernel and call_result_queue by reference, so if a previous
daemon benchmark_call_thread times out and the loop rebinds those names the
still-running thread will use the new values and corrupt results; fix by binding
the current values into the thread callback so they are captured by value (e.g.,
create local vars like _jit_kernel = jit_kernel and _call_result_queue =
call_result_queue or pass them as default arguments to _run_benchmark_target)
and then use those bound names when calling benchmark_target and putting
results, ensuring the thread always uses the intended kernel and queue for that
iteration before starting benchmark_call_thread and joining; keep the existing
TimeoutException and result_queue handling (idx, config) unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 12b8aa86-c52f-46c4-a707-5ba48dca429e
📒 Files selected for processing (1)
tilelang/autotuner/tuner.py
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tilelang/autotuner/tuner.py (1)
1114-1116: ⚡ Quick win
pool.shutdown()does not cancel pending compile futures on early exit.On the normal completion path this is fine (all futures already drained). But on an exception or
KeyboardInterrupt, any not-yet-started compilation tasks remain queued andpool.shutdown()blocks until each runs to completion. Since this codebase supports Python 3.9+,cancel_futures=Trueis available and should be passed to shed the pending work immediately.♻️ Proposed fix
- pool.shutdown() + pool.shutdown(cancel_futures=True)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/autotuner/tuner.py` around lines 1114 - 1116, The current shutdown sequence closes compile_progress and progress_bar then calls pool.shutdown() which will block waiting for queued compile futures to run on early exit; update the shutdown call in tuner.py to call pool.shutdown(cancel_futures=True) so pending compile futures are cancelled immediately (use this argument in the exception/KeyboardInterrupt/early-exit cleanup path where compile_progress, progress_bar are closed) to avoid blocking; reference the existing variables compile_progress, progress_bar and the ThreadPool/ProcessPool instance pool when making this change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tilelang/autotuner/tuner.py`:
- Around line 707-712: The current shape_equal helper silently accepts
different-rank tensors because it zips shapes; change shape_equal (used where p
and c are compared before do_bench) to first require equal ranks (e.g., if
len(a.shape) != len(b.shape) return False) and only then compare each dimension
allowing Vars (keep the existing isinstance(..., Var) logic). Update the
function used in the if condition (shape_equal) so mismatched ranks fail fast
and prevent stale tensors from reaching do_bench.
- Around line 735-744: The call to profiler.do_bench at the top of the block is
passing iteration-count variables warmup and rep using the wrong parameter names
(warmup=, rep=) causing them to be interpreted as millisecond targets; update
the profiler.do_bench invocation that assigns latency (the one using
jit_input_tensors_cache and backend) to pass these as n_warmup=warmup and
n_repeat=rep to match the later ref_prog call and the intended semantics.
---
Nitpick comments:
In `@tilelang/autotuner/tuner.py`:
- Around line 1114-1116: The current shutdown sequence closes compile_progress
and progress_bar then calls pool.shutdown() which will block waiting for queued
compile futures to run on early exit; update the shutdown call in tuner.py to
call pool.shutdown(cancel_futures=True) so pending compile futures are cancelled
immediately (use this argument in the exception/KeyboardInterrupt/early-exit
cleanup path where compile_progress, progress_bar are closed) to avoid blocking;
reference the existing variables compile_progress, progress_bar and the
ThreadPool/ProcessPool instance pool when making this change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c74dd60d-b354-4112-be79-52e31158806f
📒 Files selected for processing (1)
tilelang/autotuner/tuner.py
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
tilelang/autotuner/tuner.py (1)
738-748:⚠️ Potential issue | 🟠 Major | ⚡ Quick winUse the same
do_bench()parameter mode for kernel and reference timings.At line 738,
warmupandrep(iteration counts) are passed towarmup=andrep=parameters, which are treated as millisecond targets with auto-calculated iteration counts. At lines 742–748, the same variables are passed ton_warmup=andn_repeat=parameters, which bypass auto-calculation and use them as explicit iteration counts. This inconsistency causes the kernel and reference timings to use different benchmarking semantics.Expected fix
- latency = profiler.do_bench(warmup=warmup, rep=rep, input_tensors=jit_input_tensors_cache, backend=backend) + latency = profiler.do_bench( + n_warmup=warmup, + n_repeat=rep, + input_tensors=jit_input_tensors_cache, + backend=backend, + )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/autotuner/tuner.py` around lines 738 - 748, The kernel and reference benchmarking calls use different do_bench parameter modes: the first call passes warmup= and rep= (time-target mode) while the reference call passes n_warmup= and n_repeat= (explicit-iteration mode), causing inconsistent semantics; update the reference call inside the ref_latency_cache block (the profiler.do_bench invocation that currently uses ref_prog, n_warmup, n_repeat, input_tensors=ref_input_tensors_cache, backend=backend) to use the same parameter names as the kernel timing (warmup= and rep=) so both measurements use the same benchmarking mode and inputs supplied by ref_input_tensors_supply.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tilelang/autotuner/tuner.py`:
- Around line 372-377: The current supply_prog closure freezes a single list of
captured tensors (frozen_inputs) and returns the same objects to every worker,
causing cross-device CUDA tensor reuse; instead, change supply_prog (the lambda
assigned where get_autotune_inputs() is used) to produce a per-device cache:
when first called for a given device id, materialize/move/clone the frozen
inputs onto that device and store them in a dict keyed by device, then return
the device-specific list on subsequent calls. Use the existing
get_autotune_inputs(), frozen_inputs, and supply_prog symbols to implement a
small per-device memoization (e.g., device_cache = {}) and ensure the
supply_prog signature reads the device argument and returns
device_cache[device].
- Around line 785-794: The current logic sets requested_devices to single_device
when benchmark_devices is None and CUDA_VISIBLE_DEVICES is unset, which prevents
multi-GPU benchmarking (benchmark_multi_gpu) from using all GPUs; change the
fallback to detect the total available device count and set requested_devices to
list(range(device_count)) (using the framework's GPU count API, e.g.,
torch.cuda.device_count() or equivalent) instead of single_device, while
preserving the existing behavior when CUDA_VISIBLE_DEVICES is present or
benchmark_devices is provided; adjust references in this block that assign
requested_devices, and ensure single_device remains available as a last-resort
fallback if device_count == 0.
- Around line 607-628: The per-call thread currently reuses the shared
_BenchmarkWorkerState (worker_state) causing races when a thread times out; to
fix, create a per-call deep copy of worker_state (e.g.,
copy.deepcopy(worker_state)) and pass that copy into _run_benchmark_target
(rename param to _worker_state_copy) so the daemon thread only mutates its
private state; when you read the call_result_queue and see a successful ("ok",
...) result, merge the relevant mutable pieces (cached tensors,
ref_latency_cache, or other worker-state caches) from the per-call copy back
into the original worker_state in a well-defined merge step; ensure
timeout/error paths keep the original worker_state unchanged and only the "ok"
branch performs the merge, leaving variable names _run_benchmark_target,
benchmark_target, _BenchmarkWorkerState, call_result_queue, and result_queue
intact for easy location.
---
Duplicate comments:
In `@tilelang/autotuner/tuner.py`:
- Around line 738-748: The kernel and reference benchmarking calls use different
do_bench parameter modes: the first call passes warmup= and rep= (time-target
mode) while the reference call passes n_warmup= and n_repeat=
(explicit-iteration mode), causing inconsistent semantics; update the reference
call inside the ref_latency_cache block (the profiler.do_bench invocation that
currently uses ref_prog, n_warmup, n_repeat,
input_tensors=ref_input_tensors_cache, backend=backend) to use the same
parameter names as the kernel timing (warmup= and rep=) so both measurements use
the same benchmarking mode and inputs supplied by ref_input_tensors_supply.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: be1be661-be78-477b-9049-e1fd6cb007f2
📒 Files selected for processing (1)
tilelang/autotuner/tuner.py
| captured_inputs = get_autotune_inputs() | ||
| if captured_inputs is not None: | ||
| if supply_prog is not None: | ||
| logger.warning("`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context.") | ||
| supply_prog = lambda _: get_autotune_inputs() # noqa: E731 | ||
| frozen_inputs = list(captured_inputs) | ||
| supply_prog = lambda _, _frozen_inputs=frozen_inputs: _frozen_inputs # noqa: E731 |
There was a problem hiding this comment.
Freeze autotune inputs per device, not once globally.
This lambda returns the exact captured tensor objects to every benchmark worker. In multi-GPU mode, CUDA tensors captured on one device get reused on the other workers, which can either fail with device mismatches or benchmark the wrong GPU. Materialize/cache per-device inputs instead of closing over one shared list.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/autotuner/tuner.py` around lines 372 - 377, The current supply_prog
closure freezes a single list of captured tensors (frozen_inputs) and returns
the same objects to every worker, causing cross-device CUDA tensor reuse;
instead, change supply_prog (the lambda assigned where get_autotune_inputs() is
used) to produce a per-device cache: when first called for a given device id,
materialize/move/clone the frozen inputs onto that device and store them in a
dict keyed by device, then return the device-specific list on subsequent calls.
Use the existing get_autotune_inputs(), frozen_inputs, and supply_prog symbols
to implement a small per-device memoization (e.g., device_cache = {}) and ensure
the supply_prog signature reads the device argument and returns
device_cache[device].
| def _run_benchmark_target( | ||
| _jit_kernel: tilelang.JITKernel = jit_kernel, | ||
| _worker_state: _BenchmarkWorkerState = worker_state, | ||
| _call_result_queue: queue.SimpleQueue = call_result_queue, | ||
| ): | ||
| try: | ||
| latency, worker_ref_latency = benchmark_target( | ||
| jit_kernel=_jit_kernel, | ||
| benchmark_state=_worker_state, | ||
| ) | ||
| _call_result_queue.put(("ok", latency, worker_ref_latency, "")) | ||
| except TimeoutException: | ||
| _call_result_queue.put(("timeout", None, None, "")) | ||
| except Exception: | ||
| _call_result_queue.put(("error", None, None, traceback.format_exc())) | ||
|
|
||
| benchmark_call_thread = threading.Thread(target=_run_benchmark_target, daemon=True) | ||
| benchmark_call_thread.start() | ||
| benchmark_call_thread.join(timeout=timeout) | ||
| if benchmark_call_thread.is_alive(): | ||
| result_queue.put((idx, config, jit_kernel, None, None, "timeout", "")) | ||
| continue |
There was a problem hiding this comment.
A timed-out benchmark can corrupt the next task on the same worker.
When join(timeout=...) expires, the daemon sub-thread keeps running with the same worker_state object that the next config will reuse. If that late call eventually updates cached tensors or ref_latency_cache, it races with the next benchmark and can skew later results. Pass a per-call copy of _BenchmarkWorkerState into _run_benchmark_target and merge it back only after a successful completion.
Minimal containment fix
call_result_queue: queue.SimpleQueue = queue.SimpleQueue()
+ call_state = _BenchmarkWorkerState(
+ jit_input_tensors=worker_state.jit_input_tensors,
+ ref_input_tensors=worker_state.ref_input_tensors,
+ ref_latency_cache=worker_state.ref_latency_cache,
+ )
def _run_benchmark_target(
_jit_kernel: tilelang.JITKernel = jit_kernel,
- _worker_state: _BenchmarkWorkerState = worker_state,
+ _worker_state: _BenchmarkWorkerState = call_state,
_call_result_queue: queue.SimpleQueue = call_result_queue,
):
try:
@@
if status == "ok":
+ worker_state.jit_input_tensors = call_state.jit_input_tensors
+ worker_state.ref_input_tensors = call_state.ref_input_tensors
+ worker_state.ref_latency_cache = call_state.ref_latency_cache
result_queue.put((idx, config, jit_kernel, latency, worker_ref_latency, None, ""))🧰 Tools
🪛 Ruff (0.15.12)
[warning] 620-620: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/autotuner/tuner.py` around lines 607 - 628, The per-call thread
currently reuses the shared _BenchmarkWorkerState (worker_state) causing races
when a thread times out; to fix, create a per-call deep copy of worker_state
(e.g., copy.deepcopy(worker_state)) and pass that copy into
_run_benchmark_target (rename param to _worker_state_copy) so the daemon thread
only mutates its private state; when you read the call_result_queue and see a
successful ("ok", ...) result, merge the relevant mutable pieces (cached
tensors, ref_latency_cache, or other worker-state caches) from the per-call copy
back into the original worker_state in a well-defined merge step; ensure
timeout/error paths keep the original worker_state unchanged and only the "ok"
branch performs the merge, leaving variable names _run_benchmark_target,
benchmark_target, _BenchmarkWorkerState, call_result_queue, and result_queue
intact for easy location.
| requested_devices: list[int] = [] | ||
| if benchmark_devices: | ||
| requested_devices = list(dict.fromkeys(int(device) for device in benchmark_devices)) | ||
| else: | ||
| raw_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "") | ||
| parsed_visible_devices = [token.strip() for token in raw_visible_devices.split(",") if token.strip()] | ||
| if parsed_visible_devices: | ||
| requested_devices = list(range(len(parsed_visible_devices))) | ||
| else: | ||
| requested_devices = single_device |
There was a problem hiding this comment.
benchmark_multi_gpu=True still resolves to one GPU on the common path.
If callers do not pass benchmark_devices and CUDA_VISIBLE_DEVICES is unset, this branch falls back to single_device, so a multi-GPU host never enables sharding by default. Using all visible devices here would make the flag behave as advertised.
Suggested change
else:
raw_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
parsed_visible_devices = [token.strip() for token in raw_visible_devices.split(",") if token.strip()]
if parsed_visible_devices:
requested_devices = list(range(len(parsed_visible_devices)))
else:
- requested_devices = single_device
+ requested_devices = list(range(visible_device_count))🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/autotuner/tuner.py` around lines 785 - 794, The current logic sets
requested_devices to single_device when benchmark_devices is None and
CUDA_VISIBLE_DEVICES is unset, which prevents multi-GPU benchmarking
(benchmark_multi_gpu) from using all GPUs; change the fallback to detect the
total available device count and set requested_devices to
list(range(device_count)) (using the framework's GPU count API, e.g.,
torch.cuda.device_count() or equivalent) instead of single_device, while
preserving the existing behavior when CUDA_VISIBLE_DEVICES is present or
benchmark_devices is provided; adjust references in this block that assign
requested_devices, and ensure single_device remains available as a last-resort
fallback if device_count == 0.






Summary
This PR adds pipeline mode, grouped compilation, and multi-GPU benchmark support to autotuning.
Key Updates
1) Grouped compilation (tvm-ffi only)
tilelang/autotuner/grouped_compile.pytilelang/autotuner/tuner.pytilelang/autotuner/param.pyandtilelang/engine/lower.py2) Pipeline mode via tuner runtime refactor
AutoTuner.run()logic intilelang/autotuner/tuner.pyto support pipeline execution flow.3) Multi-GPU benchmark support
4) Example entrypoint support
examples/gemm/example_gemm_autotune.pymain entry to accept and forward:--pipeline/--no-pipeline--grouped-compile/--no-grouped-compile--group-compile-size--benchmark-multi-gpu/--no-benchmark-multi-gpu--benchmark-devices(repeatable)Summary by CodeRabbit
New Features
Refactor