Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Magpie/modes/benchmark/gap_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ def to_csv(
if find_kernel_sources:
from Magpie.tools.amd_kernel_finder import KernelSourceInfo
headers.extend(KernelSourceInfo.csv_headers())
f.write(
"# Warning: columns 7-19 from find_kernel_sources/testcase "

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we going to have these columns always the same? so always 7-19?

"are experimental.\n"
)

writer.writerow(headers)

Expand Down
2 changes: 1 addition & 1 deletion Magpie/tools/amd_kernel_finder/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def search(self, kernel_name: str) -> KernelSourceInfo:
source_match = self.searcher.search_source(parsed)

# Search for test
test_match = self.searcher.search_test(parsed, source_match)
test_match = self.searcher.search_test(parsed, source_match, category=category)

# Search for PyTorch eager baseline reference. We hand over the
# already-computed test_match + category so the searcher does not
Expand Down
40 changes: 38 additions & 2 deletions Magpie/tools/amd_kernel_finder/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ class KernelNameParser:
INDUCTOR_PATTERN = re.compile(r'triton_\w+_fused_')
HIPBLASLT_PATTERN = re.compile(r'wvSplitK|wvSpltK|DeviceGemmWmma')
AITER_PATTERN = re.compile(r'^_ZN5aiter|aiter::')
SGLANG_PATTERN = re.compile(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any correlation with _is_sglang_kernel in searcher.py? there is some regex difference with that.

r'^_ZN7sgl_hip|sgl_hip::|write_req_to_token_pool|'
r'create_flashinfer_kv_indices|flashinfer|future_token_ids|'
r'kn_get_mla_metadata|mla_metadata|clamp_position|compute_position|'
r'set_mla_kv_buffer'
)
ROCM_RUNTIME_PATTERN = re.compile(r'^__amd_rocclr_|^MEMORY_COPY_')

# Category keywords.
Expand Down Expand Up @@ -108,6 +114,11 @@ def _classify_kind(self, name: str) -> KernelKind:
# Check for aiter kernels (before CK and Triton checks)
if self.AITER_PATTERN.search(name):
return KernelKind.AITER

# SGLang owns HIP/C++ and Triton-emitted runtime kernels; keep the kind
# generic, like vLLM, and let source search route to $SGLANG_DIR.
if self.SGLANG_PATTERN.search(name):
return KernelKind.HIP_CPP

# Check for CK tile (handles both mangled and readable names)
if self.CK_PATTERN.search(name):
Expand Down Expand Up @@ -321,12 +332,37 @@ def _parse_hip(self, name: str) -> ParsedKernelName:
extra['namespace'] = match.group(1)
function_name = match.group(2)
else:
if 'act_and_mul_kernel' in name:
extra['namespace'] = 'sgl_hip'
function_name = 'act_and_mul_kernel'
elif 'write_req_to_token_pool_triton' in name:
function_name = 'write_req_to_token_pool_triton'
elif 'create_flashinfer_kv_indices_triton' in name:
function_name = 'create_flashinfer_kv_indices_triton'
elif 'kn_get_mla_metadata' in name:
function_name = 'kn_get_mla_metadata'
elif 'clamp_position_kernel' in name:
function_name = 'clamp_position_kernel'
elif 'compute_position_kernel' in name:
function_name = 'compute_position_kernel'
elif 'resolve_future_token_ids_kernel' in name:
function_name = 'resolve_future_token_ids_kernel'
elif 'set_mla_kv_buffer_kernel' in name:
function_name = 'set_mla_kv_buffer_kernel'
elif '_ZN7sgl_hip' in name:
extra['namespace'] = 'sgl_hip'
function_name = 'SGLang HIP kernel'
# Try without namespace
else:
match = re.search(r'void (?:\(anonymous namespace\)::)?(\w+)', name)
if match:
function_name = match.group(1)
else:
function_name = "HIP kernel"
if function_name == "HIP kernel":
match = re.search(r'void (\w+)<', name)
if match:
function_name = match.group(1)
else:
function_name = "HIP kernel"

# Extract dtype
if '__hip_bfloat16' in name:
Expand Down
29 changes: 29 additions & 0 deletions Magpie/tools/amd_kernel_finder/repo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,26 @@ class RepoConfig:
],
},
),
"sglang": RepoConfig(
name="sglang",
var_name="$SGLANG_DIR",
github_base="https://github.com/sgl-project/sglang",
source_paths={
"sglang": [
"python/sglang/",
"sglang/",
"sgl-kernel/",
"csrc/",
],
},
test_paths={
"sglang": [
"test/",
"tests/",
"sgl-kernel/tests/",
],
},
),
}


