-
Notifications
You must be signed in to change notification settings - Fork 6
Add sglang repo in to find_kernel_sources sources #37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,12 @@ class KernelNameParser: | |
| INDUCTOR_PATTERN = re.compile(r'triton_\w+_fused_') | ||
| HIPBLASLT_PATTERN = re.compile(r'wvSplitK|wvSpltK|DeviceGemmWmma') | ||
| AITER_PATTERN = re.compile(r'^_ZN5aiter|aiter::') | ||
| SGLANG_PATTERN = re.compile( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there any correlation with _is_sglang_kernel in searcher.py? there is some regex difference with that. |
||
| r'^_ZN7sgl_hip|sgl_hip::|write_req_to_token_pool|' | ||
| r'create_flashinfer_kv_indices|flashinfer|future_token_ids|' | ||
| r'kn_get_mla_metadata|mla_metadata|clamp_position|compute_position|' | ||
| r'set_mla_kv_buffer' | ||
| ) | ||
| ROCM_RUNTIME_PATTERN = re.compile(r'^__amd_rocclr_|^MEMORY_COPY_') | ||
|
|
||
| # Category keywords. | ||
|
|
@@ -108,6 +114,11 @@ def _classify_kind(self, name: str) -> KernelKind: | |
| # Check for aiter kernels (before CK and Triton checks) | ||
| if self.AITER_PATTERN.search(name): | ||
| return KernelKind.AITER | ||
|
|
||
| # SGLang owns HIP/C++ and Triton-emitted runtime kernels; keep the kind | ||
| # generic, like vLLM, and let source search route to $SGLANG_DIR. | ||
| if self.SGLANG_PATTERN.search(name): | ||
| return KernelKind.HIP_CPP | ||
|
|
||
| # Check for CK tile (handles both mangled and readable names) | ||
| if self.CK_PATTERN.search(name): | ||
|
|
@@ -321,12 +332,37 @@ def _parse_hip(self, name: str) -> ParsedKernelName: | |
| extra['namespace'] = match.group(1) | ||
| function_name = match.group(2) | ||
| else: | ||
| if 'act_and_mul_kernel' in name: | ||
| extra['namespace'] = 'sgl_hip' | ||
| function_name = 'act_and_mul_kernel' | ||
| elif 'write_req_to_token_pool_triton' in name: | ||
| function_name = 'write_req_to_token_pool_triton' | ||
| elif 'create_flashinfer_kv_indices_triton' in name: | ||
| function_name = 'create_flashinfer_kv_indices_triton' | ||
| elif 'kn_get_mla_metadata' in name: | ||
| function_name = 'kn_get_mla_metadata' | ||
| elif 'clamp_position_kernel' in name: | ||
| function_name = 'clamp_position_kernel' | ||
| elif 'compute_position_kernel' in name: | ||
| function_name = 'compute_position_kernel' | ||
| elif 'resolve_future_token_ids_kernel' in name: | ||
| function_name = 'resolve_future_token_ids_kernel' | ||
| elif 'set_mla_kv_buffer_kernel' in name: | ||
| function_name = 'set_mla_kv_buffer_kernel' | ||
| elif '_ZN7sgl_hip' in name: | ||
| extra['namespace'] = 'sgl_hip' | ||
| function_name = 'SGLang HIP kernel' | ||
| # Try without namespace | ||
| else: | ||
| match = re.search(r'void (?:\(anonymous namespace\)::)?(\w+)', name) | ||
| if match: | ||
| function_name = match.group(1) | ||
| else: | ||
| function_name = "HIP kernel" | ||
| if function_name == "HIP kernel": | ||
| match = re.search(r'void (\w+)<', name) | ||
| if match: | ||
| function_name = match.group(1) | ||
| else: | ||
| function_name = "HIP kernel" | ||
|
|
||
| # Extract dtype | ||
| if '__hip_bfloat16' in name: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -801,34 +801,48 @@ def search_source(self, parsed: ParsedKernelName) -> Optional[SourceMatch]: | |
|
|
||
| return None | ||
|
|
||
| def search_test(self, parsed: ParsedKernelName, source: Optional[SourceMatch] = None) -> Optional[TestMatch]: | ||
| def search_test( | ||
| self, | ||
| parsed: ParsedKernelName, | ||
| source: Optional[SourceMatch] = None, | ||
| category: Optional[KernelCategory] = None, | ||
| ) -> Optional[TestMatch]: | ||
| """ | ||
| Search for test files and generate test command. | ||
|
|
||
| Args: | ||
| parsed: Parsed kernel name information | ||
| source: Optional source match for context | ||
| category: Optional category fallback for kernels whose kind cannot | ||
| route to a repo-specific test searcher. | ||
|
|
||
| Returns: | ||
| TestMatch if found, None otherwise | ||
| """ | ||
| if parsed.kind == KernelKind.ANNOTATION: | ||
| return None | ||
|
|
||
| match = None | ||
| if parsed.kind == KernelKind.TRITON_JIT: | ||
| return self._search_triton_test(parsed, source) | ||
| match = self._search_triton_test(parsed, source) | ||
| elif parsed.kind == KernelKind.TENSILE_GEMM: | ||
| return self._search_tensile_test(parsed) | ||
| match = self._search_tensile_test(parsed) | ||
| elif parsed.kind == KernelKind.CK_TILE: | ||
| return self._search_ck_test(parsed) | ||
| match = self._search_ck_test(parsed) | ||
| elif parsed.kind == KernelKind.ATEN_NATIVE: | ||
| return self._search_aten_test(parsed) | ||
| match = self._search_aten_test(parsed) | ||
| elif parsed.kind == KernelKind.HIP_CPP: | ||
| return self._search_hip_test(parsed, source) | ||
| match = self._search_hip_test(parsed, source) | ||
| elif parsed.kind == KernelKind.AITER: | ||
| return self._search_aiter_test(parsed, source) | ||
|
|
||
| return None | ||
| match = self._search_aiter_test(parsed, source) | ||
|
|
||
| if match is not None: | ||
| return match | ||
|
|
||
| # Some profiler names lose the suffix/namespace that lets the parser | ||
| # identify their KernelKind. Still route them to a category-level test | ||
| # when we have a stable op category such as MoE GEMM or layernorm. | ||
| return self._search_category_test(category) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Baseline (PyTorch / Triton) reference search | ||
|
|
@@ -930,6 +944,46 @@ def search_baseline_ref( | |
|
|
||
| return None | ||
|
|
||
| def _search_category_test(self, category: Optional[KernelCategory]) -> Optional[TestMatch]: | ||
| """Route unknown-kind kernels to the category's canonical test file.""" | ||
| if category is None: | ||
| return None | ||
|
|
||
| for display_path in CATEGORY_TO_TEST_FILES.get(category, []): | ||
| match = self._test_match_from_display_path(display_path) | ||
| if match is not None: | ||
| return match | ||
|
|
||
| return None | ||
|
|
||
| @staticmethod | ||
| def _test_match_from_display_path(display_path: str) -> Optional[TestMatch]: | ||
| """Create a TestMatch from a `$REPO_DIR/path` category mapping.""" | ||
| if not display_path: | ||
| return None | ||
|
|
||
| repo_var = "" | ||
| test_file = display_path | ||
| if display_path.startswith("$"): | ||
| repo_var, _, test_file = display_path.partition("/") | ||
| if not test_file: | ||
| return None | ||
|
|
||
| if repo_var == "$AITER_DIR": | ||
| test_cmd = f"cd {repo_var} && pytest {test_file} -v" | ||
| elif repo_var == "$VLLM_DIR": | ||
| test_cmd = f"cd {repo_var} && pytest {test_file} -q" | ||
| elif repo_var == "$PYTORCH_DIR": | ||
| test_cmd = f"pytest {display_path} -q" | ||
| else: | ||
| test_cmd = f"pytest {display_path} -q" | ||
|
|
||
| return TestMatch( | ||
| test_file=test_file, | ||
| test_cmd=test_cmd, | ||
| repo_var=repo_var, | ||
| ) | ||
|
|
||
| def search_triton_ref( | ||
| self, | ||
| parsed: ParsedKernelName, | ||
|
|
@@ -1460,6 +1514,12 @@ def _search_hip_source(self, parsed: ParsedKernelName) -> Optional[SourceMatch]: | |
| repo_name="vllm", | ||
| repo_var="$VLLM_DIR", | ||
| ) | ||
|
|
||
| # Check SGLang kernels by namespace/name. | ||
| if namespace == "sgl_hip" or self._is_sglang_kernel(original_name): | ||
| match = self._search_sglang_source(function_name, original_name) | ||
| if match is not None: | ||
| return match | ||
|
|
||
| # Search in rocm-libraries | ||
| rocm_libs = self._repo_var_map.get("$ROCM_LIBRARIES_DIR") | ||
|
|
@@ -1476,6 +1536,65 @@ def _search_hip_source(self, parsed: ParsedKernelName) -> Optional[SourceMatch]: | |
| ) | ||
|
|
||
| return None | ||
|
|
||
| @staticmethod | ||
| def _is_sglang_kernel(name: str) -> bool: | ||
| name_lc = name.lower() | ||
| return any( | ||
| token in name_lc | ||
| for token in ( | ||
| "sgl_hip", | ||
| "write_req_to_token_pool", | ||
| "create_flashinfer_kv_indices", | ||
| "flashinfer", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm do you think there is a chance that this might misroute FlashInfer kernels into SGLang/$SGLANG_DIR |
||
| "future_token_ids", | ||
| "mla_metadata", | ||
| "clamp_position", | ||
| "compute_position", | ||
| "set_mla_kv_buffer", | ||
| ) | ||
| ) | ||
|
|
||
| def _search_sglang_source(self, function_name: str, original_name: str) -> Optional[SourceMatch]: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we make this consistent with rest? Every other search*_source returns None when nothing is found |
||
| """Search SGLang source for runtime HIP/Triton helper kernels.""" | ||
| sglang_path = self._repo_var_map.get("$SGLANG_DIR") | ||
| if not sglang_path: | ||
| return None | ||
|
|
||
| candidates = [function_name] | ||
| for token in ( | ||
| "write_req_to_token_pool_triton", | ||
| "create_flashinfer_kv_indices_triton", | ||
| "resolve_future_token_ids_kernel", | ||
| "kn_get_mla_metadata", | ||
| "clamp_position_kernel", | ||
| "compute_position_kernel", | ||
| "set_mla_kv_buffer_kernel", | ||
| "act_and_mul_kernel", | ||
| ): | ||
| if token in original_name and token not in candidates: | ||
| candidates.append(token) | ||
|
|
||
| for candidate in candidates: | ||
| if not candidate or candidate == "HIP kernel": | ||
| continue | ||
| pattern = rf"(def|void|__global__|template).*{re.escape(candidate)}|{re.escape(candidate)}" | ||
| files = self._search_files(pattern, sglang_path, ["py", "cpp", "cu", "hip", "hpp"]) | ||
| if files: | ||
| rel_path = os.path.relpath(files[0], sglang_path) | ||
| return SourceMatch( | ||
| file_path=rel_path, | ||
| symbol=candidate, | ||
| repo_name="sglang", | ||
| repo_var="$SGLANG_DIR", | ||
| ) | ||
|
|
||
| return SourceMatch( | ||
| file_path="python/sglang/", | ||
| symbol=function_name, | ||
| repo_name="sglang", | ||
| repo_var="$SGLANG_DIR", | ||
| ) | ||
|
|
||
| def _search_inductor_source(self, parsed: ParsedKernelName) -> Optional[SourceMatch]: | ||
| """Search for torch.inductor generated kernel.""" | ||
|
|
@@ -1669,6 +1788,13 @@ def _search_hip_test(self, parsed: ParsedKernelName, | |
| """Search for HIP/CUDA kernel tests.""" | ||
| namespace = parsed.namespace | ||
| function_name = parsed.function_name | ||
|
|
||
| if (source and source.repo_name == "sglang") or namespace == "sgl_hip" or self._is_sglang_kernel(parsed.original_name): | ||
| return TestMatch( | ||
| test_file="test/", | ||
| test_cmd="cd $SGLANG_DIR && pytest test/ -q", | ||
| repo_var="$SGLANG_DIR", | ||
| ) | ||
| original_name = parsed.original_name | ||
|
|
||
| # Known HIP kernel test mappings | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we going to have these columns always the same? so always 7-19?