Skip to content

Commit

Permalink
[Profile] Add pytorch profiler (#1604)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 authored Oct 7, 2024
1 parent ebbc42d commit c5325ab
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 0 deletions.
12 changes: 12 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
is_generation_model,
is_multimodal_model,
kill_parent_process,
pytorch_profile,
set_random_seed,
suppress_other_loggers,
)
Expand Down Expand Up @@ -409,6 +410,10 @@ def run_step(self):
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
# Run a new prefill batch
# replace run_batch with the uncommented line to use pytorch profiler
# result = pytorch_profile(
# "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
# )
result = self.run_batch(new_batch)
self.process_batch_result(new_batch, result)
else:
Expand All @@ -418,6 +423,13 @@ def run_step(self):
batch = self.get_new_batch_decode()

if batch:
# replace run_batch with the uncommented line to use pytorch profiler
# result = pytorch_profile(
# "profile_decode_step",
# self.run_batch,
# batch,
# data_size=len(batch.reqs),
# )
result = self.run_batch(batch)
self.process_batch_result(batch, result)

Expand Down
33 changes: 33 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import base64
import ipaddress
import json
import logging
import os
import pickle
Expand All @@ -37,6 +38,7 @@
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from torch import nn
from torch.profiler import ProfilerActivity, profile, record_function
from triton.runtime.cache import (
FileCacheManager,
default_cache_dir,
Expand Down Expand Up @@ -642,3 +644,34 @@ def broadcast_pyobj(
serialized_data = bytes(tensor_data.cpu().numpy())
data = pickle.loads(serialized_data)
return data


step_counter = 0


def pytorch_profile(name, func, *args, data_size=-1):
"""
Args:
name (string): the name of recorded function.
func: the function to be profiled.
args: the arguments of the profiled function.
data_size (int): some measurement of the computation complexity.
Usually, it could be the batch size.
"""
global step_counter
os.makedirs("trace", exist_ok=True)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
# schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
# on_trace_ready=tensorboard_trace_handler('./log_dir'),
record_shapes=True,
profile_memory=True,
with_stack=True,
) as prof:
with record_function(name):
with open(f"trace/size_{step_counter}.json", "w") as f:
json.dump({"size": data_size}, f)
result = func(*args)
prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
step_counter += 1
return result
40 changes: 40 additions & 0 deletions scripts/fix_corrupted_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import json
import re
import sys


def clean_json_file(input_file, output_file):
try:
# Open the input file with 'replace' option for handling bad characters
with open(input_file, "r", encoding="utf-8", errors="replace") as f:
data = f.read()

# Replace bad characters (represented by '�' after decoding) with a space
cleaned_data = data.replace("�", " ")

# Remove control characters (e.g., ASCII control characters like \x00 to \x1F)
# These can cause issues in JSON parsing.
cleaned_data = re.sub(r"[\x00-\x1F]+", " ", cleaned_data)

# Parse cleaned data as JSON
json_data = json.loads(cleaned_data)

# Write the cleaned JSON to a new output file
with open(output_file, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False, indent=4)

print(f"Cleaned JSON file has been saved to {output_file}")

except Exception as e:
print(f"Error: {e}")


if __name__ == "__main__":
assert len(sys.argv) > 1, "please give the input file path"
if len(sys.argv) == 3:
input_file = sys.argv[1]
output_file = sys.argv[2]
else:
input_file = output_file = sys.argv[1]

clean_json_file(input_file, output_file)
77 changes: 77 additions & 0 deletions scripts/playground/lora/analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import glob
import json
import os
import re
import sys

from tqdm import tqdm

sys.path.append("../../")
from fix_corrupted_json import clean_json_file

dirpath = "/Users/ying"
output_file_prefix = "analyzed_log"

time = {}
tot_time = {}
size = {}

os.system(f"rm {output_file_prefix}*")

for dirname in glob.glob(os.path.join(dirpath, "trace*")):
print(dirname)
trace_name = dirname.split("/")[-1]
time[trace_name] = {}
size[trace_name] = {}
total_time = 0
for filename in tqdm(glob.glob(os.path.join(dirname, "*.json"))):
step_name = filename.split("/")[-1].split(".")[0]
step_name = "_".join(step_name.split("_")[1:])
if "prefill" not in filename and "decode" not in filename:
continue

match = re.search(r"(prefill|decode)_step_(\d+)\.json", filename)
if match:
phase = match.group(1)
step = match.group(2)
else:
raise Exception(f"Cannot parse {filename}")

try:
with open(filename, "r") as f:
trace = json.load(f)
except:
clean_json_file(filename, filename)
with open(filename, "r") as f:
trace = json.load(f)

for event in trace["traceEvents"]:
name = event["name"]
if name in ["profile_prefill_step", "profile_decode_step"]:
dur = event["dur"] / 1e3
time[trace_name][step_name] = dur
break
total_time += dur

step = int(step_name.split("_")[-1])
with open(os.path.join(dirname, f"size_{step}.json"), "r") as f:
size_info = json.load(f)
size[trace_name][step_name] = size_info["size"]

tot_time[trace_name] = total_time
time[trace_name] = dict(
sorted(time[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
)
size[trace_name] = dict(
sorted(size[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
)

with open(f"{output_file_prefix}_{trace_name}", "a") as f:
for k, v in time[trace_name].items():
size_v = size[trace_name][k]
print(f"{k:>15}{v:10.2f}\t{size_v}")
f.write(f"{k:>15}{v:10.2f}\t{size_v}\n")

with open(f"{output_file_prefix}_total_time", "w") as f:
print(tot_time)
json.dump(tot_time, f)

0 comments on commit c5325ab

Please sign in to comment.