forked from wwJinkla/Efficient-Hotflip
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_inference.py
51 lines (43 loc) · 1.7 KB
/
run_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from efficient.model import CharCNNLSTMModel, Dataset
from efficient.utils import read_corpus, read_labels
from efficient.vocab import Vocab
def infer(model_path, vocab_path, test_contents_path, test_label_path, **model_config):
vocab = Vocab.load(vocab_path)
test_contents = read_corpus(test_contents_path)
test_labels = read_labels(test_label_path)
test_dataset = Dataset(
test_contents, test_labels, vocab, model_config.get("max_word_length"), "cpu"
)
predictor = CharCNNLSTMModel(vocab, **model_config)
device = "cuda" if torch.cuda.is_available() else "cpu"
predictor.model.load_state_dict(
torch.load(model_path, map_location=torch.device(device))
)
batch_size = 20
accuracies = []
# This will drop the last few examples (<= 19)
for batch_index in range(0, len(test_labels), batch_size):
batch_contents, batch_labels, batch_content_lengths = test_dataset[batch_size]
_, accuracy = predictor.predict(
batch_contents, batch_labels, batch_content_lengths
)
accuracies.append(accuracy)
print("test accuracy:", sum(accuracies) / len(accuracies))
if __name__ == "__main__":
vocab_path = "data/vocab.json"
test_contents_path = "data/test_content.txt"
test_label_path = "data/test_label.txt"
model_path = "checkpoints/case_aware/best_model.pkl"
model_config = dict(
char_embed_size=25,
embed_size=500,
hidden_size=500,
max_word_length=30,
batch_size=100,
eta=0.001,
max_grad_norm=1,
max_iter=500,
val_batch_size=500,
)
infer(model_path, vocab_path, test_contents_path, test_label_path, **model_config)