diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 4421ea0da..6c18d3361 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -11,7 +11,8 @@ from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.basemodel.infer_struct import InferStateInfo -from lightllm.common.mem_manager import MemoryManager +from lightllm.common.kv_cache_mem_manager import MemoryManager +from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.req_manager import ReqManager from lightllm.common.infer_utils import init_req_to_token_indexes from lightllm.common.build_utils import repair_config @@ -22,7 +23,7 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type from lightllm.distributed.communication_op import dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput from lightllm.common.triton_utils.autotuner import AutotuneLevel @@ -68,7 +69,7 @@ def __init__(self, kvargs): self.is_token_healing = kvargs.get("is_token_healing", False) self.return_all_prompt_logics = kvargs.get("return_all_prompt_logics", False) assert not (self.is_token_healing and self.return_all_prompt_logics), "can not be true in same time" - self.data_type = kvargs.get("data_type", "float16") + self.data_type = get_llm_data_type() mtp_step = get_env_start_args().mtp_step self.graph_max_batch_size = kvargs.get("graph_max_batch_size", 16) self.graph_max_batch_size = ( @@ -89,7 +90,6 @@ def __init__(self, kvargs): self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"] - self._init_datatype() self._init_config() self._verify_must() self._verify_params() @@ -180,7 +180,7 @@ def _init_weights(self): def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 - self.mem_manager = MemoryManager( + self.mem_manager: MemoryManager = select_mem_manager_class()( self.max_total_token_num, dtype=self.data_type, head_num=self.config["num_attention_heads"] // self.tp_world_size_, @@ -230,16 +230,6 @@ def _init_some_value(self): self.vocab_size = self.config["vocab_size"] return - def _init_datatype(self): - if self.data_type in ["fp16", "float16"]: - self.data_type = torch.float16 - elif self.data_type in ["bf16", "bfloat16"]: - self.data_type = torch.bfloat16 - elif self.data_type in ["fp32", "float32"]: - self.data_type = torch.float32 - else: - raise ValueError(f"Unsupport datatype {self.data_type}!") - def _init_cudagraph(self): self.graph = ( None if self.disable_cudagraph else CudaGraph(self.graph_max_batch_size, self.graph_max_len_in_batch) diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 7ce34585b..85d3d8c46 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -1,7 +1,7 @@ import torch import triton import collections -from lightllm.common.mem_manager import MemoryManager +from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.common.req_manager import ReqManager from lightllm.distributed import CustomProcessGroup from typing import Tuple, Any, Optional, List diff --git a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py index c6098f950..0fdc43ab9 100644 --- a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py +++ b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py @@ -2,6 +2,7 @@ import triton import triton.language as tl +from typing import Optional @triton.jit @@ -12,16 +13,28 @@ def _offload_gpu_kv_to_cpu( gpu_stride1, gpu_stride2, gpu_stride3, + gpu_kv_cache_scale_ptr, + gpu_scale_stride0, + gpu_scale_stride1, + gpu_scale_stride2, + gpu_scale_stride3, cpu_kv_cache_ptr, cpu_stride0, cpu_stride1, cpu_stride2, cpu_stride3, cpu_stride4, + cpu_kv_cache_scale_ptr, + cpu_scale_stride0, + cpu_scale_stride1, + cpu_scale_stride2, + cpu_scale_stride3, + cpu_scale_stride4, page_indexes_ptr, page_readies_ptr, layer_num, head_dim, + scale_head_dim, block_num, cpu_k_start_head_index: tl.constexpr, cpu_k_head_num: tl.constexpr, @@ -33,6 +46,7 @@ def _offload_gpu_kv_to_cpu( gpu_v_head_num: tl.constexpr, BLOCK_HEAD_DIM: tl.constexpr, TOKEN_BLOCK: tl.constexpr, + HAS_SCALE: tl.constexpr, ): block_start_index = tl.program_id(0) block_split_size = tl.num_programs(axis=0) @@ -49,6 +63,7 @@ def _offload_gpu_kv_to_cpu( token_indexes = tl.load(token_indexes_ptr + token_range).to(tl.int64) head_dim_range = tl.arange(0, BLOCK_HEAD_DIM) head_dim_mask = head_dim_range < head_dim + scale_head_dim_mask = head_dim_range < scale_head_dim for layer_index in range(layer_num * mask_layer_num): for k_head_index in range(gpu_k_head_num): @@ -77,6 +92,29 @@ def _offload_gpu_kv_to_cpu( mask=head_dim_mask[None, :], cache_modifier=".wt", ) + if HAS_SCALE: + gpu_scale_ptr = ( + gpu_kv_cache_scale_ptr + + layer_index.to(tl.int64) * gpu_scale_stride0 + + token_indexes[:, None] * gpu_scale_stride1 + + gpu_k_head_index.to(tl.int64) * gpu_scale_stride2 + + head_dim_range[None, :] + ) + gpu_scale_data = tl.load(gpu_scale_ptr, mask=scale_head_dim_mask[None, :], other=0.0) + cpu_scale_ptr = ( + cpu_kv_cache_scale_ptr + + cpu_page_index * cpu_scale_stride0 + + layer_index.to(tl.int64) * cpu_scale_stride1 + + tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_scale_stride2 + + cpu_k_head_index * cpu_scale_stride3 + + head_dim_range[None, :] + ) + tl.store( + cpu_scale_ptr, + gpu_scale_data, + mask=scale_head_dim_mask[None, :], + cache_modifier=".wt", + ) for v_head_index in range(gpu_v_head_num): gpu_v_head_index = v_head_index + gpu_v_start_head_index @@ -104,6 +142,30 @@ def _offload_gpu_kv_to_cpu( mask=head_dim_mask[None, :], cache_modifier=".wt", ) + if HAS_SCALE: + gpu_scale_ptr = ( + gpu_kv_cache_scale_ptr + + layer_index.to(tl.int64) * gpu_scale_stride0 + + token_indexes[:, None] * gpu_scale_stride1 + + gpu_v_head_index.to(tl.int64) * gpu_scale_stride2 + + head_dim_range[None, :] + ) + gpu_scale_data = tl.load(gpu_scale_ptr, mask=scale_head_dim_mask[None, :], other=0.0) + cpu_scale_ptr = ( + cpu_kv_cache_scale_ptr + + cpu_page_index * cpu_scale_stride0 + + layer_index.to(tl.int64) * cpu_scale_stride1 + + tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_scale_stride2 + + cpu_v_head_index * cpu_scale_stride3 + + head_dim_range[None, :] + ) + tl.store( + cpu_scale_ptr, + gpu_scale_data, + mask=scale_head_dim_mask[None, :], + cache_modifier=".wt", + ) + return @@ -111,7 +173,9 @@ def _offload_gpu_kv_to_cpu( def offload_gpu_kv_to_cpu( token_indexes: torch.Tensor, gpu_kv_cache: torch.Tensor, + gpu_kv_cache_scale: Optional[torch.Tensor], cpu_kv_cache: torch.Tensor, + cpu_kv_cache_scale: Optional[torch.Tensor], page_indexes: torch.Tensor, page_readies: torch.Tensor, tp_index: int, @@ -234,6 +298,15 @@ def offload_gpu_kv_to_cpu( grid = (grid_num,) num_warps = 4 + HAS_SCALE = gpu_kv_cache_scale is not None and cpu_kv_cache_scale is not None + if HAS_SCALE: + scale_head_dim = gpu_kv_cache_scale.shape[-1] + gpu_scale_stride = gpu_kv_cache_scale.stride() + cpu_scale_stride = cpu_kv_cache_scale.stride() + else: + scale_head_dim = 0 + gpu_scale_stride = [0 for _ in range(10)] + cpu_scale_stride = [0 for _ in range(10)] _offload_gpu_kv_to_cpu[grid]( token_indexes_ptr=token_indexes, @@ -242,16 +315,28 @@ def offload_gpu_kv_to_cpu( gpu_stride1=gpu_kv_cache.stride(1), gpu_stride2=gpu_kv_cache.stride(2), gpu_stride3=gpu_kv_cache.stride(3), + gpu_kv_cache_scale_ptr=gpu_kv_cache_scale, + gpu_scale_stride0=gpu_scale_stride[0], + gpu_scale_stride1=gpu_scale_stride[1], + gpu_scale_stride2=gpu_scale_stride[2], + gpu_scale_stride3=gpu_scale_stride[3], cpu_kv_cache_ptr=cpu_kv_cache, cpu_stride0=cpu_kv_cache.stride(0), cpu_stride1=cpu_kv_cache.stride(1), cpu_stride2=cpu_kv_cache.stride(2), cpu_stride3=cpu_kv_cache.stride(3), cpu_stride4=cpu_kv_cache.stride(4), + cpu_kv_cache_scale_ptr=cpu_kv_cache_scale, + cpu_scale_stride0=cpu_scale_stride[0], + cpu_scale_stride1=cpu_scale_stride[1], + cpu_scale_stride2=cpu_scale_stride[2], + cpu_scale_stride3=cpu_scale_stride[3], + cpu_scale_stride4=cpu_scale_stride[4], page_indexes_ptr=page_indexes, page_readies_ptr=page_readies, layer_num=gpu_kv_cache.shape[0], head_dim=head_dim, + scale_head_dim=scale_head_dim, block_num=page_num, cpu_k_start_head_index=cpu_k_start_head_index, cpu_k_head_num=cpu_k_head_num, @@ -263,6 +348,7 @@ def offload_gpu_kv_to_cpu( gpu_v_head_num=gpu_v_head_num, BLOCK_HEAD_DIM=triton.next_power_of_2(head_dim), TOKEN_BLOCK=token_block_size, + HAS_SCALE=HAS_SCALE, num_warps=num_warps, num_stages=1, ) @@ -281,14 +367,26 @@ def _load_cpu_cache_to_gpu( gpu_stride1, gpu_stride2, gpu_stride3, + gpu_kv_cache_scale_ptr, + gpu_scale_stride0, + gpu_scale_stride1, + gpu_scale_stride2, + gpu_scale_stride3, cpu_kv_cache_ptr, cpu_stride0, cpu_stride1, cpu_stride2, cpu_stride3, cpu_stride4, + cpu_kv_cache_scale_ptr, + cpu_scale_stride0, + cpu_scale_stride1, + cpu_scale_stride2, + cpu_scale_stride3, + cpu_scale_stride4, layer_num, head_dim, + scale_head_dim, cpu_k_start_head_index: tl.constexpr, cpu_k_head_num: tl.constexpr, gpu_k_start_head_index: tl.constexpr, @@ -299,6 +397,7 @@ def _load_cpu_cache_to_gpu( gpu_v_head_num: tl.constexpr, BLOCK_HEAD_DIM: tl.constexpr, TOKEN_BLOCK: tl.constexpr, + HAS_SCALE: tl.constexpr, ): block_index_start = tl.program_id(0) split_block_num = tl.num_programs(0) @@ -311,9 +410,11 @@ def _load_cpu_cache_to_gpu( head_dim_range = tl.arange(0, BLOCK_HEAD_DIM) head_dim_mask = head_dim_range < head_dim + scale_head_dim_mask = head_dim_range < scale_head_dim for layer_index in range(layer_num): move_mask = token_mask[:, None] & head_dim_mask[None, :] + scale_move_mask = token_mask[:, None] & scale_head_dim_mask[None, :] for k_head_index in range(cpu_k_head_num): gpu_k_head_index = k_head_index + gpu_k_start_head_index @@ -342,6 +443,30 @@ def _load_cpu_cache_to_gpu( cpu_data, mask=move_mask, ) + if HAS_SCALE: + cpu_scale_ptr = ( + cpu_kv_cache_scale_ptr + + cpu_page_indexes[:, None] * cpu_scale_stride0 + + layer_index.to(tl.int64) * cpu_scale_stride1 + + cpu_mem_indexes[:, None] * cpu_scale_stride2 + + cpu_k_head_index * cpu_scale_stride3 + + head_dim_range[None, :] + ) + cpu_scale_data = tl.load(cpu_scale_ptr, mask=scale_move_mask, other=0.0) + + gpu_scale_ptr = ( + gpu_kv_cache_scale_ptr + + layer_index.to(tl.int64) * gpu_scale_stride0 + + gpu_mem_indexes[:, None] * gpu_scale_stride1 + + gpu_k_head_index * gpu_scale_stride2 + + head_dim_range[None, :] + ) + + tl.store( + gpu_scale_ptr, + cpu_scale_data, + mask=scale_move_mask, + ) for v_head_index in range(cpu_v_head_num): gpu_v_head_index = v_head_index + gpu_v_start_head_index @@ -370,6 +495,31 @@ def _load_cpu_cache_to_gpu( cpu_data, mask=move_mask, ) + if HAS_SCALE: + cpu_scale_ptr = ( + cpu_kv_cache_scale_ptr + + cpu_page_indexes[:, None] * cpu_scale_stride0 + + layer_index.to(tl.int64) * cpu_scale_stride1 + + cpu_mem_indexes[:, None] * cpu_scale_stride2 + + cpu_v_head_index * cpu_scale_stride3 + + head_dim_range[None, :] + ) + cpu_scale_data = tl.load(cpu_scale_ptr, mask=scale_move_mask, other=0.0) + + gpu_scale_ptr = ( + gpu_kv_cache_scale_ptr + + layer_index.to(tl.int64) * gpu_scale_stride0 + + gpu_mem_indexes[:, None] * gpu_scale_stride1 + + gpu_v_head_index * gpu_scale_stride2 + + head_dim_range[None, :] + ) + + tl.store( + gpu_scale_ptr, + cpu_scale_data, + mask=scale_move_mask, + ) + return @@ -377,7 +527,9 @@ def _load_cpu_cache_to_gpu( def load_cpu_kv_to_gpu( gpu_mem_indexes: torch.Tensor, gpu_kv_cache: torch.Tensor, + gpu_kv_cache_scale: Optional[torch.Tensor], cpu_kv_cache: torch.Tensor, + cpu_kv_cache_scale: Optional[torch.Tensor], page_indexes: torch.Tensor, tp_index: int, tp_world_size: int, @@ -496,6 +648,16 @@ def load_cpu_kv_to_gpu( grid = (grid_num,) num_warps = 4 + HAS_SCALE = gpu_kv_cache_scale is not None and cpu_kv_cache_scale is not None + if HAS_SCALE: + scale_head_dim = gpu_kv_cache_scale.shape[-1] + gpu_scale_stride = gpu_kv_cache_scale.stride() + cpu_scale_stride = cpu_kv_cache_scale.stride() + else: + scale_head_dim = 0 + gpu_scale_stride = [0 for _ in range(10)] + cpu_scale_stride = [0 for _ in range(10)] + _load_cpu_cache_to_gpu[grid]( gpu_mem_indexes_ptr=gpu_mem_indexes, copy_token_num=move_token_num, @@ -507,14 +669,26 @@ def load_cpu_kv_to_gpu( gpu_stride1=gpu_kv_cache.stride(1), gpu_stride2=gpu_kv_cache.stride(2), gpu_stride3=gpu_kv_cache.stride(3), + gpu_kv_cache_scale_ptr=gpu_kv_cache_scale, + gpu_scale_stride0=gpu_scale_stride[0], + gpu_scale_stride1=gpu_scale_stride[1], + gpu_scale_stride2=gpu_scale_stride[2], + gpu_scale_stride3=gpu_scale_stride[3], cpu_kv_cache_ptr=cpu_kv_cache, cpu_stride0=cpu_kv_cache.stride(0), cpu_stride1=cpu_kv_cache.stride(1), cpu_stride2=cpu_kv_cache.stride(2), cpu_stride3=cpu_kv_cache.stride(3), cpu_stride4=cpu_kv_cache.stride(4), + cpu_kv_cache_scale_ptr=cpu_kv_cache_scale, + cpu_scale_stride0=cpu_scale_stride[0], + cpu_scale_stride1=cpu_scale_stride[1], + cpu_scale_stride2=cpu_scale_stride[2], + cpu_scale_stride3=cpu_scale_stride[3], + cpu_scale_stride4=cpu_scale_stride[4], layer_num=gpu_kv_cache.shape[0], head_dim=head_dim, + scale_head_dim=scale_head_dim, cpu_k_start_head_index=cpu_k_start_head_index, cpu_k_head_num=cpu_k_head_num, gpu_k_start_head_index=gpu_k_start_head_index, @@ -525,6 +699,7 @@ def load_cpu_kv_to_gpu( gpu_v_head_num=gpu_v_head_num, BLOCK_HEAD_DIM=triton.next_power_of_2(head_dim), TOKEN_BLOCK=TOKEN_BLOCK, + HAS_SCALE=HAS_SCALE, num_warps=num_warps, num_stages=1, ) diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py new file mode 100644 index 000000000..66caf5d78 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -0,0 +1,20 @@ +from .mem_manager import MemoryManager, ReadOnlyStaticsMemoryManager +from .int8kv_mem_manager import INT8KVMemoryManager +from .calibration_fp8kv_mem_manager import CalibrationFP8KVMemoryManager +from .export_calibration_mem_manager import ExportCalibrationMemoryManager +from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager +from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager +from .deepseek2_mem_manager import Deepseek2MemoryManager +from .deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager + +__all__ = [ + "MemoryManager", + "ReadOnlyStaticsMemoryManager", + "INT8KVMemoryManager", + "CalibrationFP8KVMemoryManager", + "ExportCalibrationMemoryManager", + "PPLINT4KVMemoryManager", + "PPLINT8KVMemoryManager", + "Deepseek2MemoryManager", + "Deepseek2FP8KVMemoryManager", +] diff --git a/lightllm/common/calibration_fp8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/calibration_fp8kv_mem_manager.py similarity index 100% rename from lightllm/common/calibration_fp8kv_mem_manager.py rename to lightllm/common/kv_cache_mem_manager/calibration_fp8kv_mem_manager.py diff --git a/lightllm/common/deepseek2_fp8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_fp8kv_mem_manager.py similarity index 100% rename from lightllm/common/deepseek2_fp8kv_mem_manager.py rename to lightllm/common/kv_cache_mem_manager/deepseek2_fp8kv_mem_manager.py diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py similarity index 96% rename from lightllm/common/deepseek2_mem_manager.py rename to lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index 4f106bdcf..771173460 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -23,12 +23,6 @@ def get_cell_size(self): def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") - # todo, etp or edp use the same work buffer here - # also it can be used for any kernels for work buffer witout save info only - if os.environ.get("ETP_MODE_ENABLED") == "true": - self.work_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.bfloat16, device="cuda") - self.work_buffer.share_memory_() - def alloc_kv_move_buffer(self, max_req_total_len): self.kv_move_buffer = torch.empty( (1, max_req_total_len + 8, self.head_num, self.head_dim), dtype=self.dtype, device="cuda" diff --git a/lightllm/common/export_calibration_mem_manager.py b/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py similarity index 100% rename from lightllm/common/export_calibration_mem_manager.py rename to lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py diff --git a/lightllm/common/int8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/int8kv_mem_manager.py similarity index 100% rename from lightllm/common/int8kv_mem_manager.py rename to lightllm/common/kv_cache_mem_manager/int8kv_mem_manager.py diff --git a/lightllm/common/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py similarity index 99% rename from lightllm/common/mem_manager.py rename to lightllm/common/kv_cache_mem_manager/mem_manager.py index 57ae9838b..daa061b7b 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -93,7 +93,7 @@ def alloc_kv_move_buffer(self, max_req_total_len): """ pd 分离模式使用的特殊接口 """ - if isinstance(self, MemoryManager) and type(self) != MemoryManager: + if isinstance(self, MemoryManager) and type(self) is not MemoryManager: raise NotImplementedError("subclass need reimpl this method") self.kv_move_buffer = torch.empty( (1, max_req_total_len + 8, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda" @@ -103,7 +103,7 @@ def alloc_kv_move_buffer(self, max_req_total_len): return def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: - if isinstance(self, MemoryManager) and type(self) != MemoryManager: + if isinstance(self, MemoryManager) and type(self) is not MemoryManager: raise NotImplementedError("subclass need reimpl this method") num_kv_head = get_num_key_value_heads(get_env_start_args().model_dir) diff --git a/lightllm/common/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py similarity index 51% rename from lightllm/common/mem_utils.py rename to lightllm/common/kv_cache_mem_manager/mem_utils.py index d5e62dc8d..259c5a56f 100644 --- a/lightllm/common/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -1,15 +1,39 @@ -from lightllm.common.mem_manager import MemoryManager -from lightllm.common.int8kv_mem_manager import INT8KVMemoryManager -from lightllm.common.calibration_fp8kv_mem_manager import CalibrationFP8KVMemoryManager -from lightllm.common.export_calibration_mem_manager import ExportCalibrationMemoryManager -from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager -from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager +from . import ( + MemoryManager, + INT8KVMemoryManager, + CalibrationFP8KVMemoryManager, + ExportCalibrationMemoryManager, + PPLINT8KVMemoryManager, + PPLINT4KVMemoryManager, + Deepseek2MemoryManager, + Deepseek2FP8KVMemoryManager, +) from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.llm_utils import get_llm_model_class +from functools import lru_cache logger = init_logger(__name__) -def select_mem_manager_class(mode): +@lru_cache(maxsize=None) +def select_mem_manager_class(): + mode = get_env_start_args().mode + + # case 1 + # 先判断是否是 deepseek 系列的模型 + model_class = get_llm_model_class() + from lightllm.models import Deepseek2TpPartModel + + if issubclass(model_class, Deepseek2TpPartModel): + mem_class = Deepseek2MemoryManager + if "triton_fp8kv" in mode: + mem_class = Deepseek2FP8KVMemoryManager + + logger.info(f"Model kv cache using mode {mode}, mem_manager class: {mem_class}") + return mem_class + + # case normal logger.info(f"mode setting params: {mode}") if "ppl_int8kv" in mode or "ppl_int8kv_flashdecoding" in mode or "ppl_int8kv_flashdecoding_diverse" in mode: memory_manager_class = PPLINT8KVMemoryManager @@ -32,3 +56,9 @@ def select_mem_manager_class(mode): memory_manager_class = MemoryManager logger.info("Model kv cache using mode normal") return memory_manager_class + + +@lru_cache(maxsize=None) +def used_mem_manager_has_scale() -> bool: + mem_class = select_mem_manager_class() + return mem_class in [PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, INT8KVMemoryManager] diff --git a/lightllm/common/offline_fp8_quant_mem_manager.py b/lightllm/common/kv_cache_mem_manager/offline_fp8_quant_mem_manager.py similarity index 100% rename from lightllm/common/offline_fp8_quant_mem_manager.py rename to lightllm/common/kv_cache_mem_manager/offline_fp8_quant_mem_manager.py diff --git a/lightllm/common/ppl_int4kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py similarity index 100% rename from lightllm/common/ppl_int4kv_mem_manager.py rename to lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py diff --git a/lightllm/common/ppl_int8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py similarity index 100% rename from lightllm/common/ppl_int8kv_mem_manager.py rename to lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index dcd1b3072..40c8aa993 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,7 +1,7 @@ import torch import collections from lightllm.utils.log_utils import init_logger -from .mem_manager import MemoryManager +from .kv_cache_mem_manager import MemoryManager from typing import List, Optional from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 96329eabe..5237f8fd2 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -36,4 +36,4 @@ Tarsier2LlamaTpPartModel, ) from lightllm.models.gpt_oss.model import GptOssTpPartModel -from .registry import get_model +from .registry import get_model, get_model_class diff --git a/lightllm/models/cohere/model.py b/lightllm/models/cohere/model.py index f3c5a1709..5b317c133 100644 --- a/lightllm/models/cohere/model.py +++ b/lightllm/models/cohere/model.py @@ -4,7 +4,7 @@ from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_cohere_template import ( TransformerLayerCohereInferTpl, ) -from lightllm.common.mem_manager import MemoryManager +from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.models.registry import ModelRegistry from lightllm.models.cohere.infer_struct import CohereInferStateInfo from lightllm.models.cohere.layer_infer.post_layer_infer import CoherePostLayerInfer diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index a08147769..6dfd88970 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -9,8 +9,7 @@ from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.models.llama.model import LlamaTpPartModel -from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager -from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager +from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.utils.log_utils import init_logger from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args @@ -94,9 +93,7 @@ def _verify_params(self): return super()._verify_params() def _init_mem_manager(self): - manager_class = Deepseek2MemoryManager - if "triton_fp8kv" in self.mode: - manager_class = Deepseek2FP8KVMemoryManager + manager_class = select_mem_manager_class() # mtp 模式下需要在mem manger上扩展draft model使用的layer added_mtp_layer_num = 0 diff --git a/lightllm/models/gemma3/model.py b/lightllm/models/gemma3/model.py index 262825b9d..42326169a 100644 --- a/lightllm/models/gemma3/model.py +++ b/lightllm/models/gemma3/model.py @@ -5,7 +5,7 @@ import torch from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer -from lightllm.common.mem_utils import select_mem_manager_class +from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.models.gemma3.infer_struct import Gemma3InferStateInfo from lightllm.models.gemma3.layer_infer.post_layer_infer import Gemma3PostLayerInfer from lightllm.models.gemma3.layer_infer.pre_layer_infer import Gemma3PreLayerInfer @@ -143,7 +143,7 @@ def _init_custom(self): self._init_to_get_rotary() def _init_mem_manager(self): - self.mem_manager = select_mem_manager_class(self.mode)( + self.mem_manager = select_mem_manager_class()( self.max_total_token_num, dtype=torch.bfloat16, head_num=self.config["num_key_value_heads"] // self.tp_world_size_, diff --git a/lightllm/models/gemma_2b/model.py b/lightllm/models/gemma_2b/model.py index 11b1f9d08..4b425c9ce 100644 --- a/lightllm/models/gemma_2b/model.py +++ b/lightllm/models/gemma_2b/model.py @@ -6,7 +6,7 @@ from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.model import LlamaTpPartModel -from lightllm.common.mem_utils import select_mem_manager_class +from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class @ModelRegistry("gemma") @@ -38,7 +38,7 @@ def _verify_params(self): return def _init_mem_manager(self): - self.mem_manager = select_mem_manager_class(self.mode)( + self.mem_manager = select_mem_manager_class()( self.max_total_token_num, dtype=self.data_type, head_num=self.config["num_key_value_heads"], diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py index b231e2bcc..34a017b31 100644 --- a/lightllm/models/gpt_oss/model.py +++ b/lightllm/models/gpt_oss/model.py @@ -1,15 +1,7 @@ -import torch -import numpy as np - from lightllm.models.gpt_oss.layer_infer.transformer_layer_infer import GptOssTransformerLayerInfer from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight -from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer -from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.models.registry import ModelRegistry -from lightllm.common.basemodel.basemodel import TpPartBaseModel -from lightllm.common.mem_manager import MemoryManager from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 8c5f203dd..5e9762ceb 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -15,7 +15,7 @@ from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo from lightllm.common.basemodel import TpPartBaseModel -from lightllm.common.mem_utils import select_mem_manager_class +from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id @@ -86,7 +86,7 @@ def _verify_params(self): def _init_mem_manager(self): head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"] head_dim_ = self.config.get("head_dim", head_dim_) - self.mem_manager = select_mem_manager_class(self.mode)( + self.mem_manager = select_mem_manager_class()( self.max_total_token_num, dtype=self.data_type, head_num=self.config["num_key_value_heads"] // self.tp_world_size_, diff --git a/lightllm/models/mistral/model.py b/lightllm/models/mistral/model.py index 6d77d42ae..ef7e5d695 100644 --- a/lightllm/models/mistral/model.py +++ b/lightllm/models/mistral/model.py @@ -9,7 +9,7 @@ from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.mistral.layer_infer.transformer_layer_infer import MistralTransformerLayerInfer -from lightllm.common.mem_utils import select_mem_manager_class +from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class @ModelRegistry("mistral") @@ -44,7 +44,7 @@ def _init_mem_manager(self): # Dealing with head_dim_!=n_embed // num_attention_heads scenarios, such as mistral 13B head_dim = self.config["hidden_size"] // self.config["num_attention_heads"] head_dim = self.config.get("head_dim", head_dim) - self.mem_manager = select_mem_manager_class(self.mode)( + self.mem_manager = select_mem_manager_class()( self.max_total_token_num, dtype=self.data_type, head_num=self.config["num_key_value_heads"] // self.tp_world_size_, diff --git a/lightllm/models/mixtral/model.py b/lightllm/models/mixtral/model.py index af4c05354..3c2d7b4e8 100644 --- a/lightllm/models/mixtral/model.py +++ b/lightllm/models/mixtral/model.py @@ -2,7 +2,7 @@ import numpy as np from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.basemodel import TpPartBaseModel -from lightllm.common.mem_manager import MemoryManager +from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index e3d8de461..5b756aadf 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -2,7 +2,7 @@ from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel -from lightllm.common.mem_utils import select_mem_manager_class +from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class @ModelRegistry("qwen2") @@ -41,7 +41,7 @@ def _init_mem_manager(self): head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"] head_dim_ = self.config.get("head_dim", head_dim_) tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) - self.mem_manager = select_mem_manager_class(self.mode)( + self.mem_manager = select_mem_manager_class()( self.max_total_token_num, dtype=self.data_type, head_num=tp_k_head_num_, diff --git a/lightllm/models/registry.py b/lightllm/models/registry.py index 8a774179d..e9568cc3d 100644 --- a/lightllm/models/registry.py +++ b/lightllm/models/registry.py @@ -67,6 +67,27 @@ def get_model(self, model_cfg: dict, model_kvargs: dict) -> tuple: is_multimodal = matches[0].is_multimodal return model, is_multimodal + def get_model_class(self, model_cfg: dict): + """Get model""" + model_type = model_cfg.get("model_type", "") + configs = self._registry.get(model_type, []) + matches = [] + for cfg in configs: + if cfg.condition is None or cfg.condition(model_cfg): + matches.append(cfg) + + if len(matches) == 0: + raise ValueError(f"Model type {model_type} is not supported.") + + if len(matches) > 1: + # Keep conditionally matched models + matches = [m for m in matches if m.condition is not None] + + assert ( + len(matches) == 1 + ), "Existence of coupled conditon, inability to determine the class of models instantiated" + return matches[0].model_class + ModelRegistry = _ModelRegistries() @@ -80,6 +101,15 @@ def get_model(model_cfg: dict, model_kvargs: dict): raise +def get_model_class(model_cfg: dict): + try: + model_class = ModelRegistry.get_model_class(model_cfg) + return model_class + except Exception as e: + logger.exception(str(e)) + raise + + def is_reward_model() -> Callable[[Dict[str, any]], bool]: """Predicate: whether the model is RewardModel.""" return lambda model_cfg: "RewardModel" in model_cfg.get("architectures", [""])[0] diff --git a/lightllm/models/starcoder/model.py b/lightllm/models/starcoder/model.py index cdf9541c0..ea2aeabbc 100644 --- a/lightllm/models/starcoder/model.py +++ b/lightllm/models/starcoder/model.py @@ -5,7 +5,7 @@ from lightllm.models.starcoder.layer_weights.pre_and_post_layer_weight import StarcoderPreAndPostLayerWeight from lightllm.models.bloom.layer_infer.post_layer_infer import BloomPostLayerInfer from lightllm.common.build_utils import repair_config -from lightllm.common.mem_utils import select_mem_manager_class +from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.basemodel import TpPartBaseModel from lightllm.common.basemodel import InferStateInfo @@ -41,7 +41,7 @@ def _verify_params(self): assert self.load_way == "HF", "StarCoder only support HF format to load Now!" def _init_mem_manager(self): - self.mem_manager = select_mem_manager_class(self.mode)( + self.mem_manager = select_mem_manager_class()( self.max_total_token_num, dtype=self.data_type, head_num=self.config["num_key_value_heads"], diff --git a/lightllm/models/starcoder2/model.py b/lightllm/models/starcoder2/model.py index cbe08f44c..a299c08be 100644 --- a/lightllm/models/starcoder2/model.py +++ b/lightllm/models/starcoder2/model.py @@ -8,7 +8,7 @@ from lightllm.models.bloom.layer_infer.post_layer_infer import BloomPostLayerInfer from lightllm.common.build_utils import repair_config -from lightllm.common.mem_utils import select_mem_manager_class +from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.basemodel import TpPartBaseModel @@ -47,7 +47,7 @@ def _init_custom(self): return def _init_mem_manager(self): - self.mem_manager = select_mem_manager_class(self.mode)( + self.mem_manager = select_mem_manager_class()( self.max_total_token_num, dtype=self.data_type, head_num=self.config["num_key_value_heads"] // self.tp_world_size_, diff --git a/lightllm/server/multi_level_kv_cache/cpu_cache_client.py b/lightllm/server/multi_level_kv_cache/cpu_cache_client.py index f7e45766f..8f475ca20 100644 --- a/lightllm/server/multi_level_kv_cache/cpu_cache_client.py +++ b/lightllm/server/multi_level_kv_cache/cpu_cache_client.py @@ -285,9 +285,11 @@ def _create_shm_cpu_kv_cache(self): self.kv_cache_tensor_meta.layer_num, self.kv_cache_tensor_meta.token_page_size, self.kv_cache_tensor_meta.num_heads, - self.kv_cache_tensor_meta.head_dim, + self.kv_cache_tensor_meta.get_merged_head_dim(), + ) + self.cpu_kv_cache_tensor = ( + torch.from_numpy(numpy_array).view(dtype=self.kv_cache_tensor_meta.data_type).view(shape) ) - self.cpu_kv_cache_tensor = torch.from_numpy(numpy_array).view(dtype=torch.bfloat16).view(shape) return def _attach_shm_cpu_kv_cache(self): @@ -301,9 +303,11 @@ def _attach_shm_cpu_kv_cache(self): self.kv_cache_tensor_meta.layer_num, self.kv_cache_tensor_meta.token_page_size, self.kv_cache_tensor_meta.num_heads, - self.kv_cache_tensor_meta.head_dim, + self.kv_cache_tensor_meta.get_merged_head_dim(), + ) + self.cpu_kv_cache_tensor = ( + torch.from_numpy(numpy_array).view(dtype=self.kv_cache_tensor_meta.data_type).view(shape) ) - self.cpu_kv_cache_tensor = torch.from_numpy(numpy_array).view(dtype=torch.bfloat16).view(shape) assert shm_ptr == self.cpu_kv_cache_tensor.data_ptr() # test code diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 2bf0a4d5a..c51774898 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -104,7 +104,7 @@ class RadixCache: """ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): - from lightllm.common.mem_manager import MemoryManager + from lightllm.common.kv_cache_mem_manager import MemoryManager self.mem_manager: MemoryManager = mem_manager self._key_dtype = torch.int64 diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 3c8ca2399..40607c2ce 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -29,7 +29,7 @@ from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient from lightllm.common.basemodel.infer_lock import g_router_lock -from lightllm.common.mem_manager import ReadOnlyStaticsMemoryManager +from lightllm.common.kv_cache_mem_manager import ReadOnlyStaticsMemoryManager from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index ec923d445..9d72bac6e 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -9,7 +9,7 @@ from datetime import timedelta from typing import List, Dict, Union from lightllm.utils.log_utils import init_logger -from lightllm.common.mem_manager import MemoryManager +from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.graceful_utils import graceful_registry diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index 709c88968..1972c47b4 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -9,7 +9,7 @@ from datetime import timedelta from typing import List, Dict, Union from lightllm.utils.log_utils import init_logger -from lightllm.common.mem_manager import MemoryManager +from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.graceful_utils import graceful_registry diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py index 024c43965..5cde46308 100644 --- a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -50,11 +50,13 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): match_tokens = len(page_list) * token_page_size # 更新命中的 cpu kv cache 长度, 减去radix cache和disk cache的部分. if is_master_in_dp: - req.shm_req.cpu_prompt_cache_len = match_tokens - req.cur_kv_len - req.shm_req.disk_prompt_cache_len + req.shm_req.cpu_prompt_cache_len = max( + 0, match_tokens - req.cur_kv_len - req.shm_req.disk_prompt_cache_len + ) need_token_num = match_tokens - req.cur_kv_len # 多匹配了一定数量的token同时请求长度大于一定的长度,才进行复制操作,不然操作效率不高,代价过高 - if need_token_num >= 256 and req.shm_req.input_len >= 512: + if need_token_num >= 128 and req.shm_req.input_len >= 256: if need_token_num <= idle_token_num: if self.backend.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_token_num) @@ -72,11 +74,28 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): # TODO 更有效的分配策略。 grid_num = 16 if self.need_sync_compute_stream or (not self.args.enable_fa3) else 1 + mem_manager = self.backend.model.mem_manager + if hasattr(mem_manager, "scale_buffer") and mem_manager.scale_buffer is not None: + cpu_cache_meta = self.cpu_cache_client.kv_cache_tensor_meta + cpu_kv_cache = self.cpu_cache_client.cpu_kv_cache_tensor[ + :, :, :, :, 0 : cpu_cache_meta.head_dim + ] + cpu_kv_cache_scale = self.cpu_cache_client.cpu_kv_cache_tensor[ + :, :, :, :, cpu_cache_meta.head_dim : + ].view(mem_manager.scale_buffer.dtype) + gpu_kv_cache_scale = mem_manager.scale_buffer + else: + cpu_kv_cache = self.cpu_cache_client.cpu_kv_cache_tensor + cpu_kv_cache_scale = None + gpu_kv_cache_scale = None + # 将 cpu page 的内容拷贝到 gpu 页面中 load_cpu_kv_to_gpu( gpu_mem_indexes=mem_indexes.cuda(non_blocking=True), - gpu_kv_cache=self.backend.model.mem_manager.kv_buffer, - cpu_kv_cache=self.cpu_cache_client.cpu_kv_cache_tensor, + gpu_kv_cache=mem_manager.kv_buffer, + gpu_kv_cache_scale=gpu_kv_cache_scale, + cpu_kv_cache=cpu_kv_cache, + cpu_kv_cache_scale=cpu_kv_cache_scale, page_indexes=torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda(non_blocking=True), tp_index=self.backend.rank_in_dp, tp_world_size=self.backend.dp_world_size, @@ -209,11 +228,26 @@ def _start_kv_cache_offload_task( # TODO 更有效的分配策略。 grid_num = 16 if self.need_sync_compute_stream or (not self.args.enable_fa3) else 1 + mem_manager = self.backend.model.mem_manager + if hasattr(mem_manager, "scale_buffer") and mem_manager.scale_buffer is not None: + cpu_cache_meta = self.cpu_cache_client.kv_cache_tensor_meta + cpu_kv_cache = self.cpu_cache_client.cpu_kv_cache_tensor[:, :, :, :, 0 : cpu_cache_meta.head_dim] + cpu_kv_cache_scale = self.cpu_cache_client.cpu_kv_cache_tensor[ + :, :, :, :, cpu_cache_meta.head_dim : + ].view(mem_manager.scale_buffer.dtype) + gpu_kv_cache_scale = mem_manager.scale_buffer + else: + cpu_kv_cache = self.cpu_cache_client.cpu_kv_cache_tensor + cpu_kv_cache_scale = None + gpu_kv_cache_scale = None + # assert max(page_list) < self.cpu_cache_client.cpu_kv_cache_tensor.shape[0] offload_gpu_kv_to_cpu( token_indexes=token_indexes, - gpu_kv_cache=self.backend.model.mem_manager.kv_buffer, - cpu_kv_cache=self.cpu_cache_client.cpu_kv_cache_tensor, + gpu_kv_cache=mem_manager.kv_buffer, + gpu_kv_cache_scale=gpu_kv_cache_scale, + cpu_kv_cache=cpu_kv_cache, + cpu_kv_cache_scale=cpu_kv_cache_scale, page_indexes=page_indexes, page_readies=page_readies, tp_index=self.backend.rank_in_dp, diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py index b1fd0c7b9..e7dc30ad8 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py @@ -8,7 +8,7 @@ import pickle from typing import List, Dict, Union, Deque, Optional from lightllm.utils.log_utils import init_logger -from lightllm.common.mem_manager import MemoryManager +from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.server.pd_io_struct import ( NIXLChunckedTransTask, NIXLChunckedTransTaskGroup, diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py index 1ff8977f6..8265afc27 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py @@ -8,7 +8,7 @@ import pickle from typing import List, Dict, Union, Deque, Optional from lightllm.utils.log_utils import init_logger -from lightllm.common.mem_manager import MemoryManager +from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NIXLChunckedTransTaskRet from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.graceful_utils import graceful_registry diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 186b8dde1..8995afbc5 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -49,6 +49,20 @@ def get_env_start_args(): return start_args +@lru_cache(maxsize=None) +def get_llm_data_type() -> torch.dtype: + data_type: str = get_env_start_args().data_type + if data_type in ["fp16", "float16"]: + data_type = torch.float16 + elif data_type in ["bf16", "bfloat16"]: + data_type = torch.bfloat16 + elif data_type in ["fp32", "float32"]: + data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {data_type}!") + return data_type + + @lru_cache(maxsize=None) def enable_env_vars(args): return os.getenv(args, "False").upper() in ["ON", "TRUE", "1"] diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 46ef98c6a..f3f00c5ad 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -8,9 +8,21 @@ import numpy as np import triton from functools import lru_cache -from lightllm.utils.envs_utils import get_env_start_args, enable_huge_page +from lightllm.utils.envs_utils import get_env_start_args, enable_huge_page, get_llm_data_type from lightllm.utils.log_utils import init_logger -from lightllm.utils.config_utils import get_num_key_value_heads, get_head_dim, get_layer_num, get_model_type +from lightllm.utils.config_utils import get_num_key_value_heads, get_head_dim, get_layer_num +from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.common.kv_cache_mem_manager import ( + MemoryManager, + INT8KVMemoryManager, + CalibrationFP8KVMemoryManager, + ExportCalibrationMemoryManager, + PPLINT8KVMemoryManager, + PPLINT4KVMemoryManager, + Deepseek2MemoryManager, + Deepseek2FP8KVMemoryManager, +) + from typing import List, Tuple, Optional from tqdm import tqdm from lightllm.utils.auto_shm_cleanup import register_sysv_shm_for_cleanup @@ -49,33 +61,63 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": args = get_env_start_args() assert args.enable_cpu_cache - if get_model_type(model_path=args.model_dir) in ["deepseek_v2", "deepseek_v3"]: - item_size = 2 - num_key_value_heads = 1 - head_dim = 512 + 64 - layer_num = get_layer_num(args.model_dir) + mem_manager_class = select_mem_manager_class() + if mem_manager_class is Deepseek2MemoryManager: + cpu_cache_meta = CpuKVCacheMeta( + page_num=0, + token_page_size=args.cpu_cache_token_page_size, + layer_num=get_layer_num(args.model_dir), + num_heads=1, + head_dim=512 + 64, + data_type=get_llm_data_type(), + scale_head_dim=0, + scale_data_type=get_llm_data_type(), + ) + elif mem_manager_class is Deepseek2FP8KVMemoryManager: + cpu_cache_meta = CpuKVCacheMeta( + page_num=0, + token_page_size=args.cpu_cache_token_page_size, + layer_num=get_layer_num(args.model_dir), + num_heads=1, + head_dim=512 + 64 + 2, + data_type=torch.uint8, + scale_head_dim=0, + scale_data_type=get_llm_data_type(), + ) + elif mem_manager_class is MemoryManager: + cpu_cache_meta = CpuKVCacheMeta( + page_num=0, + token_page_size=args.cpu_cache_token_page_size, + layer_num=get_layer_num(args.model_dir), + num_heads=get_num_key_value_heads(args.model_dir) * 2, + head_dim=get_head_dim(args.model_dir), + data_type=get_llm_data_type(), + scale_head_dim=0, + scale_data_type=get_llm_data_type(), + ) + elif mem_manager_class is PPLINT8KVMemoryManager: + cpu_cache_meta = CpuKVCacheMeta( + page_num=0, + token_page_size=args.cpu_cache_token_page_size, + layer_num=get_layer_num(args.model_dir), + num_heads=get_num_key_value_heads(args.model_dir) * 2, + head_dim=get_head_dim(args.model_dir), + data_type=torch.int8, + scale_head_dim=get_head_dim(args.model_dir) // 8, + scale_data_type=get_llm_data_type(), + ) else: - item_size = 2 - num_key_value_heads = get_num_key_value_heads(args.model_dir) * 2 - head_dim = get_head_dim(args.model_dir) - layer_num = get_layer_num(args.model_dir) + logger.error(f"not support mem manager: {mem_manager_class} for cpu kv cache") + raise Exception(f"not support mem manager: {mem_manager_class} for cpu kv cache") if args.mtp_mode is not None: # TODO 可能会存在不同mtp模式的精度问题 - layer_num += 1 - - one_token_byte_size = layer_num * num_key_value_heads * head_dim * item_size - one_page_byte_size = args.cpu_cache_token_page_size * one_token_byte_size - cpu_cache_page_num = int((args.cpu_cache_storage_size * 1024 * 1024 * 1024) / one_page_byte_size) - - cpu_cache_meta = CpuKVCacheMeta( - page_num=cpu_cache_page_num, - layer_num=layer_num, - token_page_size=args.cpu_cache_token_page_size, - num_heads=num_key_value_heads, - head_dim=head_dim, - item_size=item_size, + cpu_cache_meta.layer_num += 1 + + cpu_cache_page_num = int( + (args.cpu_cache_storage_size * 1024 * 1024 * 1024) / (cpu_cache_meta.calcu_one_page_size()) ) + cpu_cache_meta.page_num = cpu_cache_page_num logger.info(f"cpu kv cache page num: {cpu_cache_meta.page_num}") @@ -154,14 +196,35 @@ def _get_default_hugepage_size() -> int: @dataclasses.dataclass class CpuKVCacheMeta: page_num: int - layer_num: int token_page_size: int + layer_num: int num_heads: int head_dim: int - item_size: int + data_type: torch.dtype + scale_head_dim: int + scale_data_type: torch.dtype def calcu_size(self): - return self.page_num * self.layer_num * self.token_page_size * self.num_heads * self.head_dim * self.item_size + return self.page_num * self.calcu_one_page_size() + + def calcu_one_page_size(self): + return ( + self.token_page_size + * self.layer_num + * self.num_heads + * (self.head_dim * self.data_type.itemsize + self.scale_head_dim * self.scale_data_type.itemsize) + ) + + def get_merged_head_dim(self): + """ + 返回将head_dim 和 scale_head_dim 看成融合成一个head_dim时候, head_dim的长度。 + """ + assert ( + self.head_dim * self.data_type.itemsize + self.scale_head_dim * self.scale_data_type.itemsize + ) % self.data_type.itemsize == 0 + return ( + self.head_dim * self.data_type.itemsize + self.scale_head_dim * self.scale_data_type.itemsize + ) // self.data_type.itemsize @lru_cache(maxsize=None) diff --git a/lightllm/utils/llm_utils.py b/lightllm/utils/llm_utils.py new file mode 100644 index 000000000..ced75615d --- /dev/null +++ b/lightllm/utils/llm_utils.py @@ -0,0 +1,17 @@ +from functools import lru_cache +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.log_utils import init_logger + + +logger = init_logger(__name__) + + +@lru_cache(maxsize=None) +def get_llm_model_class(): + from transformers.configuration_utils import PretrainedConfig + + model_cfg, _ = PretrainedConfig.get_config_dict(get_env_start_args().model_dir) + from lightllm.models import get_model_class + + model_class = get_model_class(model_cfg=model_cfg) + return model_class diff --git a/test/start_scripts/draft.sh b/test/start_scripts/draft.sh new file mode 100644 index 000000000..866f5f2fa --- /dev/null +++ b/test/start_scripts/draft.sh @@ -0,0 +1,28 @@ +# 使能 cpu cache 功能,扩大kv cache 复用的可能性。 +LOADWORKER=18 python -m lightllm.server.api_server \ +--model_dir /mtc/models/qwen3-8b --tp 2 --dp 1 --enable_cpu_cache --cpu_cache_storage_size 66 --cpu_cache_token_page_size 128 \ +--batch_max_tokens 4096 --chunked_prefill_size 2048 \ +--max_total_token_num 20000 \ +--mode "ppl_int8kv_flashdecoding" | tee log.txt + + +# 精度评测命令 +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions \ +--model_args '{"model":"Qwen/Qwen3-8B", "base_url":"http://localhost:8000/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + + + +# H200 single node deepseek R1 tp mode +LOADWORKER=18 python -m lightllm.server.api_server \ +--model_dir /mtc/DeepSeek-R1 \ +--tp 8 \ +--enable_fa3 \ +--batch_max_tokens 4096 --chunked_prefill_size 2048 \ +--max_total_token_num 20000 \ +--enable_cpu_cache --cpu_cache_storage_size 66 --cpu_cache_token_page_size 128 + +# if you want to enable microbatch overlap, you can uncomment the following lines +#--enable_prefill_microbatch_overlap \ +#--enable_decode_microbatch_overlap \ +# 精度测试。 +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8000/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/start_scripts/single_node_tp_cpu_cache_enable.sh b/test/start_scripts/single_node_tp_cpu_cache_enable.sh new file mode 100644 index 000000000..3caabb59b --- /dev/null +++ b/test/start_scripts/single_node_tp_cpu_cache_enable.sh @@ -0,0 +1,11 @@ +# 使能 cpu cache 功能,扩大kv cache 复用的可能性。 +LOADWORKER=18 python -m lightllm.server.api_server \ +--model_dir /mtc/models/qwen3-8b --tp 2 --dp 1 --enable_cpu_cache --cpu_cache_storage_size 66 --cpu_cache_token_page_size 128 \ +--batch_max_tokens 4096 --chunked_prefill_size 2048 \ +--max_total_token_num 20000 \ +--mode "ppl_int8kv_flashdecoding" | tee log.txt + + +# 精度评测命令 +# HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions \ +# --model_args '{"model":"Qwen/Qwen3-8B", "base_url":"http://localhost:8000/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code