Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
from lightllm.common.basemodel.infer_struct import InferStateInfo
from lightllm.common.mem_manager import MemoryManager
from lightllm.common.kv_cache_mem_manager import MemoryManager
from lightllm.common.req_manager import ReqManager
from lightllm.common.infer_utils import init_req_to_token_indexes
from lightllm.common.build_utils import repair_config
Expand All @@ -22,7 +22,7 @@
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_dp_world_size
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type
from lightllm.distributed.communication_op import dist_group_manager
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
from lightllm.common.triton_utils.autotuner import AutotuneLevel
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(self, kvargs):
self.is_token_healing = kvargs.get("is_token_healing", False)
self.return_all_prompt_logics = kvargs.get("return_all_prompt_logics", False)
assert not (self.is_token_healing and self.return_all_prompt_logics), "can not be true in same time"
self.data_type = kvargs.get("data_type", "float16")
self.data_type = get_llm_data_type()
mtp_step = get_env_start_args().mtp_step
self.graph_max_batch_size = kvargs.get("graph_max_batch_size", 16)
self.graph_max_batch_size = (
Expand All @@ -89,7 +89,6 @@ def __init__(self, kvargs):

self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"]

self._init_datatype()
self._init_config()
self._verify_must()
self._verify_params()
Expand Down Expand Up @@ -230,16 +229,6 @@ def _init_some_value(self):
self.vocab_size = self.config["vocab_size"]
return

def _init_datatype(self):
if self.data_type in ["fp16", "float16"]:
self.data_type = torch.float16
elif self.data_type in ["bf16", "bfloat16"]:
self.data_type = torch.bfloat16
elif self.data_type in ["fp32", "float32"]:
self.data_type = torch.float32
else:
raise ValueError(f"Unsupport datatype {self.data_type}!")

def _init_cudagraph(self):
self.graph = (
None if self.disable_cudagraph else CudaGraph(self.graph_max_batch_size, self.graph_max_len_in_batch)
Expand Down
2 changes: 1 addition & 1 deletion lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import triton
import collections
from lightllm.common.mem_manager import MemoryManager
from lightllm.common.kv_cache_mem_manager import MemoryManager
from lightllm.common.req_manager import ReqManager
from lightllm.distributed import CustomProcessGroup
from typing import Tuple, Any, Optional, List
Expand Down
Loading