Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash Transformers modeling backend support #2913

Merged
merged 33 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ade0f44
add transformers_flash
Cyrilvallez Dec 10, 2024
da22290
inits
Cyrilvallez Dec 10, 2024
b3b0747
switch version to make it work
Cyrilvallez Dec 10, 2024
738f0b0
Update Makefile-flash-att-v2
Cyrilvallez Dec 10, 2024
a84ecf2
Update Makefile-flash-att-v2
Cyrilvallez Dec 10, 2024
372799a
Update Makefile-flash-att-v2
Cyrilvallez Dec 10, 2024
a0035e6
Update Makefile-flash-att-v2
Cyrilvallez Dec 11, 2024
e69a384
Update Makefile-flash-att-v2
Cyrilvallez Dec 11, 2024
3a636ed
Update Makefile-flash-att-v2
Cyrilvallez Dec 11, 2024
649cb1f
runnable version
Dec 12, 2024
490ca0e
working
Dec 12, 2024
f843b62
push change
Cyrilvallez Dec 12, 2024
715b2d1
fix high dim
Cyrilvallez Dec 13, 2024
e93ab92
init
Cyrilvallez Dec 13, 2024
f4c60ca
default
Cyrilvallez Dec 13, 2024
2e2631e
latest transformers changes
Cyrilvallez Dec 19, 2024
44b3679
revert
Cyrilvallez Dec 19, 2024
266377b
simplify check
Cyrilvallez Jan 15, 2025
32488c1
remove flag
Cyrilvallez Jan 17, 2025
ac62bd1
improve type hints + required args
Cyrilvallez Jan 17, 2025
b03d7ae
Update based on transformers PR
Cyrilvallez Jan 17, 2025
b40c889
small fix
Cyrilvallez Jan 17, 2025
42ae6de
Remove Warpers for Processor
Cyrilvallez Jan 17, 2025
f01014d
fix compatibility version issue
Cyrilvallez Jan 17, 2025
2659b59
raise error if needed
Cyrilvallez Jan 20, 2025
a2fe842
Simplify with monkey patch
Cyrilvallez Jan 20, 2025
6e0f37c
revert + style + minor improvements
Cyrilvallez Jan 20, 2025
52afdcc
update comment
Cyrilvallez Jan 20, 2025
9af3ea4
device check
Cyrilvallez Jan 20, 2025
6d9c011
move the import to avoid device issue
Cyrilvallez Jan 20, 2025
2ef3002
Update __init__.py
Cyrilvallez Jan 20, 2025
70ada57
check for non-native models
Cyrilvallez Jan 20, 2025
0d9ec75
oupsi
Cyrilvallez Jan 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions server/text_generation_server/layers/gptq/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,15 +956,24 @@ def _unload():

pack(model, quantizers, bits, groupsize)
from safetensors.torch import save_file
from transformers.modeling_utils import shard_checkpoint
from huggingface_hub import split_torch_state_dict_into_shards

state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}

max_shard_size = "10GB"
shards, index = shard_checkpoint(
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
state_dict_split = split_torch_state_dict_into_shards(
state_dict,
filename_pattern="model.safetensors",
max_shard_size=max_shard_size,
)
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
shards = state_dict_split.filename_to_tensors
os.makedirs(output_dir, exist_ok=True)
for shard_file, shard in shards.items():
save_file(
Expand Down
67 changes: 46 additions & 21 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List, Dict
from pathlib import Path
import transformers

from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast

from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
Expand Down Expand Up @@ -175,6 +177,14 @@
if MAMBA_AVAILABLE:
__all__.append(Mamba)

FLASH_TRANSFORMERS_BACKEND = True
try:
from text_generation_server.models.transformers_flash_causal_lm import (
TransformersFlashCausalLM,
)
except ImportError:
FLASH_TRANSFORMERS_BACKEND = False


class ModelType(enum.Enum):
DEEPSEEK_V2 = {
Expand Down Expand Up @@ -372,6 +382,21 @@ def get_model(
)
model_type = config_dict.get("model_type", None)

transformers_causal_lm_class = CausalLM

# Fast transformers path
transformers_model_class = getattr(
transformers,
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
None,
)
if (
FLASH_TRANSFORMERS_BACKEND
and transformers_model_class is not None
and transformers_model_class._supports_flex_attn
):
transformers_causal_lm_class = TransformersFlashCausalLM

quantization_config = config_dict.get("quantization_config", None)
if quantization_config is None:
quantization_config = config_dict.get("compression_config", None)
Expand Down Expand Up @@ -615,7 +640,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand Down Expand Up @@ -674,7 +699,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id=model_id,
revision=revision,
quantize=quantize,
Expand Down Expand Up @@ -722,7 +747,7 @@ def get_model(
except RuntimeError as e:
# Lots of legacy models with various weight names.
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -733,7 +758,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -758,7 +783,7 @@ def get_model(
except RuntimeError as e:
# Lots of legacy models with various weight names.
log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -769,7 +794,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand Down Expand Up @@ -806,7 +831,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -829,7 +854,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -853,7 +878,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand Down Expand Up @@ -902,7 +927,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -928,7 +953,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -954,7 +979,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -979,7 +1004,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand Down Expand Up @@ -1007,7 +1032,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand Down Expand Up @@ -1057,7 +1082,7 @@ def get_model(
config_class=RWConfig,
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -1082,7 +1107,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -1107,7 +1132,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -1134,7 +1159,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -1159,7 +1184,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand Down Expand Up @@ -1302,7 +1327,7 @@ def get_model(
elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -1323,7 +1348,7 @@ def get_model(
auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
Expand Down
Loading
Loading