Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 53 additions & 10 deletions flashinfer/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,22 +739,57 @@ def func_buffer_requested():
max_num_records = 0
return buffer_size, max_num_records

def set_kernel_name(activity):
if activity.kind == cupti.ActivityKind.CONCURRENT_KERNEL:
return activity.name
elif activity.kind == cupti.ActivityKind.MEMCPY:
return "MEMCPY"
elif activity.kind == cupti.ActivityKind.MEMSET:
return "MEMSET"

def get_bytes(activity):
if activity.kind in (cupti.ActivityKind.MEMCPY, cupti.ActivityKind.MEMSET):
return activity.bytes
else:
return 0

def get_copy_kind(activity):
if activity.kind == cupti.ActivityKind.MEMCPY:
return activity.copy_kind
else:
return 0

def get_value(activity):
if activity.kind == cupti.ActivityKind.MEMSET:
return activity.value
else:
return 0

def collect_kernel_info(activity):
return (
set_kernel_name(activity),
activity.start,
activity.end,
activity.correlation_id,
get_copy_kind(activity),
get_bytes(activity),
get_value(activity),
activity.kind,
)

def func_buffer_completed(
launches: list[tuple[float, float, int, int, int]],
kernels: list[tuple[str, float, float, int]],
activities: list,
):
for activity in activities:
if activity.kind == cupti.ActivityKind.CONCURRENT_KERNEL:
if activity.kind in (
cupti.ActivityKind.CONCURRENT_KERNEL,
cupti.ActivityKind.MEMCPY,
cupti.ActivityKind.MEMSET,
):
# Kernel activity
kernels.append(
(
activity.name,
activity.start,
activity.end,
activity.correlation_id,
)
)
kernels.append(collect_kernel_info(activity))
elif activity.kind in (
cupti.ActivityKind.RUNTIME,
cupti.ActivityKind.DRIVER,
Expand Down Expand Up @@ -832,6 +867,8 @@ def func_buffer_completed(
cupti.activity_enable(cupti.ActivityKind.RUNTIME)
cupti.activity_enable(cupti.ActivityKind.CONCURRENT_KERNEL)
cupti.activity_enable(cupti.ActivityKind.DRIVER)
cupti.activity_enable(cupti.ActivityKind.MEMCPY)
cupti.activity_enable(cupti.ActivityKind.MEMSET)
cupti.activity_register_callbacks(
func_buffer_requested, partial(func_buffer_completed, launches, kernels)
)
Expand All @@ -849,8 +886,14 @@ def func_buffer_completed(
cupti.activity_disable(cupti.ActivityKind.RUNTIME)
cupti.activity_disable(cupti.ActivityKind.CONCURRENT_KERNEL)
cupti.activity_disable(cupti.ActivityKind.DRIVER)
cupti.activity_disable(cupti.ActivityKind.MEMCPY)
cupti.activity_disable(cupti.ActivityKind.MEMSET)
cupti.finalize()

def generate_kernel_string(kernel):
# No start, end, correlation_id is considered in the kernel string
return f"{kernel[0]}_{kernel[4]}_{kernel[5]}_{kernel[6]}_{kernel[7]}"
Comment on lines +893 to +895
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using integer indices to access tuple elements (e.g., kernel[0], kernel[4]) makes the code hard to read and brittle to changes in the tuple structure. This pattern is also used elsewhere, such as k[3] on line 905.

To improve readability and maintainability, I recommend using a typing.NamedTuple or dataclasses.dataclass to represent the kernel information. This would allow accessing fields by name (e.g., kernel.name, kernel.copy_kind), making the code self-documenting and more robust.

For example, you could define a NamedTuple within bench_gpu_time_with_cupti:

from typing import NamedTuple, Any

class KernelInfo(NamedTuple):
    name: str
    start: float
    end: float
    correlation_id: int
    copy_kind: int
    bytes: int
    value: int
    kind: Any

Then collect_kernel_info would return a KernelInfo instance, and this function could be rewritten as:

def generate_kernel_string(kernel: KernelInfo):
    # No start, end, correlation_id is considered in the kernel string
    return f"{kernel.name}_{kernel.copy_kind}_{kernel.bytes}_{kernel.value}_{kernel.kind}"


# Process activities
measured_times = []
kernel_names = None
Expand All @@ -862,7 +905,7 @@ def func_buffer_completed(
iter_kernels = [k for k in kernels if k[3] in corr_ids]
if not iter_kernels:
raise ValueError(f"No kernel activities recorded for iteration {idx}")
current_kernel_names = set(k[0] for k in iter_kernels)
current_kernel_names = set(generate_kernel_string(k) for k in iter_kernels)
# check if the kernel names are consistent
if kernel_names is None:
kernel_names = current_kernel_names
Expand Down