Skip to content

Commit ddbbbaa

Browse files
flybird11111pre-commit-ci[bot]BurkeHulkwangbluo
authored
[upgrade]Upgrade transformers (#6320)
* fix for async io * test for upgrading transformers * add ci machine * fix * fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_fp16_torch.py * Update build_on_pr.yml * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fiux * fix * fix * fix * upgrade llama * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * upgrade_bert * upgrade_bloom * [upgrade] upgrade gpt2 (#6291) * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * upgrade command * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * add explanation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * [upgrade]Upgrade qwen2 (#6302) * upgrade qwen2 * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * update_bloom * fix * add explantion * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade_sam * add the explanation * upgrade_t * fix * fix * fix * upgrade_gptj * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [upgrade]upgrade opt (#6307) * upgrade opt * fix * [upgrade]Upgrade mixtral (#6317) * upgrade mixtral * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade infer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * upgrade drafter * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * upgrade lazy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade mixtral --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [upgrade]Upgrade vit (#6308) * fix * fix * fix rotate embedding test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [upgrade]upgrade mistral (#6296) * upgrade mistral * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix falcon * fix * Update test_shard_deepseek.py * Update build_on_pr.yml * Update requirements.txt * fix (#6327) * fix (#6328) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update bert.py * fix (#6329) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hanks <[email protected]> Co-authored-by: wangbluo <[email protected]> Co-authored-by: Wang Binluo <[email protected]>
1 parent 46ed5d8 commit ddbbbaa

File tree

40 files changed

+817
-839
lines changed

40 files changed

+817
-839
lines changed

.github/workflows/build_on_pr.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }}
3535
changedLibraryFiles: ${{ steps.find-lib-change.outputs.all_changed_files }}
3636
anyLibraryFileChanged: ${{ steps.find-lib-change.outputs.any_changed }}
37-
runs-on: ubuntu-latest
37+
runs-on: [self-hosted,ubuntu-latest]
3838
concurrency:
3939
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-detect-change
4040
cancel-in-progress: true
@@ -87,7 +87,7 @@ jobs:
8787
name: Build and Test Colossal-AI
8888
needs: detect
8989
if: needs.detect.outputs.anyLibraryFileChanged == 'true'
90-
runs-on: ubuntu-latest
90+
runs-on: [self-hosted,ubuntu-latest]
9191
container:
9292
image: image-cloud.luchentech.com/hpcaitech/pytorch-cuda:2.2.2-12.1.0
9393
options: --gpus all --shm-size=2g --rm -v /dev/shm -v /data/scratch:/data/scratch

colossalai/inference/modeling/models/glide_llama.py

Lines changed: 20 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,16 @@
66

77
import torch
88
import torch.nn as nn
9-
from transformers.cache_utils import Cache, DynamicCache, StaticCache
9+
from transformers.cache_utils import DynamicCache
1010
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
1111
from transformers.models.llama.modeling_llama import (
1212
LlamaAttention,
1313
LlamaConfig,
1414
LlamaDecoderLayer,
15-
LlamaDynamicNTKScalingRotaryEmbedding,
1615
LlamaForCausalLM,
17-
LlamaLinearScalingRotaryEmbedding,
1816
LlamaMLP,
1917
LlamaModel,
2018
LlamaRMSNorm,
21-
LlamaRotaryEmbedding,
2219
)
2320

