-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathloading.py
69 lines (59 loc) · 2.18 KB
/
loading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import json
import torch
from utils.dataloaders import mnist, celeba
from utils.init_models import initialize_model
def load_model(directory, model_version=None):
"""
Returns model, data_loader and mask_descriptor of trained model.
Parameters
----------
directory : string
Directory where experiment was saved. For example './experiment_1'.
model_version : int or None
If None loads final model, otherwise loads model version determined by
int.
"""
path_to_config = directory + '/config.json'
if model_version is None:
path_to_model = directory + '/model.pt'
else:
path_to_model = directory + '/model{}.pt'.format(model_version)
# Open config file
with open(path_to_config) as config_file:
config = json.load(config_file)
# Load dataset info
dataset = config["dataset"]
resize = config["resize"]
crop = config["crop"]
batch_size = config["batch_size"]
num_colors = config["num_colors"]
if "grayscale" in config:
grayscale = config["grayscale"]
else:
grayscale = False
# Get data
if dataset == 'mnist':
# Extract the test dataset (second argument)
_, data_loader = mnist(batch_size, num_colors, resize)
img_size = (1, resize, resize)
elif dataset == 'celeba':
data_loader = celeba(batch_size, num_colors, resize, crop, grayscale)
if grayscale:
img_size = (1, resize, resize)
else:
img_size = (3, resize, resize)
# Load model info
constrained = config["constrained"]
depth = config["depth"]
num_filters_cond = config["num_filters_cond"]
num_filters_prior = config["num_filters_prior"]
filter_size = config["filter_size"]
model = initialize_model(img_size,
num_colors,
depth,
filter_size,
constrained,
num_filters_prior,
num_filters_cond)
model.load_state_dict(torch.load(path_to_model, map_location=lambda storage, loc: storage))
return model, data_loader, config["mask_descriptor"]