diff --git a/gpt.py b/gpt.py index e4fc68d6..8fed084c 100644 --- a/gpt.py +++ b/gpt.py @@ -162,7 +162,7 @@ def forward(self, idx, targets=None): # idx and targets are both (B,T) tensor of integers tok_emb = self.token_embedding_table(idx) # (B,T,C) - pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C) + pos_emb = self.position_embedding_table.weight[:T] # (T,C) x = tok_emb + pos_emb # (B,T,C) x = self.blocks(x) # (B,T,C) x = self.ln_f(x) # (B,T,C)