Mega kernel implementation based on TileLang.
Inspired by Triton-distributed's mega_triton_kernel.
Fuses multiple GPU kernels into a single mega kernel using TileLang.
pip install tilelang
pip install -e .import torch
from megatile import ModelBuilder
builder = ModelBuilder(num_warps=4)
input = torch.randn(1024, 512, dtype=torch.bfloat16, device="cuda").contiguous()
weight = torch.randn(1024, 512, dtype=torch.bfloat16, device="cuda").contiguous()
output = torch.empty(1024, 1024, dtype=torch.bfloat16, device="cuda").contiguous()
builder.make_linear(input, weight, output, layer_id=0)
builder.compile()
builder.run()core/- task management, scheduling, code generationtasks/- task builders (e.g., linear)kernels/- TileLang kernel implementationsmodels/- ModelBuilder API
This project is based on Triton-distributed's mega_triton_kernel implementation.
Thanks to the ByteDance Seed team for their excellent work.