Skip to content

Commit b980848

Browse files
CyrilvallezSystem administrator
and
System administrator
authoredJan 21, 2025··
Flash Transformers modeling backend support (#2913)
* add transformers_flash * inits * switch version to make it work * Update Makefile-flash-att-v2 * Update Makefile-flash-att-v2 * Update Makefile-flash-att-v2 * Update Makefile-flash-att-v2 * Update Makefile-flash-att-v2 * Update Makefile-flash-att-v2 * runnable version * working * push change * fix high dim * init * default * latest transformers changes * revert * simplify check * remove flag * improve type hints + required args * Update based on transformers PR * small fix * Remove Warpers for Processor * fix compatibility version issue * raise error if needed * Simplify with monkey patch * revert + style + minor improvements * update comment * device check * move the import to avoid device issue * Update __init__.py * check for non-native models * oupsi --------- Co-authored-by: System administrator <root@ip-10-90-0-159.ec2.internal>
1 parent 447a5b2 commit b980848

File tree

4 files changed

+330
-31
lines changed

4 files changed

+330
-31
lines changed
 

‎server/text_generation_server/layers/gptq/quantize.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -956,15 +956,24 @@ def _unload():
956956

957957
pack(model, quantizers, bits, groupsize)
958958
from safetensors.torch import save_file
959-
from transformers.modeling_utils import shard_checkpoint
959+
from huggingface_hub import split_torch_state_dict_into_shards
960960

961961
state_dict = model.state_dict()
962962
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
963963

964964
max_shard_size = "10GB"
965-
shards, index = shard_checkpoint(
966-
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
965+
state_dict_split = split_torch_state_dict_into_shards(
966+
state_dict,
967+
filename_pattern="model.safetensors",
968+
max_shard_size=max_shard_size,
967969
)
970+
index = None
971+
if state_dict_split.is_sharded:
972+
index = {
973+
"metadata": state_dict_split.metadata,
974+
"weight_map": state_dict_split.tensor_to_filename,
975+
}
976+
shards = state_dict_split.filename_to_tensors
968977
os.makedirs(output_dir, exist_ok=True)
969978
for shard_file, shard in shards.items():
970979
save_file(

‎server/text_generation_server/models/__init__.py

+46-21
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
from huggingface_hub import hf_hub_download, HfApi
1717
from typing import Optional, List, Dict
1818
from pathlib import Path
19+
import transformers
1920

2021
from text_generation_server.utils.speculate import get_speculate, set_speculate
2122
from text_generation_server.models.model import Model
2223
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
24+
2325
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
2426
from text_generation_server.models.custom_modeling.mpt_modeling import (
2527
MPTForCausalLM,
@@ -178,6 +180,14 @@
178180
if MAMBA_AVAILABLE:
179181
__all__.append(Mamba)
180182

183+
FLASH_TRANSFORMERS_BACKEND = True
184+
try:
185+
from text_generation_server.models.transformers_flash_causal_lm import (
186+
TransformersFlashCausalLM,
187+
)
188+
except ImportError:
189+
FLASH_TRANSFORMERS_BACKEND = False
190+
181191

182192
class ModelType(enum.Enum):
183193
DEEPSEEK_V2 = {
@@ -381,6 +391,21 @@ def get_model(
381391
)
382392
model_type = config_dict.get("model_type", None)
383393

394+
transformers_causal_lm_class = CausalLM
395+
396+
# Fast transformers path
397+
transformers_model_class = getattr(
398+
transformers,
399+
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
400+
None,
401+
)
402+
if (
403+
FLASH_TRANSFORMERS_BACKEND
404+
and transformers_model_class is not None
405+
and transformers_model_class._supports_flex_attn
406+
):
407+
transformers_causal_lm_class = TransformersFlashCausalLM
408+
384409
quantization_config = config_dict.get("quantization_config", None)
385410
if quantization_config is None:
386411
quantization_config = config_dict.get("compression_config", None)
@@ -624,7 +649,7 @@ def get_model(
624649
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
625650
)
626651
else:
627-
return CausalLM.fallback(
652+
return transformers_causal_lm_class.fallback(
628653
model_id,
629654
revision,
630655
quantize=quantize,
@@ -683,7 +708,7 @@ def get_model(
683708
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
684709
)
685710
else:
686-
return CausalLM.fallback(
711+
return transformers_causal_lm_class.fallback(
687712
model_id=model_id,
688713
revision=revision,
689714
quantize=quantize,
@@ -731,7 +756,7 @@ def get_model(
731756
except RuntimeError as e:
732757
# Lots of legacy models with various weight names.
733758
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
734-
return CausalLM.fallback(
759+
return transformers_causal_lm_class.fallback(
735760
model_id,
736761
revision,
737762
quantize=quantize,
@@ -742,7 +767,7 @@ def get_model(
742767
elif sharded:
743768
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
744769
else:
745-
return CausalLM.fallback(
770+
return transformers_causal_lm_class.fallback(
746771
model_id,
747772
revision,
748773
quantize=quantize,
@@ -767,7 +792,7 @@ def get_model(
767792
except RuntimeError as e:
768793
# Lots of legacy models with various weight names.
769794
log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
770-
return CausalLM.fallback(
795+
return transformers_causal_lm_class.fallback(
771796
model_id,
772797
revision,
773798
quantize=quantize,
@@ -778,7 +803,7 @@ def get_model(
778803
elif sharded:
779804
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
780805
else:
781-
return CausalLM.fallback(
806+
return transformers_causal_lm_class.fallback(
782807
model_id,
783808
revision,
784809
quantize=quantize,
@@ -815,7 +840,7 @@ def get_model(
815840
trust_remote_code=trust_remote_code,
816841
)
817842
else:
818-
return CausalLM.fallback(
843+
return transformers_causal_lm_class.fallback(
819844
model_id,
820845
revision,
821846
quantize=quantize,
@@ -838,7 +863,7 @@ def get_model(
838863
lora_adapter_ids=lora_adapter_ids,
839864
)
840865
else:
841-
return CausalLM.fallback(
866+
return transformers_causal_lm_class.fallback(
842867
model_id,
843868
revision,
844869
quantize=quantize,
@@ -862,7 +887,7 @@ def get_model(
862887
lora_adapter_ids=lora_adapter_ids,
863888
)
864889
else:
865-
return CausalLM.fallback(
890+
return transformers_causal_lm_class.fallback(
866891
model_id,
867892
revision,
868893
quantize=quantize,
@@ -911,7 +936,7 @@ def get_model(
911936
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
912937
)
913938
else:
914-
return CausalLM.fallback(
939+
return transformers_causal_lm_class.fallback(
915940
model_id,
916941
revision,
917942
quantize=quantize,
@@ -937,7 +962,7 @@ def get_model(
937962
elif sharded:
938963
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
939964
else:
940-
return CausalLM.fallback(
965+
return transformers_causal_lm_class.fallback(
941966
model_id,
942967
revision,
943968
quantize=quantize,
@@ -963,7 +988,7 @@ def get_model(
963988
elif sharded:
964989
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
965990
else:
966-
return CausalLM.fallback(
991+
return transformers_causal_lm_class.fallback(
967992
model_id,
968993
revision,
969994
quantize=quantize,
@@ -988,7 +1013,7 @@ def get_model(
9881013
elif sharded:
9891014
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
9901015
else:
991-
return CausalLM.fallback(
1016+
return transformers_causal_lm_class.fallback(
9921017
model_id,
9931018
revision,
9941019
quantize=quantize,
@@ -1016,7 +1041,7 @@ def get_model(
10161041
elif sharded:
10171042
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
10181043
else:
1019-
return CausalLM.fallback(
1044+
return transformers_causal_lm_class.fallback(
10201045
model_id,
10211046
revision,
10221047
quantize=quantize,
@@ -1066,7 +1091,7 @@ def get_model(
10661091
config_class=RWConfig,
10671092
)
10681093
else:
1069-
return CausalLM.fallback(
1094+
return transformers_causal_lm_class.fallback(
10701095
model_id,
10711096
revision,
10721097
quantize=quantize,
@@ -1091,7 +1116,7 @@ def get_model(
10911116
elif sharded:
10921117
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
10931118
else:
1094-
return CausalLM.fallback(
1119+
return transformers_causal_lm_class.fallback(
10951120
model_id,
10961121
revision,
10971122
quantize=quantize,
@@ -1116,7 +1141,7 @@ def get_model(
11161141
elif sharded:
11171142
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
11181143
else:
1119-
return CausalLM.fallback(
1144+
return transformers_causal_lm_class.fallback(
11201145
model_id,
11211146
revision,
11221147
quantize=quantize,
@@ -1143,7 +1168,7 @@ def get_model(
11431168
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
11441169
)
11451170
else:
1146-
return CausalLM.fallback(
1171+
return transformers_causal_lm_class.fallback(
11471172
model_id,
11481173
revision,
11491174
quantize=quantize,
@@ -1168,7 +1193,7 @@ def get_model(
11681193
elif sharded:
11691194
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
11701195
else:
1171-
return CausalLM.fallback(
1196+
return transformers_causal_lm_class.fallback(
11721197
model_id,
11731198
revision,
11741199
quantize=quantize,
@@ -1329,7 +1354,7 @@ def get_model(
13291354
elif quantize == "exl2":
13301355
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
13311356
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
1332-
return CausalLM.fallback(
1357+
return transformers_causal_lm_class.fallback(
13331358
model_id,
13341359
revision,
13351360
quantize=quantize,
@@ -1350,7 +1375,7 @@ def get_model(
13501375
auto_map = config_dict.get("auto_map", None)
13511376
if trust_remote_code and auto_map is not None:
13521377
if "AutoModelForCausalLM" in auto_map.keys():
1353-
return CausalLM.fallback(
1378+
return transformers_causal_lm_class.fallback(
13541379
model_id,
13551380
revision,
13561381
quantize=quantize,

0 commit comments

Comments
 (0)
Please sign in to comment.