diff --git a/exoplanet-ml/astronet/predict.py b/exoplanet-ml/astronet/predict.py index d2e3ccf..932c611 100644 --- a/exoplanet-ml/astronet/predict.py +++ b/exoplanet-ml/astronet/predict.py @@ -112,12 +112,12 @@ def _process_tce(feature_config): features = {} if "global_view" in feature_config: - global_view = preprocess.global_view(time, flux, FLAGS.period) + global_view = preprocess.global_view(time, flux, FLAGS.period).astype(np.float32) # Add a batch dimension. features["global_view"] = np.expand_dims(global_view, 0) if "local_view" in feature_config: - local_view = preprocess.local_view(time, flux, FLAGS.period, FLAGS.duration) + local_view = preprocess.local_view(time, flux, FLAGS.period, FLAGS.duration).astype(np.float32) # Add a batch dimension. features["local_view"] = np.expand_dims(local_view, 0)