diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ec422f1..56dedb5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -6,7 +6,7 @@ on: jobs: build: - name: Build and Test Colossal-AI + name: Build and Test Titans if: | github.event.pull_request.draft == false && github.base_ref == 'main' && @@ -23,7 +23,7 @@ jobs: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - name: Install Colossal-AI run: | - pip install colossalai==0.1.4+torch1.10cu11.3 -f https://release.colossalai.org + pip install colossalai==0.1.7+torch1.10cu11.3 -f https://release.colossalai.org pip install -v . pip install -r requirements/requirements-test.txt - name: Unit Testing diff --git a/titans/layer/attention/__init__.py b/titans/layer/attention/__init__.py index 5022573..6b8b478 100644 --- a/titans/layer/attention/__init__.py +++ b/titans/layer/attention/__init__.py @@ -1,5 +1,5 @@ from .gpt_attention import GPTSelfAttention -from .detr_attention import DeTrCrossAttention +from .detr_attention import DeTrAttention from .vit_attention import ViTSelfAttention from .vit_moe_attention import SelfAttentionForMoe from .transformer_attention import TransformerSelfAttention, TransformerMultiHeadAttention diff --git a/titans/layer/attention/detr_attention.py b/titans/layer/attention/detr_attention.py index a06085e..06d48bf 100644 --- a/titans/layer/attention/detr_attention.py +++ b/titans/layer/attention/detr_attention.py @@ -6,10 +6,10 @@ from colossalai import nn as col_nn from ..init_rules import init_rules from titans.decorator import no_support -# This part need to work together with the col_nn.Linear (row, col) in order to better parallelize. + @no_support(['sp']) -class DeTrCrossAttention(nn.Module): +class DeTrAttention(nn.Module): def __init__(self, hidden_size: int, @@ -25,46 +25,57 @@ def __init__(self, hidden_size, dtype=dtype, bias=bias, - ) - self.key_value = col_nn.Linear1D_Col(hidden_size, - 2 * hidden_size, + **init_rules[init_method]['transformer']) + self.key = col_nn.Linear1D_Col(hidden_size, + hidden_size, + dtype=dtype, + bias=bias, + **init_rules[init_method]['transformer']) + self.value = col_nn.Linear1D_Col(hidden_size, + hidden_size, dtype=dtype, bias=bias, - ) + **init_rules[init_method]['transformer']) self.attention_dropout = col_nn.Dropout(attention_dropout) self.dense = col_nn.Linear1D_Row(hidden_size, hidden_size, dtype=dtype, bias=True) self.dropout = col_nn.Dropout(dropout) self.softmax = nn.Softmax(dim=-1) - def forward(self, x, memory): - q = self.query(x) - kv = self.key_value(memory) - all_head_size = kv.shape[-1] // 2 + def forward(self, q, k, v, attn_mask=None, key_padding_mask=None): + q = self.query(q) + k = self.key(k) + v = self.value(v) + + all_head_size = q.shape[-1] num_attention_heads = all_head_size // self.attention_head_size new_q_shape = q.shape[:-1] + (num_attention_heads, self.attention_head_size) q = q.view(new_q_shape) q = q.permute((0, 2, 1, 3)) - q = q.permute((2, 3, 0, 1)) # ? - new_kv_shape = kv.shape[:-1] + (num_attention_heads, 2 * self.attention_head_size) - kv = kv.view(new_kv_shape) - kv = kv.permute((0, 2, 1, 3)) - k, v = torch.chunk(kv, 2, dim=-1) - k = k.permute((2, 3, 0, 1)) # ? - v = v.permute((2, 3, 0, 1)) # ? + new_k_shape = k.shape[:-1] + (num_attention_heads, self.attention_head_size) + k = k.view(new_k_shape) + k = k.permute((0, 2, 1, 3)) + + new_v_shape = v.shape[:-1] + (num_attention_heads, self.attention_head_size) + v = v.view(new_v_shape) + v = v.permute((0, 2, 1, 3)) x = torch.matmul(q, k.transpose(-1, -2)) x = x / math.sqrt(self.attention_head_size) + + # if attn_mask is not None: + # x += attn_mask + x = self.softmax(x) x = self.attention_dropout(x) x = torch.matmul(x, v) x = x.transpose(1, 2) new_context_layer_shape = x.size()[:-2] + (all_head_size,) + # the size of x after reshape is (BATCH_SZIE, SEQ_LEN, HIDDEN_SIZE) x = x.reshape(new_context_layer_shape) - x = x.transpose(0, 1) - + # the size of x after dense is (BATCH_SZIE, SEQ_LEN, HIDDEN_SIZE) x = self.dense(x) x = self.dropout(x) diff --git a/titans/layer/block/detr_block.py b/titans/layer/block/detr_block.py index f396911..fa5625e 100644 --- a/titans/layer/block/detr_block.py +++ b/titans/layer/block/detr_block.py @@ -6,7 +6,7 @@ from colossalai.nn.layer.utils import CheckpointModule from torch import dtype, nn -from titans.layer.attention import ViTSelfAttention, DeTrCrossAttention +from titans.layer.attention import DeTrAttention from titans.layer.mlp import ViTMLP from titans.decorator import support_tp_pp_only @@ -29,7 +29,7 @@ def __init__(self, init_method: str = 'torch'): super().__init__(checkpoint) self.norm1 = col_nn.LayerNorm(normalized_shape=hidden_size, eps=layernorm_epsilon, dtype=dtype) - self.attn = ViTSelfAttention(hidden_size=hidden_size, + self.attn = DeTrAttention(hidden_size=hidden_size, num_heads=num_heads, attention_dropout=attention_dropout, dropout=dropout, @@ -46,10 +46,12 @@ def __init__(self, bias=bias, init_method=init_method) - def _forward(self, x): - x = x + self.drop_path(self.norm1(self.attn(x))) + def _forward(self, x, attn_mask=None, key_padding_mask=None): + # input dimension [b,s,h] + x = x.transpose(0,1) + x = x + self.drop_path(self.norm1(self.attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask))) x = x + self.drop_path(self.norm2(self.mlp(x))) - return x + return x.transpose(0,1) @support_tp_pp_only() @@ -73,7 +75,7 @@ def __init__(self, self.norm2 = col_nn.LayerNorm(normalized_shape=hidden_size, eps=layernorm_epsilon, dtype=dtype) self.norm3 = col_nn.LayerNorm(normalized_shape=hidden_size, eps=layernorm_epsilon, dtype=dtype) - self.attn1 = ViTSelfAttention(hidden_size=hidden_size, + self.attn1 = DeTrAttention(hidden_size=hidden_size, num_heads=num_heads, attention_dropout=attention_dropout, dropout=dropout, @@ -81,7 +83,7 @@ def __init__(self, dtype=dtype, init_method=init_method) - self.attn2 = DeTrCrossAttention(hidden_size=hidden_size, + self.attn2 = DeTrAttention(hidden_size=hidden_size, num_heads=num_heads, attention_dropout=attention_dropout, dropout=dropout, @@ -99,8 +101,11 @@ def __init__(self, bias=bias, init_method=init_method) - def _forward(self, x, memory): - x = x + self.drop_path(self.norm1(self.attn1(x))) - x = x + self.drop_path(self.norm2(self.attn2(x, memory))) + def _forward(self, x, memory, self_attn_mask=None, self_attn_key_padding_mask=None, multihead_attn_mask=None, multihead_attn_key_padding_mask=None): + # input dimension [b,s,h] [q,s,h] + x = x.transpose(0,1) + memory = memory.transpose(0,1) + x = x + self.drop_path(self.norm1(self.attn1(x, x, x, attn_mask=self_attn_mask, key_padding_mask=self_attn_key_padding_mask))) + x = x + self.drop_path(self.norm2(self.attn2(x, memory, memory, attn_mask=multihead_attn_mask, key_padding_mask=multihead_attn_key_padding_mask))) x = x + self.drop_path(self.mlp(self.norm3(x))) - return x + return x.transpose(0,1)