Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ experiments/
nccl_thread_sweep_*.log
nccl_thread_sweep_summary_*.txt

# Test data and outputs (tests/gemm_analysis/)
expected_outputs/
testdata/
actual_outputs/

# IDE/project-specific folders
.vscode/
.idea/
Expand Down
314 changes: 157 additions & 157 deletions docs/comprehensive_report.html

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions scripts/gemm_analysis/run_tracelens_analysis.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ fi

# Create output directory
OUTPUT_DIR="${SWEEP_DIR}/tracelens_analysis"
# Use patched TraceLens entrypoint with GEMM recognition.
# Assume caller activates desired venv; fall back to python3 on PATH.
PYTHON_BIN="${PYTHON_BIN:-python3}"
TL_WRAPPER="scripts/gemm_analysis/tracelense_with_gemm_path.py"
if ! mkdir -p "$OUTPUT_DIR" 2>/dev/null; then
echo "Error: Cannot create output directory: $OUTPUT_DIR"
echo ""
Expand Down Expand Up @@ -134,14 +138,14 @@ for thread in "${THREAD_CONFIGS[@]}"; do
OUTPUT="$OUTPUT_DIR/$thread/individual_reports/perf_${ch}ch_rank${rank}.xlsx"

echo " Rank ${rank}..."
TraceLens_generate_perf_report_pytorch \
"$PYTHON_BIN" "$TL_WRAPPER" generate_perf_report \
--profile_json_path "$TRACE" \
--output_xlsx_path "$OUTPUT" \
--include_unlinked_kernels \
--short_kernel_study \
--short_kernel_threshold_us 50 \
--topk_ops 100 \
--enable_kernel_summary \
--enable_kernel_summary \
--topk_roofline_ops 100

echo " [OK] $OUTPUT"
Expand Down Expand Up @@ -181,16 +185,12 @@ for thread in "${THREAD_CONFIGS[@]}"; do
echo "Processing $thread/${ch}ch (all 8 ranks)..."

# Use trace_pattern instead of trace_dir for better subdirectory support
# It is not guaranteed that trace files will have the exact same name in all the ranks.
# To avoid file not found errors with `--trace_pattern` flag in TraceLens, we first
# create a directory called `trace` in all rank folders and then mv the respective
# trace file in the rank folder to the canonical `trace/pt.trace.json` path.
# This will satisfy TraceLens's requirement of only one `*` being present in the trace pattern
# while also avoiding FileNotFoundErrors due to different filenames.
find $TRACE_DIR/rank* -name "*.json" -exec sh -c 'mkdir -p "$(dirname "$0")/trace" && mv "$0" "$(dirname "$0")/trace/pt.trace.json"' {} \;

TraceLens_generate_multi_rank_collective_report_pytorch \
--trace_pattern "$TRACE_DIR/rank*/trace/pt.trace.json" \
# Find the trace filename from rank0
SAMPLE_TRACE=$(find "$TRACE_DIR/rank0" -name "*.json" | head -1)
TRACE_FILENAME=$(basename "$SAMPLE_TRACE")

"$PYTHON_BIN" "$TL_WRAPPER" generate_multi_rank_collective \
--trace_pattern "$TRACE_DIR/rank*/$TRACE_FILENAME" \
--world_size 8 \
--output_xlsx_path "$OUTPUT" \
--detailed_analysis \
Expand Down
194 changes: 194 additions & 0 deletions scripts/gemm_analysis/tracelense_with_gemm_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
#!/usr/bin/env python3
"""
TraceLens with GEMM Recognition Patches

This script applies GEMM recognition patches and runs TraceLens commands.

Usage:
python tracelens_with_gemm_patch.py generate_perf_report [args...]
python tracelens_with_gemm_patch.py generate_multi_rank_collective [args...]
python tracelens_with_gemm_patch.py compare_perf_reports [args...]
"""

import re
import sys


def apply_gemm_patches():
"""Apply all GEMM recognition patches to TraceLens."""

