@@ -177,18 +177,18 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
177
177
dots_image = einsum ('b i d, b i j d -> b i j' , q_img , k_img )
178
178
dots_image_to_text = einsum ('b i d, b j d -> b i j' , q_img , k_text )
179
179
180
- # calculate causal attention for local convolution
180
+ # use padding of 0 on tensor of 1s and unfold for padding mask
181
181
182
182
i , j = dots_image .shape [- 2 :]
183
- img_seq = torch .arange ( img_seq_len , device = device )
184
- k_img_indices = rearrange (img_seq . float () , '(h w) -> () () h w' , h = img_size )
185
- k_img_indices = F .pad (k_img_indices , causal_padding , value = img_seq_len ) # padding set to be max, so it is never attended to
186
- k_img_indices = F .unfold (k_img_indices , kernel_size , dilation = dilation )
187
- k_img_indices = rearrange (k_img_indices , 'b j i -> b i j' )
183
+ ones = torch .ones (( img_seq_len ,) , device = device )
184
+ ones = rearrange (ones , '(h w) -> () () h w' , h = img_size )
185
+ ones = F .pad (ones , causal_padding , value = 0. )
186
+ ones = F .unfold (ones , kernel_size , dilation = dilation )
187
+ ones = rearrange (ones , 'b j i -> b i j' )
188
188
189
189
# mask image attention
190
190
191
- padding_mask = k_img_indices == img_seq_len
191
+ padding_mask = ones == 0.
192
192
193
193
# concat text mask with image causal mask
194
194
0 commit comments