Skip to content
121 changes: 111 additions & 10 deletions examples/gemm/example_gemm_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,53 @@ def get_best_config(
K,
with_roller: bool = False,
profile_backend: str = "event",
execution_backend: str = "auto",
warmup: int = 3,
rep: int = 20,
timeout: int = 30,
skip_check: bool = False,
cache_input_tensors: bool = False,
topk: int = 20,
use_pipeline: bool = False,
enable_grouped_compile: bool = False,
group_compile_size: int = 2,
benchmark_multi_gpu: bool = False,
benchmark_devices: list[int] | None = None,
):
autotuner, _, _ = _build_autotuner(
M=M,
N=N,
K=K,
with_roller=with_roller,
profile_backend=profile_backend,
execution_backend=execution_backend,
skip_check=skip_check,
cache_input_tensors=cache_input_tensors,
topk=topk,
)
autotuner_result = autotuner.run(
warmup=warmup,
rep=rep,
timeout=timeout,
use_pipeline=use_pipeline,
enable_grouped_compile=enable_grouped_compile,
group_compile_size=group_compile_size,
benchmark_multi_gpu=benchmark_multi_gpu,
benchmark_devices=benchmark_devices,
)
return autotuner_result


def _build_autotuner(
M: int,
N: int,
K: int,
with_roller: bool,
profile_backend: str,
execution_backend: str,
skip_check: bool,
cache_input_tensors: bool,
topk: int,
):
def kernel(
block_M=None,
Expand Down Expand Up @@ -152,20 +199,23 @@ def main(

return main

configs = get_configs(M, N, K, with_roller, topk=topk)
autotuner = (
AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller))
AutoTuner.from_kernel(kernel=kernel, configs=configs)
.set_compile_args(
out_idx=[-1],
target="auto",
execution_backend=execution_backend,
)
.set_profile_args(
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
ref_prog=None if skip_check else ref_program,
skip_check=skip_check,
cache_input_tensors=cache_input_tensors,
backend=profile_backend,
)
)
return autotuner.run(warmup=3, rep=20)
return autotuner, configs, kernel


def get_heuristic_config() -> dict:
Expand Down Expand Up @@ -221,14 +271,26 @@ def main(
use_autotune: bool = False,
with_roller: bool = False,
profile_backend: str = "event",
use_pipeline: bool = False,
enable_grouped_compile: bool = False,
group_compile_size: int = 2,
benchmark_multi_gpu: bool = False,
benchmark_devices: list[int] | None = None,
):
benchmark_devices = benchmark_devices or []

if use_autotune:
result = get_best_config(
M,
N,
K,
with_roller=with_roller,
profile_backend=profile_backend,
use_pipeline=use_pipeline,
enable_grouped_compile=enable_grouped_compile,
group_compile_size=group_compile_size,
benchmark_multi_gpu=benchmark_multi_gpu,
benchmark_devices=benchmark_devices,
)
print(result.config)
kernel = result.kernel
Expand Down Expand Up @@ -267,12 +329,51 @@ def run_regression_perf(M: int = 4096, N: int = 4096, K: int = 4096):
parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs")
parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space")
parser.add_argument("--profile_backend", type=str, default="event", help="Profiler backend")
pipeline_group = parser.add_mutually_exclusive_group()
pipeline_group.add_argument(
"--pipeline", dest="use_pipeline", action="store_true", help="Enable compile/benchmark pipeline in autotune"
)
pipeline_group.add_argument(
"--no-pipeline", dest="use_pipeline", action="store_false", help="Disable compile/benchmark pipeline in autotune"
)
parser.set_defaults(use_pipeline=False)

grouped_compile_group = parser.add_mutually_exclusive_group()
grouped_compile_group.add_argument(
"--grouped-compile", dest="enable_grouped_compile", action="store_true", help="Enable grouped compilation in autotune"
)
grouped_compile_group.add_argument(
"--no-grouped-compile", dest="enable_grouped_compile", action="store_false", help="Disable grouped compilation in autotune"
)
parser.set_defaults(enable_grouped_compile=False)
parser.add_argument("--group-compile-size", type=int, default=2, help="Number of configs per grouped compile unit")

benchmark_multi_gpu_group = parser.add_mutually_exclusive_group()
benchmark_multi_gpu_group.add_argument(
"--benchmark-multi-gpu", dest="benchmark_multi_gpu", action="store_true", help="Benchmark autotune configs across multiple GPUs"
)
benchmark_multi_gpu_group.add_argument(
"--no-benchmark-multi-gpu", dest="benchmark_multi_gpu", action="store_false", help="Benchmark autotune configs on a single GPU"
)
parser.set_defaults(benchmark_multi_gpu=False)
parser.add_argument(
"--benchmark-devices",
action="append",
type=int,
default=[],
help="Repeatable CUDA device ordinals for benchmark workers (e.g. --benchmark-devices 0 --benchmark-devices 1)",
)
args = parser.parse_args()
main(
args.m,
args.n,
args.k,
args.use_autotune,
args.with_roller,
args.profile_backend,
M=args.m,
N=args.n,
K=args.k,
use_autotune=args.use_autotune,
with_roller=args.with_roller,
profile_backend=args.profile_backend,
use_pipeline=args.use_pipeline,
enable_grouped_compile=args.enable_grouped_compile,
group_compile_size=args.group_compile_size,
benchmark_multi_gpu=args.benchmark_multi_gpu,
benchmark_devices=args.benchmark_devices,
)
155 changes: 155 additions & 0 deletions tilelang/autotuner/grouped_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Grouped compilation helpers for autotuner.

