Skip to content

Commit

Permalink
Merged three native APIs into one: get_server_info (#2152)
Browse files Browse the repository at this point in the history
  • Loading branch information
henryhmko authored Nov 24, 2024
1 parent 84a1698 commit dbe1729
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 119 deletions.
2 changes: 1 addition & 1 deletion benchmark/json_schema/bench_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def main(args):

# Compute accuracy
tokenizer = get_tokenizer(
global_config.default_backend.get_server_args()["tokenizer_path"]
global_config.default_backend.get_server_info()["tokenizer_path"]
)
output_jsons = [state["json_output"] for state in states]
num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons)
Expand Down
82 changes: 20 additions & 62 deletions docs/backend/native_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
"Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n",
"\n",
"- `/generate` (text generation model)\n",
"- `/get_server_args`\n",
"- `/get_model_info`\n",
"- `/get_server_info`\n",
"- `/health`\n",
"- `/health_generate`\n",
"- `/flush_cache`\n",
"- `/get_memory_pool_size`\n",
"- `/get_max_total_num_tokens`\n",
"- `/update_weights`\n",
"- `/encode`(embedding model)\n",
"- `/classify`(reward model)\n",
Expand Down Expand Up @@ -75,26 +73,6 @@
"print_highlight(response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Server Args\n",
"Get the arguments of a server."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = \"http://localhost:30010/get_server_args\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -127,9 +105,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Health Check\n",
"- `/health`: Check the health of the server.\n",
"- `/health_generate`: Check the health of the server by generating one token."
"## Get Server Info\n",
"Gets the server information including CLI arguments, token limits, and memory pool sizes.\n",
"- Note: `get_server_info` merges the following deprecated endpoints:\n",
" - `get_server_args`\n",
" - `get_memory_pool_size` \n",
" - `get_max_total_num_tokens`"
]
},
{
Expand All @@ -138,19 +119,9 @@
"metadata": {},
"outputs": [],
"source": [
"url = \"http://localhost:30010/health_generate\"\n",
"# get_server_info\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = \"http://localhost:30010/health\"\n",
"url = \"http://localhost:30010/get_server_info\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
Expand All @@ -160,9 +131,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Flush Cache\n",
"\n",
"Flush the radix cache. It will be automatically triggered when the model weights are updated by the `/update_weights` API."
"## Health Check\n",
"- `/health`: Check the health of the server.\n",
"- `/health_generate`: Check the health of the server by generating one token."
]
},
{
Expand All @@ -171,32 +142,19 @@
"metadata": {},
"outputs": [],
"source": [
"# flush cache\n",
"\n",
"url = \"http://localhost:30010/flush_cache\"\n",
"url = \"http://localhost:30010/health_generate\"\n",
"\n",
"response = requests.post(url)\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Memory Pool Size\n",
"\n",
"Get the memory pool size in number of tokens.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get_memory_pool_size\n",
"\n",
"url = \"http://localhost:30010/get_memory_pool_size\"\n",
"url = \"http://localhost:30010/health\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
Expand All @@ -206,9 +164,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Maximum Total Number of Tokens\n",
"## Flush Cache\n",
"\n",
"Exposes the maximum number of tokens SGLang can handle based on the current configuration."
"Flush the radix cache. It will be automatically triggered when the model weights are updated by the `/update_weights` API."
]
},
{
Expand All @@ -217,11 +175,11 @@
"metadata": {},
"outputs": [],
"source": [
"# get_max_total_num_tokens\n",
"# flush cache\n",
"\n",
"url = \"http://localhost:30010/get_max_total_num_tokens\"\n",
"url = \"http://localhost:30010/flush_cache\"\n",
"\n",
"response = requests.get(url)\n",
"response = requests.post(url)\n",
"print_highlight(response.text)"
]
},
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
gen,
gen_int,
gen_string,
get_server_args,
get_server_info,
image,
select,
set_default_backend,
Expand Down Expand Up @@ -41,7 +41,7 @@
"gen",
"gen_int",
"gen_string",
"get_server_args",
"get_server_info",
"image",
"select",
"set_default_backend",
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ def flush_cache(backend: Optional[BaseBackend] = None):
return backend.flush_cache()


def get_server_args(backend: Optional[BaseBackend] = None):
def get_server_info(backend: Optional[BaseBackend] = None):
backend = backend or global_config.default_backend
if backend is None:
return None

# If backend is Runtime
if hasattr(backend, "endpoint"):
backend = backend.endpoint
return backend.get_server_args()
return backend.get_server_info()


def gen(
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/lang/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,5 @@ def shutdown(self):
def flush_cache(self):
pass

def get_server_args(self):
def get_server_info(self):
pass
4 changes: 2 additions & 2 deletions python/sglang/lang/backend/runtime_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def flush_cache(self):
)
self._assert_success(res)