Expand All @@ -268,6 +288,7 @@ class RepoConfig:
"$VLLM_DIR": ("$VLLM_DIR", ""), # vllm is its own repo
"$PYTORCH_DIR": ("$PYTORCH_DIR", ""), # pytorch is its own repo
"$AITER_DIR": ("$AITER_DIR", ""), # aiter is its own repo
"$SGLANG_DIR": ("$SGLANG_DIR", ""), # sglang is its own repo
"$ROCM_SYSTEMS_DIR": ("$ROCM_SYSTEMS_DIR", ""), # rocm-systems super-repo (clr, hip, rocprofiler, etc)
}

Expand All @@ -280,6 +301,7 @@ class RepoConfig:
"vllm": "https://github.com/vllm-project/vllm/blob/main/{path}",
"pytorch": "https://github.com/pytorch/pytorch/blob/main/{path}",
"aiter": "https://github.com/ROCm/aiter/blob/main/{path}",
"sglang": "https://github.com/sgl-project/sglang/blob/main/{path}",
}


Expand Down Expand Up @@ -313,6 +335,13 @@ def detect_repo_type(repo_path: str) -> Optional[str]:
# Check for pytorch
if (path / "aten").exists() and (path / "torch").exists():
return "pytorch"

# Check for SGLang
if (
(path / "python" / "sglang").exists()
or (path / "sglang").exists()
) and ((path / "sgl-kernel").exists() or (path / "test").exists() or (path / "tests").exists()):
return "sglang"

return None

Expand Down
23 changes: 22 additions & 1 deletion Magpie/tools/amd_kernel_finder/repo_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"pytorch": "https://github.com/pytorch/pytorch.git",
"rocm-systems": "https://github.com/ROCm/rocm-systems.git", # ROCm super-repo (clr, hip, rocprofiler, etc)
"aiter": "https://github.com/ROCm/aiter.git",
"sglang": "https://github.com/sgl-project/sglang.git",
}

KERNEL_REPO_MAP = {
Expand All @@ -38,7 +39,7 @@
}

# All repos to clone when force_all is True
ALL_REPOS = ["rocm-libraries", "triton", "vllm", "pytorch", "aiter", "rocm-systems"]
ALL_REPOS = ["rocm-libraries", "triton", "vllm", "pytorch", "aiter", "sglang", "rocm-systems"]



