From ade0f44aca75f5a5e9dd29141fb57013070e15fc Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 10 Dec 2024 16:46:55 +0100 Subject: [PATCH 01/33] add transformers_flash --- .../text_generation_server/models/__init__.py | 29 +- .../text_generation_server/models/globals.py | 4 + .../models/transformers_flash_causal_lm.py | 309 ++++++++++++++++++ 3 files changed, 341 insertions(+), 1 deletion(-) create mode 100644 server/text_generation_server/models/transformers_flash_causal_lm.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fcc79608645..2f3ccc2dce1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -20,6 +20,7 @@ from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast +from text_generation_server.models.transformers_flash_causal_lm import TransformersFlashCausalLM from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.models.custom_modeling.mpt_modeling import ( MPTForCausalLM, @@ -28,7 +29,7 @@ from text_generation_server.models.custom_modeling.bloom_modeling import ( BloomForCausalLM, ) -from text_generation_server.models.globals import ATTENTION +from text_generation_server.models.globals import ATTENTION, USE_CUSTOM_MODELING from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.galactica import GalacticaCausalLMBatch from text_generation_server.models.custom_modeling.neox_modeling import ( @@ -366,12 +367,38 @@ def get_model( max_input_tokens: int, ) -> Model: global FLASH_ATTENTION + global USE_CUSTOM_MODELING config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) model_type = config_dict.get("model_type", None) + transformers_causal_lm_class = CausalLM + if ( + not USE_CUSTOM_MODELING + and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + ): + logger.info( + "TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback." + ) + transformers_model_class = getattr( + transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] + ) + + if ( + transformers_model_class._supports_flash_attn_2 + and transformers_model_class._supports_cache_class + ): + logger.info( + f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersFlashCausalLM with ragged tensors (single dimension for batch and sequence length)." + ) + transformers_causal_lm_class = TransformersFlashCausalLM + else: + logger.info( + f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersCausalLM with classic tensors with padding (two dimensions for batch size and sequence length)." + ) + quantization_config = config_dict.get("quantization_config", None) if quantization_config is None: quantization_config = config_dict.get("compression_config", None) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 8d988ad5870..7d6639f218b 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -67,3 +67,7 @@ def set_adapter_to_index(adapter_to_index: Dict[str, int]): def get_adapter_to_index(): global ADAPTER_TO_INDEX return ADAPTER_TO_INDEX + + +USE_CUSTOM_MODELING = os.getenv("USE_CUSTOM_MODELING", "true") +USE_CUSTOM_MODELING = USE_CUSTOM_MODELING == "true" or USE_CUSTOM_MODELING == "1" diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py new file mode 100644 index 00000000000..ff76b2cc203 --- /dev/null +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -0,0 +1,309 @@ +import math +import sys +from typing import Optional, Tuple, Dict, Any + +import torch +from opentelemetry import trace +from loguru import logger +from transformers import AutoTokenizer, AutoModelForCausalLM + +from text_generation_server.models.flash_causal_lm import ( + FlashCausalLMBatch, + FlashCausalLM, +) +from text_generation_server.utils.import_utils import ( + empty_cache, + synchronize, + get_free_memory, +) +from text_generation_server.adapters import AdapterBatchData +from text_generation_server.layers.attention import paged_attention, attention, Seqlen +from text_generation_server.layers.attention.kv_cache import KVScales +from text_generation_server.models.globals import ATTENTION +from text_generation_server.models.metadata_kernels import block_tables_to_ragged + + +tracer = trace.get_tracer(__name__) + + +def patch_everywhere( + attribute_name: str, patch: Any, module_name_prefix: Optional[str] = None +): + """ + Finds all occurences of `attribute_name` in the loaded modules and patches them with `patch`. + + Args: + attribute_name (`str`): + The name of attribute to patch. + patch (`Any`): + The patch for the attribute. + module_name_prefix (`Optional[str]`, defaults to `None`): + If set, only module names starting with this prefix will be considered for patching. + """ + # sys.modules may be updated while being iterated over, hence the list copy. + for name in list(sys.modules): + module = sys.modules[name] + if module_name_prefix is not None and not name.startswith(module_name_prefix): + continue + if hasattr(module, attribute_name): + setattr(module, attribute_name, patch) + + +def _flash_attention_forward_patched( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool, + softmax_scale: Optional[float] = None, + sliding_window: int = -1, + softcap: Optional[float] = None, + **kwargs, +): + + kv_cache = kwargs["kv_cache"][kwargs["layer_idx"]] + # This means no scale + kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) + + # Correctly reshape the states + _, _, num_heads, head_dim = query_states.size() + _, _, num_kv_heads, _ = key_states.size() + query_states = query_states.view(-1, num_heads, head_dim) + key_states = key_states.view(-1, num_kv_heads, head_dim) + value_states = value_states.view(-1, num_kv_heads, head_dim) + + # Take care of updating the cache in-place + kv_cache.store( + key=key_states, + value=value_states, + slots=kwargs["slots"], + kv_scales=kv_scales + ) + + softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale + + if kwargs["cu_seqlen_prefill"] is not None: + attn_output = attention( + query=query_states, + key=key_states, + value=value_states, + kv_cache=kv_cache, + kv_scales=kv_scales, + seqlen=kwargs["seqlen"], + block_tables=kwargs["block_tables"], + softmax_scale=softmax_scale, + window_size_left=sliding_window, + softcap=softcap, + ) + else: + attn_output = paged_attention( + query_states, + kv_cache, + kwargs["kv_head_mapping"], + softmax_scale, + kwargs["block_tables"], + kwargs["seqlen"], + kwargs["max_s"], + kv_scales=kv_scales, + softcap=softcap, + ) + + attn_output = attn_output.view(attn_output.shape[0], -1) + + return attn_output + + +class TransformersFlashCausalLM(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + if speculator: + raise RuntimeError("Speculator decoding is not enabled for AutoModel") + + device_count = 0 + if torch.cuda.is_available(): + device = torch.device("cuda") + device_count = torch.cuda.device_count() + dtype = torch.float16 if dtype is None else dtype + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device("xpu") + device_count = torch.xpu.device_count() + dtype = torch.float16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + model = AutoModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + device_map=("auto" if device_count > 1 else None), + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + ) + if device_count == 1 and quantize != "bitsandbytes": + model = model.to(device) + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None and isinstance( + model.config.eos_token_id, int + ): + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + + self.num_layers = len(model.model.layers) + self.num_kv_heads = model.config.num_key_value_heads + self.head_size = model.config.hidden_size // model.config.num_attention_heads + + # Skip FlashCausalLM init. + super(FlashCausalLM, self).__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, + device=device, + ) + + def warmup(self, batch: FlashCausalLMBatch): + patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched) + super().warmup(batch) + + def forward( + self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # NOTE: adapter_data: not supported + + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length + lm_head_indices = batch.prefill_head_indices + + if cu_seqlen_prefill is None and self.max_past() is not None: + # In decode, not prefill, we're actually overwriting the KV-cache + # in a circular buffer mode. + # This makes sure the max_s for the decode pass is correct. + max_s = min(self.max_past(), max_s) + + bs = input_ids.shape[0] + sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) + if sorted_padded_bs: + # Get associated cuda graph + cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] + else: + cuda_graph = None + + if cu_seqlen_prefill is not None or cuda_graph is None: + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=batch.max_current_length, + ) + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=cu_seqlen_prefill, + input_lengths_tensor=input_lengths, + cache_lengths_tensor=cache_lengths_tensor, + ): + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=batch.max_input_length, + max_k=batch.max_current_length, + ) + logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=None, + use_cache=False, # we use self.kv_cache instead of transformers cache object + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + return logits, None + + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=batch.max_current_length, + ) + # assert block_tables.shape[0] >= slots.shape[0] + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables + + # XXX: This is working only because block 0 is reserved for the healthcheck + # so it doesn't matter if we override it with bogus values. + cuda_graph["slots"].fill_(0) + cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["cache_lengths"].zero_() + cuda_graph["cache_lengths"][ + : cache_lengths_tensor.shape[0] + ] = cache_lengths_tensor + + with self._forward_context( + block_tables=cuda_graph["block_tables"], + cu_seqlen_prefill=None, + input_lengths_tensor=cuda_graph["input_lengths"], + cache_lengths_tensor=cuda_graph["cache_lengths"], + state=cuda_graph["state"], + ): + # Replay the graph + cuda_graph["graph"].replay() + + # Slice output to the correct shape + logits = cuda_graph["logits"][:bs] + return logits, None From da222900a1076733f052d9ffab72a607aca85375 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 10 Dec 2024 16:57:07 +0100 Subject: [PATCH 02/33] inits --- .../text_generation_server/models/__init__.py | 89 ++++++++++--------- .../models/transformers_flash_causal_lm.py | 12 +++ 2 files changed, 57 insertions(+), 44 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 2f3ccc2dce1..35ab8edec4e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -16,6 +16,7 @@ from huggingface_hub import hf_hub_download, HfApi from typing import Optional, List, Dict from pathlib import Path +import transformers from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model @@ -617,7 +618,7 @@ def get_model( ) if model_type == DEEPSEEK_V2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: head_size = max( config_dict.get("qk_nope_dim", 128) + config_dict.get("qk_rope_dim", 64), @@ -642,7 +643,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2") ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -682,7 +683,7 @@ def get_model( or model_type == GPT2 and model_id.startswith("bigcode/") ): - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashSantacoderForCausalLM, @@ -701,7 +702,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id=model_id, revision=revision, quantize=quantize, @@ -733,7 +734,7 @@ def get_model( batch_class=CausalLMBatchKeysLast, ) elif model_type == GPT2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: try: return FlashCausalLM( model_id=model_id, @@ -749,7 +750,7 @@ def get_model( except RuntimeError as e: # Lots of legacy models with various weight names. log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}") - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -760,7 +761,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -769,7 +770,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == GPTJ: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: try: return FlashCausalLM( model_id=model_id, @@ -785,7 +786,7 @@ def get_model( except RuntimeError as e: # Lots of legacy models with various weight names. log_master(logger.warning, f"Couldn't load flash gptj variant: {e}") - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -796,7 +797,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -805,7 +806,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == GPT_NEOX: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: from text_generation_server.models.custom_modeling.flash_neox_modeling import ( GPTNeoXConfig, ) @@ -833,7 +834,7 @@ def get_model( trust_remote_code=trust_remote_code, ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -843,7 +844,7 @@ def get_model( ) elif model_type == PHI: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashPhiForCausalLM, @@ -856,7 +857,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -866,7 +867,7 @@ def get_model( ) elif model_type == PHI_MOE: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashLlamaForCausalLM, @@ -880,7 +881,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -890,7 +891,7 @@ def get_model( ) elif model_type == "phi-msft": - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: raise NotImplementedError( "Legacy phi-msft is not supported with Flash Attention" ) @@ -912,7 +913,7 @@ def get_model( or model_type == PHI3 or model_type == GRANITE ): - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashLlamaForCausalLM, @@ -929,7 +930,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}") ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -938,7 +939,7 @@ def get_model( trust_remote_code=trust_remote_code, ) if model_type == GEMMA: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashGemmaForCausalLM, @@ -955,7 +956,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -964,7 +965,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == GEMMA2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashGemma2ForCausalLM, @@ -981,7 +982,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -991,7 +992,7 @@ def get_model( ) if model_type == COHERE: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashCohereForCausalLM, @@ -1006,7 +1007,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1016,7 +1017,7 @@ def get_model( ) if model_type == DBRX: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashDbrxForCausalLM, @@ -1034,7 +1035,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1045,7 +1046,7 @@ def get_model( if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]: if sharded: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: if config_dict.get("alibi", False): raise NotImplementedError("sharded is not supported for this model") return FlashCausalLM( @@ -1084,7 +1085,7 @@ def get_model( config_class=RWConfig, ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1094,7 +1095,7 @@ def get_model( ) if model_type == MISTRAL: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashMistralForCausalLM, @@ -1109,7 +1110,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1119,7 +1120,7 @@ def get_model( ) if model_type == MIXTRAL: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashMixtralForCausalLM, @@ -1134,7 +1135,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1144,7 +1145,7 @@ def get_model( ) if model_type == STARCODER2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashStarcoder2ForCausalLM, @@ -1161,7 +1162,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1171,7 +1172,7 @@ def get_model( ) if model_type == QWEN2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=Qwen2ForCausalLM, @@ -1186,7 +1187,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1223,7 +1224,7 @@ def get_model( }, ) if model_type == IDEFICS: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return IdeficsCausalLM( model_id, revision, @@ -1247,7 +1248,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) if model_type == MLLAMA: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return MllamaCausalLM( model_id=model_id, model_class=MllamaForConditionalGeneration, @@ -1263,7 +1264,7 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama")) if model_type == IDEFICS2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return VlmCausalLM( model_id=model_id, model_class=Idefics2ForConditionalGeneration, @@ -1281,7 +1282,7 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == PALIGEMMA: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return VlmCausalLM( model_id=model_id, model_class=PaliGemmaForConditionalGeneration, @@ -1300,7 +1301,7 @@ def get_model( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == LLAVA_NEXT: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return VlmCausalLM( model_class=LlavaNextForConditionalGeneration, model_id=model_id, @@ -1329,7 +1330,7 @@ def get_model( elif quantize == "exl2": raise NotImplementedError("exl2 quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1350,7 +1351,7 @@ def get_model( auto_map = config_dict.get("auto_map", None) if trust_remote_code and auto_map is not None: if "AutoModelForCausalLM" in auto_map.keys(): - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index ff76b2cc203..de2570b0512 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -188,6 +188,18 @@ def __init__( device=device, ) + @classmethod + def fallback( + cls, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + return cls(model_id, revision, quantize, speculator, dtype, trust_remote_code) + def warmup(self, batch: FlashCausalLMBatch): patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched) super().warmup(batch) From b3b0747432bad9cbc82eb0947f2f6f8ea05b9270 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 10 Dec 2024 18:14:22 +0100 Subject: [PATCH 03/33] switch version to make it work --- server/Makefile-flash-att-v2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index a9cdf782270..51b304bde9a 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit_cuda := v2.6.1 +flash_att_v2_commit_cuda := v2.6.3 flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4 build-flash-attention-v2-cuda: From 738f0b0e35ca1d4deccdc78786c49ca902b1554f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 10 Dec 2024 18:37:22 +0100 Subject: [PATCH 04/33] Update Makefile-flash-att-v2 --- server/Makefile-flash-att-v2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 51b304bde9a..a9cdf782270 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit_cuda := v2.6.3 +flash_att_v2_commit_cuda := v2.6.1 flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4 build-flash-attention-v2-cuda: From a84ecf26aaede5981bc3a9ecf90e1b93deccc6cb Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 10 Dec 2024 18:43:44 +0100 Subject: [PATCH 05/33] Update Makefile-flash-att-v2 --- server/Makefile-flash-att-v2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index a9cdf782270..51b304bde9a 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit_cuda := v2.6.1 +flash_att_v2_commit_cuda := v2.6.3 flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4 build-flash-attention-v2-cuda: From 372799a4212224b89b2183c213109b62e5ca469a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 10 Dec 2024 20:24:25 +0100 Subject: [PATCH 06/33] Update Makefile-flash-att-v2 --- server/Makefile-flash-att-v2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 51b304bde9a..a9cdf782270 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit_cuda := v2.6.3 +flash_att_v2_commit_cuda := v2.6.1 flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4 build-flash-attention-v2-cuda: From a0035e660740d4f2dc70eda6066c7877f3f51a0b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Dec 2024 10:05:23 +0100 Subject: [PATCH 07/33] Update Makefile-flash-att-v2 --- server/Makefile-flash-att-v2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index a9cdf782270..51b304bde9a 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit_cuda := v2.6.1 +flash_att_v2_commit_cuda := v2.6.3 flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4 build-flash-attention-v2-cuda: From e69a384dfbe2f5292c81b8c0df3c6fe9d7023907 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Dec 2024 10:58:23 +0100 Subject: [PATCH 08/33] Update Makefile-flash-att-v2 --- server/Makefile-flash-att-v2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 51b304bde9a..27d30c1a9ba 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit_cuda := v2.6.3 +flash_att_v2_commit_cuda := v2.7.2.post1 flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4 build-flash-attention-v2-cuda: From 3a636ed1654a0bbdde819901335532aa1949ba5e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Dec 2024 11:25:12 +0100 Subject: [PATCH 09/33] Update Makefile-flash-att-v2 --- server/Makefile-flash-att-v2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 27d30c1a9ba..a9cdf782270 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit_cuda := v2.7.2.post1 +flash_att_v2_commit_cuda := v2.6.1 flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4 build-flash-attention-v2-cuda: From 649cb1f5f1a99049c45c03ecbca39efb66fe4133 Mon Sep 17 00:00:00 2001 From: System administrator Date: Thu, 12 Dec 2024 14:27:07 +0000 Subject: [PATCH 10/33] runnable version --- .../text_generation_server/models/globals.py | 2 +- .../models/transformers_flash_causal_lm.py | 87 ++++++++++++++++--- 2 files changed, 74 insertions(+), 15 deletions(-) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 7d6639f218b..7c7e026e107 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -8,7 +8,7 @@ REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} ATTENTION = os.environ["ATTENTION"] # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" -PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in { +PREFIX_CACHING = os.environ["USE_PREFIX_CACHING"].lower() in { "1", "true", } diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index de2570b0512..f4f2474996b 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -5,12 +5,17 @@ import torch from opentelemetry import trace from loguru import logger -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, ) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) from text_generation_server.utils.import_utils import ( empty_cache, synchronize, @@ -57,7 +62,7 @@ def _flash_attention_forward_patched( query_length: int, is_causal: bool, softmax_scale: Optional[float] = None, - sliding_window: int = -1, + sliding_window: Optional[int] = None, softcap: Optional[float] = None, **kwargs, ): @@ -67,11 +72,11 @@ def _flash_attention_forward_patched( kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) # Correctly reshape the states - _, _, num_heads, head_dim = query_states.size() - _, _, num_kv_heads, _ = key_states.size() - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) + _, num_heads, head_dim = query_states.size() + # _, num_kv_heads, _ = key_states.size() + # query_states = query_states.view(-1, num_heads, head_dim) + # key_states = key_states.view(-1, num_kv_heads, head_dim) + # value_states = value_states.view(-1, num_kv_heads, head_dim) # Take care of updating the cache in-place kv_cache.store( @@ -82,6 +87,7 @@ def _flash_attention_forward_patched( ) softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale + sliding_window = -1 if sliding_window is None else sliding_window if kwargs["cu_seqlen_prefill"] is not None: attn_output = attention( @@ -109,7 +115,8 @@ def _flash_attention_forward_patched( softcap=softcap, ) - attn_output = attn_output.view(attn_output.shape[0], -1) + # attn_output = attn_output.view(attn_output.shape[0], -1) + attn_output = attn_output.view(-1, num_heads * head_dim) return attn_output @@ -122,14 +129,21 @@ def __init__( quantize: Optional[str] = None, speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, trust_remote_code: bool = False, + tokenizer_class=AutoTokenizer, + config_class=AutoConfig, + kv_cache_dtype: Optional[torch.dtype] = None, ): + self.quantize = quantize + self.process_group, rank, world_size = initialize_torch_distributed() + if speculator: raise RuntimeError("Speculator decoding is not enabled for AutoModel") device_count = 0 if torch.cuda.is_available(): - device = torch.device("cuda") + device = torch.device("cuda:0") device_count = torch.cuda.device_count() dtype = torch.float16 if dtype is None else dtype elif hasattr(torch, "xpu") and torch.xpu.is_available(): @@ -157,6 +171,7 @@ def __init__( device_map=("auto" if device_count > 1 else None), load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, + attn_implementation="flash_attention_2" ) if device_count == 1 and quantize != "bitsandbytes": model = model.to(device) @@ -174,10 +189,44 @@ def __init__( tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - self.num_layers = len(model.model.layers) + self.num_layers = model.config.num_hidden_layers + self.num_heads = model.config.num_attention_heads // self.process_group.size() self.num_kv_heads = model.config.num_key_value_heads + self.num_kv_heads = ( + self.num_kv_heads // self.process_group.size() + if self.num_kv_heads > 1 + else self.num_kv_heads + ) self.head_size = model.config.hidden_size // model.config.num_attention_heads + self.cuda_graphs = {} + self.kv_cache = [] + self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype + + if ATTENTION == "flashinfer": + from text_generation_server.layers.attention.flashinfer import ( + create_prefill_state, + create_decode_state, + create_prefill_with_paged_kv_state, + ) + + self.prefill_state = create_prefill_state(device=device) + self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state( + device=device + ) + + self.decode_state = create_decode_state( + device=device, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + + self.num_groups = self.num_heads // self.num_kv_heads + self.kv_head_mapping = torch.arange( + 0, self.num_kv_heads, dtype=torch.int32, device=device + ).repeat_interleave(self.num_groups) + + torch.distributed.barrier(group=self.process_group) # Skip FlashCausalLM init. super(FlashCausalLM, self).__init__( model_id=model_id, @@ -186,6 +235,8 @@ def __init__( requires_padding=False, dtype=dtype, device=device, + rank=rank, + world_size=world_size, ) @classmethod @@ -198,11 +249,18 @@ def fallback( dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - return cls(model_id, revision, quantize, speculator, dtype, trust_remote_code) + return cls( + model_id=model_id, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) - def warmup(self, batch: FlashCausalLMBatch): + def warmup(self, batch: FlashCausalLMBatch, max_input_tokens: Optional[int], max_total_tokens: Optional[int],): patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched) - super().warmup(batch) + return super().warmup(batch, max_input_tokens, max_total_tokens) def forward( self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData @@ -270,7 +328,8 @@ def forward( max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, - ) + kv_head_mapping=self.kv_head_mapping, + ).logits if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None return logits, None From 490ca0ef6ac0fa5fa34e1b019b62ece2708c7580 Mon Sep 17 00:00:00 2001 From: System administrator Date: Thu, 12 Dec 2024 15:48:56 +0000 Subject: [PATCH 11/33] working --- .../models/transformers_flash_causal_lm.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index f4f2474996b..abfaa06ea33 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -72,11 +72,14 @@ def _flash_attention_forward_patched( kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) # Correctly reshape the states - _, num_heads, head_dim = query_states.size() - # _, num_kv_heads, _ = key_states.size() + _, _, num_heads, head_dim = query_states.size() + _, _, num_kv_heads, _ = key_states.size() # query_states = query_states.view(-1, num_heads, head_dim) # key_states = key_states.view(-1, num_kv_heads, head_dim) # value_states = value_states.view(-1, num_kv_heads, head_dim) + query_states = query_states.squeeze(dim=0) + key_states = key_states.squeeze(dim=0) + value_states = value_states.squeeze(dim=0) # Take care of updating the cache in-place kv_cache.store( @@ -316,8 +319,8 @@ def forward( max_k=batch.max_current_length, ) logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, + input_ids=input_ids[None, ...], + position_ids=position_ids[None, ...], past_key_values=None, use_cache=False, # we use self.kv_cache instead of transformers cache object cu_seqlen_prefill=cu_seqlen_prefill, @@ -329,7 +332,8 @@ def forward( prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, kv_head_mapping=self.kv_head_mapping, - ).logits + ).logits[0, ...] + print("SUCCESSFUL FORWARD") if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None return logits, None From f843b62a442eccbb8c5b5c6e07a40fc44e032d53 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 12 Dec 2024 18:25:11 +0000 Subject: [PATCH 12/33] push change --- .../text_generation_server/models/__init__.py | 3 +- .../models/transformers_flash_causal_lm.py | 73 +++++++------------ 2 files changed, 29 insertions(+), 47 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 35ab8edec4e..54665083ff6 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -375,7 +375,8 @@ def get_model( ) model_type = config_dict.get("model_type", None) - transformers_causal_lm_class = CausalLM + # transformers_causal_lm_class = CausalLM + transformers_causal_lm_class = TransformersFlashCausalLM if ( not USE_CUSTOM_MODELING and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index abfaa06ea33..d71e75b4458 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -6,6 +6,7 @@ from opentelemetry import trace from loguru import logger from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +import transformers.modeling_utils from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, @@ -31,36 +32,12 @@ tracer = trace.get_tracer(__name__) -def patch_everywhere( - attribute_name: str, patch: Any, module_name_prefix: Optional[str] = None -): - """ - Finds all occurences of `attribute_name` in the loaded modules and patches them with `patch`. - - Args: - attribute_name (`str`): - The name of attribute to patch. - patch (`Any`): - The patch for the attribute. - module_name_prefix (`Optional[str]`, defaults to `None`): - If set, only module names starting with this prefix will be considered for patching. - """ - # sys.modules may be updated while being iterated over, hence the list copy. - for name in list(sys.modules): - module = sys.modules[name] - if module_name_prefix is not None and not name.startswith(module_name_prefix): - continue - if hasattr(module, attribute_name): - setattr(module, attribute_name, patch) - - -def _flash_attention_forward_patched( +def tgi_flash_attention_forward( + module, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask: torch.Tensor, - query_length: int, - is_causal: bool, softmax_scale: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, @@ -71,15 +48,15 @@ def _flash_attention_forward_patched( # This means no scale kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) - # Correctly reshape the states - _, _, num_heads, head_dim = query_states.size() - _, _, num_kv_heads, _ = key_states.size() - # query_states = query_states.view(-1, num_heads, head_dim) - # key_states = key_states.view(-1, num_kv_heads, head_dim) - # value_states = value_states.view(-1, num_kv_heads, head_dim) - query_states = query_states.squeeze(dim=0) - key_states = key_states.squeeze(dim=0) - value_states = value_states.squeeze(dim=0) + query_states = query_states.transpose(1, 2).squeeze(dim=0) + key_states = key_states.transpose(1, 2).squeeze(dim=0) + value_states = value_states.transpose(1, 2).squeeze(dim=0) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) # Take care of updating the cache in-place kv_cache.store( @@ -89,6 +66,8 @@ def _flash_attention_forward_patched( kv_scales=kv_scales ) + + _, num_heads, head_dim = query_states.shape softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale sliding_window = -1 if sliding_window is None else sliding_window @@ -121,7 +100,10 @@ def _flash_attention_forward_patched( # attn_output = attn_output.view(attn_output.shape[0], -1) attn_output = attn_output.view(-1, num_heads * head_dim) - return attn_output + return attn_output, None + + +transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward class TransformersFlashCausalLM(FlashCausalLM): @@ -174,8 +156,9 @@ def __init__( device_map=("auto" if device_count > 1 else None), load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, - attn_implementation="flash_attention_2" + attn_implementation="tgi" ) + if device_count == 1 and quantize != "bitsandbytes": model = model.to(device) @@ -261,10 +244,6 @@ def fallback( trust_remote_code=trust_remote_code, ) - def warmup(self, batch: FlashCausalLMBatch, max_input_tokens: Optional[int], max_total_tokens: Optional[int],): - patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched) - return super().warmup(batch, max_input_tokens, max_total_tokens) - def forward( self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -318,11 +297,13 @@ def forward( max_q=batch.max_input_length, max_k=batch.max_current_length, ) - logits = self.model.forward( - input_ids=input_ids[None, ...], + # Use only the Model, not ModelForCausalLM + hidden_states = self.model.model.forward( + input_ids=input_ids[None, ...], # expand dim to easily fit transformers position_ids=position_ids[None, ...], past_key_values=None, use_cache=False, # we use self.kv_cache instead of transformers cache object + return_dict=True, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, @@ -330,10 +311,10 @@ def forward( seqlen=seqlen, max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, kv_head_mapping=self.kv_head_mapping, - ).logits[0, ...] - print("SUCCESSFUL FORWARD") + )[0].squeeze(dim=0) + # And compute logits from the lm_head, slicing correctly the indices + logits = self.model.lm_head.forward(hidden_states[lm_head_indices]) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None return logits, None From 715b2d19edb98531cdc56a9de2bd90ef51dc4130 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 13 Dec 2024 13:25:56 +0000 Subject: [PATCH 13/33] fix high dim --- .../models/transformers_flash_causal_lm.py | 232 +++++++++++++++++- 1 file changed, 219 insertions(+), 13 deletions(-) diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index d71e75b4458..7bcad8aa0ff 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -44,7 +44,7 @@ def tgi_flash_attention_forward( **kwargs, ): - kv_cache = kwargs["kv_cache"][kwargs["layer_idx"]] + kv_cache = kwargs["kv_cache"][module.layer_idx] # This means no scale kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) @@ -97,7 +97,6 @@ def tgi_flash_attention_forward( softcap=softcap, ) - # attn_output = attn_output.view(attn_output.shape[0], -1) attn_output = attn_output.view(-1, num_heads * head_dim) return attn_output, None @@ -244,6 +243,42 @@ def fallback( trust_remote_code=trust_remote_code, ) + + def _model_forward( + self, + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + lm_head_indices, + ): + hidden_states = self.model.model.forward( + input_ids=input_ids[None, ...], # expand dim to easily fit transformers + position_ids=position_ids[None, ...], # expand dim to easily fit transformers + past_key_values=None, # we use self.kv_cache instead of transformers cache object + use_cache=False, # we use self.kv_cache instead of transformers cache object + return_dict=True, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=prefill_cache_indices, + kv_head_mapping=self.kv_head_mapping, + )[0].squeeze(dim=0) + # And compute logits from the lm_head, slicing correctly the indices + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.model.lm_head.forward(hidden_states) + return logits + + def forward( self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -297,13 +332,9 @@ def forward( max_q=batch.max_input_length, max_k=batch.max_current_length, ) - # Use only the Model, not ModelForCausalLM - hidden_states = self.model.model.forward( - input_ids=input_ids[None, ...], # expand dim to easily fit transformers - position_ids=position_ids[None, ...], - past_key_values=None, - use_cache=False, # we use self.kv_cache instead of transformers cache object - return_dict=True, + logits = self._model_forward( + input_ids=input_ids, + position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, @@ -311,10 +342,8 @@ def forward( seqlen=seqlen, max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, - kv_head_mapping=self.kv_head_mapping, - )[0].squeeze(dim=0) - # And compute logits from the lm_head, slicing correctly the indices - logits = self.model.lm_head.forward(hidden_states[lm_head_indices]) + lm_head_indices=lm_head_indices, + ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None return logits, None @@ -363,3 +392,180 @@ def forward( # Slice output to the correct shape logits = cuda_graph["logits"][:bs] return logits, None + + + def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): + max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None + input_lengths = [max_s] * bs + cache_lengths = [0] * bs + if max_bs is None: + input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) + position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + slots = torch.arange(bs, dtype=torch.int64, device=self.device) + input_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * max_s + ) + cache_lengths_tensor = torch.zeros( + bs, dtype=torch.int32, device=self.device + ) + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(bs) + block_tables = block_tables.reshape((bs, max_bt)) + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths, + cache_lengths=cache_lengths, + input_lengths_tensor=input_lengths_tensor, + cache_lengths_tensor=cache_lengths_tensor, + max_current_length=max_s, + ) + else: + if bs > max_bs: + raise RuntimeError( + "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" + ) + input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] + position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] + if ATTENTION == "flashinfer": + block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] + else: + block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs] + slots = self.cuda_graphs[max_bs]["slots"][:bs] + input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs] + cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs] + + if ATTENTION == "flashinfer": + from text_generation_server.layers.attention.flashinfer import ( + create_decode_state_cuda_graphs, + ) + + block_tables_ptr = torch.zeros( + bs + 1, dtype=torch.int32, device=self.device + ) + last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) + state = create_decode_state_cuda_graphs( + device=input_ids.device, + block_tables=block_tables, + block_tables_ptr=block_tables_ptr, + last_page_len=last_page_len, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + else: + state = None + + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, "model_type") + and self.model.config.model_type == "qwen2_vl" + ): + if position_ids.dim() == 1: + position_ids = self.model.get_position_ids(input_ids) + + graph = torch.cuda.CUDAGraph() + self.cuda_graphs[bs] = { + "input_ids": input_ids, + "position_ids": position_ids, + "kv_cache": self.kv_cache, + "block_tables": block_tables, + "slots": slots, + "input_lengths": input_lengths_tensor, + "cache_lengths": cache_lengths_tensor, + "state": state, + "graph": graph, + } + + torch.cuda.synchronize() + # Run once outside to warmup + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=None, + input_lengths_tensor=input_lengths_tensor, + state=state, + cache_lengths_tensor=cache_lengths_tensor, + ): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) + self._model_forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + del seqlen + + torch.cuda.synchronize() + + with torch.cuda.graph(graph, pool=MEM_POOL): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) + logits = self._model_forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = None + torch.cuda.synchronize() + + + def tunableop_warmup(self, seqlen: int): + input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) + position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) + slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) + + # Dummy value, some models (starcoder2) don't accept `None`. + input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.zeros( + seqlen, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.tensor( + [0, seqlen], device=self.device, dtype=torch.int32 + ) + max_s = seqlen + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=1, + max_k=seqlen, + ) + + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + self._model_forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=self.kv_cache, + block_tables=None, + seqlen=seqlen, + slots=slots, + max_s=max_s, + lm_head_indices=None, + prefill_cache_indices=None, + ) \ No newline at end of file From e93ab925f921c2119702ce83c5ab8d3420f81a83 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 13 Dec 2024 14:02:45 +0000 Subject: [PATCH 14/33] init --- .../text_generation_server/models/__init__.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 54665083ff6..bf481e297b8 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -12,7 +12,7 @@ from loguru import logger from transformers.configuration_utils import PretrainedConfig -from transformers.models.auto import modeling_auto +from transformers.models.auto import modeling_auto, modeling_task from huggingface_hub import hf_hub_download, HfApi from typing import Optional, List, Dict from pathlib import Path @@ -375,30 +375,26 @@ def get_model( ) model_type = config_dict.get("model_type", None) - # transformers_causal_lm_class = CausalLM - transformers_causal_lm_class = TransformersFlashCausalLM - if ( - not USE_CUSTOM_MODELING - and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES - ): + transformers_causal_lm_class = CausalLM + if not USE_CUSTOM_MODELING: logger.info( "TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback." ) - transformers_model_class = getattr( - transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] - ) + try: + transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) + except KeyError: + transformers_model_class = modeling_task.AutoForCausalLM - if ( - transformers_model_class._supports_flash_attn_2 - and transformers_model_class._supports_cache_class - ): + if transformers_model_class._supports_flash_attn_2: logger.info( - f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersFlashCausalLM with ragged tensors (single dimension for batch and sequence length)." + f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for " + "batch and sequence length). All TGI's batching/caching optimizations are enabled." ) transformers_causal_lm_class = TransformersFlashCausalLM else: logger.info( - f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersCausalLM with classic tensors with padding (two dimensions for batch size and sequence length)." + f"Transformers' {model_type} implementation does not supports ragged tensors format. Will use classic " + "format with padding (two dimensions for batch size and sequence length). This is expected to be slow." ) quantization_config = config_dict.get("quantization_config", None) From f4c60ca522fb9286ea5ee270c3aef0d6c936af92 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 13 Dec 2024 14:13:47 +0000 Subject: [PATCH 15/33] default --- server/text_generation_server/models/globals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 7c7e026e107..6373f7967b9 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -70,4 +70,4 @@ def get_adapter_to_index(): USE_CUSTOM_MODELING = os.getenv("USE_CUSTOM_MODELING", "true") -USE_CUSTOM_MODELING = USE_CUSTOM_MODELING == "true" or USE_CUSTOM_MODELING == "1" +USE_CUSTOM_MODELING = USE_CUSTOM_MODELING.lower() == "true" or USE_CUSTOM_MODELING == "1" From 2e2631e093c574242ba155ae8b2356c3a432378f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Dec 2024 17:37:45 +0000 Subject: [PATCH 16/33] latest transformers changes --- server/text_generation_server/models/__init__.py | 14 ++++++++------ .../models/transformers_flash_causal_lm.py | 10 ++-------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index bf481e297b8..a1359212658 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -12,7 +12,7 @@ from loguru import logger from transformers.configuration_utils import PretrainedConfig -from transformers.models.auto import modeling_auto, modeling_task +from transformers.models.auto import modeling_auto from huggingface_hub import hf_hub_download, HfApi from typing import Optional, List, Dict from pathlib import Path @@ -380,12 +380,14 @@ def get_model( logger.info( "TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback." ) - try: - transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) - except KeyError: - transformers_model_class = modeling_task.AutoForCausalLM - if transformers_model_class._supports_flash_attn_2: + transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) + # Ugly check but works in the meantime + model_path = os.path.join(os.path.dirname(transformers.__file__), "models", model_type, f"modeling_{model_type}.py") + with open(model_path) as file: + has_fa2_class = f"FlashAttention2(" in file.read() + + if transformers_model_class._supports_flash_attn_2 and not has_fa2_class: logger.info( f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for " "batch and sequence length). All TGI's batching/caching optimizations are enabled." diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 7bcad8aa0ff..49dcac6297c 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -52,12 +52,6 @@ def tgi_flash_attention_forward( key_states = key_states.transpose(1, 2).squeeze(dim=0) value_states = value_states.transpose(1, 2).squeeze(dim=0) - input_dtype = query_states.dtype - if input_dtype == torch.float32: - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - # Take care of updating the cache in-place kv_cache.store( key=key_states, @@ -66,7 +60,6 @@ def tgi_flash_attention_forward( kv_scales=kv_scales ) - _, num_heads, head_dim = query_states.shape softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale sliding_window = -1 if sliding_window is None else sliding_window @@ -155,7 +148,8 @@ def __init__( device_map=("auto" if device_count > 1 else None), load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, - attn_implementation="tgi" + attn_implementation="tgi", + tp_plan="auto" if world_size > 1 else None, ) if device_count == 1 and quantize != "bitsandbytes": From 44b367937b8b2b70ea301cf44106d9bb3a7c9b0f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Dec 2024 17:49:06 +0000 Subject: [PATCH 17/33] revert --- server/text_generation_server/models/globals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 6373f7967b9..89f920bb7b3 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -8,7 +8,7 @@ REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} ATTENTION = os.environ["ATTENTION"] # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" -PREFIX_CACHING = os.environ["USE_PREFIX_CACHING"].lower() in { +PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in { "1", "true", } From 266377b3282d95ae50ef8fa263a6df148cb10a26 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Jan 2025 18:05:55 +0000 Subject: [PATCH 18/33] simplify check --- server/text_generation_server/models/__init__.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index a1359212658..c97b0006987 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -382,12 +382,8 @@ def get_model( ) transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) - # Ugly check but works in the meantime - model_path = os.path.join(os.path.dirname(transformers.__file__), "models", model_type, f"modeling_{model_type}.py") - with open(model_path) as file: - has_fa2_class = f"FlashAttention2(" in file.read() - if transformers_model_class._supports_flash_attn_2 and not has_fa2_class: + if transformers_model_class._supports_flex_attn: logger.info( f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for " "batch and sequence length). All TGI's batching/caching optimizations are enabled." From 32488c1a11f7593490bcfbe000fb4e1fc6c89400 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 17 Jan 2025 12:26:51 +0000 Subject: [PATCH 19/33] remove flag --- .../text_generation_server/models/__init__.py | 70 ++++++++----------- .../text_generation_server/models/globals.py | 3 - 2 files changed, 28 insertions(+), 45 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index c97b0006987..66be0be2613 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -30,7 +30,7 @@ from text_generation_server.models.custom_modeling.bloom_modeling import ( BloomForCausalLM, ) -from text_generation_server.models.globals import ATTENTION, USE_CUSTOM_MODELING +from text_generation_server.models.globals import ATTENTION from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.galactica import GalacticaCausalLMBatch from text_generation_server.models.custom_modeling.neox_modeling import ( @@ -368,7 +368,6 @@ def get_model( max_input_tokens: int, ) -> Model: global FLASH_ATTENTION - global USE_CUSTOM_MODELING config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code @@ -376,24 +375,11 @@ def get_model( model_type = config_dict.get("model_type", None) transformers_causal_lm_class = CausalLM - if not USE_CUSTOM_MODELING: - logger.info( - "TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback." - ) - - transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) - if transformers_model_class._supports_flex_attn: - logger.info( - f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for " - "batch and sequence length). All TGI's batching/caching optimizations are enabled." - ) - transformers_causal_lm_class = TransformersFlashCausalLM - else: - logger.info( - f"Transformers' {model_type} implementation does not supports ragged tensors format. Will use classic " - "format with padding (two dimensions for batch size and sequence length). This is expected to be slow." - ) + # Fast transformers path + transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) + if transformers_model_class._supports_flex_attn: + transformers_causal_lm_class = TransformersFlashCausalLM quantization_config = config_dict.get("quantization_config", None) if quantization_config is None: @@ -613,7 +599,7 @@ def get_model( ) if model_type == DEEPSEEK_V2: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: head_size = max( config_dict.get("qk_nope_dim", 128) + config_dict.get("qk_rope_dim", 64), @@ -678,7 +664,7 @@ def get_model( or model_type == GPT2 and model_id.startswith("bigcode/") ): - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashSantacoderForCausalLM, @@ -729,7 +715,7 @@ def get_model( batch_class=CausalLMBatchKeysLast, ) elif model_type == GPT2: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: try: return FlashCausalLM( model_id=model_id, @@ -765,7 +751,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == GPTJ: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: try: return FlashCausalLM( model_id=model_id, @@ -801,7 +787,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == GPT_NEOX: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: from text_generation_server.models.custom_modeling.flash_neox_modeling import ( GPTNeoXConfig, ) @@ -839,7 +825,7 @@ def get_model( ) elif model_type == PHI: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashPhiForCausalLM, @@ -862,7 +848,7 @@ def get_model( ) elif model_type == PHI_MOE: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashLlamaForCausalLM, @@ -886,7 +872,7 @@ def get_model( ) elif model_type == "phi-msft": - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: raise NotImplementedError( "Legacy phi-msft is not supported with Flash Attention" ) @@ -908,7 +894,7 @@ def get_model( or model_type == PHI3 or model_type == GRANITE ): - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashLlamaForCausalLM, @@ -934,7 +920,7 @@ def get_model( trust_remote_code=trust_remote_code, ) if model_type == GEMMA: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashGemmaForCausalLM, @@ -960,7 +946,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == GEMMA2: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashGemma2ForCausalLM, @@ -987,7 +973,7 @@ def get_model( ) if model_type == COHERE: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashCohereForCausalLM, @@ -1012,7 +998,7 @@ def get_model( ) if model_type == DBRX: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashDbrxForCausalLM, @@ -1041,7 +1027,7 @@ def get_model( if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]: if sharded: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: if config_dict.get("alibi", False): raise NotImplementedError("sharded is not supported for this model") return FlashCausalLM( @@ -1090,7 +1076,7 @@ def get_model( ) if model_type == MISTRAL: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashMistralForCausalLM, @@ -1115,7 +1101,7 @@ def get_model( ) if model_type == MIXTRAL: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashMixtralForCausalLM, @@ -1140,7 +1126,7 @@ def get_model( ) if model_type == STARCODER2: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashStarcoder2ForCausalLM, @@ -1167,7 +1153,7 @@ def get_model( ) if model_type == QWEN2: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=Qwen2ForCausalLM, @@ -1219,7 +1205,7 @@ def get_model( }, ) if model_type == IDEFICS: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return IdeficsCausalLM( model_id, revision, @@ -1243,7 +1229,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) if model_type == MLLAMA: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return MllamaCausalLM( model_id=model_id, model_class=MllamaForConditionalGeneration, @@ -1259,7 +1245,7 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama")) if model_type == IDEFICS2: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return VlmCausalLM( model_id=model_id, model_class=Idefics2ForConditionalGeneration, @@ -1277,7 +1263,7 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == PALIGEMMA: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return VlmCausalLM( model_id=model_id, model_class=PaliGemmaForConditionalGeneration, @@ -1296,7 +1282,7 @@ def get_model( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == LLAVA_NEXT: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return VlmCausalLM( model_class=LlavaNextForConditionalGeneration, model_id=model_id, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 89f920bb7b3..8a33fb32b59 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -68,6 +68,3 @@ def get_adapter_to_index(): global ADAPTER_TO_INDEX return ADAPTER_TO_INDEX - -USE_CUSTOM_MODELING = os.getenv("USE_CUSTOM_MODELING", "true") -USE_CUSTOM_MODELING = USE_CUSTOM_MODELING.lower() == "true" or USE_CUSTOM_MODELING == "1" From ac62bd1572b743847deca404ed459e6ed83a5adb Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 17 Jan 2025 13:09:52 +0000 Subject: [PATCH 20/33] improve type hints + required args --- .../models/transformers_flash_causal_lm.py | 66 +++++++++---------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 49dcac6297c..18ab27c2c49 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -1,6 +1,6 @@ import math import sys -from typing import Optional, Tuple, Dict, Any +from typing import List, Optional, Tuple, Dict, Any import torch from opentelemetry import trace @@ -24,7 +24,7 @@ ) from text_generation_server.adapters import AdapterBatchData from text_generation_server.layers.attention import paged_attention, attention, Seqlen -from text_generation_server.layers.attention.kv_cache import KVScales +from text_generation_server.layers.attention.kv_cache import KVScales, KVCache from text_generation_server.models.globals import ATTENTION from text_generation_server.models.metadata_kernels import block_tables_to_ragged @@ -37,14 +37,20 @@ def tgi_flash_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - attention_mask: torch.Tensor, + kv_cache: List[KVCache], + kv_head_mapping: torch.Tensor, + slots: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + seqlen: Seqlen, + block_tables: torch.Tensor, + max_s: int, softmax_scale: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, - **kwargs, + **kwargs, # This is needed to "absorb" other args passed by Transformers modeling ): - kv_cache = kwargs["kv_cache"][module.layer_idx] + kv_cache = kv_cache[module.layer_idx] # This means no scale kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) @@ -56,7 +62,7 @@ def tgi_flash_attention_forward( kv_cache.store( key=key_states, value=value_states, - slots=kwargs["slots"], + slots=slots, kv_scales=kv_scales ) @@ -64,15 +70,15 @@ def tgi_flash_attention_forward( softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale sliding_window = -1 if sliding_window is None else sliding_window - if kwargs["cu_seqlen_prefill"] is not None: + if cu_seqlen_prefill is not None: attn_output = attention( query=query_states, key=key_states, value=value_states, kv_cache=kv_cache, kv_scales=kv_scales, - seqlen=kwargs["seqlen"], - block_tables=kwargs["block_tables"], + seqlen=seqlen, + block_tables=block_tables, softmax_scale=softmax_scale, window_size_left=sliding_window, softcap=softcap, @@ -81,11 +87,11 @@ def tgi_flash_attention_forward( attn_output = paged_attention( query_states, kv_cache, - kwargs["kv_head_mapping"], + kv_head_mapping, softmax_scale, - kwargs["block_tables"], - kwargs["seqlen"], - kwargs["max_s"], + block_tables, + seqlen, + max_s, kv_scales=kv_scales, softcap=softcap, ) @@ -145,16 +151,13 @@ def __init__( model_id, revision=revision, torch_dtype=dtype, - device_map=("auto" if device_count > 1 else None), + device_map="auto", load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, attn_implementation="tgi", tp_plan="auto" if world_size > 1 else None, ) - if device_count == 1 and quantize != "bitsandbytes": - model = model.to(device) - if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: tokenizer.pad_token_id = model.config.pad_token_id @@ -237,23 +240,21 @@ def fallback( trust_remote_code=trust_remote_code, ) - def _model_forward( self, - input_ids, - position_ids, - cu_seqlen_prefill, - kv_cache, - block_tables, - slots, - seqlen, - max_s, - prefill_cache_indices, - lm_head_indices, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[KVCache], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + lm_head_indices: torch.Tensor, ): hidden_states = self.model.model.forward( - input_ids=input_ids[None, ...], # expand dim to easily fit transformers - position_ids=position_ids[None, ...], # expand dim to easily fit transformers + input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers + position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers past_key_values=None, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object return_dict=True, @@ -263,7 +264,6 @@ def _model_forward( slots=slots, seqlen=seqlen, max_s=max_s, - prefill_cache_indices=prefill_cache_indices, kv_head_mapping=self.kv_head_mapping, )[0].squeeze(dim=0) # And compute logits from the lm_head, slicing correctly the indices @@ -335,7 +335,6 @@ def forward( slots=slots, seqlen=seqlen, max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, ) if batch.prefill_cache_indices is not None: @@ -496,7 +495,6 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): slots=slots, seqlen=seqlen, max_s=max_s, - prefill_cache_indices=None, lm_head_indices=None, ) del seqlen @@ -520,7 +518,6 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): slots=slots, seqlen=seqlen, max_s=max_s, - prefill_cache_indices=None, lm_head_indices=None, ) self.cuda_graphs[bs]["logits"] = logits @@ -561,5 +558,4 @@ def tunableop_warmup(self, seqlen: int): slots=slots, max_s=max_s, lm_head_indices=None, - prefill_cache_indices=None, ) \ No newline at end of file From b03d7ae951ac69467be51073cefe16bf3e2fa5b2 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 17 Jan 2025 15:34:08 +0000 Subject: [PATCH 21/33] Update based on transformers PR --- server/text_generation_server/models/__init__.py | 2 +- .../models/transformers_flash_causal_lm.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 66be0be2613..92f3c51ceee 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -378,7 +378,7 @@ def get_model( # Fast transformers path transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) - if transformers_model_class._supports_flex_attn: + if transformers_model_class.is_backend_compatible(): transformers_causal_lm_class = TransformersFlashCausalLM quantization_config = config_dict.get("quantization_config", None) diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 18ab27c2c49..21aa1f8bf35 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -250,14 +250,19 @@ def _model_forward( slots: torch.Tensor, seqlen: Seqlen, max_s: int, - lm_head_indices: torch.Tensor, + lm_head_indices: Optional[torch.Tensor], ): - hidden_states = self.model.model.forward( + # Transformers does not support None as a default + if lm_head_indices is None: + lm_head_indices = 0 + + logits = self.model.forward( input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers past_key_values=None, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object return_dict=True, + num_logits_to_keep=lm_head_indices, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, @@ -265,11 +270,7 @@ def _model_forward( seqlen=seqlen, max_s=max_s, kv_head_mapping=self.kv_head_mapping, - )[0].squeeze(dim=0) - # And compute logits from the lm_head, slicing correctly the indices - if lm_head_indices is not None: - hidden_states = hidden_states[lm_head_indices] - logits = self.model.lm_head.forward(hidden_states) + ).logits.squeeze(dim=0) return logits From b40c889360f1894fa178afece0823a1468c7d6db Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 17 Jan 2025 16:05:47 +0000 Subject: [PATCH 22/33] small fix --- .../models/transformers_flash_causal_lm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 21aa1f8bf35..42ec1b3f04c 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -37,6 +37,7 @@ def tgi_flash_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], # This needs to stay as it is passed as a positional arg in transformers kv_cache: List[KVCache], kv_head_mapping: torch.Tensor, slots: torch.Tensor, From 42ae6dea02b9a36afa23e6c378db61533d0482de Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 17 Jan 2025 16:39:49 +0000 Subject: [PATCH 23/33] Remove Warpers for Processor --- server/text_generation_server/utils/logits_process.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 132e441be4f..5066de53e39 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -11,7 +11,6 @@ from outlines.fsm.guide import RegexGuide from transformers import ( - LogitsWarper, LogitsProcessor, PreTrainedTokenizerBase, TemperatureLogitsWarper, @@ -219,7 +218,7 @@ def filter(self, indices): return None -class HeterogeneousTopPLogitsWarper(LogitsWarper): +class HeterogeneousTopPLogitsWarper(LogitsProcessor): """ [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. This version allows for a separate value for each sample and runs inplace when possible. @@ -278,7 +277,7 @@ def filter(self, indices): return None -class HeterogeneousTopKLogitsWarper(LogitsWarper): +class HeterogeneousTopKLogitsWarper(LogitsProcessor): r""" [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. This version allows for a separate value for each sample and runs inplace when possible. @@ -359,7 +358,7 @@ def filter(self, indices): return None -class HeterogeneousTypicalLogitsWarper(LogitsWarper): +class HeterogeneousTypicalLogitsWarper(LogitsProcessor): r""" [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information. @@ -453,13 +452,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor): r""" A wrapper for logit warpers or processors without heterogeneous parameter support. Args: - processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): + processors (`Dict[int, LogitsProcessor]`): A mapping of sample indices to logit warpers or processors, to be run sequentially. """ def __init__( self, - processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], + processors: Dict[int, LogitsProcessor], ): self.processors = processors From f01014de37ea7cb807c328583e3bef10eb78f2f0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 17 Jan 2025 17:04:56 +0000 Subject: [PATCH 24/33] fix compatibility version issue --- .../text_generation_server/layers/gptq/quantize.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index 66fc15ec0e3..41dc867d313 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -956,15 +956,22 @@ def _unload(): pack(model, quantizers, bits, groupsize) from safetensors.torch import save_file - from transformers.modeling_utils import shard_checkpoint + from huggingface_hub import split_torch_state_dict_into_shards state_dict = model.state_dict() state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} max_shard_size = "10GB" - shards, index = shard_checkpoint( - state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors" + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern="model.safetensors", max_shard_size=max_shard_size, ) + index = None + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + shards = state_dict_split.filename_to_tensors os.makedirs(output_dir, exist_ok=True) for shard_file, shard in shards.items(): save_file( From 2659b5998b25928111397113dd06369cdfb4bd78 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 20 Jan 2025 11:29:51 +0100 Subject: [PATCH 25/33] raise error if needed --- server/text_generation_server/models/__init__.py | 2 ++ .../models/transformers_flash_causal_lm.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 92f3c51ceee..4b50653297f 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -380,6 +380,8 @@ def get_model( transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) if transformers_model_class.is_backend_compatible(): transformers_causal_lm_class = TransformersFlashCausalLM + if not FLASH_ATTENTION and lora_adapter_ids is not None and len(lora_adapter_ids) > 0: + raise ValueError("Transformers backend AutoModel do not support `lora_adapter_ids`.") quantization_config = config_dict.get("quantization_config", None) if quantization_config is None: diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 42ec1b3f04c..98fbf9a2efc 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -48,7 +48,7 @@ def tgi_flash_attention_forward( softmax_scale: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, - **kwargs, # This is needed to "absorb" other args passed by Transformers modeling + **_kwargs, # This is needed to "absorb" other args passed by Transformers modeling ): kv_cache = kv_cache[module.layer_idx] From a2fe842795c6230e878840f76e5af92db1c91717 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 20 Jan 2025 11:52:58 +0100 Subject: [PATCH 26/33] Simplify with monkey patch --- .../models/transformers_flash_causal_lm.py | 324 +----------------- 1 file changed, 15 insertions(+), 309 deletions(-) diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 98fbf9a2efc..cfc9c861abc 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -1,32 +1,17 @@ import math -import sys -from typing import List, Optional, Tuple, Dict, Any +from typing import List, Optional import torch from opentelemetry import trace -from loguru import logger from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig import transformers.modeling_utils -from text_generation_server.models.flash_causal_lm import ( - FlashCausalLMBatch, - FlashCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import ( - empty_cache, - synchronize, - get_free_memory, -) -from text_generation_server.adapters import AdapterBatchData +from text_generation_server.models.flash_causal_lm import FlashCausalLM +from text_generation_server.utils import initialize_torch_distributed + from text_generation_server.layers.attention import paged_attention, attention, Seqlen from text_generation_server.layers.attention.kv_cache import KVScales, KVCache from text_generation_server.models.globals import ATTENTION -from text_generation_server.models.metadata_kernels import block_tables_to_ragged tracer = trace.get_tracer(__name__) @@ -48,7 +33,7 @@ def tgi_flash_attention_forward( softmax_scale: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, - **_kwargs, # This is needed to "absorb" other args passed by Transformers modeling + **kwargs, # This is needed to "absorb" other args passed by Transformers modeling ): kv_cache = kv_cache[module.layer_idx] @@ -222,6 +207,11 @@ def __init__( world_size=world_size, ) + # Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code + # We first copy the original model.forward because we still need it in the monkey patch + self.model.original_forward = self.model.forward + self.model.forward = self._model_forward + @classmethod def fallback( cls, @@ -252,12 +242,15 @@ def _model_forward( seqlen: Seqlen, max_s: int, lm_head_indices: Optional[torch.Tensor], + prefill_cache_indices = None, # not used, but passed to match original signature + adapter_data = None, # not supported, but passed to match original signature ): # Transformers does not support None as a default if lm_head_indices is None: lm_head_indices = 0 - logits = self.model.forward( + # Equivalent tp `self.model.forward`, see the monkey patch in __init__ + logits = self.model.original_forward( input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers past_key_values=None, # we use self.kv_cache instead of transformers cache object @@ -272,292 +265,5 @@ def _model_forward( max_s=max_s, kv_head_mapping=self.kv_head_mapping, ).logits.squeeze(dim=0) - return logits - - - def forward( - self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # NOTE: adapter_data: not supported - - input_ids = batch.input_ids - position_ids = batch.position_ids - cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = self.kv_cache - block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - cache_lengths_tensor = batch.cache_lengths_tensor - max_s = batch.max_current_length - lm_head_indices = batch.prefill_head_indices - - if cu_seqlen_prefill is None and self.max_past() is not None: - # In decode, not prefill, we're actually overwriting the KV-cache - # in a circular buffer mode. - # This makes sure the max_s for the decode pass is correct. - max_s = min(self.max_past(), max_s) - - bs = input_ids.shape[0] - sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) - if sorted_padded_bs: - # Get associated cuda graph - cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] - else: - cuda_graph = None - - if cu_seqlen_prefill is not None or cuda_graph is None: - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - cache_lengths=batch.cache_lengths, - input_lengths_tensor=batch.input_lengths_tensor, - cache_lengths_tensor=batch.cache_lengths_tensor, - max_current_length=batch.max_current_length, - ) - with self._forward_context( - block_tables=block_tables, - cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths_tensor=input_lengths, - cache_lengths_tensor=cache_lengths_tensor, - ): - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - max_q=batch.max_input_length, - max_k=batch.max_current_length, - ) - logits = self._model_forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - lm_head_indices=lm_head_indices, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits, None - - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids - cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - cache_lengths=batch.cache_lengths, - input_lengths_tensor=batch.input_lengths_tensor, - cache_lengths_tensor=batch.cache_lengths_tensor, - max_current_length=batch.max_current_length, - ) - # assert block_tables.shape[0] >= slots.shape[0] - cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables - else: - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables - - # XXX: This is working only because block 0 is reserved for the healthcheck - # so it doesn't matter if we override it with bogus values. - cuda_graph["slots"].fill_(0) - cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths - cuda_graph["cache_lengths"].zero_() - cuda_graph["cache_lengths"][ - : cache_lengths_tensor.shape[0] - ] = cache_lengths_tensor - - with self._forward_context( - block_tables=cuda_graph["block_tables"], - cu_seqlen_prefill=None, - input_lengths_tensor=cuda_graph["input_lengths"], - cache_lengths_tensor=cuda_graph["cache_lengths"], - state=cuda_graph["state"], - ): - # Replay the graph - cuda_graph["graph"].replay() - - # Slice output to the correct shape - logits = cuda_graph["logits"][:bs] - return logits, None - - - def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): - max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None - input_lengths = [max_s] * bs - cache_lengths = [0] * bs - if max_bs is None: - input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) - slots = torch.arange(bs, dtype=torch.int64, device=self.device) - input_lengths_tensor = ( - torch.ones(bs, dtype=torch.int32, device=self.device) * max_s - ) - cache_lengths_tensor = torch.zeros( - bs, dtype=torch.int32, device=self.device - ) - block_tables = torch.arange( - max_bt, dtype=torch.int32, device=self.device - ).repeat(bs) - block_tables = block_tables.reshape((bs, max_bt)) - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=input_lengths, - cache_lengths=cache_lengths, - input_lengths_tensor=input_lengths_tensor, - cache_lengths_tensor=cache_lengths_tensor, - max_current_length=max_s, - ) - else: - if bs > max_bs: - raise RuntimeError( - "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" - ) - input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] - position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] - if ATTENTION == "flashinfer": - block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] - else: - block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs] - slots = self.cuda_graphs[max_bs]["slots"][:bs] - input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs] - cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs] - - if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flashinfer import ( - create_decode_state_cuda_graphs, - ) - block_tables_ptr = torch.zeros( - bs + 1, dtype=torch.int32, device=self.device - ) - last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) - state = create_decode_state_cuda_graphs( - device=input_ids.device, - block_tables=block_tables, - block_tables_ptr=block_tables_ptr, - last_page_len=last_page_len, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - ) - else: - state = None - - if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "model_type") - and self.model.config.model_type == "qwen2_vl" - ): - if position_ids.dim() == 1: - position_ids = self.model.get_position_ids(input_ids) - - graph = torch.cuda.CUDAGraph() - self.cuda_graphs[bs] = { - "input_ids": input_ids, - "position_ids": position_ids, - "kv_cache": self.kv_cache, - "block_tables": block_tables, - "slots": slots, - "input_lengths": input_lengths_tensor, - "cache_lengths": cache_lengths_tensor, - "state": state, - "graph": graph, - } - - torch.cuda.synchronize() - # Run once outside to warmup - with self._forward_context( - block_tables=block_tables, - cu_seqlen_prefill=None, - input_lengths_tensor=input_lengths_tensor, - state=state, - cache_lengths_tensor=cache_lengths_tensor, - ): - seqlen = Seqlen( - input_lengths=input_lengths_tensor, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=None, - max_q=1, - max_k=max_s, - ) - self._model_forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=self.kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - lm_head_indices=None, - ) - del seqlen - - torch.cuda.synchronize() - - with torch.cuda.graph(graph, pool=MEM_POOL): - seqlen = Seqlen( - input_lengths=input_lengths_tensor, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=None, - max_q=1, - max_k=max_s, - ) - logits = self._model_forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=self.kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - lm_head_indices=None, - ) - self.cuda_graphs[bs]["logits"] = logits - self.cuda_graphs[bs]["speculative_logits"] = None - torch.cuda.synchronize() - - - def tunableop_warmup(self, seqlen: int): - input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) - slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) - - # Dummy value, some models (starcoder2) don't accept `None`. - input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - cache_lengths_tensor = torch.zeros( - seqlen, dtype=torch.int32, device=self.device - ) - cu_seqlen_prefill = torch.tensor( - [0, seqlen], device=self.device, dtype=torch.int32 - ) - max_s = seqlen - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - max_q=1, - max_k=seqlen, - ) - - # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. - self._model_forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=self.kv_cache, - block_tables=None, - seqlen=seqlen, - slots=slots, - max_s=max_s, - lm_head_indices=None, - ) \ No newline at end of file + return logits, None \ No newline at end of file From 6e0f37c0cacd8a54701ec6509fc99cc9e7505f5b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 20 Jan 2025 15:13:24 +0100 Subject: [PATCH 27/33] revert + style + minor improvements --- .../layers/gptq/quantize.py | 4 +- .../text_generation_server/models/__init__.py | 20 ++++++-- .../text_generation_server/models/globals.py | 1 - .../models/transformers_flash_causal_lm.py | 51 ++++++++++--------- .../utils/logits_process.py | 2 +- 5 files changed, 45 insertions(+), 33 deletions(-) diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index 41dc867d313..aa664ea607a 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -963,7 +963,9 @@ def _unload(): max_shard_size = "10GB" state_dict_split = split_torch_state_dict_into_shards( - state_dict, filename_pattern="model.safetensors", max_shard_size=max_shard_size, + state_dict, + filename_pattern="model.safetensors", + max_shard_size=max_shard_size, ) index = None if state_dict_split.is_sharded: diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 4b50653297f..5069fff6d66 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -21,7 +21,9 @@ from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast -from text_generation_server.models.transformers_flash_causal_lm import TransformersFlashCausalLM +from text_generation_server.models.transformers_flash_causal_lm import ( + TransformersFlashCausalLM, +) from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.models.custom_modeling.mpt_modeling import ( MPTForCausalLM, @@ -377,11 +379,19 @@ def get_model( transformers_causal_lm_class = CausalLM # Fast transformers path - transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) - if transformers_model_class.is_backend_compatible(): + transformers_model_class = getattr( + transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] + ) + if transformers_model_class._supports_flex_attn: transformers_causal_lm_class = TransformersFlashCausalLM - if not FLASH_ATTENTION and lora_adapter_ids is not None and len(lora_adapter_ids) > 0: - raise ValueError("Transformers backend AutoModel do not support `lora_adapter_ids`.") + if ( + not FLASH_ATTENTION + and lora_adapter_ids is not None + and len(lora_adapter_ids) > 0 + ): + raise ValueError( + "Transformers backend AutoModel do not support `lora_adapter_ids`." + ) quantization_config = config_dict.get("quantization_config", None) if quantization_config is None: diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 8a33fb32b59..8d988ad5870 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -67,4 +67,3 @@ def set_adapter_to_index(adapter_to_index: Dict[str, int]): def get_adapter_to_index(): global ADAPTER_TO_INDEX return ADAPTER_TO_INDEX - diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index cfc9c861abc..30ea4c8fcce 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -22,7 +22,7 @@ def tgi_flash_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - attention_mask: Optional[torch.Tensor], # This needs to stay as it is passed as a positional arg in transformers + attention_mask: Optional[torch.Tensor], # This is a positional arg in Transformers kv_cache: List[KVCache], kv_head_mapping: torch.Tensor, slots: torch.Tensor, @@ -30,6 +30,7 @@ def tgi_flash_attention_forward( seqlen: Seqlen, block_tables: torch.Tensor, max_s: int, + kv_scales: KVScales, softmax_scale: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, @@ -37,20 +38,13 @@ def tgi_flash_attention_forward( ): kv_cache = kv_cache[module.layer_idx] - # This means no scale - kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) query_states = query_states.transpose(1, 2).squeeze(dim=0) key_states = key_states.transpose(1, 2).squeeze(dim=0) value_states = value_states.transpose(1, 2).squeeze(dim=0) # Take care of updating the cache in-place - kv_cache.store( - key=key_states, - value=value_states, - slots=slots, - kv_scales=kv_scales - ) + kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales) _, num_heads, head_dim = query_states.shape softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale @@ -110,14 +104,11 @@ def __init__( if speculator: raise RuntimeError("Speculator decoding is not enabled for AutoModel") - device_count = 0 if torch.cuda.is_available(): device = torch.device("cuda:0") - device_count = torch.cuda.device_count() dtype = torch.float16 if dtype is None else dtype elif hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device("xpu") - device_count = torch.xpu.device_count() dtype = torch.float16 if dtype is None else dtype else: if quantize: @@ -156,7 +147,6 @@ def __init__( else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - self.num_layers = model.config.num_hidden_layers self.num_heads = model.config.num_attention_heads // self.process_group.size() self.num_kv_heads = model.config.num_key_value_heads @@ -190,9 +180,16 @@ def __init__( ) self.num_groups = self.num_heads // self.num_kv_heads + + # Those will never change and will be used in the forwards self.kv_head_mapping = torch.arange( 0, self.num_kv_heads, dtype=torch.int32, device=device ).repeat_interleave(self.num_groups) + # This means no scale + self.kv_scales = KVScales( + torch.tensor(1.0, device=device), + torch.tensor(1.0, device=device), + ) torch.distributed.barrier(group=self.process_group) # Skip FlashCausalLM init. @@ -242,21 +239,17 @@ def _model_forward( seqlen: Seqlen, max_s: int, lm_head_indices: Optional[torch.Tensor], - prefill_cache_indices = None, # not used, but passed to match original signature - adapter_data = None, # not supported, but passed to match original signature + prefill_cache_indices=None, # not used, but passed to match original signature + adapter_data=None, # not supported, but passed to match original signature ): - # Transformers does not support None as a default - if lm_head_indices is None: - lm_head_indices = 0 - - # Equivalent tp `self.model.forward`, see the monkey patch in __init__ - logits = self.model.original_forward( + hidden_states = self.model.model.forward( input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers - position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers + position_ids=position_ids.unsqueeze( + 0 + ), # expand dim to easily fit transformers past_key_values=None, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object return_dict=True, - num_logits_to_keep=lm_head_indices, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, @@ -264,6 +257,14 @@ def _model_forward( seqlen=seqlen, max_s=max_s, kv_head_mapping=self.kv_head_mapping, - ).logits.squeeze(dim=0) + kv_scales=self.kv_scales, + )[0].squeeze(dim=0) + + # And compute logits from the lm_head, slicing correctly the indices + # NOTE: some logits post-processing (e.g. in gemma2) may be absent here with the split of the modules + # To update with full Transformers support asap + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.model.lm_head.forward(hidden_states) - return logits, None \ No newline at end of file + return logits, None diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 5066de53e39..64a285b93f8 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -5,7 +5,7 @@ from typing import List, Optional, DefaultDict from loguru import logger -from typing import Dict, Union +from typing import Dict from text_generation_server.pb.generate_pb2 import GrammarType from outlines.fsm.guide import RegexGuide From 52afdcc281187f2cb1602e553d70a72c1874fa45 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 20 Jan 2025 15:25:10 +0100 Subject: [PATCH 28/33] update comment --- .../models/transformers_flash_causal_lm.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 30ea4c8fcce..17f47e5e3c0 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -243,10 +243,8 @@ def _model_forward( adapter_data=None, # not supported, but passed to match original signature ): hidden_states = self.model.model.forward( - input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers - position_ids=position_ids.unsqueeze( - 0 - ), # expand dim to easily fit transformers + input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers + position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers past_key_values=None, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object return_dict=True, From 9af3ea4b70a5ea206904304fe9ee97576a83c301 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 20 Jan 2025 15:55:31 +0100 Subject: [PATCH 29/33] device check --- server/text_generation_server/models/__init__.py | 7 +++++-- .../models/transformers_flash_causal_lm.py | 8 +++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5069fff6d66..612ad8b37f6 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -382,7 +382,10 @@ def get_model( transformers_model_class = getattr( transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] ) - if transformers_model_class._supports_flex_attn: + accelerator_available = torch.cuda.is_available() or ( + hasattr(torch, "xpu") and torch.xpu.is_available() + ) + if transformers_model_class._supports_flex_attn and accelerator_available: transformers_causal_lm_class = TransformersFlashCausalLM if ( not FLASH_ATTENTION @@ -390,7 +393,7 @@ def get_model( and len(lora_adapter_ids) > 0 ): raise ValueError( - "Transformers backend AutoModel do not support `lora_adapter_ids`." + "Flash `Transformers` modeling backend does not support `lora_adapter_ids`." ) quantization_config = config_dict.get("quantization_config", None) diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 17f47e5e3c0..647fabc210a 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -111,11 +111,9 @@ def __init__( device = torch.device("xpu") dtype = torch.float16 if dtype is None else dtype else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + raise ValueError( + "Flash `Transformers` modeling backend is not available on cpu." + ) tokenizer = AutoTokenizer.from_pretrained( model_id, From 6d9c011f51b373160bf8a749de68227e33922598 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 20 Jan 2025 16:11:41 +0100 Subject: [PATCH 30/33] move the import to avoid device issue --- .../text_generation_server/models/__init__.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 612ad8b37f6..cfe9d02566a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -21,9 +21,7 @@ from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast -from text_generation_server.models.transformers_flash_causal_lm import ( - TransformersFlashCausalLM, -) + from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.models.custom_modeling.mpt_modeling import ( MPTForCausalLM, @@ -85,6 +83,9 @@ try: from text_generation_server.models.flash_causal_lm import FlashCausalLM + from text_generation_server.models.transformers_flash_causal_lm import ( + TransformersFlashCausalLM, + ) from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.mllama_causal_lm import MllamaCausalLM from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( @@ -382,16 +383,10 @@ def get_model( transformers_model_class = getattr( transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] ) - accelerator_available = torch.cuda.is_available() or ( - hasattr(torch, "xpu") and torch.xpu.is_available() - ) - if transformers_model_class._supports_flex_attn and accelerator_available: + + if FLASH_ATTENTION and transformers_model_class._supports_flex_attn: transformers_causal_lm_class = TransformersFlashCausalLM - if ( - not FLASH_ATTENTION - and lora_adapter_ids is not None - and len(lora_adapter_ids) > 0 - ): + if lora_adapter_ids is not None and len(lora_adapter_ids) > 0: raise ValueError( "Flash `Transformers` modeling backend does not support `lora_adapter_ids`." ) From 2ef3002c2b1ec9e3c3347fd3ae61c87db00b6266 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 20 Jan 2025 16:37:41 +0100 Subject: [PATCH 31/33] Update __init__.py --- .../text_generation_server/models/__init__.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index cfe9d02566a..a7c8c4a7ac8 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -83,9 +83,6 @@ try: from text_generation_server.models.flash_causal_lm import FlashCausalLM - from text_generation_server.models.transformers_flash_causal_lm import ( - TransformersFlashCausalLM, - ) from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.mllama_causal_lm import MllamaCausalLM from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( @@ -180,6 +177,14 @@ if MAMBA_AVAILABLE: __all__.append(Mamba) +FLASH_TRANSFORMERS_BACKEND = True +try: + from text_generation_server.models.transformers_flash_causal_lm import ( + TransformersFlashCausalLM, + ) +except ImportError: + FLASH_TRANSFORMERS_BACKEND = False + class ModelType(enum.Enum): DEEPSEEK_V2 = { @@ -384,12 +389,8 @@ def get_model( transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] ) - if FLASH_ATTENTION and transformers_model_class._supports_flex_attn: + if FLASH_TRANSFORMERS_BACKEND and transformers_model_class._supports_flex_attn: transformers_causal_lm_class = TransformersFlashCausalLM - if lora_adapter_ids is not None and len(lora_adapter_ids) > 0: - raise ValueError( - "Flash `Transformers` modeling backend does not support `lora_adapter_ids`." - ) quantization_config = config_dict.get("quantization_config", None) if quantization_config is None: From 70ada578b92442d3893451c4920e77bdc372b6d6 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 20 Jan 2025 18:01:12 +0100 Subject: [PATCH 32/33] check for non-native models --- server/text_generation_server/models/__init__.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index a7c8c4a7ac8..f7f7a26ee7d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -16,7 +16,6 @@ from huggingface_hub import hf_hub_download, HfApi from typing import Optional, List, Dict from pathlib import Path -import transformers from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model @@ -385,11 +384,14 @@ def get_model( transformers_causal_lm_class = CausalLM # Fast transformers path - transformers_model_class = getattr( - transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] + transformers_model_class = modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get( + model_type, None ) - - if FLASH_TRANSFORMERS_BACKEND and transformers_model_class._supports_flex_attn: + if ( + FLASH_TRANSFORMERS_BACKEND + and transformers_model_class is not None + and transformers_model_class._supports_flex_attn + ): transformers_causal_lm_class = TransformersFlashCausalLM quantization_config = config_dict.get("quantization_config", None) From 0d9ec75f27aa82ad247d6b5b3638043a8763fc29 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 20 Jan 2025 18:42:12 +0100 Subject: [PATCH 33/33] oupsi --- server/text_generation_server/models/__init__.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f7f7a26ee7d..160b45ada8b 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -16,6 +16,7 @@ from huggingface_hub import hf_hub_download, HfApi from typing import Optional, List, Dict from pathlib import Path +import transformers from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model @@ -384,8 +385,10 @@ def get_model( transformers_causal_lm_class = CausalLM # Fast transformers path - transformers_model_class = modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get( - model_type, None + transformers_model_class = getattr( + transformers, + modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""), + None, ) if ( FLASH_TRANSFORMERS_BACKEND