Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 21 additions & 47 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,11 @@
from vizier._src.algorithms.designers.gp import output_warpers
from vizier._src.algorithms.optimizers import eagle_strategy as es
from vizier._src.algorithms.optimizers import vectorized_base as vb
from vizier._src.jax import gp_bandit_utils
from vizier._src.jax import stochastic_process_model as sp
from vizier._src.jax import types
from vizier._src.jax.models import tuned_gp_models
from vizier.jax import optimizers
from vizier.pyvizier import converters
from vizier.pyvizier.converters import feature_mapper
from vizier.pyvizier.converters import padding
from vizier.utils import profiler

Expand Down Expand Up @@ -124,10 +122,6 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
_output_warper_pipeline: output_warpers.OutputWarperPipeline = attr.field(
init=False
)
# TODO: Remove this.
_feature_mapper: Optional[
feature_mapper.ContinuousCategoricalFeatureMapper
] = attr.field(init=False, default=None)
_last_computed_state: _GPBanditState = attr.field(init=False)

default_acquisition_optimizer_factory = vb.VectorizedOptimizerFactory(
Expand Down Expand Up @@ -155,31 +149,8 @@ def __attrs_post_init__(self):
)
self._output_warper_pipeline = output_warpers.create_default_warper()

# TODO: Get rid of this when Vectorized optimizers operate on CACV.
self._one_hot_converter = (
converters.TrialToArrayConverter.from_study_config(
self._problem,
scale=True,
pad_oovs=True,
max_discrete_indices=0,
flip_sign_for_minimization_metrics=True,
)
)
self._padded_one_hot_converter = (
converters.PaddedTrialToArrayConverter.from_study_config(
self._problem,
scale=True,
pad_oovs=True,
padding_schedule=self._padding_schedule,
max_discrete_indices=0,
flip_sign_for_minimization_metrics=True,
)
)
self._feature_mapper = feature_mapper.ContinuousCategoricalFeatureMapper(
self._one_hot_converter
)
self._acquisition_optimizer = self._acquisition_optimizer_factory(
self._padded_one_hot_converter
self._converter
)
acquisition_problem = copy.deepcopy(self._problem)
if isinstance(
Expand Down Expand Up @@ -356,30 +327,33 @@ def _optimize_acquisition(
) -> list[vz.Trial]:
start_time = datetime.datetime.now()
# Set up optimizer and run
seed_features = vb.trials_to_sorted_array(
self._trials, self._padded_one_hot_converter
)
seed_features_unpad = vb.trials_to_sorted_array(
self._trials, self._one_hot_converter
)
seed_features = vb.trials_to_sorted_array(self._trials, self._converter)
acq_rng, self._rng = jax.random.split(self._rng)

# TODO: Remove this when Vectorized Optimizer works on CACV.
cacpa = self._converter.to_features(self._trials)
one_hot_to_modelinput = gp_bandit_utils.make_one_hot_to_modelinput_fn(
seed_features_unpad, self._feature_mapper, cacpa
)
score = scoring_fn.score
score_with_aux = scoring_fn.score_with_aux

score = lambda xs: scoring_fn.score(one_hot_to_modelinput(xs))
score_with_aux = lambda xs: scoring_fn.score_with_aux(
one_hot_to_modelinput(xs)
)
prior_features = None
if seed_features is not None:
continuous = seed_features.continuous.unpad()
continuous = types.PaddedArray.from_array(
continuous,
(continuous.shape[0], seed_features.continuous.shape[1]),
fill_value=np.nan,
)
categorical = seed_features.categorical.unpad()
categorical = types.PaddedArray.from_array(
categorical,
(categorical.shape[0], seed_features.categorical.shape[1]),
fill_value=-1,
)
prior_features = types.ModelInput(continuous, categorical)

best_candidates: vb.VectorizedStrategyResults = eqx.filter_jit(
self._acquisition_optimizer
)(
score,
prior_features=seed_features,
prior_features=prior_features,
count=count,
seed=acq_rng,
score_with_aux_fn=score_with_aux,
Expand Down Expand Up @@ -410,7 +384,7 @@ def _optimize_acquisition(
# space); also append debug information like model predictions.
logging.info('Converting the optimization result into suggestions...')
return vb.best_candidates_to_trials(
best_candidates, self._one_hot_converter
best_candidates, self._converter
) # [N, D]

@profiler.record_runtime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,6 @@
from absl.testing import parameterized


def randomize_array(converter: converters.TrialToArrayConverter) -> jax.Array:
"""Generate a random array of features to be used as score_fn shift."""
features_arrays = []
for spec in converter.output_specs:
if spec.type == converters.NumpyArraySpecType.ONEHOT_EMBEDDING:
dim = spec.num_dimensions - spec.num_oovs
features_arrays.append(
jnp.eye(spec.num_dimensions)[np.random.randint(0, dim)]
)
elif spec.type == converters.NumpyArraySpecType.CONTINUOUS:
features_arrays.append(np.random.uniform(0.4, 0.6, size=(1,)))
else:
raise ValueError(f'The type {spec.type} is not supported!')
return jnp.hstack(features_arrays)


def create_continuous_problem(
n_features: int,
problem: Optional[vz.ProblemStatement] = None) -> vz.ProblemStatement:
Expand Down Expand Up @@ -80,16 +64,25 @@ def create_mix_problem(n_features: int,


# TODO: Change to bbob functions when they can support batching.
def sphere(x: types.Array) -> jax.Array:
return -jnp.sum(jnp.square(x), axis=-1)
def sphere(x: types.ModelInput) -> jax.Array:
return -(
jnp.sum(jnp.square(x.continuous.padded_array), axis=-1)
+ 0.1 * jnp.sum(jnp.square(x.categorical.padded_array), axis=-1)
)


def rastrigin_d10(x: types.Array) -> jax.Array:
def _rastrigin_d10_part(x: types.Array) -> jax.Array:
return 10 * jnp.sum(jnp.cos(2 * np.pi * x), axis=-1) - jnp.sum(
jnp.square(x), axis=-1
)


def rastrigin_d10(x: types.ModelInput) -> jax.Array:
return _rastrigin_d10_part(x.continuous.padded_array) + _rastrigin_d10_part(
0.01 * x.categorical.padded_array
)


class EagleOptimizerConvegenceTest(parameterized.TestCase):
"""Test optimizing an acquisition functions using vectorized Eagle Strategy.
"""
Expand All @@ -108,11 +101,9 @@ def test_converges(self, create_problem_fn, n_features, score_fn):
logging.info('Starting a new convergence test (n_features: %s)', n_features)
evaluations = 20_000
problem = create_problem_fn(n_features)
converter = converters.TrialToArrayConverter.from_study_config(problem)
converter = converters.TrialToModelInputConverter.from_problem(problem)
eagle_strategy_factory = eagle_strategy.VectorizedEagleStrategyFactory(
eagle_config=eagle_strategy.EagleStrategyConfig())
optimum_features = randomize_array(converter)
shifted_score_fn = lambda x, shift=optimum_features: score_fn(x - shift)
shifted_score_fn = score_fn
random_strategy_factory = rvo.random_strategy_factory
# Run simple regret convergence test.
Expand Down
Loading