From 0a24eb850a8ed690c5ae5f3cbffc10f2b0c1c42e Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Mon, 28 Oct 2024 12:02:23 -0700 Subject: [PATCH 1/7] Fix update_weights deadlock for DP (#1825) --- .../sglang/srt/managers/tokenizer_manager.py | 57 ++++++++++++++----- test/srt/test_data_parallelism.py | 23 ++++++++ 2 files changed, 67 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 347e7ad1d01..428bf10d75f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -554,18 +554,43 @@ async def update_weights( obj.load_format = self.server_args.load_format if not self.model_update_lock.locked(): - async with self.model_update_lock: - # wait for the previous generation requests to finish - while len(self.rid_to_state) > 0: - await asyncio.sleep(0.001) - self.send_to_scheduler.send_pyobj(obj) - self.model_update_result = asyncio.Future() - result = await self.model_update_result - if result.success: - self.server_args.model_path = obj.model_path - self.server_args.load_format = obj.load_format - self.model_path = obj.model_path - return result.success, result.message + + if self.server_args.dp_size == 1: + async with self.model_update_lock: + # wait for the previous generation requests to finish + while len(self.rid_to_state) > 0: + await asyncio.sleep(0.001) + self.send_to_scheduler.send_pyobj(obj) + self.model_update_result = asyncio.Future() + result = await self.model_update_result + if result.success: + self.server_args.model_path = obj.model_path + self.server_args.load_format = obj.load_format + self.model_path = obj.model_path + return result.success, result.message + + else: # self.server_args.dp_size > 1 + + # There will be dp_size number of response from the detokenizer + async with self.model_update_lock: + # wait for the previous generation requests to finish + while len(self.rid_to_state) > 0: + await asyncio.sleep(0.001) + self.send_to_scheduler.send_pyobj(obj) + self.model_update_result = asyncio.Future() + self.model_update_tmp = [] + result = await self.model_update_result + + all_success = all([r.success for r in result]) + if all_success is True: + self.server_args.model_path = obj.model_path + self.server_args.load_format = obj.load_format + self.model_path = obj.model_path + all_message = [r.message for r in result] + all_message = " | ".join(all_message) + + return all_success, all_message + else: return False, "Another update is in progress. Please try again later." @@ -600,7 +625,13 @@ async def handle_loop(self): ] = await self.recv_from_detokenizer.recv_pyobj() if isinstance(recv_obj, UpdateWeightReqOutput): - self.model_update_result.set_result(recv_obj) + if self.server_args.dp_size == 1: + self.model_update_result.set_result(recv_obj) + else: # self.server_args.dp_size > 1 + self.model_update_tmp.append(recv_obj) + # set future if the all results are recevied + if len(self.model_update_tmp) == self.server_args.dp_size: + self.model_update_result.set_result(self.model_update_tmp) continue elif isinstance(recv_obj, GetMemPoolSizeReqOutput): self.mem_pool_size.set_result(recv_obj) diff --git a/test/srt/test_data_parallelism.py b/test/srt/test_data_parallelism.py index 5f17994a2d5..00bae0a880b 100644 --- a/test/srt/test_data_parallelism.py +++ b/test/srt/test_data_parallelism.py @@ -1,6 +1,9 @@ +import time import unittest from types import SimpleNamespace +import requests + from sglang.srt.utils import kill_child_process from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( @@ -39,6 +42,26 @@ def test_mmlu(self): metrics = run_eval(args) assert metrics["score"] >= 0.65 + def test_update_weight(self): + response = requests.post( + self.base_url + "/update_weights", + json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST}, + ) + + # check if the response is 200 + assert response.status_code == 200 + + # pause a few seconds then send again + time.sleep(5) + + response = requests.post( + self.base_url + "/update_weights", + json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST}, + ) + + # check if the response is 200 + assert response.status_code == 200 + if __name__ == "__main__": unittest.main() From 680cad20233be46da97e92db0ba29d2b8fa41c03 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Mon, 28 Oct 2024 23:07:14 -0700 Subject: [PATCH 2/7] fix get_memory_pool_size deadlock for DP (#1830) --- .../sglang/srt/managers/tokenizer_manager.py | 27 ++++++++++++++++--- python/sglang/srt/server.py | 3 ++- test/srt/test_data_parallelism.py | 9 +++++++ 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 428bf10d75f..9a3e9096952 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -539,9 +539,22 @@ async def get_memory_pool_size(self): self.create_handle_loop() req = GetMemPoolSizeReq() - self.send_to_scheduler.send_pyobj(req) - self.mem_pool_size = asyncio.Future() - return await self.mem_pool_size + ret = None + + if self.server_args.dp_size == 1: + self.send_to_scheduler.send_pyobj(req) + self.mem_pool_size = asyncio.Future() + res = await self.mem_pool_size + ret = res.size + + else: # self.server_args.dp_size > 1 + self.send_to_scheduler.send_pyobj(req) + self.mem_pool_size = asyncio.Future() + self.mem_pool_size_tmp = [] + res = await self.mem_pool_size + ret = [r.size for r in res] + + return ret async def update_weights( self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None @@ -634,7 +647,13 @@ async def handle_loop(self): self.model_update_result.set_result(self.model_update_tmp) continue elif isinstance(recv_obj, GetMemPoolSizeReqOutput): - self.mem_pool_size.set_result(recv_obj) + if self.server_args.dp_size == 1: + self.mem_pool_size.set_result(recv_obj) + else: # self.sever_args.dp_size > 1 + self.mem_pool_size_tmp.append(recv_obj) + # set future if the all results are received + if len(self.mem_pool_size_tmp) == self.server_args.dp_size: + self.mem_pool_size.set_result(self.mem_pool_size_tmp) continue assert isinstance( diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 64f6c6f5504..c9d9c7ee563 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -177,7 +177,8 @@ async def get_memory_pool_size(): """Get the memory pool size in number of tokens""" try: ret = await tokenizer_manager.get_memory_pool_size() - return ret.size + + return ret except Exception as e: return JSONResponse( {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST diff --git a/test/srt/test_data_parallelism.py b/test/srt/test_data_parallelism.py index 00bae0a880b..0ac8b784c39 100644 --- a/test/srt/test_data_parallelism.py +++ b/test/srt/test_data_parallelism.py @@ -62,6 +62,15 @@ def test_update_weight(self): # check if the response is 200 assert response.status_code == 200 + def test_get_memory_pool_size(self): + response = requests.get(self.base_url + "/get_memory_pool_size") + assert response.status_code == 200 + + time.sleep(5) + + response = requests.get(self.base_url + "/get_memory_pool_size") + assert response.status_code == 200 + if __name__ == "__main__": unittest.main() From 5e6c32657e384b023faf03d79e06f7727feedb7c Mon Sep 17 00:00:00 2001 From: Yanyi Liu Date: Tue, 29 Oct 2024 14:51:47 +0800 Subject: [PATCH 3/7] Support setting `use_thread` in the `run_program` for easier debugging. (#1823) Co-authored-by: Byron Hsu --- python/sglang/lang/interpreter.py | 10 +++++++++- python/sglang/lang/ir.py | 11 ++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 44ea17f666e..55a20336bc7 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -54,7 +54,14 @@ def run_internal(state, program, func_args, func_kwargs, sync): def run_program( - program, backend, func_args, func_kwargs, default_sampling_para, stream, sync=False + program, + backend, + func_args, + func_kwargs, + default_sampling_para, + stream, + sync=False, + use_thread=True, ): if hasattr(backend, "endpoint"): backend = backend.endpoint @@ -67,6 +74,7 @@ def run_program( chat_template=None, stream=stream, num_api_spec_tokens=program.num_api_spec_tokens, + use_thread=use_thread, ) state = ProgramState(stream_executor) diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 8164478ed39..d3c010108e3 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -168,6 +168,7 @@ def run( return_text_in_logprobs: Optional[bool] = None, stream: bool = False, backend=None, + use_thread: bool = True, **kwargs, ): from sglang.lang.interpreter import run_program @@ -195,7 +196,15 @@ def run( return_text_in_logprobs=return_text_in_logprobs, ) backend = backend or global_config.default_backend - return run_program(self, backend, args, kwargs, default_sampling_para, stream) + return run_program( + self, + backend, + args, + kwargs, + default_sampling_para, + stream, + use_thread=use_thread, + ) def run_batch( self, From 5010e0d2ca87716c872b6c78c0c754128812bd90 Mon Sep 17 00:00:00 2001 From: HAI Date: Tue, 29 Oct 2024 10:51:02 -0700 Subject: [PATCH 4/7] [3rdparty, document] Add 3rdparty/amd, with profiling and tuning instructions to be added (#1822) --- 3rdparty/amd/profiling/PROFILING.md | 10 ++++++++++ 3rdparty/amd/tuning/TUNING.md | 13 +++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 3rdparty/amd/profiling/PROFILING.md create mode 100644 3rdparty/amd/tuning/TUNING.md diff --git a/3rdparty/amd/profiling/PROFILING.md b/3rdparty/amd/profiling/PROFILING.md new file mode 100644 index 00000000000..de0c2ef71af --- /dev/null +++ b/3rdparty/amd/profiling/PROFILING.md @@ -0,0 +1,10 @@ +## Profiling SGLang Infer System with AMD GPUs +This AppNote describes the SGLang profiling technical, code augment and running steps for systems with AMD Instinct GPUs, nevertheless the same procedure may work with Nvidia GPUs too. +Examples and steps are provided in detail, to facilitate easy reproduce and use to localize performance problem towards optimizations. +Two primary methods are covered: +- [RPD](https://github.com/ROCm/rocmProfileData.git) + + +- [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) + + diff --git a/3rdparty/amd/tuning/TUNING.md b/3rdparty/amd/tuning/TUNING.md new file mode 100644 index 00000000000..00f995ebbc7 --- /dev/null +++ b/3rdparty/amd/tuning/TUNING.md @@ -0,0 +1,13 @@ +## Tuning SGLang Infer System with AMD GPUs +This AppNote describes the SGLang performance tuning technical, code harness and running steps for systems with AMD Instinct GPUs. +Harness code, examples and steps are provided in detail, to facilitate easy reproduce & use to tune performance towards workloads. +Three primary runtime areas are covered: +- Triton Kernels + + +- Torch Tunable Ops + + +- Torch Compile + + From d04899d7ca645671335db6876758f0062f239ebc Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Wed, 30 Oct 2024 04:30:41 +0800 Subject: [PATCH 5/7] stop_str of qwen2-vl template should be a tuple not a str (#1834) Co-authored-by: Byron Hsu --- python/sglang/lang/chat_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index b8f9a533dee..ca5a7a26184 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -144,7 +144,7 @@ def get_chat_template_by_model_path(model_path): "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"), }, style=ChatTemplateStyle.PLAIN, - stop_str=("<|im_end|>"), + stop_str=("<|im_end|>",), image_token="<|vision_start|><|image_pad|><|vision_end|>", ) ) From 54dd3ea12277f782823c8067ed723279136c40bb Mon Sep 17 00:00:00 2001 From: HAI Date: Tue, 29 Oct 2024 13:58:03 -0700 Subject: [PATCH 6/7] =?UTF-8?q?[FP8=20KV=20Cache,=20Mixtral]=20Avoid=20Key?= =?UTF-8?q?Error=20at=20loading=20pre-quantized=20FP8=20m=E2=80=A6=20(#183?= =?UTF-8?q?5)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/sglang/srt/models/mixtral.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 6ad8023675e..dc4198b5245 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -369,6 +369,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip loading kv_scale from ckpts towards new design. + if name.endswith(".kv_scale") and name not in params_dict: + continue if name is None: continue From 5e00ddebc09e6919996e55407c45d89c50d6c522 Mon Sep 17 00:00:00 2001 From: DanielC12321 <73292458+DanielC12321@users.noreply.github.com> Date: Tue, 29 Oct 2024 19:52:33 -0500 Subject: [PATCH 7/7] Add new model: Gpt2 (#1833) --- python/sglang/srt/models/gpt2.py | 286 ++++++++++++++++++++++ test/srt/models/test_generation_models.py | 1 + 2 files changed, 287 insertions(+) create mode 100644 python/sglang/srt/models/gpt2.py diff --git a/python/sglang/srt/models/gpt2.py b/python/sglang/srt/models/gpt2.py new file mode 100644 index 00000000000..a5848210308 --- /dev/null +++ b/python/sglang/srt/models/gpt2.py @@ -0,0 +1,286 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py +# Copyright 2023 The vLLM team. +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-2 model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import GPT2Config +from vllm.config import CacheConfig +from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +#from sglang.srt.layers.activation import get_act_fn +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + + +class GPT2Attention(nn.Module): + + def __init__( + self, + layer_id: int, + config: GPT2Config, + cache_config = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads + self.scale = self.head_dim**-0.5 + + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_attn", + ) + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + self.attn = RadixAttention(self.num_heads, + self.head_dim, + scaling=self.scale, + num_kv_heads=total_num_heads, + layer_id=layer_id) + + def forward( + self, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + attn_output = self.attn(q, k, v, forward_batch) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class GPT2MLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPT2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_fc", + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + self.act = get_act_fn(config.activation_function, quant_config, + intermediate_size) + + def forward(self, hidden_states: torch.Tensor,) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class GPT2Block(nn.Module): + + def __init__( + self, + layer_id: int, + config: GPT2Config, + cache_config = None, + + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(layer_id, + config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPT2MLP(inner_dim, + config, + quant_config, + prefix=f"{prefix}.mlp") + + def forward( + self, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + + +class GPT2Model(nn.Module): + + def __init__( + self, + config: GPT2Config, + cache_config = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + assert not config.add_cross_attention + assert not config.scale_attn_by_inverse_layer_idx + assert not config.reorder_and_upcast_attn + self.embed_dim = config.hidden_size + self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.h = nn.ModuleList( + [ + GPT2Block(i, config, cache_config, quant_config) + for i in range(config.num_hidden_layers) + ] + ) + + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + for i in range(len(self.h)): + layer = self.h[i] + hidden_states = layer(hidden_states, forward_batch) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPT2LMHeadModel(nn.Module): + + def __init__( + self, + config: GPT2Config, + cache_config = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.quant_config = quant_config + self.transformer = GPT2Model(config, + cache_config, + quant_config, + prefix="transformer") + self.lm_head = self.transformer.wte + + self.logits_processor = LogitsProcessor(config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, forward_batch) + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, forward_batch + ) + + + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if not name.startswith("transformer."): + name = "transformer." + name + + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + +EntryClass = GPT2LMHeadModel \ No newline at end of file diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index 9cd1f4207c7..1d32b8af123 100755 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -56,6 +56,7 @@ class ModelCase: ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True), ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True), ModelCase("THUDM/glm-4-9b-chat"), + ModelCase("openai-community/gpt2") ] TORCH_DTYPES = [torch.float16]