diff --git a/tests/autoqkeras_test.py b/tests/autoqkeras_test.py index 8d0f5239..43d2983c 100644 --- a/tests/autoqkeras_test.py +++ b/tests/autoqkeras_test.py @@ -29,11 +29,16 @@ from tensorflow.keras.layers import Dropout from tensorflow.keras.layers import Input from tensorflow.keras.models import Model -from tensorflow.keras.optimizers import Adam +# TODO: Update to new optimizer API +from tensorflow.keras.optimizers.legacy import Adam from tensorflow.keras.utils import to_categorical from qkeras.autoqkeras import AutoQKerasScheduler +np.random.seed(42) +tf.random.set_seed(42) +tf.config.experimental.enable_op_determinism() + def dense_model(): """Creates test dense model.""" @@ -52,13 +57,20 @@ def dense_model(): x = Activation("softmax", name="softmax")(x) model = Model(inputs=x_in, outputs=x) - return model + # Manually set the weights for each layer. Needed for test determinism. + for layer in model.layers: + if isinstance(layer, Dense): + weights_shape = layer.get_weights()[0].shape + bias_shape = layer.get_weights()[1].shape + weights = np.random.RandomState(42).randn(*weights_shape) + bias = np.random.RandomState(42).randn(*bias_shape) + layer.set_weights([weights, bias]) + + return model def test_autoqkeras(): """Tests AutoQKeras scheduler.""" - np.random.seed(42) - tf.random.set_seed(42) x_train, y_train = load_iris(return_X_y=True) @@ -104,7 +116,7 @@ def test_autoqkeras(): model = dense_model() model.summary() - optimizer = Adam(lr=0.01) + optimizer = Adam(learning_rate=0.015) model.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["acc"]) @@ -140,14 +152,12 @@ def test_autoqkeras(): qmodel = autoqk.get_best_model() - optimizer = Adam(lr=0.01) + optimizer = Adam(learning_rate=0.015) qmodel.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["acc"]) - history = qmodel.fit(x_train, y_train, epochs=5, batch_size=150, + _ = qmodel.fit(x_train, y_train, epochs=5, batch_size=150, validation_split=0.1) - quantized_acc = history.history["acc"][-1] - if __name__ == "__main__": pytest.main([__file__])