Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 99 additions & 10 deletions examples/gemm/example_gemm_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,53 @@ def get_best_config(
K,
with_roller: bool = False,
profile_backend: str = "event",
execution_backend: str = "auto",
warmup: int = 3,
rep: int = 20,
timeout: int = 30,
skip_check: bool = False,
cache_input_tensors: bool = False,
topk: int = 20,
use_pipeline: bool = False,
enable_grouped_compile: bool = False,
group_compile_size: int = 2,
benchmark_multi_gpu: bool = False,
benchmark_devices: list[int] | None = None,
):
autotuner, _, _ = _build_autotuner(
M=M,
N=N,
K=K,
with_roller=with_roller,
profile_backend=profile_backend,
execution_backend=execution_backend,
skip_check=skip_check,
cache_input_tensors=cache_input_tensors,
topk=topk,
)
autotuner_result = autotuner.run(
warmup=warmup,
rep=rep,
timeout=timeout,
use_pipeline=use_pipeline,
enable_grouped_compile=enable_grouped_compile,
group_compile_size=group_compile_size,
benchmark_multi_gpu=benchmark_multi_gpu,
benchmark_devices=benchmark_devices,
)
return autotuner_result


def _build_autotuner(
M: int,
N: int,
K: int,
with_roller: bool,
profile_backend: str,
execution_backend: str,
skip_check: bool,
cache_input_tensors: bool,
topk: int,
):
def kernel(
block_M=None,
Expand Down Expand Up @@ -152,20 +199,23 @@ def main(

return main

configs = get_configs(M, N, K, with_roller, topk=topk)
autotuner = (
AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller))
AutoTuner.from_kernel(kernel=kernel, configs=configs)
.set_compile_args(
out_idx=[-1],
target="auto",
execution_backend=execution_backend,
)
.set_profile_args(
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
ref_prog=None if skip_check else ref_program,
skip_check=skip_check,
cache_input_tensors=cache_input_tensors,
backend=profile_backend,
)
)
return autotuner.run(warmup=3, rep=20)
return autotuner, configs, kernel


def get_heuristic_config() -> dict:
Expand Down Expand Up @@ -221,14 +271,26 @@ def main(
use_autotune: bool = False,
with_roller: bool = False,
profile_backend: str = "event",
use_pipeline: bool = False,
enable_grouped_compile: bool = False,
group_compile_size: int = 2,
benchmark_multi_gpu: bool = False,
benchmark_devices: list[int] | None = None,
):
benchmark_devices = benchmark_devices or []

if use_autotune:
result = get_best_config(
M,
N,
K,
with_roller=with_roller,
profile_backend=profile_backend,
use_pipeline=use_pipeline,
enable_grouped_compile=enable_grouped_compile,
group_compile_size=group_compile_size,
benchmark_multi_gpu=benchmark_multi_gpu,
benchmark_devices=benchmark_devices,
)
print(result.config)
kernel = result.kernel
Expand Down Expand Up @@ -267,12 +329,39 @@ def run_regression_perf(M: int = 4096, N: int = 4096, K: int = 4096):
parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs")
parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space")
parser.add_argument("--profile_backend", type=str, default="event", help="Profiler backend")
pipeline_group = parser.add_mutually_exclusive_group()
pipeline_group.add_argument("--pipeline", dest="use_pipeline", action="store_true", help="Enable compile/benchmark pipeline in autotune")
pipeline_group.add_argument("--no-pipeline", dest="use_pipeline", action="store_false", help="Disable compile/benchmark pipeline in autotune")
parser.set_defaults(use_pipeline=False)

grouped_compile_group = parser.add_mutually_exclusive_group()
grouped_compile_group.add_argument("--grouped-compile", dest="enable_grouped_compile", action="store_true", help="Enable grouped compilation in autotune")
grouped_compile_group.add_argument("--no-grouped-compile", dest="enable_grouped_compile", action="store_false", help="Disable grouped compilation in autotune")
parser.set_defaults(enable_grouped_compile=False)
parser.add_argument("--group-compile-size", type=int, default=2, help="Number of configs per grouped compile unit")

