-
Notifications
You must be signed in to change notification settings - Fork 37
Open
Labels
questionFurther information is requestedFurther information is requested
Description
Describe the issue
Hi,
does it make sense that each epoch takes 30 minutes for the 35M pretrained tirex model? Are there any CUDA kernel requirements during adaptation?
To reproduce
class MoETransformerBlock(nn.Module):
def __init__(self, config: MoiraiMoEConfig):
super().__init__()
model: ForecastModel = load_model("NX-AI/TiRex")
self.pretrained_model = model
for param in self.pretrained_model.parameters():
param.requires_grad = False
self.adaptor = nn.Linear(512, 64)
# MoE layer
self.moe = SparseMoELayer(
d_model=config.d_model,
d_ff=config.d_ff,
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
use_shared_expert=config.use_shared_expert,
use_centroid_gating=config.use_centroid_gating,
dropout=config.dropout
)
self.dropout = nn.Dropout(config.dropout)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
training: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
residual = x
x, all_hidden_states = self.pretrained_model._forward_hidden(x)
x = self.adaptor(self.dropout(x)) + residual
residual = x
x, aux_loss = self.moe(x, training)
x = self.dropout(x) + residual
return x, aux_loss
class MoiraiMoEModule(nn.Module):
def __init__(self, config: MoiraiMoEConfig):
super().__init__()
self.config = config
# Patch embedding
self.patch_embed = PatchEmbedding(
patch_size=config.patch_size,
d_model=64,
dropout=config.dropout
)
self.layers = nn.ModuleList([
MoETransformerBlock(config) for _ in range(config.num_layers)
])
self.final_norm = RMSNorm(config.d_model)
# Output distribution head
self.output_head = MixtureDistributionHead(
d_model=config.d_model,
patch_size=config.patch_size,
num_components=config.num_distributions
)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
return_aux_loss: bool = True
) -> Union[dict, Tuple[dict, torch.Tensor]]:
# Instance normalization
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True) + 1e-6
x = (x - mean) / std
# Patch embedding
x = self.patch_embed(x)
# Pass through transformer layers
total_aux_loss = 0.0
training = self.training
for layer in self.layers:
x, aux_loss = layer(x, attention_mask, training)
total_aux_loss += aux_loss
# Final normalization
x = self.final_norm(x)
# Get distribution parameters
output_params = self.output_head(x)
# Store normalization stats for denormalization
output_params['norm_mean'] = mean
output_params['norm_std'] = std
if return_aux_loss:
return output_params, total_aux_loss * self.config.aux_loss_coef
return output_paramsPlatform
Linux
OS Version
22.04
Tirex Version
latest
Architecture
ARM64
Tirex Backend
cuda
Tirex Device Type
cuda
CUDA Version
12.1
Priority
Critical
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested