Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support prometheus metrics #1853

Merged
merged 14 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import dataclasses
import logging
import time
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -254,6 +255,16 @@ def __init__(
# For Qwen2-VL
self.mrope_position_delta = [] # use mutable object

# Lifetime traces
# time when request is created and added to waitlist
self.created_time = None
# time when request is added to prefill batch
self.queued_time = None
# time when request is being processed
self.started_time = None
# time when request is finished
self.finished_time = None

# whether request reached finished condition
def finished(self) -> bool:
return self.finished_reason is not None
Expand Down Expand Up @@ -1028,6 +1039,9 @@ def __str__(self):
f"#req={(len(self.reqs))})"
)

def mark_reqs_started(self):
for req in self.reqs:
req.started_time = time.time()

@dataclasses.dataclass
class ModelWorkerBatch:
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import os
import random
import time
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum, auto
Expand Down Expand Up @@ -306,6 +307,7 @@ def add_one_req(self, req: Req):
):
# Non-chunked prefill
self.can_run_list.append(req)
req.queued_time = time.time()
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(
prefix_len,
Expand All @@ -324,6 +326,7 @@ def add_one_req(self, req: Req):
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
self.can_run_list.append(req)
req.queued_time = time.time()
self.new_inflight_req = req
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0)
Expand Down
112 changes: 109 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.metrics_collector import PrometheusMetricsCollector
from sglang.srt.metrics.metrics_types import Stats
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
broadcast_pyobj,
Expand Down Expand Up @@ -222,7 +224,8 @@ def __init__(
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
self.last_stats_tic = time.time() # time of last stats for every iter
self.last_log_tic = time.time() # time of last log for print decode log
self.stream_interval = server_args.stream_interval

# Init chunked prefill
Expand Down Expand Up @@ -291,6 +294,15 @@ def __init__(
],
with_stack=True,
)
# Init metrics stats
self.stats = Stats()
self.metrics_collector = PrometheusMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
max_model_len=self.max_total_num_tokens,
)

def watchdog_thread(self):
self.watchdog_last_forward_ct = 0
Expand Down Expand Up @@ -338,6 +350,11 @@ def event_loop_normal(self):
else:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
# log stats
if self.is_generation and self.server_args.enable_metrics:
stats = self.get_stats(batch)
self.log_stats(stats)
self.last_stats_tic = time.time()

self.last_batch = batch

Expand Down Expand Up @@ -476,6 +493,7 @@ def handle_generate_request(
self.max_req_len - len(req.origin_input_ids) - 1,
)

req.created_time = time.time()
self.waiting_queue.append(req)

def handle_embedding_request(
Expand Down Expand Up @@ -504,9 +522,11 @@ def print_decode_stats(self):
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
throughput = self.num_generated_tokens / (time.time() - self.last_log_tic)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
self.last_log_tic = time.time()
# set system stats
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
logger.info(
f"Decode batch. "
Expand Down Expand Up @@ -676,6 +696,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
# set system stats
self.stats.cache_hit_rate = round(100.0 * tree_cache_hit_rate, 2)
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)

if num_mixed_running > 0:
logger.info(
Expand Down Expand Up @@ -770,6 +793,7 @@ def run_batch(self, batch: ScheduleBatch):
if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch()
batch.mark_reqs_started()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
Expand All @@ -789,6 +813,88 @@ def run_batch(self, batch: ScheduleBatch):
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = embeddings, model_worker_batch.bid
return ret
def get_stats(self,batch: ScheduleBatch):
# TODO: get stats for chunked prefill

now = time.time()
# system stats
# Scheduler State
new_seq: int = 0
num_running_req = len(self.running_batch.reqs) if self.running_batch else 0
num_waiting_req = len(self.waiting_queue)
# Cache State
cache_hit_rate: float = 0.0
token_usage: float = 0.0

# set stats from prefill
if self.stats is not None:
# new_seq=self.stats.new_seq
cache_hit_rate=self.stats.cache_hit_rate
token_usage=self.stats.token_usage
# Iteration stats
num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0
time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = []

# Request stats
# Decode
gen_throughput: float = 0.0
# Latency
time_e2e_requests: List[float] = []
time_waiting_requests: List[float] = []
# Metadata
num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = []
finished_reason_requests: List[str] = []

# _, next_token_ids, _ = result
if batch is not None:
num_generation_tokens_iter = len(batch.output_ids)
gen_throughput = round(num_generation_tokens_iter / (now - self.last_stats_tic), 2)

for i, req in enumerate(batch.reqs):
# NOTE: Batch forward mode is extend befor start decode,
if batch.forward_mode.is_extend():
num_prompt_tokens_iter=len(batch.input_ids)+sum(batch.prefix_lens)
time_to_first_tokens_iter.append(now - req.started_time)
else:
time_per_output_tokens_iter.append(now-self.last_stats_tic)

if req.finished():
time_e2e_requests.append(now - req.created_time)
time_waiting_requests.append(req.queued_time - req.created_time)
num_prompt_tokens_requests.append(len(req.origin_input_ids))
num_generation_tokens_requests.append(len(req.output_ids))
finished_reason_requests.append(
req.finished_reason.to_json()
if req.finished_reason is not None
else None)

return Stats(
new_seq=new_seq,
num_running_req=num_running_req,
num_waiting_req=num_waiting_req,
cache_hit_rate=cache_hit_rate,
token_usage=token_usage,
num_prompt_tokens_iter=num_prompt_tokens_iter,
num_generation_tokens_iter=num_generation_tokens_iter,
time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter,
gen_throughput=gen_throughput,
time_e2e_requests=time_e2e_requests,
time_waiting_requests=time_waiting_requests,
num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests,
finished_reason_requests=finished_reason_requests,
context_len=self.model_config.context_len,
max_total_num_tokens=self.max_total_num_tokens,
max_prefill_tokens=self.max_prefill_tokens,
max_running_requests=self.max_running_requests,
)

def log_stats(self,stats:Stats):
self.metrics_collector.log_stats(stats)

def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_decode():
Expand Down
Loading
Loading