Skip to content

Commit e55dfef

Browse files
authored
Merge pull request #34 from XanaduAI/generative_models
Add generative models and datasets
2 parents 8be28e6 + cb4e53e commit e55dfef

13 files changed

+1512
-143
lines changed

README.md

Lines changed: 98 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Benchmarking for quantum machine learning models
22

33
This repository contains tools to compare the performance of near-term quantum machine learning (QML)
4-
as well as standard classical machine learning models on supervised learning tasks.
4+
as well as standard classical machine learning models on supervised and generative learning tasks.
55

66
It is based on pipelines using [Pennylane](https://pennylane.ai/) for the simulation of quantum circuits,
77
[JAX](https://jax.readthedocs.io/en/latest/index.html) for training,
@@ -39,12 +39,12 @@ Dependencies of this package can be installed in your environment by running
3939
pip install -r requirements.txt
4040
```
4141

42-
## Adding a custom model
42+
## Adding a custom classifier
4343

4444
We use the [Scikit-learn API](https://scikit-learn.org/stable/developers/develop.html) to create
4545
models and perform hyperparameter search.
4646

47-
A minimal template for a new quantum model is as follows, and can be stored
47+
A minimal template for a new quantum classifier is as follows, and can be stored
4848
in `qml_benchmarks/models/my_model.py`:
4949

5050
```python
@@ -61,18 +61,23 @@ class MyModel(BaseEstimator, ClassifierMixin):
6161

6262
# reproducibility is ensured by creating a numpy PRNG and using it for all
6363
# subsequent random functions.
64-
self._random_state = random_state
65-
self._rng = np.random.default_rng(random_state)
64+
self.random_state = random_state
65+
self.rng = np.random.default_rng(random_state)
6666

6767
# define data-dependent attributes
6868
self.params_ = None
6969
self.n_qubits_ = None
70+
71+
def initialize(self, args):
72+
"""
73+
initialize the model if necessary
74+
"""
75+
# ... your code here ...
7076

7177
def fit(self, X, y):
7278
"""Fit the model to data X and labels y.
7379
7480
Add your custom training loop here and store the trained model parameters in `self.params_`.
75-
Set the data-dependent attributes, such as `self.n_qubits_`.
7681
7782
Args:
7883
X (array_like): Data of shape (n_samples, n_features)
@@ -146,9 +151,86 @@ model.fit(X_train, y_train)
146151
print(model.score(X_test, y_test))
147152
```
148153

154+
155+
## Adding a custom generative model
156+
157+
The minimal template for a new generative model closely follows that of the classifier models.
158+
Labels are set to `None` throughout to maintain sci-kit learn functionality.
159+
160+
```python
161+
import numpy as np
162+
163+
from sklearn.base import BaseEstimator
164+
165+
166+
class MyModel(BaseEstimator):
167+
def __init__(self, hyperparam1="some_value", random_state=42):
168+
169+
# store hyperparameters as attributes
170+
self.hyperparam1 = hyperparam1
171+
172+
# reproducibility is ensured by creating a numpy PRNG and using it for all
173+
# subsequent random functions.
174+
self.random_state = random_state
175+
self.rng = np.random.default_rng(random_state)
176+
177+
# define data-dependent attributes
178+
self.params_ = None
179+
self.n_qubits_ = None
180+
181+
def initialize(self, args):
182+
"""
183+
initialize the model if necessary
184+
"""
185+
# ... your code here ...
186+
187+
def fit(self, X, y=None):
188+
"""Fit the model to data X.
189+
190+
Add your custom training loop here and store the trained model parameters in `self.params_`.
191+
192+
Args:
193+
X (array_like): Data of shape (n_samples, n_features)
194+
y (array_like): not used (no labels)
195+
"""
196+
# ... your code here ...
197+
198+
def sample(self, num_samples):
199+
"""sample from the generative model
200+
201+
Args:
202+
num_samples (int): number of points to sample
203+
204+
Returns:
205+
array_like: sampled points
206+
"""
207+
# ... your code here ...
208+
209+
return samples
210+
211+
def score(self, X, y=None):
212+
"""A optional custom score function to be used with hyperparameter optimization
213+
Args:
214+
X (array_like): Data of shape (n_samples, n_features)
215+
y: unused (no labels for generative models)
216+
217+
Returns:
218+
(float): score for the dataset X
219+
"""
220+
# ... your code here ...
221+
return score
222+
```
223+
224+
If the model samples binary data, it is recommended to construct models that sample binary strings (rather than $\pm1$ valued strings)
225+
to align with the datasets designed for generative models.
226+
Energy based models can easily be constructed by replacing the multilayer perceptron neural network in `DeepEBM` by
227+
any other differentiable network written in `flax`.
228+
149229
## Datasets
150230

151-
The `qml_benchmarks.data` module provides generating functions to create datasets for binary classification.
231+
The `qml_benchmarks.data` module provides generating functions to create datasets for binary classification and
232+
generative learning.
233+
152234
A generating function can be used like this:
153235

154236
```python
@@ -158,7 +240,7 @@ X, y = generate_two_curves(n_samples=200, n_features=4, degree=3, noise=0.1, off
158240
```
159241

160242
Note that some datasets might have different return data structures, for example if the train/test split
161-
is performed by the generating function.
243+
is performed by the generating function. If the dataset does not include labels, `y = None` is returned.
162244

163245
The original datasets used in the paper can be generated by running the scripts in the `paper/benchmarks` folder,
164246
such as:
@@ -172,15 +254,18 @@ This will create a new folder in `paper/benchmarks` containing the datasets.
172254
## Running hyperparameter optimization
173255

174256
In the folder `scripts` we provide an example that can be used to
175-
generate results for a hyperparameter search for any model and dataset. The script
257+
generate results for a hyperparameter search for any model and dataset. The script functions
258+
for both classifier and generative models. The script
176259
can be run as
177260

178261
```
179-
python run_hyperparameter_search.py --classifier-name "DataReuploadingClassifier" --dataset-path "my_dataset.csv"
262+
python run_hyperparameter_search.py --model "DataReuploadingClassifier" --dataset-path "my_dataset.csv"
180263
```
181264

182-
where `my_dataset.csv` is a CSV file containing the training data such that each column is a feature
183-
and the last column is the target.
265+
where`my_dataset.csv` is a CSV file containing the training data. For classification problems, each column should
266+
correspond to a feature and the last column to the target. For generative learning, each row
267+
should correspond to a binary string that specifies a unique data sample, and the model should implement a `score`
268+
method.
184269

185270
Unless otherwise specified, the hyperparameter grid is loaded from `qml_benchmarks/hyperparameter_settings.py`.
186271
One can override the default grid of hyperparameters by specifying the hyperparameter list,
@@ -189,7 +274,7 @@ For example, for the `DataReuploadingClassifier` we can run:
189274

190275
```
191276
python run_hyperparameter_search.py \
192-
--classifier-name DataReuploadingClassifier \
277+
--model DataReuploadingClassifier \
193278
--dataset-path "my_dataset.csv" \
194279
--n_layers 1 2 \
195280
--observable_type "single" "full"\

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ pyyaml~=6.0
1010
pennyLane~=0.34
1111
scipy~=1.11
1212
pandas~=2.2
13+
numpyro~=0.14.0

scripts/run_hyperparameter_search.py

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,33 @@
2020
import time
2121
import argparse
2222
import logging
23+
2324
logging.getLogger().setLevel(logging.INFO)
2425
from importlib import import_module
2526
import pandas as pd
2627
from pathlib import Path
2728
import matplotlib.pyplot as plt
2829
from sklearn.model_selection import GridSearchCV
30+
from sklearn.metrics import make_scorer
31+
from qml_benchmarks.models.base import BaseGenerator
2932
from qml_benchmarks.hyperparam_search_utils import read_data, construct_hyperparameter_grid
3033
from qml_benchmarks.hyperparameter_settings import hyper_parameter_settings
3134

3235
np.random.seed(42)
3336

34-
logging.info('cpu count:' + str(os.cpu_count()))
37+
def custom_scorer(estimator, X, y=None):
38+
return estimator.score(X, y)
3539

40+
logging.info('cpu count:' + str(os.cpu_count()))
3641

3742
if __name__ == "__main__":
3843
# Create an argument parser
3944
parser = argparse.ArgumentParser(description="Run experiments with hyperparameter search.",
40-
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
45+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
4146

4247
parser.add_argument(
43-
"--classifier-name",
44-
help="Classifier to run",
48+
"--model",
49+
help="Model to run",
4550
)
4651

4752
parser.add_argument(
@@ -91,27 +96,28 @@
9196
# Parse the arguments along with any extra arguments that might be model specific
9297
args, unknown_args = parser.parse_known_args()
9398

94-
if any(arg is None for arg in [args.classifier_name,
99+
if any(arg is None for arg in [args.model,
95100
args.dataset_path]):
96101
msg = "\n================================================================================"
97-
msg += "\nA classifier from qml.benchmarks.model and dataset path are required. E.g., \n \n"
98-
msg += "python run_hyperparameter_search \ \n--classifier DataReuploadingClassifier \ \n--dataset-path train.csv\n"
102+
msg += "\nA model from qml.benchmarks.models and dataset path are required. E.g., \n \n"
103+
msg += "python run_hyperparameter_search \ \n--model DataReuploadingClassifier \ \n--dataset-path train.csv\n"
99104
msg += "\nCheck all arguments for the script with \n"
100105
msg += "python run_hyperparameter_search --help\n"
101106
msg += "================================================================================"
102107
raise ValueError(msg)
103-
108+
104109
# Add model specific arguments to override the default hyperparameter grid
105110
hyperparam_grid = construct_hyperparameter_grid(
106-
hyper_parameter_settings, args.classifier_name
111+
hyper_parameter_settings, args.model
107112
)
113+
108114
for hyperparam in hyperparam_grid:
109115
hp_type = type(hyperparam_grid[hyperparam][0])
110116
parser.add_argument(f'--{hyperparam}',
111117
type=hp_type,
112118
nargs="+",
113119
default=hyperparam_grid[hyperparam],
114-
help=f'{hyperparam} grid values for {args.classifier_name}')
120+
help=f'{hyperparam} grid values for {args.model}')
115121

116122
args = parser.parse_args(unknown_args, namespace=args)
117123

@@ -122,11 +128,12 @@
122128
logging.info(
123129
"Running hyperparameter search experiment with the following settings\n"
124130
)
125-
logging.info(args.classifier_name)
131+
logging.info(args.model)
126132
logging.info(args.dataset_path)
127133
logging.info(" ".join(args.hyperparameter_scoring))
128134
logging.info(args.hyperparameter_refit)
129-
logging.info("Hyperparam grid:"+" ".join([(str(key)+str(":")+str(hyperparam_grid[key])) for key in hyperparam_grid.keys()]))
135+
logging.info("Hyperparam grid:" + " ".join(
136+
[(str(key) + str(":") + str(hyperparam_grid[key])) for key in hyperparam_grid.keys()]))
130137

131138
experiment_path = args.results_path
132139
results_path = os.path.join(experiment_path, "results")
@@ -135,22 +142,25 @@
135142
os.makedirs(results_path)
136143

137144
###################################################################
138-
# Get the classifier, dataset and search methods from the arguments
145+
# Get the model, dataset and search methods from the arguments
139146
###################################################################
140-
Classifier = getattr(
147+
Model = getattr(
141148
import_module("qml_benchmarks.models"),
142-
args.classifier_name
149+
args.model
143150
)
144-
classifier_name = Classifier.__name__
151+
model_name = Model.__name__
152+
153+
is_generative = isinstance(Model(), BaseGenerator)
154+
use_labels = False if is_generative else True
145155

146156
# Run the experiments save the results
147157
train_dataset_filename = os.path.join(args.dataset_path)
148-
X, y = read_data(train_dataset_filename)
158+
X, y = read_data(train_dataset_filename, labels=use_labels)
149159

150160
dataset_path_obj = Path(args.dataset_path)
151161
results_filename_stem = " ".join(
152-
[Classifier.__name__ + "_" + dataset_path_obj.stem
153-
+ "_GridSearchCV"])
162+
[Model.__name__ + "_" + dataset_path_obj.stem
163+
+ "_GridSearchCV"])
154164

155165
# If we have already run this experiment then continue
156166
if os.path.isfile(os.path.join(results_path, results_filename_stem + ".csv")):
@@ -162,44 +172,48 @@
162172
logging.warning(msg)
163173
sys.exit(msg)
164174
else:
165-
logging.warning("Cleaning existing results for ", os.path.join(results_path, results_filename_stem + ".csv"))
166-
175+
logging.warning("Cleaning existing results for ",
176+
os.path.join(results_path, results_filename_stem + ".csv"))
167177

168178
###########################################################################
169179
# Single fit to check everything works
170180
###########################################################################
171-
classifier = Classifier()
181+
model = Model()
172182
a = time.time()
173-
classifier.fit(X, y)
183+
model.fit(X, y)
174184
b = time.time()
175-
acc_train = classifier.score(X, y)
185+
default_score = model.score(X, y)
176186
logging.info(" ".join(
177-
[classifier_name,
178-
"Dataset path",
179-
args.dataset_path,
180-
"Train acc:",
181-
str(acc_train),
182-
"Time single run",
183-
str(b - a)])
187+
[model_name,
188+
"Dataset path",
189+
args.dataset_path,
190+
"Train score:",
191+
str(default_score),
192+
"Time single run",
193+
str(b - a)])
184194
)
185-
if hasattr(classifier, "loss_history_"):
195+
if hasattr(model, "loss_history_"):
186196
if args.plot_loss:
187-
plt.plot(classifier.loss_history_)
197+
plt.plot(model.loss_history_)
188198
plt.xlabel("Iterations")
189199
plt.ylabel("Loss")
190200
plt.show()
191201

192-
if hasattr(classifier, "n_qubits_"):
193-
logging.info(" ".join(["Num qubits", f"{classifier.n_qubits_}"]))
202+
if hasattr(model, "n_qubits_"):
203+
logging.info(" ".join(["Num qubits", f"{model.n_qubits_}"]))
194204

195205
###########################################################################
196206
# Hyperparameter search
197207
###########################################################################
198-
gs = GridSearchCV(estimator=classifier, param_grid=hyperparam_grid,
199-
scoring=args.hyperparameter_scoring,
200-
refit=args.hyperparameter_refit,
201-
verbose=3,
202-
n_jobs=-1).fit(
208+
209+
scorer = args.hyperparameter_scoring if not is_generative else custom_scorer
210+
refit = args.hyperparameter_refit if not is_generative else False
211+
212+
gs = GridSearchCV(estimator=model, param_grid=hyperparam_grid,
213+
scoring=scorer,
214+
refit=refit,
215+
verbose=3,
216+
n_jobs=args.n_jobs).fit(
203217
X, y
204218
)
205219
logging.info("Best hyperparams")

scripts/score_with_best_hyperparameters.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Score a model using the best hyperparameters, using a command-line script."""
15+
"""
16+
Score a model using the best hyperparameters, using a command-line script.
17+
Note this is only compatible with supervised models for classification.
18+
"""
19+
1620

1721
import numpy as np
1822
import sys

src/qml_benchmarks/data/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@
1919
from qml_benchmarks.data.hyperplanes import generate_hyperplanes_parity
2020
from qml_benchmarks.data.linearly_separable import generate_linearly_separable
2121
from qml_benchmarks.data.two_curves import generate_two_curves
22-
22+
from qml_benchmarks.data.spin_blobs import generate_spin_blobs, generate_8blobs
23+
from qml_benchmarks.data.ising import generate_ising

0 commit comments

Comments
 (0)