Skip to content

Commit af3a5e1

Browse files
Updates nanoGPT to not store certain tensors as members
1 parent c15f675 commit af3a5e1

File tree

3 files changed

+15
-16
lines changed

3 files changed

+15
-16
lines changed

tripy/examples/nanogpt/example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def main():
103103
1,
104104
# We can specify dynamic dimensions by using a sequence indicating the min/opt/max values that
105105
# a dimension should support:
106-
[1, len(input_ids), padded_seq_len],
106+
(1, len(input_ids), padded_seq_len),
107107
)
108108
model = tp.compile(model, args=[tp.InputInfo(input_shape, dtype=tp.int32)])
109109
compile_end_time = time.perf_counter()

tripy/examples/nanogpt/model.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def __init__(self, config):
6767
tp.tril(tp.ones((config.block_size, config.block_size), dtype=config.dtype)),
6868
(1, 1, config.block_size, config.block_size),
6969
)
70-
self.zeros = tp.zeros((1, 1, self.seq_len, self.seq_len), dtype=config.dtype)
7170

7271
def __call__(self, x: tp.Tensor):
7372
B, T = x.shape[0:2]
@@ -87,11 +86,7 @@ def __call__(self, x: tp.Tensor):
8786

8887
k_t = tp.transpose(k, -2, -1)
8988
att = (q @ k_t) * (1.0 / math.sqrt(self.embedding_size // self.num_heads))
90-
att = tp.masked_fill(
91-
att,
92-
self.bias[:, :, :T, :T] == self.zeros[:, :, :T, :T],
93-
float("-inf"),
94-
)
89+
att = tp.masked_fill(att, self.bias[:, :, :T, :T] == 0, float("-inf"))
9590

9691
att = tp.softmax(att, dim=-1)
9792

@@ -135,18 +130,18 @@ def __call__(self, x):
135130
class Transformer(tp.Module):
136131
def __init__(self, config):
137132
super().__init__()
133+
self.seq_len = config.seq_len
138134
self.wte = tp.Embedding(config.vocab_size, config.embedding_size, dtype=config.dtype)
139135
self.wpe = tp.Embedding(config.block_size, config.embedding_size, dtype=config.dtype)
140-
self.h = [Block(config) for _ in range(config.num_layers)]
136+
self.h = tp.Sequential(*[Block(config) for _ in range(config.num_layers)])
141137
self.ln_f = tp.LayerNorm(config.embedding_size)
142-
self.pos = tp.reshape(tp.arange(0, config.seq_len, dtype=tp.int32), (1, config.seq_len))
143138

144139
def __call__(self, idx):
145140
tok_emb = self.wte(idx) # token embeddings of shape (batch_size, seq_len, embedding_size)
146-
pos_emb = self.wpe(self.pos[:, : idx.shape[1]]) # position embeddings of shape (seq_len, embedding_size)
141+
pos = tp.unsqueeze(tp.arange(self.seq_len, dtype=tp.int32)[: idx.shape[1]], 0)
142+
pos_emb = self.wpe(pos) # position embeddings of shape (seq_len, embedding_size)
147143
x = tok_emb + pos_emb # (batch_size, seq_len, embedding_size)
148-
for block in self.h:
149-
x = block(x)
144+
x = self.h(x)
150145
x = tp.cast(self.ln_f(tp.cast(x, self.ln_f.dtype)), x.dtype)
151146
return x
152147

tripy/examples/nanogpt/weight_loader.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,20 @@ def load_weights_from_hf(model, model_type, dtype):
2626

2727
tripy_state_dict = model.state_dict()
2828
# attention biases are initialized in the model based on block size.
29-
tripy_keys = [key for key in tripy_state_dict.keys() if not key.endswith(".attn.bias")]
29+
tripy_keys = {key for key in tripy_state_dict.keys() if not key.endswith(".attn.bias")}
3030

3131
# Load huggingface/transformers model
3232
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
3333
hf_state_dict = model_hf.state_dict()
3434
# We ignore some of the keys in the HF checkpoint:
35-
hf_keys = [
35+
hf_keys = {
3636
key for key in hf_state_dict.keys() if not key.endswith(".attn.masked_bias") and not key.endswith(".attn.bias")
37-
]
38-
assert len(hf_keys) == len(tripy_keys), f"Mismatched keys: {hf_keys} != {tripy_keys}"
37+
}
38+
assert hf_keys == tripy_keys, (
39+
f"Mismatched keys. Note:\n"
40+
f"`hf_keys` extra keys: {hf_keys - tripy_keys}\n"
41+
f"`tripy_keys` extra keys: {tripy_keys - hf_keys}"
42+
)
3943

4044
# See https://paperswithcode.com/method/weight-tying for details on why we do this:
4145
hf_state_dict["transformer.wte.weight"] = hf_state_dict["lm_head.weight"]

0 commit comments

Comments
 (0)