Skip to content

Encoder-decoder Multihead attention cpu optimization #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 46 additions & 28 deletions fastseq/optimizer/fairseq/beam_search_optimizer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,6 @@ def forward(
],
dim=1)

q = q.contiguous().view(tgt_len, bsz * self.num_heads,
self.head_dim).transpose(0, 1)
if k is not None:
kv_bsz = k.size(1)
k = k.contiguous().view(-1, kv_bsz * self.num_heads,
Expand Down Expand Up @@ -283,14 +281,16 @@ def forward(
dim=1)

if self.encoder_decoder_attention and bsz != kv_bsz:
attn_weights = torch.einsum(
'bxhtd,bhsd->bxhts',
q.view(kv_bsz, -1, self.num_heads,
*q.size()[1:]),
k.view(kv_bsz, self.num_heads,
*k.size()[1:]))
attn_weights = attn_weights.reshape(-1, *attn_weights.size()[-2:])
#query size (1, B*b*h, c_embed) => (B*h, b, c)
q = q.view(tgt_len,-1, self.beam_size, self.num_heads,
self.head_dim).permute(1,3,2,0,4).contiguous(
).view(kv_bsz*self.num_heads, self.beam_size, self.head_dim)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = attn_weights.view(-1, tgt_len,
*attn_weights.size()[-1:])
else:
q = q.contiguous().view(tgt_len, bsz * self.num_heads,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why contiguous is needed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was present in the earlier implementation, I didn't touch it since my changes are only meant for en-dec attention. I agree this is redundant. I'll remove it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are other places using contiguous. please also check if they can be removed as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just checked this. In all other places, its present after permute/transpose operations which is essential.

self.head_dim).transpose(0, 1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(
attn_weights, tgt_len, src_len, bsz)
Expand All @@ -306,15 +306,14 @@ def forward(

if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
src_len)
#attn_weights size (B*b*h/B*h*b, 1, S) => (B,h*b, S)
attn_weights = attn_weights.view(kv_bsz, -1, src_len)
if not self.tpu:
attn_weights = attn_weights.view(kv_bsz, -1, self.num_heads,
tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(
key_padding_mask.unsqueeze(1).to(
torch.bool), float("-inf"))
else:
#Not supported
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add "assert False, reason"

attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(
key_padding_mask, float('-inf'))
Expand All @@ -323,6 +322,11 @@ def forward(
src_len)

if before_softmax:
#attn_weights size (B*h*b, 1, S) => (B*b*h, 1, S)
if self.encoder_decoder_attention and bsz != kv_bsz:
attn_weights = attn_weights.view(kv_bsz,self.num_heads,
self.beam_size, tgt_len, src_len).permute(0,2,1,3,4
).contiguous().view(-1, tgt_len, src_len)
return attn_weights, v

attn_weights_float = utils.softmax(attn_weights,
Expand All @@ -335,18 +339,26 @@ def forward(
assert v is not None

if self.encoder_decoder_attention and bsz != kv_bsz:
attn = torch.einsum(
'bxhts,bhsd->bxhtd',
attn_probs.view(kv_bsz, -1, self.num_heads,
*attn_probs.size()[1:]),
v.view(kv_bsz, self.num_heads,
*v.size()[1:]))
attn = attn.reshape(-1, *attn.size()[-2:])
#attn_probs size (B*h*b, 1, S) => (B*h, b, S)
attn_probs = attn_probs.view(-1, self.beam_size, src_len)
attn = torch.bmm(attn_probs, v)

if self.encoder_decoder_attention and bsz != kv_bsz:
assert list(
attn.size()) == [kv_bsz * self.num_heads,
self.beam_size, self.head_dim]
else:
attn = torch.bmm(attn_probs, v)
assert list(
attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if (self.onnx_trace and attn.size(1) == 1):
assert list(
attn.size()) == [bsz * self.num_heads,
tgt_len, self.head_dim]

if self.encoder_decoder_attention and bsz != kv_bsz:
#attn size (B*h, b, c) => (1, B*b, c_embed)
attn = attn.view(kv_bsz, self.num_heads,
self.beam_size,self.head_dim).permute(0, 2, 1, 3
).contiguous().view(tgt_len, bsz, embed_dim)
#.view(tgt_len, -1, self.head_dim
elif (self.onnx_trace and attn.size(1) == 1):
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
Expand All @@ -356,9 +368,15 @@ def forward(
attn = self.out_proj(attn)

if need_weights:
attn_weights = attn_weights_float.view(bsz, self.num_heads,
tgt_len,
src_len).transpose(1, 0)
#attn_weights size (B*h*b,1, S) => (h,B*b,1,S)
if self.encoder_decoder_attention and bsz != kv_bsz:
attn_weights = attn_weights_float.view(kv_bsz, self.num_heads,
self.beam_size, tgt_len, src_len).permute(1,0,2,3,4).contiguous(
).view(self.num_heads, bsz, tgt_len, src_len)
else:
attn_weights = attn_weights_float.view(bsz, self.num_heads,
tgt_len,
src_len).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
Expand Down