benchmark_multi_gpu_group = parser.add_mutually_exclusive_group()
benchmark_multi_gpu_group.add_argument("--benchmark-multi-gpu", dest="benchmark_multi_gpu", action="store_true", help="Benchmark autotune configs across multiple GPUs")
benchmark_multi_gpu_group.add_argument("--no-benchmark-multi-gpu", dest="benchmark_multi_gpu", action="store_false", help="Benchmark autotune configs on a single GPU")
parser.set_defaults(benchmark_multi_gpu=False)
parser.add_argument(
"--benchmark-devices",
action="append",
type=int,
default=[],
help="Repeatable CUDA device ordinals for benchmark workers (e.g. --benchmark-devices 0 --benchmark-devices 1)",
)
args = parser.parse_args()
main(
args.m,
args.n,
args.k,
args.use_autotune,
args.with_roller,
args.profile_backend,
M=args.m,
N=args.n,
K=args.k,
use_autotune=args.use_autotune,
with_roller=args.with_roller,
profile_backend=args.profile_backend,
use_pipeline=args.use_pipeline,
enable_grouped_compile=args.enable_grouped_compile,
group_compile_size=args.group_compile_size,
benchmark_multi_gpu=args.benchmark_multi_gpu,
benchmark_devices=args.benchmark_devices,
)
156 changes: 156 additions & 0 deletions tilelang/autotuner/grouped_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Grouped compilation helpers for autotuner.

This module isolates backend-aware grouped compilation logic from AutoTuner.run
so tuner.py can stay focused on orchestration.
"""

from __future__ import annotations

from typing import Any, Callable

from tilelang import tvm
from tvm.tir import PrimFunc

from tilelang.autotuner.param import CompileArgs
from tilelang.engine.lower import lower_to_host_device_ir, device_codegen, host_codegen
from tilelang.engine.param import CompiledArtifact
from tilelang.jit.adapter import TVMFFIKernelAdapter
from tilelang.jit.kernel import JITKernel
from tilelang.transform import PassConfigKey

CompileUnitResult = tuple[int, dict[str, Any], JITKernel | None, Exception | None]


def compile_grouped_unit_tvm_ffi(
unit_items: list[tuple[int, dict[str, Any]]],
compile_args: CompileArgs,
elaborate_func: Callable[..., PrimFunc],
) -> list[CompileUnitResult]:
"""Compile one grouped unit for CUDA+tvm_ffi backend.

