Skip to content

Commit

Permalink
load/save routines added in mnist_demo
Browse files Browse the repository at this point in the history
  • Loading branch information
smatmo committed Jul 14, 2020
1 parent 6eaa844 commit 7fa3739
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ src/.idea
src/scribble*
venv
results
demo_results
samples
auxiliary
__pycache__
1 change: 1 addition & 0 deletions src/EinsumNetwork/Graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import networkx as nx
from itertools import count
from networkx import read_gpickle, write_gpickle


class EiNetAddress:
Expand Down
4 changes: 3 additions & 1 deletion src/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,15 @@ def maybe_download_svhn():
maybe_download('../data/svhn', 'http://ufldl.stanford.edu/housenumbers/', file)


def load_svhn(data_dir, dtype=np.uint8):
def load_svhn(dtype=np.uint8):
"""
Load the SVHN dataset.
"""

maybe_download_svhn()

data_dir = '../data/svhn'

data_train = sp.loadmat(os.path.join(data_dir, "train_32x32.mat"))
data_test = sp.loadmat(os.path.join(data_dir, "test_32x32.mat"))
data_extra = sp.loadmat(os.path.join(data_dir, "extra_32x32.mat"))
Expand Down
50 changes: 46 additions & 4 deletions src/demo_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,17 @@

einet.em_update()


# Draw some samples
if fashion_mnist:
samples_dir = os.path.join("../samples/demo_fashion_mnist/")
result_dir = '../demo_results/fashion_mnist/'
else:
samples_dir = os.path.join("../samples/demo_mnist/")
result_dir = '../demo_results/mnist/'
utils.mkdir_p(result_dir)

#####################
# draw some samples #
#####################

samples_dir = os.path.join(result_dir, "samples/")
utils.mkdir_p(samples_dir)

samples = einet.sample(num_samples=25).cpu().numpy()
Expand Down Expand Up @@ -187,3 +191,41 @@

print()
print('Saved samples to {}'.format(samples_dir))

####################
# save and re-load #
####################

# evaluate log-likelihoods
einet.eval()
train_ll_before = EinsumNetwork.eval_loglikelihood_batched(einet, train_x, batch_size=batch_size)
valid_ll_before = EinsumNetwork.eval_loglikelihood_batched(einet, valid_x, batch_size=batch_size)
test_ll_before = EinsumNetwork.eval_loglikelihood_batched(einet, test_x, batch_size=batch_size)

# save model
graph_file = os.path.join(result_dir, "einet.pc")
Graph.write_gpickle(graph, graph_file)
print("Saved PC graph to {}".format(graph_file))
model_file = os.path.join(result_dir, "einet.mdl")
torch.save(einet, model_file)
print("Saved model to {}".format(model_file))

del einet

# reload model
einet = torch.load(model_file)
print("Loaded model from {}".format(model_file))

# evaluate log-likelihoods on re-loaded model
train_ll = EinsumNetwork.eval_loglikelihood_batched(einet, train_x, batch_size=batch_size)
valid_ll = EinsumNetwork.eval_loglikelihood_batched(einet, valid_x, batch_size=batch_size)
test_ll = EinsumNetwork.eval_loglikelihood_batched(einet, test_x, batch_size=batch_size)
print()
print("Log-likelihoods before saving --- train LL {} valid LL {} test LL {}".format(
train_ll / train_N,
valid_ll / valid_N,
test_ll / test_N))
print("Log-likelihoods after saving --- train LL {} valid LL {} test LL {}".format(
train_ll / train_N,
valid_ll / valid_N,
test_ll / test_N))

0 comments on commit 7fa3739

Please sign in to comment.