diff --git a/benchmark/json_schema/bench_sglang.py b/benchmark/json_schema/bench_sglang.py index 72e2f0b3dc3..5a5fd66c63a 100644 --- a/benchmark/json_schema/bench_sglang.py +++ b/benchmark/json_schema/bench_sglang.py @@ -113,7 +113,7 @@ def main(args): # Compute accuracy tokenizer = get_tokenizer( - global_config.default_backend.get_server_args()["tokenizer_path"] + global_config.default_backend.get_server_info()["tokenizer_path"] ) output_jsons = [state["json_output"] for state in states] num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons) diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb index 4a27d1f7f15..0b43c6a5ae9 100644 --- a/docs/backend/native_api.ipynb +++ b/docs/backend/native_api.ipynb @@ -9,13 +9,11 @@ "Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n", "\n", "- `/generate` (text generation model)\n", - "- `/get_server_args`\n", "- `/get_model_info`\n", + "- `/get_server_info`\n", "- `/health`\n", "- `/health_generate`\n", "- `/flush_cache`\n", - "- `/get_memory_pool_size`\n", - "- `/get_max_total_num_tokens`\n", "- `/update_weights`\n", "- `/encode`(embedding model)\n", "- `/classify`(reward model)\n", @@ -75,26 +73,6 @@ "print_highlight(response.json())" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Get Server Args\n", - "Get the arguments of a server." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "url = \"http://localhost:30010/get_server_args\"\n", - "\n", - "response = requests.get(url)\n", - "print_highlight(response.json())" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -127,9 +105,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Health Check\n", - "- `/health`: Check the health of the server.\n", - "- `/health_generate`: Check the health of the server by generating one token." + "## Get Server Info\n", + "Gets the server information including CLI arguments, token limits, and memory pool sizes.\n", + "- Note: `get_server_info` merges the following deprecated endpoints:\n", + " - `get_server_args`\n", + " - `get_memory_pool_size` \n", + " - `get_max_total_num_tokens`" ] }, { @@ -138,19 +119,9 @@ "metadata": {}, "outputs": [], "source": [ - "url = \"http://localhost:30010/health_generate\"\n", + "# get_server_info\n", "\n", - "response = requests.get(url)\n", - "print_highlight(response.text)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "url = \"http://localhost:30010/health\"\n", + "url = \"http://localhost:30010/get_server_info\"\n", "\n", "response = requests.get(url)\n", "print_highlight(response.text)" @@ -160,9 +131,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Flush Cache\n", - "\n", - "Flush the radix cache. It will be automatically triggered when the model weights are updated by the `/update_weights` API." + "## Health Check\n", + "- `/health`: Check the health of the server.\n", + "- `/health_generate`: Check the health of the server by generating one token." ] }, { @@ -171,32 +142,19 @@ "metadata": {}, "outputs": [], "source": [ - "# flush cache\n", - "\n", - "url = \"http://localhost:30010/flush_cache\"\n", + "url = \"http://localhost:30010/health_generate\"\n", "\n", - "response = requests.post(url)\n", + "response = requests.get(url)\n", "print_highlight(response.text)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Get Memory Pool Size\n", - "\n", - "Get the memory pool size in number of tokens.\n" - ] - }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "# get_memory_pool_size\n", - "\n", - "url = \"http://localhost:30010/get_memory_pool_size\"\n", + "url = \"http://localhost:30010/health\"\n", "\n", "response = requests.get(url)\n", "print_highlight(response.text)" @@ -206,9 +164,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Get Maximum Total Number of Tokens\n", + "## Flush Cache\n", "\n", - "Exposes the maximum number of tokens SGLang can handle based on the current configuration." + "Flush the radix cache. It will be automatically triggered when the model weights are updated by the `/update_weights` API." ] }, { @@ -217,11 +175,11 @@ "metadata": {}, "outputs": [], "source": [ - "# get_max_total_num_tokens\n", + "# flush cache\n", "\n", - "url = \"http://localhost:30010/get_max_total_num_tokens\"\n", + "url = \"http://localhost:30010/flush_cache\"\n", "\n", - "response = requests.get(url)\n", + "response = requests.post(url)\n", "print_highlight(response.text)" ] }, diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index 3c4457c983a..40fbf17bc9d 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -11,7 +11,7 @@ gen, gen_int, gen_string, - get_server_args, + get_server_info, image, select, set_default_backend, @@ -41,7 +41,7 @@ "gen", "gen_int", "gen_string", - "get_server_args", + "get_server_info", "image", "select", "set_default_backend", diff --git a/python/sglang/api.py b/python/sglang/api.py index 28c6783a33f..9a30ad492da 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -65,7 +65,7 @@ def flush_cache(backend: Optional[BaseBackend] = None): return backend.flush_cache() -def get_server_args(backend: Optional[BaseBackend] = None): +def get_server_info(backend: Optional[BaseBackend] = None): backend = backend or global_config.default_backend if backend is None: return None @@ -73,7 +73,7 @@ def get_server_args(backend: Optional[BaseBackend] = None): # If backend is Runtime if hasattr(backend, "endpoint"): backend = backend.endpoint - return backend.get_server_args() + return backend.get_server_info() def gen( diff --git a/python/sglang/lang/backend/base_backend.py b/python/sglang/lang/backend/base_backend.py index 185f2e297ae..725c0a91da7 100644 --- a/python/sglang/lang/backend/base_backend.py +++ b/python/sglang/lang/backend/base_backend.py @@ -78,5 +78,5 @@ def shutdown(self): def flush_cache(self): pass - def get_server_args(self): + def get_server_info(self): pass diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index f43ae240aaf..779bf988d20 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -58,9 +58,9 @@ def flush_cache(self): ) self._assert_success(res) - def get_server_args(self): + def get_server_info(self): res = http_request( - self.base_url + "/get_server_args", + self.base_url + "/get_server_info", api_key=self.api_key, verify=self.verify, ) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index c2aa73e3647..9d67a92a5e1 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -146,10 +146,15 @@ async def get_model_info(): return result -@app.get("/get_server_args") -async def get_server_args(): - """Get the server arguments.""" - return dataclasses.asdict(tokenizer_manager.server_args) +@app.get("/get_server_info") +async def get_server_info(): + try: + return await _get_server_info() + + except Exception as e: + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) @app.post("/flush_cache") @@ -185,30 +190,6 @@ async def stop_profile(): ) -@app.get("/get_max_total_num_tokens") -async def get_max_total_num_tokens(): - try: - return {"max_total_num_tokens": _get_max_total_num_tokens()} - - except Exception as e: - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) - - -@app.api_route("/get_memory_pool_size", methods=["GET", "POST"]) -async def get_memory_pool_size(): - """Get the memory pool size in number of tokens""" - try: - ret = await tokenizer_manager.get_memory_pool_size() - - return ret - except Exception as e: - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) - - @app.post("/update_weights") @time_func_latency async def update_weights(obj: UpdateWeightReqInput, request: Request): @@ -542,8 +523,12 @@ def launch_server( t.join() -def _get_max_total_num_tokens(): - return _max_total_num_tokens +async def _get_server_info(): + return { + **dataclasses.asdict(tokenizer_manager.server_args), # server args + "memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size + "max_total_num_tokens": _max_total_num_tokens, # max total num tokens + } def _set_envs_and_config(server_args: ServerArgs): @@ -787,14 +772,16 @@ def encode( response = requests.post(self.url + "/encode", json=json_data) return json.dumps(response.json()) - def get_max_total_num_tokens(self): - response = requests.get(f"{self.url}/get_max_total_num_tokens") - if response.status_code == 200: - return response.json()["max_total_num_tokens"] - else: - raise RuntimeError( - f"Failed to get max tokens. {response.json()['error']['message']}" - ) + async def get_server_info(self): + async with aiohttp.ClientSession() as session: + async with session.get(f"{self.url}/get_server_info") as response: + if response.status == 200: + return await response.json() + else: + error_data = await response.json() + raise RuntimeError( + f"Failed to get server info. {error_data['error']['message']}" + ) def __del__(self): self.shutdown() @@ -946,5 +933,5 @@ def encode( loop = asyncio.get_event_loop() return loop.run_until_complete(encode_request(obj, None)) - def get_max_total_num_tokens(self): - return _get_max_total_num_tokens() + async def get_server_info(self): + return await _get_server_info() diff --git a/rust/src/server.rs b/rust/src/server.rs index 93dd9e0b9f9..fb9fdcfac82 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -66,14 +66,14 @@ async fn health_generate(data: web::Data) -> impl Responder { forward_request(&data.client, worker_url, "/health_generate".to_string()).await } -#[get("/get_server_args")] -async fn get_server_args(data: web::Data) -> impl Responder { +#[get("/get_server_info")] +async fn get_server_info(data: web::Data) -> impl Responder { let worker_url = match data.router.get_first() { Some(url) => url, None => return HttpResponse::InternalServerError().finish(), }; - forward_request(&data.client, worker_url, "/get_server_args".to_string()).await + forward_request(&data.client, worker_url, "/get_server_info".to_string()).await } #[get("/v1/models")] @@ -153,7 +153,7 @@ pub async fn startup( .service(get_model_info) .service(health) .service(health_generate) - .service(get_server_args) + .service(get_server_info) }) .bind((host, port))? .run() diff --git a/test/srt/test_data_parallelism.py b/test/srt/test_data_parallelism.py index 393e44468ca..f34313ea09a 100644 --- a/test/srt/test_data_parallelism.py +++ b/test/srt/test_data_parallelism.py @@ -63,12 +63,13 @@ def test_update_weight(self): assert response.status_code == 200 def test_get_memory_pool_size(self): - response = requests.get(self.base_url + "/get_memory_pool_size") + # use `get_server_info` instead since `get_memory_pool_size` is merged into `get_server_info` + response = requests.get(self.base_url + "/get_server_info") assert response.status_code == 200 time.sleep(5) - response = requests.get(self.base_url + "/get_memory_pool_size") + response = requests.get(self.base_url + "/get_server_info") assert response.status_code == 200 diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 4ca17adb616..98c124fee1a 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -154,9 +154,18 @@ def test_logprob_with_chunked_prefill(self): self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens) self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens) - def test_get_memory_pool_size(self): - response = requests.post(self.base_url + "/get_memory_pool_size") - self.assertIsInstance(response.json(), int) + def test_get_server_info(self): + response = requests.get(self.base_url + "/get_server_info") + response_json = response.json() + + max_total_num_tokens = response_json["max_total_num_tokens"] + self.assertIsInstance(max_total_num_tokens, int) + + memory_pool_size = response_json["memory_pool_size"] + self.assertIsInstance(memory_pool_size, int) + + attention_backend = response_json["attention_backend"] + self.assertIsInstance(attention_backend, str) if __name__ == "__main__":