def get_server_args(self):
def get_server_info(self):
res = http_request(
self.base_url + "/get_server_args",
self.base_url + "/get_server_info",
api_key=self.api_key,
verify=self.verify,
)
Expand Down
67 changes: 27 additions & 40 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,15 @@ async def get_model_info():
return result


@app.get("/get_server_args")
async def get_server_args():
"""Get the server arguments."""
return dataclasses.asdict(tokenizer_manager.server_args)
@app.get("/get_server_info")
async def get_server_info():
try:
return await _get_server_info()

except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)


@app.post("/flush_cache")
Expand Down Expand Up @@ -185,30 +190,6 @@ async def stop_profile():
)


@app.get("/get_max_total_num_tokens")
async def get_max_total_num_tokens():
try:
return {"max_total_num_tokens": _get_max_total_num_tokens()}

except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)


@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
async def get_memory_pool_size():
"""Get the memory pool size in number of tokens"""
try:
ret = await tokenizer_manager.get_memory_pool_size()

return ret
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)


@app.post("/update_weights")
@time_func_latency
async def update_weights(obj: UpdateWeightReqInput, request: Request):
Expand Down Expand Up @@ -542,8 +523,12 @@ def launch_server(
t.join()


def _get_max_total_num_tokens():
return _max_total_num_tokens
async def _get_server_info():
return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
"memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size
"max_total_num_tokens": _max_total_num_tokens, # max total num tokens
}


def _set_envs_and_config(server_args: ServerArgs):
Expand Down Expand Up @@ -787,14 +772,16 @@ def encode(
response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json())

def get_max_total_num_tokens(self):
response = requests.get(f"{self.url}/get_max_total_num_tokens")
if response.status_code == 200:
return response.json()["max_total_num_tokens"]
else:
raise RuntimeError(
f"Failed to get max tokens. {response.json()['error']['message']}"
)
async def get_server_info(self):
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.url}/get_server_info") as response:
if response.status == 200:
return await response.json()
else:
error_data = await response.json()
raise RuntimeError(
f"Failed to get server info. {error_data['error']['message']}"
)

def __del__(self):
self.shutdown()
Expand Down Expand Up @@ -946,5 +933,5 @@ def encode(
loop = asyncio.get_event_loop()
return loop.run_until_complete(encode_request(obj, None))

def get_max_total_num_tokens(self):
return _get_max_total_num_tokens()
async def get_server_info(self):
return await _get_server_info()
8 changes: 4 additions & 4 deletions rust/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ async fn health_generate(data: web::Data<AppState>) -> impl Responder {
forward_request(&data.client, worker_url, "/health_generate".to_string()).await
}

#[get("/get_server_args")]
async fn get_server_args(data: web::Data<AppState>) -> impl Responder {
#[get("/get_server_info")]
async fn get_server_info(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() {
Some(url) => url,
None => return HttpResponse::InternalServerError().finish(),
};

forward_request(&data.client, worker_url, "/get_server_args".to_string()).await
forward_request(&data.client, worker_url, "/get_server_info".to_string()).await
}

#[get("/v1/models")]
Expand Down Expand Up @@ -153,7 +153,7 @@ pub async fn startup(
.service(get_model_info)
.service(health)
.service(health_generate)
.service(get_server_args)
.service(get_server_info)
})
.bind((host, port))?
.run()
Expand Down
5 changes: 3 additions & 2 deletions test/srt/test_data_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ def test_update_weight(self):
assert response.status_code == 200

def test_get_memory_pool_size(self):
response = requests.get(self.base_url + "/get_memory_pool_size")
# use `get_server_info` instead since `get_memory_pool_size` is merged into `get_server_info`
response = requests.get(self.base_url + "/get_server_info")
assert response.status_code == 200

time.sleep(5)

response = requests.get(self.base_url + "/get_memory_pool_size")
response = requests.get(self.base_url + "/get_server_info")
assert response.status_code == 200


Expand Down
15 changes: 12 additions & 3 deletions test/srt/test_srt_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,18 @@ def test_logprob_with_chunked_prefill(self):
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)

def test_get_memory_pool_size(self):
response = requests.post(self.base_url + "/get_memory_pool_size")
self.assertIsInstance(response.json(), int)
def test_get_server_info(self):
response = requests.get(self.base_url + "/get_server_info")
response_json = response.json()

max_total_num_tokens = response_json["max_total_num_tokens"]
self.assertIsInstance(max_total_num_tokens, int)

memory_pool_size = response_json["memory_pool_size"]
self.assertIsInstance(memory_pool_size, int)

attention_backend = response_json["attention_backend"]
self.assertIsInstance(attention_backend, str)


if __name__ == "__main__":
Expand Down

0 comments on commit dbe1729

Please sign in to comment.