Skip to content

Commit

Permalink
FIX TypeError: normal_() got an unexpected keyword argument 'generator'
Browse files Browse the repository at this point in the history
  • Loading branch information
earlytobed committed Dec 24, 2024
1 parent 7ecd890 commit ab40bdf
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,19 @@ def __init__(self, config):
self.init_rng.manual_seed(42)
self.apply(self._init_weights)

@torch.no_grad()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
# apply special scaled init to the residual projections, per GPT-2 paper
std = 0.02 if not hasattr(module, 'LLMC_RESIDUAL_SCALE_FLAG') else 0.02/math.sqrt(2 * self.config.n_layer)
# we want to skip initializing lm_head, which shares parameters with wte
# and wte was already initialized down below during the Embedding init
if not hasattr(module, 'LLMC_SKIP_INIT'):
torch.nn.init.normal_(module.weight, mean=0.0, std=std, generator=self.init_rng)
module.weight.normal_(mean=0.0, std=std, generator=self.init_rng)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng)
module.weight.normal_(mean=0.0, std=0.02, generator=self.init_rng)

def forward(self, idx, targets=None, return_logits=True):
device = idx.device
Expand Down

0 comments on commit ab40bdf

Please sign in to comment.