Skip to content

Commit 2ef3002

Browse files
committed
Update __init__.py
1 parent 6d9c011 commit 2ef3002

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

server/text_generation_server/models/__init__.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,6 @@
8383

8484
try:
8585
from text_generation_server.models.flash_causal_lm import FlashCausalLM
86-
from text_generation_server.models.transformers_flash_causal_lm import (
87-
TransformersFlashCausalLM,
88-
)
8986
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
9087
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
9188
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
@@ -180,6 +177,14 @@
180177
if MAMBA_AVAILABLE:
181178
__all__.append(Mamba)
182179

180+
FLASH_TRANSFORMERS_BACKEND = True
181+
try:
182+
from text_generation_server.models.transformers_flash_causal_lm import (
183+
TransformersFlashCausalLM,
184+
)
185+
except ImportError:
186+
FLASH_TRANSFORMERS_BACKEND = False
187+
183188

184189
class ModelType(enum.Enum):
185190
DEEPSEEK_V2 = {
@@ -384,12 +389,8 @@ def get_model(
384389
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
385390
)
386391

387-
if FLASH_ATTENTION and transformers_model_class._supports_flex_attn:
392+
if FLASH_TRANSFORMERS_BACKEND and transformers_model_class._supports_flex_attn:
388393
transformers_causal_lm_class = TransformersFlashCausalLM
389-
if lora_adapter_ids is not None and len(lora_adapter_ids) > 0:
390-
raise ValueError(
391-
"Flash `Transformers` modeling backend does not support `lora_adapter_ids`."
392-
)
393394

394395
quantization_config = config_dict.get("quantization_config", None)
395396
if quantization_config is None:

0 commit comments

Comments
 (0)