2421
from colossalai.inference.spec import GlideInput
@@ -156,31 +153,29 @@ def glide_llama_model_forward(
156153
if inputs_embeds is None:
157154
inputs_embeds = self.embed_tokens(input_ids)
158155

159-
past_seen_tokens = 0
160-
if use_cache: # kept for BC (cache positions)
161-
if not isinstance(past_key_values, StaticCache):
162-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
163-
past_seen_tokens = past_key_values.get_seq_length()
156+
if use_cache and past_key_values is None:
157+
past_key_values = DynamicCache()
164158

165159
if cache_position is None:
166-
if isinstance(past_key_values, StaticCache):
167-
raise ValueError("cache_position is a required argument when using StaticCache.")
160+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
168161
cache_position = torch.arange(
169162
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
170163
)
171164

172165
if position_ids is None:
173166
position_ids = cache_position.unsqueeze(0)
174167

175-
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
168+
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
169+
if hasattr(glide_input, "n_spec_tokens"):
170+
position_ids = position_ids + glide_input.n_spec_tokens
176171

177172
# embed positions
178173
hidden_states = inputs_embeds
174+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
179175

180176
# decoder layers
181177
all_hidden_states = () if output_hidden_states else None
182178
all_self_attns = () if output_attentions else None
183-
next_decoder_cache = None
184179

185180
for decoder_layer in self.layers:
186181
if output_hidden_states:
@@ -189,9 +184,9 @@ def glide_llama_model_forward(
189184
# GlideLlamaDecoderLayer
190185
layer_outputs = decoder_layer(
191186
hidden_states,
187+
position_embeddings=position_embeddings,
192188
glide_input=glide_input,
193189
attention_mask=attention_mask,
194-
position_ids=position_ids,
195190
past_key_value=past_key_values,
196191
output_attentions=output_attentions,
197192
use_cache=use_cache,
@@ -200,9 +195,6 @@ def glide_llama_model_forward(
200195

201196
hidden_states = layer_outputs[0]
202197

203-
if use_cache:
204-
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
205-
206198
if output_attentions:
207199
all_self_attns += (layer_outputs[1],)
208200

@@ -212,16 +204,11 @@ def glide_llama_model_forward(
212204
if output_hidden_states:
213205
all_hidden_states += (hidden_states,)
214206

215-
next_cache = None
216-
if use_cache:
217-
next_cache = (
218-
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
219-
)
220207
if not return_dict:
221-
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
208+
return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None)
222209
return BaseModelOutputWithPast(
223210
last_hidden_state=hidden_states,
224-
past_key_values=next_cache,
211+
past_key_values=past_key_values,
225212
hidden_states=all_hidden_states,
226213
attentions=all_self_attns,
227214
)
@@ -267,41 +254,17 @@ def __init__(self, config: GlideLlamaConfig):
267254

268255
self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False)
269256
self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False)
270-
self._init_rope()
271-
272-
def _init_rope(self):
273-
if self.config.rope_scaling is None:
274-
self.rotary_emb = LlamaRotaryEmbedding(
275-
self.large_head_dim,
276-
max_position_embeddings=self.max_position_embeddings,
277-
)
278-
else:
279-
scaling_type = self.config.rope_scaling["type"]
280-
scaling_factor = self.config.rope_scaling["factor"]
281-
if scaling_type == "linear":
282-
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
283-
self.large_head_dim,
284-
max_position_embeddings=self.max_position_embeddings,
285-
scaling_factor=scaling_factor,
286-
)
287-
elif scaling_type == "dynamic":
288-
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
289-
self.large_head_dim,
290-
max_position_embeddings=self.max_position_embeddings,
291-
scaling_factor=scaling_factor,
292-
)
293-
else:
294-
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
295257

296258
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
297259
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
298260

