|
1 | 1 | import torch
|
2 | 2 | from torch import nn, einsum
|
3 | 3 | import torch.nn.functional as F
|
4 |
| -from einops import rearrange |
| 4 | +from einops import rearrange, repeat |
5 | 5 |
|
6 | 6 | from dalle_pytorch.reversible import ReversibleSequence, SequentialSequence
|
7 | 7 |
|
|
10 | 10 | def exists(val):
|
11 | 11 | return val is not None
|
12 | 12 |
|
| 13 | +def cast_tuple(val, depth): |
| 14 | + return val if isinstance(val, tuple) else (val,) * depth |
| 15 | + |
13 | 16 | # classes
|
14 | 17 |
|
15 | 18 | class PreNorm(nn.Module):
|
@@ -40,10 +43,11 @@ def forward(self, x):
|
40 | 43 | return self.net(x)
|
41 | 44 |
|
42 | 45 | class Attention(nn.Module):
|
43 |
| - def __init__(self, dim, causal = True, heads = 8, dim_head = 64, dropout = 0.): |
| 46 | + def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.): |
44 | 47 | super().__init__()
|
45 | 48 | inner_dim = dim_head * heads
|
46 | 49 | self.heads = heads
|
| 50 | + self.seq_len = seq_len |
47 | 51 | self.scale = dim ** -0.5
|
48 | 52 | self.causal = causal
|
49 | 53 |
|
@@ -78,26 +82,67 @@ def forward(self, x, mask = None):
|
78 | 82 | out = self.to_out(out)
|
79 | 83 | return out
|
80 | 84 |
|
| 85 | +class SparseAttention(Attention): |
| 86 | + def __init__(self, *args, **kwargs): |
| 87 | + super().__init__(*args, **kwargs) |
| 88 | + from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig |
| 89 | + |
| 90 | + self.attn_fn = SparseSelfAttention( |
| 91 | + sparsity_config = VariableSparsityConfig( |
| 92 | + num_heads = self.heads, |
| 93 | + block = 16, |
| 94 | + attention = 'unidirectional' if self.causal else 'bidirectional' |
| 95 | + ), |
| 96 | + max_seq_length = self.seq_len, |
| 97 | + attn_mask_mode = 'add' |
| 98 | + ) |
| 99 | + |
| 100 | + def forward(self, x, mask = None): |
| 101 | + b, n, _, h, device = *x.shape, self.heads, x.device |
| 102 | + qkv = self.to_qkv(x).chunk(3, dim = -1) |
| 103 | + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) |
| 104 | + |
| 105 | + key_pad_mask = None |
| 106 | + if exists(mask): |
| 107 | + key_pad_mask = ~mask |
| 108 | + |
| 109 | + attn_mask = None |
| 110 | + if self.causal: |
| 111 | + i, j = q.shape[-2], k.shape[-2] |
| 112 | + mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() |
| 113 | + attn_mask = torch.zeros(i, j, device = device).to(q) |
| 114 | + mask_value = -(torch.finfo(q.dtype).max / 2) |
| 115 | + attn_mask.masked_fill_(mask, mask_value) |
| 116 | + |
| 117 | + out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask) |
| 118 | + out = rearrange(out, 'b h n d -> b n (h d)') |
| 119 | + return self.to_out(out) |
| 120 | + |
81 | 121 | class Transformer(nn.Module):
|
82 | 122 | def __init__(
|
83 | 123 | self,
|
84 | 124 | *,
|
85 | 125 | dim,
|
86 | 126 | depth,
|
| 127 | + seq_len, |
87 | 128 | reversible = False,
|
88 | 129 | causal = True,
|
89 | 130 | heads = 8,
|
90 | 131 | dim_head = 64,
|
91 | 132 | ff_mult = 4,
|
92 | 133 | attn_dropout = 0.,
|
93 |
| - ff_dropout = 0. |
| 134 | + ff_dropout = 0., |
| 135 | + sparse_attn = True |
94 | 136 | ):
|
95 | 137 | super().__init__()
|
96 | 138 | layers = nn.ModuleList([])
|
| 139 | + sparse_layer = cast_tuple(sparse_attn, depth) |
| 140 | + |
| 141 | + for _, sparse_attn in zip(range(depth), sparse_layer): |
| 142 | + attn_class = Attention if not sparse_attn else SparseAttention |
97 | 143 |
|
98 |
| - for _ in range(depth): |
99 | 144 | layers.append(nn.ModuleList([
|
100 |
| - PreNorm(dim, Attention(dim, causal = causal, heads = heads, dim_head = dim_head, dropout = attn_dropout)), |
| 145 | + PreNorm(dim, attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)), |
101 | 146 | PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout))
|
102 | 147 | ]))
|
103 | 148 |
|
|
0 commit comments