From 282ee9e4428f3389cdc57c0074cef82c59a1abd4 Mon Sep 17 00:00:00 2001 From: "Andrei.Aksionov" Date: Thu, 9 Feb 2023 16:57:11 +0300 Subject: [PATCH] Loss calculation should not permanently change shapes of logits and targets --- bigram.py | 4 +--- gpt.py | 8 +++----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/bigram.py b/bigram.py index faa3a14a..ba54641e 100644 --- a/bigram.py +++ b/bigram.py @@ -74,9 +74,7 @@ def forward(self, idx, targets=None): loss = None else: B, T, C = logits.shape - logits = logits.view(B*T, C) - targets = targets.view(B*T) - loss = F.cross_entropy(logits, targets) + loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T)) return logits, loss diff --git a/gpt.py b/gpt.py index e4fc68d6..24d69525 100644 --- a/gpt.py +++ b/gpt.py @@ -103,7 +103,7 @@ def forward(self, x): out = self.dropout(self.proj(out)) return out -class FeedFoward(nn.Module): +class FeedForward(nn.Module): """ a simple linear layer followed by a non-linearity """ def __init__(self, n_embd): @@ -126,7 +126,7 @@ def __init__(self, n_embd, n_head): super().__init__() head_size = n_embd // n_head self.sa = MultiHeadAttention(n_head, head_size) - self.ffwd = FeedFoward(n_embd) + self.ffwd = FeedForward(n_embd) self.ln1 = nn.LayerNorm(n_embd) self.ln2 = nn.LayerNorm(n_embd) @@ -172,9 +172,7 @@ def forward(self, idx, targets=None): loss = None else: B, T, C = logits.shape - logits = logits.view(B*T, C) - targets = targets.view(B*T) - loss = F.cross_entropy(logits, targets) + loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T)) return logits, loss