Flow:
1. Elaborate each config into a PrimFunc.
2. Lower each PrimFunc into host/device IR modules.
3. Merge all device IR into one IRModule and compile device code once.
4. Build host runtime module per config and import shared device module.
5. Construct per-config JITKernel objects that share the grouped device module.
"""

pass_configs = dict(compile_args.pass_configs) if compile_args.pass_configs else {}
pass_instruments = []
if pass_configs.get(PassConfigKey.TL_ENABLE_DUMP_IR):
dump_ir_path = pass_configs.get(PassConfigKey.TL_DUMP_IR_DIR, "./dump_ir")
pass_instruments.append(tvm.ir.instrument.DumpIR(dump_dir=dump_ir_path))

unit_results: list[CompileUnitResult] = []
lowered_items: list[dict[str, Any]] = []

for idx, config_arg in unit_items:
try:
program = elaborate_func(**config_arg)
original_symbol = str(program.attrs["global_symbol"])
unique_symbol = f"{original_symbol}_gc_{idx}"
program = program.with_attr("global_symbol", unique_symbol)
Comment on lines +50 to +53
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Defensive access to global_symbol attribute.

program.attrs["global_symbol"] will raise if attrs is None or missing the key. Although tilelang elaborated PrimFuncs typically carry it, the failure here aborts the whole group rather than just this config — guarding it ensures the per-config error path captures it cleanly.

🛡️ Suggested defensive lookup
-            original_symbol = str(program.attrs["global_symbol"])
+            global_symbol_attr = program.attrs.get("global_symbol") if program.attrs else None
+            if global_symbol_attr is None:
+                raise RuntimeError(
+                    "Grouped compilation requires PrimFunc to carry the 'global_symbol' attribute"
+                )
+            original_symbol = str(global_symbol_attr)
             unique_symbol = f"{original_symbol}_gc_{idx}"
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/autotuner/grouped_compile.py` around lines 50 - 53, The code assumes
program.attrs["global_symbol"] exists and will crash if attrs is None or missing
the key; update the block around elaborate_func(...) so you defensively obtain
the symbol (e.g., check if program.attrs is truthy and "global_symbol" in
program.attrs, or wrap the access in a try/except catching TypeError/KeyError)
and fall back to a safe default or raise a controlled per-config error that the
caller can handle; then construct unique_symbol (used with
with_attr("global_symbol", unique_symbol)) from that safe value (for example a
generated "gc_{idx}" when missing) so a single bad config doesn't abort the
whole grouping.


with tvm.transform.PassContext(opt_level=3, config=pass_configs, instruments=pass_instruments), compile_args.target:
host_mod, device_mod, params, normalized_target, normalized_target_host = lower_to_host_device_ir(
program,
target=compile_args.target,
target_host=compile_args.target_host,
)

lowered_items.append(
{
"idx": idx,
"config_arg": config_arg,
"program": program,
"host_mod": host_mod,
"device_mod": device_mod,
"params": params,
"target": normalized_target,
"target_host": normalized_target_host,
}
)
except Exception as e:
unit_results.append((idx, config_arg, None, e))

if not lowered_items:
return unit_results

try:
merged_funcs: dict[Any, Any] = {}
merged_attrs = None
merged_names: set[str] = set()
for item in lowered_items:
device_mod = item["device_mod"]
if merged_attrs is None:
merged_attrs = device_mod.attrs
for global_var, func in device_mod.functions.items():
name_hint = getattr(global_var, "name_hint", str(global_var))
if name_hint in merged_names:
raise RuntimeError(
f"Duplicate device global symbol '{name_hint}' during grouped compilation "
f"(config index={item['idx']})."
)
merged_names.add(name_hint)
merged_funcs[global_var] = func
merged_device_mod = tvm.IRModule(merged_funcs, attrs=merged_attrs)

reference_target = lowered_items[0]["target"]
with tvm.transform.PassContext(opt_level=3, config=pass_configs, instruments=pass_instruments), reference_target:
grouped_device_rt_mod = device_codegen(merged_device_mod, reference_target)

grouped_kernel_source = grouped_device_rt_mod.inspect_source()

for item in lowered_items:
idx = item["idx"]
config_arg = item["config_arg"]
try:
with tvm.transform.PassContext(opt_level=3, config=pass_configs, instruments=pass_instruments), item["target"]:
grouped_host_rt_mod = host_codegen(item["host_mod"], item["target_host"], target=item["target"])

grouped_host_rt_mod.import_module(grouped_device_rt_mod)

artifact = CompiledArtifact(
host_mod=grouped_host_rt_mod,
device_mod=item["device_mod"],
params=item["params"],
kernel_source=grouped_kernel_source,
rt_mod=grouped_host_rt_mod,
)

adapter = TVMFFIKernelAdapter(
params=artifact.params,
result_idx=compile_args.out_idx,
target=compile_args.target,
func_or_mod=item["program"],
host_mod=artifact.host_mod,
device_mod=artifact.device_mod,
rt_mod=artifact.rt_mod,
device_kernel_source=artifact.kernel_source,
verbose=compile_args.verbose,
pass_configs=pass_configs,
)

jit_kernel = JITKernel(
func=item["program"],
out_idx=compile_args.out_idx,
execution_backend=compile_args.execution_backend,
target=compile_args.target,
target_host=compile_args.target_host,
verbose=compile_args.verbose,
pass_configs=pass_configs,
from_database=True,
)
jit_kernel.artifact = artifact
jit_kernel.adapter = adapter
jit_kernel.torch_function = adapter.func

unit_results.append((idx, config_arg, jit_kernel, None))
except Exception as e:
unit_results.append((idx, config_arg, None, e))
except Exception as e:
for item in lowered_items:
unit_results.append((item["idx"], item["config_arg"], None, e))

return unit_results
1 change: 1 addition & 0 deletions tilelang/autotuner/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def compile_program(self, program: PrimFunc):
return tilelang.compile(
program,
out_idx=self.out_idx,
execution_backend=self.execution_backend,
target=self.target,
target_host=self.target_host,
verbose=self.verbose,
Expand Down
Loading