@@ -47,8 +47,6 @@ def _attn(
4747
4848 attn_weights = torch .matmul (query , key .transpose (- 1 , - 2 ))
4949
50- attn_weights = attn_weights / self .scale_attn
51-
5250 # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
5351 # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
5452 mask_value = torch .tensor (MIN_MASKED_ATTENTION_VALUE , dtype = attn_weights .dtype ).to (attn_weights .device )
@@ -57,6 +55,7 @@ def _attn(
5755 # Apply the attention mask
5856 attn_weights = torch .where (attention_mask , mask_value , attn_weights )
5957
58+ attn_weights = attn_weights / self .scale_attn
6059 attn_weights = nn .Softmax (dim = - 1 )(attn_weights )
6160 attn_weights = attn_weights .to (value .dtype )
6261 attn_weights = self .attn_dropout (attn_weights )
@@ -124,36 +123,16 @@ def forward(
124123 query = query .permute (0 , 2 , 1 , 3 )
125124
126125 if layer_past is not None :
127- # Update the cache_kwargs with position_ids for Cloud AI 100
128- past_key_value = layer_past
129- cache_kwargs = {
130- "position_ids" : position_ids ,
131- "batch_index" : batch_index ,
132- }
133- pkv = QEffDynamicCache ()
134- pkv .key_cache .append (past_key_value [0 ])
135- pkv .value_cache .append (past_key_value [1 ])
136- key , value = pkv .update (key , value , 0 , cache_kwargs )
137-
138- if use_cache is True :
139- # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
140- # Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
141- present = (pkv .key_cache [0 ].to (hidden_states .dtype ), pkv .value_cache [0 ])
142- else :
143- present = None
126+ cache_kwargs = {"position_ids" : position_ids , "batch_index" : batch_index }
127+ key , value = layer_past .update (key .to (hidden_states .dtype ), value , self .layer_idx , cache_kwargs )
144128
145129 # compute self-attention: V x Softmax(QK^T)
146130 attn_output , attn_weights = self ._attn (query , key , value , attention_mask , head_mask )
147131
148132 attn_output = self ._merge_heads (attn_output , self .num_attention_heads , self .head_dim )
149133 attn_output = self .out_proj (attn_output )
150134 attn_output = self .resid_dropout (attn_output )
151-
152- outputs = (attn_output , present )
153- if output_attentions :
154- outputs += (attn_weights ,)
155-
156- return outputs # a, present, (attentions)
135+ return attn_output , attn_weights
157136
158137
159138class QEffCodeGenModel (CodeGenModel ):
@@ -167,7 +146,7 @@ class QEffCodeGenModel(CodeGenModel):
167146 def forward (
168147 self ,
169148 input_ids : Optional [torch .LongTensor ] = None ,
170- past_key_values : Optional [Tuple [ Tuple [ torch .Tensor ]]] = None ,
149+ past_key_values : Optional [Union [ Cache , tuple [ tuple [ torch .Tensor ] ]]] = None ,
171150 attention_mask : Optional [torch .FloatTensor ] = None ,
172151 token_type_ids : Optional [torch .LongTensor ] = None ,
173152 batch_index : Optional [torch .LongTensor ] = None ,
@@ -179,7 +158,8 @@ def forward(
179158 output_hidden_states : Optional [bool ] = None ,
180159 return_dict : Optional [bool ] = None ,
181160 cache_position : Optional [torch .LongTensor ] = None ,
182- ) -> Union [Tuple , BaseModelOutputWithPast ]:
161+ ** kwargs , # NOOP kwargs, for now
162+ ) -> Union [tuple , BaseModelOutputWithPast ]:
183163 output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
184164 output_hidden_states = (
185165 output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
@@ -200,20 +180,21 @@ def forward(
200180 else :
201181 raise ValueError ("You have to specify either input_ids or inputs_embeds" )
202182
203- device = input_ids .device if input_ids is not None else inputs_embeds .device
183+ if inputs_embeds is None :
184+ inputs_embeds = self .wte (input_ids )
204185
205- if token_type_ids is not None :
206- token_type_ids = token_type_ids .view (- 1 , input_shape [- 1 ])
186+ return_legacy_cache = False
187+ if use_cache and not isinstance (past_key_values , Cache ):
188+ return_legacy_cache = True
189+ past_key_values = QEffDynamicCache .from_legacy_cache (past_key_values )
207190
208- if past_key_values is None :
209- past_length = 0
210- past_key_values = tuple ([None ] * len (self .h ))
211- else :
212- past_length = past_key_values [0 ][0 ].size (- 2 )
191+ past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
192+ seq_length = inputs_embeds .shape [1 ]
193+ if cache_position is None :
194+ cache_position = torch .arange (past_seen_tokens , past_seen_tokens + seq_length , device = inputs_embeds .device )
213195
214196 if position_ids is None :
215- position_ids = torch .arange (past_length , input_shape [- 1 ] + past_length , dtype = torch .long , device = device )
216- position_ids = position_ids .unsqueeze (0 )
197+ position_ids = cache_position .unsqueeze (0 )
217198
218199 # Attention mask.
219200 if attention_mask is not None :
@@ -237,40 +218,33 @@ def forward(
237218
238219 elif attention_mask is None :
239220 # 4d mask is passed through the layers
240- attention_mask = _create_causal_mask (position_ids = position_ids , target_length = past_length )
221+ attention_mask = _create_causal_mask (position_ids = position_ids , target_length = past_seen_tokens )
241222
242223 # Prepare head mask if needed
243224 # 1.0 in head_mask indicate we keep the head
244225 # attention_probs has shape bsz x num_attention_heads x N x N
245226 # head_mask has shape n_layer x batch x num_attention_heads x N x N
246227 head_mask = self .get_head_mask (head_mask , self .config .n_layer )
247228
248- if inputs_embeds is None :
249- inputs_embeds = self .wte (input_ids )
250-
251229 hidden_states = inputs_embeds
252230
253231 if token_type_ids is not None :
232+ token_type_ids = token_type_ids .view (- 1 , seq_length )
254233 token_type_embeds = self .wte (token_type_ids )
255234 hidden_states = hidden_states + token_type_embeds
256235
257236 hidden_states = self .drop (hidden_states )
237+ output_shape = (- 1 , seq_length , hidden_states .size (- 1 ))
258238
259- output_shape = input_shape + (hidden_states .size (- 1 ),)
260-
261- if position_ids is None :
262- position_ids = cache_position .unsqueeze (0 )
263-
264- presents = () if use_cache else None
265239 all_self_attentions = () if output_attentions else None
266240 all_hidden_states = () if output_hidden_states else None
267- for i , ( block , layer_past ) in enumerate (zip ( self .h , past_key_values ) ):
241+ for i , block in enumerate (self .h ):
268242 if output_hidden_states :
269243 all_hidden_states = all_hidden_states + (hidden_states ,)
270244
271245 outputs = block (
272- hidden_states = hidden_states ,
273- layer_past = layer_past ,
246+ hidden_states ,
247+ layer_past = past_key_values ,
274248 batch_index = batch_index ,
275249 attention_mask = attention_mask ,
276250 position_ids = position_ids ,
@@ -281,11 +255,9 @@ def forward(
281255 )
282256
283257 hidden_states = outputs [0 ]
284- if use_cache is True :
285- presents = presents + (outputs [1 ],)
286258
287259 if output_attentions :
288- all_self_attentions = all_self_attentions + (outputs [2 if use_cache else 1 ],)
260+ all_self_attentions = all_self_attentions + (outputs [1 ],)
289261
290262 hidden_states = self .ln_f (hidden_states )
291263
@@ -294,12 +266,17 @@ def forward(
294266 if output_hidden_states :
295267 all_hidden_states = all_hidden_states + (hidden_states ,)
296268
269+ if return_legacy_cache :
270+ past_key_values = past_key_values .to_legacy_cache ()
271+
297272 if not return_dict :
298- return tuple (v for v in [hidden_states , presents , all_hidden_states , all_self_attentions ] if v is not None )
273+ return tuple (
274+ v for v in [hidden_states , past_key_values , all_hidden_states , all_self_attentions ] if v is not None
275+ )
299276
300277 return BaseModelOutputWithPast (
301278 last_hidden_state = hidden_states ,
302- past_key_values = presents ,
279+ past_key_values = past_key_values ,
303280 hidden_states = all_hidden_states ,
304281 attentions = all_self_attentions ,
305282 )
@@ -330,12 +307,6 @@ def forward(
330307 return_dict : Optional [bool ] = None ,
331308 cache_position : Optional [torch .LongTensor ] = None ,
332309 ) -> Union [Tuple , CausalLMOutputWithPast ]:
333- r"""
334- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
335- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
336- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
337- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
338- """
339310 return_dict = return_dict if return_dict is not None else self .config .use_return_dict
340311
341312 transformer_outputs = self .transformer (
@@ -372,9 +343,7 @@ def forward(
372343 )
373344
374345
375- class QeffCodeGenBlock (CodeGenBlock ):
376- # Ignore copy
377-
346+ class QEffCodeGenBlock (CodeGenBlock ):
378347 def forward (
379348 self ,
380349 hidden_states : Optional [torch .FloatTensor ],
@@ -389,7 +358,7 @@ def forward(
389358 ) -> Union [Tuple [torch .Tensor ], Optional [Tuple [torch .Tensor , Tuple [torch .FloatTensor , ...]]]]:
390359 residual = hidden_states
391360 hidden_states = self .ln_1 (hidden_states )
392- attn_outputs = self .attn (
361+ attn_outputs , attn_weights = self .attn (
393362 hidden_states = hidden_states ,
394363 layer_past = layer_past ,
395364 attention_mask = attention_mask ,
@@ -400,15 +369,8 @@ def forward(
400369 output_attentions = output_attentions ,
401370 cache_position = cache_position ,
402371 )
403- attn_output = attn_outputs [0 ] # output_attn: a, present, (attentions)
404- outputs = attn_outputs [1 :]
405372
406373 feed_forward_hidden_states = self .mlp (hidden_states )
407- hidden_states = attn_output + feed_forward_hidden_states + residual
408-
409- if use_cache :
410- outputs = (hidden_states ,) + outputs
411- else :
412- outputs = (hidden_states ,) + outputs [1 :]
374+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
413375
414- return outputs # hidden_states, present, (attentions)
376+ return hidden_states , attn_weights
0 commit comments