-
Notifications
You must be signed in to change notification settings - Fork 63
/
Copy pathmha.py
1088 lines (931 loc) · 48.3 KB
/
mha.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import inspect
import math
from functools import partial
from typing import Callable, Dict, Optional
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.parallel.comm.utils import (
gather_forward_split_backward,
split_forward_gather_backward,
)
from internlm.model.modules.embedding import new_rotary_embedding
from internlm.model.modules.linear import new_linear
from internlm.model.modules.utils import update_kv_cache
from internlm.model.ops.attention import CrossAttention, SelfAttention
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
def _convert_cu_seqlens_for_qksplited(kwargs: Dict):
cu_seqlens = kwargs.pop("cu_seqlens", None)
max_seqlen = kwargs.pop("max_seqlen", None)
if cu_seqlens is not None:
kwargs["cu_seqlens_q"] = cu_seqlens
kwargs["cu_seqlens_k"] = cu_seqlens
if max_seqlen is not None:
kwargs["max_seqlen_q"] = max_seqlen
kwargs["max_seqlen_k"] = max_seqlen
return kwargs
def split_fused_wqkv_weight(wqkv, *args, **kwargs): # pylint: disable=W0613
q_dim = kwargs["q_dim"]
kv_dim = kwargs["kv_dim"]
split_size = [q_dim, kv_dim, kv_dim]
assert (q_dim + 2 * kv_dim) % wqkv.size(0) == 0
divisor = (q_dim + 2 * kv_dim) // wqkv.size(0)
wq, wk, wv = torch.split(wqkv, [x // divisor for x in split_size], dim=0)
return wq, wk, wv
def _qkv_pre_load_convert(module: "GQA", state_dict, prefix: str, *args, **kwargs) -> None: # pylint: disable=W0613
wq_name, wk_name, wv_name, fused_name = (
f"{prefix}wq.weight",
f"{prefix}wk.weight",
f"{prefix}wv.weight",
f"{prefix}wqkv.weight",
)
if module.enable_qkv_fusion and fused_name not in state_dict:
wq, wk, wv = state_dict.pop(wq_name), state_dict.pop(wk_name), state_dict.pop(wv_name)
state_dict[fused_name] = torch.cat([wq, wk, wv], dim=0)
if not module.enable_qkv_fusion and (
wq_name not in state_dict or wk_name not in state_dict or wv_name not in state_dict
):
state_dict[wq_name], state_dict[wk_name], state_dict[wv_name] = split_fused_wqkv_weight(
state_dict.pop(fused_name), *args, **kwargs
)
def _qkv_save_convert(module: "GQA", state_dict, prefix: str, *args, **kwargs) -> Dict: # pylint: disable=W0613
wq_name, wk_name, wv_name, fused_name = (
f"{prefix}wq.weight",
f"{prefix}wk.weight",
f"{prefix}wv.weight",
f"{prefix}wqkv.weight",
)
if module.enable_qkv_fusion:
state_dict[wq_name], state_dict[wk_name], state_dict[wv_name] = split_fused_wqkv_weight(
state_dict.pop(fused_name), *args, **kwargs
)
return state_dict
class MHA(nn.Module):
"""
Multi-head self-attention and cross-attention.
Args:
embed_dim (int): The dimention of hidden state.
num_heads (int): The number of attention heads.
max_position_embeddings (int): max position embeddings, 2048 by default.
bias (bool): Whether the bias is needed for linears. True by default.
dropout (float): The dropout rate for cross attention and self attention. 0.0 by default.
softmax_scale (float): The temperature to use for the softmax attention.
causal (boolean): Whether to apply causal attention mask. False by default.
layer_idx (int): The index of current layer. None by default.
use_dynamic_ntk_rope (bool): whether use dynamic ntk rope, false by default.
rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default.
rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements
XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
qk_interleaved (Optional[bool]): whether the odd and even columns of wq and wk is interleaved. True by default.
enable_qkv_fusion (bool): whether wq, wk and wv lienar is fused. True by default.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
max_position_embeddings: int = 2048,
bias: bool = True,
dropout: float = 0.0,
softmax_scale: float = None,
causal: bool = False,
layer_idx: int = None,
use_dynamic_ntk_rope: bool = False,
rotary_emb_dim: int = 0,
rotary_emb_scale_base: int = 0,
rope_base: int = 10000,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
qk_interleaved: Optional[bool] = True,
enable_qkv_fusion: bool = True,
out_bias: bool = True,
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.causal = causal
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = self.embed_dim // num_heads
self.kv_dim = self.head_dim * num_heads # num_kv_heads equals to num_heads in MHA
self.enable_qkv_fusion = enable_qkv_fusion
self.use_dynamic_ntk_rope = use_dynamic_ntk_rope
self.rotary_emb_dim = rotary_emb_dim
self.max_position_embeddings = max_position_embeddings
self.interleaved = qk_interleaved
factory_kwargs = {"device": device, "dtype": dtype}
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
if self.rotary_emb_dim > 0:
self.rotary_emb = new_rotary_embedding(
self.rotary_emb_dim,
base=rope_base,
scale_base=rotary_emb_scale_base,
device=device,
max_position_embeddings=max_position_embeddings,
scaling_factor=1.0,
rotary_type="dynamic_ntk" if self.use_dynamic_ntk_rope else "native",
)
if self.enable_qkv_fusion:
# bias=True is according to https://spaces.ac.cn/archives/9577
self.wqkv = new_linear("wqkv", embed_dim, 3 * embed_dim, bias, **factory_kwargs)
else:
self.wq = new_linear("wq", embed_dim, embed_dim, bias, **factory_kwargs)
self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs)
self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs)
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
# output projection always have the bias (for now) (except for baichuan2 model)
self.out_proj = new_linear("out_proj", embed_dim, embed_dim, bias=out_bias, **factory_kwargs)
def register_checkpoint_compatibility_hooks(
self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None
):
# Here we explicitly expose the checkpoint compatibility interface of the module,
# hoping that model developers will make good use of it when adapting.
# Is this interface already meeting all reasonable requirements?
self._register_load_state_dict_pre_hook(pre_load_hook, with_module=True)
self._register_state_dict_hook(pre_save_hook)
def forward(self, x, inference_params=None, **kwargs):
if inference_params is None:
return self._training(x=x, **kwargs)
else:
return self._inference(x=x, inference_params=inference_params, **kwargs)
def _training(self, x, **kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim)
"""
# wqkv
if self.enable_qkv_fusion:
qkv = self.wqkv(x)
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
q = qkv[:, :, 0].squeeze(2)
k = qkv[:, :, 1].squeeze(2)
v = qkv[:, :, 2].squeeze(2)
else:
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim)
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim)
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim)
# rotary embedding
indexes = kwargs.pop("indexes", 0)
max_seqlen = kwargs.get("max_seqlen", None)
q = self.rotary_emb(q, offsets=indexes, cache_type="query", interleaved=self.interleaved, max_seqlen=max_seqlen)
k = self.rotary_emb(k, offsets=indexes, cache_type="key", interleaved=self.interleaved, max_seqlen=max_seqlen)
# self attention
kwargs = _convert_cu_seqlens_for_qksplited(kwargs)
if gpc.config.data.use_packed_dataset is False or self.training is False:
kwargs.pop("max_seqlen_q", None)
kwargs.pop("max_seqlen_k", None)
context = self.inner_attn(q, k, v, **kwargs)
# wo
return self.out_proj(rearrange(context, "b s h d -> b s (h d)"))
def _convert_unpacked_qkv_to_packed(
self, q: torch.Tensor, kv: torch.Tensor, batch_size: int, attention_mask: torch.Tensor
):
cu_seqlens = torch.concat(
[
torch.tensor([0], dtype=torch.int32, device=attention_mask.device),
attention_mask.sum(dim=-1).to(dtype=torch.int32),
],
dim=0,
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens_q = cu_seqlens
cu_seqlens_k = cu_seqlens
max_seqlen_q = attention_mask.shape[-1]
max_seqlen_k = attention_mask.shape[-1]
q_packed = (
q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]).unsqueeze(0)
)
kv_packed = (
kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1))
.view(-1, kv.shape[-3], kv.shape[-2], kv.shape[-1])
.unsqueeze(0)
)
return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613
assert inference_params is not None, "inference_params is required for inference"
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
attention_mask = inference_params.attention_mask
sequence_len_offset = inference_params.sequence_len_offset
batch_size = x.shape[0]
# wqkv, output: q, kv
if self.enable_qkv_fusion:
qkv = self.wqkv(x)
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
q = qkv[:, :, 0].squeeze(2)
kv = qkv[:, :, 1:]
else:
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim)
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim)
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim)
kv = torch.stack([k, v], dim=2)
# rotary embedding, output: q, kv
# q shape: [bsz, nheads, head_dim]
# kv shape: [bsz, seqlen, 2, nheads, head_dim]
if self.use_dynamic_ntk_rope:
# update kv cache fisrt when enable dynamic ntk rope.
kv = update_kv_cache(kv, inference_params, self.layer_idx)
if sequence_len_offset != 0:
if sequence_len_offset > self.max_position_embeddings:
logger.warning(
"Notice your prompt's length is longer than model's max_position_embeddings: "
f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations."
)
if self.rotary_emb_dim > 0:
q = self.rotary_emb(
q, offsets=sequence_len_offset, cache_type="query", interleaved=self.interleaved
)
k = kv[:, :, 0].squeeze(2)
self.rotary_emb(
k, offsets=0, cache_type="key", interleaved=self.interleaved, in_place=True
) # in-place is important
else:
if self.rotary_emb_dim > 0:
q = self.rotary_emb(q, offsets=0, cache_type="query", interleaved=self.interleaved)
k = kv[:, :, 0].squeeze(2)
self.rotary_emb(
k, offsets=0, cache_type="key", interleaved=self.interleaved, in_place=True
) # in-place is important
else:
assert self.rotary_emb_dim > 0, "You should use rotary_emb."
k, v = kv[:, :, 0].squeeze(2), kv[:, :, 1].squeeze(2)
if attention_mask is None:
q = self.rotary_emb(q, offsets=sequence_len_offset, cache_type="query", interleaved=self.interleaved)
k = self.rotary_emb(k, offsets=sequence_len_offset, cache_type="key", interleaved=self.interleaved)
else:
if sequence_len_offset == 0:
q = self.rotary_emb(
q, offsets=0, cache_type="query", interleaved=self.interleaved, left_padding_mask=attention_mask
)
k = self.rotary_emb(
k, offsets=0, cache_type="key", interleaved=self.interleaved, left_padding_mask=attention_mask
)
else:
if sequence_len_offset > self.max_position_embeddings:
logger.warning(
"Notice your prompt's length is longer than model's max_position_embeddings: "
f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations."
)
empties = attention_mask[..., -1].sum(dim=-1)
indexes4q = sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - empties
indexes4k = sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - empties
# TODO To fit flash_attn apis, we rearrange q&k to pack them here and
# calculate rope for this batch input. Waiting to be optimized
q = rearrange(q, "b s h d -> s b h d", d=self.head_dim) # pack input
k = rearrange(k, "b s h d -> s b h d", d=self.head_dim)
q = self.rotary_emb(q, offsets=indexes4q, cache_type="query", interleaved=self.interleaved)
k = self.rotary_emb(k, offsets=indexes4k, cache_type="key", interleaved=self.interleaved)
q = rearrange(q, "s b h d -> b s h d", d=self.head_dim) # unpack
k = rearrange(k, "s b h d -> b s h d", d=self.head_dim)
kv = torch.stack([k, v], dim=2)
# update kv cache after rotary embedding when disable dynamic ntk rope.
kv = update_kv_cache(kv, inference_params, self.layer_idx)
# self-attention
if attention_mask is None:
context = self.inner_cross_attn(q, kv)
else:
if sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen)
attn_mask = attention_mask[:, None, ...]
attn_mask = torch.logical_or(torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask)
attn_mask4flsh = ~attn_mask[:, :, -1, :].view(batch_size, -1)
output = self.inner_attn(*self._convert_unpacked_qkv_to_packed(q, kv, batch_size, attn_mask4flsh))
output = output.to(x.dtype)
context = torch.zeros_like(q).masked_scatter_(attn_mask4flsh.view(batch_size, -1, 1, 1), output)
else:
attn_mask = attention_mask[:, -1, :].view(batch_size, 1, 1, -1)
k, v = torch.chunk(kv, 2, dim=2)
k = k.squeeze(2)
v = v.squeeze(2)
sp = k.shape
scores = torch.einsum(
"blhd,bnhd->bhln",
q,
k.reshape(sp[0], sp[1], q.size(2), sp[3]),
) / math.sqrt(q.size(-1))
scores = scores.masked_fill(attn_mask, -65000.0)
scores = F.softmax(scores, dim=-1) # bsz x h x L x L
context = torch.einsum(
"bhmn,bnhd->bmhd",
scores,
v.reshape(sp[0], sp[1], q.size(2), sp[3]),
)
# wo
return self.out_proj(rearrange(context, "b s h d -> b s (h d)"))
class ChameleonLayerNorm(nn.LayerNorm):
"""
LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta
from each shard separately to each head, instead of reducing. We can apply each head's own
gamma/beta by repeat-interleaving weights from each shard, but the stats have to be computed
in the last dimension. This module applies gamma/beta manually to fulfill this requirement.
"""
def __init__(self, hidden_size, head_group_num, n_heads_per_group, *args, **kwargs):
if isinstance(hidden_size, int):
hidden_size = (hidden_size,)
super().__init__([head_group_num, *hidden_size], *args, **kwargs)
self.normalized_shape = (hidden_size[-1],)
self.n_heads_per_group = n_heads_per_group
def repeat_param(self, param):
return param.repeat_interleave(self.n_heads_per_group, dim=0)
def forward(self, hidden_states):
hidden_states = F.layer_norm(hidden_states, self.normalized_shape, None, None, eps=1e-5)
hidden_states = hidden_states * self.repeat_param(self.weight) + self.repeat_param(self.bias)
return hidden_states
class GQA(nn.Module):
"""
Multi-head self-attention and cross-attention.
Args:
embed_dim (int): The dimention of hidden state.
num_heads (int): The number of attention heads.
num_kv_heads (int): The number of attention heads for key and value.
max_position_embeddings (int): max position embeddings, 2048 by default.
bias (bool): Whether the bias is needed for linears. Will be used when initializing QKV matrix and
output projection. False by default.
dropout (float): The dropout rate for cross attention and self attention. 0.0 by default.
softmax_scale (float): The temperature to use for the softmax attention.
causal (boolean): Whether to apply causal attention mask. False by default.
layer_idx (int): The index of current layer. None by default.
use_dynamic_ntk_rope (bool): whether use dynamic ntk rope, false by default.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default.
rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements
XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
qk_interleaved (Optional[bool]): whether the odd and even columns of wq and wk is interleaved. True by default.
enable_qkv_fusion (bool): whether wq, wk and wv lienar is fused. True by default.
qk_norm (Optional[bool]): if set, the query and key will be applied by layer norm after qk_linear.
False by default.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
num_kv_heads: int,
max_position_embeddings: int = 2048,
head_dim: int = None,
bias: bool = False,
dropout: float = 0.0,
softmax_scale: float = None,
causal: bool = False,
layer_idx: int = None,
use_dynamic_ntk_rope: bool = False,
rope_base: int = 10000,
rotary_emb_dim: int = 0,
rotary_emb_scale_base: int = 0,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
qk_interleaved: Optional[bool] = True,
enable_qkv_fusion: bool = True,
qk_norm: bool = False,
chameleon_mp_size: int = 1,
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.causal = causal
self.embed_dim = embed_dim
self.num_heads = num_heads
if head_dim:
self.head_dim = head_dim
q_dim = head_dim * num_heads
else:
self.head_dim = self.embed_dim // num_heads
q_dim = embed_dim
self.num_kv_heads = num_kv_heads
self.q_per_kv = num_heads // num_kv_heads
self.kv_dim = self.head_dim * num_kv_heads
self.enable_qkv_fusion = enable_qkv_fusion
self.use_dynamic_ntk_rope = use_dynamic_ntk_rope
self.rotary_emb_dim = rotary_emb_dim
self.max_position_embeddings = max_position_embeddings
self.interleaved = qk_interleaved
factory_kwargs = {"device": device, "dtype": dtype}
assert self.use_dynamic_ntk_rope is False, "Not support dynamic ntk rope yet."
assert self.embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads"
if self.rotary_emb_dim > 0:
self.rotary_emb = new_rotary_embedding(
self.rotary_emb_dim,
base=rope_base,
scale_base=rotary_emb_scale_base,
device=device,
max_position_embeddings=max_position_embeddings,
scaling_factor=1.0,
rotary_type="dynamic_ntk" if self.use_dynamic_ntk_rope else "native",
)
self.qk_norm = qk_norm
if qk_norm:
assert enable_qkv_fusion is False, "qk_norm cannot be applied when fused wqkv"
if enable_qkv_fusion:
assert bias is False, "Fuesd wqkv only support bias is False."
self.wqkv = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, **factory_kwargs)
self._register_load_state_dict_pre_hook(
partial(_qkv_pre_load_convert, q_dim=q_dim, kv_dim=self.kv_dim), with_module=True
)
self._register_state_dict_hook(partial(_qkv_save_convert, q_dim=q_dim, kv_dim=self.kv_dim))
else:
self.wq = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs)
self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs)
self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs)
if qk_norm:
assert num_heads % chameleon_mp_size == 0, "num_heads%chameleon_mp_size != 0 in GQA"
assert num_kv_heads % chameleon_mp_size == 0, "num_kv_heads%chameleon_mp_size != 0 in GQA"
self.q_norm = ChameleonLayerNorm(self.head_dim, chameleon_mp_size, num_heads // chameleon_mp_size)
self.k_norm = ChameleonLayerNorm(self.head_dim, chameleon_mp_size, num_kv_heads // chameleon_mp_size)
self.inner_attn = SelfAttention(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx
)
self.inner_cross_attn = CrossAttention(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx
)
self.wo = new_linear("wo", q_dim, embed_dim, bias, **factory_kwargs)
def register_checkpoint_compatibility_hooks(
self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None
):
# Here we explicitly expose the checkpoint compatibility interface of the module,
# hoping that model developers will make good use of it when adapting.
# Is this interface already meeting all reasonable requirements?
self._register_load_state_dict_pre_hook(pre_load_hook, with_module=True)
self._register_state_dict_hook(pre_save_hook)
def forward(self, x, inference_params=None, **kwargs):
if inference_params is None:
return self._training(x=x, **kwargs)
else:
return self._inference(x=x, inference_params=inference_params, **kwargs)
def _training(self, x, **kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim)
"""
# wqkv
if self.enable_qkv_fusion:
qkv = self.wqkv(x)
qkv = rearrange(qkv, "b s (h gs d) -> b s h gs d", gs=self.q_per_kv + 2, d=self.head_dim)
q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :], qkv[..., -1, :])
q = rearrange(q, "b s h gs d -> b s (h gs) d")
else:
q, k, v = self.wq(x), self.wk(x), self.wv(x)
if self.qk_norm:
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim)
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim)
q_all = gather_forward_split_backward(q, ParallelMode.TENSOR, dim=-2)
q_norm_out = self.q_norm(q_all)
q = split_forward_gather_backward(q_norm_out, ParallelMode.TENSOR, dim=-2)
k_all = gather_forward_split_backward(k, ParallelMode.TENSOR, dim=-2)
k_norm_out = self.k_norm(k_all)
k = split_forward_gather_backward(k_norm_out, ParallelMode.TENSOR, dim=-2)
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim)
else:
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim)
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim)
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim)
kwargs = _convert_cu_seqlens_for_qksplited(kwargs)
# rotary embedding
if self.rotary_emb_dim > 0:
indexes = kwargs.pop("indexes", 0)
max_seqlen_q = kwargs.get("max_seqlen_q", None)
max_seqlen_k = kwargs.get("max_seqlen_k", None)
q = self.rotary_emb(
q, offsets=indexes, max_seqlen=max_seqlen_q, cache_type="query", interleaved=self.interleaved
)
k = self.rotary_emb(
k, offsets=indexes, max_seqlen=max_seqlen_k, cache_type="key", interleaved=self.interleaved
)
kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2)
if gpc.config.data.use_packed_dataset is False or self.training is False:
kwargs.pop("max_seqlen_q", None)
kwargs.pop("max_seqlen_k", None)
# self attention
context = self.inner_attn(q, kv, **kwargs)
# wo
return self.wo(rearrange(context, "b s h d -> b s (h d)"))
def _convert_unpacked_qkv_to_packed(
self, q: torch.Tensor, kv: torch.Tensor, batch_size: int, attention_mask: torch.Tensor
):
cu_seqlens = torch.concat(
[
torch.tensor([0], dtype=torch.int32, device=attention_mask.device),
attention_mask.sum(dim=-1).to(dtype=torch.int32),
],
dim=0,
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens_q = cu_seqlens
cu_seqlens_k = cu_seqlens
max_seqlen_q = attention_mask.shape[-1]
max_seqlen_k = attention_mask.shape[-1]
q_packed = (
q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]).unsqueeze(0)
)
kv_packed = (
kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1))
.view(-1, kv.shape[-3], kv.shape[-2], kv.shape[-1])
.unsqueeze(0)
)
return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613
assert inference_params is not None, "inference_params is required for inference"
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
attention_mask = inference_params.attention_mask
sequence_len_offset = inference_params.sequence_len_offset
window_size = inference_params.window_size
batch_size = x.shape[0]
# wqkv, output: q, k, v
if self.enable_qkv_fusion:
qkv = self.wqkv(x)
qkv = rearrange(qkv, "b s (h gs d) -> b s h gs d", gs=self.q_per_kv + 2, d=self.head_dim)
q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :], qkv[..., -1, :])
q = rearrange(q, "b s h gs d -> b s (h gs) d")
else:
q, k, v = self.wq(x), self.wk(x), self.wv(x)
if self.qk_norm:
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim)
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim)
q_all = gather_forward_split_backward(q, ParallelMode.TENSOR, dim=-2)
q_norm_out = self.q_norm(q_all)
q = split_forward_gather_backward(q_norm_out, ParallelMode.TENSOR, dim=-2)
k_all = gather_forward_split_backward(k, ParallelMode.TENSOR, dim=-2)
k_norm_out = self.k_norm(k_all)
k = split_forward_gather_backward(k_norm_out, ParallelMode.TENSOR, dim=-2)
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim)
else:
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim)
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim)
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim)
# rotary embedding, output: q, kv
assert self.rotary_emb_dim > 0
if attention_mask is None:
raise NotImplementedError(
"You should make sure you are aware that you are changing the method of generating."
"According to your generation function instead of inference/seq_generator_module.py, "
"You may implement here for normal running."
)
else:
if inference_params.sequence_len_offset == 0:
q = self.rotary_emb(
q, offsets=0, cache_type="query", interleaved=self.interleaved, left_padding_mask=attention_mask
)
k = self.rotary_emb(
k, offsets=0, cache_type="key", interleaved=self.interleaved, left_padding_mask=attention_mask
)
else:
empties = attention_mask[..., -1].sum(dim=-1)
indexes4q = sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - empties
indexes4k = sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - empties
# TODO To fit flash_attn apis, we rearrange q&k to pack them here and
# calculate rope for this batch input. Waiting to be optimized
q = rearrange(q, "b s h d -> s b h d", d=self.head_dim) # pack input
k = rearrange(k, "b s h d -> s b h d", d=self.head_dim)
q = self.rotary_emb(q, offsets=indexes4q, cache_type="query", interleaved=self.interleaved)
k = self.rotary_emb(k, offsets=indexes4k, cache_type="key", interleaved=self.interleaved)
q = rearrange(q, "s b h d -> b s h d", d=self.head_dim) # unpack
k = rearrange(k, "s b h d -> b s h d", d=self.head_dim)
kv = torch.stack([k, v], dim=2)
if window_size is None or window_size > sequence_len_offset:
kv = update_kv_cache(kv, inference_params, self.layer_idx)
else: # window_size <= sequence_len_offset
assert kv.size(1) == 1, "update kv length more than 1"
inference_params.key_value_memory_dict[self.layer_idx][
:, inference_params.keep_first : inference_params.window_size - 1, ...
] = inference_params.key_value_memory_dict[self.layer_idx][
:, -(inference_params.window_size - 1 - inference_params.keep_first) :, ...
].clone()
inference_params.real_sequence_len_offset = inference_params.sequence_len_offset
inference_params.sequence_len_offset = inference_params.window_size - 1
kv = update_kv_cache(kv, inference_params, self.layer_idx)
inference_params.sequence_len_offset = inference_params.real_sequence_len_offset
# When using FP16, there is a high probability of NAN in the KV.
# Since NAN cannot be removed by multiplying with and 0, it needs
# to be removed manually here.
kv = torch.where(torch.isnan(kv), 0, kv)
# attention
if attention_mask is None:
context = self.inner_cross_attn(q, kv)
else:
if sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen)
attn_mask = attention_mask[:, None, ...]
attn_mask = torch.logical_or(torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask)
attn_mask4flsh = ~attn_mask[:, :, -1, :].view(batch_size, -1)
output = self.inner_attn(*self._convert_unpacked_qkv_to_packed(q, kv, batch_size, attn_mask4flsh))
output = output.to(x.dtype)
context = torch.zeros_like(q).masked_scatter_(attn_mask4flsh.view(batch_size, -1, 1, 1), output)
else:
attn_mask = attention_mask[:, -1, :].view(batch_size, 1, 1, -1)
if window_size is not None and window_size <= sequence_len_offset:
attn_mask = torch.concat(
[
attn_mask[..., : inference_params.keep_first],
attn_mask[..., -(window_size - inference_params.keep_first) :],
],
dim=-1,
)
k, v = torch.chunk(kv, 2, dim=2)
k = k.squeeze(2)
v = v.squeeze(2)
sp = k.shape
expansion = q.size(2) // k.size(2)
scores = torch.einsum(
"blhd,bnhd->bhln",
q,
k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]),
) / math.sqrt(q.size(-1))
scores = scores.masked_fill(attn_mask, -65000.0)
scores = F.softmax(scores, dim=-1) # bsz x h x L x L
context = torch.einsum(
"bhmn,bnhd->bmhd",
scores,
v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]),
)
# wo
return self.wo(rearrange(context, "b s h d -> b s (h d)"))
try:
from flash_attn import flash_attn_func
# flash_attn >= v2.3.0
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
except (ModuleNotFoundError, ImportError):
_flash_supports_window_size = False
class SWA(nn.Module):
"""
sliding window attention
Args:
embed_dim (int): The dimention of hidden state.
num_heads (int): The number of attention heads.
process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`.
sequence_process_group (torch.distributed.ProcessGroup): The process group for attention calculation.
bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and
output projection. True by default.
dropout (float): The dropout rate for cross attention and self attention. 0.0 by default.
softmax_scale (float): The temperature to use for the softmax attention.
causal (boolean): Whether to apply causal attention mask. False by default.
layer_idx (int): The index of current layer. None by default.
rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default.
rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements
XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"],
"mtp" by default.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
num_kv_heads: int,
qkv_bias: bool = True,
o_bias: bool = False,
max_position_embeddings: int = 2048,
dropout: float = 0.0,
softmax_scale: float = None,
causal: bool = False,
layer_idx: int = None,
use_dynamic_ntk_rope: bool = False,
rope_type: str = "normal",
rope_base: int = 10000,
rope_scaling_factor: float = 1.0,
rotary_emb_dim: int = 0,
rotary_emb_scale_base: int = 0,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
use_sliding_window: bool = False,
sliding_window: int = None,
tp_mode: str = "mtp",
qk_interleaved: Optional[bool] = True,
use_logn_attn: bool = False, # Qwen1
) -> None:
assert embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads"
assert (not use_sliding_window) or (
sliding_window is not None
), "Must set `sliding windows` size when `use_sliding_window` is True."
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = self.embed_dim // num_heads
self.num_kv_heads = num_kv_heads
self.kv_dim = self.head_dim * num_kv_heads
self.causal = causal
self.layer_idx = layer_idx
self.use_dynamic_ntk_rope = use_dynamic_ntk_rope
self.rotary_emb_dim = rotary_emb_dim
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.dtype = dtype
self.tp_mode = tp_mode
self.rope_type = rope_type
self.use_logn_attn = use_logn_attn
self.interleaved = qk_interleaved
factory_kwargs = {"device": device, "dtype": dtype}
assert self.use_dynamic_ntk_rope is False, "Not support dynamic ntk rope yet."
assert self.embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads"
if self.rotary_emb_dim > 0:
self.rotary_emb = new_rotary_embedding(
self.rotary_emb_dim,
base=rope_base,
scale_base=rotary_emb_scale_base,
device=device,
max_position_embeddings=max_position_embeddings,
scaling_factor=rope_scaling_factor,
rotary_type="dynamic_ntk" if self.use_dynamic_ntk_rope else "native",
)
# notice here should change bias=True
self.wq = new_linear(
"wq",
embed_dim,
embed_dim,
qkv_bias,
**factory_kwargs,
)
self.wk = new_linear(
"wk",
embed_dim,
self.kv_dim,
qkv_bias,
**factory_kwargs,
)
self.wv = new_linear(
"wv",
embed_dim,
self.kv_dim,
qkv_bias,
**factory_kwargs,
)
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
self.inner_cross_attn_causal = causal
self.inner_cross_attn_softmax_scale = softmax_scale
self.inner_cross_attn_dropout = dropout
self.wo = new_linear(
"wo",
embed_dim,
embed_dim,
o_bias,
**factory_kwargs,
)
def forward(self, x, inference_params=None, **kwargs):
if inference_params is None:
return self._training(x=x, **kwargs)
else:
return self._inference(x=x, inference_params=inference_params, **kwargs)
def _training(self, x, **kwargs):
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = rearrange(q, "b t (h d) -> b t h d", d=self.head_dim)
k = rearrange(k, "b t (h d) -> b t h d", d=self.head_dim)
v = rearrange(v, "b t (h d) -> b t h d", d=self.head_dim)
kv_seq_len = k.size(0)
use_window_circumstance = (
_flash_supports_window_size
and self.use_sliding_window
and self.sliding_window
and kv_seq_len > self.sliding_window
)
kwargs = _convert_cu_seqlens_for_qksplited(kwargs)
# rotary embedding
if self.rotary_emb_dim > 0:
indexes = kwargs.pop("indexes", 0)
max_seqlen_q = kwargs.get("max_seqlen_q", None)
max_seqlen_k = kwargs.get("max_seqlen_k", None)
q = self.rotary_emb(
q, offsets=indexes, max_seqlen=max_seqlen_q, cache_type="query", interleaved=self.interleaved
)
k = self.rotary_emb(
k, offsets=indexes, max_seqlen=max_seqlen_k, cache_type="key", interleaved=self.interleaved
)
kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2)
if use_window_circumstance:
kwargs["window_size"] = (self.sliding_window, 0)
# self attention
context = self.inner_attn(q, kv, **kwargs)
# wo
return self.wo(rearrange(context, "b s h d -> b s (h d)"))
def _convert_unpacked_qkv_to_packed(
self, q: torch.Tensor, kv: torch.Tensor, batch_size: int, attention_mask: torch.Tensor
):
cu_seqlens = torch.concat(
[
torch.tensor([0], dtype=torch.int32, device=attention_mask.device),
attention_mask.sum(dim=-1).to(dtype=torch.int32),
],
dim=0,
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens_q = cu_seqlens
cu_seqlens_k = cu_seqlens
max_seqlen_q = attention_mask.shape[-1]
max_seqlen_k = attention_mask.shape[-1]
q_packed = (
q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]).unsqueeze(0)
)
kv_packed = (
kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1))
.view(-1, kv.shape[-3], kv.shape[-2], kv.shape[-1])
.unsqueeze(0)
)
return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
def _inference(self, x, inference_params=None, **kwargs): # pylint: disable=W0613
assert inference_params is not None, "inference_params is required for inference"
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
attention_mask = inference_params.attention_mask
sequence_len_offset = inference_params.sequence_len_offset
window_size = inference_params.window_size
bsz = x.shape[0]
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim)
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim)
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim)
kv_seq_len = k.size(0)
use_window_circumstance = (
_flash_supports_window_size
and self.use_sliding_window
and self.sliding_window
and kv_seq_len > self.sliding_window
)
assert self.rotary_emb_dim > 0
if attention_mask is None:
raise NotImplementedError(
"You should make sure you are aware that you are changing the method of generating."
"According to your generation function instead of inference/seq_generator_module.py, "
"You may implement here for normal running."
)
else:
if inference_params.sequence_len_offset == 0: