Skip to content

Commit f794ba6

Browse files
committed
stability measure number 2
1 parent 27079dd commit f794ba6

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

dalle_pytorch/attention.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ def forward(self, x, mask = None):
4545
qkv = self.to_qkv(x).chunk(3, dim = -1)
4646
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
4747

48-
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
48+
q = q * self.scale
49+
50+
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
4951
mask_value = max_neg_value(dots)
5052

5153
if exists(mask):

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '0.12.1',
7+
version = '0.12.2',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)