Skip to content

Commit 0436cdd

Browse files
authored
use flash_attn_with_kvcache for faster inference (#2539)
* use flash_attn_with_kvcache * patch rmsnorm for multiexperts * rope theta as an option
1 parent 05cde4d commit 0436cdd

File tree

10 files changed

+183
-39
lines changed

10 files changed

+183
-39
lines changed

onmt/decoders/transformer.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
from onmt.modules.position_ffn import ActivationFunction
1212
from onmt.modules.moe import MoE
1313
from onmt.utils.misc import sequence_mask
14-
15-
try:
16-
from apex.normalization import FusedRMSNorm as RMSNorm
17-
except ImportError:
18-
from onmt.modules.rmsnorm import RMSNorm
14+
from onmt.modules.rmsnorm import RMSNorm
1915

2016

2117
class TransformerDecoderLayerBase(nn.Module):
@@ -44,6 +40,7 @@ def __init__(
4440
parallel_gpu=1,
4541
sliding_window=0,
4642
rotary_interleave=True,
43+
rotary_theta=1e4,
4744
num_experts=0,
4845
num_experts_per_tok=2,
4946
):
@@ -89,6 +86,7 @@ def __init__(
8986
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
9087
rotary_interleave (bool): Interleave the head dimensions when rotary
9188
embeddings are applied
89+
rotary_theta (int): rotary base theta
9290
"""
9391
super(TransformerDecoderLayerBase, self).__init__()
9492

@@ -100,6 +98,7 @@ def __init__(
10098
max_relative_positions=max_relative_positions,
10199
relative_positions_buckets=relative_positions_buckets,
102100
rotary_interleave=rotary_interleave,
101+
rotary_theta=rotary_theta,
103102
attn_type="self",
104103
self_attn_type=self_attn_type,
105104
add_qkvbias=add_qkvbias,
@@ -280,6 +279,7 @@ def __init__(
280279
parallel_gpu=1,
281280
sliding_window=0,
282281
rotary_interleave=True,
282+
rotary_theta=1e4,
283283
num_experts=0,
284284
num_experts_per_tok=2,
285285
):
@@ -311,6 +311,7 @@ def __init__(
311311
parallel_gpu=parallel_gpu,
312312
sliding_window=sliding_window,
313313
rotary_interleave=rotary_interleave,
314+
rotary_theta=rotary_theta,
314315
num_experts=num_experts,
315316
num_experts_per_tok=num_experts_per_tok,
316317
)
@@ -473,6 +474,7 @@ def from_opt(cls, opt, embeddings):
473474
else 1,
474475
sliding_window=opt.sliding_window,
475476
rotary_interleave=opt.rotary_interleave,
477+
rotary_theta=opt.rotary_theta,
476478
num_experts=opt.num_experts,
477479
num_experts_per_tok=opt.num_experts_per_tok,
478480
)
@@ -563,6 +565,7 @@ class TransformerDecoder(TransformerDecoderBase):
563565
parallel_gpu (int): Number of gpu for tensor parallelism
564566
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
565567
rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
568+
rotary_theta (int): rotary base theta
566569
"""
567570

568571
def __init__(
@@ -594,6 +597,7 @@ def __init__(
594597
parallel_gpu=1,
595598
sliding_window=0,
596599
rotary_interleave=True,
600+
rotary_theta=1e4,
597601
num_experts=0,
598602
num_experts_per_tok=2,
599603
):
@@ -627,6 +631,7 @@ def __init__(
627631
parallel_gpu=parallel_gpu,
628632
sliding_window=sliding_window,
629633
rotary_interleave=rotary_interleave,
634+
rotary_theta=rotary_theta,
630635
num_experts=num_experts,
631636
num_experts_per_tok=num_experts_per_tok,
632637
)
@@ -834,6 +839,7 @@ class TransformerLMDecoder(TransformerDecoderBase):
834839
parallel_gpu (int): Number of gpu for tensor parallelism
835840
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
836841
rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
842+
rotary_theta (int): rotary base theta
837843
"""
838844

839845
def __init__(
@@ -865,6 +871,7 @@ def __init__(
865871
parallel_gpu=1,
866872
sliding_window=0,
867873
rotary_interleave=True,
874+
rotary_theta=1e4,
868875
num_experts=0,
869876
num_experts_per_tok=2,
870877
):
@@ -897,6 +904,7 @@ def __init__(
897904
parallel_gpu=parallel_gpu,
898905
sliding_window=sliding_window,
899906
rotary_interleave=rotary_interleave,
907+
rotary_theta=rotary_theta,
900908
num_experts=num_experts,
901909
num_experts_per_tok=num_experts_per_tok,
902910
)
@@ -976,3 +984,5 @@ def _init_cache(self, tgt=None):
976984
)
977985
if hasattr(layer.self_attn, "rope"):
978986
layer.self_attn.rope = layer.self_attn.rope.to(tgt.device)
987+
layer.self_attn.cos = layer.self_attn.cos.to(tgt.device)
988+
layer.self_attn.sin = layer.self_attn.sin.to(tgt.device)

onmt/encoders/transformer.py

+6
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class TransformerEncoderLayer(nn.Module):
4040
parallel_gpu (int): Number of gpu for tensor parallelism
4141
rotary_interleave (bool): Interleave the head dimensions when rotary
4242
embeddings are applied
43+
rotary_theta (int): rotary base theta
4344
"""
4445

4546
def __init__(
@@ -61,6 +62,7 @@ def __init__(
6162
use_ckpting=[],
6263
parallel_gpu=1,
6364
rotary_interleave=True,
65+
rotary_theta=1e4,
6466
):
6567
super(TransformerEncoderLayer, self).__init__()
6668

@@ -72,6 +74,7 @@ def __init__(
7274
max_relative_positions=max_relative_positions,
7375
relative_positions_buckets=relative_positions_buckets,
7476
rotary_interleave=rotary_interleave,
77+
rotary_theta=rotary_theta,
7578
attn_type="self",
7679
add_qkvbias=add_qkvbias,
7780
num_kv=num_kv,
@@ -177,6 +180,7 @@ def __init__(
177180
use_ckpting=[],
178181
parallel_gpu=1,
179182
rotary_interleave=True,
183+
rotary_theta=1e4,
180184
):
181185
super(TransformerEncoder, self).__init__()
182186

@@ -201,6 +205,7 @@ def __init__(
201205
use_ckpting=use_ckpting,
202206
parallel_gpu=parallel_gpu,
203207
rotary_interleave=rotary_interleave,
208+
rotary_theta=rotary_theta,
204209
)
205210
for i in range(num_layers)
206211
]
@@ -239,6 +244,7 @@ def from_opt(cls, opt, embeddings):
239244
if opt.parallel_mode == "tensor_parallel"
240245
else 1,
241246
rotary_interleave=opt.rotary_interleave,
247+
rotary_theta=opt.rotary_theta,
242248
)
243249

244250
def forward(self, src, src_len=None):

onmt/modules/moe.py

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import torch.nn as nn
44
from onmt.modules.position_ffn import PositionwiseFeedForward
5+
from torch.distributed import all_reduce
56

67

78
class MoE(nn.Module):
@@ -40,12 +41,15 @@ def __init__(
4041
)
4142
self.gate = nn.Linear(d_model, num_experts, bias=False)
4243
self.num_experts_per_tok = num_experts_per_tok
44+
self.parallel_gpu = parallel_gpu
4345

4446
def forward(self, x):
4547
orig_shape = x.shape
4648
x = x.view(-1, x.shape[-1])
4749

4850
scores = self.gate(x)
51+
if self.parallel_gpu > 1:
52+
all_reduce(scores)
4953
expert_weights, expert_indices = torch.topk(
5054
scores, self.num_experts_per_tok, dim=-1
5155
)

onmt/modules/multi_headed_attn.py

+114-22
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from torch.distributed import all_reduce
1212
from importlib import import_module
1313

14-
1514
# Help functions for Rotary Embeddings
1615
# https://arxiv.org/pdf/2104.09864.pdf
1716
# too convoluted to make maxseqlen a parameter.
@@ -258,6 +257,7 @@ def __init__(
258257
max_relative_positions: int = 0,
259258
relative_positions_buckets: int = 0,
260259
rotary_interleave: bool = True,
260+
rotary_theta: int = 1e4,
261261
attn_type: str = None,
262262
self_attn_type: str = None,
263263
add_qkvbias=False,
@@ -352,9 +352,19 @@ def __init__(
352352
self.relative_attention_bias = None
353353

354354
if max_relative_positions == -1: # rotary embeddings
355-
self.rope = rotaryembeddings(self.dim_per_head)
355+
self.rope = rotaryembeddings(self.dim_per_head, base=rotary_theta)
356+
self.cos = (
357+
self.rope[:, : self.rope.size(1) // 2].real.contiguous().half()
358+
)
359+
self.sin = (
360+
self.rope[:, : self.rope.size(1) // 2].imag.contiguous().half()
361+
)
356362
self.rotary_interleave = rotary_interleave
357-
363+
self.rotary_theta = rotary_theta
364+
else:
365+
self.cos = None
366+
self.sin = None
367+
self.rotary_interleave = None
358368
if max_relative_positions == -2: # alibi positional bias
359369
self.alibi = AlibiPositionalBias(head_count)
360370

@@ -367,6 +377,9 @@ def __init__(
367377
and torch.cuda.get_device_capability()[0] >= 8
368378
):
369379
self.flash_attn_func = getattr(flash_pack, "flash_attn_func")
380+
self.flash_attn_with_kvcache = getattr(
381+
flash_pack, "flash_attn_with_kvcache"
382+
)
370383
self.flash2 = True
371384
except ImportError:
372385
self.flash2 = False
@@ -420,27 +433,104 @@ def forward(
420433
key = shape(key, self.dim_per_head)
421434
value = shape(value, self.dim_per_head)
422435

423-
if self.max_relative_positions == -1: # Rotary Embeddings
424-
start_pos = step
425-
seqlen = query.size(2)
426-
if seqlen > self.rope.size(0):
427-
self.rope = rotaryembeddings(
428-
self.dim_per_head, maxseqlen=(seqlen + 2048)
429-
).to(self.rope.device)
430-
rope = self.rope[start_pos : start_pos + seqlen]
431-
query, key = apply_rotary_emb(
432-
query, key, rope, interleave=self.rotary_interleave
433-
)
436+
start_pos = step
437+
seqlen = query.size(2)
438+
439+
if (
440+
step == 0
441+
or not self.flash2
442+
or self.max_relative_positions not in [0, -1]
443+
or query.size(0) > 128
444+
or query.dtype != torch.float16
445+
):
446+
if self.max_relative_positions == -1: # Rotary Embeddings
447+
if seqlen > self.rope.size(0):
448+
self.rope = rotaryembeddings(
449+
self.dim_per_head,
450+
maxseqlen=(seqlen + 2048),
451+
base=self.rotary_theta,
452+
).to(self.rope.device)
453+
rope = self.rope[start_pos : start_pos + seqlen]
454+
query, key = apply_rotary_emb(
455+
query, key, rope, interleave=self.rotary_interleave
456+
)
457+
458+
if self.layer_cache[1]["keys"].numel() != 0:
459+
key = torch.cat((self.layer_cache[1]["keys"], key), dim=2)
460+
value = torch.cat((self.layer_cache[1]["values"], value), dim=2)
461+
if sliding_window > 0 and key.size(2) > sliding_window:
462+
key = key[:, :, 1:, :]
463+
value = value[:, :, 1:, :]
464+
465+
self.layer_cache[1]["keys"] = key
466+
self.layer_cache[1]["values"] = value
434467

435-
if self.layer_cache[1]["keys"].numel() != 0:
436-
key = torch.cat((self.layer_cache[1]["keys"], key), dim=2)
437-
value = torch.cat((self.layer_cache[1]["values"], value), dim=2)
468+
else:
469+
if self.max_relative_positions == -1: # Rotary Embeddings
470+
if seqlen > self.rope.size(0):
471+
self.rope = rotaryembeddings(
472+
self.dim_per_head,
473+
maxseqlen=(seqlen + 2048),
474+
base=self.rotary_theta,
475+
).to(self.rope.device)
476+
self.cos = (
477+
self.rope[:, : self.rope.size(1) // 2]
478+
.real.contiguous()
479+
.half()
480+
)
481+
self.sin = (
482+
self.rope[:, : self.rope.size(1) // 2]
483+
.imag.contiguous()
484+
.half()
485+
)
486+
if start_pos >= self.layer_cache[1]["keys"].size(2):
487+
self.layer_cache[1]["keys"] = torch.cat(
488+
[
489+
self.layer_cache[1]["keys"],
490+
torch.zeros(
491+
self.layer_cache[1]["keys"].shape[:-2]
492+
+ (32,)
493+
+ self.layer_cache[1]["keys"].shape[-1:],
494+
device=query.device,
495+
).half(),
496+
],
497+
dim=-2,
498+
)
499+
self.layer_cache[1]["values"] = torch.cat(
500+
[
501+
self.layer_cache[1]["values"],
502+
torch.zeros(
503+
self.layer_cache[1]["values"].shape[:-2]
504+
+ (32,)
505+
+ self.layer_cache[1]["values"].shape[-1:],
506+
device=query.device,
507+
).half(),
508+
],
509+
dim=-2,
510+
)
438511
if sliding_window > 0 and key.size(2) > sliding_window:
439-
key = key[:, :, 1:, :]
440-
value = value[:, :, 1:, :]
512+
self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][
513+
:, :, 1:, :
514+
]
515+
self.layer_cache[1]["values"] = self.layer_cache[1]["values"][
516+
:, :, 1:, :
517+
]
518+
context = self.flash_attn_with_kvcache(
519+
query.transpose(1, 2),
520+
self.layer_cache[1]["keys"].transpose(1, 2),
521+
self.layer_cache[1]["values"].transpose(1, 2),
522+
key.transpose(1, 2),
523+
value.transpose(1, 2),
524+
rotary_cos=self.cos,
525+
rotary_sin=self.sin,
526+
cache_seqlens=step,
527+
rotary_interleaved=self.rotary_interleave,
528+
).transpose(1, 2)
529+
attn_output = self.final_linear(unshape(context))
530+
if self.parallel_gpu > 1:
531+
all_reduce(attn_output)
532+
return attn_output, None
441533

442-
self.layer_cache[1]["keys"] = key
443-
self.layer_cache[1]["values"] = value
444534
elif self.attn_type == "context":
445535
query = self.linear_query(query)
446536
query = shape(query, self.dim_per_head)
@@ -484,7 +574,9 @@ def forward(
484574
seqlen = query.size(2)
485575
if seqlen > self.rope.size(0):
486576
self.rope = rotaryembeddings(
487-
self.dim_per_head, maxseqlen=(seqlen + 2048)
577+
self.dim_per_head,
578+
maxseqlen=(seqlen + 2048),
579+
base=self.rotary_theta,
488580
).to(self.rope.device)
489581
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
490582
query, key = apply_rotary_emb(

onmt/modules/position_ffn.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,7 @@
33
import torch.nn as nn
44
import torch.nn.functional as F
55
from torch.utils.checkpoint import checkpoint
6-
7-
try:
8-
from apex.normalization import FusedRMSNorm as RMSNorm
9-
except ImportError:
10-
from onmt.modules.rmsnorm import RMSNorm
6+
from onmt.modules.rmsnorm import RMSNorm
117
from torch.nn.utils import skip_init
128
from torch.distributed import all_reduce
139

0 commit comments

Comments
 (0)