Skip to content

Commit

Permalink
minor changes, added debd demo
Browse files Browse the repository at this point in the history
  • Loading branch information
smatmo committed Jul 14, 2020
1 parent 3fc46a3 commit 6eaa844
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/EinsumNetwork/EinsumNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Args(object):
RVs.
num_input_distributions: number of distributions per input region (K in the paper).
num_sums: number of sum nodes per internal region (K in the paper).
num_classes: number of outputs of the SPN.
num_classes: number of outputs of the PC.
exponential_family: which exponential family to use; (sub-class ExponentialFamilyTensor).
exponential_family_args: arguments for the exponential family, e.g. trial-number N for Binomial.
use_em: determines if the internal em algorithm shall be used; otherwise you might use e.g. SGD.
Expand Down Expand Up @@ -261,21 +261,21 @@ def log_likelihoods(outputs, labels=None):
return lls


def eval_accuracy_batched(spn, x, labels, batch_size):
def eval_accuracy_batched(einet, x, labels, batch_size):
"""Computes accuracy in batched way."""
with torch.no_grad():
idx_batches = torch.arange(0, x.shape[0], dtype=torch.int64, device=x.device).split(batch_size)
n_correct = 0
for batch_count, idx in enumerate(idx_batches):
batch_x = x[idx, :]
batch_labels = labels[idx]
outputs = spn.forward(batch_x)
outputs = einet.forward(batch_x)
_, pred = outputs.max(1)
n_correct += torch.sum(pred == batch_labels)
return (n_correct.float() / x.shape[0]).item()


def eval_loglikelihood_batched(spn, x, labels=None, batch_size=100):
def eval_loglikelihood_batched(einet, x, labels=None, batch_size=100):
"""Computes log-likelihood in batched way."""
with torch.no_grad():
idx_batches = torch.arange(0, x.shape[0], dtype=torch.int64, device=x.device).split(batch_size)
Expand All @@ -286,7 +286,7 @@ def eval_loglikelihood_batched(spn, x, labels=None, batch_size=100):
batch_labels = labels[idx]
else:
batch_labels = None
outputs = spn(batch_x)
outputs = einet(batch_x)
ll_sample = log_likelihoods(outputs, batch_labels)
ll_total += ll_sample.sum().item()
return ll_total
90 changes: 90 additions & 0 deletions src/demo_debd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
from EinsumNetwork import Graph, EinsumNetwork
import datasets

device = 'cuda' if torch.cuda.is_available() else 'cpu'

demo_text = """
This demo loads one of the 20 binary datasets and quickly trains an EiNet for some epochs.
There are some parameters to play with, as for example which dataset shall be used and some
structural parameters.
"""
print(demo_text)

##########################################################
dataset = 'accidents'

depth = 3
num_repetitions = 10
num_input_distributions = 20
num_sums = 20

max_num_epochs = 10
batch_size = 100
online_em_frequency = 1
online_em_stepsize = 0.05

##########################################################

print(dataset)

train_x_orig, test_x_orig, valid_x_orig = datasets.load_debd(dataset, dtype='float32')

train_x = train_x_orig
test_x = test_x_orig
valid_x = valid_x_orig

# to torch
train_x = torch.from_numpy(train_x).to(torch.device(device))
valid_x = torch.from_numpy(valid_x).to(torch.device(device))
test_x = torch.from_numpy(test_x).to(torch.device(device))

train_N, num_dims = train_x.shape
valid_N = valid_x.shape[0]
test_N = test_x.shape[0]

graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions)

args = EinsumNetwork.Args(
num_classes=1,
num_input_distributions=num_input_distributions,
exponential_family=EinsumNetwork.CategoricalArray,
exponential_family_args={'K': 2},
num_sums=num_sums,
num_var=train_x.shape[1],
online_em_frequency=1,
online_em_stepsize=0.05)

einet = EinsumNetwork.EinsumNetwork(graph, args)
einet.initialize()
einet.to(device)
print(einet)

for epoch_count in range(max_num_epochs):

# evaluate
train_ll = EinsumNetwork.eval_loglikelihood_batched(einet, train_x)
valid_ll = EinsumNetwork.eval_loglikelihood_batched(einet, valid_x)
test_ll = EinsumNetwork.eval_loglikelihood_batched(einet, test_x)

print("[{}] train LL {} valid LL {} test LL {}".format(epoch_count,
train_ll / train_N,
valid_ll / valid_N,
test_ll / test_N))

# train
idx_batches = torch.randperm(train_N).split(batch_size)
for batch_count, idx in enumerate(idx_batches):
batch_x = train_x[idx, :]
outputs = einet.forward(batch_x)

ll_sample = EinsumNetwork.log_likelihoods(outputs)
log_likelihood = ll_sample.sum()

objective = log_likelihood
objective.backward()

einet.em_process_batch()

einet.em_update()
10 changes: 10 additions & 0 deletions src/demo_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@

device = 'cuda' if torch.cuda.is_available() else 'cpu'

demo_text = """
This demo loads (fashion) mnist and quickly trains an EiNet for some epochs.
There are some parameters to play with, as for example which exponential family you want
to use, which classes you want to pick, and structural parameters. Then an EiNet is trained,
the log-likelihoods reported and some (conditional and unconditional) samples are produced.
"""
print(demo_text)

############################################################################
fashion_mnist = False

Expand All @@ -15,6 +24,7 @@
# exponential_family = EinsumNetwork.NormalArray

classes = [7]
# classes = [2, 3, 5, 7]
# classes = None

K = 10
Expand Down

0 comments on commit 6eaa844

Please sign in to comment.