11
11
from torch .distributed import all_reduce
12
12
from importlib import import_module
13
13
14
-
15
14
# Help functions for Rotary Embeddings
16
15
# https://arxiv.org/pdf/2104.09864.pdf
17
16
# too convoluted to make maxseqlen a parameter.
@@ -258,6 +257,7 @@ def __init__(
258
257
max_relative_positions : int = 0 ,
259
258
relative_positions_buckets : int = 0 ,
260
259
rotary_interleave : bool = True ,
260
+ rotary_theta : int = 1e4 ,
261
261
attn_type : str = None ,
262
262
self_attn_type : str = None ,
263
263
add_qkvbias = False ,
@@ -352,9 +352,19 @@ def __init__(
352
352
self .relative_attention_bias = None
353
353
354
354
if max_relative_positions == - 1 : # rotary embeddings
355
- self .rope = rotaryembeddings (self .dim_per_head )
355
+ self .rope = rotaryembeddings (self .dim_per_head , base = rotary_theta )
356
+ self .cos = (
357
+ self .rope [:, : self .rope .size (1 ) // 2 ].real .contiguous ().half ()
358
+ )
359
+ self .sin = (
360
+ self .rope [:, : self .rope .size (1 ) // 2 ].imag .contiguous ().half ()
361
+ )
356
362
self .rotary_interleave = rotary_interleave
357
-
363
+ self .rotary_theta = rotary_theta
364
+ else :
365
+ self .cos = None
366
+ self .sin = None
367
+ self .rotary_interleave = None
358
368
if max_relative_positions == - 2 : # alibi positional bias
359
369
self .alibi = AlibiPositionalBias (head_count )
360
370
@@ -367,6 +377,9 @@ def __init__(
367
377
and torch .cuda .get_device_capability ()[0 ] >= 8
368
378
):
369
379
self .flash_attn_func = getattr (flash_pack , "flash_attn_func" )
380
+ self .flash_attn_with_kvcache = getattr (
381
+ flash_pack , "flash_attn_with_kvcache"
382
+ )
370
383
self .flash2 = True
371
384
except ImportError :
372
385
self .flash2 = False
@@ -420,27 +433,104 @@ def forward(
420
433
key = shape (key , self .dim_per_head )
421
434
value = shape (value , self .dim_per_head )
422
435
423
- if self .max_relative_positions == - 1 : # Rotary Embeddings
424
- start_pos = step
425
- seqlen = query .size (2 )
426
- if seqlen > self .rope .size (0 ):
427
- self .rope = rotaryembeddings (
428
- self .dim_per_head , maxseqlen = (seqlen + 2048 )
429
- ).to (self .rope .device )
430
- rope = self .rope [start_pos : start_pos + seqlen ]
431
- query , key = apply_rotary_emb (
432
- query , key , rope , interleave = self .rotary_interleave
433
- )
436
+ start_pos = step
437
+ seqlen = query .size (2 )
438
+
439
+ if (
440
+ step == 0
441
+ or not self .flash2
442
+ or self .max_relative_positions not in [0 , - 1 ]
443
+ or query .size (0 ) > 128
444
+ or query .dtype != torch .float16
445
+ ):
446
+ if self .max_relative_positions == - 1 : # Rotary Embeddings
447
+ if seqlen > self .rope .size (0 ):
448
+ self .rope = rotaryembeddings (
449
+ self .dim_per_head ,
450
+ maxseqlen = (seqlen + 2048 ),
451
+ base = self .rotary_theta ,
452
+ ).to (self .rope .device )
453
+ rope = self .rope [start_pos : start_pos + seqlen ]
454
+ query , key = apply_rotary_emb (
455
+ query , key , rope , interleave = self .rotary_interleave
456
+ )
457
+
458
+ if self .layer_cache [1 ]["keys" ].numel () != 0 :
459
+ key = torch .cat ((self .layer_cache [1 ]["keys" ], key ), dim = 2 )
460
+ value = torch .cat ((self .layer_cache [1 ]["values" ], value ), dim = 2 )
461
+ if sliding_window > 0 and key .size (2 ) > sliding_window :
462
+ key = key [:, :, 1 :, :]
463
+ value = value [:, :, 1 :, :]
464
+
465
+ self .layer_cache [1 ]["keys" ] = key
466
+ self .layer_cache [1 ]["values" ] = value
434
467
435
- if self .layer_cache [1 ]["keys" ].numel () != 0 :
436
- key = torch .cat ((self .layer_cache [1 ]["keys" ], key ), dim = 2 )
437
- value = torch .cat ((self .layer_cache [1 ]["values" ], value ), dim = 2 )
468
+ else :
469
+ if self .max_relative_positions == - 1 : # Rotary Embeddings
470
+ if seqlen > self .rope .size (0 ):
471
+ self .rope = rotaryembeddings (
472
+ self .dim_per_head ,
473
+ maxseqlen = (seqlen + 2048 ),
474
+ base = self .rotary_theta ,
475
+ ).to (self .rope .device )
476
+ self .cos = (
477
+ self .rope [:, : self .rope .size (1 ) // 2 ]
478
+ .real .contiguous ()
479
+ .half ()
480
+ )
481
+ self .sin = (
482
+ self .rope [:, : self .rope .size (1 ) // 2 ]
483
+ .imag .contiguous ()
484
+ .half ()
485
+ )
486
+ if start_pos >= self .layer_cache [1 ]["keys" ].size (2 ):
487
+ self .layer_cache [1 ]["keys" ] = torch .cat (
488
+ [
489
+ self .layer_cache [1 ]["keys" ],
490
+ torch .zeros (
491
+ self .layer_cache [1 ]["keys" ].shape [:- 2 ]
492
+ + (32 ,)
493
+ + self .layer_cache [1 ]["keys" ].shape [- 1 :],
494
+ device = query .device ,
495
+ ).half (),
496
+ ],
497
+ dim = - 2 ,
498
+ )
499
+ self .layer_cache [1 ]["values" ] = torch .cat (
500
+ [
501
+ self .layer_cache [1 ]["values" ],
502
+ torch .zeros (
503
+ self .layer_cache [1 ]["values" ].shape [:- 2 ]
504
+ + (32 ,)
505
+ + self .layer_cache [1 ]["values" ].shape [- 1 :],
506
+ device = query .device ,
507
+ ).half (),
508
+ ],
509
+ dim = - 2 ,
510
+ )
438
511
if sliding_window > 0 and key .size (2 ) > sliding_window :
439
- key = key [:, :, 1 :, :]
440
- value = value [:, :, 1 :, :]
512
+ self .layer_cache [1 ]["keys" ] = self .layer_cache [1 ]["keys" ][
513
+ :, :, 1 :, :
514
+ ]
515
+ self .layer_cache [1 ]["values" ] = self .layer_cache [1 ]["values" ][
516
+ :, :, 1 :, :
517
+ ]
518
+ context = self .flash_attn_with_kvcache (
519
+ query .transpose (1 , 2 ),
520
+ self .layer_cache [1 ]["keys" ].transpose (1 , 2 ),
521
+ self .layer_cache [1 ]["values" ].transpose (1 , 2 ),
522
+ key .transpose (1 , 2 ),
523
+ value .transpose (1 , 2 ),
524
+ rotary_cos = self .cos ,
525
+ rotary_sin = self .sin ,
526
+ cache_seqlens = step ,
527
+ rotary_interleaved = self .rotary_interleave ,
528
+ ).transpose (1 , 2 )
529
+ attn_output = self .final_linear (unshape (context ))
530
+ if self .parallel_gpu > 1 :
531
+ all_reduce (attn_output )
532
+ return attn_output , None
441
533
442
- self .layer_cache [1 ]["keys" ] = key
443
- self .layer_cache [1 ]["values" ] = value
444
534
elif self .attn_type == "context" :
445
535
query = self .linear_query (query )
446
536
query = shape (query , self .dim_per_head )
@@ -484,7 +574,9 @@ def forward(
484
574
seqlen = query .size (2 )
485
575
if seqlen > self .rope .size (0 ):
486
576
self .rope = rotaryembeddings (
487
- self .dim_per_head , maxseqlen = (seqlen + 2048 )
577
+ self .dim_per_head ,
578
+ maxseqlen = (seqlen + 2048 ),
579
+ base = self .rotary_theta ,
488
580
).to (self .rope .device )
489
581
rope = self .rope [start_pos : start_pos + seqlen ].to (query .device )
490
582
query , key = apply_rotary_emb (
0 commit comments