16
16
from huggingface_hub import hf_hub_download , HfApi
17
17
from typing import Optional , List , Dict
18
18
from pathlib import Path
19
+ import transformers
19
20
20
21
from text_generation_server .utils .speculate import get_speculate , set_speculate
21
22
from text_generation_server .models .model import Model
22
23
from text_generation_server .models .causal_lm import CausalLM , CausalLMBatchKeysLast
24
+
23
25
from text_generation_server .models .custom_modeling .opt_modeling import OPTForCausalLM
24
26
from text_generation_server .models .custom_modeling .mpt_modeling import (
25
27
MPTForCausalLM ,
178
180
if MAMBA_AVAILABLE :
179
181
__all__ .append (Mamba )
180
182
183
+ FLASH_TRANSFORMERS_BACKEND = True
184
+ try :
185
+ from text_generation_server .models .transformers_flash_causal_lm import (
186
+ TransformersFlashCausalLM ,
187
+ )
188
+ except ImportError :
189
+ FLASH_TRANSFORMERS_BACKEND = False
190
+
181
191
182
192
class ModelType (enum .Enum ):
183
193
DEEPSEEK_V2 = {
@@ -381,6 +391,21 @@ def get_model(
381
391
)
382
392
model_type = config_dict .get ("model_type" , None )
383
393
394
+ transformers_causal_lm_class = CausalLM
395
+
396
+ # Fast transformers path
397
+ transformers_model_class = getattr (
398
+ transformers ,
399
+ modeling_auto .MODEL_FOR_CAUSAL_LM_MAPPING_NAMES .get (model_type , "" ),
400
+ None ,
401
+ )
402
+ if (
403
+ FLASH_TRANSFORMERS_BACKEND
404
+ and transformers_model_class is not None
405
+ and transformers_model_class ._supports_flex_attn
406
+ ):
407
+ transformers_causal_lm_class = TransformersFlashCausalLM
408
+
384
409
quantization_config = config_dict .get ("quantization_config" , None )
385
410
if quantization_config is None :
386
411
quantization_config = config_dict .get ("compression_config" , None )
@@ -624,7 +649,7 @@ def get_model(
624
649
FLASH_ATT_ERROR_MESSAGE .format ("Sharded Deepseek V2" )
625
650
)
626
651
else :
627
- return CausalLM .fallback (
652
+ return transformers_causal_lm_class .fallback (
628
653
model_id ,
629
654
revision ,
630
655
quantize = quantize ,
@@ -683,7 +708,7 @@ def get_model(
683
708
FLASH_ATT_ERROR_MESSAGE .format ("Sharded Santacoder" )
684
709
)
685
710
else :
686
- return CausalLM .fallback (
711
+ return transformers_causal_lm_class .fallback (
687
712
model_id = model_id ,
688
713
revision = revision ,
689
714
quantize = quantize ,
@@ -731,7 +756,7 @@ def get_model(
731
756
except RuntimeError as e :
732
757
# Lots of legacy models with various weight names.
733
758
log_master (logger .warning , f"Couldn't load flash gpt2 variant: { e } " )
734
- return CausalLM .fallback (
759
+ return transformers_causal_lm_class .fallback (
735
760
model_id ,
736
761
revision ,
737
762
quantize = quantize ,
@@ -742,7 +767,7 @@ def get_model(
742
767
elif sharded :
743
768
raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded GPT-2" ))
744
769
else :
745
- return CausalLM .fallback (
770
+ return transformers_causal_lm_class .fallback (
746
771
model_id ,
747
772
revision ,
748
773
quantize = quantize ,
@@ -767,7 +792,7 @@ def get_model(
767
792
except RuntimeError as e :
768
793
# Lots of legacy models with various weight names.
769
794
log_master (logger .warning , f"Couldn't load flash gptj variant: { e } " )
770
- return CausalLM .fallback (
795
+ return transformers_causal_lm_class .fallback (
771
796
model_id ,
772
797
revision ,
773
798
quantize = quantize ,
@@ -778,7 +803,7 @@ def get_model(
778
803
elif sharded :
779
804
raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded GPT-J" ))
780
805
else :
781
- return CausalLM .fallback (
806
+ return transformers_causal_lm_class .fallback (
782
807
model_id ,
783
808
revision ,
784
809
quantize = quantize ,
@@ -815,7 +840,7 @@ def get_model(
815
840
trust_remote_code = trust_remote_code ,
816
841
)
817
842
else :
818
- return CausalLM .fallback (
843
+ return transformers_causal_lm_class .fallback (
819
844
model_id ,
820
845
revision ,
821
846
quantize = quantize ,
@@ -838,7 +863,7 @@ def get_model(
838
863
lora_adapter_ids = lora_adapter_ids ,
839
864
)
840
865
else :
841
- return CausalLM .fallback (
866
+ return transformers_causal_lm_class .fallback (
842
867
model_id ,
843
868
revision ,
844
869
quantize = quantize ,
@@ -862,7 +887,7 @@ def get_model(
862
887
lora_adapter_ids = lora_adapter_ids ,
863
888
)
864
889
else :
865
- return CausalLM .fallback (
890
+ return transformers_causal_lm_class .fallback (
866
891
model_id ,
867
892
revision ,
868
893
quantize = quantize ,
@@ -911,7 +936,7 @@ def get_model(
911
936
FLASH_ATT_ERROR_MESSAGE .format (f"Sharded { model_type } " )
912
937
)
913
938
else :
914
- return CausalLM .fallback (
939
+ return transformers_causal_lm_class .fallback (
915
940
model_id ,
916
941
revision ,
917
942
quantize = quantize ,
@@ -937,7 +962,7 @@ def get_model(
937
962
elif sharded :
938
963
raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded Gemma" ))
939
964
else :
940
- return CausalLM .fallback (
965
+ return transformers_causal_lm_class .fallback (
941
966
model_id ,
942
967
revision ,
943
968
quantize = quantize ,
@@ -963,7 +988,7 @@ def get_model(
963
988
elif sharded :
964
989
raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded Gemma2" ))
965
990
else :
966
- return CausalLM .fallback (
991
+ return transformers_causal_lm_class .fallback (
967
992
model_id ,
968
993
revision ,
969
994
quantize = quantize ,
@@ -988,7 +1013,7 @@ def get_model(
988
1013
elif sharded :
989
1014
raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded Cohere" ))
990
1015
else :
991
- return CausalLM .fallback (
1016
+ return transformers_causal_lm_class .fallback (
992
1017
model_id ,
993
1018
revision ,
994
1019
quantize = quantize ,
@@ -1016,7 +1041,7 @@ def get_model(
1016
1041
elif sharded :
1017
1042
raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded DBRX" ))
1018
1043
else :
1019
- return CausalLM .fallback (
1044
+ return transformers_causal_lm_class .fallback (
1020
1045
model_id ,
1021
1046
revision ,
1022
1047
quantize = quantize ,
@@ -1066,7 +1091,7 @@ def get_model(
1066
1091
config_class = RWConfig ,
1067
1092
)
1068
1093
else :
1069
- return CausalLM .fallback (
1094
+ return transformers_causal_lm_class .fallback (
1070
1095
model_id ,
1071
1096
revision ,
1072
1097
quantize = quantize ,
@@ -1091,7 +1116,7 @@ def get_model(
1091
1116
elif sharded :
1092
1117
raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded Mistral" ))
1093
1118
else :
1094
- return CausalLM .fallback (
1119
+ return transformers_causal_lm_class .fallback (
1095
1120
model_id ,
1096
1121
revision ,
1097
1122
quantize = quantize ,
@@ -1116,7 +1141,7 @@ def get_model(
1116
1141
elif sharded :
1117
1142
raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded Mixtral" ))
1118
1143
else :
1119
- return CausalLM .fallback (
1144
+ return transformers_causal_lm_class .fallback (
1120
1145
model_id ,
1121
1146
revision ,
1122
1147
quantize = quantize ,
@@ -1143,7 +1168,7 @@ def get_model(
1143
1168
FLASH_ATT_ERROR_MESSAGE .format ("Sharded Starcoder2" )
1144
1169
)
1145
1170
else :
1146
- return CausalLM .fallback (
1171
+ return transformers_causal_lm_class .fallback (
1147
1172
model_id ,
1148
1173
revision ,
1149
1174
quantize = quantize ,
@@ -1168,7 +1193,7 @@ def get_model(
1168
1193
elif sharded :
1169
1194
raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format ("Sharded Qwen2" ))
1170
1195
else :
1171
- return CausalLM .fallback (
1196
+ return transformers_causal_lm_class .fallback (
1172
1197
model_id ,
1173
1198
revision ,
1174
1199
quantize = quantize ,
@@ -1329,7 +1354,7 @@ def get_model(
1329
1354
elif quantize == "exl2" :
1330
1355
raise NotImplementedError ("exl2 quantization is not supported for AutoModel" )
1331
1356
if model_type in modeling_auto .MODEL_FOR_CAUSAL_LM_MAPPING_NAMES :
1332
- return CausalLM .fallback (
1357
+ return transformers_causal_lm_class .fallback (
1333
1358
model_id ,
1334
1359
revision ,
1335
1360
quantize = quantize ,
@@ -1350,7 +1375,7 @@ def get_model(
1350
1375
auto_map = config_dict .get ("auto_map" , None )
1351
1376
if trust_remote_code and auto_map is not None :
1352
1377
if "AutoModelForCausalLM" in auto_map .keys ():
1353
- return CausalLM .fallback (
1378
+ return transformers_causal_lm_class .fallback (
1354
1379
model_id ,
1355
1380
revision ,
1356
1381
quantize = quantize ,
0 commit comments