Skip to content

Commit

Permalink
Enable using aux_target in the train stage and add a small example in…
Browse files Browse the repository at this point in the history
… the main scope.
  • Loading branch information
keli-wen committed May 29, 2024
1 parent b2ba5f9 commit 099d03f
Showing 1 changed file with 41 additions and 9 deletions.
50 changes: 41 additions & 9 deletions MoD/MoD.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class MoDTransformerBlock(nn.Module):
"""Wrapper class for integrating a transformer block with Mixture-of-Depths routing.
Attributes:
transformer_block (nn.Module): Transformer block to be wrapped.
transformer_block (...): Transformer block to be wrapped.
router_mlp (nn.Linear): MLP layer for calculating router weights.
aux_mlp (nn.Linear): MLP layer for calculating auxiliary routing decision.
capacity (float): Capacity of the mixture-of-depths routing. Default is 0.125.
Expand Down Expand Up @@ -56,6 +56,10 @@ def forward(self, x: torch.Tensor):
Returns:
torch.Tensor: Output tensor of shape (batch_size, sequence_length, hidden_size).
Note:
Since we just to demonstrate the MoD concept, we don't consider the extra parameters
for the transformer block.
"""
B, sequence_length, _ = x.shape

Expand Down Expand Up @@ -109,26 +113,54 @@ def forward(self, x: torch.Tensor):
# [Shape] output: (batch_size, sequence_length, hidden_size)
output = x.clone()

# Assure that the auxiliary decision is a boolean tensor.
aux_decision = aux_decision.bool()
# In the train stage, we use the real top-k decision.
# In the eval stage, we use the auxiliary router prediction decision.
topk_decision = (
aux_targets.bool() if self.training else aux_decision.bool()
)

# TODO[keli]: How to enable batch processing for the following loop?
for b in range(B):
# Extract tokens and router that need to go through the transformer block.
# [Shape] selected_tokens_emb: (selected_tokens_count, hidden_size)
selected_tokens_emb = x[b, aux_decision[b]]
# `unsqueeze(0)` is used to add the batch dimension back.
# [Shape] selected_tokens_emb: (1, selected_tokens_count, hidden_size)
selected_tokens_emb = (x[b, topk_decision[b]]).unsqueeze(0)
# [Shape] selected_router_weights: (selected_tokens_count, 1)
selected_router_weights = router_weights[
b, aux_decision[b]
b, topk_decision[b]
].unsqueeze(-1)

if selected_tokens_emb.shape[0] > 0:
if selected_tokens_emb.shape[1] > 0:
# Apply the transformer block to the selected tokens.
# [Shape] transformer_tokens_emb: (selected_tokens_count, hidden_size)
transformer_tokens_emb = (
self.transformer_block(selected_tokens_emb)
* selected_router_weights
)
).squeeze(0)

# Scatter the tokens into output according to the auxiliary decision.
output[b, aux_decision[b]] = transformer_tokens_emb
output[b, topk_decision[b]] = transformer_tokens_emb

return output


if __name__ == "__main__":
# Set the seed for reproducibility.
torch.manual_seed(42)

# Define the transformer block.
transformer_block = nn.TransformerEncoderLayer(d_model=512, nhead=1)

# Wrap the transformer block with MoD.
mod_transformer_block = MoDTransformerBlock(
transformer_block, hidden_size=512, capacity=0.125
)

# Input tensor.
# [Shape] x: (batch_size, sequence_length, hidden_size)
x = torch.rand(2, 20, 512)

# Forward pass.
output = mod_transformer_block(x)

print(output.shape)

0 comments on commit 099d03f

Please sign in to comment.