print("Applying TraceLens GEMM recognition patches...")

# Patch kernel_name_parser for enhanced ROCm GEMM recognition
try:
from TraceLens.PerfModel import kernel_name_parser

def patched_is_rocm_gemm(kernel_name):
"""
Enhanced ROCm GEMM pattern matching for Tensile kernels.
Recognizes: Cijk_Alik_Bljk_... and variants with arbitrary prefixes.
"""
pattern = r"^.*C[a-z]{3}_A[a-z]{3}_B[a-z]{3}.*$"
return bool(re.match(pattern, kernel_name))

def patched_parse_rocm_gemm(kernel_name):
"""Parse ROCm GEMM kernel details."""
# Parse transpose flags
trans_a, trans_b = None, None
if "_Ailk_" in kernel_name:
trans_a = False
elif "_Alik_" in kernel_name:
trans_a = True
if "_Bljk_" in kernel_name:
trans_b = False
elif "_Bjlk_" in kernel_name:
trans_b = True

# Parse macro tile size (MT64x16x64)
macro_tile_match = re.search(r"MT(\d+)x(\d+)x(\d+)", kernel_name)
if macro_tile_match:
mt_m = int(macro_tile_match.group(1))
mt_n = int(macro_tile_match.group(2))
depth_u = int(macro_tile_match.group(3))
else:
mt_m, mt_n, depth_u = None, None, None

return {
"transpose": (trans_a, trans_b),
"mt_m": mt_m,
"mt_n": mt_n,
"depth_u": depth_u,
}

def patched_gemm_name_parser(kernel_name):
"""Enhanced GEMM name parser with better ROCm support."""
if patched_is_rocm_gemm(kernel_name):
return patched_parse_rocm_gemm(kernel_name)
elif kernel_name_parser.is_cuda_gemm(kernel_name):
return kernel_name_parser.parse_cuda_gemm(kernel_name)
return None

kernel_name_parser.is_rocm_gemm = patched_is_rocm_gemm
kernel_name_parser.parse_rocm_gemm = patched_parse_rocm_gemm
kernel_name_parser.gemm_name_parser = patched_gemm_name_parser

print(" [OK] Patched kernel_name_parser (ROCm GEMM recognition)")
except ImportError as e:
print(f" [WARN] Could not patch kernel_name_parser: {e}")

# Patch Trace2Tree util for is_gemm_kernel function
try:
from TraceLens.Trace2Tree import util as trace_util

def patched_is_gemm_kernel(kernel_event: dict) -> bool:
"""Enhanced GEMM kernel detection."""
assert kernel_event["cat"] == "kernel"
kernel_name = kernel_event["name"]

# ROCm Tensile GEMM pattern: C[xyz]_A[xyz]_B[xyz]
pattern = r"^.*C[a-z]{3}_A[a-z]{3}_B[a-z]{3}.*$"
is_rocm_gemm = bool(re.match(pattern, kernel_name))

# CUDA GEMM pattern
is_cuda_gemm = kernel_name.startswith("nvjet") or "cublasLt" in kernel_name

return is_rocm_gemm or is_cuda_gemm

trace_util.is_gemm_kernel = patched_is_gemm_kernel
print(" [OK] Patched Trace2Tree.util (is_gemm_kernel)")
except ImportError as e:
print(f" [WARN] Could not patch Trace2Tree.util: {e}")

# Patch TraceEventUtils to enhance GEMM keys
try:
from TraceLens import util as tracelens_util

if hasattr(tracelens_util, 'TraceEventUtils'):
if hasattr(tracelens_util.TraceEventUtils, 'JaxOpKeys'):
original_gemm_keys = tracelens_util.TraceEventUtils.JaxOpKeys.GemmKeys
enhanced_gemm_keys = ["Cijk", "gemm", "nvjet", "cublasLt", "C[a-z]{3}_A[a-z]{3}_B[a-z]{3}"]

all_keys = list(set(original_gemm_keys + enhanced_gemm_keys))
tracelens_util.TraceEventUtils.JaxOpKeys.GemmKeys = all_keys

print(" [OK] Patched TraceEventUtils.JaxOpKeys (GEMM keys enhanced)")
except (ImportError, AttributeError) as e:
print(f" [WARN] Could not patch TraceEventUtils: {e}")

# Patch torch_op_mapping for better categorization
try:
from TraceLens.PerfModel import torch_op_mapping

original_categorize = torch_op_mapping.categorize_torch_op

def patched_categorize_torch_op(row):
"""Enhanced categorization with better GEMM detection."""
result = original_categorize(row)

# If result is 'other', check for GEMM patterns in kernel names
if result == "other" and "kernel_details" in row and len(row["kernel_details"]) > 0:
kernel_name = row["kernel_details"][0]["name"]
pattern = r"^.*C[a-z]{3}_A[a-z]{3}_B[a-z]{3}.*$"
if re.match(pattern, kernel_name):
return "GEMM"

return result

torch_op_mapping.categorize_torch_op = patched_categorize_torch_op
print(" [OK] Patched torch_op_mapping (categorize_torch_op)")
except ImportError as e:
print(f" [WARN] Could not patch torch_op_mapping: {e}")

print("[OK] All GEMM patches applied successfully!\n")


def main():
if len(sys.argv) < 2:
print("Usage: tracelens_with_gemm_patch.py <command> [args...]")
print("")
print("Commands:")
print(" generate_perf_report - Generate individual performance report")
print(" generate_multi_rank_collective - Generate multi-rank collective report")
print(" compare_perf_reports - Compare performance reports")
sys.exit(1)

# Apply patches FIRST, before any TraceLens imports
apply_gemm_patches()

# Import TraceLens modules - they will use the patched functions
from TraceLens.Reporting.generate_perf_report_pytorch import main as generate_perf_report_main
from TraceLens.Reporting.generate_multi_rank_collective_report_pytorch import main as generate_multi_rank_collective_report_main
from TraceLens.Reporting.compare_perf_reports_pytorch import main as compare_perf_reports_main

# Verify patches are still in effect after imports
from TraceLens.PerfModel import torch_op_mapping
if hasattr(torch_op_mapping.categorize_torch_op, '__name__'):
print(f"[VERIFY] categorize_torch_op function name: {torch_op_mapping.categorize_torch_op.__name__}")
if hasattr(torch_op_mapping.categorize_torch_op, '_debug_counts'):
print("[VERIFY] Debug counters found - patch confirmed active!")
else:
print("[VERIFY] No debug counters - patch may not be active!")

command = sys.argv[1]

# Remove the command from argv so TraceLens sees only its args
sys.argv = [sys.argv[0]] + sys.argv[2:]

if command == "generate_perf_report":
generate_perf_report_main()
elif command == "generate_multi_rank_collective":
generate_multi_rank_collective_report_main()
elif command == "compare_perf_reports":
compare_perf_reports_main()
else:
print(f"Error: Unknown command '{command}'")
print("")
print("Available commands:")
print(" generate_perf_report")
print(" generate_multi_rank_collective")
print(" compare_perf_reports")
sys.exit(1)


if __name__ == "__main__":
main()
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,30 @@
import pytest


def pytest_configure(config):
"""Configure pytest with custom markers for all tests."""
config.addinivalue_line(
"markers", "integration: mark test as integration test (deselect with '-m \"not integration\"')"
)
config.addinivalue_line(
"markers", "slow: mark test as slow running"
)
config.addinivalue_line(
"markers", "gpu: mark test as requiring GPU"
)


def pytest_addoption(parser):
"""Add custom command-line options."""
# GEMM regression test options
parser.addoption(
"--generate-baseline",
action="store_true",
default=False,
help="Generate baseline expected outputs for GEMM regression tests"
)


@pytest.fixture
def sample_trace_event():
"""Create a sample trace event for testing."""
Expand Down
Loading