Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions train_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,19 @@ def forward(self, x: Tensor) -> Tensor:
class CastedLinear(nn.Linear):
# Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute.
def forward(self, x: Tensor) -> Tensor:
# Simulate INT8 symmetric quantization using Straight-Through Estimator (STE)
w = self.weight.float()
clip_abs = w.abs().max().clamp(min=1e-8)
scale = clip_abs / 127.0

# quantize the weights
w_q = torch.clamp(torch.round(w / scale), -127, 127)
# gradients pass through the quantization step unchanged
w_fq = (w_q * scale - w).detach() + w

bias = self.bias.to(x.dtype) if self.bias is not None else None
return F.linear(x, self.weight.to(x.dtype), bias)

return F.linear(x, w_fq.to(x.dtype), bias)


def restore_low_dim_params_to_fp32(module: nn.Module) -> None:
Expand Down Expand Up @@ -671,19 +682,24 @@ def __init__(
self.num_decoder_layers = num_layers - self.num_encoder_layers
self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers)
self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32))
self.blocks = nn.ModuleList(
[
Block(
model_dim,
num_heads,
num_kv_heads,
mlp_mult,
rope_base,
qk_gain_init,
)
for i in range(num_layers)
]

self.shared_block = Block(
model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init
)

# self.blocks = nn.ModuleList(
# [
# Block(
# model_dim,
# num_heads,
# num_kv_heads,
# mlp_mult,
# rope_base,
# qk_gain_init,
# )
# for i in range(num_layers)
# ]
# )
self.final_norm = RMSNorm()
self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False)
if self.lm_head is not None:
Expand All @@ -705,12 +721,12 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:

# First half stores skips; second half reuses them in reverse order.
for i in range(self.num_encoder_layers):
x = self.blocks[i](x, x0)
x = self.shared_block(x, x0)
skips.append(x)
for i in range(self.num_decoder_layers):
if skips:
x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop()
x = self.blocks[self.num_encoder_layers + i](x, x0)
x = self.shared_block(x, x0)

x = self.final_norm(x).reshape(-1, x.size(-1))
targets = target_ids.reshape(-1)
Expand Down Expand Up @@ -848,7 +864,7 @@ def log0(msg: str, console: bool = True) -> None:
# - untied lm_head (Adam) uses HEAD_LR
# - matrix params in transformer blocks use MATRIX_LR via Muon
# - vectors/scalars use SCALAR_LR via Adam
block_named_params = list(base_model.blocks.named_parameters())
block_named_params = list(base_model.shared_block.named_parameters())
matrix_params = [
p
for name, p in block_named_params
Expand Down