Skip to content

Commit 2969103

Browse files
committed
cleanup conv-like attention
1 parent 3064403 commit 2969103

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

dalle_pytorch/attention.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -177,18 +177,18 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
177177
dots_image = einsum('b i d, b i j d -> b i j', q_img, k_img)
178178
dots_image_to_text = einsum('b i d, b j d -> b i j', q_img, k_text)
179179

180-
# calculate causal attention for local convolution
180+
# use padding of 0 on tensor of 1s and unfold for padding mask
181181

182182
i, j = dots_image.shape[-2:]
183-
img_seq = torch.arange(img_seq_len, device = device)
184-
k_img_indices = rearrange(img_seq.float(), '(h w) -> () () h w', h = img_size)
185-
k_img_indices = F.pad(k_img_indices, causal_padding, value = img_seq_len) # padding set to be max, so it is never attended to
186-
k_img_indices = F.unfold(k_img_indices, kernel_size, dilation = dilation)
187-
k_img_indices = rearrange(k_img_indices, 'b j i -> b i j')
183+
ones = torch.ones((img_seq_len,), device = device)
184+
ones = rearrange(ones, '(h w) -> () () h w', h = img_size)
185+
ones = F.pad(ones, causal_padding, value = 0.)
186+
ones = F.unfold(ones, kernel_size, dilation = dilation)
187+
ones = rearrange(ones, 'b j i -> b i j')
188188

189189
# mask image attention
190190

191-
padding_mask = k_img_indices == img_seq_len
191+
padding_mask = ones == 0.
192192

193193
# concat text mask with image causal mask
194194

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '1.5.0',
7+
version = '1.5.1',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)