This module isolates backend-aware grouped compilation logic from AutoTuner.run
so tuner.py can stay focused on orchestration.
"""

from __future__ import annotations

from typing import Any, Callable

from tilelang import tvm
from tvm.tir import PrimFunc

from tilelang.autotuner.param import CompileArgs
from tilelang.engine.lower import lower_to_host_device_ir, device_codegen, host_codegen
from tilelang.engine.param import CompiledArtifact
from tilelang.jit.adapter import TVMFFIKernelAdapter
from tilelang.jit.kernel import JITKernel
from tilelang.transform import PassConfigKey

CompileUnitResult = tuple[int, dict[str, Any], JITKernel | None, Exception | None]


def compile_grouped_unit_tvm_ffi(
unit_items: list[tuple[int, dict[str, Any]]],
compile_args: CompileArgs,
elaborate_func: Callable[..., PrimFunc],
) -> list[CompileUnitResult]:
"""Compile one grouped unit for CUDA+tvm_ffi backend.

Flow:
1. Elaborate each config into a PrimFunc.
2. Lower each PrimFunc into host/device IR modules.
3. Merge all device IR into one IRModule and compile device code once.
4. Build host runtime module per config and import shared device module.
5. Construct per-config JITKernel objects that share the grouped device module.
"""

pass_configs = dict(compile_args.pass_configs) if compile_args.pass_configs else {}
pass_instruments = []
if pass_configs.get(PassConfigKey.TL_ENABLE_DUMP_IR):
dump_ir_path = pass_configs.get(PassConfigKey.TL_DUMP_IR_DIR, "./dump_ir")
pass_instruments.append(tvm.ir.instrument.DumpIR(dump_dir=dump_ir_path))

unit_results: list[CompileUnitResult] = []
lowered_items: list[dict[str, Any]] = []

for idx, config_arg in unit_items:
try:
program = elaborate_func(**config_arg)
original_symbol = str(program.attrs["global_symbol"])
unique_symbol = f"{original_symbol}_gc_{idx}"
program = program.with_attr("global_symbol", unique_symbol)
Comment thread
Wazrrr marked this conversation as resolved.

with tvm.transform.PassContext(opt_level=3, config=pass_configs, instruments=pass_instruments), compile_args.target:
host_mod, device_mod, params, normalized_target, normalized_target_host = lower_to_host_device_ir(
program,
target=compile_args.target,
target_host=compile_args.target_host,
)

lowered_items.append(
{
"idx": idx,
"config_arg": config_arg,
"program": program,
"host_mod": host_mod,
"device_mod": device_mod,
"params": params,
"target": normalized_target,
"target_host": normalized_target_host,
}
)
except Exception as e:
unit_results.append((idx, config_arg, None, e))

if not lowered_items:
return unit_results

try:
merged_funcs: dict[Any, Any] = {}
merged_attrs = None
merged_names: set[str] = set()
for item in lowered_items:
device_mod = item["device_mod"]
if merged_attrs is None:
merged_attrs = device_mod.attrs
for global_var, func in device_mod.functions.items():
name_hint = getattr(global_var, "name_hint", str(global_var))
if name_hint in merged_names:
raise RuntimeError(
f"Duplicate device global symbol '{name_hint}' during grouped compilation (config index={item['idx']})."
)
merged_names.add(name_hint)
merged_funcs[global_var] = func
merged_device_mod = tvm.IRModule(merged_funcs, attrs=merged_attrs)

reference_target = lowered_items[0]["target"]
with tvm.transform.PassContext(opt_level=3, config=pass_configs, instruments=pass_instruments), reference_target:
grouped_device_rt_mod = device_codegen(merged_device_mod, reference_target)

grouped_kernel_source = grouped_device_rt_mod.inspect_source()

for item in lowered_items:
idx = item["idx"]
config_arg = item["config_arg"]
try:
with tvm.transform.PassContext(opt_level=3, config=pass_configs, instruments=pass_instruments), item["target"]:
grouped_host_rt_mod = host_codegen(item["host_mod"], item["target_host"], target=item["target"])

grouped_host_rt_mod.import_module(grouped_device_rt_mod)

artifact = CompiledArtifact(
host_mod=grouped_host_rt_mod,
device_mod=item["device_mod"],
params=item["params"],
kernel_source=grouped_kernel_source,
rt_mod=grouped_host_rt_mod,
)

adapter = TVMFFIKernelAdapter(
params=artifact.params,
result_idx=compile_args.out_idx,
target=compile_args.target,
func_or_mod=item["program"],
host_mod=artifact.host_mod,
device_mod=artifact.device_mod,
rt_mod=artifact.rt_mod,
device_kernel_source=artifact.kernel_source,
verbose=compile_args.verbose,
pass_configs=pass_configs,
)

jit_kernel = JITKernel(
func=item["program"],
out_idx=compile_args.out_idx,
execution_backend=compile_args.execution_backend,
target=compile_args.target,
target_host=compile_args.target_host,
verbose=compile_args.verbose,
pass_configs=pass_configs,
from_database=True,
)
jit_kernel.artifact = artifact
jit_kernel.adapter = adapter
jit_kernel.torch_function = adapter.func

unit_results.append((idx, config_arg, jit_kernel, None))
except Exception as e:
unit_results.append((idx, config_arg, None, e))
except Exception as e:
for item in lowered_items:
unit_results.append((item["idx"], item["config_arg"], None, e))

return unit_results
1 change: 1 addition & 0 deletions tilelang/autotuner/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def compile_program(self, program: PrimFunc):
return tilelang.compile(
program,
out_idx=self.out_idx,
execution_backend=self.execution_backend,
target=self.target,
target_host=self.target_host,
verbose=self.verbose,
Expand Down
Loading