Skip to content

Commit dc1c6d5

Browse files
committed
add sparse attention to DALL-E
1 parent 89c99e1 commit dc1c6d5

File tree

4 files changed

+90
-9
lines changed

4 files changed

+90
-9
lines changed

README.md

+30
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,36 @@ dalle = DALLE(
145145
)
146146
```
147147

148+
## Sparse Attention
149+
150+
You can also train with Microsoft Deepspeed's Sparse Attention, with any combination of dense and sparse attention that you'd like. However, you will have to endure the installation process.
151+
152+
First, you need to install Deepspeed with Sparse Attention
153+
154+
```bash
155+
$ sh install_deepspeed.sh
156+
```
157+
158+
Next, you need to install the pip package `triton`
159+
160+
```bash
161+
$ pip install triton
162+
```
163+
164+
If both of the above succeeded, now you can train with Sparse Attention!
165+
166+
```python
167+
dalle = DALLE(
168+
dim = 512,
169+
vae = vae,
170+
num_text_tokens = 10000,
171+
text_seq_len = 256,
172+
depth = 64,
173+
heads = 8,
174+
sparse_attn = (True, False) * 32 # interleave sparse and dense attention for 64 layers
175+
)
176+
```
177+
148178
## Citations
149179

150180
```bibtex

dalle_pytorch/dalle_pytorch.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def __init__(
179179
super().__init__()
180180
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
181181
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
182-
self.text_transformer = Transformer(causal = False, dim = dim_text, depth = text_enc_depth, heads = text_heads)
182+
self.text_transformer = Transformer(causal = False, seq_len = text_seq_len, dim = dim_text, depth = text_enc_depth, heads = text_heads)
183183
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False)
184184

185185
assert visual_image_size % visual_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
@@ -189,7 +189,7 @@ def __init__(
189189
self.visual_patch_size = visual_patch_size
190190
self.to_visual_embedding = nn.Linear(patch_dim, dim_image)
191191
self.visual_pos_emb = nn.Embedding(num_patches, dim_image)
192-
self.visual_transformer = Transformer(causal = False, dim = dim_image, depth = visual_enc_depth, heads = visual_heads)
192+
self.visual_transformer = Transformer(causal = False, seq_len = num_patches, dim = dim_image, depth = visual_enc_depth, heads = visual_heads)
193193
self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False)
194194

195195
self.temperature = nn.Parameter(torch.tensor(1.))
@@ -251,7 +251,8 @@ def __init__(
251251
dim_head = 64,
252252
reversible = False,
253253
attn_dropout = 0.,
254-
ff_dropout = 0
254+
ff_dropout = 0,
255+
sparse_attn = False
255256
):
256257
super().__init__()
257258
assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE'
@@ -284,12 +285,14 @@ def __init__(
284285
self.transformer = Transformer(
285286
dim = dim,
286287
causal = True,
288+
seq_len = seq_len,
287289
depth = depth,
288290
heads = heads,
289291
dim_head = dim_head,
290292
reversible = reversible,
291293
attn_dropout = attn_dropout,
292-
ff_dropout = ff_dropout
294+
ff_dropout = ff_dropout,
295+
sparse_attn = sparse_attn
293296
)
294297

295298
self.to_logits = nn.Sequential(

dalle_pytorch/transformer.py

+50-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from torch import nn, einsum
33
import torch.nn.functional as F
4-
from einops import rearrange
4+
from einops import rearrange, repeat
55

66
from dalle_pytorch.reversible import ReversibleSequence, SequentialSequence
77

@@ -10,6 +10,9 @@
1010
def exists(val):
1111
return val is not None
1212

13+
def cast_tuple(val, depth):
14+
return val if isinstance(val, tuple) else (val,) * depth
15+
1316
# classes
1417

1518
class PreNorm(nn.Module):
@@ -40,10 +43,11 @@ def forward(self, x):
4043
return self.net(x)
4144

4245
class Attention(nn.Module):
43-
def __init__(self, dim, causal = True, heads = 8, dim_head = 64, dropout = 0.):
46+
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.):
4447
super().__init__()
4548
inner_dim = dim_head * heads
4649
self.heads = heads
50+
self.seq_len = seq_len
4751
self.scale = dim ** -0.5
4852
self.causal = causal
4953

@@ -78,26 +82,67 @@ def forward(self, x, mask = None):
7882
out = self.to_out(out)
7983
return out
8084

85+
class SparseAttention(Attention):
86+
def __init__(self, *args, **kwargs):
87+
super().__init__(*args, **kwargs)
88+
from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig
89+
90+
self.attn_fn = SparseSelfAttention(
91+
sparsity_config = VariableSparsityConfig(
92+
num_heads = self.heads,
93+
block = 16,
94+
attention = 'unidirectional' if self.causal else 'bidirectional'
95+
),
96+
max_seq_length = self.seq_len,
97+
attn_mask_mode = 'add'
98+
)
99+
100+
def forward(self, x, mask = None):
101+
b, n, _, h, device = *x.shape, self.heads, x.device
102+
qkv = self.to_qkv(x).chunk(3, dim = -1)
103+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
104+
105+
key_pad_mask = None
106+
if exists(mask):
107+
key_pad_mask = ~mask
108+
109+
attn_mask = None
110+
if self.causal:
111+
i, j = q.shape[-2], k.shape[-2]
112+
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
113+
attn_mask = torch.zeros(i, j, device = device).to(q)
114+
mask_value = -(torch.finfo(q.dtype).max / 2)
115+
attn_mask.masked_fill_(mask, mask_value)
116+
117+
out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask)
118+
out = rearrange(out, 'b h n d -> b n (h d)')
119+
return self.to_out(out)
120+
81121
class Transformer(nn.Module):
82122
def __init__(
83123
self,
84124
*,
85125
dim,
86126
depth,
127+
seq_len,
87128
reversible = False,
88129
causal = True,
89130
heads = 8,
90131
dim_head = 64,
91132
ff_mult = 4,
92133
attn_dropout = 0.,
93-
ff_dropout = 0.
134+
ff_dropout = 0.,
135+
sparse_attn = True
94136
):
95137
super().__init__()
96138
layers = nn.ModuleList([])
139+
sparse_layer = cast_tuple(sparse_attn, depth)
140+
141+
for _, sparse_attn in zip(range(depth), sparse_layer):
142+
attn_class = Attention if not sparse_attn else SparseAttention
97143

98-
for _ in range(depth):
99144
layers.append(nn.ModuleList([
100-
PreNorm(dim, Attention(dim, causal = causal, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
145+
PreNorm(dim, attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
101146
PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout))
102147
]))
103148

install_deepspeed.sh

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
sudo apt-get -y install llvm-9-dev cmake
2+
git clone https://github.com/microsoft/DeepSpeed.git /tmp/Deepspeed
3+
cd /tmp/Deepspeed && DS_BUILD_SPARSE_ATTN=1 ./install.sh -s

0 commit comments

Comments
 (0)