-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
55 lines (45 loc) · 1.07 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import fire
import pickle
from model import GPT, ModelConfig
import torch
import tiktoken
import dataclasses
tokenizer = tiktoken.get_encoding('gpt2')
def decode(tokens):
return tokenizer.decode(tokens)
def encode(prompt):
return tokenizer.encode_ordinary(prompt)
def main(path, device='cuda:0'):
vocab_size = 50304
embedding_dim = 768
block_size = 1024
n_layers = 24
internal_dim = 3072
n_heads = 12
dropout = 0.0
bias = False
config = ModelConfig(
vocab_size,
embedding_dim,
block_size,
n_layers,
internal_dim,
n_heads,
dropout,
device,
bias,
)
checkpoint = torch.load(path, map_location=device)
model = GPT(config)
model.load_state_dict(checkpoint['model'])
model.to(device)
while True:
prompt = input('Prompt: ')
if prompt == 'q':
break
prompt_tokens = encode(prompt)
out = model.generate(prompt_tokens, 25)
print(out)
print(decode(out))
if __name__ == '__main__':
fire.Fire(main)