299261
def forward(
300262
self,
301263
hidden_states: torch.Tensor,
264+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
265+
position_ids: Optional[torch.LongTensor] = None,
302266
glide_input: GlideInput = None, # Used for glimpsing main model's KV caches
303267
attention_mask: Optional[torch.Tensor] = None,
304-
position_ids: Optional[torch.LongTensor] = None,
305268
output_attentions: bool = False,
306269
use_cache: bool = False,
307270
) -> Optional[torch.Tensor]:
@@ -319,8 +282,7 @@ def forward(
319282
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
320283

321284
# for RoPE
322-
position_ids = position_ids + glide_input.n_spec_tokens
323-
cos, sin = self.rotary_emb(query_states, position_ids)
285+
cos, sin = position_embeddings
324286
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
325287
query_states = query_states.transpose(1, 2)
326288
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
@@ -367,9 +329,10 @@ def from_native_module(module: LlamaDecoderLayer, *args, **kwargs) -> "GlideLlam
367329
def forward(
368330
self,
369331
hidden_states: torch.Tensor,
332+
position_embeddings: torch.Tensor = None,
333+
position_ids: Optional[torch.LongTensor] = None,
370334
glide_input: GlideInput = None,
371335
attention_mask: Optional[torch.Tensor] = None,
372-
position_ids: Optional[torch.LongTensor] = None,
373336
past_key_value: Optional[Tuple[torch.Tensor]] = None,
374337
output_attentions: Optional[bool] = False,
375338
use_cache: Optional[bool] = False,
@@ -399,10 +362,10 @@ def forward(
399362
hidden_states = self.input_layernorm(hidden_states)
400363

401364
# Self Attention
402-
hidden_states, self_attn_weights, present_key_value = self.self_attn(
365+
hidden_states, self_attn_weights = self.self_attn(
403366
hidden_states=hidden_states,
367+
position_embeddings=position_embeddings,
404368
attention_mask=attention_mask,
405-
position_ids=position_ids,
406369
past_key_value=past_key_value,
407370
output_attentions=output_attentions,
408371
use_cache=use_cache,
@@ -425,9 +388,10 @@ def forward(
425388

426389
hidden_states = self.cross_attn(
427390
hidden_states=hidden_states,
391+
position_embeddings=position_embeddings,
392+
position_ids=position_ids,
428393
glide_input=glide_input,
429394
attention_mask=attention_mask,
430-
position_ids=position_ids,
431395
output_attentions=output_attentions,
432396
use_cache=True,
433397
)
@@ -441,9 +405,6 @@ def forward(
441405

442406
outputs = (hidden_states,)
443407

444-
if use_cache:
445-
outputs += (present_key_value,)
446-
447408
return outputs
448409

449410

colossalai/inference/modeling/models/nopadding_llama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,9 @@ def from_native_module(
478478
attn_oproj=attn_oproj,
479479
process_group=process_group,
480480
model_shard_infer_config=model_shard_infer_config,
481-
num_heads=module.num_heads,
482-
hidden_size=module.hidden_size,
483-
num_key_value_heads=module.num_key_value_heads,
481+
num_heads=module.config.num_attention_heads,
482+
hidden_size=module.config.hidden_size,
483+
num_key_value_heads=module.config.num_key_value_heads,
484484
)
485485

486486
return attn_layer

colossalai/inference/spec/drafter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torch.nn as nn
55
from transformers import PreTrainedTokenizer
6+
from transformers.cache_utils import DynamicCache
67

78
from colossalai.utils import get_current_device
89

@@ -93,9 +94,8 @@ def speculate(
9394

9495
for _ in range(n_spec_tokens):
9596
# update past key values
96-
kwargs["past_key_values"] = past_key_values
9797

98-
outputs = self._drafter_model(input_ids, **kwargs)
98+
outputs = self._drafter_model(input_ids, past_key_values=past_key_values, **kwargs)
9999
next_token_logits = outputs.logits[:, -1, :]
100100

101101
# NOTE Only use greedy search for speculating.
@@ -114,6 +114,8 @@ def speculate(
114114
speculated_length = len(token_ids) # For now, only support bsz 1
115115
logits = torch.concat(logits, dim=0)
116116
token_ids = torch.concat(token_ids, dim=-1)
117+
if isinstance(past_key_values, DynamicCache):
118+
past_key_values = past_key_values.to_legacy_cache()
117119

118120
out = DrafterOutput(
119121
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values

colossalai/lazy/pretrained.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def new_from_pretrained(
6969
_ = kwargs.pop("mirror", None)
7070
from_pipeline = kwargs.pop("_from_pipeline", None)
7171
from_auto_class = kwargs.pop("_from_auto", False)
72-
_fast_init = kwargs.pop("_fast_init", True)
72+
kwargs.pop("_fast_init", True)
7373
torch_dtype = kwargs.pop("torch_dtype", None)
7474
subfolder = kwargs.pop("subfolder", "")
7575
commit_hash = kwargs.pop("_commit_hash", None)
@@ -286,7 +286,7 @@ def new_from_pretrained(
286286
config.name_or_path = pretrained_model_name_or_path
287287

288288
# Instantiate model.
289-
init_contexts = [no_init_weights(_enable=_fast_init)]
289+
init_contexts = [no_init_weights()]
290290

291291
with ContextManagers(init_contexts):
292292
model = cls(config, *model_args, **model_kwargs)

colossalai/shardformer/modeling/bert.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def bert_model_forward(
5858
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
5959
stage_index: Optional[List[int]] = None,
6060
shard_config: ShardConfig = None,
61-
):
61+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
6262
# TODO(jianghai): add explaination of the output here.
6363
r"""
6464
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -1037,6 +1037,89 @@ def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.T
10371037
return forward
10381038

10391039

1040+
# Fix the tgt_len size in sequence parallel attention:
1041+
# same with the one in BertSdpaSelfAttention forward in v4.51.3 transformers except the
1042+
def get_bert_sequence_parallel_attention_forward(shard_config: ShardConfig):
1043+
from transformers.models.bert.modeling_bert import BertSdpaSelfAttention
1044+
1045+
def forward(
1046+
self: BertSdpaSelfAttention,
1047+
hidden_states: torch.Tensor,
1048+
attention_mask: Optional[torch.FloatTensor] = None,
1049+
head_mask: Optional[torch.FloatTensor] = None,
1050+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
1051+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
1052+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1053+
output_attentions: Optional[bool] = False,
1054+
) -> Tuple[torch.Tensor]:
1055+
1056+
bsz, tgt_len, _ = hidden_states.size()
1057+
1058+
query_layer = self.transpose_for_scores(self.query(hidden_states))
1059+
1060+
# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
1061+
# mask needs to be such that the encoder's padding tokens are not attended to.
1062+
is_cross_attention = encoder_hidden_states is not None
1063+
1064+
current_states = encoder_hidden_states if is_cross_attention else hidden_states
1065+
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
1066+
1067+
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
1068+
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
1069+
key_layer, value_layer = past_key_value
1070+
else:
1071+
key_layer = self.transpose_for_scores(self.key(current_states))
1072+
value_layer = self.transpose_for_scores(self.value(current_states))
1073+
if past_key_value is not None and not is_cross_attention:
1074+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
1075+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
1076+
1077+
if self.is_decoder:
1078+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
1079+
# Further calls to cross_attention layer can then reuse all cross-attention
1080+
# key/value_states (first "if" case)
1081+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
1082+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
1083+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
1084+
# if encoder bi-directional self-attention `past_key_value` is always `None`
1085+
past_key_value = (key_layer, value_layer)
1086+
1087+
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
1088+
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
1089+
# Reference: https://github.com/pytorch/pytorch/issues/112577
1090+
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
1091+
query_layer = query_layer.contiguous()
1092+
key_layer = key_layer.contiguous()
1093+
value_layer = value_layer.contiguous()
1094+
1095+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
1096+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
1097+
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
1098+
# a causal mask in case tgt_len == 1.
1099+
is_causal = (
1100+
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
1101+
)
1102+
attn_output = torch.nn.functional.scaled_dot_product_attention(
1103+
query_layer,
1104+
key_layer,
1105+
value_layer,
1106+
attn_mask=attention_mask,
1107+
dropout_p=self.dropout_prob if self.training else 0.0,
1108+
is_causal=is_causal,
1109+
)
1110+
1111+
attn_output = attn_output.transpose(1, 2)
1112+
_, _, tgt_len, _ = query_layer.shape
1113+
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
1114+
1115+
outputs = (attn_output,)
1116+
if self.is_decoder:
1117+
outputs = outputs + (past_key_value,)
1118+
return outputs
1119+
1120+
return forward
1121+
1122+
10401123
def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
10411124
def forward(
10421125
self,

0 commit comments

Comments
 (0)