diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index a3d30324c37..8deafefb408 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -263,6 +263,24 @@ jobs: cd test/srt python3 test_moe_eval_accuracy_large.py + weight-update-test-2-gpu: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 2-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + bash scripts/ci_install_dependency.sh + + - name: Test weight update (TP, DP = 2,1 or 1,2) + timeout-minutes: 20 + run: | + cd test/srt + python3 test_update_parameter_from_distributed.py + python3 test_get_parameter_by_name.py + finish: needs: [ unit-test-frontend, unit-test-backend-part-1, unit-test-backend-part-2, unit-test-backend-part-3, unit-test-backend-part-4, diff --git a/3rdparty/amd/profiling/PROFILING.md b/3rdparty/amd/profiling/PROFILING.md index 90ad8665e67..79bc75b503b 100644 --- a/3rdparty/amd/profiling/PROFILING.md +++ b/3rdparty/amd/profiling/PROFILING.md @@ -421,5 +421,5 @@ index 62d1ff9..6ecd78c 100644 3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container. 4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling. -======= +------- - [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb index 0b43c6a5ae9..9a97c15ca30 100644 --- a/docs/backend/native_api.ipynb +++ b/docs/backend/native_api.ipynb @@ -14,7 +14,7 @@ "- `/health`\n", "- `/health_generate`\n", "- `/flush_cache`\n", - "- `/update_weights`\n", + "- `/update_weights_from_disk`\n", "- `/encode`(embedding model)\n", "- `/classify`(reward model)\n", "\n", @@ -98,7 +98,7 @@ "print_highlight(response_json)\n", "assert response_json[\"model_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n", "assert response_json[\"is_generation\"] is True\n", - "assert response_json.keys() == {\"model_path\", \"is_generation\"}" + "assert list(response_json.keys()) == [\"model_path\", \"tokenizer_path\", \"is_generation\"]" ] }, { @@ -144,8 +144,7 @@ "source": [ "url = \"http://localhost:30010/health_generate\"\n", "\n", - "response = requests.get(url)\n", - "print_highlight(response.text)" + "response = requests.get(url)" ] }, { @@ -156,8 +155,7 @@ "source": [ "url = \"http://localhost:30010/health\"\n", "\n", - "response = requests.get(url)\n", - "print_highlight(response.text)" + "response = requests.get(url)" ] }, { @@ -187,9 +185,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Update Weights\n", + "## Update Weights From Disk\n", "\n", - "Update model weights without restarting the server. Use for continuous evaluation during training. Only applicable for models with the same architecture and parameter size." + "Update model weights from disk without restarting the server. Use for continuous evaluation during training. Only applicable for models with the same architecture and parameter size." ] }, { @@ -200,7 +198,7 @@ "source": [ "# successful update with same architecture and size\n", "\n", - "url = \"http://localhost:30010/update_weights\"\n", + "url = \"http://localhost:30010/update_weights_from_disk\"\n", "data = {\"model_path\": \"meta-llama/Llama-3.2-1B\"}\n", "\n", "response = requests.post(url, json=data)\n", @@ -218,7 +216,7 @@ "source": [ "# failed update with different parameter size\n", "\n", - "url = \"http://localhost:30010/update_weights\"\n", + "url = \"http://localhost:30010/update_weights_from_disk\"\n", "data = {\"model_path\": \"meta-llama/Llama-3.2-3B\"}\n", "\n", "response = requests.post(url, json=data)\n", @@ -340,7 +338,19 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", + "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", + "\u001b[1;31mClick here for more info. \n", + "\u001b[1;31mView Jupyter log for further details." + ] + } + ], "source": [ "terminate_process(reward_process)" ] diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 8b1f88fa26b..a33ce7180ae 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -352,7 +352,7 @@ class FlushCacheReq: @dataclass -class UpdateWeightReqInput: +class UpdateWeightFromDistReqInput: # The model path with the new weights model_path: str # The format to load the weights @@ -360,11 +360,58 @@ class UpdateWeightReqInput: @dataclass -class UpdateWeightReqOutput: +class UpdateWeightFromDistReqOutput: success: bool message: str +@dataclass +class UpdateParameterFromDistributedReqInput: + name: str + dtype: str + shape: List[int] + empty_cache: bool + + +@dataclass +class UpdateParameterFromDistributedReqOutput: + success: bool + message: str + + +@dataclass +class InitParameterUpdateGroupReqInput: + # The master address + master_address: str + # The master port + master_port: int + # The rank offset + rank_offset: int + # The world size + world_size: int + # The group name + group_name: str + # The backend + backend: str = "nccl" + + +@dataclass +class InitParameterUpdateGroupReqOutput: + success: bool + message: str + + +@dataclass +class GetParameterByNameReqInput: + name: str + truncate_size: int = 100 + + +@dataclass +class GetParameterByNameReqOutput: + parameter: list + + @dataclass class AbortReq: # The request id diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 663d4c4f935..bb4f2cc93cc 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -38,13 +38,19 @@ FlushCacheReq, GetMemPoolSizeReq, GetMemPoolSizeReqOutput, + GetParameterByNameReqInput, + GetParameterByNameReqOutput, + InitParameterUpdateGroupReqInput, + InitParameterUpdateGroupReqOutput, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, - UpdateWeightReqInput, - UpdateWeightReqOutput, + UpdateParameterFromDistributedReqInput, + UpdateParameterFromDistributedReqOutput, + UpdateWeightFromDistReqInput, + UpdateWeightFromDistReqOutput, ) from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, @@ -500,10 +506,25 @@ def process_input_requests(self, recv_reqs: List): self.flush_cache() elif isinstance(recv_req, AbortReq): self.abort_request(recv_req) - elif isinstance(recv_req, UpdateWeightReqInput): - success, message = self.update_weights(recv_req) + elif isinstance(recv_req, UpdateWeightFromDistReqInput): + success, message = self.update_weights_from_disk(recv_req) self.send_to_tokenizer.send_pyobj( - UpdateWeightReqOutput(success, message) + UpdateWeightFromDistReqOutput(success, message) + ) + elif isinstance(recv_req, GetParameterByNameReqInput): + parameter = self.get_weights_by_parameter_name(recv_req) + self.send_to_tokenizer.send_pyobj( + GetParameterByNameReqOutput(parameter) + ) + elif isinstance(recv_req, InitParameterUpdateGroupReqInput): + success, message = self.init_parameter_update_group(recv_req) + self.send_to_tokenizer.send_pyobj( + InitParameterUpdateGroupReqOutput(success, message) + ) + elif isinstance(recv_req, UpdateParameterFromDistributedReqInput): + success, message = self.update_parameter_from_distributed(recv_req) + self.send_to_tokenizer.send_pyobj( + UpdateParameterFromDistributedReqOutput(success, message) ) elif isinstance(recv_req, ProfileReq): if recv_req == ProfileReq.START_PROFILE: @@ -1353,9 +1374,9 @@ def abort_request(self, recv_req: AbortReq): self.tree_cache.cache_finished_req(req) break - def update_weights(self, recv_req: UpdateWeightReqInput): + def update_weights_from_disk(self, recv_req: UpdateWeightFromDistReqInput): """In-place update of the weights.""" - success, message = self.tp_worker.update_weights(recv_req) + success, message = self.tp_worker.update_weights_from_disk(recv_req) if success: flash_cache_success = self.flush_cache() assert flash_cache_success, "Cache flush failed after updating weights" @@ -1363,6 +1384,27 @@ def update_weights(self, recv_req: UpdateWeightReqInput): logger.error(message) return success, message + def init_parameter_update_group(self, recv_req: InitParameterUpdateGroupReqInput): + """Initialize the online model parameter update group.""" + success, message = self.tp_worker.init_parameter_update_group(recv_req) + return success, message + + def update_parameter_from_distributed( + self, recv_req: UpdateParameterFromDistributedReqInput + ): + """Update the online model parameter.""" + success, message = self.tp_worker.update_parameter_from_distributed(recv_req) + if success: + flash_cache_success = self.flush_cache() + assert flash_cache_success, "Cache flush failed after updating weights" + else: + logger.error(message) + return success, message + + def get_weights_by_parameter_name(self, recv_req: GetParameterByNameReqInput): + parameter = self.tp_worker.get_weights_by_parameter_name(recv_req) + return parameter + def start_profile(self) -> None: if self.profiler is None: raise RuntimeError("Profiler is not enabled.") diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 001ecc1ebe7..11ed2ec2540 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -47,13 +47,19 @@ GenerateReqInput, GetMemPoolSizeReq, GetMemPoolSizeReqOutput, + GetParameterByNameReqInput, + GetParameterByNameReqOutput, + InitParameterUpdateGroupReqInput, + InitParameterUpdateGroupReqOutput, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, - UpdateWeightReqInput, - UpdateWeightReqOutput, + UpdateParameterFromDistributedReqInput, + UpdateParameterFromDistributedReqOutput, + UpdateWeightFromDistReqInput, + UpdateWeightFromDistReqOutput, ) from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams @@ -425,8 +431,10 @@ async def get_memory_pool_size(self): ret = [r.size for r in res] return ret - async def update_weights( - self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None + async def update_weights_from_disk( + self, + obj: UpdateWeightFromDistReqInput, + request: Optional[fastapi.Request] = None, ): if self.to_create_loop: self.create_handle_loop() @@ -471,6 +479,89 @@ async def update_weights( else: return False, "Another update is in progress. Please try again later." + async def init_parameter_update_group( + self, + obj: InitParameterUpdateGroupReqInput, + request: Optional[fastapi.Request] = None, + ) -> bool: + if self.to_create_loop: + self.create_handle_loop() + + if obj.backend is None: + obj.backend = "nccl" + self.send_to_scheduler.send_pyobj(obj) + + self.init_parameter_update_group_result = asyncio.Future() + + if self.server_args.dp_size == 1: + result = await self.init_parameter_update_group_result + return result.success, result.message + else: + self.init_parameter_update_group_tmp = [] + result = await self.init_parameter_update_group_result + all_success = all([r.success for r in result]) + all_message = [r.message for r in result] + all_message = " | ".join(all_message) + return all_success, all_message + + async def update_parameter_from_distributed( + self, + obj: UpdateParameterFromDistributedReqInput, + request: Optional[fastapi.Request] = None, + ): + if self.to_create_loop: + self.create_handle_loop() + if not self.model_update_lock.locked(): + + async with self.model_update_lock: + # wait for the previous update requests to finish + for i in range(3): + while len(self.rid_to_state) > 0: + await asyncio.sleep(0.001) + # FIXME: We add some sleep here to avoid some race conditions. + # We can use a read-write lock as a better fix. + await asyncio.sleep(0.01) + + self.send_to_scheduler.send_pyobj(obj) + self.parameter_update_result = asyncio.Future() + + if self.server_args.dp_size == 1: + result = await self.parameter_update_result + return result.success, result.message + else: # self.server_args.dp_size > 1 + self.parameter_update_tmp = [] + result = await self.parameter_update_result + all_success = all([r.success for r in result]) + all_message = [r.message for r in result] + all_message = " | ".join(all_message) + return all_success, all_message + + else: + logger.error( + f"Another parameter update is in progress in tokenizer manager" + ) + return ( + False, + "Another parameter update is in progress. Please try again later.", + ) + + async def get_weights_by_parameter_name( + self, obj: GetParameterByNameReqInput, request: Optional[fastapi.Request] = None + ): + if self.to_create_loop: + self.create_handle_loop() + + self.send_to_scheduler.send_pyobj(obj) + self.get_weights_by_parameter_name_result = asyncio.Future() + if self.server_args.dp_size == 1: + result = await self.get_weights_by_parameter_name_result + return result.parameter + else: + self.get_weights_by_parameter_name_tmp = [] + result = await self.get_weights_by_parameter_name_result + all_parameters = [r.parameter for r in result] + return all_parameters + async def open_session( self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None ): @@ -540,10 +631,16 @@ async def handle_loop(self): while True: recv_obj: Union[ - BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput + BatchStrOut, + BatchEmbeddingOut, + BatchTokenIDOut, + UpdateWeightFromDistReqOutput, + UpdateParameterFromDistributedReqOutput, + GetParameterByNameReqOutput, + InitParameterUpdateGroupReqOutput, ] = await self.recv_from_detokenizer.recv_pyobj() - if isinstance(recv_obj, UpdateWeightReqOutput): + if isinstance(recv_obj, UpdateWeightFromDistReqOutput): if self.server_args.dp_size == 1: self.model_update_result.set_result(recv_obj) else: # self.server_args.dp_size > 1 @@ -552,6 +649,43 @@ async def handle_loop(self): if len(self.model_update_tmp) == self.server_args.dp_size: self.model_update_result.set_result(self.model_update_tmp) continue + elif isinstance(recv_obj, UpdateParameterFromDistributedReqOutput): + if self.server_args.dp_size == 1: + self.parameter_update_result.set_result(recv_obj) + else: # self.server_args.dp_size > 1 + self.parameter_update_tmp.append(recv_obj) + # set future if the all results are recevied + if len(self.parameter_update_tmp) == self.server_args.dp_size: + self.parameter_update_result.set_result( + self.parameter_update_tmp + ) + continue + elif isinstance(recv_obj, GetParameterByNameReqOutput): + if self.server_args.dp_size == 1: + self.get_weights_by_parameter_name_result.set_result(recv_obj) + else: + self.get_weights_by_parameter_name_tmp.append(recv_obj) + if ( + len(self.get_weights_by_parameter_name_tmp) + == self.server_args.dp_size + ): + self.get_weights_by_parameter_name_result.set_result( + self.get_weights_by_parameter_name_tmp + ) + continue + elif isinstance(recv_obj, InitParameterUpdateGroupReqOutput): + if self.server_args.dp_size == 1: + self.init_parameter_update_group_result.set_result(recv_obj) + else: + self.init_parameter_update_group_tmp.append(recv_obj) + if ( + len(self.init_parameter_update_group_tmp) + == self.server_args.dp_size + ): + self.init_parameter_update_group_result.set_result( + self.init_parameter_update_group_tmp + ) + continue elif isinstance(recv_obj, GetMemPoolSizeReqOutput): if self.server_args.dp_size == 1: self.mem_pool_size.set_result(recv_obj) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index a5d694e77bc..d68dd78cc79 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -19,7 +19,12 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer -from sglang.srt.managers.io_struct import UpdateWeightReqInput +from sglang.srt.managers.io_struct import ( + GetParameterByNameReqInput, + InitParameterUpdateGroupReqInput, + UpdateParameterFromDistributedReqInput, + UpdateWeightFromDistReqInput, +) from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner @@ -155,8 +160,33 @@ def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): embeddings = logits_output.embeddings return embeddings - def update_weights(self, recv_req: UpdateWeightReqInput): - success, message = self.model_runner.update_weights( + def update_weights_from_disk(self, recv_req: UpdateWeightFromDistReqInput): + success, message = self.model_runner.update_weights_from_disk( recv_req.model_path, recv_req.load_format ) return success, message + + def init_parameter_update_group(self, recv_req: InitParameterUpdateGroupReqInput): + success, message = self.model_runner.init_parameter_update_group( + recv_req.master_address, + recv_req.master_port, + recv_req.rank_offset, + recv_req.world_size, + recv_req.group_name, + recv_req.backend, + ) + return success, message + + def update_parameter_from_distributed( + self, recv_req: UpdateParameterFromDistributedReqInput + ): + success, message = self.model_runner.update_parameter_from_distributed( + recv_req.name, recv_req.dtype, recv_req.shape, recv_req.empty_cache + ) + return success, message + + def get_weights_by_parameter_name(self, recv_req: GetParameterByNameReqInput): + parameter = self.model_runner.get_weights_by_parameter_name( + recv_req.name, recv_req.truncate_size + ) + return parameter diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 3b53759a75f..02a6dd847e6 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -21,7 +21,12 @@ import torch -from sglang.srt.managers.io_struct import UpdateWeightReqInput +from sglang.srt.managers.io_struct import ( + GetParameterByNameReqInput, + InitParameterUpdateGroupReqInput, + UpdateParameterFromDistributedReqInput, + UpdateWeightFromDistReqInput, +) from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.server_args import ServerArgs @@ -195,10 +200,23 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): ) % self.future_token_ids_limit return None, future_next_token_ids - def update_weights(self, recv_req: UpdateWeightReqInput): - success, message = self.worker.update_weights(recv_req) + def update_weights_from_disk(self, recv_req: UpdateWeightFromDistReqInput): + success, message = self.worker.update_weights_from_disk(recv_req) return success, message + def init_parameter_update_group(self, recv_req: InitParameterUpdateGroupReqInput): + success, message = self.worker.init_parameter_update_group(recv_req) + return success, message + + def update_parameter_from_distributed( + self, recv_req: UpdateParameterFromDistributedReqInput + ): + success, message = self.worker.update_parameter_from_distributed(recv_req) + return success, message + + def get_weights_by_parameter_name(self, recv_req: GetParameterByNameReqInput): + return self.worker.get_weights_by_parameter_name(recv_req) + def __delete__(self): self.input_queue.put((None, None)) self.copy_queue.put((None, None, None)) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7c1c51a8fb2..43036c949ea 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -21,9 +21,10 @@ import logging import pkgutil from functools import lru_cache -from typing import Optional, Type +from typing import Any, Optional, Type, Union import torch +import torch.distributed as dist import torch.nn as nn from vllm.config import DeviceConfig, LoadConfig from vllm.config import ModelConfig as VllmModelConfig @@ -58,6 +59,7 @@ crash_on_warnings, enable_show_time_cost, get_available_gpu_memory, + init_custom_process_group, is_hip, monkey_patch_vllm_model_config, monkey_patch_vllm_p2p_access_check, @@ -316,8 +318,8 @@ def load_model(self): f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) - def update_weights(self, model_path: str, load_format: str): - """Update weights in-place.""" + def update_weights_from_disk(self, model_path: str, load_format: str): + """Update engine weights online from disk.""" from vllm.model_executor.model_loader.loader import ( DefaultModelLoader, device_loading_context, @@ -326,7 +328,7 @@ def update_weights(self, model_path: str, load_format: str): from vllm.model_executor.model_loader.utils import set_default_torch_dtype logger.info( - f"Update weights begin. " + f"Update engine weights online from disk begin. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) @@ -397,6 +399,130 @@ def model_load_weights(model, iter): logger.info("Update weights end.") return True, "Succeeded to update model weights." + def init_parameter_update_group( + self, + master_address, + master_port, + rank_offset, + world_size, + group_name, + backend="nccl", + ): + """Initialize the Torch process group for model parameter updates. + + `_model_update_group` is used in the RLHF workflow, where rank 0 is the actor model in + the training engine, and the other ranks are the inference engine, which is used for rollout. + + In the RLHF workflow, the training engine updates the model weights/parameters online, + and broadcasts them to the inference engine through the `_model_update_group` process group. + """ + assert ( + torch.distributed.is_initialized() + ), "Default torch process group must be initialized" + assert group_name != "", "Group name cannot be empty" + + rank = rank_offset + self.tp_rank + + logger.info( + f"init custom process group: master_address={master_address}, master_port={master_port}, " + f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}" + ) + + try: + self._model_update_group = init_custom_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=rank, + group_name=group_name, + ) + + return True, "Succeeded to initialize custom process group." + except Exception as e: + message = f"Failed to initialize custom process group: {e}." + logger.error(message) + return False, message + + def get_weights_by_parameter_name( + self, name: str, truncate_size: int = 100 + ) -> Optional[torch.Tensor]: + try: + # 检查是否是合并的参数 + mapped_name = name + mapped_shard_id = None + for param_name, weight_name, shard_id in self.model.stacked_params_mapping: + if weight_name in name: + mapped_name = name.replace(weight_name, param_name) + mapped_shard_id = shard_id + break + params_dict = dict(self.model.named_parameters()) + if mapped_name in params_dict: + param = params_dict[mapped_name] + if mapped_shard_id is not None: + # 处理合并参数的情况 + if mapped_shard_id in ["q", "k", "v"]: + # 计算在qkv_proj中的偏移和大小 + num_heads = ( + self.model.config.num_attention_heads // self.tp_size + ) + num_kv_heads = ( + self.model.config.num_key_value_heads // self.tp_size + ) + head_dim = ( + self.model.config.hidden_size + // self.model.config.num_attention_heads + ) + + if mapped_shard_id == "q": + offset = 0 + size = num_heads * head_dim + elif mapped_shard_id == "k": + offset = num_heads * head_dim + size = num_kv_heads * head_dim + elif mapped_shard_id == "v": + offset = (num_heads + num_kv_heads) * head_dim + size = num_kv_heads * head_dim + + # 提取对应部分的权重 + weight = param.data.narrow(0, offset, size) + elif mapped_shard_id in [0, 1]: + # 处理 gate_up_proj 的情况 + intermediate_size = self.model.config.intermediate_size + hidden_size = self.model.config.hidden_size + slice_size = intermediate_size // self.tp_size + + if mapped_shard_id == 0: # gate_proj + offset = 0 + size = slice_size + elif mapped_shard_id == 1: # up_proj + offset = slice_size + size = slice_size + + # 提取对应部分的权重 + weight = param.data.narrow(0, offset, size) + else: + weight = param.data + else: + weight = param.data + + if self.tp_size > 1 and ("o_proj" in name or "down_proj" in name): + gathered_weights = [ + torch.zeros_like(weight) for _ in range(self.tp_size) + ] + torch.distributed.all_gather(gathered_weights, weight) + weight = torch.cat(gathered_weights, dim=1) + + return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size] + else: + logger.warning( + f"Parameter {name} (mapped to {mapped_name}) not found in model" + ) + return None + + except Exception as e: + logger.error(f"Error when getting parameter {name}: {e}") + return None + def init_lora_manager(self): self.lora_manager = LoRAManager( base_model=self.model, diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 7e9fd0f7267..67ae46a6440 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -305,6 +305,14 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] @torch.no_grad() def forward( @@ -349,15 +357,7 @@ def get_module_name(self, name): return params_mapping.get(name, name) def get_module_name_from_weight_name(self, name): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id, num_shard) - ("qkv_proj", "q_proj", "q", 3), - ("qkv_proj", "k_proj", "k", 3), - ("qkv_proj", "v_proj", "v", 3), - ("gate_up_proj", "gate_proj", 0, 2), - ("gate_up_proj", "up_proj", 1, 2), - ] - for param_name, weight_name, shard_id, num_shard in stacked_params_mapping: + for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping: if weight_name in name: return ( name.replace(weight_name, param_name)[: -len(".weight")], @@ -370,6 +370,7 @@ def get_num_params(self): return len(params_dict) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + embed_tokens_weight = None stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -378,6 +379,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] + params_dict = dict(self.named_parameters()) load_tie_word_embeddings = ( @@ -425,7 +427,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing param = self.lm_head.weight weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, embed_tokens_weight) + if embed_tokens_weight is not None: + weight_loader(param, embed_tokens_weight) apply_torchao_config_(self, params_dict, set(["proj.weight"])) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index a4753a13458..4030c4e9da9 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -28,12 +28,15 @@ from http import HTTPStatus from typing import AsyncIterator, Dict, List, Optional, Union +import torch + # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) import aiohttp import orjson import requests +import torch.distributed as dist import uvicorn import uvloop from fastapi import FastAPI, File, Form, Request, UploadFile @@ -51,8 +54,11 @@ CloseSessionReqInput, EmbeddingReqInput, GenerateReqInput, + GetParameterByNameReqInput, + InitParameterUpdateGroupReqInput, OpenSessionReqInput, - UpdateWeightReqInput, + UpdateParameterFromDistributedReqInput, + UpdateWeightFromDistReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -78,6 +84,7 @@ assert_pkg_version, configure_logger, delete_directory, + init_custom_process_group, is_port_available, kill_child_process, maybe_set_triton_cache_manager, @@ -201,11 +208,24 @@ async def stop_profile_async(): ) -@app.post("/update_weights") +@app.api_route("/get_memory_pool_size", methods=["GET", "POST"]) +async def get_memory_pool_size(): + """Get the memory pool size in number of tokens""" + try: + ret = await tokenizer_manager.get_memory_pool_size() + + return ret + except Exception as e: + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + +@app.post("/update_weights_from_disk") @time_func_latency -async def update_weights(obj: UpdateWeightReqInput, request: Request): - """Update the weights inplace without re-launching the server.""" - success, message = await tokenizer_manager.update_weights(obj, request) +async def update_weights_from_disk(obj: UpdateWeightFromDistReqInput, request: Request): + """Update the weights from disk inplace without re-launching the server.""" + success, message = await tokenizer_manager.update_weights_from_disk(obj, request) content = {"success": success, "message": message} if success: return ORJSONResponse( @@ -219,6 +239,54 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request): ) +@app.post("/init_parameter_update_group") +async def init_parameter_update_group( + obj: InitParameterUpdateGroupReqInput, request: Request +): + """Initialize the parameter update group.""" + success, message = await tokenizer_manager.init_parameter_update_group(obj, request) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content, status_code=200) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + +@app.post("/update_parameter_from_distributed") +async def update_parameter_from_distributed( + obj: UpdateParameterFromDistributedReqInput, request: Request +): + """Update model parameter from distributed online.""" + success, message = await tokenizer_manager.update_parameter_from_distributed( + obj, request + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content, status_code=200) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + +@app.api_route("/get_weights_by_parameter_name", methods=["GET", "POST"]) +async def get_weights_by_parameter_name( + obj: GetParameterByNameReqInput, request: Request +): + """Get model parameter by name.""" + try: + ret = await tokenizer_manager.get_weights_by_parameter_name(obj, request) + if ret is None: + return ORJSONResponse( + {"error": {"message": "Get parameter by name failed"}}, + status_code=HTTPStatus.BAD_REQUEST, + ) + else: + return ORJSONResponse(ret, status_code=200) + except Exception as e: + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + @app.api_route("/open_session", methods=["GET", "POST"]) async def open_session(obj: OpenSessionReqInput, request: Request): """Open a session, and return its unique session id.""" @@ -276,6 +344,51 @@ async def stream_results() -> AsyncIterator[bytes]: ) +@time_func_latency +async def init_parameter_update_group_request( + obj: InitParameterUpdateGroupReqInput, request: Request +): + """Handle an init parameter update group request.""" + try: + ret = await tokenizer_manager.init_parameter_update_group(obj, request) + print(f"init_parameter_update_group_request in server: {ret}") + return ret + except ValueError as e: + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + +@time_func_latency +async def get_weights_by_parameter_name_request( + obj: GetParameterByNameReqInput, request: Request +): + """Handle a get parameter by name request.""" + try: + ret = await tokenizer_manager.get_weights_by_parameter_name(obj, request) + return ret + except ValueError as e: + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + +@time_func_latency +async def update_parameter_from_distributed_request( + obj: UpdateParameterFromDistributedReqInput, request: Request +): + """Handle an update parameter from distributed request.""" + try: + torch.cuda.synchronize() + print(f"try to update parameter from distributed in server") + ret = await tokenizer_manager.update_parameter_from_distributed(obj, request) + return ret + except ValueError as e: + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + # fastapi implicitly converts json in the request to obj (dataclass) app.post("/generate")(generate_request) app.put("/generate")(generate_request) @@ -946,3 +1059,99 @@ def encode( async def get_server_info(self): return await _get_server_info() + + def generate( + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + stream: bool = False, + ): + obj = GenerateReqInput( + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + lora_path=lora_path, + stream=stream, + ) + + # get the current event loop + loop = asyncio.get_event_loop() + ret = loop.run_until_complete(generate_request(obj, None)) + + if stream is True: + + def generator_wrapper(): + offset = 0 + loop = asyncio.get_event_loop() + generator = ret.body_iterator + while True: + chunk = loop.run_until_complete(generator.__anext__()) + + if chunk.startswith(STREAM_END_SYMBOL): + break + else: + data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :]) + data["text"] = data["text"][offset:] + offset += len(data["text"]) + yield data + + # we cannot yield in the scope of generate() because python does not allow yield + return in the same function + # however, it allows to wrap the generator as a subfunction and return + return generator_wrapper() + else: + return ret + + def init_parameter_update_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str = "nccl", + ): + obj = InitParameterUpdateGroupReqInput( + master_address=master_address, + master_port=master_port, + rank_offset=rank_offset, + world_size=world_size, + group_name=group_name, + backend=backend, + ) + + loop = asyncio.get_event_loop() + return loop.run_until_complete(init_parameter_update_group_request(obj, None)) + + def update_parameter_from_distributed(self, name, dtype, shape, empty_cache=False): + print(f"update parameter from distributed request in engine before synchronize") + torch.cuda.synchronize() + print(f"update parameter from distributed request in engine after synchronize") + obj = UpdateParameterFromDistributedReqInput( + name=name, + dtype=dtype, + shape=shape, + empty_cache=empty_cache, + ) + torch.cuda.synchronize() + print(f"update parameter from distributed request in engine") + loop = asyncio.get_event_loop() + torch.cuda.synchronize() + print(f"try to update parameter from distributed request in engine") + return loop.run_until_complete( + update_parameter_from_distributed_request(obj, None) + ) + + def get_weights_by_parameter_name(self, name, truncate_size=100): + obj = GetParameterByNameReqInput(name=name, truncate_size=truncate_size) + loop = asyncio.get_event_loop() + return loop.run_until_complete(get_weights_by_parameter_name_request(obj, None)) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 4a974e2e754..867b0fa61d7 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -30,6 +30,7 @@ import tempfile import time import warnings +from datetime import timedelta from importlib.metadata import PackageNotFoundError, version from io import BytesIO from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union @@ -38,6 +39,7 @@ import psutil import requests import torch +import torch.distributed import torch.distributed as dist import triton import zmq @@ -45,6 +47,16 @@ from packaging import version as pkg_version from starlette.routing import Mount from torch import nn +from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + Store, + _new_process_group_helper, + _store_based_barrier, + _world, + default_pg_timeout, + rendezvous, +) from torch.func import functional_call from torch.library import Library from torch.profiler import ProfilerActivity, profile, record_function @@ -934,6 +946,71 @@ def get_nvgpu_memory_capacity(): ) +# Copy from pytorch and OpenRLHF to allow creating multiple main groups. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py +def init_custom_process_group( + backend: Union[str, Backend] = None, + init_method: Optional[str] = None, + timeout: Optional[timedelta] = None, + world_size: int = -1, + rank: int = -1, + store: Optional[Store] = None, + group_name: str = None, + pg_options: Optional[Any] = None, +): + assert (store is None) or ( + init_method is None + ), "Cannot specify both init_method and store." + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + if backend: + backend = Backend(backend) + else: + backend = Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore(group_name, store) + + # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 + # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 + # We need to determine the appropriate parameter name based on PyTorch version + pg_options_param_name = ( + "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" + ) + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + ) + + logger.error(f"pg pass in init_custom_process_group world size: {world_size}") + logger.error(f"pg pass in init_custom_process_group rank: {rank}") + + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + + return pg + + def crash_on_warnings(): # Crash on warning if we are running CI tests return get_bool_env_var("SGLANG_IS_IN_CI") diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 3089668443e..41906e51a3a 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -424,6 +424,10 @@ def popen_launch_server( port, *other_args, ] + + if api_key: + command += ["--api-key", api_key] + if api_key: command += ["--api-key", api_key] diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 3f55eb25fdc..ba436ffd811 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -11,6 +11,7 @@ "models/test_reward_models.py", "sampling/penaltylib", "test_chunked_prefill.py", + "test_custom_process_group.py", "test_double_sparsity.py", "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", diff --git a/test/srt/stderr.txt b/test/srt/stderr.txt new file mode 100644 index 00000000000..25e5f8d74a1 --- /dev/null +++ b/test/srt/stderr.txt @@ -0,0 +1,447 @@ +/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/transformers/utils/hub.py:128: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead. + warnings.warn( +[2024-11-22 03:08:11] server_args=ServerArgs(model_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-3.1-8B-Instruct', chat_template=None, is_embedding=False, host='127.0.0.1', port=4254, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=32, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, tp_size=1, stream_interval=1, random_seed=668455020, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=True, enable_mixed_chunk=True, enable_dp_attention=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, num_continuous_decode_steps=1, delete_ckpt_after_loading=False) +/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/transformers/utils/hub.py:128: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead. + warnings.warn( +/opt/dlami/nvme/chenyang/miniconda3/envs/sglang/lib/python3.11/site-packages/transformers/utils/hub.py:128: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead. + warnings.warn( +[2024-11-22 03:08:19 TP0] Init torch distributed begin. +[2024-11-22 03:08:20 TP0] Load weight begin. avail mem=78.50 GB +[2024-11-22 03:08:21 TP0] lm_eval is not installed, GPTQ may not be usable + Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00= 2, "At least 2 GPUs are required" + test_suits = [(1, 1, 1, 1), (2, 0, 1, 0), (0, 2, 0, 1)] + + if torch.cuda.device_count() >= 4: + test_suits.extend([(2, 2, 1, 1), (1, 1, 2, 2)]) + + if torch.cuda.device_count() >= 8: + test_suits.append((2, 2, 2, 2)) + + parameters = [ + "model.embed_tokens.weight", + "model.layers.0.input_layernorm.weight", + "model.layers.1.self_attn.q_proj.weight", + "model.layers.2.self_attn.k_proj.weight", + "model.layers.3.self_attn.v_proj.weight", + "model.layers.4.self_attn.o_proj.weight", + "model.layers.5.mlp.gate_proj.weight", + "model.layers.6.mlp.up_proj.weight", + "model.layers.7.mlp.down_proj.weight", + "model.layers.8.post_attention_layernorm.weight", + "model.norm.weight", + "lm_head.weight", + ] + + for test_suit in test_suits: + cls.init_engine_and_server(*test_suit) + for param_name in parameters: + cls.assert_update_weights_all_close(param_name, 100) + cls.close_engine_and_server() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_init_parameter_update_group.py b/test/srt/test_init_parameter_update_group.py new file mode 100644 index 00000000000..016294ee21d --- /dev/null +++ b/test/srt/test_init_parameter_update_group.py @@ -0,0 +1,130 @@ +import os +import time +import unittest + +import requests +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from transformers import AutoModelForCausalLM + +from sglang.srt.utils import init_custom_process_group, kill_child_process +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + +mp.set_start_method("spawn", force=True) + + +class TestParameterUpdateGroup(unittest.TestCase): + @classmethod + def init_process(cls, rank, world_size, base_url, model_name, server_pid): + try: + # 设置分布式环境 + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + torch.cuda.set_device(rank) # 使用本地GPU ID 0,因为每个进程只能看到一个GPU + + print( + f"[Rank {rank}] Using GPU: {torch.cuda.current_device()} " + f"(CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']})" + ) + + if rank == 0: + print(f"[Rank 0] Starting initialization") + hf_model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda:0") + print(f"[Rank 0] HF model loaded") + + group = init_custom_process_group( + backend="nccl", + init_method="tcp://localhost:29500", + world_size=world_size, + rank=rank, + group_name="test_parameter_update_group", + ) + print(f"[Rank 0] Process group initialized") + print(f"[Rank 0] before barrier") + dist.barrier(group=group) + print(f"[Rank 0] after barrier") + + elif rank == 1: + print(f"[Rank 1] Starting server launch") + process = popen_launch_server( + model_name, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=("--base-gpu-id", str(rank)), + ) + server_pid.value = process.pid + print(f"[Rank 1] Server launched with pid {process.pid}") + + response = requests.post( + f"{base_url}/init_parameter_update_group", + json={ + "master_address": "localhost", + "master_port": "29500", + "rank_offset": 1, + "world_size": world_size, + "group_name": "test_parameter_update_group", + "backend": "nccl", + }, + timeout=30, + ) + print( + f"[Rank 1] Parameter update group initialized with response: {response.status_code}" + ) + + print(f"[Rank {rank}] Process initialization completed") + + except Exception as e: + print(f"[Rank {rank}] Error occurred: {str(e)}") + raise + + @classmethod + def setUpClass(cls): + cls.world_size = 2 + cls.model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.server_pid = mp.Value("i", 0) + + print("Starting multiprocessing spawn") + mp.spawn( + cls.init_process, + args=( + cls.world_size, + cls.base_url, + cls.model_name, + cls.server_pid, + ), + nprocs=cls.world_size, + join=True, + ) + print("Multiprocessing spawn completed") + + @classmethod + def tearDownClass(cls): + print("Starting teardown") + # 先清理分布式进程组 + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + print("Process group destroyed") + + # 然后清理服务器进程 + if cls.server_pid.value != 0: + print(f"Cleaning up server process {cls.server_pid.value}") + kill_child_process(cls.server_pid.value, include_self=True) + print("Server process cleaned up") + + time.sleep(1) # 给进程一些清理的时间 + + def test_init_parameter_update_group(self): + print( + "Successfully initialized parameter update group between huggingface and SGLang server." + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_metrics.py b/test/srt/test_metrics.py index 163a7cc0e06..94479ab02e9 100644 --- a/test/srt/test_metrics.py +++ b/test/srt/test_metrics.py @@ -1,3 +1,4 @@ +import json import unittest import requests @@ -11,71 +12,99 @@ ) -class TestEnableMetrics(unittest.TestCase): - def test_metrics_enabled(self): - """Test that metrics endpoint returns data when enabled""" - process = popen_launch_server( - DEFAULT_SMALL_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_TEST, +class TestUpdateWeights(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-metrics"], + other_args=("--mem-",), ) - try: - # Make some requests to generate some metrics - response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate") - self.assertEqual(response.status_code, 200) - - response = requests.post( - f"{DEFAULT_URL_FOR_TEST}/generate", - json={ - "text": "The capital of France is", - "sampling_params": { - "temperature": 0, - "max_new_tokens": 32, - }, - "stream": True, + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid, include_self=True) + + def run_decode(self): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, }, - stream=True, - ) - for _ in response.iter_lines(decode_unicode=False): - pass - - # Get metrics - metrics_response = requests.get(f"{DEFAULT_URL_FOR_TEST}/metrics") - self.assertEqual(metrics_response.status_code, 200) - metrics_content = metrics_response.text - - print(f"metrics_content=\n{metrics_content}") - - # Verify essential metrics are present - essential_metrics = [ - "sglang:num_running_reqs", - "sglang:token_usage", - "sglang:gen_throughput", - "sglang:cache_hit_rate", - "sglang:func_latency_seconds", - "sglang:prompt_tokens_total", - "sglang:generation_tokens_total", - "sglang:time_to_first_token_seconds", - "sglang:time_per_output_token_seconds", - "sglang:e2e_request_latency_seconds", - ] - - for metric in essential_metrics: - self.assertIn(metric, metrics_content, f"Missing metric: {metric}") - - # Verify model name label is present and correct - expected_model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - self.assertIn(f'model_name="{expected_model_name}"', metrics_content) - - # Verify metrics have values (not empty) - self.assertIn("_sum{", metrics_content) - self.assertIn("_count{", metrics_content) - self.assertIn("_bucket{", metrics_content) - - finally: - kill_child_process(process.pid, include_self=True) + }, + ) + print(json.dumps(response.json())) + print("=" * 100) + text = response.json()["text"] + return text + + def get_model_info(self): + response = requests.get(self.base_url + "/get_model_info") + model_path = response.json()["model_path"] + print(json.dumps(response.json())) + return model_path + + def run_update_weights(self, model_path): + response = requests.post( + self.base_url + "/update_weights_from_disk", + json={ + "model_path": model_path, + }, + ) + ret = response.json() + print(json.dumps(response.json())) + return ret + + def test_update_weights(self): + origin_model_path = self.get_model_info() + print(f"origin_model_path: {origin_model_path}") + origin_response = self.run_decode() + + # update weights + new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "") + ret = self.run_update_weights(new_model_path) + assert ret["success"] + + updated_model_path = self.get_model_info() + print(f"updated_model_path: {updated_model_path}") + assert updated_model_path == new_model_path + assert updated_model_path != origin_model_path + + updated_response = self.run_decode() + assert origin_response[:32] != updated_response[:32] + + # update weights back + ret = self.run_update_weights(origin_model_path) + assert ret["success"] + + updated_model_path = self.get_model_info() + assert updated_model_path == origin_model_path + + updated_response = self.run_decode() + assert origin_response[:32] == updated_response[:32] + + def test_update_weights_unexist_model(self): + origin_model_path = self.get_model_info() + print(f"origin_model_path: {origin_model_path}") + origin_response = self.run_decode() + + # update weights + new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "wrong") + ret = self.run_update_weights(new_model_path) + assert not ret["success"] + + updated_model_path = self.get_model_info() + print(f"updated_model_path: {updated_model_path}") + assert updated_model_path == origin_model_path + + updated_response = self.run_decode() + assert origin_response[:32] == updated_response[:32] if __name__ == "__main__": diff --git a/test/srt/test_update_parameter_from_distributed.py b/test/srt/test_update_parameter_from_distributed.py new file mode 100644 index 00000000000..45c3008d841 --- /dev/null +++ b/test/srt/test_update_parameter_from_distributed.py @@ -0,0 +1,867 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ModelRunner runs the forward passes of the models.""" + +import gc +import importlib +import importlib.resources +import inspect +import json +import logging +import pkgutil +from functools import lru_cache +from typing import Any, Optional, Type, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from vllm.config import DeviceConfig, LoadConfig +from vllm.config import ModelConfig as VllmModelConfig +from vllm.distributed import ( + get_tp_group, + init_distributed_environment, + initialize_model_parallel, + set_custom_all_reduce, +) +from vllm.distributed.parallel_state import in_the_same_node_as +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import ModelRegistry + +from sglang.srt.configs.model_config import AttentionArch, ModelConfig +from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend +from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend +from sglang.srt.layers.attention.triton_backend import TritonAttnBackend +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.sampler import Sampler +from sglang.srt.lora.lora_manager import LoRAManager +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.mem_cache.memory_pool import ( + DoubleSparseTokenToKVPool, + MHATokenToKVPool, + MLATokenToKVPool, + ReqToTokenPool, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import ( + crash_on_warnings, + enable_show_time_cost, + get_available_gpu_memory, + init_custom_process_group, + is_hip, + monkey_patch_vllm_model_config, + monkey_patch_vllm_p2p_access_check, + set_cpu_offload_max_bytes, +) + +logger = logging.getLogger(__name__) + + +class ModelRunner: + """ModelRunner runs the forward passes of the models.""" + + def __init__( + self, + model_config: ModelConfig, + mem_fraction_static: float, + gpu_id: int, + tp_rank: int, + tp_size: int, + nccl_port: int, + server_args: ServerArgs, + ): + # Parse args + self.model_config = model_config + self.mem_fraction_static = mem_fraction_static + self.device = server_args.device + self.gpu_id = gpu_id + self.tp_rank = tp_rank + self.tp_size = tp_size + self.dist_port = nccl_port + self.server_args = server_args + self.is_generation = model_config.is_generation + self.is_multimodal = model_config.is_multimodal + + # Model-specific adjustment + if ( + self.model_config.attention_arch == AttentionArch.MLA + and not self.server_args.disable_mla + ): + logger.info("MLA optimization is turned on. Use triton backend.") + self.server_args.attention_backend = "triton" + + if self.server_args.enable_double_sparsity: + logger.info( + "Double sparsity optimization is turned on. Use triton backend without CUDA graph." + ) + self.server_args.attention_backend = "triton" + self.server_args.disable_cuda_graph = True + if self.server_args.ds_heavy_channel_type is None: + raise ValueError( + "Please specify the heavy channel type for double sparsity optimization." + ) + self.init_double_sparsity_channel_config( + self.server_args.ds_heavy_channel_type + ) + + if self.is_multimodal: + logger.info( + "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." + ) + server_args.chunked_prefill_size = None + self.mem_fraction_static *= 0.95 + # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically + if self.model_config.hf_config.architectures == [ + "Qwen2VLForConditionalGeneration" + ]: + server_args.disable_radix_cache = True + + # Global vars + if server_args.show_time_cost: + enable_show_time_cost() + if server_args.disable_disk_cache: + from outlines.caching import disable_cache + + disable_cache() + + global_server_args_dict.update( + { + "attention_backend": server_args.attention_backend, + "sampling_backend": server_args.sampling_backend, + "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, + "disable_mla": server_args.disable_mla, + "torchao_config": server_args.torchao_config, + "enable_nan_detection": server_args.enable_nan_detection, + "enable_dp_attention": server_args.enable_dp_attention, + } + ) + + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) + + # Init components + min_per_gpu_memory = self.init_torch_distributed() + self.sampler = Sampler() + self.load_model() + + # Apply torch TP if model supports it + supports_torch_tp = getattr(self.model, "supports_torch_tp", False) + if self.tp_size > 1 and supports_torch_tp: + self.apply_torch_tp() + self.torch_tp_applied = True + else: + self.torch_tp_applied = False + + if server_args.lora_paths is not None: + self.init_lora_manager() + self.init_memory_pool( + min_per_gpu_memory, + server_args.max_running_requests, + server_args.max_total_tokens, + ) + if self.device == "cuda": + self.init_cublas() + self.init_attention_backend() + self.init_cuda_graphs() + else: + self.cuda_graph_runner = None + self.init_attention_backend() + + def init_torch_distributed(self): + logger.info("Init torch distributed begin.") + # Init torch distributed + torch.get_device_module(self.device).set_device(self.gpu_id) + if self.device == "cuda": + backend = "nccl" + # ToDO(liangan1):Just use gloo to bypass the initilization fail + # Need to use xccl for xpu backend in the future + elif self.device == "xpu": + backend = "gloo" + elif self.device == "hpu": + backend = "hccl" + + if not self.server_args.enable_p2p_check: + monkey_patch_vllm_p2p_access_check(self.gpu_id) + if self.server_args.dist_init_addr: + dist_init_method = f"tcp://{self.server_args.dist_init_addr}" + else: + dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" + set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) + init_distributed_environment( + backend=backend, + world_size=self.tp_size, + rank=self.tp_rank, + local_rank=self.gpu_id, + distributed_init_method=dist_init_method, + ) + initialize_model_parallel(tensor_model_parallel_size=self.tp_size) + min_per_gpu_memory = get_available_gpu_memory( + self.device, self.gpu_id, distributed=self.tp_size > 1 + ) + self.tp_group = get_tp_group() + + # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph, + # so we disable padding in cuda graph. + if self.device == "cuda" and not all( + in_the_same_node_as(self.tp_group.cpu_group, source_rank=0) + ): + self.server_args.disable_cuda_graph_padding = True + logger.info( + "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism." + ) + + # Check memory for tensor parallelism + if self.tp_size > 1: + local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id) + if min_per_gpu_memory < local_gpu_memory * 0.9: + raise ValueError( + "The memory capacity is unbalanced. Some GPUs may be occupied by other processes." + ) + + return min_per_gpu_memory + + def setup_model(self): + try: + from vllm.config import VllmConfig + + vllm_config = VllmConfig() + vllm_config.model_config = self.vllm_model_config + vllm_config.load_config = self.load_config + vllm_config.device_config = DeviceConfig(self.device) + vllm_config.quant_config = VllmConfig._get_quantization_config( + vllm_config.model_config, vllm_config.load_config + ) + return get_model(vllm_config=vllm_config) + except ImportError: + pass + + return get_model( + model_config=self.vllm_model_config, + load_config=self.load_config, + device_config=DeviceConfig(self.device), + parallel_config=None, + scheduler_config=None, + lora_config=None, + cache_config=None, + ) + + def get_model_config_params(self): + sig = inspect.signature(VllmModelConfig.__init__) + params = { + "model": self.server_args.model_path, + "quantization": self.server_args.quantization, + "tokenizer": None, + "tokenizer_mode": None, + "trust_remote_code": self.server_args.trust_remote_code, + "dtype": self.server_args.dtype, + "seed": self.server_args.random_seed, + "skip_tokenizer_init": True, + } + + if "task" in sig.parameters: + params["task"] = "" + + return params + + def load_model(self): + logger.info( + f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + ) + + # This can reduce thread conflicts and speed up weight loading. + torch.set_num_threads(1) + if self.device == "cuda": + if torch.cuda.get_device_capability()[0] < 8: + logger.info( + "Compute capability below sm80. Use float16 due to lack of bfloat16 support." + ) + self.server_args.dtype = "float16" + if torch.cuda.get_device_capability()[1] < 5: + raise RuntimeError("SGLang only supports sm75 and above.") + + # Prepare the vllm model config + self.load_config = LoadConfig( + load_format=self.server_args.load_format, + download_dir=self.server_args.download_dir, + ) + monkey_patch_vllm_model_config() + self.vllm_model_config = VllmModelConfig(**self.get_model_config_params()) + if self.model_config.model_override_args is not None: + self.vllm_model_config.hf_config.update( + self.model_config.model_override_args + ) + + self.model = self.setup_model() + + self.sliding_window_size = ( + self.model.get_attention_sliding_window_size() + if hasattr(self.model, "get_attention_sliding_window_size") + else None + ) + self.dtype = self.vllm_model_config.dtype + + logger.info( + f"Load weight end. " + f"type={type(self.model).__name__}, " + f"dtype={self.dtype}, " + f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + ) + + def update_weights_from_disk(self, model_path: str, load_format: str): + """Update engine weights online from disk.""" + from vllm.model_executor.model_loader.loader import ( + DefaultModelLoader, + device_loading_context, + get_model_loader, + ) + from vllm.model_executor.model_loader.utils import set_default_torch_dtype + + logger.info( + f"Update engine weights online from disk begin. " + f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + ) + + target_device = torch.device(self.device) + + try: + model_config_params = self.get_model_config_params() + model_config_params["model"] = model_path + vllm_model_config = VllmModelConfig(**model_config_params) + except Exception as e: + message = f"Failed to load model config: {e}." + return False, message + + load_config = LoadConfig(load_format=load_format) + + # Only support vllm DefaultModelLoader for now + loader = get_model_loader(load_config) + if not isinstance(loader, DefaultModelLoader): + message = f"Failed to get model loader: {loader}." + return False, message + + def get_weight_iter(config): + iter = loader._get_weights_iterator( + DefaultModelLoader.Source( + config.model, + revision=config.revision, + fall_back_to_pt=getattr( + self.model, "fall_back_to_pt_during_load", True + ), + ) + ) + return iter + + def model_load_weights(model, iter): + model.load_weights(iter) + for _, module in self.model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + return model + + with set_default_torch_dtype(vllm_model_config.dtype): + try: + iter = get_weight_iter(vllm_model_config) + except Exception as e: + message = f"Failed to get weights iterator: {e}." + return False, message + try: + model = model_load_weights(self.model, iter) + except Exception as e: + message = ( + f"Failed to update weights: {e}.\nRolling back to original weights." + ) + del iter + gc.collect() + iter = get_weight_iter(self.vllm_model_config) + self.model = model_load_weights(self.model, iter) + return False, message + + self.model = model + self.server_args.model_path = model_path + self.server_args.load_format = load_format + self.vllm_model_config = vllm_model_config + self.load_config = load_config + self.model_config.path = model_path + + logger.info("Update weights end.") + return True, "Succeeded to update model weights." + + def init_parameter_update_group( + self, + master_address, + master_port, + rank_offset, + world_size, + group_name, + backend="nccl", + ): + """Initialize the Torch process group for model parameter updates. + + `_model_update_group` is used in the RLHF workflow, where rank 0 is the actor model in + the training engine, and the other ranks are the inference engine, which is used for rollout. + + In the RLHF workflow, the training engine updates the model weights/parameters online, + and broadcasts them to the inference engine through the `_model_update_group` process group. + """ + assert ( + torch.distributed.is_initialized() + ), "Default torch process group must be initialized" + assert group_name != "", "Group name cannot be empty" + + rank = rank_offset + self.tp_rank + + logger.info( + f"init custom process group: master_address={master_address}, master_port={master_port}, " + f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}" + ) + + try: + self._model_update_group = init_custom_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=rank, + group_name=group_name, + ) + + return True, "Succeeded to initialize custom process group." + except Exception as e: + message = f"Failed to initialize custom process group: {e}." + logger.error(message) + return False, message + + def get_weights_by_parameter_name( + self, name: str, truncate_size: int = 100 + ) -> Optional[torch.Tensor]: + try: + # 检查是否是合并的参数 + mapped_name = name + mapped_shard_id = None + for param_name, weight_name, shard_id in self.model.stacked_params_mapping: + if weight_name in name: + mapped_name = name.replace(weight_name, param_name) + mapped_shard_id = shard_id + break + params_dict = dict(self.model.named_parameters()) + if mapped_name in params_dict: + param = params_dict[mapped_name] + if mapped_shard_id is not None: + # 处理合并参数的情况 + if mapped_shard_id in ["q", "k", "v"]: + # 计算在qkv_proj中的偏移和大小 + num_heads = ( + self.model.config.num_attention_heads // self.tp_size + ) + num_kv_heads = ( + self.model.config.num_key_value_heads // self.tp_size + ) + head_dim = ( + self.model.config.hidden_size + // self.model.config.num_attention_heads + ) + + if mapped_shard_id == "q": + offset = 0 + size = num_heads * head_dim + elif mapped_shard_id == "k": + offset = num_heads * head_dim + size = num_kv_heads * head_dim + elif mapped_shard_id == "v": + offset = (num_heads + num_kv_heads) * head_dim + size = num_kv_heads * head_dim + + # 提取对应部分的权重 + weight = param.data.narrow(0, offset, size) + elif mapped_shard_id in [0, 1]: + # 处理 gate_up_proj 的情况 + intermediate_size = self.model.config.intermediate_size + hidden_size = self.model.config.hidden_size + + if mapped_shard_id == 0: # gate_proj + offset = 0 + size = intermediate_size + elif mapped_shard_id == 1: # up_proj + offset = intermediate_size + size = intermediate_size + + # 提取对应部分的权重 + weight = param.data.narrow(0, offset, size) + else: + weight = param.data + else: + weight = param.data + + # 转换并截断 + return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size] + else: + logger.warning( + f"Parameter {name} (mapped to {mapped_name}) not found in model" + ) + return None + + except Exception as e: + logger.error(f"Error when getting parameter {name}: {e}") + return None + + def init_lora_manager(self): + self.lora_manager = LoRAManager( + base_model=self.model, + lora_paths=self.server_args.lora_paths, + base_hf_config=self.model_config.hf_config, + max_loras_per_batch=self.server_args.max_loras_per_batch, + load_config=self.load_config, + dtype=self.dtype, + ) + logger.info("LoRA manager ready.") + + def profile_max_num_token(self, total_gpu_memory: int): + available_gpu_memory = get_available_gpu_memory( + self.device, self.gpu_id, distributed=self.tp_size > 1 + ) + if ( + self.model_config.attention_arch == AttentionArch.MLA + and not self.server_args.disable_mla + ): + cell_size = ( + (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) + * self.model_config.num_hidden_layers + * torch._utils._element_size(self.kv_cache_dtype) + ) + else: + cell_size = ( + self.model_config.get_num_kv_heads(self.tp_size) + * self.model_config.head_dim + * self.model_config.num_hidden_layers + * 2 + * torch._utils._element_size(self.kv_cache_dtype) + ) + rest_memory = available_gpu_memory - total_gpu_memory * ( + 1 - self.mem_fraction_static + ) + max_num_token = int(rest_memory * (1 << 30) // cell_size) + return max_num_token + + def init_memory_pool( + self, + total_gpu_memory: int, + max_num_reqs: Optional[int] = None, + max_total_tokens: Optional[int] = None, + ): + if self.server_args.kv_cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + elif self.server_args.kv_cache_dtype == "fp8_e5m2": + if is_hip(): # Using natively supported format + self.kv_cache_dtype = torch.float8_e5m2fnuz + else: + self.kv_cache_dtype = torch.float8_e5m2 + else: + raise ValueError( + f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." + ) + + self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) + if max_total_tokens is not None: + if max_total_tokens > self.max_total_num_tokens: + logging.warning( + f"max_total_tokens={max_total_tokens} is larger than the profiled value " + f"{self.max_total_num_tokens}. " + f"Use the profiled value instead." + ) + self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens) + + if self.max_total_num_tokens <= 0: + raise RuntimeError( + "Not enough memory. Please try to increase --mem-fraction-static." + ) + + if max_num_reqs is None: + max_num_reqs = min( + max( + int( + self.max_total_num_tokens / self.model_config.context_len * 512 + ), + 2048, + ), + 4096, + ) + + self.req_to_token_pool = ReqToTokenPool( + size=max_num_reqs + 1, + max_context_len=self.model_config.context_len + 4, + device=self.device, + use_records=False, + ) + if ( + self.model_config.attention_arch == AttentionArch.MLA + and not self.server_args.disable_mla + ): + self.token_to_kv_pool = MLATokenToKVPool( + self.max_total_num_tokens, + dtype=self.kv_cache_dtype, + kv_lora_rank=self.model_config.kv_lora_rank, + qk_rope_head_dim=self.model_config.qk_rope_head_dim, + layer_num=self.model_config.num_hidden_layers, + device=self.device, + ) + elif self.server_args.enable_double_sparsity: + self.token_to_kv_pool = DoubleSparseTokenToKVPool( + self.max_total_num_tokens, + dtype=self.kv_cache_dtype, + head_num=self.model_config.get_num_kv_heads(self.tp_size), + head_dim=self.model_config.head_dim, + layer_num=self.model_config.num_hidden_layers, + device=self.device, + heavy_channel_num=self.server_args.ds_heavy_channel_num, + ) + else: + self.token_to_kv_pool = MHATokenToKVPool( + self.max_total_num_tokens, + dtype=self.kv_cache_dtype, + head_num=self.model_config.get_num_kv_heads(self.tp_size), + head_dim=self.model_config.head_dim, + layer_num=self.model_config.num_hidden_layers, + device=self.device, + ) + logger.info( + f"Memory pool end. " + f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + ) + + def init_cublas(self): + """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later.""" + dtype = torch.float16 + device = "cuda" + a = torch.ones((16, 16), dtype=dtype, device=device) + b = torch.ones((16, 16), dtype=dtype, device=device) + c = a @ b + return c + + def init_attention_backend(self): + """Init attention kernel backend.""" + if self.server_args.attention_backend == "flashinfer": + self.attn_backend = FlashInferAttnBackend(self) + elif self.server_args.attention_backend == "triton": + assert self.sliding_window_size is None, ( + "Window attention is not supported in the triton attention backend. " + "Please use `--attention-backend flashinfer`." + ) + assert not self.model_config.is_encoder_decoder, ( + "Cross attention is not supported in the triton attention backend. " + "Please use `--attention-backend flashinfer`." + ) + if self.server_args.enable_double_sparsity: + self.attn_backend = DoubleSparseAttnBackend(self) + else: + self.attn_backend = TritonAttnBackend(self) + else: + raise ValueError( + f"Invalid attention backend: {self.server_args.attention_backend}" + ) + + def init_double_sparsity_channel_config(self, selected_channel): + + selected_channel = "." + selected_channel + "_proj" + self.sorted_channels = [] + # load channel config + with open(self.server_args.ds_channel_config_path, "r") as f: + channel_config = json.load(f) + + for i in range(self.model_config.num_hidden_layers): + key = "model.layers." + str(i) + ".self_attn" + selected_channel + self.sorted_channels.append( + torch.tensor(channel_config[key])[ + :, : self.server_args.ds_heavy_channel_num + ] + .contiguous() + .cuda() + ) + + def init_cuda_graphs(self): + """Capture cuda graphs.""" + from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner + + self.cuda_graph_runner = None + + if not self.is_generation: + # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models + return + + if self.server_args.disable_cuda_graph: + return + + logger.info("Capture cuda graph begin. This can take up to several minutes.") + self.cuda_graph_runner = CudaGraphRunner(self) + + def apply_torch_tp(self): + logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") + from sglang.srt.model_parallel import tensor_parallel + + device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,)) + tensor_parallel(self.model, device_mesh) + + def forward_decode(self, forward_batch: ForwardBatch): + if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): + return self.cuda_graph_runner.replay(forward_batch) + + forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64) + self.attn_backend.init_forward_metadata(forward_batch) + return self.model.forward( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + + def forward_extend(self, forward_batch: ForwardBatch): + self.attn_backend.init_forward_metadata(forward_batch) + if self.is_generation: + if forward_batch.input_embeds is None: + return self.model.forward( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + else: + return self.model.forward( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + input_embeds=forward_batch.input_embeds.bfloat16(), + ) + else: + # Only embedding models have get_embedding parameter + return self.model.forward( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + get_embedding=True, + ) + + def forward_idle(self, forward_batch: ForwardBatch): + if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): + return self.cuda_graph_runner.replay(forward_batch) + + return self.model.forward( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + + def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput: + if forward_batch.forward_mode.is_decode(): + return self.forward_decode(forward_batch) + elif forward_batch.forward_mode.is_extend(): + return self.forward_extend(forward_batch) + elif forward_batch.forward_mode.is_idle(): + return self.forward_idle(forward_batch) + else: + raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}") + + def sample( + self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch + ) -> torch.Tensor: + sampling_info = forward_batch.sampling_info + if sampling_info.sampling_info_done: + # Overlap mode: the function update_regex_vocab_mask was executed + # in process_batch_result of the last batch. + if sampling_info.grammars: + sampling_info.sampling_info_done.wait() + else: + # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass. + sampling_info.update_regex_vocab_mask() + sampling_info.update_penalties() + logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info) + + # Sample the next tokens. + next_token_ids = self.sampler(logits, sampling_info) + return next_token_ids + + def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): + # Apply logit_bias + if sampling_info.logit_bias is not None: + logits.add_(sampling_info.logit_bias) + + # min-token, presence, frequency + if sampling_info.linear_penalties is not None: + logits.add_(sampling_info.linear_penalties) + + # repetition + if sampling_info.scaling_penalties is not None: + logits = torch.where( + logits > 0, + logits / sampling_info.scaling_penalties, + logits * sampling_info.scaling_penalties, + ) + + # Apply regex vocab_mask + if sampling_info.vocab_mask is not None: + sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask) + + return logits + + @property + def model_is_mrope(self) -> bool: + """Detect if the model has "mrope" rope_scaling type. + mrope requires keep "rope_deltas" between prompt and decoding phases.""" + rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) + if rope_scaling is None: + return False + return rope_scaling.get("type", None) == "mrope" + + +@lru_cache() +def import_model_classes(): + model_arch_name_to_cls = {} + package_name = "sglang.srt.models" + package = importlib.import_module(package_name) + for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."): + if not ispkg: + try: + module = importlib.import_module(name) + except Exception as e: + logger.warning(f"Ignore import error when loading {name}. {e}") + if crash_on_warnings(): + raise ValueError(f"Ignore import error when loading {name}. {e}") + continue + if hasattr(module, "EntryClass"): + entry = module.EntryClass + if isinstance( + entry, list + ): # To support multiple model classes in one module + for tmp in entry: + assert ( + tmp.__name__ not in model_arch_name_to_cls + ), f"Duplicated model implementation for {tmp.__name__}" + model_arch_name_to_cls[tmp.__name__] = tmp + else: + assert ( + entry.__name__ not in model_arch_name_to_cls + ), f"Duplicated model implementation for {entry.__name__}" + model_arch_name_to_cls[entry.__name__] = entry + + return model_arch_name_to_cls + + +def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: + model_arch_name_to_cls = import_model_classes() + + if model_arch not in model_arch_name_to_cls: + raise ValueError( + f"Unsupported architectures: {model_arch}. " + f"Supported list: {list(model_arch_name_to_cls.keys())}" + ) + return model_arch_name_to_cls[model_arch] + + +# Monkey patch model loader +setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt) +setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False) +setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False) +setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False) +setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False) diff --git a/test/srt/test_update_weights.py b/test/srt/test_update_weights.py index 327da729aad..2f53f12dbd7 100644 --- a/test/srt/test_update_weights.py +++ b/test/srt/test_update_weights.py @@ -49,7 +49,7 @@ def get_model_info(self): def run_update_weights(self, model_path): response = requests.post( - self.base_url + "/update_weights", + self.base_url + "/update_weights_from_disk", json={ "model_path": model_path, },