diff --git a/gpt.py b/gpt.py index e4fc68d6..2fcb2a0f 100644 --- a/gpt.py +++ b/gpt.py @@ -96,11 +96,10 @@ def __init__(self, num_heads, head_size): super().__init__() self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) self.proj = nn.Linear(head_size * num_heads, n_embd) - self.dropout = nn.Dropout(dropout) def forward(self, x): out = torch.cat([h(x) for h in self.heads], dim=-1) - out = self.dropout(self.proj(out)) + out = self.proj(out) return out class FeedFoward(nn.Module): @@ -111,8 +110,8 @@ def __init__(self, n_embd): self.net = nn.Sequential( nn.Linear(n_embd, 4 * n_embd), nn.ReLU(), - nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout), + nn.Linear(4 * n_embd, n_embd), ) def forward(self, x): @@ -129,10 +128,12 @@ def __init__(self, n_embd, n_head): self.ffwd = FeedFoward(n_embd) self.ln1 = nn.LayerNorm(n_embd) self.ln2 = nn.LayerNorm(n_embd) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) def forward(self, x): - x = x + self.sa(self.ln1(x)) - x = x + self.ffwd(self.ln2(x)) + x = x + self.dropout1(self.sa(self.ln1(x))) + x = x + self.dropout2(self.ffwd(self.ln2(x))) return x class GPTLanguageModel(nn.Module):