Expand Down Expand Up @@ -158,13 +159,31 @@ def get_repos_for_kernels(self, kernel_names: List[str], force_all: bool = False
parser = KernelNameParser()
kinds = set()
has_vllm = False
has_sglang = False

for name in kernel_names:
parsed = parser.parse(name)
kinds.add(parsed.kind.value)

if "vllm::" in name or "vllm" in name.lower():
has_vllm = True
name_lc = name.lower()
if any(
token in name_lc
for token in (
"sglang",
"sgl_hip",
"write_req_to_token_pool",
"create_flashinfer_kv_indices",
"flashinfer",
"future_token_ids",
"mla_metadata",
"clamp_position",
"compute_position",
"set_mla_kv_buffer",
)
):
has_sglang = True

repos = set()
for kind in kinds:
Expand All @@ -173,6 +192,8 @@ def get_repos_for_kernels(self, kernel_names: List[str], force_all: bool = False

if has_vllm:
repos.add("vllm")
if has_sglang:
repos.add("sglang")

return list(repos)

Expand Down
144 changes: 135 additions & 9 deletions Magpie/tools/amd_kernel_finder/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,34 +801,48 @@ def search_source(self, parsed: ParsedKernelName) -> Optional[SourceMatch]:

return None

def search_test(self, parsed: ParsedKernelName, source: Optional[SourceMatch] = None) -> Optional[TestMatch]:
def search_test(
self,
parsed: ParsedKernelName,
source: Optional[SourceMatch] = None,
category: Optional[KernelCategory] = None,
) -> Optional[TestMatch]:
"""
Search for test files and generate test command.

Args:
parsed: Parsed kernel name information
source: Optional source match for context
category: Optional category fallback for kernels whose kind cannot
route to a repo-specific test searcher.

Returns:
TestMatch if found, None otherwise
"""
if parsed.kind == KernelKind.ANNOTATION:
return None

match = None
if parsed.kind == KernelKind.TRITON_JIT:
return self._search_triton_test(parsed, source)
match = self._search_triton_test(parsed, source)
elif parsed.kind == KernelKind.TENSILE_GEMM:
return self._search_tensile_test(parsed)
match = self._search_tensile_test(parsed)
elif parsed.kind == KernelKind.CK_TILE:
return self._search_ck_test(parsed)
match = self._search_ck_test(parsed)
elif parsed.kind == KernelKind.ATEN_NATIVE:
return self._search_aten_test(parsed)
match = self._search_aten_test(parsed)
elif parsed.kind == KernelKind.HIP_CPP:
return self._search_hip_test(parsed, source)
match = self._search_hip_test(parsed, source)
elif parsed.kind == KernelKind.AITER:
return self._search_aiter_test(parsed, source)

return None
match = self._search_aiter_test(parsed, source)

if match is not None:
return match

# Some profiler names lose the suffix/namespace that lets the parser
# identify their KernelKind. Still route them to a category-level test
# when we have a stable op category such as MoE GEMM or layernorm.
return self._search_category_test(category)

# ------------------------------------------------------------------
# Baseline (PyTorch / Triton) reference search
Expand Down Expand Up @@ -930,6 +944,46 @@ def search_baseline_ref(

return None

def _search_category_test(self, category: Optional[KernelCategory]) -> Optional[TestMatch]:
"""Route unknown-kind kernels to the category's canonical test file."""
if category is None:
return None

for display_path in CATEGORY_TO_TEST_FILES.get(category, []):
match = self._test_match_from_display_path(display_path)
if match is not None:
return match

return None

@staticmethod
def _test_match_from_display_path(display_path: str) -> Optional[TestMatch]:
"""Create a TestMatch from a `$REPO_DIR/path` category mapping."""
if not display_path:
return None

repo_var = ""
test_file = display_path
if display_path.startswith("$"):
repo_var, _, test_file = display_path.partition("/")
if not test_file:
return None

if repo_var == "$AITER_DIR":
test_cmd = f"cd {repo_var} && pytest {test_file} -v"
elif repo_var == "$VLLM_DIR":
test_cmd = f"cd {repo_var} && pytest {test_file} -q"
elif repo_var == "$PYTORCH_DIR":
test_cmd = f"pytest {display_path} -q"
else:
test_cmd = f"pytest {display_path} -q"

return TestMatch(
test_file=test_file,
test_cmd=test_cmd,
repo_var=repo_var,
)

def search_triton_ref(
self,
parsed: ParsedKernelName,
Expand Down Expand Up @@ -1460,6 +1514,12 @@ def _search_hip_source(self, parsed: ParsedKernelName) -> Optional[SourceMatch]:
repo_name="vllm",
repo_var="$VLLM_DIR",
)

# Check SGLang kernels by namespace/name.
if namespace == "sgl_hip" or self._is_sglang_kernel(original_name):
match = self._search_sglang_source(function_name, original_name)
if match is not None:
return match

# Search in rocm-libraries
rocm_libs = self._repo_var_map.get("$ROCM_LIBRARIES_DIR")
Expand All @@ -1476,6 +1536,65 @@ def _search_hip_source(self, parsed: ParsedKernelName) -> Optional[SourceMatch]:
)

return None

@staticmethod
def _is_sglang_kernel(name: str) -> bool:
name_lc = name.lower()
return any(
token in name_lc
for token in (
"sgl_hip",
"write_req_to_token_pool",
"create_flashinfer_kv_indices",
"flashinfer",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm do you think there is a chance that this might misroute FlashInfer kernels into SGLang/$SGLANG_DIR

"future_token_ids",
"mla_metadata",
"clamp_position",
"compute_position",
"set_mla_kv_buffer",
)
)

def _search_sglang_source(self, function_name: str, original_name: str) -> Optional[SourceMatch]:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we make this consistent with rest? Every other search*_source returns None when nothing is found

"""Search SGLang source for runtime HIP/Triton helper kernels."""
sglang_path = self._repo_var_map.get("$SGLANG_DIR")
if not sglang_path:
return None

candidates = [function_name]
for token in (
"write_req_to_token_pool_triton",
"create_flashinfer_kv_indices_triton",
"resolve_future_token_ids_kernel",
"kn_get_mla_metadata",
"clamp_position_kernel",
"compute_position_kernel",
"set_mla_kv_buffer_kernel",
"act_and_mul_kernel",
):
if token in original_name and token not in candidates:
candidates.append(token)

for candidate in candidates:
if not candidate or candidate == "HIP kernel":
continue
pattern = rf"(def|void|__global__|template).*{re.escape(candidate)}|{re.escape(candidate)}"
files = self._search_files(pattern, sglang_path, ["py", "cpp", "cu", "hip", "hpp"])
if files:
rel_path = os.path.relpath(files[0], sglang_path)
return SourceMatch(
file_path=rel_path,
symbol=candidate,
repo_name="sglang",
repo_var="$SGLANG_DIR",
)

return SourceMatch(
file_path="python/sglang/",
symbol=function_name,
repo_name="sglang",
repo_var="$SGLANG_DIR",
)

def _search_inductor_source(self, parsed: ParsedKernelName) -> Optional[SourceMatch]:
"""Search for torch.inductor generated kernel."""
Expand Down Expand Up @@ -1669,6 +1788,13 @@ def _search_hip_test(self, parsed: ParsedKernelName,
"""Search for HIP/CUDA kernel tests."""
namespace = parsed.namespace
function_name = parsed.function_name

if (source and source.repo_name == "sglang") or namespace == "sgl_hip" or self._is_sglang_kernel(parsed.original_name):
return TestMatch(
test_file="test/",
test_cmd="cd $SGLANG_DIR && pytest test/ -q",
repo_var="$SGLANG_DIR",
)
original_name = parsed.original_name

# Known HIP kernel test mappings
Expand Down
Loading