diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml
index a3d30324c37..8deafefb408 100644
--- a/.github/workflows/pr-test.yml
+++ b/.github/workflows/pr-test.yml
@@ -263,6 +263,24 @@ jobs:
cd test/srt
python3 test_moe_eval_accuracy_large.py
+ weight-update-test-2-gpu:
+ if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
+ runs-on: 2-gpu-runner
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Install dependencies
+ run: |
+ bash scripts/ci_install_dependency.sh
+
+ - name: Test weight update (TP, DP = 2,1 or 1,2)
+ timeout-minutes: 20
+ run: |
+ cd test/srt
+ python3 test_update_parameter_from_distributed.py
+ python3 test_get_parameter_by_name.py
+
finish:
needs: [
unit-test-frontend, unit-test-backend-part-1, unit-test-backend-part-2, unit-test-backend-part-3, unit-test-backend-part-4,
diff --git a/3rdparty/amd/profiling/PROFILING.md b/3rdparty/amd/profiling/PROFILING.md
index 90ad8665e67..79bc75b503b 100644
--- a/3rdparty/amd/profiling/PROFILING.md
+++ b/3rdparty/amd/profiling/PROFILING.md
@@ -421,5 +421,5 @@ index 62d1ff9..6ecd78c 100644
3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container.
4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling.
-=======
+-------
- [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb
index 0b43c6a5ae9..9a97c15ca30 100644
--- a/docs/backend/native_api.ipynb
+++ b/docs/backend/native_api.ipynb
@@ -14,7 +14,7 @@
"- `/health`\n",
"- `/health_generate`\n",
"- `/flush_cache`\n",
- "- `/update_weights`\n",
+ "- `/update_weights_from_disk`\n",
"- `/encode`(embedding model)\n",
"- `/classify`(reward model)\n",
"\n",
@@ -98,7 +98,7 @@
"print_highlight(response_json)\n",
"assert response_json[\"model_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n",
"assert response_json[\"is_generation\"] is True\n",
- "assert response_json.keys() == {\"model_path\", \"is_generation\"}"
+ "assert list(response_json.keys()) == [\"model_path\", \"tokenizer_path\", \"is_generation\"]"
]
},
{
@@ -144,8 +144,7 @@
"source": [
"url = \"http://localhost:30010/health_generate\"\n",
"\n",
- "response = requests.get(url)\n",
- "print_highlight(response.text)"
+ "response = requests.get(url)"
]
},
{
@@ -156,8 +155,7 @@
"source": [
"url = \"http://localhost:30010/health\"\n",
"\n",
- "response = requests.get(url)\n",
- "print_highlight(response.text)"
+ "response = requests.get(url)"
]
},
{
@@ -187,9 +185,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Update Weights\n",
+ "## Update Weights From Disk\n",
"\n",
- "Update model weights without restarting the server. Use for continuous evaluation during training. Only applicable for models with the same architecture and parameter size."
+ "Update model weights from disk without restarting the server. Use for continuous evaluation during training. Only applicable for models with the same architecture and parameter size."
]
},
{
@@ -200,7 +198,7 @@
"source": [
"# successful update with same architecture and size\n",
"\n",
- "url = \"http://localhost:30010/update_weights\"\n",
+ "url = \"http://localhost:30010/update_weights_from_disk\"\n",
"data = {\"model_path\": \"meta-llama/Llama-3.2-1B\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
@@ -218,7 +216,7 @@
"source": [
"# failed update with different parameter size\n",
"\n",
- "url = \"http://localhost:30010/update_weights\"\n",
+ "url = \"http://localhost:30010/update_weights_from_disk\"\n",
"data = {\"model_path\": \"meta-llama/Llama-3.2-3B\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
@@ -340,7 +338,19 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
+ "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
+ "\u001b[1;31mClick here for more info. \n",
+ "\u001b[1;31mView Jupyter log for further details."
+ ]
+ }
+ ],
"source": [
"terminate_process(reward_process)"
]
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 8b1f88fa26b..a33ce7180ae 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -352,7 +352,7 @@ class FlushCacheReq:
@dataclass
-class UpdateWeightReqInput:
+class UpdateWeightFromDistReqInput:
# The model path with the new weights
model_path: str
# The format to load the weights
@@ -360,11 +360,58 @@ class UpdateWeightReqInput:
@dataclass
-class UpdateWeightReqOutput:
+class UpdateWeightFromDistReqOutput:
success: bool
message: str
+@dataclass
+class UpdateParameterFromDistributedReqInput:
+ name: str
+ dtype: str
+ shape: List[int]
+ empty_cache: bool
+
+
+@dataclass
+class UpdateParameterFromDistributedReqOutput:
+ success: bool
+ message: str
+
+
+@dataclass
+class InitParameterUpdateGroupReqInput:
+ # The master address
+ master_address: str
+ # The master port
+ master_port: int
+ # The rank offset
+ rank_offset: int
+ # The world size
+ world_size: int
+ # The group name
+ group_name: str
+ # The backend
+ backend: str = "nccl"
+
+
+@dataclass
+class InitParameterUpdateGroupReqOutput:
+ success: bool
+ message: str
+
+
+@dataclass
+class GetParameterByNameReqInput:
+ name: str
+ truncate_size: int = 100
+
+
+@dataclass
+class GetParameterByNameReqOutput:
+ parameter: list
+
+
@dataclass
class AbortReq:
# The request id
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 663d4c4f935..bb4f2cc93cc 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -38,13 +38,19 @@
FlushCacheReq,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
+ GetParameterByNameReqInput,
+ GetParameterByNameReqOutput,
+ InitParameterUpdateGroupReqInput,
+ InitParameterUpdateGroupReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
- UpdateWeightReqInput,
- UpdateWeightReqOutput,
+ UpdateParameterFromDistributedReqInput,
+ UpdateParameterFromDistributedReqOutput,
+ UpdateWeightFromDistReqInput,
+ UpdateWeightFromDistReqOutput,
)
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
@@ -500,10 +506,25 @@ def process_input_requests(self, recv_reqs: List):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req)
- elif isinstance(recv_req, UpdateWeightReqInput):
- success, message = self.update_weights(recv_req)
+ elif isinstance(recv_req, UpdateWeightFromDistReqInput):
+ success, message = self.update_weights_from_disk(recv_req)
self.send_to_tokenizer.send_pyobj(
- UpdateWeightReqOutput(success, message)
+ UpdateWeightFromDistReqOutput(success, message)
+ )
+ elif isinstance(recv_req, GetParameterByNameReqInput):
+ parameter = self.get_weights_by_parameter_name(recv_req)
+ self.send_to_tokenizer.send_pyobj(
+ GetParameterByNameReqOutput(parameter)
+ )
+ elif isinstance(recv_req, InitParameterUpdateGroupReqInput):
+ success, message = self.init_parameter_update_group(recv_req)
+ self.send_to_tokenizer.send_pyobj(
+ InitParameterUpdateGroupReqOutput(success, message)
+ )
+ elif isinstance(recv_req, UpdateParameterFromDistributedReqInput):
+ success, message = self.update_parameter_from_distributed(recv_req)
+ self.send_to_tokenizer.send_pyobj(
+ UpdateParameterFromDistributedReqOutput(success, message)
)
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
@@ -1353,9 +1374,9 @@ def abort_request(self, recv_req: AbortReq):
self.tree_cache.cache_finished_req(req)
break
- def update_weights(self, recv_req: UpdateWeightReqInput):
+ def update_weights_from_disk(self, recv_req: UpdateWeightFromDistReqInput):
"""In-place update of the weights."""
- success, message = self.tp_worker.update_weights(recv_req)
+ success, message = self.tp_worker.update_weights_from_disk(recv_req)
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
@@ -1363,6 +1384,27 @@ def update_weights(self, recv_req: UpdateWeightReqInput):
logger.error(message)
return success, message
+ def init_parameter_update_group(self, recv_req: InitParameterUpdateGroupReqInput):
+ """Initialize the online model parameter update group."""
+ success, message = self.tp_worker.init_parameter_update_group(recv_req)
+ return success, message
+
+ def update_parameter_from_distributed(
+ self, recv_req: UpdateParameterFromDistributedReqInput
+ ):
+ """Update the online model parameter."""
+ success, message = self.tp_worker.update_parameter_from_distributed(recv_req)
+ if success:
+ flash_cache_success = self.flush_cache()
+ assert flash_cache_success, "Cache flush failed after updating weights"
+ else:
+ logger.error(message)
+ return success, message
+
+ def get_weights_by_parameter_name(self, recv_req: GetParameterByNameReqInput):
+ parameter = self.tp_worker.get_weights_by_parameter_name(recv_req)
+ return parameter
+
def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 001ecc1ebe7..11ed2ec2540 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -47,13 +47,19 @@
GenerateReqInput,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
+ GetParameterByNameReqInput,
+ GetParameterByNameReqOutput,
+ InitParameterUpdateGroupReqInput,
+ InitParameterUpdateGroupReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
- UpdateWeightReqInput,
- UpdateWeightReqOutput,
+ UpdateParameterFromDistributedReqInput,
+ UpdateParameterFromDistributedReqOutput,
+ UpdateWeightFromDistReqInput,
+ UpdateWeightFromDistReqOutput,
)
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
@@ -425,8 +431,10 @@ async def get_memory_pool_size(self):
ret = [r.size for r in res]
return ret
- async def update_weights(
- self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
+ async def update_weights_from_disk(
+ self,
+ obj: UpdateWeightFromDistReqInput,
+ request: Optional[fastapi.Request] = None,
):
if self.to_create_loop:
self.create_handle_loop()
@@ -471,6 +479,89 @@ async def update_weights(
else:
return False, "Another update is in progress. Please try again later."
+ async def init_parameter_update_group(
+ self,
+ obj: InitParameterUpdateGroupReqInput,
+ request: Optional[fastapi.Request] = None,
+ ) -> bool:
+ if self.to_create_loop:
+ self.create_handle_loop()
+
+ if obj.backend is None:
+ obj.backend = "nccl"
+ self.send_to_scheduler.send_pyobj(obj)
+
+ self.init_parameter_update_group_result = asyncio.Future()
+
+ if self.server_args.dp_size == 1:
+ result = await self.init_parameter_update_group_result
+ return result.success, result.message
+ else:
+ self.init_parameter_update_group_tmp = []
+ result = await self.init_parameter_update_group_result
+ all_success = all([r.success for r in result])
+ all_message = [r.message for r in result]
+ all_message = " | ".join(all_message)
+ return all_success, all_message
+
+ async def update_parameter_from_distributed(
+ self,
+ obj: UpdateParameterFromDistributedReqInput,
+ request: Optional[fastapi.Request] = None,
+ ):
+ if self.to_create_loop:
+ self.create_handle_loop()
+ if not self.model_update_lock.locked():
+
+ async with self.model_update_lock:
+ # wait for the previous update requests to finish
+ for i in range(3):
+ while len(self.rid_to_state) > 0:
+ await asyncio.sleep(0.001)
+ # FIXME: We add some sleep here to avoid some race conditions.
+ # We can use a read-write lock as a better fix.
+ await asyncio.sleep(0.01)
+
+ self.send_to_scheduler.send_pyobj(obj)
+ self.parameter_update_result = asyncio.Future()
+
+ if self.server_args.dp_size == 1:
+ result = await self.parameter_update_result
+ return result.success, result.message
+ else: # self.server_args.dp_size > 1
+ self.parameter_update_tmp = []
+ result = await self.parameter_update_result
+ all_success = all([r.success for r in result])
+ all_message = [r.message for r in result]
+ all_message = " | ".join(all_message)
+ return all_success, all_message
+
+ else:
+ logger.error(
+ f"Another parameter update is in progress in tokenizer manager"
+ )
+ return (
+ False,
+ "Another parameter update is in progress. Please try again later.",
+ )
+
+ async def get_weights_by_parameter_name(
+ self, obj: GetParameterByNameReqInput, request: Optional[fastapi.Request] = None
+ ):
+ if self.to_create_loop:
+ self.create_handle_loop()
+
+ self.send_to_scheduler.send_pyobj(obj)
+ self.get_weights_by_parameter_name_result = asyncio.Future()
+ if self.server_args.dp_size == 1:
+ result = await self.get_weights_by_parameter_name_result
+ return result.parameter
+ else:
+ self.get_weights_by_parameter_name_tmp = []
+ result = await self.get_weights_by_parameter_name_result
+ all_parameters = [r.parameter for r in result]
+ return all_parameters
+
async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
@@ -540,10 +631,16 @@ async def handle_loop(self):
while True:
recv_obj: Union[
- BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
+ BatchStrOut,
+ BatchEmbeddingOut,
+ BatchTokenIDOut,
+ UpdateWeightFromDistReqOutput,
+ UpdateParameterFromDistributedReqOutput,
+ GetParameterByNameReqOutput,
+ InitParameterUpdateGroupReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj()
- if isinstance(recv_obj, UpdateWeightReqOutput):
+ if isinstance(recv_obj, UpdateWeightFromDistReqOutput):
if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1
@@ -552,6 +649,43 @@ async def handle_loop(self):
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)
continue
+ elif isinstance(recv_obj, UpdateParameterFromDistributedReqOutput):
+ if self.server_args.dp_size == 1:
+ self.parameter_update_result.set_result(recv_obj)
+ else: # self.server_args.dp_size > 1
+ self.parameter_update_tmp.append(recv_obj)
+ # set future if the all results are recevied
+ if len(self.parameter_update_tmp) == self.server_args.dp_size:
+ self.parameter_update_result.set_result(
+ self.parameter_update_tmp
+ )
+ continue
+ elif isinstance(recv_obj, GetParameterByNameReqOutput):
+ if self.server_args.dp_size == 1:
+ self.get_weights_by_parameter_name_result.set_result(recv_obj)
+ else:
+ self.get_weights_by_parameter_name_tmp.append(recv_obj)
+ if (
+ len(self.get_weights_by_parameter_name_tmp)
+ == self.server_args.dp_size
+ ):
+ self.get_weights_by_parameter_name_result.set_result(
+ self.get_weights_by_parameter_name_tmp
+ )
+ continue
+ elif isinstance(recv_obj, InitParameterUpdateGroupReqOutput):
+ if self.server_args.dp_size == 1:
+ self.init_parameter_update_group_result.set_result(recv_obj)
+ else:
+ self.init_parameter_update_group_tmp.append(recv_obj)
+ if (
+ len(self.init_parameter_update_group_tmp)
+ == self.server_args.dp_size
+ ):
+ self.init_parameter_update_group_result.set_result(
+ self.init_parameter_update_group_tmp
+ )
+ continue
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
if self.server_args.dp_size == 1:
self.mem_pool_size.set_result(recv_obj)
diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py
index a5d694e77bc..d68dd78cc79 100644
--- a/python/sglang/srt/managers/tp_worker.py
+++ b/python/sglang/srt/managers/tp_worker.py
@@ -19,7 +19,12 @@
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
-from sglang.srt.managers.io_struct import UpdateWeightReqInput
+from sglang.srt.managers.io_struct import (
+ GetParameterByNameReqInput,
+ InitParameterUpdateGroupReqInput,
+ UpdateParameterFromDistributedReqInput,
+ UpdateWeightFromDistReqInput,
+)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
@@ -155,8 +160,33 @@ def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
embeddings = logits_output.embeddings
return embeddings
- def update_weights(self, recv_req: UpdateWeightReqInput):
- success, message = self.model_runner.update_weights(
+ def update_weights_from_disk(self, recv_req: UpdateWeightFromDistReqInput):
+ success, message = self.model_runner.update_weights_from_disk(
recv_req.model_path, recv_req.load_format
)
return success, message
+
+ def init_parameter_update_group(self, recv_req: InitParameterUpdateGroupReqInput):
+ success, message = self.model_runner.init_parameter_update_group(
+ recv_req.master_address,
+ recv_req.master_port,
+ recv_req.rank_offset,
+ recv_req.world_size,
+ recv_req.group_name,
+ recv_req.backend,
+ )
+ return success, message
+
+ def update_parameter_from_distributed(
+ self, recv_req: UpdateParameterFromDistributedReqInput
+ ):
+ success, message = self.model_runner.update_parameter_from_distributed(
+ recv_req.name, recv_req.dtype, recv_req.shape, recv_req.empty_cache
+ )
+ return success, message
+
+ def get_weights_by_parameter_name(self, recv_req: GetParameterByNameReqInput):
+ parameter = self.model_runner.get_weights_by_parameter_name(
+ recv_req.name, recv_req.truncate_size
+ )
+ return parameter
diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py
index 3b53759a75f..02a6dd847e6 100644
--- a/python/sglang/srt/managers/tp_worker_overlap_thread.py
+++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py
@@ -21,7 +21,12 @@
import torch
-from sglang.srt.managers.io_struct import UpdateWeightReqInput
+from sglang.srt.managers.io_struct import (
+ GetParameterByNameReqInput,
+ InitParameterUpdateGroupReqInput,
+ UpdateParameterFromDistributedReqInput,
+ UpdateWeightFromDistReqInput,
+)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
@@ -195,10 +200,23 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
) % self.future_token_ids_limit
return None, future_next_token_ids
- def update_weights(self, recv_req: UpdateWeightReqInput):
- success, message = self.worker.update_weights(recv_req)
+ def update_weights_from_disk(self, recv_req: UpdateWeightFromDistReqInput):
+ success, message = self.worker.update_weights_from_disk(recv_req)
return success, message
+ def init_parameter_update_group(self, recv_req: InitParameterUpdateGroupReqInput):
+ success, message = self.worker.init_parameter_update_group(recv_req)
+ return success, message
+
+ def update_parameter_from_distributed(
+ self, recv_req: UpdateParameterFromDistributedReqInput
+ ):
+ success, message = self.worker.update_parameter_from_distributed(recv_req)
+ return success, message
+
+ def get_weights_by_parameter_name(self, recv_req: GetParameterByNameReqInput):
+ return self.worker.get_weights_by_parameter_name(recv_req)
+
def __delete__(self):
self.input_queue.put((None, None))
self.copy_queue.put((None, None, None))
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index 7c1c51a8fb2..43036c949ea 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -21,9 +21,10 @@
import logging
import pkgutil
from functools import lru_cache
-from typing import Optional, Type
+from typing import Any, Optional, Type, Union
import torch
+import torch.distributed as dist
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
@@ -58,6 +59,7 @@
crash_on_warnings,
enable_show_time_cost,
get_available_gpu_memory,
+ init_custom_process_group,
is_hip,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check,
@@ -316,8 +318,8 @@ def load_model(self):
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
- def update_weights(self, model_path: str, load_format: str):
- """Update weights in-place."""
+ def update_weights_from_disk(self, model_path: str, load_format: str):
+ """Update engine weights online from disk."""
from vllm.model_executor.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
@@ -326,7 +328,7 @@ def update_weights(self, model_path: str, load_format: str):
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
logger.info(
- f"Update weights begin. "
+ f"Update engine weights online from disk begin. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
@@ -397,6 +399,130 @@ def model_load_weights(model, iter):
logger.info("Update weights end.")
return True, "Succeeded to update model weights."
+ def init_parameter_update_group(
+ self,
+ master_address,
+ master_port,
+ rank_offset,
+ world_size,
+ group_name,
+ backend="nccl",
+ ):
+ """Initialize the Torch process group for model parameter updates.
+
+ `_model_update_group` is used in the RLHF workflow, where rank 0 is the actor model in
+ the training engine, and the other ranks are the inference engine, which is used for rollout.
+
+ In the RLHF workflow, the training engine updates the model weights/parameters online,
+ and broadcasts them to the inference engine through the `_model_update_group` process group.
+ """
+ assert (
+ torch.distributed.is_initialized()
+ ), "Default torch process group must be initialized"
+ assert group_name != "", "Group name cannot be empty"
+
+ rank = rank_offset + self.tp_rank
+
+ logger.info(
+ f"init custom process group: master_address={master_address}, master_port={master_port}, "
+ f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
+ )
+
+ try:
+ self._model_update_group = init_custom_process_group(
+ backend=backend,
+ init_method=f"tcp://{master_address}:{master_port}",
+ world_size=world_size,
+ rank=rank,
+ group_name=group_name,
+ )
+
+ return True, "Succeeded to initialize custom process group."
+ except Exception as e:
+ message = f"Failed to initialize custom process group: {e}."
+ logger.error(message)
+ return False, message
+
+ def get_weights_by_parameter_name(
+ self, name: str, truncate_size: int = 100
+ ) -> Optional[torch.Tensor]:
+ try:
+ # 检查是否是合并的参数
+ mapped_name = name
+ mapped_shard_id = None
+ for param_name, weight_name, shard_id in self.model.stacked_params_mapping:
+ if weight_name in name:
+ mapped_name = name.replace(weight_name, param_name)
+ mapped_shard_id = shard_id
+ break
+ params_dict = dict(self.model.named_parameters())
+ if mapped_name in params_dict:
+ param = params_dict[mapped_name]
+ if mapped_shard_id is not None:
+ # 处理合并参数的情况
+ if mapped_shard_id in ["q", "k", "v"]:
+ # 计算在qkv_proj中的偏移和大小
+ num_heads = (
+ self.model.config.num_attention_heads // self.tp_size
+ )
+ num_kv_heads = (
+ self.model.config.num_key_value_heads // self.tp_size
+ )
+ head_dim = (
+ self.model.config.hidden_size
+ // self.model.config.num_attention_heads
+ )
+
+ if mapped_shard_id == "q":
+ offset = 0
+ size = num_heads * head_dim
+ elif mapped_shard_id == "k":
+ offset = num_heads * head_dim
+ size = num_kv_heads * head_dim
+ elif mapped_shard_id == "v":
+ offset = (num_heads + num_kv_heads) * head_dim
+ size = num_kv_heads * head_dim
+
+ # 提取对应部分的权重
+ weight = param.data.narrow(0, offset, size)
+ elif mapped_shard_id in [0, 1]:
+ # 处理 gate_up_proj 的情况
+ intermediate_size = self.model.config.intermediate_size
+ hidden_size = self.model.config.hidden_size
+ slice_size = intermediate_size // self.tp_size
+
+ if mapped_shard_id == 0: # gate_proj
+ offset = 0
+ size = slice_size
+ elif mapped_shard_id == 1: # up_proj
+ offset = slice_size
+ size = slice_size
+
+ # 提取对应部分的权重
+ weight = param.data.narrow(0, offset, size)
+ else:
+ weight = param.data
+ else:
+ weight = param.data
+
+ if self.tp_size > 1 and ("o_proj" in name or "down_proj" in name):
+ gathered_weights = [
+ torch.zeros_like(weight) for _ in range(self.tp_size)
+ ]
+ torch.distributed.all_gather(gathered_weights, weight)
+ weight = torch.cat(gathered_weights, dim=1)
+
+ return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
+ else:
+ logger.warning(
+ f"Parameter {name} (mapped to {mapped_name}) not found in model"
+ )
+ return None
+
+ except Exception as e:
+ logger.error(f"Error when getting parameter {name}: {e}")
+ return None
+
def init_lora_manager(self):
self.lora_manager = LoRAManager(
base_model=self.model,
diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py
index 7e9fd0f7267..67ae46a6440 100644
--- a/python/sglang/srt/models/llama.py
+++ b/python/sglang/srt/models/llama.py
@@ -305,6 +305,14 @@ def __init__(
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
+ self.stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ (".qkv_proj", ".q_proj", "q"),
+ (".qkv_proj", ".k_proj", "k"),
+ (".qkv_proj", ".v_proj", "v"),
+ (".gate_up_proj", ".gate_proj", 0),
+ (".gate_up_proj", ".up_proj", 1),
+ ]
@torch.no_grad()
def forward(
@@ -349,15 +357,7 @@ def get_module_name(self, name):
return params_mapping.get(name, name)
def get_module_name_from_weight_name(self, name):
- stacked_params_mapping = [
- # (param_name, shard_name, shard_id, num_shard)
- ("qkv_proj", "q_proj", "q", 3),
- ("qkv_proj", "k_proj", "k", 3),
- ("qkv_proj", "v_proj", "v", 3),
- ("gate_up_proj", "gate_proj", 0, 2),
- ("gate_up_proj", "up_proj", 1, 2),
- ]
- for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
+ for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
if weight_name in name:
return (
name.replace(weight_name, param_name)[: -len(".weight")],
@@ -370,6 +370,7 @@ def get_num_params(self):
return len(params_dict)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ embed_tokens_weight = None
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
@@ -378,6 +379,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
+
params_dict = dict(self.named_parameters())
load_tie_word_embeddings = (
@@ -425,7 +427,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, embed_tokens_weight)
+ if embed_tokens_weight is not None:
+ weight_loader(param, embed_tokens_weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
index a4753a13458..4030c4e9da9 100644
--- a/python/sglang/srt/server.py
+++ b/python/sglang/srt/server.py
@@ -28,12 +28,15 @@
from http import HTTPStatus
from typing import AsyncIterator, Dict, List, Optional, Union
+import torch
+
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
import aiohttp
import orjson
import requests
+import torch.distributed as dist
import uvicorn
import uvloop
from fastapi import FastAPI, File, Form, Request, UploadFile
@@ -51,8 +54,11 @@
CloseSessionReqInput,
EmbeddingReqInput,
GenerateReqInput,
+ GetParameterByNameReqInput,
+ InitParameterUpdateGroupReqInput,
OpenSessionReqInput,
- UpdateWeightReqInput,
+ UpdateParameterFromDistributedReqInput,
+ UpdateWeightFromDistReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -78,6 +84,7 @@
assert_pkg_version,
configure_logger,
delete_directory,
+ init_custom_process_group,
is_port_available,
kill_child_process,
maybe_set_triton_cache_manager,
@@ -201,11 +208,24 @@ async def stop_profile_async():
)
-@app.post("/update_weights")
+@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
+async def get_memory_pool_size():
+ """Get the memory pool size in number of tokens"""
+ try:
+ ret = await tokenizer_manager.get_memory_pool_size()
+
+ return ret
+ except Exception as e:
+ return ORJSONResponse(
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
+ )
+
+
+@app.post("/update_weights_from_disk")
@time_func_latency
-async def update_weights(obj: UpdateWeightReqInput, request: Request):
- """Update the weights inplace without re-launching the server."""
- success, message = await tokenizer_manager.update_weights(obj, request)
+async def update_weights_from_disk(obj: UpdateWeightFromDistReqInput, request: Request):
+ """Update the weights from disk inplace without re-launching the server."""
+ success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(
@@ -219,6 +239,54 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
)
+@app.post("/init_parameter_update_group")
+async def init_parameter_update_group(
+ obj: InitParameterUpdateGroupReqInput, request: Request
+):
+ """Initialize the parameter update group."""
+ success, message = await tokenizer_manager.init_parameter_update_group(obj, request)
+ content = {"success": success, "message": message}
+ if success:
+ return ORJSONResponse(content, status_code=200)
+ else:
+ return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
+
+
+@app.post("/update_parameter_from_distributed")
+async def update_parameter_from_distributed(
+ obj: UpdateParameterFromDistributedReqInput, request: Request
+):
+ """Update model parameter from distributed online."""
+ success, message = await tokenizer_manager.update_parameter_from_distributed(
+ obj, request
+ )
+ content = {"success": success, "message": message}
+ if success:
+ return ORJSONResponse(content, status_code=200)
+ else:
+ return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
+
+
+@app.api_route("/get_weights_by_parameter_name", methods=["GET", "POST"])
+async def get_weights_by_parameter_name(
+ obj: GetParameterByNameReqInput, request: Request
+):
+ """Get model parameter by name."""
+ try:
+ ret = await tokenizer_manager.get_weights_by_parameter_name(obj, request)
+ if ret is None:
+ return ORJSONResponse(
+ {"error": {"message": "Get parameter by name failed"}},
+ status_code=HTTPStatus.BAD_REQUEST,
+ )
+ else:
+ return ORJSONResponse(ret, status_code=200)
+ except Exception as e:
+ return ORJSONResponse(
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
+ )
+
+
@app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""
@@ -276,6 +344,51 @@ async def stream_results() -> AsyncIterator[bytes]:
)
+@time_func_latency
+async def init_parameter_update_group_request(
+ obj: InitParameterUpdateGroupReqInput, request: Request
+):
+ """Handle an init parameter update group request."""
+ try:
+ ret = await tokenizer_manager.init_parameter_update_group(obj, request)
+ print(f"init_parameter_update_group_request in server: {ret}")
+ return ret
+ except ValueError as e:
+ return ORJSONResponse(
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
+ )
+
+
+@time_func_latency
+async def get_weights_by_parameter_name_request(
+ obj: GetParameterByNameReqInput, request: Request
+):
+ """Handle a get parameter by name request."""
+ try:
+ ret = await tokenizer_manager.get_weights_by_parameter_name(obj, request)
+ return ret
+ except ValueError as e:
+ return ORJSONResponse(
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
+ )
+
+
+@time_func_latency
+async def update_parameter_from_distributed_request(
+ obj: UpdateParameterFromDistributedReqInput, request: Request
+):
+ """Handle an update parameter from distributed request."""
+ try:
+ torch.cuda.synchronize()
+ print(f"try to update parameter from distributed in server")
+ ret = await tokenizer_manager.update_parameter_from_distributed(obj, request)
+ return ret
+ except ValueError as e:
+ return ORJSONResponse(
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
+ )
+
+
# fastapi implicitly converts json in the request to obj (dataclass)
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)
@@ -946,3 +1059,99 @@ def encode(
async def get_server_info(self):
return await _get_server_info()
+
+ def generate(
+ self,
+ # The input prompt. It can be a single prompt or a batch of prompts.
+ prompt: Optional[Union[List[str], str]] = None,
+ sampling_params: Optional[Union[List[Dict], Dict]] = None,
+ # The token ids for text; one can either specify text or input_ids.
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None,
+ return_logprob: Optional[Union[List[bool], bool]] = False,
+ logprob_start_len: Optional[Union[List[int], int]] = None,
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
+ lora_path: Optional[List[Optional[str]]] = None,
+ stream: bool = False,
+ ):
+ obj = GenerateReqInput(
+ text=prompt,
+ input_ids=input_ids,
+ sampling_params=sampling_params,
+ return_logprob=return_logprob,
+ logprob_start_len=logprob_start_len,
+ top_logprobs_num=top_logprobs_num,
+ lora_path=lora_path,
+ stream=stream,
+ )
+
+ # get the current event loop
+ loop = asyncio.get_event_loop()
+ ret = loop.run_until_complete(generate_request(obj, None))
+
+ if stream is True:
+
+ def generator_wrapper():
+ offset = 0
+ loop = asyncio.get_event_loop()
+ generator = ret.body_iterator
+ while True:
+ chunk = loop.run_until_complete(generator.__anext__())
+
+ if chunk.startswith(STREAM_END_SYMBOL):
+ break
+ else:
+ data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
+ data["text"] = data["text"][offset:]
+ offset += len(data["text"])
+ yield data
+
+ # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
+ # however, it allows to wrap the generator as a subfunction and return
+ return generator_wrapper()
+ else:
+ return ret
+
+ def init_parameter_update_group(
+ self,
+ master_address: str,
+ master_port: int,
+ rank_offset: int,
+ world_size: int,
+ group_name: str,
+ backend: str = "nccl",
+ ):
+ obj = InitParameterUpdateGroupReqInput(
+ master_address=master_address,
+ master_port=master_port,
+ rank_offset=rank_offset,
+ world_size=world_size,
+ group_name=group_name,
+ backend=backend,
+ )
+
+ loop = asyncio.get_event_loop()
+ return loop.run_until_complete(init_parameter_update_group_request(obj, None))
+
+ def update_parameter_from_distributed(self, name, dtype, shape, empty_cache=False):
+ print(f"update parameter from distributed request in engine before synchronize")
+ torch.cuda.synchronize()
+ print(f"update parameter from distributed request in engine after synchronize")
+ obj = UpdateParameterFromDistributedReqInput(
+ name=name,
+ dtype=dtype,
+ shape=shape,
+ empty_cache=empty_cache,
+ )
+ torch.cuda.synchronize()
+ print(f"update parameter from distributed request in engine")
+ loop = asyncio.get_event_loop()
+ torch.cuda.synchronize()
+ print(f"try to update parameter from distributed request in engine")
+ return loop.run_until_complete(
+ update_parameter_from_distributed_request(obj, None)
+ )
+
+ def get_weights_by_parameter_name(self, name, truncate_size=100):
+ obj = GetParameterByNameReqInput(name=name, truncate_size=truncate_size)
+ loop = asyncio.get_event_loop()
+ return loop.run_until_complete(get_weights_by_parameter_name_request(obj, None))
diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py
index 4a974e2e754..867b0fa61d7 100644
--- a/python/sglang/srt/utils.py
+++ b/python/sglang/srt/utils.py
@@ -30,6 +30,7 @@
import tempfile
import time
import warnings
+from datetime import timedelta
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
@@ -38,6 +39,7 @@
import psutil
import requests
import torch
+import torch.distributed
import torch.distributed as dist
import triton
import zmq
@@ -45,6 +47,16 @@
from packaging import version as pkg_version
from starlette.routing import Mount
from torch import nn
+from torch.distributed.distributed_c10d import (
+ Backend,
+ PrefixStore,
+ Store,
+ _new_process_group_helper,
+ _store_based_barrier,
+ _world,
+ default_pg_timeout,
+ rendezvous,
+)
from torch.func import functional_call
from torch.library import Library
from torch.profiler import ProfilerActivity, profile, record_function
@@ -934,6 +946,71 @@ def get_nvgpu_memory_capacity():
)
+# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
+# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
+def init_custom_process_group(
+ backend: Union[str, Backend] = None,
+ init_method: Optional[str] = None,
+ timeout: Optional[timedelta] = None,
+ world_size: int = -1,
+ rank: int = -1,
+ store: Optional[Store] = None,
+ group_name: str = None,
+ pg_options: Optional[Any] = None,
+):
+ assert (store is None) or (
+ init_method is None
+ ), "Cannot specify both init_method and store."
+
+ if store is not None:
+ assert world_size > 0, "world_size must be positive if using store"
+ assert rank >= 0, "rank must be non-negative if using store"
+ elif init_method is None:
+ init_method = "env://"
+
+ if backend:
+ backend = Backend(backend)
+ else:
+ backend = Backend("undefined")
+
+ if timeout is None:
+ timeout = default_pg_timeout
+
+ # backward compatible API
+ if store is None:
+ rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
+ store, rank, world_size = next(rendezvous_iterator)
+ store.set_timeout(timeout)
+
+ # Use a PrefixStore to avoid accidental overrides of keys used by
+ # different systems (e.g. RPC) in case the store is multi-tenant.
+ store = PrefixStore(group_name, store)
+
+ # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
+ # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
+ # We need to determine the appropriate parameter name based on PyTorch version
+ pg_options_param_name = (
+ "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
+ )
+ pg, _ = _new_process_group_helper(
+ world_size,
+ rank,
+ [],
+ backend,
+ store,
+ group_name=group_name,
+ **{pg_options_param_name: pg_options},
+ timeout=timeout,
+ )
+
+ logger.error(f"pg pass in init_custom_process_group world size: {world_size}")
+ logger.error(f"pg pass in init_custom_process_group rank: {rank}")
+
+ _world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
+
+ return pg
+
+
def crash_on_warnings():
# Crash on warning if we are running CI tests
return get_bool_env_var("SGLANG_IS_IN_CI")
diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py
index 3089668443e..41906e51a3a 100644
--- a/python/sglang/test/test_utils.py
+++ b/python/sglang/test/test_utils.py
@@ -424,6 +424,10 @@ def popen_launch_server(
port,
*other_args,
]
+
+ if api_key:
+ command += ["--api-key", api_key]
+
if api_key:
command += ["--api-key", api_key]
diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py
index 3f55eb25fdc..ba436ffd811 100644
--- a/test/srt/run_suite.py
+++ b/test/srt/run_suite.py
@@ -11,6 +11,7 @@
"models/test_reward_models.py",
"sampling/penaltylib",
"test_chunked_prefill.py",
+ "test_custom_process_group.py",
"test_double_sparsity.py",
"test_embedding_openai_server.py",
"test_eval_accuracy_mini.py",
diff --git a/test/srt/stderr.txt b/test/srt/stderr.txt
new file mode 100644
index 00000000000..25e5f8d74a1
--- /dev/null
+++ b/test/srt/stderr.txt
@@ -0,0 +1,447 @@
+/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/transformers/utils/hub.py:128: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
+ warnings.warn(
+[2024-11-22 03:08:11] server_args=ServerArgs(model_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-3.1-8B-Instruct', chat_template=None, is_embedding=False, host='127.0.0.1', port=4254, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=32, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, tp_size=1, stream_interval=1, random_seed=668455020, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=True, enable_mixed_chunk=True, enable_dp_attention=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, num_continuous_decode_steps=1, delete_ckpt_after_loading=False)
+/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/transformers/utils/hub.py:128: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
+ warnings.warn(
+/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/transformers/utils/hub.py:128: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
+ warnings.warn(
+[2024-11-22 03:08:19 TP0] Init torch distributed begin.
+[2024-11-22 03:08:20 TP0] Load weight begin. avail mem=78.50 GB
+[2024-11-22 03:08:21 TP0] lm_eval is not installed, GPTQ may not be usable
+
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00, ?it/s]
+
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:01, 1.65it/s]
+
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.45it/s]
+
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:02<00:00, 1.38it/s]
+
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.78it/s]
+
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.65it/s]
+
+[2024-11-22 03:08:24 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=63.41 GB
+[2024-11-22 03:08:24 TP0] Memory pool end. avail mem=8.41 GB
+[2024-11-22 03:08:24 TP0] Capture cuda graph begin. This can take up to several minutes.
+[2024-11-22 03:08:32 TP0] max_total_num_tokens=442300, max_prefill_tokens=16384, max_running_requests=2049, context_len=131072
+[2024-11-22 03:08:32] INFO: Started server process [3057147]
+[2024-11-22 03:08:32] INFO: Waiting for application startup.
+[2024-11-22 03:08:32] INFO: Application startup complete.
+[2024-11-22 03:08:32] INFO: Uvicorn running on http://127.0.0.1:4254 (Press CTRL+C to quit)
+[2024-11-22 03:08:33 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
+[2024-11-22 03:08:33 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
+[2024-11-22 03:08:33] The server is fired up and ready to roll!
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 1, #new-token: 32, #cached-token: 1, cache hit rate: 2.44%, token usage: 0.00, #running-req: 0, #queue-req: 0
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 1, #new-token: 32, #cached-token: 33, cache hit rate: 32.08%, token usage: 0.00, #running-req: 0, #queue-req: 14
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 1, #new-token: 32, #cached-token: 65, cache hit rate: 48.77%, token usage: 0.00, #running-req: 0, #queue-req: 16
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 1, #new-token: 32, #cached-token: 97, cache hit rate: 59.04%, token usage: 0.00, #running-req: 0, #queue-req: 28
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 2, #new-token: 32, #cached-token: 206, cache hit rate: 70.53%, token usage: 0.00, #running-req: 0, #queue-req: 32
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 2, #new-token: 31, #cached-token: 179, cache hit rate: 74.49%, token usage: 0.00, #running-req: 1, #queue-req: 39
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 1, #new-token: 30, #cached-token: 90, cache hit rate: 74.56%, token usage: 0.00, #running-req: 2, #queue-req: 48
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 2, #new-token: 30, #cached-token: 197, cache hit rate: 77.02%, token usage: 0.00, #running-req: 2, #queue-req: 53
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 2, #new-token: 29, #cached-token: 170, cache hit rate: 78.28%, token usage: 0.00, #running-req: 3, #queue-req: 61
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 1, #new-token: 28, #cached-token: 88, cache hit rate: 78.09%, token usage: 0.00, #running-req: 4, #queue-req: 69
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 2, #new-token: 28, #cached-token: 193, cache hit rate: 79.31%, token usage: 0.00, #running-req: 4, #queue-req: 73
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 1, #new-token: 27, #cached-token: 82, cache hit rate: 79.06%, token usage: 0.00, #running-req: 5, #queue-req: 82
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 2, #new-token: 27, #cached-token: 186, cache hit rate: 79.95%, token usage: 0.00, #running-req: 5, #queue-req: 90
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 1, #new-token: 26, #cached-token: 94, cache hit rate: 79.86%, token usage: 0.00, #running-req: 6, #queue-req: 98
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 1, #new-token: 25, #cached-token: 77, cache hit rate: 79.66%, token usage: 0.00, #running-req: 7, #queue-req: 104
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 2, #new-token: 25, #cached-token: 179, cache hit rate: 80.34%, token usage: 0.00, #running-req: 7, #queue-req: 112
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 1, #new-token: 24, #cached-token: 79, cache hit rate: 80.19%, token usage: 0.00, #running-req: 8, #queue-req: 120
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 1, #new-token: 24, #cached-token: 103, cache hit rate: 80.23%, token usage: 0.00, #running-req: 8, #queue-req: 120
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 2, #new-token: 24, #cached-token: 204, cache hit rate: 80.97%, token usage: 0.00, #running-req: 8, #queue-req: 119
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 2, #new-token: 23, #cached-token: 177, cache hit rate: 81.46%, token usage: 0.00, #running-req: 9, #queue-req: 118
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 1, #new-token: 22, #cached-token: 84, cache hit rate: 81.39%, token usage: 0.00, #running-req: 10, #queue-req: 118
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 1, #new-token: 22, #cached-token: 106, cache hit rate: 81.44%, token usage: 0.00, #running-req: 10, #queue-req: 118
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 1, #new-token: 22, #cached-token: 128, cache hit rate: 81.61%, token usage: 0.00, #running-req: 10, #queue-req: 118
+[2024-11-22 03:08:34 TP0] Prefill batch. #new-seq: 2, #new-token: 22, #cached-token: 227, cache hit rate: 82.25%, token usage: 0.00, #running-req: 10, #queue-req: 117
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 2, #new-token: 21, #cached-token: 175, cache hit rate: 82.61%, token usage: 0.00, #running-req: 11, #queue-req: 116
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 20, #cached-token: 78, cache hit rate: 82.53%, token usage: 0.00, #running-req: 12, #queue-req: 116
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 20, #cached-token: 98, cache hit rate: 82.55%, token usage: 0.00, #running-req: 12, #queue-req: 116
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 20, #cached-token: 118, cache hit rate: 82.64%, token usage: 0.00, #running-req: 12, #queue-req: 116
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 20, #cached-token: 138, cache hit rate: 82.81%, token usage: 0.00, #running-req: 12, #queue-req: 116
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 2, #new-token: 20, #cached-token: 235, cache hit rate: 83.32%, token usage: 0.00, #running-req: 12, #queue-req: 115
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 19, #cached-token: 78, cache hit rate: 83.26%, token usage: 0.00, #running-req: 13, #queue-req: 115
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 19, #cached-token: 97, cache hit rate: 83.27%, token usage: 0.00, #running-req: 13, #queue-req: 115
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 2, #new-token: 19, #cached-token: 193, cache hit rate: 83.60%, token usage: 0.00, #running-req: 13, #queue-req: 114
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 18, #cached-token: 89, cache hit rate: 83.59%, token usage: 0.00, #running-req: 14, #queue-req: 114
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 18, #cached-token: 107, cache hit rate: 83.63%, token usage: 0.00, #running-req: 14, #queue-req: 114
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 2, #new-token: 18, #cached-token: 201, cache hit rate: 83.96%, token usage: 0.00, #running-req: 14, #queue-req: 113
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 17, #cached-token: 78, cache hit rate: 83.92%, token usage: 0.00, #running-req: 15, #queue-req: 113
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 17, #cached-token: 95, cache hit rate: 83.94%, token usage: 0.00, #running-req: 15, #queue-req: 113
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 2, #new-token: 17, #cached-token: 192, cache hit rate: 84.22%, token usage: 0.00, #running-req: 15, #queue-req: 112
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 16, #cached-token: 92, cache hit rate: 84.24%, token usage: 0.00, #running-req: 16, #queue-req: 112
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 16, #cached-token: 108, cache hit rate: 84.29%, token usage: 0.00, #running-req: 16, #queue-req: 112
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 16, #cached-token: 124, cache hit rate: 84.39%, token usage: 0.00, #running-req: 16, #queue-req: 112
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 2, #new-token: 16, #cached-token: 223, cache hit rate: 84.71%, token usage: 0.00, #running-req: 16, #queue-req: 111
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 15, #cached-token: 91, cache hit rate: 84.73%, token usage: 0.00, #running-req: 17, #queue-req: 111
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 15, #cached-token: 106, cache hit rate: 84.78%, token usage: 0.00, #running-req: 17, #queue-req: 111
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 2, #new-token: 15, #cached-token: 203, cache hit rate: 85.04%, token usage: 0.00, #running-req: 17, #queue-req: 110
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 14, #cached-token: 89, cache hit rate: 85.06%, token usage: 0.00, #running-req: 18, #queue-req: 110
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 2, #new-token: 14, #cached-token: 186, cache hit rate: 85.28%, token usage: 0.00, #running-req: 18, #queue-req: 109
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 13, #cached-token: 87, cache hit rate: 85.30%, token usage: 0.00, #running-req: 19, #queue-req: 109
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 13, #cached-token: 100, cache hit rate: 85.35%, token usage: 0.00, #running-req: 19, #queue-req: 109
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 13, #cached-token: 113, cache hit rate: 85.42%, token usage: 0.00, #running-req: 19, #queue-req: 109
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 2, #new-token: 13, #cached-token: 208, cache hit rate: 85.67%, token usage: 0.00, #running-req: 19, #queue-req: 108
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 12, #cached-token: 90, cache hit rate: 85.70%, token usage: 0.00, #running-req: 20, #queue-req: 108
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 2, #new-token: 12, #cached-token: 183, cache hit rate: 85.89%, token usage: 0.00, #running-req: 20, #queue-req: 107
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 83, cache hit rate: 85.92%, token usage: 0.00, #running-req: 21, #queue-req: 107
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 94, cache hit rate: 85.97%, token usage: 0.00, #running-req: 21, #queue-req: 107
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 105, cache hit rate: 86.03%, token usage: 0.00, #running-req: 21, #queue-req: 107
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 116, cache hit rate: 86.11%, token usage: 0.00, #running-req: 21, #queue-req: 107
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 127, cache hit rate: 86.20%, token usage: 0.00, #running-req: 21, #queue-req: 107
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 138, cache hit rate: 86.31%, token usage: 0.00, #running-req: 21, #queue-req: 107
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 149, cache hit rate: 86.43%, token usage: 0.00, #running-req: 21, #queue-req: 107
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 160, cache hit rate: 86.56%, token usage: 0.00, #running-req: 21, #queue-req: 107
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 171, cache hit rate: 86.70%, token usage: 0.00, #running-req: 21, #queue-req: 107
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 182, cache hit rate: 86.85%, token usage: 0.00, #running-req: 21, #queue-req: 107
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 193, cache hit rate: 87.02%, token usage: 0.00, #running-req: 21, #queue-req: 107
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 204, cache hit rate: 87.18%, token usage: 0.00, #running-req: 21, #queue-req: 107
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 2, #new-token: 11, #cached-token: 296, cache hit rate: 87.46%, token usage: 0.00, #running-req: 21, #queue-req: 106
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 91, cache hit rate: 87.48%, token usage: 0.00, #running-req: 22, #queue-req: 106
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 101, cache hit rate: 87.52%, token usage: 0.00, #running-req: 22, #queue-req: 106
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 111, cache hit rate: 87.57%, token usage: 0.00, #running-req: 22, #queue-req: 106
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 121, cache hit rate: 87.63%, token usage: 0.01, #running-req: 22, #queue-req: 106
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 131, cache hit rate: 87.70%, token usage: 0.01, #running-req: 22, #queue-req: 106
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 141, cache hit rate: 87.77%, token usage: 0.01, #running-req: 22, #queue-req: 106
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 151, cache hit rate: 87.86%, token usage: 0.01, #running-req: 22, #queue-req: 106
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 2, #new-token: 10, #cached-token: 243, cache hit rate: 88.04%, token usage: 0.01, #running-req: 22, #queue-req: 105
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 88, cache hit rate: 88.06%, token usage: 0.01, #running-req: 23, #queue-req: 105
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 97, cache hit rate: 88.09%, token usage: 0.01, #running-req: 23, #queue-req: 105
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 106, cache hit rate: 88.13%, token usage: 0.01, #running-req: 23, #queue-req: 105
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 115, cache hit rate: 88.18%, token usage: 0.01, #running-req: 23, #queue-req: 105
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 2, #new-token: 9, #cached-token: 205, cache hit rate: 88.31%, token usage: 0.01, #running-req: 23, #queue-req: 104
+[2024-11-22 03:08:35 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 86, cache hit rate: 88.34%, token usage: 0.01, #running-req: 24, #queue-req: 104
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 94, cache hit rate: 88.37%, token usage: 0.01, #running-req: 24, #queue-req: 104
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 102, cache hit rate: 88.41%, token usage: 0.01, #running-req: 24, #queue-req: 104
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 110, cache hit rate: 88.45%, token usage: 0.01, #running-req: 24, #queue-req: 104
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 118, cache hit rate: 88.51%, token usage: 0.01, #running-req: 24, #queue-req: 104
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 126, cache hit rate: 88.56%, token usage: 0.01, #running-req: 24, #queue-req: 104
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 134, cache hit rate: 88.63%, token usage: 0.01, #running-req: 24, #queue-req: 104
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 2, #new-token: 8, #cached-token: 222, cache hit rate: 88.76%, token usage: 0.01, #running-req: 24, #queue-req: 103
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 81, cache hit rate: 88.79%, token usage: 0.01, #running-req: 25, #queue-req: 103
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 88, cache hit rate: 88.81%, token usage: 0.01, #running-req: 25, #queue-req: 103
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 95, cache hit rate: 88.85%, token usage: 0.01, #running-req: 25, #queue-req: 103
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 102, cache hit rate: 88.88%, token usage: 0.01, #running-req: 25, #queue-req: 103
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 109, cache hit rate: 88.93%, token usage: 0.01, #running-req: 25, #queue-req: 103
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 2, #new-token: 7, #cached-token: 196, cache hit rate: 89.04%, token usage: 0.01, #running-req: 25, #queue-req: 102
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 83, cache hit rate: 89.06%, token usage: 0.01, #running-req: 26, #queue-req: 102
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 89, cache hit rate: 89.10%, token usage: 0.01, #running-req: 26, #queue-req: 102
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 95, cache hit rate: 89.13%, token usage: 0.01, #running-req: 26, #queue-req: 102
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 101, cache hit rate: 89.17%, token usage: 0.01, #running-req: 26, #queue-req: 102
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 107, cache hit rate: 89.21%, token usage: 0.01, #running-req: 26, #queue-req: 102
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 113, cache hit rate: 89.26%, token usage: 0.01, #running-req: 26, #queue-req: 102
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 119, cache hit rate: 89.31%, token usage: 0.01, #running-req: 26, #queue-req: 102
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 125, cache hit rate: 89.36%, token usage: 0.01, #running-req: 26, #queue-req: 102
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 131, cache hit rate: 89.42%, token usage: 0.01, #running-req: 26, #queue-req: 102
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 137, cache hit rate: 89.48%, token usage: 0.01, #running-req: 26, #queue-req: 102
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 143, cache hit rate: 89.55%, token usage: 0.01, #running-req: 26, #queue-req: 102
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 149, cache hit rate: 89.61%, token usage: 0.01, #running-req: 26, #queue-req: 102
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 155, cache hit rate: 89.68%, token usage: 0.01, #running-req: 26, #queue-req: 102
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 161, cache hit rate: 89.75%, token usage: 0.01, #running-req: 26, #queue-req: 102
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 2, #new-token: 6, #cached-token: 248, cache hit rate: 89.88%, token usage: 0.01, #running-req: 26, #queue-req: 101
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 85, cache hit rate: 89.90%, token usage: 0.01, #running-req: 27, #queue-req: 101
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 90, cache hit rate: 89.93%, token usage: 0.01, #running-req: 27, #queue-req: 101
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 95, cache hit rate: 89.96%, token usage: 0.01, #running-req: 27, #queue-req: 101
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 100, cache hit rate: 90.00%, token usage: 0.01, #running-req: 27, #queue-req: 101
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 105, cache hit rate: 90.03%, token usage: 0.01, #running-req: 27, #queue-req: 101
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 110, cache hit rate: 90.07%, token usage: 0.01, #running-req: 27, #queue-req: 101
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 115, cache hit rate: 90.11%, token usage: 0.01, #running-req: 27, #queue-req: 101
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 120, cache hit rate: 90.16%, token usage: 0.01, #running-req: 27, #queue-req: 101
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 125, cache hit rate: 90.20%, token usage: 0.01, #running-req: 27, #queue-req: 101
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 130, cache hit rate: 90.25%, token usage: 0.01, #running-req: 27, #queue-req: 101
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 135, cache hit rate: 90.30%, token usage: 0.01, #running-req: 27, #queue-req: 101
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 2, #new-token: 5, #cached-token: 220, cache hit rate: 90.40%, token usage: 0.01, #running-req: 27, #queue-req: 100
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 81, cache hit rate: 90.42%, token usage: 0.01, #running-req: 28, #queue-req: 100
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 85, cache hit rate: 90.45%, token usage: 0.01, #running-req: 28, #queue-req: 100
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 89, cache hit rate: 90.47%, token usage: 0.01, #running-req: 28, #queue-req: 100
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 93, cache hit rate: 90.50%, token usage: 0.01, #running-req: 28, #queue-req: 100
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 97, cache hit rate: 90.54%, token usage: 0.01, #running-req: 28, #queue-req: 100
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 101, cache hit rate: 90.57%, token usage: 0.01, #running-req: 28, #queue-req: 100
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 105, cache hit rate: 90.60%, token usage: 0.01, #running-req: 28, #queue-req: 100
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 109, cache hit rate: 90.64%, token usage: 0.01, #running-req: 28, #queue-req: 100
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 113, cache hit rate: 90.68%, token usage: 0.01, #running-req: 28, #queue-req: 100
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 117, cache hit rate: 90.72%, token usage: 0.01, #running-req: 28, #queue-req: 100
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 2, #new-token: 4, #cached-token: 201, cache hit rate: 90.80%, token usage: 0.01, #running-req: 28, #queue-req: 99
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 81, cache hit rate: 90.82%, token usage: 0.01, #running-req: 29, #queue-req: 99
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 84, cache hit rate: 90.85%, token usage: 0.01, #running-req: 29, #queue-req: 99
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 87, cache hit rate: 90.88%, token usage: 0.01, #running-req: 29, #queue-req: 99
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 90, cache hit rate: 90.90%, token usage: 0.01, #running-req: 29, #queue-req: 99
+[2024-11-22 03:08:36 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 93, cache hit rate: 90.93%, token usage: 0.01, #running-req: 29, #queue-req: 99
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 96, cache hit rate: 90.97%, token usage: 0.01, #running-req: 29, #queue-req: 99
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 99, cache hit rate: 91.00%, token usage: 0.01, #running-req: 29, #queue-req: 99
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 102, cache hit rate: 91.03%, token usage: 0.01, #running-req: 29, #queue-req: 99
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 105, cache hit rate: 91.07%, token usage: 0.01, #running-req: 29, #queue-req: 99
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 108, cache hit rate: 91.10%, token usage: 0.01, #running-req: 29, #queue-req: 99
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 111, cache hit rate: 91.14%, token usage: 0.01, #running-req: 29, #queue-req: 99
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 114, cache hit rate: 91.17%, token usage: 0.01, #running-req: 29, #queue-req: 99
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 117, cache hit rate: 91.21%, token usage: 0.01, #running-req: 29, #queue-req: 99
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 2, #new-token: 3, #cached-token: 200, cache hit rate: 91.28%, token usage: 0.01, #running-req: 29, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 82, cache hit rate: 91.31%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 84, cache hit rate: 91.34%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 86, cache hit rate: 91.37%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 88, cache hit rate: 91.39%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 90, cache hit rate: 91.42%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 92, cache hit rate: 91.45%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 94, cache hit rate: 91.48%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 96, cache hit rate: 91.51%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 98, cache hit rate: 91.54%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 100, cache hit rate: 91.57%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 102, cache hit rate: 91.61%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 104, cache hit rate: 91.64%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 106, cache hit rate: 91.67%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 108, cache hit rate: 91.70%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 110, cache hit rate: 91.74%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 112, cache hit rate: 91.77%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 114, cache hit rate: 91.81%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 116, cache hit rate: 91.84%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 118, cache hit rate: 91.88%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 120, cache hit rate: 91.91%, token usage: 0.01, #running-req: 30, #queue-req: 98
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 2, #new-token: 2, #cached-token: 202, cache hit rate: 91.98%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 81, cache hit rate: 92.00%, token usage: 0.01, #running-req: 31, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 82, cache hit rate: 92.02%, token usage: 0.01, #running-req: 31, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 83, cache hit rate: 92.05%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 85, cache hit rate: 92.07%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 87, cache hit rate: 92.09%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 89, cache hit rate: 92.11%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 91, cache hit rate: 92.14%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 93, cache hit rate: 92.16%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 95, cache hit rate: 92.18%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 97, cache hit rate: 92.21%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 99, cache hit rate: 92.23%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 101, cache hit rate: 92.26%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 103, cache hit rate: 92.28%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 105, cache hit rate: 92.31%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 107, cache hit rate: 92.34%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 109, cache hit rate: 92.36%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 111, cache hit rate: 92.39%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 113, cache hit rate: 92.42%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 115, cache hit rate: 92.45%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 117, cache hit rate: 92.48%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 119, cache hit rate: 92.50%, token usage: 0.01, #running-req: 30, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 121, cache hit rate: 92.53%, token usage: 0.01, #running-req: 29, #queue-req: 97
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 80, cache hit rate: 92.55%, token usage: 0.01, #running-req: 30, #queue-req: 95
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 82, cache hit rate: 92.56%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 84, cache hit rate: 92.58%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 86, cache hit rate: 92.60%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:37 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 88, cache hit rate: 92.62%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 90, cache hit rate: 92.64%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 92, cache hit rate: 92.66%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 94, cache hit rate: 92.68%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 96, cache hit rate: 92.70%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 98, cache hit rate: 92.72%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 100, cache hit rate: 92.74%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 102, cache hit rate: 92.76%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 104, cache hit rate: 92.78%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 106, cache hit rate: 92.80%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 108, cache hit rate: 92.83%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 110, cache hit rate: 92.85%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 112, cache hit rate: 92.87%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 114, cache hit rate: 92.90%, token usage: 0.01, #running-req: 30, #queue-req: 96
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 78, cache hit rate: 92.91%, token usage: 0.01, #running-req: 31, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 79, cache hit rate: 92.93%, token usage: 0.01, #running-req: 31, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 80, cache hit rate: 92.95%, token usage: 0.01, #running-req: 31, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 81, cache hit rate: 92.97%, token usage: 0.01, #running-req: 31, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 82, cache hit rate: 92.98%, token usage: 0.01, #running-req: 31, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 83, cache hit rate: 93.00%, token usage: 0.01, #running-req: 31, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 84, cache hit rate: 93.02%, token usage: 0.01, #running-req: 31, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 85, cache hit rate: 93.04%, token usage: 0.01, #running-req: 31, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 86, cache hit rate: 93.06%, token usage: 0.01, #running-req: 31, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 87, cache hit rate: 93.08%, token usage: 0.01, #running-req: 31, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 88, cache hit rate: 93.10%, token usage: 0.01, #running-req: 31, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 89, cache hit rate: 93.11%, token usage: 0.01, #running-req: 30, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 91, cache hit rate: 93.13%, token usage: 0.01, #running-req: 30, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 93, cache hit rate: 93.14%, token usage: 0.01, #running-req: 30, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 95, cache hit rate: 93.16%, token usage: 0.01, #running-req: 30, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 97, cache hit rate: 93.18%, token usage: 0.01, #running-req: 30, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 99, cache hit rate: 93.19%, token usage: 0.01, #running-req: 29, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 102, cache hit rate: 93.21%, token usage: 0.01, #running-req: 29, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 105, cache hit rate: 93.22%, token usage: 0.01, #running-req: 29, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 108, cache hit rate: 93.24%, token usage: 0.01, #running-req: 29, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 111, cache hit rate: 93.25%, token usage: 0.01, #running-req: 29, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 114, cache hit rate: 93.27%, token usage: 0.01, #running-req: 29, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 117, cache hit rate: 93.29%, token usage: 0.01, #running-req: 29, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 120, cache hit rate: 93.31%, token usage: 0.01, #running-req: 29, #queue-req: 95
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 2, #new-token: 3, #cached-token: 200, cache hit rate: 93.34%, token usage: 0.01, #running-req: 29, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 79, cache hit rate: 93.35%, token usage: 0.01, #running-req: 30, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 81, cache hit rate: 93.37%, token usage: 0.01, #running-req: 30, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 83, cache hit rate: 93.38%, token usage: 0.01, #running-req: 30, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 85, cache hit rate: 93.39%, token usage: 0.01, #running-req: 30, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 87, cache hit rate: 93.41%, token usage: 0.01, #running-req: 30, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 89, cache hit rate: 93.42%, token usage: 0.02, #running-req: 30, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 91, cache hit rate: 93.43%, token usage: 0.02, #running-req: 30, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 93, cache hit rate: 93.45%, token usage: 0.02, #running-req: 30, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 95, cache hit rate: 93.46%, token usage: 0.01, #running-req: 29, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 98, cache hit rate: 93.47%, token usage: 0.01, #running-req: 29, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 101, cache hit rate: 93.48%, token usage: 0.01, #running-req: 29, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 104, cache hit rate: 93.50%, token usage: 0.01, #running-req: 29, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 107, cache hit rate: 93.51%, token usage: 0.02, #running-req: 29, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 110, cache hit rate: 93.52%, token usage: 0.02, #running-req: 29, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 113, cache hit rate: 93.54%, token usage: 0.01, #running-req: 28, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 117, cache hit rate: 93.55%, token usage: 0.01, #running-req: 28, #queue-req: 94
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 2, #new-token: 4, #cached-token: 197, cache hit rate: 93.58%, token usage: 0.01, #running-req: 28, #queue-req: 93
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 77, cache hit rate: 93.58%, token usage: 0.01, #running-req: 28, #queue-req: 93
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 81, cache hit rate: 93.59%, token usage: 0.01, #running-req: 28, #queue-req: 93
+[2024-11-22 03:08:38 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 85, cache hit rate: 93.59%, token usage: 0.01, #running-req: 27, #queue-req: 93
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 90, cache hit rate: 93.59%, token usage: 0.01, #running-req: 27, #queue-req: 93
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 95, cache hit rate: 93.60%, token usage: 0.01, #running-req: 27, #queue-req: 93
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 100, cache hit rate: 93.60%, token usage: 0.01, #running-req: 27, #queue-req: 93
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 105, cache hit rate: 93.61%, token usage: 0.01, #running-req: 27, #queue-req: 93
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 110, cache hit rate: 93.61%, token usage: 0.01, #running-req: 26, #queue-req: 93
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 116, cache hit rate: 93.62%, token usage: 0.01, #running-req: 25, #queue-req: 93
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 123, cache hit rate: 93.62%, token usage: 0.01, #running-req: 25, #queue-req: 93
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 2, #new-token: 7, #cached-token: 207, cache hit rate: 93.64%, token usage: 0.01, #running-req: 25, #queue-req: 92
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 79, cache hit rate: 93.64%, token usage: 0.01, #running-req: 25, #queue-req: 92
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 86, cache hit rate: 93.63%, token usage: 0.01, #running-req: 25, #queue-req: 92
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 93, cache hit rate: 93.63%, token usage: 0.01, #running-req: 25, #queue-req: 92
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 100, cache hit rate: 93.63%, token usage: 0.01, #running-req: 25, #queue-req: 92
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 2, #new-token: 7, #cached-token: 184, cache hit rate: 93.65%, token usage: 0.01, #running-req: 25, #queue-req: 91
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 80, cache hit rate: 93.64%, token usage: 0.01, #running-req: 26, #queue-req: 91
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 86, cache hit rate: 93.64%, token usage: 0.01, #running-req: 26, #queue-req: 91
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 92, cache hit rate: 93.64%, token usage: 0.01, #running-req: 26, #queue-req: 91
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 98, cache hit rate: 93.65%, token usage: 0.01, #running-req: 26, #queue-req: 91
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 104, cache hit rate: 93.65%, token usage: 0.01, #running-req: 26, #queue-req: 91
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 2, #new-token: 6, #cached-token: 187, cache hit rate: 93.67%, token usage: 0.01, #running-req: 26, #queue-req: 90
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 80, cache hit rate: 93.67%, token usage: 0.01, #running-req: 26, #queue-req: 90
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 86, cache hit rate: 93.67%, token usage: 0.01, #running-req: 26, #queue-req: 90
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 92, cache hit rate: 93.67%, token usage: 0.01, #running-req: 26, #queue-req: 90
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 98, cache hit rate: 93.67%, token usage: 0.01, #running-req: 26, #queue-req: 90
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 104, cache hit rate: 93.67%, token usage: 0.01, #running-req: 25, #queue-req: 90
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 2, #new-token: 7, #cached-token: 188, cache hit rate: 93.68%, token usage: 0.01, #running-req: 25, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 79, cache hit rate: 93.68%, token usage: 0.01, #running-req: 26, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 85, cache hit rate: 93.68%, token usage: 0.01, #running-req: 24, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 93, cache hit rate: 93.67%, token usage: 0.01, #running-req: 24, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 101, cache hit rate: 93.67%, token usage: 0.01, #running-req: 24, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 109, cache hit rate: 93.66%, token usage: 0.01, #running-req: 23, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 118, cache hit rate: 93.66%, token usage: 0.01, #running-req: 23, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 127, cache hit rate: 93.66%, token usage: 0.01, #running-req: 23, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 136, cache hit rate: 93.66%, token usage: 0.01, #running-req: 23, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 145, cache hit rate: 93.66%, token usage: 0.01, #running-req: 22, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 155, cache hit rate: 93.66%, token usage: 0.01, #running-req: 22, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 165, cache hit rate: 93.66%, token usage: 0.01, #running-req: 22, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 175, cache hit rate: 93.67%, token usage: 0.01, #running-req: 21, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 186, cache hit rate: 93.67%, token usage: 0.01, #running-req: 21, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 197, cache hit rate: 93.68%, token usage: 0.01, #running-req: 21, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 208, cache hit rate: 93.68%, token usage: 0.01, #running-req: 21, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 219, cache hit rate: 93.69%, token usage: 0.01, #running-req: 21, #queue-req: 89
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 77, cache hit rate: 93.68%, token usage: 0.01, #running-req: 22, #queue-req: 87
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 87, cache hit rate: 93.67%, token usage: 0.01, #running-req: 22, #queue-req: 88
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 97, cache hit rate: 93.66%, token usage: 0.01, #running-req: 22, #queue-req: 88
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 2, #new-token: 11, #cached-token: 184, cache hit rate: 93.67%, token usage: 0.01, #running-req: 21, #queue-req: 87
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 83, cache hit rate: 93.66%, token usage: 0.01, #running-req: 22, #queue-req: 87
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 93, cache hit rate: 93.65%, token usage: 0.01, #running-req: 22, #queue-req: 87
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 103, cache hit rate: 93.64%, token usage: 0.01, #running-req: 22, #queue-req: 87
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 113, cache hit rate: 93.63%, token usage: 0.01, #running-req: 22, #queue-req: 87
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 2, #new-token: 10, #cached-token: 200, cache hit rate: 93.64%, token usage: 0.01, #running-req: 22, #queue-req: 86
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 84, cache hit rate: 93.63%, token usage: 0.01, #running-req: 23, #queue-req: 86
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 93, cache hit rate: 93.63%, token usage: 0.01, #running-req: 23, #queue-req: 86
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 102, cache hit rate: 93.62%, token usage: 0.01, #running-req: 23, #queue-req: 86
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 2, #new-token: 9, #cached-token: 190, cache hit rate: 93.63%, token usage: 0.01, #running-req: 23, #queue-req: 85
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 80, cache hit rate: 93.62%, token usage: 0.01, #running-req: 24, #queue-req: 85
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 88, cache hit rate: 93.62%, token usage: 0.01, #running-req: 24, #queue-req: 85
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 96, cache hit rate: 93.62%, token usage: 0.01, #running-req: 24, #queue-req: 85
+[2024-11-22 03:08:39 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 104, cache hit rate: 93.61%, token usage: 0.01, #running-req: 23, #queue-req: 85
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 2, #new-token: 9, #cached-token: 189, cache hit rate: 93.62%, token usage: 0.01, #running-req: 23, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 77, cache hit rate: 93.61%, token usage: 0.01, #running-req: 24, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 85, cache hit rate: 93.61%, token usage: 0.01, #running-req: 24, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 93, cache hit rate: 93.61%, token usage: 0.01, #running-req: 24, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 101, cache hit rate: 93.60%, token usage: 0.01, #running-req: 24, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 109, cache hit rate: 93.60%, token usage: 0.01, #running-req: 24, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 117, cache hit rate: 93.60%, token usage: 0.01, #running-req: 24, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 125, cache hit rate: 93.60%, token usage: 0.01, #running-req: 24, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 133, cache hit rate: 93.61%, token usage: 0.01, #running-req: 24, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 141, cache hit rate: 93.61%, token usage: 0.01, #running-req: 24, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 149, cache hit rate: 93.61%, token usage: 0.01, #running-req: 24, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 157, cache hit rate: 93.62%, token usage: 0.01, #running-req: 24, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 165, cache hit rate: 93.63%, token usage: 0.01, #running-req: 24, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 173, cache hit rate: 93.64%, token usage: 0.01, #running-req: 24, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 181, cache hit rate: 93.65%, token usage: 0.01, #running-req: 24, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 189, cache hit rate: 93.66%, token usage: 0.01, #running-req: 24, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 197, cache hit rate: 93.67%, token usage: 0.01, #running-req: 23, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 206, cache hit rate: 93.68%, token usage: 0.01, #running-req: 23, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 215, cache hit rate: 93.69%, token usage: 0.01, #running-req: 23, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 224, cache hit rate: 93.71%, token usage: 0.01, #running-req: 23, #queue-req: 84
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 2, #new-token: 9, #cached-token: 311, cache hit rate: 93.73%, token usage: 0.01, #running-req: 23, #queue-req: 83
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 86, cache hit rate: 93.73%, token usage: 0.01, #running-req: 24, #queue-req: 83
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 94, cache hit rate: 93.72%, token usage: 0.01, #running-req: 24, #queue-req: 83
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 102, cache hit rate: 93.72%, token usage: 0.01, #running-req: 24, #queue-req: 83
+[2024-11-22 03:08:40 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 110, cache hit rate: 93.72%, token usage: 0.01, #running-req: 24, #queue-req: 83
+Process Process-1:
+Process Process-2:
+Traceback (most recent call last):
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
+ self.run()
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/multiprocessing/process.py", line 108, in run
+ self._target(*self._args, **self._kwargs)
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/sglang/srt/managers/detokenizer_manager.py", line 204, in run_detokenizer_process
+ manager.event_loop()
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/sglang/srt/managers/detokenizer_manager.py", line 99, in event_loop
+ recv_obj = self.recv_from_scheduler.recv_pyobj()
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/zmq/sugar/socket.py", line 972, in recv_pyobj
+ msg = self.recv(flags)
+ ^^^^^^^^^^^^^^^^
+ File "_zmq.py", line 1156, in zmq.backend.cython._zmq.Socket.recv
+ File "_zmq.py", line 1191, in zmq.backend.cython._zmq.Socket.recv
+ File "_zmq.py", line 1278, in zmq.backend.cython._zmq._recv_copy
+ File "_zmq.py", line 160, in zmq.backend.cython._zmq._check_rc
+KeyboardInterrupt
+Traceback (most recent call last):
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
+ self.run()
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/multiprocessing/process.py", line 108, in run
+ self._target(*self._args, **self._kwargs)
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/sglang/srt/managers/scheduler.py", line 1433, in run_scheduler_process
+ scheduler.event_loop_normal()
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
+ return func(*args, **kwargs)
+ ^^^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/sglang/srt/managers/scheduler.py", line 377, in event_loop_normal
+ result = self.run_batch(batch)
+ ^^^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/sglang/srt/managers/scheduler.py", line 926, in run_batch
+ logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/sglang/srt/managers/tp_worker.py", line 153, in forward_batch_generation
+ logits_output = self.model_runner.forward(forward_batch)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 756, in forward
+ return self.forward_extend(forward_batch)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 732, in forward_extend
+ return self.model.forward(
+ ^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
+ return func(*args, **kwargs)
+ ^^^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/sglang/srt/models/llama.py", line 318, in forward
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
+ return self._call_impl(*args, **kwargs)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
+ return forward_call(*args, **kwargs)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/sglang/srt/models/llama.py", line 283, in forward
+ hidden_states, residual = layer(
+ ^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
+ return self._call_impl(*args, **kwargs)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
+ return forward_call(*args, **kwargs)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/sglang/srt/models/llama.py", line 233, in forward
+ hidden_states = self.self_attn(
+ ^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
+ return self._call_impl(*args, **kwargs)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
+ return forward_call(*args, **kwargs)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/sglang/srt/models/llama.py", line 168, in forward
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/torch/_tensor.py", line 917, in split
+ return torch._VF.split_with_sizes(self, split_size, dim)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+KeyboardInterrupt
+[2024-11-22 03:08:40] INFO: Shutting down
+[2024-11-22 03:08:40] INFO: Waiting for connections to close. (CTRL+C to force quit)
diff --git a/test/srt/stdout.txt b/test/srt/stdout.txt
new file mode 100644
index 00000000000..bc4f3035a3d
--- /dev/null
+++ b/test/srt/stdout.txt
@@ -0,0 +1,25 @@
+INFO 11-22 03:08:21 weight_utils.py:243] Using model weights format ['*.safetensors']
+[2024-11-22 03:08:33] INFO: 127.0.0.1:48830 - "GET /health_generate HTTP/1.1" 200 OK
+[2024-11-22 03:08:33] INFO: 127.0.0.1:48842 - "GET /get_model_info HTTP/1.1" 200 OK
+[2024-11-22 03:08:33] INFO: 127.0.0.1:48850 - "POST /generate HTTP/1.1" 200 OK
+[2024-11-22 03:08:37] INFO: 127.0.0.1:49414 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:37] INFO: 127.0.0.1:49154 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:38] INFO: 127.0.0.1:49516 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:38] INFO: 127.0.0.1:49490 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:38] INFO: 127.0.0.1:49246 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:38] INFO: 127.0.0.1:48994 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:38] INFO: 127.0.0.1:49828 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:38] INFO: 127.0.0.1:49954 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:39] INFO: 127.0.0.1:48882 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:39] INFO: 127.0.0.1:48872 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:39] INFO: 127.0.0.1:48938 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:39] INFO: 127.0.0.1:49526 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:39] INFO: 127.0.0.1:48866 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:39] INFO: 127.0.0.1:49542 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:39] INFO: 127.0.0.1:49606 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:39] INFO: 127.0.0.1:48958 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:39] INFO: 127.0.0.1:49742 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:39] INFO: 127.0.0.1:48924 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:39] INFO: 127.0.0.1:49126 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:39] INFO: 127.0.0.1:49276 - "POST /v1/chat/completions HTTP/1.1" 200 OK
+[2024-11-22 03:08:40] INFO: 127.0.0.1:49608 - "POST /v1/chat/completions HTTP/1.1" 200 OK
diff --git a/test/srt/test.log2 b/test/srt/test.log2
new file mode 100644
index 00000000000..16795a7ddc4
--- /dev/null
+++ b/test/srt/test.log2
@@ -0,0 +1,36 @@
+INFO 11-27 07:20:39 weight_utils.py:243] Using model weights format ['*.safetensors']
+INFO 11-27 07:20:39 weight_utils.py:288] No model.safetensors.index.json found in remote.
+torch.Size([2048, 2048])
+rank: 0, try to barrier
+rank: 0, try to broadcast hf_instruct_param
+rank: 0, try to del hf_instruct_model
+rank: 0, try to del hf_base_model
+rank: 0, try to gc
+rank: 0, try to empty cache
+Queue get error:
+Got parameter: hf_instruct_param
+Got parameter: hf_base_param
+Queue get error:
+Queue get error:
+Queue get error:
+Got parameter: engine_instruct_param
+Queue get error:
+Child processes have terminated
+INFO 11-27 07:24:19 weight_utils.py:243] Using model weights format ['*.safetensors']
+INFO 11-27 07:24:19 weight_utils.py:288] No model.safetensors.index.json found in remote.
+torch.Size([2048, 2048])
+rank: 0, try to barrier
+rank: 0, try to broadcast hf_instruct_param
+rank: 0, try to del hf_instruct_model
+rank: 0, try to del hf_base_model
+rank: 0, try to gc
+rank: 0, try to empty cache
+Queue get error:
+Got parameter: hf_instruct_param
+Got parameter: hf_base_param
+Queue get error:
+Queue get error:
+Queue get error:
+Got parameter: engine_instruct_param
+Queue get error:
+Child processes have terminated
diff --git a/test/srt/test_custom_process_group.py b/test/srt/test_custom_process_group.py
new file mode 100644
index 00000000000..3506ee89f26
--- /dev/null
+++ b/test/srt/test_custom_process_group.py
@@ -0,0 +1,101 @@
+import os
+import time
+import unittest
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from transformers import AutoModelForCausalLM
+
+import sglang as sgl
+from sglang.srt.utils import init_custom_process_group, kill_child_process
+from sglang.test.test_utils import (
+ DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
+ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ DEFAULT_URL_FOR_TEST,
+ popen_launch_server,
+)
+
+mp.set_start_method("spawn", force=True)
+
+
+class TestParameterUpdateDistributed(unittest.TestCase):
+ @classmethod
+ def init_process(cls, rank, world_size, base_url, model_name, tensor_value):
+ try:
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "30000"
+ torch.cuda.set_device(rank)
+ engine = None
+
+ group_name = "test_group_for_custom_process_group"
+ group = init_custom_process_group(
+ backend="nccl",
+ init_method="tcp://localhost:30000",
+ world_size=world_size,
+ rank=rank,
+ group_name=group_name,
+ )
+ print(f"Initialized custom process group on rank {rank}")
+ print(f"rank: {rank}, before barrier")
+ dist.barrier(group=group)
+ print(f"rank: {rank}, after barrier")
+
+ if rank == 0:
+ # hf_model = AutoModelForCausalLM.from_pretrained(
+ # model_name, torch_dtype="bfloat16"
+ # ).to(f"cuda:{rank}")
+ print("HF model loaded on rank 0")
+ elif rank == 1:
+ # engine = sgl.Runtime(
+ # model_path=model_name, random_seed=42, base_gpu_id=rank
+ # )
+ print("SGLang server launched on rank 1")
+ time.sleep(5)
+
+ tensor = torch.ones(1).cuda() * (rank + 2)
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
+ print(f"Rank {rank} sees sum: {tensor.item()}")
+
+ if rank == 0:
+ tensor_value.value = tensor.item()
+
+ dist.barrier(group=group)
+
+ finally:
+ if group is not None:
+ dist.destroy_process_group(group)
+ if engine is not None:
+ engine.shutdown()
+
+ @classmethod
+ def setUpClass(cls):
+ cls.world_size = 2
+ cls.model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
+ cls.base_url = DEFAULT_URL_FOR_TEST
+
+ cls.tensor_value = mp.Value("d", 0.0)
+
+ mp.spawn(
+ cls.init_process,
+ args=(
+ cls.world_size,
+ cls.base_url,
+ cls.model_name,
+ cls.tensor_value,
+ ),
+ nprocs=cls.world_size,
+ join=True,
+ )
+
+ @classmethod
+ def tearDownClass(cls):
+ if torch.distributed.is_initialized():
+ torch.distributed.destroy_process_group()
+
+ def test_custom_process_group(self):
+ self.assertEqual(self.__class__.tensor_value.value, 5)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/srt/test_custom_process_group_wraper.py b/test/srt/test_custom_process_group_wraper.py
new file mode 100644
index 00000000000..2d0476b0ce5
--- /dev/null
+++ b/test/srt/test_custom_process_group_wraper.py
@@ -0,0 +1,132 @@
+import asyncio
+import json
+import os
+import time
+import unittest
+from multiprocessing import process
+from types import SimpleNamespace
+
+import requests
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from transformers import AutoModelForCausalLM
+
+import sglang as sgl
+from sglang.bench_offline_throughput import BenchArgs, throughput_test
+from sglang.srt.hf_transformers_utils import get_tokenizer
+from sglang.srt.server_args import ServerArgs
+from sglang.srt.utils import init_custom_process_group, kill_child_process
+from sglang.test.few_shot_gsm8k_engine import run_eval
+from sglang.test.test_utils import (
+ DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
+ DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
+ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ DEFAULT_URL_FOR_TEST,
+ popen_launch_server,
+)
+
+mp.set_start_method("spawn", force=True)
+
+
+def mock_init_parameter_update_group(
+ master_address,
+ master_port,
+ rank_offset,
+ world_size,
+ group_name,
+ backend="nccl",
+):
+ rank = rank_offset + 0
+
+ _model_update_group = init_custom_process_group(
+ backend=backend,
+ init_method=f"tcp://{master_address}:{master_port}",
+ world_size=world_size,
+ rank=rank,
+ group_name=group_name,
+ )
+
+ return _model_update_group
+
+
+class TestParameterUpdateGroup(unittest.TestCase):
+ @classmethod
+ def init_process(cls, rank, world_size):
+ try:
+ # 设置分布式环境
+ torch.cuda.set_device(rank)
+
+ print(
+ f"[Rank {rank}] Using GPU: {torch.cuda.current_device()} "
+ f"(CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']})"
+ )
+
+ if rank == 0:
+ os.environ["NCCL_CUMEM_ENABLE"] = "0"
+ os.environ["NCCL_NVLS_ENABLE"] = "0"
+ print(f"[Rank 0] Starting initialization")
+ group = init_custom_process_group(
+ backend="nccl",
+ init_method="tcp://localhost:29500",
+ world_size=world_size,
+ rank=rank,
+ group_name="test_parameter_update_group",
+ )
+ print(f"[Rank 0] Process group initialized")
+ print(f"[Rank 0] before barrier")
+ dist.barrier(group=group)
+ print(f"[Rank 0] after barrier")
+
+ elif rank == 1:
+ print(f"[Rank 1] Starting server launch")
+ time.sleep(2)
+ _model_update_group = mock_init_parameter_update_group(
+ master_address="localhost",
+ master_port="29500",
+ rank_offset=1,
+ world_size=world_size,
+ group_name="test_parameter_update_group",
+ backend="nccl",
+ )
+ print(f"[Rank 1] before barrier")
+ dist.barrier(group=_model_update_group)
+ print(f"[Rank 1] after barrier")
+
+ print(f"[Rank {rank}] Process initialization completed")
+
+ except Exception as e:
+ print(f"[Rank {rank}] Error occurred: {str(e)}")
+ raise
+
+ @classmethod
+ def setUpClass(cls):
+ cls.world_size = 2
+
+ print("Starting multiprocessing spawn")
+ mp.spawn(
+ cls.init_process,
+ args=(cls.world_size,),
+ nprocs=cls.world_size,
+ join=True,
+ )
+ print("Multiprocessing spawn completed")
+
+ @classmethod
+ def tearDownClass(cls):
+ print("Starting teardown")
+ # 先清理分布式进程组
+ if torch.distributed.is_initialized():
+ torch.distributed.destroy_process_group()
+ print("Process group destroyed")
+
+ time.sleep(1) # 给进程一些清理的时间
+
+ def test_init_parameter_update_group(self):
+ print(
+ "Successfully initialized parameter update group between huggingface and SGLang server."
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/srt/test_data_parallelism.py b/test/srt/test_data_parallelism.py
index f34313ea09a..8d2fa487f2f 100644
--- a/test/srt/test_data_parallelism.py
+++ b/test/srt/test_data_parallelism.py
@@ -44,7 +44,7 @@ def test_mmlu(self):
def test_update_weight(self):
response = requests.post(
- self.base_url + "/update_weights",
+ self.base_url + "/update_weights_from_disk",
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
)
@@ -55,7 +55,7 @@ def test_update_weight(self):
time.sleep(5)
response = requests.post(
- self.base_url + "/update_weights",
+ self.base_url + "/update_weights_from_disk",
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
)
diff --git a/test/srt/test_engine_init_parameter_update_group.py b/test/srt/test_engine_init_parameter_update_group.py
new file mode 100644
index 00000000000..01215c81fa1
--- /dev/null
+++ b/test/srt/test_engine_init_parameter_update_group.py
@@ -0,0 +1,85 @@
+import gc
+import os
+import time
+import unittest
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from transformers import AutoModelForCausalLM
+
+import sglang as sgl
+from sglang.srt.utils import init_custom_process_group
+from sglang.test.test_utils import (
+ DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
+ DEFAULT_URL_FOR_TEST,
+)
+
+mp.set_start_method("spawn", force=True)
+
+
+class TestParameterUpdateGroup(unittest.TestCase):
+ @classmethod
+ def init_process(cls, rank, world_size, base_url, model_name):
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "65500"
+ torch.cuda.set_device(rank)
+
+ if rank == 0:
+ # Rank 0: 加载HF模型
+ os.environ["NCCL_CUMEM_ENABLE"] = "0"
+ os.environ["NCCL_NVLS_ENABLE"] = "0"
+ hf_model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda:0")
+ group = init_custom_process_group(
+ backend="nccl",
+ init_method="tcp://localhost:65500",
+ world_size=world_size,
+ rank=rank,
+ group_name="test_parameter_update_group",
+ )
+ print(f"rank: {rank}, after init_custom_process_group")
+ del hf_model
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ elif rank == 1:
+ # Rank 1: 启动SGLang服务器
+ engine = sgl.Engine(model_path=model_name, random_seed=42, base_gpu_id=rank)
+ engine.init_parameter_update_group(
+ master_address="localhost",
+ master_port="65500",
+ rank_offset=1,
+ world_size=world_size,
+ group_name="test_parameter_update_group",
+ backend="nccl",
+ )
+ print(f"rank: {rank}, after init_parameter_update_group")
+ engine.shutdown()
+
+ @classmethod
+ def setUpClass(cls):
+ cls.world_size = 2
+ cls.model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
+ cls.base_url = DEFAULT_URL_FOR_TEST
+
+ mp.spawn(
+ cls.init_process,
+ args=(cls.world_size, cls.base_url, cls.model_name),
+ nprocs=cls.world_size,
+ join=True,
+ )
+
+ @classmethod
+ def tearDownClass(cls):
+ if torch.distributed.is_initialized():
+ torch.distributed.destroy_process_group()
+ time.sleep(1)
+
+ def test_init_parameter_update_group(self):
+ print(
+ "Successfully initialized parameter update group between huggingface and SGLang server."
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/srt/test_engine_update_parameter_from_distributed.py b/test/srt/test_engine_update_parameter_from_distributed.py
new file mode 100644
index 00000000000..64f5d104ff1
--- /dev/null
+++ b/test/srt/test_engine_update_parameter_from_distributed.py
@@ -0,0 +1,83 @@
+import gc
+import os
+import time
+import unittest
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from transformers import AutoModelForCausalLM
+
+import sglang as sgl
+from sglang.srt.utils import init_custom_process_group
+from sglang.test.test_utils import (
+ DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
+ DEFAULT_URL_FOR_TEST,
+)
+
+mp.set_start_method("spawn", force=True)
+
+
+class TestParameterUpdateGroup(unittest.TestCase):
+ @classmethod
+ def init_process(cls, rank, world_size):
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "65500"
+
+ if rank == 0:
+ os.environ["NCCL_CUMEM_ENABLE"] = "0"
+ os.environ["NCCL_NVLS_ENABLE"] = "0"
+ hf_model = AutoModelForCausalLM.from_pretrained(
+ DEFAULT_SMALL_MODEL_NAME_FOR_TEST, torch_dtype="bfloat16"
+ ).to("cuda:0")
+ group = init_custom_process_group(
+ backend="nccl",
+ init_method="tcp://localhost:65500",
+ world_size=world_size,
+ rank=rank,
+ group_name="test_parameter_update_group",
+ )
+ del hf_model
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ elif rank == 1:
+ engine = sgl.Engine(
+ model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
+ random_seed=42,
+ base_gpu_id=rank,
+ )
+ engine.init_parameter_update_group(
+ master_address="localhost",
+ master_port="65500",
+ rank_offset=1,
+ world_size=world_size,
+ group_name="test_parameter_update_group",
+ backend="nccl",
+ )
+ engine.shutdown()
+
+ @classmethod
+ def setUpClass(cls):
+ cls.world_size = 2
+
+ mp.spawn(
+ cls.init_process,
+ args=(cls.world_size,),
+ nprocs=cls.world_size,
+ join=True,
+ )
+
+ @classmethod
+ def tearDownClass(cls):
+ if torch.distributed.is_initialized():
+ torch.distributed.destroy_process_group()
+
+ def test_init_parameter_update_group(self):
+ print(
+ "Successfully initialized parameter update group between huggingface and SGLang server."
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/srt/test_get_parameter_by_name.py b/test/srt/test_get_parameter_by_name.py
new file mode 100644
index 00000000000..5bbd545f571
--- /dev/null
+++ b/test/srt/test_get_parameter_by_name.py
@@ -0,0 +1,137 @@
+import gc
+import unittest
+
+import numpy as np
+import requests
+import torch
+from transformers import AutoModelForCausalLM
+
+import sglang as sgl
+from sglang.test.test_utils import (
+ DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
+ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ DEFAULT_URL_FOR_TEST,
+ popen_launch_server,
+)
+from sglang.utils import terminate_process
+
+
+class TestUpdateWeights(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
+ cls.base_url = DEFAULT_URL_FOR_TEST
+ cls.hf_model = AutoModelForCausalLM.from_pretrained(
+ cls.model, torch_dtype="bfloat16"
+ ).to("cuda:0")
+
+ @classmethod
+ def init_engine_and_server(cls, engine_tp, server_tp, engine_dp, server_dp):
+ cls.engine = None
+ cls.process = None
+ cls.engine_dp = engine_dp
+ cls.server_dp = server_dp
+ cls.engine_tp = engine_tp
+ cls.server_tp = server_tp
+ if engine_dp != 0:
+ cls.engine = sgl.Engine(
+ model_path=cls.model,
+ random_seed=42,
+ tp_size=engine_tp,
+ dp_size=engine_dp,
+ base_gpu_id=0,
+ mem_fraction_static=0.85,
+ )
+ if server_dp != 0:
+ cls.process = popen_launch_server(
+ cls.model,
+ cls.base_url,
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ other_args=(
+ "--base-gpu-id",
+ str(engine_dp * engine_tp),
+ "--tp-size",
+ str(server_tp),
+ "--dp-size",
+ str(server_dp),
+ ),
+ )
+
+ @classmethod
+ def close_engine_and_server(cls):
+ if cls.engine:
+ cls.engine.shutdown()
+ if cls.process:
+ terminate_process(cls.process)
+
+ @classmethod
+ def tearDownClass(cls):
+ del cls.hf_model
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ @classmethod
+ def assert_update_weights_all_close(cls, param_name, truncate_size):
+ print(
+ f"param_name: {param_name}, engine_dp: {cls.engine_dp}, server_dp: {cls.server_dp}, engine_tp: {cls.engine_tp}, server_tp: {cls.server_tp}"
+ )
+ param = cls.hf_model.get_parameter(param_name)[:truncate_size]
+ param_np = param.cpu().detach().float().numpy()
+
+ if cls.engine:
+ engine_ret = cls.engine.get_weights_by_parameter_name(
+ param_name, truncate_size
+ )
+ engine_ret = cls._process_return(engine_ret)
+ np.testing.assert_allclose(engine_ret, param_np, rtol=1e-5, atol=1e-5)
+
+ if cls.process:
+ runtime_ret = requests.get(
+ f"{cls.base_url}/get_weights_by_parameter_name",
+ json={"name": param_name, "truncate_size": truncate_size},
+ ).json()
+ runtime_ret = cls._process_return(runtime_ret)
+ np.testing.assert_allclose(runtime_ret, param_np, rtol=1e-5, atol=1e-5)
+
+ @staticmethod
+ def _process_return(ret):
+ if isinstance(ret, list) and len(ret) == 2:
+ np.testing.assert_allclose(ret[0], ret[1])
+ return np.array(ret[0])
+ return np.array(ret)
+
+ @classmethod
+ def test_update_weights_unexist_model(cls):
+ assert torch.cuda.device_count() >= 2, "At least 2 GPUs are required"
+ test_suits = [(1, 1, 1, 1), (2, 0, 1, 0), (0, 2, 0, 1)]
+
+ if torch.cuda.device_count() >= 4:
+ test_suits.extend([(2, 2, 1, 1), (1, 1, 2, 2)])
+
+ if torch.cuda.device_count() >= 8:
+ test_suits.append((2, 2, 2, 2))
+
+ parameters = [
+ "model.embed_tokens.weight",
+ "model.layers.0.input_layernorm.weight",
+ "model.layers.1.self_attn.q_proj.weight",
+ "model.layers.2.self_attn.k_proj.weight",
+ "model.layers.3.self_attn.v_proj.weight",
+ "model.layers.4.self_attn.o_proj.weight",
+ "model.layers.5.mlp.gate_proj.weight",
+ "model.layers.6.mlp.up_proj.weight",
+ "model.layers.7.mlp.down_proj.weight",
+ "model.layers.8.post_attention_layernorm.weight",
+ "model.norm.weight",
+ "lm_head.weight",
+ ]
+
+ for test_suit in test_suits:
+ cls.init_engine_and_server(*test_suit)
+ for param_name in parameters:
+ cls.assert_update_weights_all_close(param_name, 100)
+ cls.close_engine_and_server()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/srt/test_init_parameter_update_group.py b/test/srt/test_init_parameter_update_group.py
new file mode 100644
index 00000000000..016294ee21d
--- /dev/null
+++ b/test/srt/test_init_parameter_update_group.py
@@ -0,0 +1,130 @@
+import os
+import time
+import unittest
+
+import requests
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from transformers import AutoModelForCausalLM
+
+from sglang.srt.utils import init_custom_process_group, kill_child_process
+from sglang.test.test_utils import (
+ DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
+ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ DEFAULT_URL_FOR_TEST,
+ popen_launch_server,
+)
+
+mp.set_start_method("spawn", force=True)
+
+
+class TestParameterUpdateGroup(unittest.TestCase):
+ @classmethod
+ def init_process(cls, rank, world_size, base_url, model_name, server_pid):
+ try:
+ # 设置分布式环境
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "29500"
+ torch.cuda.set_device(rank) # 使用本地GPU ID 0,因为每个进程只能看到一个GPU
+
+ print(
+ f"[Rank {rank}] Using GPU: {torch.cuda.current_device()} "
+ f"(CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']})"
+ )
+
+ if rank == 0:
+ print(f"[Rank 0] Starting initialization")
+ hf_model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda:0")
+ print(f"[Rank 0] HF model loaded")
+
+ group = init_custom_process_group(
+ backend="nccl",
+ init_method="tcp://localhost:29500",
+ world_size=world_size,
+ rank=rank,
+ group_name="test_parameter_update_group",
+ )
+ print(f"[Rank 0] Process group initialized")
+ print(f"[Rank 0] before barrier")
+ dist.barrier(group=group)
+ print(f"[Rank 0] after barrier")
+
+ elif rank == 1:
+ print(f"[Rank 1] Starting server launch")
+ process = popen_launch_server(
+ model_name,
+ base_url,
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ other_args=("--base-gpu-id", str(rank)),
+ )
+ server_pid.value = process.pid
+ print(f"[Rank 1] Server launched with pid {process.pid}")
+
+ response = requests.post(
+ f"{base_url}/init_parameter_update_group",
+ json={
+ "master_address": "localhost",
+ "master_port": "29500",
+ "rank_offset": 1,
+ "world_size": world_size,
+ "group_name": "test_parameter_update_group",
+ "backend": "nccl",
+ },
+ timeout=30,
+ )
+ print(
+ f"[Rank 1] Parameter update group initialized with response: {response.status_code}"
+ )
+
+ print(f"[Rank {rank}] Process initialization completed")
+
+ except Exception as e:
+ print(f"[Rank {rank}] Error occurred: {str(e)}")
+ raise
+
+ @classmethod
+ def setUpClass(cls):
+ cls.world_size = 2
+ cls.model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
+ cls.base_url = DEFAULT_URL_FOR_TEST
+ cls.server_pid = mp.Value("i", 0)
+
+ print("Starting multiprocessing spawn")
+ mp.spawn(
+ cls.init_process,
+ args=(
+ cls.world_size,
+ cls.base_url,
+ cls.model_name,
+ cls.server_pid,
+ ),
+ nprocs=cls.world_size,
+ join=True,
+ )
+ print("Multiprocessing spawn completed")
+
+ @classmethod
+ def tearDownClass(cls):
+ print("Starting teardown")
+ # 先清理分布式进程组
+ if torch.distributed.is_initialized():
+ torch.distributed.destroy_process_group()
+ print("Process group destroyed")
+
+ # 然后清理服务器进程
+ if cls.server_pid.value != 0:
+ print(f"Cleaning up server process {cls.server_pid.value}")
+ kill_child_process(cls.server_pid.value, include_self=True)
+ print("Server process cleaned up")
+
+ time.sleep(1) # 给进程一些清理的时间
+
+ def test_init_parameter_update_group(self):
+ print(
+ "Successfully initialized parameter update group between huggingface and SGLang server."
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/srt/test_metrics.py b/test/srt/test_metrics.py
index 163a7cc0e06..94479ab02e9 100644
--- a/test/srt/test_metrics.py
+++ b/test/srt/test_metrics.py
@@ -1,3 +1,4 @@
+import json
import unittest
import requests
@@ -11,71 +12,99 @@
)
-class TestEnableMetrics(unittest.TestCase):
- def test_metrics_enabled(self):
- """Test that metrics endpoint returns data when enabled"""
- process = popen_launch_server(
- DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
- DEFAULT_URL_FOR_TEST,
+class TestUpdateWeights(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
+ cls.base_url = DEFAULT_URL_FOR_TEST
+ cls.process = popen_launch_server(
+ cls.model,
+ cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
- other_args=["--enable-metrics"],
+ other_args=("--mem-",),
)
- try:
- # Make some requests to generate some metrics
- response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
- self.assertEqual(response.status_code, 200)
-
- response = requests.post(
- f"{DEFAULT_URL_FOR_TEST}/generate",
- json={
- "text": "The capital of France is",
- "sampling_params": {
- "temperature": 0,
- "max_new_tokens": 32,
- },
- "stream": True,
+ @classmethod
+ def tearDownClass(cls):
+ kill_child_process(cls.process.pid, include_self=True)
+
+ def run_decode(self):
+ response = requests.post(
+ self.base_url + "/generate",
+ json={
+ "text": "The capital of France is",
+ "sampling_params": {
+ "temperature": 0,
+ "max_new_tokens": 32,
},
- stream=True,
- )
- for _ in response.iter_lines(decode_unicode=False):
- pass
-
- # Get metrics
- metrics_response = requests.get(f"{DEFAULT_URL_FOR_TEST}/metrics")
- self.assertEqual(metrics_response.status_code, 200)
- metrics_content = metrics_response.text
-
- print(f"metrics_content=\n{metrics_content}")
-
- # Verify essential metrics are present
- essential_metrics = [
- "sglang:num_running_reqs",
- "sglang:token_usage",
- "sglang:gen_throughput",
- "sglang:cache_hit_rate",
- "sglang:func_latency_seconds",
- "sglang:prompt_tokens_total",
- "sglang:generation_tokens_total",
- "sglang:time_to_first_token_seconds",
- "sglang:time_per_output_token_seconds",
- "sglang:e2e_request_latency_seconds",
- ]
-
- for metric in essential_metrics:
- self.assertIn(metric, metrics_content, f"Missing metric: {metric}")
-
- # Verify model name label is present and correct
- expected_model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
- self.assertIn(f'model_name="{expected_model_name}"', metrics_content)
-
- # Verify metrics have values (not empty)
- self.assertIn("_sum{", metrics_content)
- self.assertIn("_count{", metrics_content)
- self.assertIn("_bucket{", metrics_content)
-
- finally:
- kill_child_process(process.pid, include_self=True)
+ },
+ )
+ print(json.dumps(response.json()))
+ print("=" * 100)
+ text = response.json()["text"]
+ return text
+
+ def get_model_info(self):
+ response = requests.get(self.base_url + "/get_model_info")
+ model_path = response.json()["model_path"]
+ print(json.dumps(response.json()))
+ return model_path
+
+ def run_update_weights(self, model_path):
+ response = requests.post(
+ self.base_url + "/update_weights_from_disk",
+ json={
+ "model_path": model_path,
+ },
+ )
+ ret = response.json()
+ print(json.dumps(response.json()))
+ return ret
+
+ def test_update_weights(self):
+ origin_model_path = self.get_model_info()
+ print(f"origin_model_path: {origin_model_path}")
+ origin_response = self.run_decode()
+
+ # update weights
+ new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
+ ret = self.run_update_weights(new_model_path)
+ assert ret["success"]
+
+ updated_model_path = self.get_model_info()
+ print(f"updated_model_path: {updated_model_path}")
+ assert updated_model_path == new_model_path
+ assert updated_model_path != origin_model_path
+
+ updated_response = self.run_decode()
+ assert origin_response[:32] != updated_response[:32]
+
+ # update weights back
+ ret = self.run_update_weights(origin_model_path)
+ assert ret["success"]
+
+ updated_model_path = self.get_model_info()
+ assert updated_model_path == origin_model_path
+
+ updated_response = self.run_decode()
+ assert origin_response[:32] == updated_response[:32]
+
+ def test_update_weights_unexist_model(self):
+ origin_model_path = self.get_model_info()
+ print(f"origin_model_path: {origin_model_path}")
+ origin_response = self.run_decode()
+
+ # update weights
+ new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "wrong")
+ ret = self.run_update_weights(new_model_path)
+ assert not ret["success"]
+
+ updated_model_path = self.get_model_info()
+ print(f"updated_model_path: {updated_model_path}")
+ assert updated_model_path == origin_model_path
+
+ updated_response = self.run_decode()
+ assert origin_response[:32] == updated_response[:32]
if __name__ == "__main__":
diff --git a/test/srt/test_update_parameter_from_distributed.py b/test/srt/test_update_parameter_from_distributed.py
new file mode 100644
index 00000000000..45c3008d841
--- /dev/null
+++ b/test/srt/test_update_parameter_from_distributed.py
@@ -0,0 +1,867 @@
+# Copyright 2023-2024 SGLang Team
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ModelRunner runs the forward passes of the models."""
+
+import gc
+import importlib
+import importlib.resources
+import inspect
+import json
+import logging
+import pkgutil
+from functools import lru_cache
+from typing import Any, Optional, Type, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from vllm.config import DeviceConfig, LoadConfig
+from vllm.config import ModelConfig as VllmModelConfig
+from vllm.distributed import (
+ get_tp_group,
+ init_distributed_environment,
+ initialize_model_parallel,
+ set_custom_all_reduce,
+)
+from vllm.distributed.parallel_state import in_the_same_node_as
+from vllm.model_executor.model_loader import get_model
+from vllm.model_executor.models import ModelRegistry
+
+from sglang.srt.configs.model_config import AttentionArch, ModelConfig
+from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
+from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
+from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
+from sglang.srt.layers.logits_processor import LogitsProcessorOutput
+from sglang.srt.layers.sampler import Sampler
+from sglang.srt.lora.lora_manager import LoRAManager
+from sglang.srt.managers.schedule_batch import global_server_args_dict
+from sglang.srt.mem_cache.memory_pool import (
+ DoubleSparseTokenToKVPool,
+ MHATokenToKVPool,
+ MLATokenToKVPool,
+ ReqToTokenPool,
+)
+from sglang.srt.model_executor.forward_batch_info import ForwardBatch
+from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
+from sglang.srt.server_args import ServerArgs
+from sglang.srt.utils import (
+ crash_on_warnings,
+ enable_show_time_cost,
+ get_available_gpu_memory,
+ init_custom_process_group,
+ is_hip,
+ monkey_patch_vllm_model_config,
+ monkey_patch_vllm_p2p_access_check,
+ set_cpu_offload_max_bytes,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class ModelRunner:
+ """ModelRunner runs the forward passes of the models."""
+
+ def __init__(
+ self,
+ model_config: ModelConfig,
+ mem_fraction_static: float,
+ gpu_id: int,
+ tp_rank: int,
+ tp_size: int,
+ nccl_port: int,
+ server_args: ServerArgs,
+ ):
+ # Parse args
+ self.model_config = model_config
+ self.mem_fraction_static = mem_fraction_static
+ self.device = server_args.device
+ self.gpu_id = gpu_id
+ self.tp_rank = tp_rank
+ self.tp_size = tp_size
+ self.dist_port = nccl_port
+ self.server_args = server_args
+ self.is_generation = model_config.is_generation
+ self.is_multimodal = model_config.is_multimodal
+
+ # Model-specific adjustment
+ if (
+ self.model_config.attention_arch == AttentionArch.MLA
+ and not self.server_args.disable_mla
+ ):
+ logger.info("MLA optimization is turned on. Use triton backend.")
+ self.server_args.attention_backend = "triton"
+
+ if self.server_args.enable_double_sparsity:
+ logger.info(
+ "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
+ )
+ self.server_args.attention_backend = "triton"
+ self.server_args.disable_cuda_graph = True
+ if self.server_args.ds_heavy_channel_type is None:
+ raise ValueError(
+ "Please specify the heavy channel type for double sparsity optimization."
+ )
+ self.init_double_sparsity_channel_config(
+ self.server_args.ds_heavy_channel_type
+ )
+
+ if self.is_multimodal:
+ logger.info(
+ "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
+ )
+ server_args.chunked_prefill_size = None
+ self.mem_fraction_static *= 0.95
+ # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
+ if self.model_config.hf_config.architectures == [
+ "Qwen2VLForConditionalGeneration"
+ ]:
+ server_args.disable_radix_cache = True
+
+ # Global vars
+ if server_args.show_time_cost:
+ enable_show_time_cost()
+ if server_args.disable_disk_cache:
+ from outlines.caching import disable_cache
+
+ disable_cache()
+
+ global_server_args_dict.update(
+ {
+ "attention_backend": server_args.attention_backend,
+ "sampling_backend": server_args.sampling_backend,
+ "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
+ "disable_mla": server_args.disable_mla,
+ "torchao_config": server_args.torchao_config,
+ "enable_nan_detection": server_args.enable_nan_detection,
+ "enable_dp_attention": server_args.enable_dp_attention,
+ }
+ )
+
+ set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
+
+ # Init components
+ min_per_gpu_memory = self.init_torch_distributed()
+ self.sampler = Sampler()
+ self.load_model()
+
+ # Apply torch TP if model supports it
+ supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
+ if self.tp_size > 1 and supports_torch_tp:
+ self.apply_torch_tp()
+ self.torch_tp_applied = True
+ else:
+ self.torch_tp_applied = False
+
+ if server_args.lora_paths is not None:
+ self.init_lora_manager()
+ self.init_memory_pool(
+ min_per_gpu_memory,
+ server_args.max_running_requests,
+ server_args.max_total_tokens,
+ )
+ if self.device == "cuda":
+ self.init_cublas()
+ self.init_attention_backend()
+ self.init_cuda_graphs()
+ else:
+ self.cuda_graph_runner = None
+ self.init_attention_backend()
+
+ def init_torch_distributed(self):
+ logger.info("Init torch distributed begin.")
+ # Init torch distributed
+ torch.get_device_module(self.device).set_device(self.gpu_id)
+ if self.device == "cuda":
+ backend = "nccl"
+ # ToDO(liangan1):Just use gloo to bypass the initilization fail
+ # Need to use xccl for xpu backend in the future
+ elif self.device == "xpu":
+ backend = "gloo"
+ elif self.device == "hpu":
+ backend = "hccl"
+
+ if not self.server_args.enable_p2p_check:
+ monkey_patch_vllm_p2p_access_check(self.gpu_id)
+ if self.server_args.dist_init_addr:
+ dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
+ else:
+ dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
+ set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
+ init_distributed_environment(
+ backend=backend,
+ world_size=self.tp_size,
+ rank=self.tp_rank,
+ local_rank=self.gpu_id,
+ distributed_init_method=dist_init_method,
+ )
+ initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
+ min_per_gpu_memory = get_available_gpu_memory(
+ self.device, self.gpu_id, distributed=self.tp_size > 1
+ )
+ self.tp_group = get_tp_group()
+
+ # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
+ # so we disable padding in cuda graph.
+ if self.device == "cuda" and not all(
+ in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
+ ):
+ self.server_args.disable_cuda_graph_padding = True
+ logger.info(
+ "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
+ )
+
+ # Check memory for tensor parallelism
+ if self.tp_size > 1:
+ local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
+ if min_per_gpu_memory < local_gpu_memory * 0.9:
+ raise ValueError(
+ "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
+ )
+
+ return min_per_gpu_memory
+
+ def setup_model(self):
+ try:
+ from vllm.config import VllmConfig
+
+ vllm_config = VllmConfig()
+ vllm_config.model_config = self.vllm_model_config
+ vllm_config.load_config = self.load_config
+ vllm_config.device_config = DeviceConfig(self.device)
+ vllm_config.quant_config = VllmConfig._get_quantization_config(
+ vllm_config.model_config, vllm_config.load_config
+ )
+ return get_model(vllm_config=vllm_config)
+ except ImportError:
+ pass
+
+ return get_model(
+ model_config=self.vllm_model_config,
+ load_config=self.load_config,
+ device_config=DeviceConfig(self.device),
+ parallel_config=None,
+ scheduler_config=None,
+ lora_config=None,
+ cache_config=None,
+ )
+
+ def get_model_config_params(self):
+ sig = inspect.signature(VllmModelConfig.__init__)
+ params = {
+ "model": self.server_args.model_path,
+ "quantization": self.server_args.quantization,
+ "tokenizer": None,
+ "tokenizer_mode": None,
+ "trust_remote_code": self.server_args.trust_remote_code,
+ "dtype": self.server_args.dtype,
+ "seed": self.server_args.random_seed,
+ "skip_tokenizer_init": True,
+ }
+
+ if "task" in sig.parameters:
+ params["task"] = ""
+
+ return params
+
+ def load_model(self):
+ logger.info(
+ f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
+ )
+
+ # This can reduce thread conflicts and speed up weight loading.
+ torch.set_num_threads(1)
+ if self.device == "cuda":
+ if torch.cuda.get_device_capability()[0] < 8:
+ logger.info(
+ "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
+ )
+ self.server_args.dtype = "float16"
+ if torch.cuda.get_device_capability()[1] < 5:
+ raise RuntimeError("SGLang only supports sm75 and above.")
+
+ # Prepare the vllm model config
+ self.load_config = LoadConfig(
+ load_format=self.server_args.load_format,
+ download_dir=self.server_args.download_dir,
+ )
+ monkey_patch_vllm_model_config()
+ self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
+ if self.model_config.model_override_args is not None:
+ self.vllm_model_config.hf_config.update(
+ self.model_config.model_override_args
+ )
+
+ self.model = self.setup_model()
+
+ self.sliding_window_size = (
+ self.model.get_attention_sliding_window_size()
+ if hasattr(self.model, "get_attention_sliding_window_size")
+ else None
+ )
+ self.dtype = self.vllm_model_config.dtype
+
+ logger.info(
+ f"Load weight end. "
+ f"type={type(self.model).__name__}, "
+ f"dtype={self.dtype}, "
+ f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
+ )
+
+ def update_weights_from_disk(self, model_path: str, load_format: str):
+ """Update engine weights online from disk."""
+ from vllm.model_executor.model_loader.loader import (
+ DefaultModelLoader,
+ device_loading_context,
+ get_model_loader,
+ )
+ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
+
+ logger.info(
+ f"Update engine weights online from disk begin. "
+ f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
+ )
+
+ target_device = torch.device(self.device)
+
+ try:
+ model_config_params = self.get_model_config_params()
+ model_config_params["model"] = model_path
+ vllm_model_config = VllmModelConfig(**model_config_params)
+ except Exception as e:
+ message = f"Failed to load model config: {e}."
+ return False, message
+
+ load_config = LoadConfig(load_format=load_format)
+
+ # Only support vllm DefaultModelLoader for now
+ loader = get_model_loader(load_config)
+ if not isinstance(loader, DefaultModelLoader):
+ message = f"Failed to get model loader: {loader}."
+ return False, message
+
+ def get_weight_iter(config):
+ iter = loader._get_weights_iterator(
+ DefaultModelLoader.Source(
+ config.model,
+ revision=config.revision,
+ fall_back_to_pt=getattr(
+ self.model, "fall_back_to_pt_during_load", True
+ ),
+ )
+ )
+ return iter
+
+ def model_load_weights(model, iter):
+ model.load_weights(iter)
+ for _, module in self.model.named_modules():
+ quant_method = getattr(module, "quant_method", None)
+ if quant_method is not None:
+ with device_loading_context(module, target_device):
+ quant_method.process_weights_after_loading(module)
+ return model
+
+ with set_default_torch_dtype(vllm_model_config.dtype):
+ try:
+ iter = get_weight_iter(vllm_model_config)
+ except Exception as e:
+ message = f"Failed to get weights iterator: {e}."
+ return False, message
+ try:
+ model = model_load_weights(self.model, iter)
+ except Exception as e:
+ message = (
+ f"Failed to update weights: {e}.\nRolling back to original weights."
+ )
+ del iter
+ gc.collect()
+ iter = get_weight_iter(self.vllm_model_config)
+ self.model = model_load_weights(self.model, iter)
+ return False, message
+
+ self.model = model
+ self.server_args.model_path = model_path
+ self.server_args.load_format = load_format
+ self.vllm_model_config = vllm_model_config
+ self.load_config = load_config
+ self.model_config.path = model_path
+
+ logger.info("Update weights end.")
+ return True, "Succeeded to update model weights."
+
+ def init_parameter_update_group(
+ self,
+ master_address,
+ master_port,
+ rank_offset,
+ world_size,
+ group_name,
+ backend="nccl",
+ ):
+ """Initialize the Torch process group for model parameter updates.
+
+ `_model_update_group` is used in the RLHF workflow, where rank 0 is the actor model in
+ the training engine, and the other ranks are the inference engine, which is used for rollout.
+
+ In the RLHF workflow, the training engine updates the model weights/parameters online,
+ and broadcasts them to the inference engine through the `_model_update_group` process group.
+ """
+ assert (
+ torch.distributed.is_initialized()
+ ), "Default torch process group must be initialized"
+ assert group_name != "", "Group name cannot be empty"
+
+ rank = rank_offset + self.tp_rank
+
+ logger.info(
+ f"init custom process group: master_address={master_address}, master_port={master_port}, "
+ f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
+ )
+
+ try:
+ self._model_update_group = init_custom_process_group(
+ backend=backend,
+ init_method=f"tcp://{master_address}:{master_port}",
+ world_size=world_size,
+ rank=rank,
+ group_name=group_name,
+ )
+
+ return True, "Succeeded to initialize custom process group."
+ except Exception as e:
+ message = f"Failed to initialize custom process group: {e}."
+ logger.error(message)
+ return False, message
+
+ def get_weights_by_parameter_name(
+ self, name: str, truncate_size: int = 100
+ ) -> Optional[torch.Tensor]:
+ try:
+ # 检查是否是合并的参数
+ mapped_name = name
+ mapped_shard_id = None
+ for param_name, weight_name, shard_id in self.model.stacked_params_mapping:
+ if weight_name in name:
+ mapped_name = name.replace(weight_name, param_name)
+ mapped_shard_id = shard_id
+ break
+ params_dict = dict(self.model.named_parameters())
+ if mapped_name in params_dict:
+ param = params_dict[mapped_name]
+ if mapped_shard_id is not None:
+ # 处理合并参数的情况
+ if mapped_shard_id in ["q", "k", "v"]:
+ # 计算在qkv_proj中的偏移和大小
+ num_heads = (
+ self.model.config.num_attention_heads // self.tp_size
+ )
+ num_kv_heads = (
+ self.model.config.num_key_value_heads // self.tp_size
+ )
+ head_dim = (
+ self.model.config.hidden_size
+ // self.model.config.num_attention_heads
+ )
+
+ if mapped_shard_id == "q":
+ offset = 0
+ size = num_heads * head_dim
+ elif mapped_shard_id == "k":
+ offset = num_heads * head_dim
+ size = num_kv_heads * head_dim
+ elif mapped_shard_id == "v":
+ offset = (num_heads + num_kv_heads) * head_dim
+ size = num_kv_heads * head_dim
+
+ # 提取对应部分的权重
+ weight = param.data.narrow(0, offset, size)
+ elif mapped_shard_id in [0, 1]:
+ # 处理 gate_up_proj 的情况
+ intermediate_size = self.model.config.intermediate_size
+ hidden_size = self.model.config.hidden_size
+
+ if mapped_shard_id == 0: # gate_proj
+ offset = 0
+ size = intermediate_size
+ elif mapped_shard_id == 1: # up_proj
+ offset = intermediate_size
+ size = intermediate_size
+
+ # 提取对应部分的权重
+ weight = param.data.narrow(0, offset, size)
+ else:
+ weight = param.data
+ else:
+ weight = param.data
+
+ # 转换并截断
+ return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
+ else:
+ logger.warning(
+ f"Parameter {name} (mapped to {mapped_name}) not found in model"
+ )
+ return None
+
+ except Exception as e:
+ logger.error(f"Error when getting parameter {name}: {e}")
+ return None
+
+ def init_lora_manager(self):
+ self.lora_manager = LoRAManager(
+ base_model=self.model,
+ lora_paths=self.server_args.lora_paths,
+ base_hf_config=self.model_config.hf_config,
+ max_loras_per_batch=self.server_args.max_loras_per_batch,
+ load_config=self.load_config,
+ dtype=self.dtype,
+ )
+ logger.info("LoRA manager ready.")
+
+ def profile_max_num_token(self, total_gpu_memory: int):
+ available_gpu_memory = get_available_gpu_memory(
+ self.device, self.gpu_id, distributed=self.tp_size > 1
+ )
+ if (
+ self.model_config.attention_arch == AttentionArch.MLA
+ and not self.server_args.disable_mla
+ ):
+ cell_size = (
+ (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
+ * self.model_config.num_hidden_layers
+ * torch._utils._element_size(self.kv_cache_dtype)
+ )
+ else:
+ cell_size = (
+ self.model_config.get_num_kv_heads(self.tp_size)
+ * self.model_config.head_dim
+ * self.model_config.num_hidden_layers
+ * 2
+ * torch._utils._element_size(self.kv_cache_dtype)
+ )
+ rest_memory = available_gpu_memory - total_gpu_memory * (
+ 1 - self.mem_fraction_static
+ )
+ max_num_token = int(rest_memory * (1 << 30) // cell_size)
+ return max_num_token
+
+ def init_memory_pool(
+ self,
+ total_gpu_memory: int,
+ max_num_reqs: Optional[int] = None,
+ max_total_tokens: Optional[int] = None,
+ ):
+ if self.server_args.kv_cache_dtype == "auto":
+ self.kv_cache_dtype = self.dtype
+ elif self.server_args.kv_cache_dtype == "fp8_e5m2":
+ if is_hip(): # Using natively supported format
+ self.kv_cache_dtype = torch.float8_e5m2fnuz
+ else:
+ self.kv_cache_dtype = torch.float8_e5m2
+ else:
+ raise ValueError(
+ f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
+ )
+
+ self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
+ if max_total_tokens is not None:
+ if max_total_tokens > self.max_total_num_tokens:
+ logging.warning(
+ f"max_total_tokens={max_total_tokens} is larger than the profiled value "
+ f"{self.max_total_num_tokens}. "
+ f"Use the profiled value instead."
+ )
+ self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
+
+ if self.max_total_num_tokens <= 0:
+ raise RuntimeError(
+ "Not enough memory. Please try to increase --mem-fraction-static."
+ )
+
+ if max_num_reqs is None:
+ max_num_reqs = min(
+ max(
+ int(
+ self.max_total_num_tokens / self.model_config.context_len * 512
+ ),
+ 2048,
+ ),
+ 4096,
+ )
+
+ self.req_to_token_pool = ReqToTokenPool(
+ size=max_num_reqs + 1,
+ max_context_len=self.model_config.context_len + 4,
+ device=self.device,
+ use_records=False,
+ )
+ if (
+ self.model_config.attention_arch == AttentionArch.MLA
+ and not self.server_args.disable_mla
+ ):
+ self.token_to_kv_pool = MLATokenToKVPool(
+ self.max_total_num_tokens,
+ dtype=self.kv_cache_dtype,
+ kv_lora_rank=self.model_config.kv_lora_rank,
+ qk_rope_head_dim=self.model_config.qk_rope_head_dim,
+ layer_num=self.model_config.num_hidden_layers,
+ device=self.device,
+ )
+ elif self.server_args.enable_double_sparsity:
+ self.token_to_kv_pool = DoubleSparseTokenToKVPool(
+ self.max_total_num_tokens,
+ dtype=self.kv_cache_dtype,
+ head_num=self.model_config.get_num_kv_heads(self.tp_size),
+ head_dim=self.model_config.head_dim,
+ layer_num=self.model_config.num_hidden_layers,
+ device=self.device,
+ heavy_channel_num=self.server_args.ds_heavy_channel_num,
+ )
+ else:
+ self.token_to_kv_pool = MHATokenToKVPool(
+ self.max_total_num_tokens,
+ dtype=self.kv_cache_dtype,
+ head_num=self.model_config.get_num_kv_heads(self.tp_size),
+ head_dim=self.model_config.head_dim,
+ layer_num=self.model_config.num_hidden_layers,
+ device=self.device,
+ )
+ logger.info(
+ f"Memory pool end. "
+ f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
+ )
+
+ def init_cublas(self):
+ """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
+ dtype = torch.float16
+ device = "cuda"
+ a = torch.ones((16, 16), dtype=dtype, device=device)
+ b = torch.ones((16, 16), dtype=dtype, device=device)
+ c = a @ b
+ return c
+
+ def init_attention_backend(self):
+ """Init attention kernel backend."""
+ if self.server_args.attention_backend == "flashinfer":
+ self.attn_backend = FlashInferAttnBackend(self)
+ elif self.server_args.attention_backend == "triton":
+ assert self.sliding_window_size is None, (
+ "Window attention is not supported in the triton attention backend. "
+ "Please use `--attention-backend flashinfer`."
+ )
+ assert not self.model_config.is_encoder_decoder, (
+ "Cross attention is not supported in the triton attention backend. "
+ "Please use `--attention-backend flashinfer`."
+ )
+ if self.server_args.enable_double_sparsity:
+ self.attn_backend = DoubleSparseAttnBackend(self)
+ else:
+ self.attn_backend = TritonAttnBackend(self)
+ else:
+ raise ValueError(
+ f"Invalid attention backend: {self.server_args.attention_backend}"
+ )
+
+ def init_double_sparsity_channel_config(self, selected_channel):
+
+ selected_channel = "." + selected_channel + "_proj"
+ self.sorted_channels = []
+ # load channel config
+ with open(self.server_args.ds_channel_config_path, "r") as f:
+ channel_config = json.load(f)
+
+ for i in range(self.model_config.num_hidden_layers):
+ key = "model.layers." + str(i) + ".self_attn" + selected_channel
+ self.sorted_channels.append(
+ torch.tensor(channel_config[key])[
+ :, : self.server_args.ds_heavy_channel_num
+ ]
+ .contiguous()
+ .cuda()
+ )
+
+ def init_cuda_graphs(self):
+ """Capture cuda graphs."""
+ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
+
+ self.cuda_graph_runner = None
+
+ if not self.is_generation:
+ # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
+ return
+
+ if self.server_args.disable_cuda_graph:
+ return
+
+ logger.info("Capture cuda graph begin. This can take up to several minutes.")
+ self.cuda_graph_runner = CudaGraphRunner(self)
+
+ def apply_torch_tp(self):
+ logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
+ from sglang.srt.model_parallel import tensor_parallel
+
+ device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
+ tensor_parallel(self.model, device_mesh)
+
+ def forward_decode(self, forward_batch: ForwardBatch):
+ if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
+ return self.cuda_graph_runner.replay(forward_batch)
+
+ forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
+ self.attn_backend.init_forward_metadata(forward_batch)
+ return self.model.forward(
+ forward_batch.input_ids, forward_batch.positions, forward_batch
+ )
+
+ def forward_extend(self, forward_batch: ForwardBatch):
+ self.attn_backend.init_forward_metadata(forward_batch)
+ if self.is_generation:
+ if forward_batch.input_embeds is None:
+ return self.model.forward(
+ forward_batch.input_ids, forward_batch.positions, forward_batch
+ )
+ else:
+ return self.model.forward(
+ forward_batch.input_ids,
+ forward_batch.positions,
+ forward_batch,
+ input_embeds=forward_batch.input_embeds.bfloat16(),
+ )
+ else:
+ # Only embedding models have get_embedding parameter
+ return self.model.forward(
+ forward_batch.input_ids,
+ forward_batch.positions,
+ forward_batch,
+ get_embedding=True,
+ )
+
+ def forward_idle(self, forward_batch: ForwardBatch):
+ if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
+ return self.cuda_graph_runner.replay(forward_batch)
+
+ return self.model.forward(
+ forward_batch.input_ids, forward_batch.positions, forward_batch
+ )
+
+ def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
+ if forward_batch.forward_mode.is_decode():
+ return self.forward_decode(forward_batch)
+ elif forward_batch.forward_mode.is_extend():
+ return self.forward_extend(forward_batch)
+ elif forward_batch.forward_mode.is_idle():
+ return self.forward_idle(forward_batch)
+ else:
+ raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
+
+ def sample(
+ self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
+ ) -> torch.Tensor:
+ sampling_info = forward_batch.sampling_info
+ if sampling_info.sampling_info_done:
+ # Overlap mode: the function update_regex_vocab_mask was executed
+ # in process_batch_result of the last batch.
+ if sampling_info.grammars:
+ sampling_info.sampling_info_done.wait()
+ else:
+ # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
+ sampling_info.update_regex_vocab_mask()
+ sampling_info.update_penalties()
+ logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
+
+ # Sample the next tokens.
+ next_token_ids = self.sampler(logits, sampling_info)
+ return next_token_ids
+
+ def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
+ # Apply logit_bias
+ if sampling_info.logit_bias is not None:
+ logits.add_(sampling_info.logit_bias)
+
+ # min-token, presence, frequency
+ if sampling_info.linear_penalties is not None:
+ logits.add_(sampling_info.linear_penalties)
+
+ # repetition
+ if sampling_info.scaling_penalties is not None:
+ logits = torch.where(
+ logits > 0,
+ logits / sampling_info.scaling_penalties,
+ logits * sampling_info.scaling_penalties,
+ )
+
+ # Apply regex vocab_mask
+ if sampling_info.vocab_mask is not None:
+ sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
+
+ return logits
+
+ @property
+ def model_is_mrope(self) -> bool:
+ """Detect if the model has "mrope" rope_scaling type.
+ mrope requires keep "rope_deltas" between prompt and decoding phases."""
+ rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
+ if rope_scaling is None:
+ return False
+ return rope_scaling.get("type", None) == "mrope"
+
+
+@lru_cache()
+def import_model_classes():
+ model_arch_name_to_cls = {}
+ package_name = "sglang.srt.models"
+ package = importlib.import_module(package_name)
+ for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
+ if not ispkg:
+ try:
+ module = importlib.import_module(name)
+ except Exception as e:
+ logger.warning(f"Ignore import error when loading {name}. {e}")
+ if crash_on_warnings():
+ raise ValueError(f"Ignore import error when loading {name}. {e}")
+ continue
+ if hasattr(module, "EntryClass"):
+ entry = module.EntryClass
+ if isinstance(
+ entry, list
+ ): # To support multiple model classes in one module
+ for tmp in entry:
+ assert (
+ tmp.__name__ not in model_arch_name_to_cls
+ ), f"Duplicated model implementation for {tmp.__name__}"
+ model_arch_name_to_cls[tmp.__name__] = tmp
+ else:
+ assert (
+ entry.__name__ not in model_arch_name_to_cls
+ ), f"Duplicated model implementation for {entry.__name__}"
+ model_arch_name_to_cls[entry.__name__] = entry
+
+ return model_arch_name_to_cls
+
+
+def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
+ model_arch_name_to_cls = import_model_classes()
+
+ if model_arch not in model_arch_name_to_cls:
+ raise ValueError(
+ f"Unsupported architectures: {model_arch}. "
+ f"Supported list: {list(model_arch_name_to_cls.keys())}"
+ )
+ return model_arch_name_to_cls[model_arch]
+
+
+# Monkey patch model loader
+setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
+setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
+setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
+setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
+setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
diff --git a/test/srt/test_update_weights.py b/test/srt/test_update_weights.py
index 327da729aad..2f53f12dbd7 100644
--- a/test/srt/test_update_weights.py
+++ b/test/srt/test_update_weights.py
@@ -49,7 +49,7 @@ def get_model_info(self):
def run_update_weights(self, model_path):
response = requests.post(
- self.base_url + "/update_weights",
+ self.base_url + "/update_weights_from_disk",
json={
"model_path": model_path,
},