Skip to content

Commit

Permalink
Update streamlit widgets
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Aug 30, 2024
1 parent 3565252 commit e40d7c7
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions streamlit/surrogate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,20 @@
import torch
from funcy import rpartial

from baybe.acquisition.acqfs import qLogExpectedImprovement
from baybe.acquisition.base import AcquisitionFunction
from baybe.parameters.numerical import NumericalDiscreteParameter
from baybe.recommenders.pure.bayesian.botorch import BotorchRecommender
from baybe.searchspace import SearchSpace
from baybe.surrogates import CustomONNXSurrogate
from baybe.surrogates.base import Surrogate
from baybe.surrogates.gaussian_process.core import GaussianProcessSurrogate
from baybe.targets.numerical import NumericalTarget
from baybe.utils.basic import get_subclasses
from baybe.utils.random import set_random_seed

# Number of values used for the input parameter
N_PARAMETER_VALUES = 1000
N_PARAMETER_VALUES = 200


def cubic(
Expand Down Expand Up @@ -79,6 +82,13 @@ def main():
for cls in get_subclasses(Surrogate)
if not issubclass(cls, CustomONNXSurrogate)
}
surrogate_model_names = list(surrogate_model_classes.keys())

# Collect all available acquisition functions
acquisition_function_classes = {
cls.__name__: cls for cls in get_subclasses(AcquisitionFunction)
}
acquisition_function_names = list(acquisition_function_classes.keys())

# Streamlit simulation parameters
st.sidebar.markdown("# Domain")
Expand All @@ -95,7 +105,14 @@ def main():
st.sidebar.markdown("---")
st.sidebar.markdown("# Model")
st_surrogate_name = st.sidebar.selectbox(
"Surrogate model", list(surrogate_model_classes.keys())
"Surrogate model",
surrogate_model_names,
surrogate_model_names.index(GaussianProcessSurrogate.__name__),
)
st_acqf_name = st.sidebar.selectbox(
"Acquisition function",
acquisition_function_names,
acquisition_function_names.index(qLogExpectedImprovement.__name__),
)
st_n_training_points = st.sidebar.slider("Number of training points", 1, 20, 5)
st_n_recommendations = st.sidebar.slider("Number of recommendations", 1, 20, 5)
Expand Down Expand Up @@ -152,9 +169,12 @@ def main():
searchspace = SearchSpace.from_product(parameters=[parameter])
objective = NumericalTarget(name="y", mode=st_target_mode).to_objective()

# Create the surrogate model and the recommender
# Create the surrogate model, acquisition function, and the recommender
surrogate_model = surrogate_model_classes[st_surrogate_name]()
recommender = BotorchRecommender(surrogate_model=surrogate_model)
acqf = acquisition_function_classes[st_acqf_name]()
recommender = BotorchRecommender(
surrogate_model=surrogate_model, acquisition_function=acqf
)

# Get the recommendations and extract the posterior mean / standard deviation
recommendations = recommender.recommend(
Expand Down

0 comments on commit e40d7c7

Please sign in to comment.