From 4eb7c944a267c43a1a23fa0c9bbfe157449a02be Mon Sep 17 00:00:00 2001 From: Boxiang Wang Date: Mon, 4 Jul 2022 11:47:31 +0800 Subject: [PATCH 1/3] Update detr attention --- titans/layer/attention/__init__.py | 2 +- titans/layer/attention/detr_attention.py | 73 ++++++++++++++++++------ titans/layer/block/detr_block.py | 27 +++++---- 3 files changed, 71 insertions(+), 31 deletions(-) diff --git a/titans/layer/attention/__init__.py b/titans/layer/attention/__init__.py index 5022573..6b8b478 100644 --- a/titans/layer/attention/__init__.py +++ b/titans/layer/attention/__init__.py @@ -1,5 +1,5 @@ from .gpt_attention import GPTSelfAttention -from .detr_attention import DeTrCrossAttention +from .detr_attention import DeTrAttention from .vit_attention import ViTSelfAttention from .vit_moe_attention import SelfAttentionForMoe from .transformer_attention import TransformerSelfAttention, TransformerMultiHeadAttention diff --git a/titans/layer/attention/detr_attention.py b/titans/layer/attention/detr_attention.py index a06085e..01b903a 100644 --- a/titans/layer/attention/detr_attention.py +++ b/titans/layer/attention/detr_attention.py @@ -6,10 +6,10 @@ from colossalai import nn as col_nn from ..init_rules import init_rules from titans.decorator import no_support -# This part need to work together with the col_nn.Linear (row, col) in order to better parallelize. + @no_support(['sp']) -class DeTrCrossAttention(nn.Module): +class DeTrAttention(nn.Module): def __init__(self, hidden_size: int, @@ -25,46 +25,81 @@ def __init__(self, hidden_size, dtype=dtype, bias=bias, - ) - self.key_value = col_nn.Linear1D_Col(hidden_size, - 2 * hidden_size, + **init_rules[init_method]['transformer']) + self.key = col_nn.Linear1D_Col(hidden_size, + hidden_size, + dtype=dtype, + bias=bias, + **init_rules[init_method]['transformer']) + self.value = col_nn.Linear1D_Col(hidden_size, + hidden_size, dtype=dtype, bias=bias, - ) + **init_rules[init_method]['transformer']) self.attention_dropout = col_nn.Dropout(attention_dropout) self.dense = col_nn.Linear1D_Row(hidden_size, hidden_size, dtype=dtype, bias=True) self.dropout = col_nn.Dropout(dropout) self.softmax = nn.Softmax(dim=-1) - def forward(self, x, memory): - q = self.query(x) - kv = self.key_value(memory) - all_head_size = kv.shape[-1] // 2 + def forward(self, q, k, v, attn_mask=None, key_padding_mask=None): + # bsz, tgt_len, all_head_size = q.shape + # _, src_len, _ = k.shape + + # num_attention_heads = all_head_size // self.attention_head_size + + # if key_padding_mask is not None: + # assert key_padding_mask.shape == (bsz, src_len), \ + # f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + # key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ + # expand(-1, num_attention_heads, -1, -1).reshape(bsz * num_attention_heads, 1, src_len) + # if attn_mask is None: + # attn_mask = key_padding_mask + # elif attn_mask.dtype == torch.bool: + # attn_mask = attn_mask.logical_or(key_padding_mask) + # else: + # attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf")) + + # # convert mask to float + # if attn_mask is not None and attn_mask.dtype == torch.bool: + # new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + # new_attn_mask.masked_fill_(attn_mask, float("-inf")) + # attn_mask = new_attn_mask + + + q = self.query(q) + k = self.key(k) + v = self.value(v) + + all_head_size = q.shape[-1] num_attention_heads = all_head_size // self.attention_head_size new_q_shape = q.shape[:-1] + (num_attention_heads, self.attention_head_size) q = q.view(new_q_shape) q = q.permute((0, 2, 1, 3)) - q = q.permute((2, 3, 0, 1)) # ? - new_kv_shape = kv.shape[:-1] + (num_attention_heads, 2 * self.attention_head_size) - kv = kv.view(new_kv_shape) - kv = kv.permute((0, 2, 1, 3)) - k, v = torch.chunk(kv, 2, dim=-1) - k = k.permute((2, 3, 0, 1)) # ? - v = v.permute((2, 3, 0, 1)) # ? + new_k_shape = k.shape[:-1] + (num_attention_heads, self.attention_head_size) + k = k.view(new_k_shape) + k = k.permute((0, 2, 1, 3)) + + new_v_shape = v.shape[:-1] + (num_attention_heads, self.attention_head_size) + v = v.view(new_v_shape) + v = v.permute((0, 2, 1, 3)) x = torch.matmul(q, k.transpose(-1, -2)) x = x / math.sqrt(self.attention_head_size) + + # if attn_mask is not None: + # x += attn_mask + x = self.softmax(x) x = self.attention_dropout(x) x = torch.matmul(x, v) x = x.transpose(1, 2) new_context_layer_shape = x.size()[:-2] + (all_head_size,) + # the size of x after reshape is (BATCH_SZIE, SEQ_LEN, HIDDEN_SIZE) x = x.reshape(new_context_layer_shape) - x = x.transpose(0, 1) - + # the size of x after dense is (BATCH_SZIE, SEQ_LEN, HIDDEN_SIZE) x = self.dense(x) x = self.dropout(x) diff --git a/titans/layer/block/detr_block.py b/titans/layer/block/detr_block.py index f396911..fa5625e 100644 --- a/titans/layer/block/detr_block.py +++ b/titans/layer/block/detr_block.py @@ -6,7 +6,7 @@ from colossalai.nn.layer.utils import CheckpointModule from torch import dtype, nn -from titans.layer.attention import ViTSelfAttention, DeTrCrossAttention +from titans.layer.attention import DeTrAttention from titans.layer.mlp import ViTMLP from titans.decorator import support_tp_pp_only @@ -29,7 +29,7 @@ def __init__(self, init_method: str = 'torch'): super().__init__(checkpoint) self.norm1 = col_nn.LayerNorm(normalized_shape=hidden_size, eps=layernorm_epsilon, dtype=dtype) - self.attn = ViTSelfAttention(hidden_size=hidden_size, + self.attn = DeTrAttention(hidden_size=hidden_size, num_heads=num_heads, attention_dropout=attention_dropout, dropout=dropout, @@ -46,10 +46,12 @@ def __init__(self, bias=bias, init_method=init_method) - def _forward(self, x): - x = x + self.drop_path(self.norm1(self.attn(x))) + def _forward(self, x, attn_mask=None, key_padding_mask=None): + # input dimension [b,s,h] + x = x.transpose(0,1) + x = x + self.drop_path(self.norm1(self.attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask))) x = x + self.drop_path(self.norm2(self.mlp(x))) - return x + return x.transpose(0,1) @support_tp_pp_only() @@ -73,7 +75,7 @@ def __init__(self, self.norm2 = col_nn.LayerNorm(normalized_shape=hidden_size, eps=layernorm_epsilon, dtype=dtype) self.norm3 = col_nn.LayerNorm(normalized_shape=hidden_size, eps=layernorm_epsilon, dtype=dtype) - self.attn1 = ViTSelfAttention(hidden_size=hidden_size, + self.attn1 = DeTrAttention(hidden_size=hidden_size, num_heads=num_heads, attention_dropout=attention_dropout, dropout=dropout, @@ -81,7 +83,7 @@ def __init__(self, dtype=dtype, init_method=init_method) - self.attn2 = DeTrCrossAttention(hidden_size=hidden_size, + self.attn2 = DeTrAttention(hidden_size=hidden_size, num_heads=num_heads, attention_dropout=attention_dropout, dropout=dropout, @@ -99,8 +101,11 @@ def __init__(self, bias=bias, init_method=init_method) - def _forward(self, x, memory): - x = x + self.drop_path(self.norm1(self.attn1(x))) - x = x + self.drop_path(self.norm2(self.attn2(x, memory))) + def _forward(self, x, memory, self_attn_mask=None, self_attn_key_padding_mask=None, multihead_attn_mask=None, multihead_attn_key_padding_mask=None): + # input dimension [b,s,h] [q,s,h] + x = x.transpose(0,1) + memory = memory.transpose(0,1) + x = x + self.drop_path(self.norm1(self.attn1(x, x, x, attn_mask=self_attn_mask, key_padding_mask=self_attn_key_padding_mask))) + x = x + self.drop_path(self.norm2(self.attn2(x, memory, memory, attn_mask=multihead_attn_mask, key_padding_mask=multihead_attn_key_padding_mask))) x = x + self.drop_path(self.mlp(self.norm3(x))) - return x + return x.transpose(0,1) From 887d98d63652c9c1d6798dd8d35e1b13fc37c7c2 Mon Sep 17 00:00:00 2001 From: Boxiang Wang Date: Mon, 4 Jul 2022 15:57:01 +0800 Subject: [PATCH 2/3] update build setting --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ec422f1..56dedb5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -6,7 +6,7 @@ on: jobs: build: - name: Build and Test Colossal-AI + name: Build and Test Titans if: | github.event.pull_request.draft == false && github.base_ref == 'main' && @@ -23,7 +23,7 @@ jobs: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - name: Install Colossal-AI run: | - pip install colossalai==0.1.4+torch1.10cu11.3 -f https://release.colossalai.org + pip install colossalai==0.1.7+torch1.10cu11.3 -f https://release.colossalai.org pip install -v . pip install -r requirements/requirements-test.txt - name: Unit Testing From d7fd44d1d0c4c6e0ebffe4adfc72a162a947b7bb Mon Sep 17 00:00:00 2001 From: BoxiangW <45734921+BoxiangW@users.noreply.github.com> Date: Wed, 6 Jul 2022 17:49:21 +0800 Subject: [PATCH 3/3] Update detr_attention.py --- titans/layer/attention/detr_attention.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/titans/layer/attention/detr_attention.py b/titans/layer/attention/detr_attention.py index 01b903a..06d48bf 100644 --- a/titans/layer/attention/detr_attention.py +++ b/titans/layer/attention/detr_attention.py @@ -42,30 +42,6 @@ def __init__(self, self.softmax = nn.Softmax(dim=-1) def forward(self, q, k, v, attn_mask=None, key_padding_mask=None): - # bsz, tgt_len, all_head_size = q.shape - # _, src_len, _ = k.shape - - # num_attention_heads = all_head_size // self.attention_head_size - - # if key_padding_mask is not None: - # assert key_padding_mask.shape == (bsz, src_len), \ - # f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" - # key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ - # expand(-1, num_attention_heads, -1, -1).reshape(bsz * num_attention_heads, 1, src_len) - # if attn_mask is None: - # attn_mask = key_padding_mask - # elif attn_mask.dtype == torch.bool: - # attn_mask = attn_mask.logical_or(key_padding_mask) - # else: - # attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf")) - - # # convert mask to float - # if attn_mask is not None and attn_mask.dtype == torch.bool: - # new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) - # new_attn_mask.masked_fill_(attn_mask, float("-inf")) - # attn_mask = new_attn_mask - - q = self.query(q) k = self.key(k) v = self.value(v)