Skip to content

[Question]: Long training time of adapter networks / heads #28

@moghadas76

Description

@moghadas76

Describe the issue

Hi,

does it make sense that each epoch takes 30 minutes for the 35M pretrained tirex model? Are there any CUDA kernel requirements during adaptation?

To reproduce

class MoETransformerBlock(nn.Module):

    def __init__(self, config: MoiraiMoEConfig):
        super().__init__()

        model: ForecastModel = load_model("NX-AI/TiRex")
        self.pretrained_model = model
        for param in self.pretrained_model.parameters():
            param.requires_grad = False
        self.adaptor = nn.Linear(512, 64)
        # MoE layer
        self.moe = SparseMoELayer(
            d_model=config.d_model,
            d_ff=config.d_ff,
            num_experts=config.num_experts,
            top_k=config.num_experts_per_tok,
            use_shared_expert=config.use_shared_expert,
            use_centroid_gating=config.use_centroid_gating,
            dropout=config.dropout
        )
        
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        training: bool = True
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        residual = x
        x, all_hidden_states = self.pretrained_model._forward_hidden(x)
        x = self.adaptor(self.dropout(x)) + residual
        
        residual = x
        x, aux_loss = self.moe(x, training)
        x = self.dropout(x) + residual
        
        return x, aux_loss


class MoiraiMoEModule(nn.Module):
    
    def __init__(self, config: MoiraiMoEConfig):
        super().__init__()
        self.config = config
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(
            patch_size=config.patch_size,
            d_model=64,
            dropout=config.dropout
        )
        
        self.layers = nn.ModuleList([
            MoETransformerBlock(config) for _ in range(config.num_layers)
        ])
        
        self.final_norm = RMSNorm(config.d_model)
        
        # Output distribution head
        self.output_head = MixtureDistributionHead(
            d_model=config.d_model,
            patch_size=config.patch_size,
            num_components=config.num_distributions
        )
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        return_aux_loss: bool = True
    ) -> Union[dict, Tuple[dict, torch.Tensor]]:
      
        # Instance normalization
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True) + 1e-6
        x = (x - mean) / std
        
        # Patch embedding
        x = self.patch_embed(x)
        
        # Pass through transformer layers
        total_aux_loss = 0.0
        training = self.training
        
        for layer in self.layers:
            x, aux_loss = layer(x, attention_mask, training)
            total_aux_loss += aux_loss
        
        # Final normalization
        x = self.final_norm(x)
        
        # Get distribution parameters
        output_params = self.output_head(x)
        
        # Store normalization stats for denormalization
        output_params['norm_mean'] = mean
        output_params['norm_std'] = std
        
        if return_aux_loss:
            return output_params, total_aux_loss * self.config.aux_loss_coef
        return output_params

Platform

Linux

OS Version

22.04

Tirex Version

latest

Architecture

ARM64

Tirex Backend

cuda

Tirex Device Type

cuda

CUDA Version

12.1

Priority

Critical

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions