From 628e00b9169e63d83e7c89679135524a36f7803e Mon Sep 17 00:00:00 2001 From: zhanglei335 Date: Thu, 31 Oct 2024 14:06:34 +0800 Subject: [PATCH 1/9] support prometheus metrics --- python/sglang/srt/managers/schedule_batch.py | 14 + python/sglang/srt/managers/scheduler.py | 108 ++++++- .../sglang/srt/metrics/metrics_collector.py | 297 ++++++++++++++++++ python/sglang/srt/metrics/metrics_types.py | 57 ++++ python/sglang/srt/server.py | 37 +++ python/sglang/srt/server_args.py | 7 + 6 files changed, 519 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/metrics/metrics_collector.py create mode 100644 python/sglang/srt/metrics/metrics_types.py diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 85ca560a926..5d34a893586 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -31,6 +31,7 @@ import dataclasses import logging +import time from typing import List, Optional, Tuple, Union import torch @@ -251,6 +252,16 @@ def __init__( # For Qwen2-VL self.mrope_position_delta = [] # use mutable object + # Lifetime traces + # time when request is created and added to waitlist + self.created_time = None + # time when request is added to prefill batech + self.queued_time = None + # time when request is being processed + self.started_time = None + # time when request is finished + self.finished_time = None + # whether request reached finished condition def finished(self) -> bool: return self.finished_reason is not None @@ -1023,6 +1034,9 @@ def __str__(self): f"#req={(len(self.reqs))})" ) + def mark_reqs_started(self): + for req in self.reqs: + req.started_time = time.time() @dataclasses.dataclass class ModelWorkerBatch: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 47f0b7d4413..e9f15163bbd 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -76,6 +76,8 @@ suppress_other_loggers, ) from sglang.utils import get_exception_traceback +from sglang.srt.metrics.metrics_types import Stats +from sglang.srt.metrics.metrics_collector import PrometheusMetricsCollector logger = logging.getLogger(__name__) @@ -295,6 +297,15 @@ def __init__( ], with_stack=True, ) + # Init metrics stats + self.stats = Stats() + self.metrics_collector = PrometheusMetricsCollector( + labels={ + "name": self.model_config.path, + # TODO: Add lora name/path in the future, + }, + max_model_len=self.max_total_num_tokens, + ) def watchdog_thread(self): self.watchdog_last_forward_ct = 0 @@ -342,6 +353,11 @@ def event_loop_normal(self): else: self.check_memory() self.new_token_ratio = self.init_new_token_ratio + # log stats + if self.is_generation and not self.server_args.disable_log_stats: + stats = self.get_stats(batch) + self.log_stats(stats) + self.last_stats_tic = time.time() self.last_batch = batch @@ -482,6 +498,7 @@ def handle_generate_request( self.max_req_len - len(req.origin_input_ids) - 1, ) + req.created_time = time.time() self.waiting_queue.append(req) def handle_embedding_request( @@ -512,7 +529,9 @@ def print_decode_stats(self): ) throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic) self.num_generated_tokens = 0 - self.last_stats_tic = time.time() + # self.last_stats_tic = time.time() + # set system stats + self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2) num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 logger.info( f"Decode batch. " @@ -651,6 +670,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if res == AddReqResult.NO_TOKEN: self.batch_is_full = True break + req.queued_time = time.time() # Update waiting queue can_run_list = adder.can_run_list @@ -684,6 +704,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) + # set system stats + self.stats.cache_hit_rate = round(100.0 * tree_cache_hit_rate, 2) + self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2) if num_mixed_running > 0: logger.info( @@ -776,6 +799,7 @@ def run_batch(self, batch: ScheduleBatch): if self.is_generation: if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: model_worker_batch = batch.get_model_worker_batch() + batch.mark_reqs_started() logits_output, next_token_ids = self.tp_worker.forward_batch_generation( model_worker_batch ) @@ -795,6 +819,88 @@ def run_batch(self, batch: ScheduleBatch): embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) ret = embeddings, model_worker_batch.bid return ret + def get_stats(self,batch: ScheduleBatch): + # TODO: get stats for chunked prefill + + now = time.time() + # system stats + # Scheduler State + new_seq: int = 0 + num_running_req = len(self.running_batch.reqs) if self.running_batch else 0 + num_waiting_req = len(self.waiting_queue) + # Cache State + cache_hit_rate: float = 0.0 + token_usage: float = 0.0 + + # set stats from prefill + if self.stats is not None: + # new_seq=self.stats.new_seq + cache_hit_rate=self.stats.cache_hit_rate + token_usage=self.stats.token_usage + # Iteration stats + num_prompt_tokens_iter = 0 + num_generation_tokens_iter = 0 + time_to_first_tokens_iter: List[float] = [] + time_per_output_tokens_iter: List[float] = [] + + # Request stats + # Decode + gen_throughput: float = 0.0 + # Latency + time_e2e_requests: List[float] = [] + time_waiting_requests: List[float] = [] + # Metadata + num_prompt_tokens_requests: List[int] = [] + num_generation_tokens_requests: List[int] = [] + finished_reason_requests: List[str] = [] + + # _, next_token_ids, _ = result + if batch is not None: + num_generation_tokens_iter = len(batch.output_ids) + gen_throughput = round(num_generation_tokens_iter / (now - self.last_stats_tic), 2) + + for i, req in enumerate(batch.reqs): + # NOTE: Batch forward mode is extend befor start decode, + if batch.forward_mode.is_extend(): + num_prompt_tokens_iter=len(batch.input_ids)+sum(batch.prefix_lens) + time_to_first_tokens_iter.append(now - req.started_time) + else: + time_per_output_tokens_iter.append(now-self.last_stats_tic) + + if req.finished(): + time_e2e_requests.append(now - req.created_time) + time_waiting_requests.append(req.queued_time - req.created_time) + num_prompt_tokens_requests.append(len(req.origin_input_ids)) + num_generation_tokens_requests.append(len(req.output_ids)) + finished_reason_requests.append( + req.finished_reason.to_json() + if req.finished_reason is not None + else None) + + return Stats( + new_seq=new_seq, + num_running_req=num_running_req, + num_waiting_req=num_waiting_req, + cache_hit_rate=cache_hit_rate, + token_usage=token_usage, + num_prompt_tokens_iter=num_prompt_tokens_iter, + num_generation_tokens_iter=num_generation_tokens_iter, + time_to_first_tokens_iter=time_to_first_tokens_iter, + time_per_output_tokens_iter=time_per_output_tokens_iter, + gen_throughput=gen_throughput, + time_e2e_requests=time_e2e_requests, + time_waiting_requests=time_waiting_requests, + num_prompt_tokens_requests=num_prompt_tokens_requests, + num_generation_tokens_requests=num_generation_tokens_requests, + finished_reason_requests=finished_reason_requests, + context_len=self.model_config.context_len, + max_total_num_tokens=self.max_total_num_tokens, + max_prefill_tokens=self.max_prefill_tokens, + max_running_requests=self.max_running_requests, + ) + + def log_stats(self,stats:Stats): + self.metrics_collector.log_stats(stats) def process_batch_result(self, batch: ScheduleBatch, result): if batch.forward_mode.is_decode(): diff --git a/python/sglang/srt/metrics/metrics_collector.py b/python/sglang/srt/metrics/metrics_collector.py new file mode 100644 index 00000000000..bb1eca88b14 --- /dev/null +++ b/python/sglang/srt/metrics/metrics_collector.py @@ -0,0 +1,297 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Utilities for Prometheus Metrics Collection.""" + +import logging +from abc import ABC, abstractmethod +from typing import Counter as CollectionsCounter +from typing import Dict, List, Union + +import numpy as np +from prometheus_client import Counter, Gauge, Histogram + +from sglang.srt.metrics.metrics_types import Stats + + +class Metrics: + """ + SGLang Metrics + """ + + def __init__(self, labelnames: List[str], max_model_len): + + # Configuration Stats + self.max_total_num_tokens = Gauge( + name="sglang:max_total_num_tokens", + documentation="Maximum total number of tokens", + labelnames=labelnames, + multiprocess_mode="min", + ) # static across processes + + self.max_prefill_tokens = Gauge( + name="sglang:max_prefill_tokens", + documentation="Maximum prefill tokens", + labelnames=labelnames, + multiprocess_mode="min", + ) # static across processes + + self.max_running_requests = Gauge( + name="sglang:max_running_requests", + documentation="Maximum running requests", + labelnames=labelnames, + multiprocess_mode="min", + ) # static across processes + + self.context_len = Gauge( + name="sglang:context_len", + documentation="Context length", + labelnames=labelnames, + multiprocess_mode="min", + ) # static across processes + # Decode Stats + self.num_running_sys = Gauge( + name="sglang:num_requests_running", + documentation="Number of requests currently running on GPU", + labelnames=labelnames, + multiprocess_mode="sum", + ) + self.num_waiting_sys = Gauge( + name="sglang:num_requests_waiting", + documentation="Number of requests waiting to be processed.", + labelnames=labelnames, + multiprocess_mode="sum", + ) + self.gen_throughput = Gauge( + name="sglang:gen_throughput", + documentation="Gen token throughput (token/s)", + labelnames=labelnames, + multiprocess_mode="sum", + ) + self.token_usage = Gauge( + name="sglang:token_usage", + documentation="Total token usage", + labelnames=labelnames, + multiprocess_mode="sum", + ) + # System Stats + # KV Cache Usage in % + # self.gpu_cache_usage_sys = Gauge( + # "gpu_cache_usage_perc", + # "GPU KV-cache usage. 1 means 100 percent usage.", + # labelnames=labelnames, + # multiprocess_mode="sum") + + self.new_seq = Gauge( + name="sglang:new_seq", + documentation="Number of new sequences", + labelnames=labelnames, + multiprocess_mode="sum", + ) + self.new_token = Gauge( + name="sglang:new_token", + documentation="Number of new token", + labelnames=labelnames, + multiprocess_mode="sum", + ) + # Prefix caching block hit rate + self.cached_token = Gauge( + name="sglang:cached_token", + documentation="Number of cached token", + labelnames=labelnames, + multiprocess_mode="sum", + ) + self.cache_hit_rate = Gauge( + name="sglang:cache_hit_rate", + documentation="Cache hit rate", + labelnames=labelnames, + multiprocess_mode="sum", + ) + self.queue_req = Gauge( + name="sglang:queue_req", + documentation="Number of queued requests", + labelnames=labelnames, + multiprocess_mode="sum", + ) + + # Iteration stats + self.counter_prompt_tokens = Counter( + name="sglang:prompt_tokens_total", + documentation="Number of prefill tokens processed.", + labelnames=labelnames) + self.counter_generation_tokens = Counter( + name="sglang:generation_tokens_total", + documentation="Number of generation tokens processed.", + labelnames=labelnames) + self.histogram_time_to_first_token = Histogram( + name="sglang:time_to_first_token_seconds", + documentation="Histogram of time to first token in seconds.", + labelnames=labelnames, + buckets=[ + 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, + 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 25.0, 30.0 + ]) + self.histogram_time_per_output_token = Histogram( + name="sglang:time_per_output_token_seconds", + documentation="Histogram of time per output token in seconds.", + labelnames=labelnames, + buckets=[ + 0.005, 0.01, 0.015, 0.02, 0.025, 0.03, 0.04, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5 + ]) + + # Request Stats + # Metadata + self.num_prompt_tokens_requests = Histogram( + name="sglang:request_prompt_tokens", + documentation="Number of prefill tokens processed", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.num_generation_tokens_requests = Histogram( + name="sglang:request_generation_tokens", + documentation="Number of generation tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.finished_reason_requests = Counter( + name="sglang:request_success_total", + documentation="Count of successfully processed requests.", + labelnames=labelnames + ["finished_reason"], + ) + self.histogram_time_e2e_requests = Histogram( + name="sglang:e2e_request_latency_seconds", + documentation="Histogram of End-to-end request latency in seconds", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_time_waiting_requests = Histogram( + name="sglang:waiting_request_latency_seconds", + documentation="Histogram of request waiting time in seconds", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_time_decode_requests = Histogram( + name="sglang:decode_request_latency_seconds", + documentation="Histogram of request decoding time in seconds", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + + +class MetricsCollector(ABC): + """ + SGLang Metrics Collector + """ + + @abstractmethod + def log_stats(self, stats: Stats) -> None: + pass + + +class PrometheusMetricsCollector(MetricsCollector): + """ + SGLang Metrics Collector + """ + + def __init__(self, labels: Dict[str, str], max_model_len: int) -> None: + self.labels = labels + self.metrics = Metrics( + labelnames=list(labels.keys()), max_model_len=max_model_len + ) + + def _log_gauge(self, gauge, data: Union[int, float]) -> None: + # Convenience function for logging to gauge. + gauge.labels(**self.labels).set(data) + + def _log_counter(self, counter, data: Union[int, float]) -> None: + # Convenience function for logging to counter. + counter.labels(**self.labels).inc(data) + + def _log_counter_labels( + self, counter, data: CollectionsCounter, label_key: str + ) -> None: + # Convenience function for collection counter of labels. + for label, count in data.items(): + counter.labels(**{**self.labels, label_key: label}).inc(count) + + def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None: + # Convenience function for logging list to histogram. + for datum in data: + histogram.labels(**self.labels).observe(datum) + + def log_stats(self, stats: Stats) -> None: + self._log_gauge(self.metrics.max_total_num_tokens, stats.max_total_num_tokens) + self._log_gauge(self.metrics.max_prefill_tokens, stats.max_prefill_tokens) + self._log_gauge(self.metrics.max_running_requests, stats.max_running_requests) + self._log_gauge(self.metrics.context_len, stats.context_len) + self._log_histogram( + self.metrics.num_prompt_tokens_requests, stats.num_prompt_tokens_requests + ) + self._log_histogram( + self.metrics.num_generation_tokens_requests, + stats.num_generation_tokens_requests, + ) + + self._log_counter(self.metrics.counter_prompt_tokens, + stats.num_prompt_tokens_iter) + self._log_counter(self.metrics.counter_generation_tokens, + stats.num_generation_tokens_iter) + self._log_histogram(self.metrics.histogram_time_to_first_token, + stats.time_to_first_tokens_iter) + self._log_histogram(self.metrics.histogram_time_per_output_token, + stats.time_per_output_tokens_iter) + + # self._log_gauge(self.metrics.gpu_cache_usage_sys, stats.gpu_cache_usage_sys) + self._log_gauge(self.metrics.num_running_sys, stats.num_running_req) + self._log_gauge(self.metrics.num_waiting_sys, stats.num_waiting_req) + self._log_gauge(self.metrics.gen_throughput, stats.gen_throughput) + self._log_gauge(self.metrics.token_usage, stats.token_usage) + self._log_histogram( + self.metrics.histogram_time_e2e_requests, stats.time_e2e_requests + ) + self._log_histogram( + self.metrics.histogram_time_waiting_requests, stats.time_waiting_requests + ) + self._log_histogram( + self.metrics.histogram_time_decode_requests, stats.time_decode_requests + ) + self._log_gauge(self.metrics.new_seq, stats.new_seq) + self._log_gauge(self.metrics.new_token, stats.new_token) + self._log_gauge(self.metrics.cached_token, stats.cached_token) + self._log_gauge(self.metrics.cache_hit_rate, stats.cache_hit_rate) + self._log_gauge(self.metrics.queue_req, stats.queue_req) + + +def build_1_2_5_buckets(max_value: int) -> List[int]: + """ + Builds a list of buckets with increasing powers of 10 multiplied by + mantissa values (1, 2, 5) until the value exceeds the specified maximum. + + Example: + >>> build_1_2_5_buckets(100) + [1, 2, 5, 10, 20, 50, 100] + """ + mantissa_lst = [1, 2, 5] + exponent = 0 + buckets: List[int] = [] + while True: + for m in mantissa_lst: + value = m * 10**exponent + if value <= max_value: + buckets.append(value) + else: + return buckets + exponent += 1 \ No newline at end of file diff --git a/python/sglang/srt/metrics/metrics_types.py b/python/sglang/srt/metrics/metrics_types.py new file mode 100644 index 00000000000..f1b357f403d --- /dev/null +++ b/python/sglang/srt/metrics/metrics_types.py @@ -0,0 +1,57 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Metrics Types""" + +from dataclasses import dataclass, field +from typing import List + + +@dataclass +class Stats: + # config + max_total_num_tokens: int = 0 + max_prefill_tokens: int = 0 + max_running_requests: int = 0 + context_len: int = 0 + # request stats + num_prompt_tokens_requests: List[int] = field(default_factory=list) + num_generation_tokens_requests: List[int] = field(default_factory=list) + finished_reason_requests: List[str] = field(default_factory=list) + # decode stats + num_running_req: int = 0 + num_waiting_req: int = 0 + gen_throughput: float = 0.0 + num_token: int = 0 + token_usage: float = 0.0 + waiting_queue: int = 0 + time_e2e_requests: List[float] = field(default_factory=list) + time_waiting_requests: List[float] = field(default_factory=list) + time_decode_requests: List[float] = field(default_factory=list) + # system stats + token_usage: float = 0.0 + is_mixed_chunk: bool = False + new_seq: int = 0 + new_token: int = 0 + cached_token: int = 0 + cache_hit_rate: float = 0.0 + running_req: int = 0 + queue_req: int = 0 + + # Iteration stats (should have _iter suffix) + num_prompt_tokens_iter: int = 0 + num_generation_tokens_iter: int = 0 + time_to_first_tokens_iter: List[float] = field(default_factory=list) + time_per_output_tokens_iter: List[float] = field(default_factory=list) \ No newline at end of file diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 81d86f5dcb5..958c59178de 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -27,8 +27,11 @@ import os import threading import time +import re +import tempfile from http import HTTPStatus from typing import AsyncIterator, Dict, List, Optional, Union +from starlette.routing import Mount import orjson @@ -87,6 +90,10 @@ logger = logging.getLogger(__name__) +# Temporary directory for prometheus multiprocess mode +# Cleaned up automatically when this object is garbage collected +prometheus_multiproc_dir: tempfile.TemporaryDirectory + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -413,6 +420,18 @@ def launch_engine( for i in range(len(scheduler_pipe_readers)): scheduler_pipe_readers[i].recv() +def add_prometheus_middleware(app: FastAPI): + # Adopted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216 + from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess + + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) + + # Workaround for 307 Redirect for /metrics + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") + app.routes.append(metrics_route) + def launch_server( server_args: ServerArgs, @@ -440,6 +459,9 @@ def launch_server( if server_args.api_key: add_api_key_middleware(app, server_args.api_key) + # add prometheus middleware + add_prometheus_middleware(app) + # Send a warmup request t = threading.Thread( target=_wait_and_warmup, args=(server_args, pipe_finish_writer) @@ -476,6 +498,21 @@ def _set_envs_and_config(server_args: ServerArgs): os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" + # Set prometheus multiprocess directory + # sglang uses prometheus multiprocess mode + # we need to set this before importing prometheus_client + # https://prometheus.github.io/client_python/multiprocess/ + global prometheus_multiproc_dir + if "PROMETHEUS_MULTIPROC_DIR" in os.environ: + logger.debug(f"User set PROMETHEUS_MULTIPROC_DIR detected.") + prometheus_multiproc_dir = tempfile.TemporaryDirectory( + dir=os.environ["PROMETHEUS_MULTIPROC_DIR"] + ) + else: + prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name + logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}") + # Set ulimit set_ulimit() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 501c2e326df..97452ad5105 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -70,6 +70,7 @@ class ServerArgs: log_level_http: Optional[str] = None log_requests: bool = False show_time_cost: bool = False + disable_log_stats: bool = False # Other api_key: Optional[str] = None @@ -414,6 +415,12 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Show time cost of custom marks.", ) + parser.add_argument( + "--disable-log-stats", + action="store_true", + help="Disable log stats for prometheus metrics.", + ) + parser.add_argument( "--api-key", type=str, From 81a071e02d1457d96e225e97fe1fa32b9e856988 Mon Sep 17 00:00:00 2001 From: zhanglei335 Date: Thu, 31 Oct 2024 14:24:13 +0800 Subject: [PATCH 2/9] fix Format --- python/sglang/srt/managers/scheduler.py | 4 ++-- python/sglang/srt/server.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e9f15163bbd..2c664e9bb2a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -64,6 +64,8 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache +from sglang.srt.metrics.metrics_collector import PrometheusMetricsCollector +from sglang.srt.metrics.metrics_types import Stats from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( broadcast_pyobj, @@ -76,8 +78,6 @@ suppress_other_loggers, ) from sglang.utils import get_exception_traceback -from sglang.srt.metrics.metrics_types import Stats -from sglang.srt.metrics.metrics_collector import PrometheusMetricsCollector logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 958c59178de..6589d876116 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -25,15 +25,15 @@ import logging import multiprocessing as mp import os -import threading -import time import re import tempfile +import threading +import time from http import HTTPStatus from typing import AsyncIterator, Dict, List, Optional, Union -from starlette.routing import Mount import orjson +from starlette.routing import Mount # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) From 6b7caa25cd823ae260c0f4aef684da4498219e5b Mon Sep 17 00:00:00 2001 From: zhanglei335 Date: Sat, 2 Nov 2024 23:18:15 +0800 Subject: [PATCH 3/9] fix req queue_time is None --- python/sglang/srt/managers/schedule_policy.py | 3 +++ python/sglang/srt/managers/scheduler.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 6ea6ff194da..14ee50155ed 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -17,6 +17,7 @@ import os import random +import time from collections import defaultdict from contextlib import contextmanager from enum import Enum, auto @@ -300,6 +301,7 @@ def add_one_req(self, req: Req): ): # Non-chunked prefill self.can_run_list.append(req) + req.queued_time = time.time() self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req( prefix_len, @@ -318,6 +320,7 @@ def add_one_req(self, req: Req): req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] self.can_run_list.append(req) + req.queued_time = time.time() self.new_inflight_req = req self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req(prefix_len, trunc_len, 0) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index f2e0cb8e13c..40dab98e01c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -669,7 +669,6 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if res == AddReqResult.NO_TOKEN: self.batch_is_full = True break - req.queued_time = time.time() # Update waiting queue can_run_list = adder.can_run_list From ab40e387a6b62089c57cd19ff277a1a1c148b0cb Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Sun, 3 Nov 2024 17:02:40 -0800 Subject: [PATCH 4/9] Update python/sglang/srt/server.py --- python/sglang/srt/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 4aff5098efe..fb727888cee 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -420,7 +420,7 @@ def launch_engine( scheduler_pipe_readers[i].recv() def add_prometheus_middleware(app: FastAPI): - # Adopted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216 + # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216 from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess registry = CollectorRegistry() From 3672b1a13ac8170e01cbb7115338821829cc80bf Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Sun, 3 Nov 2024 17:02:45 -0800 Subject: [PATCH 5/9] Update python/sglang/srt/managers/schedule_batch.py --- python/sglang/srt/managers/schedule_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3205a065d30..742b91398fd 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -258,7 +258,7 @@ def __init__( # Lifetime traces # time when request is created and added to waitlist self.created_time = None - # time when request is added to prefill batech + # time when request is added to prefill batch self.queued_time = None # time when request is being processed self.started_time = None From df1283f7d4f158f5c5db886d10f88fcc60c4f911 Mon Sep 17 00:00:00 2001 From: zhanglei335 Date: Mon, 4 Nov 2024 15:05:53 +0800 Subject: [PATCH 6/9] default no log stats and add last_log_tic for print gen throughput --- python/sglang/srt/managers/scheduler.py | 9 +++++---- python/sglang/srt/server_args.py | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e25a09f5070..6ae95dd8f57 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -228,7 +228,8 @@ def __init__( self.forward_ct = 0 self.forward_ct_decode = 0 self.num_generated_tokens = 0 - self.last_stats_tic = time.time() + self.last_stats_tic = time.time() # time of last stats for every iter + self.last_log_tic = time.time() # time of last log for print decode log self.stream_interval = server_args.stream_interval # Init chunked prefill @@ -354,7 +355,7 @@ def event_loop_normal(self): self.check_memory() self.new_token_ratio = self.init_new_token_ratio # log stats - if self.is_generation and not self.server_args.disable_log_stats: + if self.is_generation and self.server_args.enable_metrics: stats = self.get_stats(batch) self.log_stats(stats) self.last_stats_tic = time.time() @@ -525,9 +526,9 @@ def print_decode_stats(self): num_used = self.max_total_num_tokens - ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) - throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic) + throughput = self.num_generated_tokens / (time.time() - self.last_log_tic) self.num_generated_tokens = 0 - # self.last_stats_tic = time.time() + self.last_log_tic = time.time() # set system stats self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2) num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 97452ad5105..84d1afbd5f6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -70,7 +70,7 @@ class ServerArgs: log_level_http: Optional[str] = None log_requests: bool = False show_time_cost: bool = False - disable_log_stats: bool = False + enable_metrics: bool = False # Other api_key: Optional[str] = None @@ -416,9 +416,9 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Show time cost of custom marks.", ) parser.add_argument( - "--disable-log-stats", + "--enable-metrics", action="store_true", - help="Disable log stats for prometheus metrics.", + help="Enable log prometheus metrics.", ) parser.add_argument( From 506773ce5f39b56d9a9f054f0348841507d540c4 Mon Sep 17 00:00:00 2001 From: zhanglei335 Date: Mon, 4 Nov 2024 15:38:22 +0800 Subject: [PATCH 7/9] change prometheus labels model_name --- python/sglang/srt/managers/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 542da344f61..6d3d419d1a6 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -298,7 +298,7 @@ def __init__( self.stats = Stats() self.metrics_collector = PrometheusMetricsCollector( labels={ - "name": self.model_config.path, + "model_name": self.server_args.served_model_name, # TODO: Add lora name/path in the future, }, max_model_len=self.max_total_num_tokens, From 902066e44f1bbdc74e4e097d4ee03088c9a63b38 Mon Sep 17 00:00:00 2001 From: zhanglei335 Date: Tue, 5 Nov 2024 14:16:40 +0800 Subject: [PATCH 8/9] move prometheus env init to launch_server --- python/sglang/srt/server.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index c3ea21a176c..b67942b3218 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -459,6 +459,7 @@ def launch_server( add_api_key_middleware(app, server_args.api_key) # add prometheus middleware + _set_prometheus_env() add_prometheus_middleware(app) # Send a warmup request @@ -488,15 +489,7 @@ def launch_server( finally: t.join() - -def _set_envs_and_config(server_args: ServerArgs): - # Set global environments - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = "0" - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" - +def _set_prometheus_env(): # Set prometheus multiprocess directory # sglang uses prometheus multiprocess mode # we need to set this before importing prometheus_client @@ -512,6 +505,14 @@ def _set_envs_and_config(server_args: ServerArgs): os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}") +def _set_envs_and_config(server_args: ServerArgs): + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" + # Set ulimit set_ulimit() From d9e557fb2a9b6af9b0cdb9d5e5f8d371f4637f26 Mon Sep 17 00:00:00 2001 From: zhanglei335 Date: Wed, 6 Nov 2024 10:13:36 +0800 Subject: [PATCH 9/9] enable prometheus according to the server_args --- python/sglang/srt/server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index b67942b3218..37882703fd9 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -459,8 +459,9 @@ def launch_server( add_api_key_middleware(app, server_args.api_key) # add prometheus middleware - _set_prometheus_env() - add_prometheus_middleware(app) + if server_args.enable_metrics: + _set_prometheus_env() + add_prometheus_middleware(app) # Send a warmup request t = threading.Thread(