diff --git a/MoD/MoD.py b/MoD/MoD.py index a128020..27cc0ed 100644 --- a/MoD/MoD.py +++ b/MoD/MoD.py @@ -7,7 +7,7 @@ class MoDTransformerBlock(nn.Module): """Wrapper class for integrating a transformer block with Mixture-of-Depths routing. Attributes: - transformer_block (nn.Module): Transformer block to be wrapped. + transformer_block (...): Transformer block to be wrapped. router_mlp (nn.Linear): MLP layer for calculating router weights. aux_mlp (nn.Linear): MLP layer for calculating auxiliary routing decision. capacity (float): Capacity of the mixture-of-depths routing. Default is 0.125. @@ -56,6 +56,10 @@ def forward(self, x: torch.Tensor): Returns: torch.Tensor: Output tensor of shape (batch_size, sequence_length, hidden_size). + + Note: + Since we just to demonstrate the MoD concept, we don't consider the extra parameters + for the transformer block. """ B, sequence_length, _ = x.shape @@ -109,26 +113,54 @@ def forward(self, x: torch.Tensor): # [Shape] output: (batch_size, sequence_length, hidden_size) output = x.clone() - # Assure that the auxiliary decision is a boolean tensor. - aux_decision = aux_decision.bool() + # In the train stage, we use the real top-k decision. + # In the eval stage, we use the auxiliary router prediction decision. + topk_decision = ( + aux_targets.bool() if self.training else aux_decision.bool() + ) + # TODO[keli]: How to enable batch processing for the following loop? for b in range(B): # Extract tokens and router that need to go through the transformer block. - # [Shape] selected_tokens_emb: (selected_tokens_count, hidden_size) - selected_tokens_emb = x[b, aux_decision[b]] + # `unsqueeze(0)` is used to add the batch dimension back. + # [Shape] selected_tokens_emb: (1, selected_tokens_count, hidden_size) + selected_tokens_emb = (x[b, topk_decision[b]]).unsqueeze(0) # [Shape] selected_router_weights: (selected_tokens_count, 1) selected_router_weights = router_weights[ - b, aux_decision[b] + b, topk_decision[b] ].unsqueeze(-1) - if selected_tokens_emb.shape[0] > 0: + if selected_tokens_emb.shape[1] > 0: # Apply the transformer block to the selected tokens. + # [Shape] transformer_tokens_emb: (selected_tokens_count, hidden_size) transformer_tokens_emb = ( self.transformer_block(selected_tokens_emb) * selected_router_weights - ) + ).squeeze(0) # Scatter the tokens into output according to the auxiliary decision. - output[b, aux_decision[b]] = transformer_tokens_emb + output[b, topk_decision[b]] = transformer_tokens_emb return output + + +if __name__ == "__main__": + # Set the seed for reproducibility. + torch.manual_seed(42) + + # Define the transformer block. + transformer_block = nn.TransformerEncoderLayer(d_model=512, nhead=1) + + # Wrap the transformer block with MoD. + mod_transformer_block = MoDTransformerBlock( + transformer_block, hidden_size=512, capacity=0.125 + ) + + # Input tensor. + # [Shape] x: (batch_size, sequence_length, hidden_size) + x = torch.rand(2, 20, 512) + + # Forward pass. + output = mod_transformer_block(x) + + print(output.shape)