diff --git a/train_gpt2.py b/train_gpt2.py index b9dee8701..ad8cfee3f 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -146,6 +146,7 @@ def __init__(self, config): self.init_rng.manual_seed(42) self.apply(self._init_weights) + @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Linear): # apply special scaled init to the residual projections, per GPT-2 paper @@ -153,11 +154,11 @@ def _init_weights(self, module): # we want to skip initializing lm_head, which shares parameters with wte # and wte was already initialized down below during the Embedding init if not hasattr(module, 'LLMC_SKIP_INIT'): - torch.nn.init.normal_(module.weight, mean=0.0, std=std, generator=self.init_rng) + module.weight.normal_(mean=0.0, std=std, generator=self.init_rng) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng) + module.weight.normal_(mean=0.0, std=0.02, generator=self.init_rng) def forward(self, idx, targets=None, return_logits=True): device = idx.device