diff --git a/examples/gemm/example_gemm_advanced_autotune.py b/examples/gemm/example_gemm_advanced_autotune.py new file mode 100644 index 0000000000..4b9781bf82 --- /dev/null +++ b/examples/gemm/example_gemm_advanced_autotune.py @@ -0,0 +1,379 @@ +import argparse +import itertools +import tilelang as tl +import tilelang.language as T +from tilelang.autotuner import AutoTuner +from tilelang.carver.template import MatmulTemplate +from tilelang.carver.arch import CUDA +from tilelang.carver.arch import CDNA +from tilelang.carver.roller.rasterization import NoRasterization +import torch + + +def ref_program(A, B): + """ + Compute the matrix product of A and the transpose of B. + + A and B are expected to be 2-D tensors where A has shape (M, K) and B has shape (N, K). The result is a tensor with shape (M, N) equal to A @ B.T, using the inputs' dtypes. + """ + return A @ B.T + + +def get_configs(M, N, K, with_roller=False, topk=20): + """ + Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply. + + When with_roller is True this queries the MatmulTemplate roller to produce up to `topk` recommended + configurations (device-specific TensorCore-friendly tilings). Each returned dict contains: + - block_M, block_N, block_K: tile sizes + - num_stages: pipeline staging (0 means no explicit staging) + - thread_num: total threads used for the block + - enable_rasteration: whether a rasterization/swizzle layout was recommended (note spelling) + + When with_roller is False this returns the Cartesian product of a fixed set of candidate + parameters; the returned dicts use the backward-compatible key name "enable_rasteration" for that flag. + + Parameters: + M, N, K (int): GEMM dimensions used to generate valid tile sizes. + with_roller (bool): If True, use MatmulTemplate's roller to generate device-aware hints; + otherwise use a predefined candidate grid. + topk (int): Maximum number of roller hints to request when with_roller is True. + + Returns: + List[dict]: A list of configuration dictionaries as described above. + + Raises: + ValueError: if with_roller is True but the roller returns no hints. + """ + if with_roller: + arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip") + carve_template = MatmulTemplate( + M=M, + N=N, + K=K, + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float32, + ).with_arch(arch) + + func = carve_template.equivalent_function() + assert func is not None, "Function is None" + roller_hints = carve_template.recommend_hints(topk=topk) + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + configs = [] + for hint in roller_hints: + config = {} + block_m, block_n = hint.block + warp_m, warp_n = hint.warp + # block_rows, block_cols represents warp partitioning + block_rows, block_cols = block_m // warp_m, block_n // warp_n + config["block_M"] = block_m + config["block_N"] = block_n + config["block_K"] = hint.rstep[0] + config["num_stages"] = hint.pipeline_stage if hint.pipeline_stage > 1 else 0 + config["thread_num"] = block_rows * block_cols * 32 + config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization + configs.append(config) + else: + block_M = [64, 128, 256] + block_N = [64, 128, 256] + block_K = [32, 64] + num_stages = [0, 1, 2, 3] + thread_num = [128, 256] + enable_rasterization = [True, False] + _configs = list( + itertools.product( + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasterization, + ) + ) + + configs = [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5], # keep param name for backward-compat + } + for c in _configs + ] + return configs + + +def get_best_config( + M, + N, + K, + with_roller: bool = False, + profile_backend: str = "event", + execution_backend: str = "auto", + warmup: int = 3, + rep: int = 20, + timeout: int = 30, + skip_check: bool = False, + cache_input_tensors: bool = False, + topk: int = 20, + use_pipeline: bool = False, + enable_grouped_compile: bool = False, + group_compile_size: int = 2, + benchmark_multi_gpu: bool = False, + benchmark_devices: list[int] | None = None, +): + autotuner, _, _ = _build_autotuner( + M=M, + N=N, + K=K, + with_roller=with_roller, + profile_backend=profile_backend, + execution_backend=execution_backend, + skip_check=skip_check, + cache_input_tensors=cache_input_tensors, + topk=topk, + ) + autotuner_result = autotuner.run( + warmup=warmup, + rep=rep, + timeout=timeout, + use_pipeline=use_pipeline, + enable_grouped_compile=enable_grouped_compile, + group_compile_size=group_compile_size, + benchmark_multi_gpu=benchmark_multi_gpu, + benchmark_devices=benchmark_devices, + ) + return autotuner_result + + +def _build_autotuner( + M: int, + N: int, + K: int, + with_roller: bool, + profile_backend: str, + execution_backend: str, + skip_check: bool, + cache_input_tensors: bool, + topk: int, +): + def kernel( + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + enable_rasteration=None, + ): + dtype = T.bfloat16 + accum_dtype = T.float32 + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + ) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + configs = get_configs(M, N, K, with_roller, topk=topk) + autotuner = ( + AutoTuner.from_kernel(kernel=kernel, configs=configs) + .set_compile_args( + out_idx=[-1], + target="auto", + execution_backend=execution_backend, + ) + .set_profile_args( + supply_type=tl.TensorSupplyType.Integer, + ref_prog=None if skip_check else ref_program, + skip_check=skip_check, + cache_input_tensors=cache_input_tensors, + backend=profile_backend, + ) + ) + return autotuner, configs, kernel + + +def get_heuristic_config() -> dict: + # Get CUDA device properties + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + device = torch.cuda.current_device() + sm_major, sm_minor = torch.cuda.get_device_capability(device) + sm_version = sm_major * 10 + sm_minor + print(f"CUDA device capability: {sm_version}") + if sm_version in {80}: + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True} + elif sm_version in {90}: + return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True} + else: + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True} + + +@tl.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm_autotune( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + ) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_autotune + + +def main( + M: int = 4096, + N: int = 4096, + K: int = 4096, + use_autotune: bool = False, + with_roller: bool = False, + profile_backend: str = "event", + use_pipeline: bool = False, + enable_grouped_compile: bool = False, + group_compile_size: int = 2, + benchmark_multi_gpu: bool = False, + benchmark_devices: list[int] | None = None, +): + benchmark_devices = benchmark_devices or [] + + if use_autotune: + result = get_best_config( + M, + N, + K, + with_roller=with_roller, + profile_backend=profile_backend, + use_pipeline=use_pipeline, + enable_grouped_compile=enable_grouped_compile, + group_compile_size=group_compile_size, + benchmark_multi_gpu=benchmark_multi_gpu, + benchmark_devices=benchmark_devices, + ) + print(result.config) + kernel = result.kernel + else: + config = get_heuristic_config() + kernel = matmul(M, N, K, **config) + + # benchmark + profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) + tilelang_latency = profiler.do_bench( + backend=profile_backend, + ) + ref_latency = profiler.do_bench( + ref_program, + backend=profile_backend, + ) + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + print(f"TileLang latency: {tilelang_latency}") + print(f"Ref latency: {ref_latency}") + print(f"TileLang TFlops: {2 * M * N * K / tilelang_latency * 1e-9}") + print(f"Ref TFlops: {2 * M * N * K / ref_latency * 1e-9}") + + +def run_regression_perf(M: int = 4096, N: int = 4096, K: int = 4096): + config = get_heuristic_config() + kernel = matmul(M, N, K, **config) + profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") + parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space") + parser.add_argument("--profile_backend", type=str, default="event", help="Profiler backend") + pipeline_group = parser.add_mutually_exclusive_group() + pipeline_group.add_argument( + "--pipeline", dest="use_pipeline", action="store_true", help="Enable compile/benchmark pipeline in autotune" + ) + pipeline_group.add_argument( + "--no-pipeline", dest="use_pipeline", action="store_false", help="Disable compile/benchmark pipeline in autotune" + ) + parser.set_defaults(use_pipeline=False) + + grouped_compile_group = parser.add_mutually_exclusive_group() + grouped_compile_group.add_argument( + "--grouped-compile", dest="enable_grouped_compile", action="store_true", help="Enable grouped compilation in autotune" + ) + grouped_compile_group.add_argument( + "--no-grouped-compile", dest="enable_grouped_compile", action="store_false", help="Disable grouped compilation in autotune" + ) + parser.set_defaults(enable_grouped_compile=False) + parser.add_argument("--group-compile-size", type=int, default=2, help="Number of configs per grouped compile unit") + + benchmark_multi_gpu_group = parser.add_mutually_exclusive_group() + benchmark_multi_gpu_group.add_argument( + "--benchmark-multi-gpu", dest="benchmark_multi_gpu", action="store_true", help="Benchmark autotune configs across multiple GPUs" + ) + benchmark_multi_gpu_group.add_argument( + "--no-benchmark-multi-gpu", dest="benchmark_multi_gpu", action="store_false", help="Benchmark autotune configs on a single GPU" + ) + parser.set_defaults(benchmark_multi_gpu=False) + parser.add_argument( + "--benchmark-devices", + action="append", + type=int, + default=[], + help="Repeatable CUDA device ordinals for benchmark workers (e.g. --benchmark-devices 0 --benchmark-devices 1)", + ) + args = parser.parse_args() + main( + M=args.m, + N=args.n, + K=args.k, + use_autotune=args.use_autotune, + with_roller=args.with_roller, + profile_backend=args.profile_backend, + use_pipeline=args.use_pipeline, + enable_grouped_compile=args.enable_grouped_compile, + group_compile_size=args.group_compile_size, + benchmark_multi_gpu=args.benchmark_multi_gpu, + benchmark_devices=args.benchmark_devices, + ) diff --git a/tilelang/autotuner/grouped_compile.py b/tilelang/autotuner/grouped_compile.py new file mode 100644 index 0000000000..73e02521b3 --- /dev/null +++ b/tilelang/autotuner/grouped_compile.py @@ -0,0 +1,155 @@ +"""Grouped compilation helpers for autotuner. + +This module isolates backend-aware grouped compilation logic from AutoTuner.run +so tuner.py can stay focused on orchestration. +""" + +from __future__ import annotations + +from typing import Any, Callable + +from tilelang import tvm +from tvm.tir import PrimFunc + +from tilelang.autotuner.param import CompileArgs +from tilelang.engine.lower import lower_to_host_device_ir, device_codegen, host_codegen +from tilelang.engine.param import CompiledArtifact +from tilelang.jit.adapter import TVMFFIKernelAdapter +from tilelang.jit.kernel import JITKernel +from tilelang.transform import PassConfigKey + +CompileUnitResult = tuple[int, dict[str, Any], JITKernel | None, Exception | None] + + +def compile_grouped_unit_tvm_ffi( + unit_items: list[tuple[int, dict[str, Any]]], + compile_args: CompileArgs, + elaborate_func: Callable[..., PrimFunc], +) -> list[CompileUnitResult]: + """Compile one grouped unit for CUDA+tvm_ffi backend. + + Flow: + 1. Elaborate each config into a PrimFunc. + 2. Lower each PrimFunc into host/device IR modules. + 3. Merge all device IR into one IRModule and compile device code once. + 4. Build host runtime module per config and import shared device module. + 5. Construct per-config JITKernel objects that share the grouped device module. + """ + + pass_configs = dict(compile_args.pass_configs) if compile_args.pass_configs else {} + pass_instruments = [] + if pass_configs.get(PassConfigKey.TL_ENABLE_DUMP_IR): + dump_ir_path = pass_configs.get(PassConfigKey.TL_DUMP_IR_DIR, "./dump_ir") + pass_instruments.append(tvm.ir.instrument.DumpIR(dump_dir=dump_ir_path)) + + unit_results: list[CompileUnitResult] = [] + lowered_items: list[dict[str, Any]] = [] + + for idx, config_arg in unit_items: + try: + program = elaborate_func(**config_arg) + original_symbol = str(program.attrs["global_symbol"]) + unique_symbol = f"{original_symbol}_gc_{idx}" + program = program.with_attr("global_symbol", unique_symbol) + + with tvm.transform.PassContext(opt_level=3, config=pass_configs, instruments=pass_instruments), compile_args.target: + host_mod, device_mod, params, normalized_target, normalized_target_host = lower_to_host_device_ir( + program, + target=compile_args.target, + target_host=compile_args.target_host, + ) + + lowered_items.append( + { + "idx": idx, + "config_arg": config_arg, + "program": program, + "host_mod": host_mod, + "device_mod": device_mod, + "params": params, + "target": normalized_target, + "target_host": normalized_target_host, + } + ) + except Exception as e: + unit_results.append((idx, config_arg, None, e)) + + if not lowered_items: + return unit_results + + try: + merged_funcs: dict[Any, Any] = {} + merged_attrs = None + merged_names: set[str] = set() + for item in lowered_items: + device_mod = item["device_mod"] + if merged_attrs is None: + merged_attrs = device_mod.attrs + for global_var, func in device_mod.functions.items(): + name_hint = getattr(global_var, "name_hint", str(global_var)) + if name_hint in merged_names: + raise RuntimeError( + f"Duplicate device global symbol '{name_hint}' during grouped compilation (config index={item['idx']})." + ) + merged_names.add(name_hint) + merged_funcs[global_var] = func + merged_device_mod = tvm.IRModule(merged_funcs, attrs=merged_attrs) + + reference_target = lowered_items[0]["target"] + with tvm.transform.PassContext(opt_level=3, config=pass_configs, instruments=pass_instruments), reference_target: + grouped_device_rt_mod = device_codegen(merged_device_mod, reference_target) + + grouped_kernel_source = grouped_device_rt_mod.inspect_source() + + for item in lowered_items: + idx = item["idx"] + config_arg = item["config_arg"] + try: + with tvm.transform.PassContext(opt_level=3, config=pass_configs, instruments=pass_instruments), item["target"]: + grouped_host_rt_mod = host_codegen(item["host_mod"], item["target_host"], target=item["target"]) + + grouped_host_rt_mod.import_module(grouped_device_rt_mod) + + artifact = CompiledArtifact( + host_mod=grouped_host_rt_mod, + device_mod=item["device_mod"], + params=item["params"], + kernel_source=grouped_kernel_source, + rt_mod=grouped_host_rt_mod, + ) + + adapter = TVMFFIKernelAdapter( + params=artifact.params, + result_idx=compile_args.out_idx, + target=compile_args.target, + func_or_mod=item["program"], + host_mod=artifact.host_mod, + device_mod=artifact.device_mod, + rt_mod=artifact.rt_mod, + device_kernel_source=artifact.kernel_source, + verbose=compile_args.verbose, + pass_configs=pass_configs, + ) + + jit_kernel = JITKernel( + func=item["program"], + out_idx=compile_args.out_idx, + execution_backend=compile_args.execution_backend, + target=compile_args.target, + target_host=compile_args.target_host, + verbose=compile_args.verbose, + pass_configs=pass_configs, + from_database=True, + ) + jit_kernel.artifact = artifact + jit_kernel.adapter = adapter + jit_kernel.torch_function = adapter.func + + unit_results.append((idx, config_arg, jit_kernel, None)) + except Exception as e: + unit_results.append((idx, config_arg, None, e)) + except Exception as e: + for item in lowered_items: + unit_results.append((item["idx"], item["config_arg"], None, e)) + + return unit_results diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index 39cef841d1..ad741b4f9b 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -62,6 +62,7 @@ def compile_program(self, program: PrimFunc): return tilelang.compile( program, out_idx=self.out_idx, + execution_backend=self.execution_backend, target=self.target, target_host=self.target_host, verbose=self.verbose, diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index 9a82294825..3c144abe22 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -26,6 +26,7 @@ from tqdm.auto import tqdm import logging import concurrent.futures +import queue import torch import os import sys @@ -37,6 +38,7 @@ from pathlib import Path from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult +from tilelang.autotuner.grouped_compile import compile_grouped_unit_tvm_ffi from tilelang.utils.language import get_prim_func_name from tilelang.autotuner.capture import get_autotune_inputs from tilelang.utils.target import determine_target @@ -217,6 +219,13 @@ def _normalize_value(value, sort_dict_items: bool = False): return value +@dataclass +class _BenchmarkWorkerState: + jit_input_tensors: Any = None + ref_input_tensors: Any = None + ref_latency_cache: float | None = None + + class AutoTuner: """Auto-tuner for tilelang programs. @@ -243,6 +252,7 @@ def __init__(self, fn: Callable, configs): self.jit_input_tensors = None self.ref_input_tensors = None self.jit_compile = None + self.jit_elaborate = None @classmethod def _get_cache_dir(cls) -> Path: @@ -357,11 +367,29 @@ def set_profile_args( AutoTuner: Self for method chaining. """ # If the program is under `with set_autotune_inputs` context, - # the `supply_prog` will be ignored and the `get_autotune_inputs` will be used instead. - if get_autotune_inputs() is not None: + # freeze captured tensors now so benchmark worker threads do not + # lose them via thread-local storage lookups. + captured_inputs = get_autotune_inputs() + if captured_inputs is not None: if supply_prog is not None: logger.warning("`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context.") - supply_prog = lambda _: get_autotune_inputs() # noqa: E731 + frozen_inputs = list(captured_inputs) + device_cache = {} + + def supply_prog(device, _frozen_inputs=frozen_inputs, _device_cache=device_cache): + if not isinstance(device, (int, str, torch.device)): + device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" + if device not in _device_cache: + if isinstance(device, torch.device): + target_device = device + elif isinstance(device, str): + target_device = torch.device(device) + else: + target_device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu") + _device_cache[device] = [ + tensor.to(device=target_device).clone() if isinstance(tensor, torch.Tensor) else tensor for tensor in _frozen_inputs + ] + return _device_cache[device] self.profile_args = ProfileArgs( supply_type=supply_type, @@ -424,17 +452,424 @@ def _load_result_from_disk(self, key) -> AutotuneResult: result = AutotuneResult.load_from_disk(self.cache_dir / key, self.compile_args) return result - def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30): + # Compile-related helpers + def _default_compile( + self, + **config_arg, + ) -> tilelang.JITKernel: + compile_args = self.compile_args + return compile_args.compile_program(self.fn(**config_arg)) + + def _default_elaborate(self, **config_arg) -> PrimFunc: + return self.fn(**config_arg) + + def _ensure_jit_functions( + self, + ) -> tuple[Callable[..., tilelang.JITKernel], Callable[..., PrimFunc]]: + compile_func = self.jit_compile + elaborate_func = self.jit_elaborate + if compile_func is None: + compile_func = self._default_compile + if elaborate_func is None: + elaborate_func = self._default_elaborate + return compile_func, elaborate_func + + def _resolve_grouped_compile_mode( + self, + enable_grouped_compile: bool, + group_compile_size: int, + ) -> tuple[str, str, bool, str]: + target_kind = self.compile_args.target.kind.name if isinstance(self.compile_args.target, Target) else str(self.compile_args.target) + execution_backend = str(self.compile_args.execution_backend) + grouped_compile_requested = enable_grouped_compile and group_compile_size > 1 + grouped_compile_active = grouped_compile_requested and target_kind == "cuda" and execution_backend == "tvm_ffi" + grouped_compile_reason = "" + if grouped_compile_requested and not grouped_compile_active: + grouped_compile_reason = ( + f"grouped compilation is currently implemented for CUDA+tvm_ffi only; " + f"fallback to per-config mode (target={target_kind}, execution_backend={execution_backend})" + ) + logger.info("%s", grouped_compile_reason) + return target_kind, execution_backend, grouped_compile_active, grouped_compile_reason + + def _resolve_num_compile_workers(self) -> int: + available_cpu_count = get_available_cpu_count() + cpu_utilizations = float(env.TILELANG_AUTO_TUNING_CPU_UTILITIES) + cpu_counts = int(env.TILELANG_AUTO_TUNING_CPU_COUNTS) + max_cpu_count = int(env.TILELANG_AUTO_TUNING_MAX_CPU_COUNT) + if cpu_counts > 0: + num_workers = min(cpu_counts, available_cpu_count) + logger.info(f"Auto-tuning with {cpu_counts} CPU counts, {available_cpu_count} CPUs available, {num_workers} CPUs will be used") + else: + num_workers = max(1, int(available_cpu_count * cpu_utilizations)) + logger.info( + f"Auto-tuning with {cpu_utilizations} CPU utilizations, {available_cpu_count} CPUs available, {num_workers} CPUs will be used" + ) + + if max_cpu_count > 0 and num_workers > max_cpu_count: + logger.warning( + f"Auto-tuning with {cpu_utilizations} CPU utilizations, {available_cpu_count} CPUs available, {num_workers} CPUs will be used, but the max CPU count is {max_cpu_count}, so we will use {max_cpu_count} CPUs" + ) + num_workers = max_cpu_count + return num_workers + + def _prepare_compile_execution( + self, + config_args: list[dict[str, Any]], + grouped_compile_active: bool, + group_compile_size: int, + compile_func: Callable[..., tilelang.JITKernel], + elaborate_func: Callable[..., PrimFunc], + ) -> tuple[ + concurrent.futures.ThreadPoolExecutor, + list[concurrent.futures.Future], + dict[concurrent.futures.Future, list[tuple[int, dict[str, Any]]]], + str, + ]: + num_workers = self._resolve_num_compile_workers() + pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) + futures: list[concurrent.futures.Future] = [] + future_to_unit: dict[concurrent.futures.Future, list[tuple[int, dict[str, Any]]]] = {} + + def cuda_device_wrapper(func: Callable[..., Any], device: int): + def inner(**config_arg): + torch.cuda.set_device(device) + return func(**config_arg) + + return inner + + def get_compile_func(): + compile_impl = compile_func + if torch.cuda.is_available(): + device = torch.cuda.current_device() + compile_impl = cuda_device_wrapper(compile_func, device) + return compile_impl + + def get_elaborate_func(): + elaborate_impl = elaborate_func + if torch.cuda.is_available(): + device = torch.cuda.current_device() + elaborate_impl = cuda_device_wrapper(elaborate_func, device) + return elaborate_impl + + def compile_unit(unit_items: list[tuple[int, dict[str, Any]]]): + if grouped_compile_active: + return compile_grouped_unit_tvm_ffi( + unit_items=unit_items, + compile_args=self.compile_args, + elaborate_func=get_elaborate_func(), + ) + compile_impl = get_compile_func() + unit_results: list[tuple[int, dict[str, Any], tilelang.JITKernel | None, Exception | None]] = [] + for idx, config_arg in unit_items: + try: + jit_kernel = compile_impl(**config_arg) + unit_results.append((idx, config_arg, jit_kernel, None)) + except Exception as e: + unit_results.append((idx, config_arg, None, e)) + return unit_results + + compile_units: list[list[tuple[int, dict[str, Any]]]] = [] + if grouped_compile_active: + for start in range(0, len(config_args), group_compile_size): + end = min(start + group_compile_size, len(config_args)) + compile_units.append([(i, config_args[i]) for i in range(start, end)]) + else: + for i, config_arg in enumerate(config_args): + compile_units.append([(i, config_arg)]) + + for unit_items in compile_units: + future = pool.submit(compile_unit, unit_items) + futures.append(future) + future_to_unit[future] = unit_items + + compile_desc = "Compiling configurations (grouped)" if grouped_compile_active else "Compiling configurations" + return pool, futures, future_to_unit, compile_desc + + # Benchmark-related helpers + def _benchmark_worker_loop( + self, + worker_device: int, + worker_queue: queue.SimpleQueue, + result_queue: queue.SimpleQueue, + start_event: threading.Event, + target_kind: str, + benchmark_target: Callable[..., tuple[float, float | None]], + timeout: int, + worker_state: _BenchmarkWorkerState, + ) -> None: + if torch.cuda.is_available() and target_kind == "cuda": + try: + torch.cuda.set_device(worker_device) + except Exception: + logger.warning("Failed to bind benchmark worker to cuda:%s", worker_device) + logger.debug("Error: %s", traceback.format_exc()) + + start_event.wait() + queue_poll_timeout_s = 0.1 + while True: + try: + item = worker_queue.get(timeout=queue_poll_timeout_s) + except queue.Empty: + continue + if item is None: + break + jit_kernel, config, idx = item + try: + if timeout > 0: + call_result_queue: queue.SimpleQueue = queue.SimpleQueue() + call_state = _BenchmarkWorkerState( + jit_input_tensors=worker_state.jit_input_tensors, + ref_input_tensors=worker_state.ref_input_tensors, + ref_latency_cache=worker_state.ref_latency_cache, + ) + + def _run_benchmark_target( + _jit_kernel: tilelang.JITKernel = jit_kernel, + _worker_state: _BenchmarkWorkerState = call_state, + _call_result_queue: queue.SimpleQueue = call_result_queue, + ): + try: + latency, worker_ref_latency = benchmark_target( + jit_kernel=_jit_kernel, + benchmark_state=_worker_state, + ) + _call_result_queue.put(("ok", latency, worker_ref_latency, "")) + except TimeoutException: + _call_result_queue.put(("timeout", None, None, "")) + except Exception: + _call_result_queue.put(("error", None, None, traceback.format_exc())) + + benchmark_call_thread = threading.Thread(target=_run_benchmark_target, daemon=True) + benchmark_call_thread.start() + benchmark_call_thread.join(timeout=timeout) + if benchmark_call_thread.is_alive(): + result_queue.put((idx, config, jit_kernel, None, None, "timeout", "")) + continue + + try: + status, latency, worker_ref_latency, error_text = call_result_queue.get_nowait() + except queue.Empty: + result_queue.put( + ( + idx, + config, + jit_kernel, + None, + None, + "error", + "Benchmark call thread exited without returning a result.", + ) + ) + continue + + if status == "ok": + worker_state.jit_input_tensors = call_state.jit_input_tensors + worker_state.ref_input_tensors = call_state.ref_input_tensors + worker_state.ref_latency_cache = call_state.ref_latency_cache + result_queue.put((idx, config, jit_kernel, latency, worker_ref_latency, None, "")) + elif status == "timeout": + result_queue.put((idx, config, jit_kernel, None, None, "timeout", "")) + else: + result_queue.put((idx, config, jit_kernel, None, None, "error", error_text)) + else: + latency, worker_ref_latency = benchmark_target( + jit_kernel=jit_kernel, + benchmark_state=worker_state, + ) + result_queue.put((idx, config, jit_kernel, latency, worker_ref_latency, None, "")) + except TimeoutException: + result_queue.put((idx, config, jit_kernel, None, None, "timeout", "")) + except Exception: + result_queue.put((idx, config, jit_kernel, None, None, "error", traceback.format_exc())) + + def _benchmark_target( + self, + jit_kernel: tilelang.JITKernel, + warmup: int, + rep: int, + benchmark_state: _BenchmarkWorkerState, + ) -> tuple[float, float | None]: + profile_args = self.profile_args + supply_type = profile_args.supply_type + skip_check = profile_args.skip_check + manual_check_prog = profile_args.manual_check_prog + cache_input_tensors = profile_args.cache_input_tensors + ref_prog = profile_args.ref_prog + supply_prog = profile_args.supply_prog + rtol = profile_args.rtol + atol = profile_args.atol + max_mismatched_ratio = profile_args.max_mismatched_ratio + backend = profile_args.backend + + profiler = jit_kernel.get_profiler(tensor_supply_type=supply_type) + + def get_input_tensors_supply(with_output: bool): + def func(): + if supply_prog is not None: + return supply_prog(profiler._get_params(with_output=with_output)) + else: + return profiler._get_inputs(with_output=with_output) + + return func + + jit_input_tensors_supply = get_input_tensors_supply(with_output=False) + ref_input_tensors_supply = get_input_tensors_supply(with_output=False) + + jit_input_tensors_cache = benchmark_state.jit_input_tensors + ref_input_tensors_cache = benchmark_state.ref_input_tensors + ref_latency_cache = benchmark_state.ref_latency_cache + + if cache_input_tensors: + params = profiler._get_params(with_output=False) + if jit_input_tensors_cache is None: + jit_input_tensors_cache = jit_input_tensors_supply() + else: + assert len(params) == len(jit_input_tensors_cache), "len(params) != len(jit_input_tensors_cache)" + for p, c in zip(params, jit_input_tensors_cache): + if not isinstance(c, torch.Tensor): + continue + + def shape_equal(a, b): + return all( + a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var) for a_dim, b_dim in zip(a.shape, b.shape) + ) + + if p.dtype != c.dtype or not shape_equal(p, c): + logger.warning( + "\nIncompatible input tensor properties detected between cached tensors and " + "tensors regenerated for the current configuration trial. " + "This can happen if different tuning configurations require different input shapes/dtypes " + "and input tensor caching is enabled.\n" + "To ensure fresh, compatible inputs are generated for every trial " + "you can disable caching by setting:\n" + " `cache_input_tensors=False`\n" + "within your `.set_compile_args(...)` call.\n" + ) + jit_input_tensors_cache = jit_input_tensors_supply() + break + else: + jit_input_tensors_cache = jit_input_tensors_supply() + + if (not skip_check) and (ref_prog is not None): + if manual_check_prog is not None: + profiler.manual_assert_close(ref_prog, input_tensors=jit_input_tensors_cache, manual_check_prog=manual_check_prog) + else: + profiler.assert_allclose( + ref_prog, input_tensors=jit_input_tensors_cache, rtol=rtol, atol=atol, max_mismatched_ratio=max_mismatched_ratio + ) + latency = profiler.do_bench(n_warmup=warmup, n_repeat=rep, input_tensors=jit_input_tensors_cache, backend=backend) + + if ref_latency_cache is None and ref_prog is not None: + ref_input_tensors_cache = ref_input_tensors_supply() + ref_latency_cache = profiler.do_bench( + ref_prog, + n_warmup=warmup, + n_repeat=rep, + input_tensors=ref_input_tensors_cache, + backend=backend, + ) + + benchmark_state.jit_input_tensors = jit_input_tensors_cache + benchmark_state.ref_input_tensors = ref_input_tensors_cache + benchmark_state.ref_latency_cache = ref_latency_cache + + return latency, ref_latency_cache + + def _resolve_benchmark_devices( + self, + benchmark_multi_gpu: bool, + benchmark_devices: list[int] | None, + target_kind: str, + ) -> tuple[bool, list[int]]: + current_device = torch.cuda.current_device() if torch.cuda.is_available() else 0 + single_device = [current_device] + + if not benchmark_multi_gpu: + return False, single_device + + if target_kind != "cuda": + logger.warning( + "Multi-GPU benchmark requested but target is '%s'. Falling back to single-device benchmark on cuda:%s.", + target_kind, + current_device, + ) + return False, single_device + + if not torch.cuda.is_available(): + logger.warning("Multi-GPU benchmark requested but CUDA is unavailable. Falling back to single-device benchmark.") + return False, single_device + + visible_device_count = torch.cuda.device_count() + if visible_device_count <= 0: + logger.warning("Multi-GPU benchmark requested but no visible CUDA devices found. Falling back to single-device benchmark.") + return False, single_device + + requested_devices: list[int] = [] + if benchmark_devices: + requested_devices = list(dict.fromkeys(int(device) for device in benchmark_devices)) + else: + raw_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "") + parsed_visible_devices = [token.strip() for token in raw_visible_devices.split(",") if token.strip()] + if parsed_visible_devices: + requested_devices = list(range(len(parsed_visible_devices))) + else: + requested_devices = list(range(visible_device_count)) if visible_device_count > 0 else single_device + + valid_devices: list[int] = [] + invalid_devices: list[int] = [] + for device in requested_devices: + if 0 <= device < visible_device_count: + valid_devices.append(device) + else: + invalid_devices.append(device) + + valid_devices = list(dict.fromkeys(valid_devices)) + if invalid_devices: + logger.warning( + "Ignoring invalid benchmark device ids %s. Visible CUDA device ordinals are [0, %d].", + invalid_devices, + max(0, visible_device_count - 1), + ) + + if not valid_devices: + logger.warning( + "No valid benchmark devices resolved for multi-GPU benchmark. Falling back to single-device benchmark on cuda:%s.", + current_device, + ) + return False, single_device + + return len(valid_devices) > 1, valid_devices + + def run( + self, + warmup: int = 25, + rep: int = 100, + timeout: int = 180, + use_pipeline: bool = False, + enable_grouped_compile: bool = False, + group_compile_size: int = 2, + benchmark_devices: list[int] | None = None, + benchmark_multi_gpu: bool = False, + ): """Run the auto-tuning process. Args: warmup: Number of warmup iterations. rep: Number of repetitions for timing. timeout: Maximum time per configuration. + use_pipeline: Whether to pipeline benchmarking with compilation. + enable_grouped_compile: Whether to enable grouped compilation. + group_compile_size: Number of configurations in one compile unit. + benchmark_devices: CUDA device ordinals used for benchmark workers when benchmark_multi_gpu=True. + benchmark_multi_gpu: Whether to benchmark configurations across multiple CUDA GPUs. Returns: - AutotuneResult: Results of the auto-tuning process. + AutotuneResult: The best autotuned artifact. """ + if group_compile_size <= 0: + raise ValueError("group_compile_size must be > 0") + _init_logger_handlers() sig = inspect.signature(self.fn) @@ -491,99 +926,9 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30): best_config: dict[str, Any] | None = None best_kernel: tilelang.JITKernel | None = None - def _compile(**config_arg) -> tilelang.JITKernel: - compile_args = self.compile_args - return compile_args.compile_program(self.fn(**config_arg)) - - if self.jit_compile is None: - self.jit_compile = _compile - - def target_fn(jit_kernel: tilelang.JITKernel): - # Unpack the context - profile_args = self.profile_args - supply_type = profile_args.supply_type - skip_check = profile_args.skip_check - manual_check_prog = profile_args.manual_check_prog - cache_input_tensors = profile_args.cache_input_tensors - ref_prog = profile_args.ref_prog - supply_prog = profile_args.supply_prog - rtol = profile_args.rtol - atol = profile_args.atol - max_mismatched_ratio = profile_args.max_mismatched_ratio - backend = profile_args.backend - - profiler = jit_kernel.get_profiler(tensor_supply_type=supply_type) - - # Factory functions for generating input tensors. - # This encapsulates the logic of using either a custom supply program (`supply_prog`) - # or the default profiler input generation (`profiler._get_inputs`). - def get_input_tensors_supply(with_output: bool): - def func(): - if supply_prog is not None: - return supply_prog(profiler._get_params(with_output=with_output)) - else: - return profiler._get_inputs(with_output=with_output) - - return func - - jit_input_tensors_supply = get_input_tensors_supply(with_output=False) - ref_input_tensors_supply = get_input_tensors_supply(with_output=False) - - if cache_input_tensors: - params = profiler._get_params(with_output=False) - if self.jit_input_tensors is None: - self.jit_input_tensors = jit_input_tensors_supply() - else: - # check if the cached tensors are compatible with the current configuration - assert len(params) == len(self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)" - for p, c in zip(params, self.jit_input_tensors): - if not isinstance(c, torch.Tensor): - # skip non-tensor inputs checking - continue - - # Check tensor compatibility using generator expression - def shape_equal(a, b): - return all( - a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var) for a_dim, b_dim in zip(a.shape, b.shape) - ) - - if p.dtype != c.dtype or not shape_equal(p, c): - logger.warning( - "\nIncompatible input tensor properties detected between cached tensors and " - "tensors regenerated for the current configuration trial. " - "This can happen if different tuning configurations require different input shapes/dtypes " - "and input tensor caching is enabled.\n" - "To ensure fresh, compatible inputs are generated for every trial " - "you can disable caching by setting:\n" - " `cache_input_tensors=False`\n" - "within your `.set_compile_args(...)` call.\n" - ) - # otherwise, regenerate the input tensors for safety - self.jit_input_tensors = jit_input_tensors_supply() - break - else: - self.jit_input_tensors = jit_input_tensors_supply() - - if (not skip_check) and (ref_prog is not None): - if manual_check_prog is not None: - profiler.manual_assert_close(ref_prog, input_tensors=self.jit_input_tensors, manual_check_prog=manual_check_prog) - else: - profiler.assert_allclose( - ref_prog, input_tensors=self.jit_input_tensors, rtol=rtol, atol=atol, max_mismatched_ratio=max_mismatched_ratio - ) - latency = profiler.do_bench(warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors, backend=backend) - - if self.ref_latency_cache is None and ref_prog is not None: - self.ref_input_tensors = ref_input_tensors_supply() - self.ref_latency_cache = profiler.do_bench( - ref_prog, - n_warmup=warmup, - n_repeat=rep, - input_tensors=self.ref_input_tensors, - backend=backend, - ) - - return latency, self.ref_latency_cache + compile_func, elaborate_func = self._ensure_jit_functions() + self.jit_compile = compile_func + self.jit_elaborate = elaborate_func config_args = [] for config in self.configs: @@ -600,6 +945,17 @@ def shape_equal(a, b): if len(config_args) == 0: raise ValueError("No configurations to tune, please check your `@autotune` decorator") + target_kind, _, grouped_compile_active, _ = self._resolve_grouped_compile_mode( + enable_grouped_compile=enable_grouped_compile, + group_compile_size=group_compile_size, + ) + + benchmark_multi_gpu_active, benchmark_device_list = self._resolve_benchmark_devices( + benchmark_multi_gpu=benchmark_multi_gpu, + benchmark_devices=benchmark_devices, + target_kind=target_kind, + ) + # check if the tunable arguments has been set. # get the back config argument top_config, *rest = config_args @@ -625,89 +981,169 @@ def check_tunable_argument_value(key, parameters, key_args_tuple) -> bool: autotuner_result = AutotuneResult(libcode=jit_kernel.get_kernel_source(), func=jit_kernel.prim_func, kernel=jit_kernel) self._memory_cache[key] = autotuner_result return autotuner_result - # get the cpu count - available_cpu_count = get_available_cpu_count() - cpu_utilizations = float(env.TILELANG_AUTO_TUNING_CPU_UTILITIES) - cpu_counts = int(env.TILELANG_AUTO_TUNING_CPU_COUNTS) - max_cpu_count = int(env.TILELANG_AUTO_TUNING_MAX_CPU_COUNT) - if cpu_counts > 0: - num_workers = min(cpu_counts, available_cpu_count) - logger.info(f"Auto-tuning with {cpu_counts} CPU counts, {available_cpu_count} CPUs available, {num_workers} CPUs will be used") - else: - num_workers = max(1, int(available_cpu_count * cpu_utilizations)) - logger.info( - f"Auto-tuning with {cpu_utilizations} CPU utilizations, {available_cpu_count} CPUs available, {num_workers} CPUs will be used" - ) - - if max_cpu_count > 0 and num_workers > max_cpu_count: - logger.warning( - f"Auto-tuning with {cpu_utilizations} CPU utilizations, {available_cpu_count} CPUs available, {num_workers} CPUs will be used, but the max CPU count is {max_cpu_count}, so we will use {max_cpu_count} CPUs" - ) - num_workers = max_cpu_count - pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) - futures = [] - future_to_index = {} + # Launch compile tasks + pool, futures, future_to_unit, compile_desc = self._prepare_compile_execution( + config_args=config_args, + grouped_compile_active=grouped_compile_active, + group_compile_size=group_compile_size, + compile_func=compile_func, + elaborate_func=elaborate_func, + ) - def cuda_device_wrapper(func, device): - def inner(**config_arg): - torch.cuda.set_device(device) - return func(**config_arg) + ref_latency = None + main_thread_benchmark_state = _BenchmarkWorkerState( + jit_input_tensors=self.jit_input_tensors, + ref_input_tensors=self.ref_input_tensors, + ref_latency_cache=self.ref_latency_cache, + ) - return inner + def _record_benchmark_result(latency: float, config: dict[str, Any], jit_kernel: tilelang.JITKernel, idx: int, progress_bar): + nonlocal best_latency, best_config, best_kernel + if latency < best_latency: + best_latency = latency + best_config = config + best_kernel = jit_kernel - for i, config_arg in enumerate(config_args): - compile_func = self.jit_compile + progress_bar.set_postfix({"best_latency": best_latency}) + tqdm.write(f"Tuned Latency {latency} with config {config} at index {idx}") - if torch.cuda.is_available(): - device = torch.cuda.current_device() + benchmark_worker_devices = benchmark_device_list if benchmark_multi_gpu_active else [benchmark_device_list[0]] + benchmark_task_queues = [queue.SimpleQueue() for _ in benchmark_worker_devices] + benchmark_result_queue: queue.SimpleQueue = queue.SimpleQueue() + benchmark_start_event = threading.Event() + benchmark_threads: list[threading.Thread] = [] + benchmark_expected_results = 0 + benchmark_processed_results = 0 - compile_func = cuda_device_wrapper(self.jit_compile, device) + if use_pipeline: + benchmark_start_event.set() - future = pool.submit( - compile_func, - **config_arg, + if timeout > 0: + logger.warning( + "Benchmark timeout is enforced in benchmark workers by running each benchmark call " + "in a daemon sub-thread and waiting up to the configured timeout." ) - futures.append(future) - future_to_index[future] = i + benchmark_target = partial( + self._benchmark_target, + warmup=warmup, + rep=rep, + ) - results_with_configs = [] - for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Compiling configurations"): - idx = future_to_index[future] - config = config_args[idx] - try: - result = future.result() - results_with_configs.append((result, config)) - except Exception as e: - logger.debug(f"Compilation failed for config {config} at index {idx} with error: {e}") - continue + def _enqueue_benchmark_task(jit_kernel: tilelang.JITKernel, config: dict[str, Any], idx: int): + nonlocal benchmark_expected_results + queue_idx = idx % len(benchmark_task_queues) + benchmark_task_queues[queue_idx].put((jit_kernel, config, idx)) + benchmark_expected_results += 1 - ref_latency = None - progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations") - for i in progress_bar: - jit_kernel, config = results_with_configs[i] - try: - # Cannot ThreadPoolExecutor to enforce timeout on target_fn execution - # Because tma init may behave strangely with one thread - # latency, ref_latency = target_fn(jit_kernel) - latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel) - except TimeoutException: + def _process_benchmark_result(result_item, progress_bar): + nonlocal benchmark_processed_results, ref_latency + idx, config, jit_kernel, latency, worker_ref_latency, status, error_text = result_item + benchmark_processed_results += 1 + progress_bar.update(1) + + if status == "timeout": logger.warning(f"A timeout occurred while testing config {config}, checkout autotuner.log for more details") - continue - except Exception: + return + if status == "error": logger.warning(f"An error occurred while testing config {config}, checkout autotuner.log for more details") - logger.debug(f"Error: {traceback.format_exc()}") - continue - - if latency < best_latency: - best_latency = latency - best_config = config - best_kernel = jit_kernel - - progress_bar.set_postfix({"best_latency": best_latency}) - tqdm.write(f"Tuned Latency {latency} with config {config} at index {i}") + if error_text: + logger.debug(f"Error: {error_text}") + return + + if worker_ref_latency is not None: + ref_latency = worker_ref_latency + assert latency is not None + _record_benchmark_result(latency=latency, config=config, jit_kernel=jit_kernel, idx=idx, progress_bar=progress_bar) + + def _drain_benchmark_results(progress_bar, block: bool): + while benchmark_processed_results < benchmark_expected_results: + try: + if block: + result_item = benchmark_result_queue.get(timeout=0.1) + else: + result_item = benchmark_result_queue.get_nowait() + except queue.Empty: + break + _process_benchmark_result(result_item, progress_bar) + + # Start benchmark worker threads + for worker_idx, worker_device in enumerate(benchmark_worker_devices): + worker_state = _BenchmarkWorkerState() if benchmark_multi_gpu_active else main_thread_benchmark_state + worker_thread = threading.Thread( + target=self._benchmark_worker_loop, + args=( + worker_device, + benchmark_task_queues[worker_idx], + benchmark_result_queue, + benchmark_start_event, + target_kind, + benchmark_target, + timeout, + worker_state, + ), + daemon=True, + ) + worker_thread.start() + benchmark_threads.append(worker_thread) + + compile_progress = tqdm(total=len(config_args), desc=compile_desc) + progress_bar = tqdm(total=len(config_args), desc="Bench configurations") + pending_futures = set(futures) + + # Main thread loop to process compile results and feed benchmark tasks, end when all compile tasks are done. + try: + while pending_futures: + done, pending_futures = concurrent.futures.wait( + pending_futures, + return_when=concurrent.futures.FIRST_COMPLETED, + ) - pool.shutdown() + for future in done: + unit_items = future_to_unit[future] + try: + unit_results = future.result() + except Exception as e: + compile_progress.update(len(unit_items)) + unit_indexes = [idx for idx, _ in unit_items] + logger.debug("Compilation unit failed for indexes %s with error: %s", unit_indexes, e) + continue + + compile_progress.update(len(unit_results)) + for idx, config, jit_kernel, error in unit_results: + if error is not None: + logger.debug(f"Compilation failed for config {config} at index {idx} with error: {error}") + continue + assert jit_kernel is not None + _enqueue_benchmark_task(jit_kernel=jit_kernel, config=config, idx=idx) + + _drain_benchmark_results(progress_bar=progress_bar, block=False) + + benchmark_start_event.set() + for worker_queue in benchmark_task_queues: + worker_queue.put(None) + + while benchmark_processed_results < benchmark_expected_results: + _drain_benchmark_results(progress_bar=progress_bar, block=True) + + # Avoid misleading unfinished progress bars when compile failures happen. + progress_bar.total = max(progress_bar.n, benchmark_processed_results) + progress_bar.refresh() + finally: + benchmark_start_event.set() + for worker_queue in benchmark_task_queues: + worker_queue.put(None) + for worker_thread in benchmark_threads: + worker_thread.join(timeout=1.0) + if worker_thread.is_alive(): + logger.warning("Benchmark worker thread did not exit cleanly before shutdown.") + compile_progress.close() + progress_bar.close() + pool.shutdown() + + self.jit_input_tensors = main_thread_benchmark_state.jit_input_tensors + self.ref_input_tensors = main_thread_benchmark_state.ref_input_tensors + self.ref_latency_cache = main_thread_benchmark_state.ref_latency_cache if best_kernel is None: error_msg = "Auto-tuning failed: No configuration successfully compiled and passed benchmarking/validation." @@ -809,6 +1245,12 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel | _T: norm_kwargs = _normalize_value(kwargs, sort_dict_items=True) key = (norm_args, norm_kwargs) if key not in self._tuner_cache: + + def jit_elaborate(**config_arg): + merged = dict(kwargs) + merged.update(config_arg) + return self.jit_impl.get_tir(*args, **merged) + if mode == "lazy": def jit_compile(**config_arg): @@ -816,6 +1258,7 @@ def jit_compile(**config_arg): autotuner = self.get_tunner() autotuner.jit_compile = jit_compile + autotuner.jit_elaborate = jit_elaborate autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters) else: @@ -826,6 +1269,7 @@ def jit_compile(**config_arg): autotuner = self.get_tunner() autotuner.jit_compile = jit_compile + autotuner.jit_elaborate = jit_elaborate autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters) artifact = autotuner.run() diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index cd433c6e14..b5622963f3 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -270,20 +270,13 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> return device_mod -def lower( +def lower_to_host_device_ir( func_or_mod: tir.PrimFunc | tvm.IRModule, target: str | Target = "auto", target_host: str | Target | None = None, - runtime_only=False, - enable_host_codegen=False, - enable_device_compile=False, -) -> CompiledArtifact: - """ - enable_host_codegen: whether to enable host codegen, default is False, as we have our - own host codegen implementation in jit. - enable_device_compile: whether to enable device codegen, default is False, as we have our - own device codegen implementation in jit. - """ + runtime_only: bool = False, +) -> tuple[tvm.IRModule, tvm.IRModule, list[KernelParam] | None, Target, Target]: + """Lower input TIR to split host/device IRModules without backend codegen.""" mod = func_or_mod params = None @@ -314,6 +307,32 @@ def lower( host_mod = tir.transform.Filter(_is_host_call)(mod) device_mod = tir.transform.Filter(_is_device_call)(mod) + + return host_mod, device_mod, params, target, target_host + + +def lower( + func_or_mod: tir.PrimFunc | tvm.IRModule, + target: str | Target = "auto", + target_host: str | Target | None = None, + runtime_only=False, + enable_host_codegen=False, + enable_device_compile=False, +) -> CompiledArtifact: + """ + enable_host_codegen: whether to enable host codegen, default is False, as we have our + own host codegen implementation in jit. + enable_device_compile: whether to enable device codegen, default is False, as we have our + own device codegen implementation in jit. + """ + + host_mod, device_mod, params, target, target_host = lower_to_host_device_ir( + func_or_mod=func_or_mod, + target=target, + target_host=target_host, + runtime_only=runtime_only, + ) + codegen_mod = device_codegen(device_mod, target) if enable_device_compile else device_codegen_without_compile(device_mod, target) kernel_source = codegen_mod.inspect_source()