diff --git a/vizier/_src/algorithms/designers/gp_bandit.py b/vizier/_src/algorithms/designers/gp_bandit.py index 6b70328d7..877fc8df5 100644 --- a/vizier/_src/algorithms/designers/gp_bandit.py +++ b/vizier/_src/algorithms/designers/gp_bandit.py @@ -38,13 +38,11 @@ from vizier._src.algorithms.designers.gp import output_warpers from vizier._src.algorithms.optimizers import eagle_strategy as es from vizier._src.algorithms.optimizers import vectorized_base as vb -from vizier._src.jax import gp_bandit_utils from vizier._src.jax import stochastic_process_model as sp from vizier._src.jax import types from vizier._src.jax.models import tuned_gp_models from vizier.jax import optimizers from vizier.pyvizier import converters -from vizier.pyvizier.converters import feature_mapper from vizier.pyvizier.converters import padding from vizier.utils import profiler @@ -124,10 +122,6 @@ class VizierGPBandit(vza.Designer, vza.Predictor): _output_warper_pipeline: output_warpers.OutputWarperPipeline = attr.field( init=False ) - # TODO: Remove this. - _feature_mapper: Optional[ - feature_mapper.ContinuousCategoricalFeatureMapper - ] = attr.field(init=False, default=None) _last_computed_state: _GPBanditState = attr.field(init=False) default_acquisition_optimizer_factory = vb.VectorizedOptimizerFactory( @@ -155,31 +149,8 @@ def __attrs_post_init__(self): ) self._output_warper_pipeline = output_warpers.create_default_warper() - # TODO: Get rid of this when Vectorized optimizers operate on CACV. - self._one_hot_converter = ( - converters.TrialToArrayConverter.from_study_config( - self._problem, - scale=True, - pad_oovs=True, - max_discrete_indices=0, - flip_sign_for_minimization_metrics=True, - ) - ) - self._padded_one_hot_converter = ( - converters.PaddedTrialToArrayConverter.from_study_config( - self._problem, - scale=True, - pad_oovs=True, - padding_schedule=self._padding_schedule, - max_discrete_indices=0, - flip_sign_for_minimization_metrics=True, - ) - ) - self._feature_mapper = feature_mapper.ContinuousCategoricalFeatureMapper( - self._one_hot_converter - ) self._acquisition_optimizer = self._acquisition_optimizer_factory( - self._padded_one_hot_converter + self._converter ) acquisition_problem = copy.deepcopy(self._problem) if isinstance( @@ -356,30 +327,33 @@ def _optimize_acquisition( ) -> list[vz.Trial]: start_time = datetime.datetime.now() # Set up optimizer and run - seed_features = vb.trials_to_sorted_array( - self._trials, self._padded_one_hot_converter - ) - seed_features_unpad = vb.trials_to_sorted_array( - self._trials, self._one_hot_converter - ) + seed_features = vb.trials_to_sorted_array(self._trials, self._converter) acq_rng, self._rng = jax.random.split(self._rng) - # TODO: Remove this when Vectorized Optimizer works on CACV. - cacpa = self._converter.to_features(self._trials) - one_hot_to_modelinput = gp_bandit_utils.make_one_hot_to_modelinput_fn( - seed_features_unpad, self._feature_mapper, cacpa - ) + score = scoring_fn.score + score_with_aux = scoring_fn.score_with_aux - score = lambda xs: scoring_fn.score(one_hot_to_modelinput(xs)) - score_with_aux = lambda xs: scoring_fn.score_with_aux( - one_hot_to_modelinput(xs) - ) + prior_features = None + if seed_features is not None: + continuous = seed_features.continuous.unpad() + continuous = types.PaddedArray.from_array( + continuous, + (continuous.shape[0], seed_features.continuous.shape[1]), + fill_value=np.nan, + ) + categorical = seed_features.categorical.unpad() + categorical = types.PaddedArray.from_array( + categorical, + (categorical.shape[0], seed_features.categorical.shape[1]), + fill_value=-1, + ) + prior_features = types.ModelInput(continuous, categorical) best_candidates: vb.VectorizedStrategyResults = eqx.filter_jit( self._acquisition_optimizer )( score, - prior_features=seed_features, + prior_features=prior_features, count=count, seed=acq_rng, score_with_aux_fn=score_with_aux, @@ -410,7 +384,7 @@ def _optimize_acquisition( # space); also append debug information like model predictions. logging.info('Converting the optimization result into suggestions...') return vb.best_candidates_to_trials( - best_candidates, self._one_hot_converter + best_candidates, self._converter ) # [N, D] @profiler.record_runtime diff --git a/vizier/_src/algorithms/optimizers/eagle_optimizer_convergence_test.py b/vizier/_src/algorithms/optimizers/eagle_optimizer_convergence_test.py index 1a4d6c3ff..de74eb3ca 100644 --- a/vizier/_src/algorithms/optimizers/eagle_optimizer_convergence_test.py +++ b/vizier/_src/algorithms/optimizers/eagle_optimizer_convergence_test.py @@ -33,22 +33,6 @@ from absl.testing import parameterized -def randomize_array(converter: converters.TrialToArrayConverter) -> jax.Array: - """Generate a random array of features to be used as score_fn shift.""" - features_arrays = [] - for spec in converter.output_specs: - if spec.type == converters.NumpyArraySpecType.ONEHOT_EMBEDDING: - dim = spec.num_dimensions - spec.num_oovs - features_arrays.append( - jnp.eye(spec.num_dimensions)[np.random.randint(0, dim)] - ) - elif spec.type == converters.NumpyArraySpecType.CONTINUOUS: - features_arrays.append(np.random.uniform(0.4, 0.6, size=(1,))) - else: - raise ValueError(f'The type {spec.type} is not supported!') - return jnp.hstack(features_arrays) - - def create_continuous_problem( n_features: int, problem: Optional[vz.ProblemStatement] = None) -> vz.ProblemStatement: @@ -80,16 +64,25 @@ def create_mix_problem(n_features: int, # TODO: Change to bbob functions when they can support batching. -def sphere(x: types.Array) -> jax.Array: - return -jnp.sum(jnp.square(x), axis=-1) +def sphere(x: types.ModelInput) -> jax.Array: + return -( + jnp.sum(jnp.square(x.continuous.padded_array), axis=-1) + + 0.1 * jnp.sum(jnp.square(x.categorical.padded_array), axis=-1) + ) -def rastrigin_d10(x: types.Array) -> jax.Array: +def _rastrigin_d10_part(x: types.Array) -> jax.Array: return 10 * jnp.sum(jnp.cos(2 * np.pi * x), axis=-1) - jnp.sum( jnp.square(x), axis=-1 ) +def rastrigin_d10(x: types.ModelInput) -> jax.Array: + return _rastrigin_d10_part(x.continuous.padded_array) + _rastrigin_d10_part( + 0.01 * x.categorical.padded_array + ) + + class EagleOptimizerConvegenceTest(parameterized.TestCase): """Test optimizing an acquisition functions using vectorized Eagle Strategy. """ @@ -108,11 +101,9 @@ def test_converges(self, create_problem_fn, n_features, score_fn): logging.info('Starting a new convergence test (n_features: %s)', n_features) evaluations = 20_000 problem = create_problem_fn(n_features) - converter = converters.TrialToArrayConverter.from_study_config(problem) + converter = converters.TrialToModelInputConverter.from_problem(problem) eagle_strategy_factory = eagle_strategy.VectorizedEagleStrategyFactory( eagle_config=eagle_strategy.EagleStrategyConfig()) - optimum_features = randomize_array(converter) - shifted_score_fn = lambda x, shift=optimum_features: score_fn(x - shift) shifted_score_fn = score_fn random_strategy_factory = rvo.random_strategy_factory # Run simple regret convergence test. diff --git a/vizier/_src/algorithms/optimizers/eagle_param_handler.py b/vizier/_src/algorithms/optimizers/eagle_param_handler.py deleted file mode 100644 index a2df0c04d..000000000 --- a/vizier/_src/algorithms/optimizers/eagle_param_handler.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright 2023 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -"""Utils for vectorized eagle strategy.""" - -from typing import Optional - -from flax import struct -import jax -from jax import numpy as jnp -from vizier.pyvizier import converters - - -@struct.dataclass -class EagleParamHandler: - """Vectorized eagle strategy utils. - - The class is used to account for the different types of Vizier parameters and - incorporate their naunces into the vectorized eagle strategy. - - Attributes: - n_feature_dimensions: The total number of feature indices. - n_categorical: The number of CATEGORICAL associated indices. - has_categorical: A flag indicating if at least one feature is categorical. - perturbation_factors: Array of continuous and categorical perturbation - factors. - _categorical_param_mask: A 2D array (n_categorical, n_feature_dimensions) - that for each categorical parameters (row) has 1s in its associated - feature indices. The array is used for sampling categorical values. - _categorical_mask: A 1D array (n_feature_dimensions,) with 1s in the indices - of categorical features and otherwise 0s. The array is used for sampling - categorical values. - _tiebreak_mask: A 1D array (n_feature_dimensions,) with multiplies of - epsilons used to tie breaking. The array is used for sampling categorical - values. - _oov_mask: A 1D array (n_feature_dimensions,) with 1s in the non-oov - indices. The array is used to generate random features with 0 value in the - OOV indices. - _epsilon: A small value used in tie-breaker - """ - # Internal variables - perturbation_factors: jax.Array - _categorical_params_mask: jax.Array - _categorical_mask: jax.Array - _tiebreak_array: jax.Array - _oov_mask: Optional[jax.Array] - _epsilon: float - # Public variables created by the class - n_feature_dimensions: int = struct.field(pytree_node=False) - n_categorical: int = struct.field(pytree_node=False) - has_categorical: bool = struct.field(pytree_node=False) - - @classmethod - def build( - cls, - converter: converters.TrialToArrayConverter, - categorical_perturbation_factor: float, - pure_categorical_perturbation_factor: float, - epsilon: float = 1e-5, - ) -> 'EagleParamHandler': - """Docstring.""" - - valid_types = [ - converters.NumpyArraySpecType.ONEHOT_EMBEDDING, - converters.NumpyArraySpecType.CONTINUOUS - ] - unsupported_params = sum( - 1 for spec in converter.output_specs if spec.type not in valid_types - ) - if unsupported_params: - raise ValueError('Only CATEGORICAL/CONTINUOUS parameters are supported!') - - n_feature_dimensions = converter.to_features([]).shape[-1] - n_categorical = sum( - 1 - for spec in converter.output_specs - if spec.type == converters.NumpyArraySpecType.ONEHOT_EMBEDDING - ) - has_categorical = n_categorical > 0 - all_features_categorical = n_feature_dimensions == n_categorical - oov_mask = None - tiebreak_array = None - categorical_mask = None - categorical_params_mask = None - if has_categorical: - categorical_params_mask = jnp.zeros((n_categorical, n_feature_dimensions)) - oov_mask = jnp.ones((n_feature_dimensions,)) - row = 0 - col = 0 - # Create a flag to indicate if the converter uses OOV padding. If none of - # the CATEGORICAL params use padding the 'oov_mask' is set to None. - is_pad_oov = False - for spec in converter.output_specs: - if spec.type == converters.NumpyArraySpecType.ONEHOT_EMBEDDING: - n_dim = spec.num_dimensions - categorical_params_mask = categorical_params_mask.at[ - row, col : col + n_dim - ].set(1.0) - if spec.num_oovs: - oov_mask = oov_mask.at[col + n_dim - 1].set(0.0) - is_pad_oov = True - row += 1 - col += n_dim - else: - col += 1 - - oov_mask = oov_mask if is_pad_oov else None - tiebreak_array = -epsilon * jnp.arange(1, n_feature_dimensions + 1) - categorical_mask = jnp.sum(categorical_params_mask, axis=0) - - perturbation_factors = [] - - if all_features_categorical: - for spec in converter.output_specs: - perturbation_factors.extend( - [pure_categorical_perturbation_factor] * spec.num_dimensions - ) - else: - for spec in converter.output_specs: - if spec.type == converters.NumpyArraySpecType.ONEHOT_EMBEDDING: - perturbation_factors.extend( - [categorical_perturbation_factor] * spec.num_dimensions - ) - - elif spec.type == converters.NumpyArraySpecType.CONTINUOUS: - perturbation_factors.append(1.0) - # Add any extra dimensions at the end. - perturbation_factors.extend( - [0.0] * (n_feature_dimensions - len(perturbation_factors)) - ) - perturbation_factors = jnp.array(perturbation_factors) - return EagleParamHandler( - n_feature_dimensions=n_feature_dimensions, - n_categorical=n_categorical, - has_categorical=has_categorical, - perturbation_factors=perturbation_factors, - _categorical_params_mask=categorical_params_mask, - _categorical_mask=categorical_mask, - _tiebreak_array=tiebreak_array, - _oov_mask=oov_mask, - _epsilon=epsilon, - ) - - def sample_categorical( - self, features: jax.Array, seed: jax.random.KeyArray - ) -> jax.Array: - """Sample categorical features. - - The categorical sampling is used before returning suggestion to ensure that - only actual categorical values are suggested. Non categorical features are - left unchanged. - - For example: If 'features' has one categorical parameter with 3 values and - one float parameter the conversion will be take the form of: - (0.1, 0.3, 0.5, 0.4578) -> (0, 1, 0, 0.4578). - - Implementation details: - ---------------------- - 1. For each categorical parameter, isolate its indices in a separate row. - 2. Normalize new row values to probabilities through dividing by the sum. - 3. Create cummulatitive sum of probabilites to generat a CDF. - 4. Add small incremental values to CDF to tie-break (explained more later). - 5. Randomize uniform values for each row to determine which value to sample. - 6. Find the minimum index that its CDF > uniform. See code around - 'sampled_categorical_params' below for more detail and for why we need to - add the tie-breaking array values. - 7. Flatten each row sampled values and combine with original features. - - Arguments: - features: (batch_size, n_parallel, n_feature_dimensions) - seed: Random seed. - - Returns: - The features with sampled categorical parameters. - (batch_size, n_parallel, n_feature_dimensions) - """ - if not self.has_categorical: - return features - batch_shape = features.shape[:-1] - - # Mask each row (which represents a categorical param) to remove values in - # indices that aren't associated with the parameter indices. - param_features = ( - features[..., jnp.newaxis, :] * self._categorical_params_mask - ) - # Create probabilities from non-normalized parameter features values. - probs = param_features / jnp.sum(param_features, axis=-1, keepdims=True) - # Generate random uniform values to use for sampling. - # TODO: Pre-compute random values in batch to improve performance. - unifs = jax.random.uniform( - seed, shape=batch_shape + (self.n_categorical, 1) - ) - # Find the locations of the indices that exceed random values. - locs = jnp.cumsum(probs, axis=-1) >= unifs - # Multiply by 'categorical_mask' to mask off cumsum in non-categorical - # indices, and add 'tiebreak_mask' to find the first index. - masked_locs = locs * self._categorical_params_mask + self._tiebreak_array - # Generate the samples so that each parameter features has a single 1 value. - sampled_categorical_params = jnp.trunc( - masked_locs / jnp.max(masked_locs, axis=-1, keepdims=True) - ) - # Flatten all the categories features to dimension - # (batch_size, n_parallel, n_feature_dimensions) - sampled_features = jnp.sum(sampled_categorical_params, axis=-2) - # Mask categorical features and add the new sampled categorical values. - return sampled_features + features * (1 - self._categorical_mask) - - def random_features( - self, - batch_size: int, - n_parallel: int, - seed: jax.random.KeyArray, - ) -> jax.Array: - """Create random features with uniform distribution. - - In case there are CATEGORICAL features with OOV we use the 'oov_mask' which - is 1D numpy array (n_feature_dimensions,) with 0s in OOV indices and - otherwise 1s. After multiplying (and broadcasting) the randomly generated - features with the mask we're guaranteed that no features will be created in - OOV indices. Therefore when mutating fireflies' features (index by index), - the final suggested features will also have 0s in the OOV indices as - desired. - - Arguments: - batch_size: - n_parallel: - seed: Random seed. - - Returns: - The random features with out of vocabulary indices zeroed out. - """ - features = jax.random.uniform( - seed, shape=(batch_size, n_parallel, self.n_feature_dimensions) - ) - if self._oov_mask is not None: - # Don't create random values for CATEGORICAL features OOV indices. - # Broadcasting: - # (batch_size, n_parallel, n_feature_dimensions) - # x (n_feature_dimensions,) - features = features * self._oov_mask - return features diff --git a/vizier/_src/algorithms/optimizers/eagle_strategy.py b/vizier/_src/algorithms/optimizers/eagle_strategy.py index 7f4dac782..d438444dc 100644 --- a/vizier/_src/algorithms/optimizers/eagle_strategy.py +++ b/vizier/_src/algorithms/optimizers/eagle_strategy.py @@ -73,17 +73,19 @@ import enum import logging import math -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import attr from flax import struct import jax from jax import numpy as jnp -from vizier._src.algorithms.optimizers import eagle_param_handler +from tensorflow_probability.substrates import jax as tfp from vizier._src.algorithms.optimizers import vectorized_base as vb from vizier._src.jax import types from vizier.pyvizier import converters +tfd = tfp.distributions + @enum.unique class MutateNormalizationType(enum.IntEnum): @@ -105,6 +107,8 @@ class EagleStrategyConfig: perturbation: The default amount of noise for perturbation. categorical_perturbation_factor: A factor to apply on categorical params. pure_categorical_perturbation_factor: A factor on purely categorical space. + prob_same_category_without_perturbation: Baseline probability of selecting + the same category. perturbation_lower_bound: The threshold below flies are removed from pool. penalize_factor: The perturbation decrease for unsuccessful flies. pool_size_exponent: An exponent for computing pool size based on search @@ -125,8 +129,9 @@ class EagleStrategyConfig: negative_gravity: float = 0.008 # Perturbation perturbation: float = 0.16 - categorical_perturbation_factor: float = 25 + categorical_perturbation_factor: float = 1.0 pure_categorical_perturbation_factor: float = 30 + prob_same_category_without_perturbation: float = 0.98 # Penalty perturbation_lower_bound: float = 7e-5 penalize_factor: float = 7e-1 @@ -153,10 +158,7 @@ class VectorizedEagleStrategyFactory(vb.VectorizedStrategyFactory): def __call__( self, - converter: Union[ - converters.TrialToArrayConverter, - converters.PaddedTrialToArrayConverter, - ], + converter: converters.TrialToModelInputConverter, suggestion_batch_size: Optional[int] = None, ) -> "VectorizedEagleStrategy": """Create a new vectorized eagle strategy. @@ -175,18 +177,42 @@ def __call__( Returns: A new instance of VectorizedEagleStrategy. """ - param_handler = eagle_param_handler.EagleParamHandler.build( - converter, - self.eagle_config.categorical_perturbation_factor, - self.eagle_config.pure_categorical_perturbation_factor, + valid_types = [ + converters.NumpyArraySpecType.DISCRETE, + converters.NumpyArraySpecType.CONTINUOUS, + ] + if any( + spec.type not in valid_types + for spec in ( + list(converter.output_specs.continuous) + + list(converter.output_specs.categorical) + ) + ): + raise ValueError("Only DISCRETE/CONTINUOUS parameters are supported!") + + empty_features = converter.to_features([]) + n_feature_dimensions_with_padding = types.ContinuousAndCategorical[int]( + continuous=empty_features.continuous.shape[-1], + categorical=empty_features.categorical.shape[-1], + ) + n_feature_dimensions = types.ContinuousAndCategorical( + continuous=len(converter.output_specs.continuous), + categorical=len(converter.output_specs.categorical), ) - n_feature_dimensions_with_padding = converter.to_features([]).shape[-1] - n_features = len(converter.output_specs) - if self.eagle_config.pool_size > 0: - # This allow to override the pool size computation. - pool_size = self.eagle_config.pool_size + categorical_sizes = [] + for spec in converter.output_specs.categorical: + categorical_sizes.append(spec.bounds[1]) + if categorical_sizes: + max_categorical_size = max(categorical_sizes) else: + max_categorical_size = 0 + + n_features = ( + n_feature_dimensions.continuous + n_feature_dimensions.categorical + ) + pool_size = self.eagle_config.pool_size + if pool_size == 0: pool_size = 10 + int( 0.5 * n_features + n_features**self.eagle_config.pool_size_exponent ) @@ -202,15 +228,14 @@ def __call__( suggestion_batch_size = pool_size # Use priors to populate Eagle state return VectorizedEagleStrategy( - param_handler=param_handler, - n_features=n_features, - n_feature_dimensions=sum( - spec.num_dimensions for spec in converter.output_specs - ), + n_feature_dimensions=n_feature_dimensions, n_feature_dimensions_with_padding=n_feature_dimensions_with_padding, batch_size=suggestion_batch_size, config=self.eagle_config, pool_size=pool_size, + categorical_sizes=tuple(categorical_sizes), + max_categorical_size=max_categorical_size, + dtype=converter._impl.dtype, ) @@ -219,58 +244,98 @@ class VectorizedEagleStrategyState: """Container for Eagle strategy state.""" iterations: jax.Array # Scalar integer. - features: jax.Array # Shape (pool_size, n_parallel, n_features). + features: vb.VectorizedOptimizerInput # (pool_size, n_parallel, n_features). rewards: jax.Array # Shape (pool_size,). best_reward: jax.Array # Scalar float. perturbations: jax.Array # Shape (pool_size,). +def _compute_features_dist( + x_batch: vb.VectorizedOptimizerInput, x_pool: vb.VectorizedOptimizerInput +) -> jax.Array: + """Computes distance between features (or parallel feature batches).""" + dist = jnp.zeros([], dtype=x_batch.continuous.dtype) + if x_batch.continuous.size > 0: + x_batch_cont = jnp.reshape( + x_batch.continuous, (x_batch.continuous.shape[0], -1) + ) + x_pool_cont = jnp.reshape( + x_pool.continuous, (x_pool.continuous.shape[0], -1) + ) + continuous_dists = ( + jnp.sum(x_batch_cont**2, axis=-1, keepdims=True) + + jnp.sum(x_pool_cont**2, axis=-1) + - 2.0 * jnp.matmul(x_batch_cont, x_pool_cont.T) + ) # shape (batch_size, pool_size) + dist = dist + continuous_dists + + if x_batch.categorical.size > 0: + x_batch_cat = jnp.reshape( + x_batch.categorical, (x_batch.categorical.shape[0], -1) + ) + x_pool_cat = jnp.reshape( + x_pool.categorical, (x_pool.categorical.shape[0], -1) + ) + categorical_diffs = (x_batch_cat[..., jnp.newaxis, :] != x_pool_cat).astype( + x_batch.continuous.dtype + ) + dist = dist + jnp.sum(categorical_diffs, axis=-1) + return dist + + @struct.dataclass -class VectorizedEagleStrategy(vb.VectorizedStrategy): +class VectorizedEagleStrategy( + vb.VectorizedStrategy[VectorizedEagleStrategyState] +): """Eagle strategy implementation for maximization problem based on Numpy. Attributes: - converter: The converter used for the optimization problem. config: The Eagle strategy configuration. n_features: The number of features. batch_size: The number of suggestions generated at each suggestion call. pool_size: The total number of flies in the pool. """ - param_handler: eagle_param_handler.EagleParamHandler - n_features: int - n_feature_dimensions: int - n_feature_dimensions_with_padding: int = struct.field(pytree_node=False) + n_feature_dimensions: types.ContinuousAndCategorical[int] + n_feature_dimensions_with_padding: types.ContinuousAndCategorical[int] = ( + struct.field(pytree_node=False) + ) + categorical_sizes: Tuple[int] = struct.field(pytree_node=False) + max_categorical_size: int = struct.field(pytree_node=False) pool_size: int = struct.field(pytree_node=False) + dtype: jnp.dtype = struct.field(pytree_node=False) batch_size: Optional[int] = struct.field(pytree_node=False, default=None) config: EagleStrategyConfig = struct.field( default_factory=EagleStrategyConfig ) - def __post_init__(self): - logging.info("Eagle class attributes:\n%s", self) - logging.info("Eagle configuration:\n%s", self.config) - def init_state( self, seed: jax.random.KeyArray, n_parallel: int = 1, *, - prior_features: Optional[types.Array] = None, + prior_features: Optional[vb.VectorizedOptimizerInput] = None, prior_rewards: Optional[types.Array] = None, ) -> VectorizedEagleStrategyState: """Initializes the state.""" if prior_features is not None and prior_rewards is not None: - if prior_features.shape[1] != n_parallel: + if prior_features.continuous.shape[1] != n_parallel: raise ValueError( - f"`prior_features` dimension 1 ({prior_features.shape[1]}) " + "`prior_features.continuous` dimension 1 " + f"({prior_features.continuous.shape[1]}) " + f"doesn't match n_parallel ({n_parallel})!" + ) + if prior_features.categorical.shape[1] != n_parallel: + raise ValueError( + "`prior_features.categorical` dimension 1 " + f"({prior_features.categorical.shape[1]}) " f"doesn't match n_parallel ({n_parallel})!" ) init_features = self._populate_pool_with_prior_trials( seed, prior_features, prior_rewards ) else: - init_features = self.param_handler.random_features( + init_features = self._sample_random_features( self.pool_size, n_parallel=n_parallel, seed=seed ) return VectorizedEagleStrategyState( @@ -281,12 +346,43 @@ def init_state( perturbations=jnp.ones(self.pool_size) * self.config.perturbation, ) + def _sample_random_features( + self, num_samples: int, n_parallel: int, seed: jax.random.KeyArray + ) -> vb.VectorizedOptimizerInput: + cont_seed, cat_seed = jax.random.split(seed) + + if self.max_categorical_size > 0: + sizes = jnp.array(self.categorical_sizes)[:, jnp.newaxis] + logits = jnp.where( + jnp.arange(self.max_categorical_size) < sizes, 0.0, -jnp.inf + ) + random_categorical_features = ( + tfd.Categorical(logits=logits) + .sample((num_samples, n_parallel), seed=cat_seed) + .astype(types.INT_DTYPE) + ) + else: + random_categorical_features = jnp.zeros( + [num_samples, n_parallel, 0], types.INT_DTYPE + ) + return types.ContinuousAndCategoricalArray( + continuous=jax.random.uniform( + cont_seed, + shape=( + num_samples, + n_parallel, + self.n_feature_dimensions_with_padding.continuous, + ), + ), + categorical=random_categorical_features, + ) + def _populate_pool_with_prior_trials( self, seed: jax.random.KeyArray, - prior_features: types.Array, + prior_features: types.ContinuousAndCategoricalArray, prior_rewards: types.Array, - ) -> jax.Array: + ) -> types.ContinuousAndCategoricalArray: """Populate the pool with prior trials. Args: @@ -304,22 +400,40 @@ def _populate_pool_with_prior_trials( """ if prior_features is None or prior_rewards is None: raise ValueError("One of prior features / prior rewards wasn't provided!") - if prior_features.shape[0] != prior_rewards.shape[0]: - raise ValueError( - f"prior features shape ({prior_features.shape[0]}) doesn't match" - f" prior rewards shape ({prior_rewards.shape[0]})!" - ) - if prior_features.shape[-1] != self.n_feature_dimensions_with_padding: - raise ValueError( - f"prior features shape ({prior_features.shape[-1]}) doesn't match" - f" n_features {self.n_feature_dimensions_with_padding}!" - ) + if prior_features.continuous is not None: + continuous_obs, _, continuous_dim = prior_features.continuous.shape + if continuous_obs != prior_rewards.shape[0]: + raise ValueError( + f"prior continuous features shape ({continuous_obs}) doesn't match" + f" prior rewards shape ({prior_rewards.shape[0]})!" + ) + expected_dim = self.n_feature_dimensions_with_padding.continuous + if continuous_dim != expected_dim: + raise ValueError( + f"prior continuous features shape ({continuous_dim}) doesn't match " + f"n_features {expected_dim}!" + ) + if prior_features.categorical is not None: + categorical_obs, _, categorical_dim = prior_features.categorical.shape + if categorical_obs != prior_rewards.shape[0]: + raise ValueError( + f"prior categorical features shape ({categorical_obs}) doesn't " + f"match prior rewards shape ({prior_rewards.shape[0]})!" + ) + expected_dim = self.n_feature_dimensions_with_padding.categorical + if categorical_dim != expected_dim: + raise ValueError( + f"prior categorical features shape ({categorical_dim}) doesn't" + f" match n_features {expected_dim}!" + ) if len(prior_rewards.shape) > 1: raise ValueError("prior rewards is expected to be 1D array!") # Reverse the order of prior trials to assign more weight to recent trials. - flipped_prior_features = jnp.flip(prior_features, axis=0) + flipped_prior_features = jax.tree_util.tree_map( + lambda x: jnp.flip(x, axis=0), prior_features + ) flipped_prior_rewards = jnp.flip(prior_rewards, axis=0) # Fill pool with random features. @@ -328,36 +442,52 @@ def _populate_pool_with_prior_trials( ) seed1, seed2 = jax.random.split(seed) - _, n_parallel, _ = prior_features.shape - init_features = self.param_handler.random_features( + n_parallel = flipped_prior_features.continuous.shape[1] + init_features = self._sample_random_features( n_random_flies, n_parallel, seed1 ) pool_left_space = self.pool_size - n_random_flies - if prior_features.shape[0] < pool_left_space: + if prior_rewards.shape[0] < pool_left_space: # Less prior trials than left space. Take all prior trials for the pool. - init_features = jnp.concatenate([init_features, flipped_prior_features]) + init_features = jax.tree_util.tree_map( + lambda x, y: jnp.concatenate([x, y]), + init_features, + flipped_prior_features, + ) # Randomize the rest of the pool fireflies. - random_features = self.param_handler.random_features( - self.pool_size - len(init_features), n_parallel, seed2 + random_features = self._sample_random_features( + self.pool_size - init_features.continuous.shape[0], n_parallel, seed2 + ) + return jax.tree_util.tree_map( + lambda x, y: jnp.concatenate([x, y]), init_features, random_features ) - return jnp.concatenate([init_features, random_features]) else: # More prior trials than left space. Iteratively populate the pool. - tmp_features = flipped_prior_features[:pool_left_space] + tmp_features = jax.tree_util.tree_map( + lambda x: x[:pool_left_space], flipped_prior_features + ) tmp_rewards = flipped_prior_rewards[:pool_left_space] def _loop_body(i, args): features, rewards = args ind = jnp.argmin( - jnp.sum( - jnp.square(flipped_prior_features[i] - features), axis=(-1, -2) - ) - ) + _compute_features_dist( + jax.tree_util.tree_map( + lambda x: x[i][jnp.newaxis], flipped_prior_features + ), + features, + ), + axis=-1, + )[0] return jax.lax.cond( rewards[ind] < flipped_prior_rewards[i], lambda: ( - features.at[ind].set(flipped_prior_features[i]), + jax.tree_util.tree_map( + lambda f, pf: f.at[ind].set(pf[i]), + features, + flipped_prior_features, + ), rewards.at[ind].set(flipped_prior_rewards[i]), ), lambda: (features, rewards), @@ -367,11 +497,13 @@ def _loop_body(i, args): # the for-loop. tmp_features, _ = jax.lax.fori_loop( lower=pool_left_space, - upper=prior_features.shape[0], + upper=prior_rewards.shape[0], body_fun=_loop_body, init_val=(tmp_features, tmp_rewards), ) - return jnp.concatenate([init_features, tmp_features]) + return jax.tree_util.tree_map( + lambda x, y: jnp.concatenate([x, y]), init_features, tmp_features + ) @property def suggestion_batch_size(self) -> int: @@ -383,7 +515,7 @@ def suggest( seed: jax.random.KeyArray, state: VectorizedEagleStrategyState, n_parallel: int = 1, - ) -> jax.Array: + ) -> vb.VectorizedOptimizerInput: """Suggest new mutated and perturbed features. After initializing, at each call `batch_size` fireflies are mutated to @@ -402,8 +534,9 @@ def suggest( """ batch_id = state.iterations % (self.pool_size // self.batch_size) start = batch_id * self.batch_size - features_batch = jax.lax.dynamic_slice_in_dim( - state.features, start, self.batch_size + features_batch = jax.tree_util.tree_map( + lambda f: jax.lax.dynamic_slice_in_dim(f, start, self.batch_size), + state.features, ) rewards_batch = jax.lax.dynamic_slice_in_dim( state.rewards, start, self.batch_size @@ -411,20 +544,20 @@ def suggest( perturbations_batch = jax.lax.dynamic_slice_in_dim( state.perturbations, start, self.batch_size ) - features_seed, perturbations_seed, cat_seed = jax.random.split(seed, num=3) + features_seed, perturbations_seed = jax.random.split(seed) def _mutate_features(features_batch_): - mutated_features = self._create_features( + perturbations = self._create_random_perturbations( + perturbations_batch, n_parallel, perturbations_seed + ) + return self._create_features( state.features, state.rewards, features_batch_, rewards_batch, + perturbations, features_seed, ) - perturbations = self._create_random_perturbations( - perturbations_batch, n_parallel, perturbations_seed - ) - return mutated_features + perturbations # If the strategy is still initializing, return the random/prior features. new_features = jax.lax.cond( @@ -434,22 +567,25 @@ def _mutate_features(features_batch_): features_batch, ) - new_features = self.param_handler.sample_categorical(new_features, cat_seed) # TODO: The range of features is not always [0, 1]. # Specifically, for features that are single-point, it can be [0, 0]; we # also want this code to be aware of the feature's bounds to enable # contextual bandit operation. Note that if a parameter's bound changes, # we might also want to change the firefly noise or normalizations. - return jnp.clip(new_features, 0.0, 1.0) + return vb.VectorizedOptimizerInput( + continuous=jnp.clip(new_features.continuous, 0.0, 1.0), + categorical=new_features.categorical, + ) def _create_features( self, - features: jax.Array, + features: vb.VectorizedOptimizerInput, rewards: jax.Array, - features_batch: jax.Array, + features_batch: vb.VectorizedOptimizerInput, rewards_batch: jax.Array, + perturbations_batch: types.ContinuousAndCategoricalArray, seed: jax.random.KeyArray, - ) -> jax.Array: + ) -> vb.VectorizedOptimizerInput: """Create new batch of mutated and perturbed features. The pool fireflies forces (pull/push) are being normalized to ensure the @@ -462,6 +598,7 @@ def _create_features( rewards: (pool_size,) features_batch: (batch_size, n_parallel, n_features) rewards_batch: (batch_size,) + perturbations_batch: (batch_size,) seed: Random seed. Returns: @@ -471,13 +608,7 @@ def _create_features( # pool. We use a less numerically precise squared distance formulation to # avoid materializing a possibly large intermediate of shape # (batch_size, pool_size, n_features). - flat_features = jnp.reshape(features, (self.pool_size, -1)) - flat_features_batch = jnp.reshape(features_batch, (self.batch_size, -1)) - dists = ( - jnp.sum(flat_features_batch**2, axis=-1, keepdims=True) - + jnp.sum(flat_features**2, axis=-1) - - 2.0 * jnp.matmul(flat_features_batch, flat_features.T) - ) # shape (batch_size, pool_size) + dists = _compute_features_dist(features_batch, features) # Compute the scaled direction for applying pull between two flies. # scaled_directions[i,j] := direction of force applied by fly 'j' on fly @@ -493,8 +624,11 @@ def _create_features( # Normalize the distance by the number of features. # Get the number of non-padded features. + n_feature_dimensions = sum( + jax.tree_util.tree_leaves(self.n_feature_dimensions) + ) force = jnp.exp( - -self.config.visibility * dists / self.n_feature_dimensions * 10.0 + -self.config.visibility * dists / n_feature_dimensions * 10.0 ) scaled_force = scaled_directions * force # Handle removed fireflies without updated rewards. @@ -507,6 +641,7 @@ def _create_features( scaled_pulls = jnp.maximum(scaled_force, 0.0) scaled_push = jnp.minimum(scaled_force, 0.0) + seed, categorical_seed = jax.random.split(seed) if self.config.mutate_normalization_type == MutateNormalizationType.MEAN: # Divide the push and pull forces by the number of flies participating. # Also multiply by normalization_scale. @@ -555,18 +690,108 @@ def _create_features( # features_dist[i, j] := distance between fly 'j' and fly 'i' # but avoids materializing the large pairwise distance matrix. scale = norm_scaled_pulls + norm_scaled_push - flat_features_changes = jnp.matmul( + flat_features = jnp.reshape( + features.continuous, (features.continuous.shape[0], -1) + ) + flat_features_batch = jnp.reshape( + features_batch.continuous, (features_batch.continuous.shape[0], -1) + ) + + # TODO: Consider computing per batch member. + features_changes_continuous = jnp.matmul( scale, flat_features ) - flat_features_batch * jnp.sum(scale, axis=-1, keepdims=True) - features_changes = jnp.reshape(flat_features_changes, features_batch.shape) - return features_batch + features_changes + + features_continuous = ( + features_batch.continuous + + jnp.reshape( + features_changes_continuous, features_batch.continuous.shape + ) + + perturbations_batch.continuous + ) + if self.max_categorical_size > 0: + features_categorical_logits = ( + self._create_categorical_feature_logits( + features.categorical, features_batch.categorical, scale + ) + + perturbations_batch.categorical + ) + features_categorical = tfd.Categorical( + logits=features_categorical_logits + ).sample(seed=categorical_seed) + else: + features_categorical = jnp.zeros( + features_batch.continuous.shape[:2] + (0,), dtype=types.INT_DTYPE + ) + return vb.VectorizedOptimizerInput( + continuous=features_continuous, categorical=features_categorical + ) + + def _create_logits_vector( + self, + features_one_category: jax.Array, # [pool_size] + feature_batch_member_one_category: jax.Array, # scalar integer + scale_batch_member: jax.Array, # [pool_size] + feature_size: jax.Array, + ): # scalar + categories = jnp.arange(self.max_categorical_size) + logit_same_category = jnp.log( + self.config.prob_same_category_without_perturbation + ) + logit_different_category = jnp.log( + (1.0 - self.config.prob_same_category_without_perturbation) + / (feature_size - 1.0) + ) + logits = ( + jnp.sum( + jnp.where( + categories[:, jnp.newaxis] == features_one_category, + scale_batch_member, + 0.0, + ), + axis=-1, + ) + + logit_different_category + ) + logits = jnp.where(categories < feature_size, logits, -jnp.inf) + return logits.at[feature_batch_member_one_category].add( + -jnp.sum(scale_batch_member) + + logit_same_category + - logit_different_category + ) # [num_categories] + + def _create_logits_one_feature( + self, + features_one_category: jax.Array, # [pool_size, num_parallel] + features_batch_one_category: jax.Array, # [batch_size, num_parallel] + scale: jax.Array, # [batch_size, pool_size] + feature_size: jax.Array, # scalar + ): + return jax.vmap( # map over batch + jax.vmap(self._create_logits_vector, in_axes=(-1, -1, None, None)), + in_axes=(None, 0, 0, None), + )( + features_one_category, features_batch_one_category, scale, feature_size + ) # [batch_size, num_parallel, max_num_categories] + + def _create_categorical_feature_logits( + self, + features: jax.Array, # [pool_size, num_parallel, num_features] + features_batch: jax.Array, # [batch_size, num_parallel, num_features] + scale: jax.Array, # [batch_size, pool_size] + ): + return jax.vmap( + self._create_logits_one_feature, in_axes=(-1, -1, None, 0), out_axes=2 + )( + features, features_batch, scale, jnp.array(self.categorical_sizes) + ) # [batch_size, num_parallel, num_features, num_categories] def _create_random_perturbations( self, perturbations_batch: jax.Array, n_parallel: int, seed: jax.random.KeyArray, - ) -> jax.Array: + ) -> types.ContinuousAndCategoricalArray: """Create random perturbations for the newly created batch. Args: @@ -577,29 +802,57 @@ def _create_random_perturbations( seed: Random seed. Returns: - perturbations: (batch_size, n_features) + perturbations: (batch_size, n_parallel, n_features) """ + cont_seed, cat_seed = jax.random.split(seed) # Generate normalized noise for each batch. - batch_noise = jax.random.laplace( - seed, + batch_noise_continuous = jax.random.laplace( + cont_seed, shape=( self.batch_size, n_parallel, - self.n_feature_dimensions_with_padding, + self.n_feature_dimensions_with_padding.continuous, ), ) - batch_noise /= jnp.max(jnp.abs(batch_noise), axis=-1, keepdims=True) - return ( - batch_noise - * perturbations_batch[:, jnp.newaxis, jnp.newaxis] - * self.param_handler.perturbation_factors + if self.n_feature_dimensions_with_padding.continuous > 0: + batch_noise_continuous /= jnp.max( + jnp.abs(batch_noise_continuous), axis=1, keepdims=True + ) + + if self.n_feature_dimensions_with_padding.continuous == 0: + categorical_perturbation = ( + self.config.pure_categorical_perturbation_factor + ) + else: + categorical_perturbation = self.config.categorical_perturbation_factor + batch_noise_categorical = ( + jax.random.laplace( + cat_seed, + shape=( + self.batch_size, + n_parallel, + self.n_feature_dimensions_with_padding.categorical, + self.max_categorical_size, + ), + ) + * categorical_perturbation + ) + return types.ContinuousAndCategoricalArray( + continuous=( + batch_noise_continuous + * perturbations_batch[:, jnp.newaxis, jnp.newaxis] + ), + categorical=( + batch_noise_categorical + * perturbations_batch[:, jnp.newaxis, jnp.newaxis, jnp.newaxis] + ), ) def update( self, seed: jax.random.KeyArray, state: VectorizedEagleStrategyState, - batch_features: types.Array, + batch_features: vb.VectorizedOptimizerInput, batch_rewards: types.Array, ) -> VectorizedEagleStrategyState: """Update the firefly pool based on the new batch of results. @@ -626,8 +879,11 @@ def _update(batch_features, batch_rewards, batch_perturbations): self._update_pool_features_and_rewards( batch_features, batch_rewards, - jax.lax.dynamic_slice_in_dim( - state.features, batch_start_ind, self.batch_size + jax.tree_util.tree_map( + lambda f: jax.lax.dynamic_slice_in_dim( + f, batch_start_ind, self.batch_size + ), + state.features, ), jax.lax.dynamic_slice_in_dim( state.rewards, batch_start_ind, self.batch_size @@ -657,8 +913,12 @@ def _update(batch_features, batch_rewards, batch_perturbations): return VectorizedEagleStrategyState( iterations=state.iterations + 1, - features=jax.lax.dynamic_update_slice_in_dim( - state.features, new_batch_features, batch_start_ind, axis=0 + features=jax.tree_util.tree_map( + lambda sf, nbf: jax.lax.dynamic_update_slice_in_dim( + sf, nbf, batch_start_ind, axis=0 + ), + state.features, + new_batch_features, ), rewards=jax.lax.dynamic_update_slice_in_dim( state.rewards, new_batch_rewards, batch_start_ind, axis=0 @@ -674,12 +934,12 @@ def _update(batch_features, batch_rewards, batch_perturbations): def _update_pool_features_and_rewards( self, - batch_features: jax.Array, + batch_features: vb.VectorizedOptimizerInput, batch_rewards: jax.Array, - prev_batch_features: jax.Array, + prev_batch_features: vb.VectorizedOptimizerInput, prev_batch_rewards: jax.Array, perturbations: jax.Array, - ) -> Tuple[jax.Array, jax.Array, jax.Array]: + ) -> Tuple[vb.VectorizedOptimizerInput, jax.Array, jax.Array]: """Update the features and rewards for flies with improved rewards. Arguments: @@ -697,8 +957,10 @@ def _update_pool_features_and_rewards( # Find indices of flies that their generated features made an improvement. improve_indx = batch_rewards > prev_batch_rewards # Update successful flies' with the associated last features and rewards. - new_batch_features = jnp.where( - improve_indx[..., jnp.newaxis, jnp.newaxis], + new_batch_features = jax.tree_util.tree_map( + lambda bf, pbf: jnp.where( + improve_indx[..., jnp.newaxis, jnp.newaxis], bf, pbf + ), batch_features, prev_batch_features, ) @@ -713,12 +975,12 @@ def _update_pool_features_and_rewards( def _trim_pool( self, - batch_features: jax.Array, + batch_features: vb.VectorizedOptimizerInput, batch_rewards: jax.Array, batch_perturbations: jax.Array, best_reward: jax.Array, seed: jax.random.KeyArray, - ) -> Tuple[jax.Array, jax.Array, jax.Array]: + ) -> Tuple[vb.VectorizedOptimizerInput, jax.Array, jax.Array]: """Trim the pool by replacing unsuccessful fireflies with new random ones. A firefly is considered unsuccessful if its current perturbation is below @@ -743,11 +1005,15 @@ def _trim_pool( indx = indx & (batch_rewards != best_reward) # Replace fireflies with random features and evaluate rewards. - random_features = self.param_handler.random_features( - self.batch_size, n_parallel=batch_features.shape[1], seed=seed + random_features = self._sample_random_features( + self.batch_size, + n_parallel=batch_features.continuous.shape[1], + seed=seed, ) - new_batch_features = jnp.where( - indx[..., jnp.newaxis, jnp.newaxis], random_features, batch_features + new_batch_features = jax.tree_util.tree_map( + lambda rf, bf: jnp.where(indx[..., jnp.newaxis, jnp.newaxis], rf, bf), + random_features, + batch_features, ) new_batch_perturbations = jnp.where( indx, self.config.perturbation, batch_perturbations diff --git a/vizier/_src/algorithms/optimizers/eagle_strategy_test.py b/vizier/_src/algorithms/optimizers/eagle_strategy_test.py index 557805cfb..0baa81237 100644 --- a/vizier/_src/algorithms/optimizers/eagle_strategy_test.py +++ b/vizier/_src/algorithms/optimizers/eagle_strategy_test.py @@ -20,8 +20,8 @@ import jax from jax import numpy as jnp import numpy as np +from tensorflow_probability.substrates import jax as tfp from vizier import pyvizier as vz -from vizier._src.algorithms.optimizers import eagle_param_handler from vizier._src.algorithms.optimizers import eagle_strategy from vizier._src.algorithms.optimizers import vectorized_base as vb from vizier.pyvizier import converters @@ -29,34 +29,105 @@ from absl.testing import absltest +tfd = tfp.distributions + + +def _create_logits_vector_simple( + categorical_features, + categorical_features_batch, + scale, + categorical_sizes, + max_categorical_size, + config, +): + n_batch = categorical_features_batch.shape[0] + n_feat = categorical_features.shape[0] + logits = np.zeros((n_batch, len(categorical_sizes), max_categorical_size)) + for i, s in enumerate(categorical_sizes): + one_hot_features = np.zeros([n_feat, s]) + one_hot_features[np.arange(n_feat), categorical_features[:, i]] = 1 + + one_hot_batch = np.zeros([n_batch, s]) + one_hot_batch[np.arange(n_batch), categorical_features_batch[:, i]] = 1 + + features_change = np.matmul( + scale, one_hot_features + ) - one_hot_batch * np.sum(scale, axis=-1, keepdims=True) + + diff_category_logit = np.log( + (1.0 - config.prob_same_category_without_perturbation) / (s - 1) + ) + logits_i = np.zeros((n_batch, max_categorical_size)) + diff_category_logit + logits_i[:, s:] = -np.inf + logits_i[np.arange(n_batch), categorical_features_batch[:, i]] = np.log( + config.prob_same_category_without_perturbation + ) + logits_i[:, :s] = logits_i[:, :s] + features_change + logits[:, i, :] = logits_i + return logits + def _create_features_simple( - features, rewards, features_batch, rewards_batch, config, n_features + features, + rewards, + features_batch, + rewards_batch, + config, + n_features, + categorical_sizes, + max_categorical_size, + seed, ): """A version of `_create_features` that materializes large intermediates.""" - features_diffs = features - features_batch[:, jnp.newaxis, :] - dists = jnp.sum(jnp.square(features_diffs), axis=-1) + # Only works with no parallel batch dimension. + continuous_features_diffs = ( + features.continuous - features_batch.continuous[:, jnp.newaxis, :] + ) + categorical_features_diffs = ( + features.categorical != features_batch.categorical[:, jnp.newaxis, :] + ) + features_diffs = vb.VectorizedOptimizerInput( + continuous=continuous_features_diffs, + categorical=categorical_features_diffs, + ) + dists = jax.tree_util.tree_map( + lambda x: jnp.sum(jnp.square(x), axis=-1), features_diffs + ) directions = rewards - rewards_batch[:, jnp.newaxis] scaled_directions = jnp.where( directions >= 0.0, config.gravity, -config.negative_gravity ) - # Normalize the distance by the number of features. - force = jnp.exp(-config.visibility * dists / n_features * 10.0) - scaled_force = scaled_directions * force # Handle removed fireflies without updated rewards. - finite_ind = jnp.isfinite(rewards).astype(scaled_force.dtype) + finite_ind = jnp.isfinite(rewards).astype(directions.dtype) # Ignore fireflies that were removed from the pool. - scaled_force = scaled_force * finite_ind + scale = jax.tree_util.tree_map( + lambda x: finite_ind # pylint: disable=g-long-lambda + * scaled_directions + * jnp.exp(-config.visibility * x / n_features * 10.0), + dists, + ) # Separate forces to pull and push so to normalize them separately. - scaled_pulls = jnp.maximum(scaled_force, 0.0) - scaled_push = jnp.minimum(scaled_force, 0.0) - features_changes = jnp.sum( - features_diffs * (scaled_pulls + scaled_push)[..., jnp.newaxis], axis=1 + new_continuous_features = features_batch.continuous + jnp.sum( + features_diffs.continuous * scale.continuous[..., jnp.newaxis], axis=1 + ) + categorical_features_logits = _create_logits_vector_simple( + features.categorical, + features_batch.categorical, + scale.categorical, + categorical_sizes, + max_categorical_size, + config, + ) + new_categorical_features = tfd.Categorical( + logits=categorical_features_logits + ).sample(seed=seed) + + return vb.VectorizedOptimizerInput( + new_continuous_features, new_categorical_features ) - return features_batch + features_changes class VectorizedEagleStrategyContinuousTest(parameterized.TestCase): @@ -70,48 +141,103 @@ def setUp(self): root = problem.search_space.select_root() root.add_float_param('x1', 0.0, 1.0) root.add_float_param('x2', 0.0, 1.0) - converter = converters.TrialToArrayConverter.from_study_config(problem) + root.add_categorical_param('c1', ['a', 'b']) + root.add_categorical_param('c2', ['a', 'b', 'c']) + self.converter = converters.TrialToModelInputConverter.from_problem(problem) self.eagle = eagle_strategy.VectorizedEagleStrategyFactory( eagle_config=self.config - )(converter=converter, suggestion_batch_size=2) + )(converter=self.converter, suggestion_batch_size=2) - def test_create_features(self): - features = jnp.array([[1, 2], [3, 4], [7, 7], [8, 8]]) + def test_create_features_and_logits(self): + features_continuous = jnp.array( + [[[1.0, 2.0]], [[3.0, 4.0]], [[7.0, 7.0]], [[8.0, 8.0]]] + ) + features_categorical = jnp.array([[1, 2], [0, 0], [0, 1], [1, 1]])[ + :, jnp.newaxis, : + ] rewards = jnp.array([2, 3, 4, 1]) seed = jax.random.PRNGKey(0) - features_batch = features[: self.eagle.batch_size] + features_continuous_batch = features_continuous[: self.eagle.batch_size] + features_categorical_batch = features_categorical[: self.eagle.batch_size] + features = vb.VectorizedOptimizerInput( + continuous=features_continuous, categorical=features_categorical + ) + features_batch = vb.VectorizedOptimizerInput( + continuous=features_continuous_batch, + categorical=features_categorical_batch, + ) rewards_batch = rewards[: self.eagle.batch_size] - self.assertEqual( - self.eagle._create_features( - features, - rewards, - features_batch, - rewards_batch, - seed=seed, - ).shape, - (2, 2), + created_features = self.eagle._create_features( + features, + rewards, + features_batch, + rewards_batch, + vb.VectorizedOptimizerInput( + jnp.zeros_like(features_continuous_batch), + jnp.zeros(features_categorical_batch.shape + (3,)), + ), + seed=seed, ) + self.assertEqual(created_features.continuous.shape, (2, 1, 2)) + self.assertEqual(created_features.categorical.shape, (2, 1, 2)) + features_2d = vb.VectorizedOptimizerInput( + features.continuous[:, 0, :], features.categorical[:, 0, :] + ) + features_batch_2d = vb.VectorizedOptimizerInput( + features_batch.continuous[:, 0, :], features_batch.categorical[:, 0, :] + ) expected = _create_features_simple( - features, + features_2d, rewards, - features_batch, + features_batch_2d, rewards_batch, self.config.replace( mutate_normalization_type=( eagle_strategy.MutateNormalizationType.UNNORMALIZED ) ), - self.eagle.n_feature_dimensions, + ( + self.eagle.n_feature_dimensions.continuous + + self.eagle.n_feature_dimensions.categorical + ), + self.eagle.categorical_sizes, + self.eagle.max_categorical_size, + seed, ) actual = self.eagle._create_features( features, rewards, features_batch, rewards_batch, + vb.VectorizedOptimizerInput( + jnp.zeros_like(features_continuous_batch), + jnp.zeros(features_categorical_batch.shape + (3,)), + ), seed=seed, ) - np.testing.assert_array_equal(expected, actual) + np.testing.assert_array_equal( + expected.continuous, actual.continuous[:, 0, :] + ) + np.testing.assert_array_equal( + expected.categorical, actual.categorical[:, 0, :] + ) + + scale = np.random.normal(size=[self.eagle.batch_size, 4]) + expected_logits = _create_logits_vector_simple( + features_2d.categorical, + features_batch_2d.categorical, + scale, + self.eagle.categorical_sizes, + self.eagle.max_categorical_size, + self.config, + ) + actual_logits = self.eagle._create_categorical_feature_logits( + features.categorical, features_batch.categorical, scale + ) + np.testing.assert_allclose( + expected_logits, actual_logits[:, 0, :, :], rtol=1e-6 + ) @parameterized.parameters(1, 5) def test_create_random_perturbations(self, n_parallel): @@ -122,32 +248,48 @@ def test_create_random_perturbations(self, n_parallel): n_parallel=n_parallel, seed=seed, ) - self.assertEqual(perturbations.shape, (2, n_parallel, 2)) + self.assertEqual(perturbations.continuous.shape, (2, n_parallel, 2)) + self.assertEqual(perturbations.categorical.shape, (2, n_parallel, 2, 3)) def test_update_pool_features_and_rewards(self): - features = jnp.array( - [[[1, 2]], [[3, 4]], [[7, 7]], [[8, 8]]], dtype=jnp.float64 + features = vb.VectorizedOptimizerInput( + continuous=jnp.array( + [[[1, 2]], [[3, 4]], [[7, 7]], [[8, 8]]], dtype=jnp.float64 + ), + categorical=jnp.array( + [[[1, 2]], [[3, 0]], [[0, 1]], [[2, 1]]], dtype=jnp.int32 + ), ) rewards = jnp.array([2, 3, 4, 1], dtype=jnp.float64) perturbations = jnp.array([1, 1, 1, 1], dtype=jnp.float64) - batch_features = jnp.array([[[9, 9]], [[10, 10]]], dtype=jnp.float64) + batch_features = vb.VectorizedOptimizerInput( + continuous=jnp.array([[[9, 9]], [[10, 10]]], dtype=jnp.float64), + categorical=jnp.array([[[0, 0]], [[1, 1]]], dtype=jnp.int32), + ) batch_rewards = jnp.array([5, 0.5], dtype=jnp.float64) new_features, new_rewards, new_perturbations = ( self.eagle._update_pool_features_and_rewards( batch_features, batch_rewards, - features[: self.eagle.batch_size], + jax.tree_util.tree_map( + lambda f: f[: self.eagle.batch_size], features + ), rewards[: self.eagle.batch_size], perturbations[: self.eagle.batch_size], ) ) np.testing.assert_array_equal( - new_features, + new_features.continuous, np.array([[[9, 9]], [[3, 4]]], dtype=np.float64), err_msg='Features are not equal.', ) + np.testing.assert_array_equal( + new_features.categorical, + np.array([[[0, 0]], [[3, 0]]], dtype=np.int32), + err_msg='Features are not equal.', + ) np.testing.assert_array_equal( new_rewards, @@ -164,8 +306,13 @@ def test_update_pool_features_and_rewards(self): def test_update_best_reward(self): # Test replacing the best reward. - features = jnp.array( - [[[1, 2]], [[3, 4]], [[7, 7]], [[8, 8]]], dtype=jnp.float64 + features = vb.VectorizedOptimizerInput( + continuous=jnp.array( + [[[1, 2]], [[3, 4]], [[7, 7]], [[8, 8]]], dtype=jnp.float64 + ), + categorical=jnp.array( + [[[1, 2]], [[3, 0]], [[0, 1]], [[2, 1]]], dtype=jnp.int32 + ), ) rewards = jnp.array([2, 3, 4, 1], dtype=jnp.float64) state = eagle_strategy.VectorizedEagleStrategyState( @@ -175,7 +322,10 @@ def test_update_best_reward(self): best_reward=jnp.max(rewards), perturbations=jnp.ones_like(rewards), ) - batch_features = jnp.array([[[9, 9]], [[10, 10]]], dtype=jnp.float64) + batch_features = vb.VectorizedOptimizerInput( + continuous=jnp.array([[[9, 9]], [[10, 10]]], dtype=jnp.float64), + categorical=jnp.array([[[0, 0]], [[1, 1]]], dtype=jnp.int32), + ) batch_rewards = jnp.array([5, 0.5], dtype=jnp.float64) seed = jax.random.PRNGKey(0) new_state = self.eagle.update(seed, state, batch_features, batch_rewards) @@ -200,7 +350,7 @@ def test_batch_size_and_pool_size( root = problem.search_space.root for i in range(100): root.add_float_param(f'x{i}', 0.0, 1.0) - converter = converters.TrialToArrayConverter.from_study_config(problem) + converter = converters.TrialToModelInputConverter.from_problem(problem) config = eagle_strategy.EagleStrategyConfig(max_pool_size=max_pool_size) eagle = eagle_strategy.VectorizedEagleStrategyFactory(eagle_config=config)( converter=converter, suggestion_batch_size=batch_size @@ -210,7 +360,10 @@ def test_batch_size_and_pool_size( def test_trim_pool(self): pc = self.config.perturbation - features_batch = jnp.array([[[1, 2]], [[3, 4]]], dtype=jnp.float64) + features_batch = vb.VectorizedOptimizerInput( + continuous=jnp.array([[[1, 2]], [[3, 4]]], dtype=jnp.float64), + categorical=jnp.array([[[1, 2]], [[3, 0]]], dtype=jnp.int32), + ) rewards_batch = jnp.array([2, 3], dtype=jnp.float64) perturbations = jnp.array([pc, 0], dtype=jnp.float64) seed = jax.random.PRNGKey(0) @@ -223,13 +376,30 @@ def test_trim_pool(self): ) np.testing.assert_array_almost_equal( - new_features[0], - features_batch[0], - err_msg='Features are not equal.', + new_features.continuous[0], + features_batch.continuous[0], + err_msg='Continuous features are not equal.', + ) + np.testing.assert_array_almost_equal( + new_features.categorical[0], + features_batch.categorical[0], + err_msg='Categorical features are not equal.', ) self.assertTrue( - np.all(np.not_equal(new_features[1], features_batch[1])), - msg='Features are not equal.', + np.all( + np.not_equal( + new_features.continuous[1], features_batch.continuous[1] + ) + ), + msg='Continuous features are not equal.', + ) + self.assertTrue( + np.all( + np.not_equal( + new_features.categorical[1], features_batch.categorical[1] + ) + ), + msg='Categorical features are not equal.', ) np.testing.assert_array_equal( @@ -251,28 +421,41 @@ def test_create_strategy_from_factory(self): root.add_float_param('x2', 0.0, 1.0) root.add_float_param('x3', 0.0, 1.0) eagle_factory = eagle_strategy.VectorizedEagleStrategyFactory() - converter = converters.TrialToArrayConverter.from_study_config(problem) + converter = converters.TrialToModelInputConverter.from_problem(problem) eagle = eagle_factory(converter) - self.assertEqual(eagle.n_feature_dimensions, 3) + self.assertEqual(eagle.n_feature_dimensions.continuous, 3) + self.assertEqual(eagle.n_feature_dimensions.categorical, 0) def test_optimize_with_eagle(self): + + eagle_factory = eagle_strategy.VectorizedEagleStrategyFactory() + optimizer = vb.VectorizedOptimizerFactory(strategy_factory=eagle_factory)( + self.converter + ) + optimizer( + score_fn=lambda x: -jnp.sum(x.continuous.padded_array, 1), count=1 + ) + + def test_optimize_with_eagle_continuous_only(self): problem = vz.ProblemStatement() root = problem.search_space.select_root() root.add_float_param('x1', 0.0, 1.0) root.add_float_param('x2', 0.0, 1.0) root.add_float_param('x3', 0.0, 1.0) - converter = converters.TrialToArrayConverter.from_study_config(problem) + converter = converters.TrialToModelInputConverter.from_problem(problem) eagle_factory = eagle_strategy.VectorizedEagleStrategyFactory() optimizer = vb.VectorizedOptimizerFactory(strategy_factory=eagle_factory)( converter ) n_parallel = 5 results = optimizer( - score_fn=lambda x: -jnp.sum(x, axis=(1, 2)), + score_fn=lambda x: -jnp.sum(x.continuous.padded_array, axis=(1, 2)), count=1, n_parallel=n_parallel, ) - self.assertSequenceEqual(results.features.shape, (1, n_parallel, 3)) + self.assertSequenceEqual( + results.features.continuous.shape, (1, n_parallel, 3) + ) def test_optimize_with_eagle_padding(self): problem = vz.ProblemStatement() @@ -280,7 +463,7 @@ def test_optimize_with_eagle_padding(self): root.add_float_param('x1', 0.0, 1.0) root.add_float_param('x2', 0.0, 1.0) root.add_float_param('x3', 0.0, 1.0) - converter = converters.PaddedTrialToArrayConverter.from_study_config( + converter = converters.TrialToModelInputConverter.from_problem( problem, padding_schedule=padding.PaddingSchedule( num_trials=padding.PaddingType.POWERS_OF_2, @@ -293,95 +476,38 @@ def test_optimize_with_eagle_padding(self): ) n_parallel = 2 results = optimizer( - score_fn=lambda x: -jnp.sum(x, axis=(1, 2)), + score_fn=lambda x: -jnp.sum(x.continuous.padded_array, axis=(1, 2)), count=1, n_parallel=n_parallel, ) - self.assertSequenceEqual(results.features.shape, (1, n_parallel, 4)) - - -class EagleParamHandlerTest(parameterized.TestCase): - - def setUp(self): - super(EagleParamHandlerTest, self).setUp() - problem = vz.ProblemStatement() - root = problem.search_space.select_root() - root.add_categorical_param('c1', ['a', 'b']) - root.add_float_param('f1', 0.0, 5.0) - root.add_categorical_param('c2', ['a', 'b', 'c']) - root.add_discrete_param('d1', [2.0, 3.0, 5.0, 11.0]) - converter = converters.TrialToArrayConverter.from_study_config( - problem, max_discrete_indices=0, pad_oovs=True - ) - self.config = eagle_strategy.EagleStrategyConfig() - self.param_handler = eagle_param_handler.EagleParamHandler.build( - converter=converter, - categorical_perturbation_factor=self.config.categorical_perturbation_factor, - pure_categorical_perturbation_factor=self.config.pure_categorical_perturbation_factor, - ) - - def test_init(self): - self.assertEqual(self.param_handler.n_feature_dimensions, 9) - self.assertLen( - self.param_handler.perturbation_factors, - self.param_handler.n_feature_dimensions, + self.assertSequenceEqual( + results.features.continuous.shape, (1, n_parallel, 4) ) - self.assertEqual(self.param_handler.n_categorical, 2) - def test_categorical_params_mask(self): - expected_categorical_params_mask = np.array( - [[1, 1, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 1, 1, 1, 0]] - ) - np.testing.assert_array_equal( - self.param_handler._categorical_params_mask, - expected_categorical_params_mask, - ) - - def test_categorical_mask(self): - expected_categorical_mask = np.array([1, 1, 1, 0, 1, 1, 1, 1, 0]) - np.testing.assert_array_equal( - self.param_handler._categorical_mask, expected_categorical_mask - ) - - def test_tiebreak_mask(self): - eps = self.param_handler._epsilon - expected_tiebreak_mask = np.array( - [-eps * (i + 1) for i in range(9)], dtype=float - ) - np.testing.assert_array_almost_equal( - self.param_handler._tiebreak_array, expected_tiebreak_mask - ) - - def test_categorical_oov_mask(self): - expected_oov_mask = np.array([1, 1, 0, 1, 1, 1, 1, 0, 1], dtype=float) - np.testing.assert_array_equal( - self.param_handler._oov_mask, expected_oov_mask - ) - - def test_perturbation_factors(self): - cp = self.config.categorical_perturbation_factor - expected_perturbation_factors = np.array( - [cp, cp, cp, 1, cp, cp, cp, cp, 1], dtype=float - ) - np.testing.assert_array_equal( - self.param_handler.perturbation_factors, expected_perturbation_factors - ) + def test_factory(self): + self.assertEqual(self.eagle.n_feature_dimensions.continuous, 2) + self.assertEqual(self.eagle.n_feature_dimensions.categorical, 2) + self.assertLen(self.eagle.categorical_sizes, 2) def test_sample_categorical_features(self): # features shouldn't have values in oov_mask, and have the structure of: # [c1,c1,c1,f1,c2,c2,c2,c2,d1] - features = jnp.array([ - [2.0, 0.0, 0.0, 1.5, 0.0, 0.0, 0.1, 0.0, 9.0], - [3.0, 0.0, 0.0, 3.5, 5.0, 0.0, 0.0, 0.0, 8.0], - ]) - expected_sampled_features = np.array([ - [1.0, 0.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0, 9.0], - [1.0, 0.0, 0.0, 3.5, 1.0, 0.0, 0.0, 0.0, 8.0], - ]) - sampled_features = self.param_handler.sample_categorical( - features, seed=jax.random.PRNGKey(0) - ) - np.testing.assert_array_equal(sampled_features, expected_sampled_features) + num_parallel = 3 + sampled_features = self.eagle._sample_random_features( + 20, n_parallel=num_parallel, seed=jax.random.PRNGKey(0) + ) + self.assertTrue( + np.all( + (sampled_features.continuous >= 0.0) + & (sampled_features.continuous <= 1.0) + ) + ) + self.assertTrue( + np.all( + sampled_features.categorical + < np.array(self.eagle.categorical_sizes) + ) + ) def test_prior_trials(self): config = eagle_strategy.EagleStrategyConfig( @@ -391,9 +517,16 @@ def test_prior_trials(self): root = problem.search_space.select_root() root.add_float_param('x1', 0.0, 1.0) root.add_float_param('x2', 0.0, 1.0) - converter = converters.TrialToArrayConverter.from_study_config(problem) + root.add_categorical_param('c1', ['a', 'b', 'c']) + converter = converters.TrialToModelInputConverter.from_problem(problem) - prior_features = jnp.array([[[1, -1]], [[2, 1]], [[3, 2]], [[4, 5]]]) + prior_features_continuous = jnp.array( + [[[1, -1]], [[2, 1]], [[3, 2]], [[4, 5]]] + ) + prior_features = vb.VectorizedOptimizerInput( + continuous=prior_features_continuous, + categorical=jnp.array([[0], [2], [1], [1]])[:, jnp.newaxis], + ) prior_rewards = jnp.array([1, 2, 3, 4]) eagle = eagle_strategy.VectorizedEagleStrategyFactory( eagle_config=config, @@ -403,8 +536,10 @@ def test_prior_trials(self): prior_features=prior_features, prior_rewards=prior_rewards, ) - np.testing.assert_array_equal( - init_state.features, jnp.flip(prior_features, axis=0) + jax.tree_util.tree_map( + lambda x, y: np.testing.assert_array_equal(y, jnp.flip(x, axis=0)), + init_state.features, + prior_features, ) @parameterized.parameters(2, 10) @@ -418,10 +553,14 @@ def test_prior_trials_with_too_few_or_many_trials(self, n_prior_trials): root = problem.search_space.select_root() root.add_float_param('x1', 0.0, 1.0) root.add_float_param('x2', 0.0, 1.0) - converter = converters.TrialToArrayConverter.from_study_config(problem) + root.add_categorical_param('c1', ['a', 'b', 'c']) + converter = converters.TrialToModelInputConverter.from_problem(problem) n_parallel = 3 - prior_features = np.random.randn(n_prior_trials, n_parallel, 2) + prior_features = vb.VectorizedOptimizerInput( + continuous=np.random.randn(n_prior_trials, n_parallel, 2), + categorical=np.random.randint(3, size=(n_prior_trials, n_parallel, 1)), + ) prior_rewards = np.random.randn(n_prior_trials) eagle = eagle_strategy.VectorizedEagleStrategyFactory( @@ -434,7 +573,10 @@ def test_prior_trials_with_too_few_or_many_trials(self, n_prior_trials): prior_rewards=prior_rewards, ) self.assertEqual( - init_state.features.shape, (config.pool_size, n_parallel, 2) + init_state.features.continuous.shape, (config.pool_size, n_parallel, 2) + ) + self.assertEqual( + init_state.features.categorical.shape, (config.pool_size, n_parallel, 1) ) diff --git a/vizier/_src/algorithms/optimizers/lbfgsb_optimizer.py b/vizier/_src/algorithms/optimizers/lbfgsb_optimizer.py index a9801a9db..e49c78632 100644 --- a/vizier/_src/algorithms/optimizers/lbfgsb_optimizer.py +++ b/vizier/_src/algorithms/optimizers/lbfgsb_optimizer.py @@ -25,6 +25,7 @@ from vizier import pyvizier as vz from vizier._src.algorithms.optimizers import vectorized_base from vizier._src.jax import stochastic_process_model as sp +from vizier._src.jax import types from vizier.jax import optimizers from vizier.pyvizier import converters @@ -42,7 +43,7 @@ class LBFGSBOptimizer: def optimize( self, - converter: converters.TrialToArrayConverter, + converter: converters.TrialToModelInputConverter, score_fn: Union[ vectorized_base.ParallelArrayScoreFunction, vectorized_base.ArrayScoreFunction, @@ -76,7 +77,9 @@ def optimize( optimize = optimizers.JaxoptScipyLbfgsB( optimizers.LbfgsBOptions(random_restarts=self.random_restarts) ) - num_features = sum(spec.num_dimensions for spec in converter.output_specs) + num_features = sum( + spec.num_dimensions for spec in converter.output_specs.continuous + ) feature_shape = [num_features] if self.num_parallel_candidates is not None: @@ -105,9 +108,15 @@ def wrapped_score_fn(x): ) new_rewards = np.asarray(score_fn(new_features[jnp.newaxis, ...]))[0] if self.num_parallel_candidates is None: - parameters = converter.to_parameters(new_features[jnp.newaxis, ...]) - else: - parameters = converter.to_parameters(new_features) + new_features = new_features[jnp.newaxis, ...] + parameters = converter.to_parameters( + types.ModelInput( + continuous=types.PaddedArray.as_padded(new_features), + categorical=types.PaddedArray.as_padded( + jnp.zeros(new_features.shape[:-1] + (0,), dtype=types.INT_DTYPE) + ), + ) + ) trials = [] for i in range(len(parameters)): trial = vz.Trial(parameters=parameters[i]) diff --git a/vizier/_src/algorithms/optimizers/lbfgsb_optimizer_test.py b/vizier/_src/algorithms/optimizers/lbfgsb_optimizer_test.py index 3c6544d1e..1e25b7353 100644 --- a/vizier/_src/algorithms/optimizers/lbfgsb_optimizer_test.py +++ b/vizier/_src/algorithms/optimizers/lbfgsb_optimizer_test.py @@ -33,7 +33,7 @@ def test_optimize_candidates_len(self): problem.search_space.root.add_float_param('f1', 0.0, 10.0) problem.search_space.root.add_float_param('f2', 0.0, 10.0) problem.search_space.root.add_float_param('f3', 0.0, 10.0) - converter = converters.TrialToArrayConverter.from_study_config(problem) + converter = converters.TrialToModelInputConverter.from_problem(problem) score_fn = lambda x: jnp.sum(x, axis=-1) optimizer = lo.LBFGSBOptimizer(random_restarts=10) res = optimizer.optimize(converter=converter, score_fn=score_fn) @@ -43,7 +43,7 @@ def test_best_candidates_count_is_1(self): problem = vz.ProblemStatement() problem.search_space.root.add_float_param('f1', 0.0, 1.0) problem.search_space.root.add_float_param('f2', 0.0, 1.0) - converter = converters.TrialToArrayConverter.from_study_config(problem) + converter = converters.TrialToModelInputConverter.from_problem(problem) score_fn = lambda x: -jnp.sum(jnp.square(x - 0.52), axis=-1) optimizer = lo.LBFGSBOptimizer(random_restarts=10) candidates = optimizer.optimize(converter=converter, score_fn=score_fn) @@ -66,7 +66,7 @@ def test_batch_candidates(self): problem.search_space.root.add_float_param('f1', 0.0, 1.0) problem.search_space.root.add_float_param('f2', 0.0, 1.0) problem.search_space.root.add_float_param('f3', 0.0, 1.0) - converter = converters.TrialToArrayConverter.from_study_config(problem) + converter = converters.TrialToModelInputConverter.from_problem(problem) # Minimize sum over all features. score_fn = lambda x: -jnp.sum(jnp.square(x - 0.52), axis=[-1, -2]) optimizer = lo.LBFGSBOptimizer( diff --git a/vizier/_src/algorithms/optimizers/random_vectorized_optimizer.py b/vizier/_src/algorithms/optimizers/random_vectorized_optimizer.py index f75613ab1..905046416 100644 --- a/vizier/_src/algorithms/optimizers/random_vectorized_optimizer.py +++ b/vizier/_src/algorithms/optimizers/random_vectorized_optimizer.py @@ -19,31 +19,54 @@ from typing import Optional import jax +from jax import numpy as jnp +import numpy as np +from tensorflow_probability.substrates import jax as tfp from vizier._src.algorithms.optimizers import vectorized_base as vb from vizier._src.jax import types from vizier.pyvizier import converters +tfd = tfp.distributions -class RandomVectorizedStrategy(vb.VectorizedStrategy): + +class RandomVectorizedStrategy(vb.VectorizedStrategy[None]): """Random vectorized strategy.""" def __init__( self, - converter: converters.TrialToArrayConverter, + converter: converters.TrialToModelInputConverter, suggestion_batch_size: int, ): - self._converter = converter - self._suggestion_batch_size = suggestion_batch_size - self._n_features = sum( - [spec.num_dimensions for spec in self._converter.output_specs] + empty_features = converter.to_features([]) + n_feature_dimensions_with_padding = types.ContinuousAndCategorical( + empty_features.continuous.shape[-1], + empty_features.categorical.shape[-1], ) + categorical_sizes = [] + for spec in converter.output_specs.categorical: + categorical_sizes.append(spec.bounds[1]) + + self._suggestion_batch_size = suggestion_batch_size + self.n_feature_dimensions_with_padding = n_feature_dimensions_with_padding + self.n_feature_dimensions = n_feature_dimensions_with_padding + self.dtype = types.ContinuousAndCategorical(jnp.float64, types.INT_DTYPE) + + self._categorical_logits = None + if categorical_sizes: + categorical_logits = np.zeros( + [len(categorical_sizes), max(categorical_sizes)] + ) + for i, s in enumerate(categorical_sizes): + categorical_logits[i, s:] = -np.inf + self._categorical_logits = categorical_logits + def init_state( self, seed: jax.random.KeyArray, n_parallel: int = 1, *, - prior_features: Optional[types.Array] = None, + prior_features: Optional[vb.VectorizedOptimizerInput] = None, prior_rewards: Optional[types.Array] = None, ) -> None: del seed @@ -54,12 +77,26 @@ def suggest( seed: jax.random.KeyArray, state: None, n_parallel: int = 1, - ) -> jax.Array: + ) -> vb.VectorizedOptimizerInput: del state - return jax.random.uniform( - seed, - shape=(self._suggestion_batch_size, n_parallel, self._n_features), + cont_seed, cat_seed = jax.random.split(seed) + cont_data = jax.random.uniform( + cont_seed, + shape=( + self._suggestion_batch_size, + n_parallel, + self.n_feature_dimensions_with_padding.continuous, + ), ) + if self._categorical_logits is None: + cat_data = jnp.zeros( + [self._suggestion_batch_size, n_parallel, 0], dtype=jnp.int32 + ) + else: + cat_data = tfd.Categorical(logits=self._categorical_logits).sample( + (self._suggestion_batch_size, n_parallel), seed=cat_seed + ) + return vb.VectorizedOptimizerInput(cont_data, cat_data) def suggestion_batch_size(self) -> int: return self._suggestion_batch_size @@ -68,14 +105,14 @@ def update( self, seed: jax.random.KeyArray, state: None, - batch_features: types.Array, + batch_features: vb.VectorizedOptimizerInput, batch_rewards: types.Array, ) -> None: return def random_strategy_factory( - converter: converters.TrialToArrayConverter, + converter: converters.TrialToModelInputConverter, suggestion_batch_size: int, ) -> vb.VectorizedStrategy: """Creates a new vectorized strategy based on the Protocol.""" @@ -86,7 +123,7 @@ def random_strategy_factory( def create_random_optimizer( - converter: converters.TrialToArrayConverter, + converter: converters.TrialToModelInputConverter, max_evaluations: int, suggestion_batch_size: int, ) -> vb.VectorizedOptimizer: diff --git a/vizier/_src/algorithms/optimizers/random_vectorized_optimizer_test.py b/vizier/_src/algorithms/optimizers/random_vectorized_optimizer_test.py index a3b49e267..2313a3a72 100644 --- a/vizier/_src/algorithms/optimizers/random_vectorized_optimizer_test.py +++ b/vizier/_src/algorithms/optimizers/random_vectorized_optimizer_test.py @@ -31,15 +31,15 @@ def test_random_optimizer(self): problem = vz.ProblemStatement() problem.search_space.root.add_float_param('f1', 0.0, 10.0) problem.search_space.root.add_float_param('f2', 0.0, 10.0) - converter = converters.TrialToArrayConverter.from_study_config(problem) - score_fn = lambda x: np.sum(x, axis=(1, 2)) + converter = converters.TrialToModelInputConverter.from_problem(problem) + score_fn = lambda x: np.sum(x.continuous.padded_array, axis=(-1, -2)) n_parallel = 3 random_optimizer = rvo.create_random_optimizer( converter=converter, max_evaluations=100, suggestion_batch_size=10 ) res = random_optimizer(score_fn=score_fn, count=5, n_parallel=n_parallel) self.assertLen(res.rewards, 5) - self.assertSequenceEqual(res.features.shape, (5, n_parallel, 2)) + self.assertSequenceEqual(res.features.continuous.shape, (5, n_parallel, 2)) def test_random_optimizer_factory(self): random_optimizer_factory = rvo.create_random_optimizer_factory( diff --git a/vizier/_src/algorithms/optimizers/vectorized_base.py b/vizier/_src/algorithms/optimizers/vectorized_base.py index 9fd9bf27b..20b3a6b53 100644 --- a/vizier/_src/algorithms/optimizers/vectorized_base.py +++ b/vizier/_src/algorithms/optimizers/vectorized_base.py @@ -29,21 +29,73 @@ from vizier import pyvizier as vz from vizier._src.jax import types from vizier.pyvizier import converters -from vizier.pyvizier.converters import feature_mapper from vizier.utils import json_utils - _S = TypeVar('_S') # A container of optimizer state that works as a Pytree. -ArrayConverter = Union[ - converters.TrialToArrayConverter, converters.PaddedTrialToArrayConverter -] +# Each component has shape (batch_size, n_parallel, n_padded_features) +VectorizedOptimizerInput = types.ContinuousAndCategorical[types.Array] + + +def _optimizer_to_model_input_single_array( + x: types.Array, n_features: jax.Array +) -> types.PaddedArray: + mask = jnp.ones_like(x, dtype=bool) + mask = jnp.logical_and(mask, jnp.arange(x.shape[-1]) < n_features) + return types.PaddedArray( + x, + fill_value=jnp.zeros([], dtype=x.dtype), + _original_shape=jnp.concatenate( + [jnp.array(x.shape[:-1]), jnp.array([n_features])], axis=0 + ), + _mask=mask, + _nopadding_done=False, # Fix + ) + + +def _optimizer_to_model_input( + x: VectorizedOptimizerInput, + n_features: types.ContinuousAndCategorical, + squeeze_middle_dim: bool = False, +) -> types.ModelInput: + if squeeze_middle_dim: + x_cont = jnp.squeeze(x.continuous, axis=1) + x_cat = jnp.squeeze(x.categorical, axis=1) + else: + x_cont = x.continuous + x_cat = x.categorical + return types.ModelInput( + continuous=_optimizer_to_model_input_single_array( + x_cont, n_features.continuous + ), + categorical=_optimizer_to_model_input_single_array( + x_cat, n_features.categorical + ), + ) + + +# pylint: disable=protected-access +def _reshape_to_parallel_batches( + x: types.PaddedArray, parallel_dim: int +) -> tuple[jax.Array, jax.Array]: + """Docstring.""" + + new_batch_dim = x.shape[0] // parallel_dim + new_padded_array = jnp.reshape( + x.padded_array[: new_batch_dim * parallel_dim], + (new_batch_dim, parallel_dim, x.shape[-1]), + ) + + valid_batch_mask = ( + jnp.arange(new_batch_dim) < x._original_shape[0] // parallel_dim + ) + return new_padded_array, valid_batch_mask class VectorizedStrategyResults(eqx.Module): """Container for a vectorized strategy result.""" - features: types.Array # (batch_size, n_parallel, n_features) + features: VectorizedOptimizerInput # (batch_size, n_parallel, n_features) rewards: types.Array # (batch_size,) aux: dict[str, jax.Array] = eqx.field(default_factory=dict) @@ -62,7 +114,7 @@ def init_state( seed: jax.random.KeyArray, n_parallel: int = 1, *, - prior_features: Optional[types.Array] = None, + prior_features: Optional[VectorizedOptimizerInput] = None, prior_rewards: Optional[types.Array] = None, ) -> _S: """Initialize the state. @@ -85,7 +137,7 @@ def suggest( seed: jax.random.KeyArray, state: _S, n_parallel: int = 1, - ) -> jax.Array: + ) -> VectorizedOptimizerInput: """Generate new suggestions. Arguments: @@ -109,7 +161,7 @@ def update( self, seed: jax.random.KeyArray, state: _S, - batch_features: types.Array, + batch_features: VectorizedOptimizerInput, batch_rewards: types.Array, ) -> _S: """Update the strategy state with the results of the last suggestions. @@ -131,7 +183,7 @@ class VectorizedStrategyFactory(Protocol): def __call__( self, - converter: ArrayConverter, + converter: converters.TrialToModelInputConverter, *, suggestion_batch_size: int, ) -> VectorizedStrategy: @@ -151,7 +203,7 @@ class ArrayScoreFunction(Protocol): its own separate score). """ - def __call__(self, batched_array_trials: types.Array) -> types.Array: + def __call__(self, batched_array_trials: types.ModelInput) -> types.Array: """Evaluates the array of batched trials. Arguments: @@ -171,7 +223,7 @@ class ParallelArrayScoreFunction(Protocol): (e.g. qUCB). """ - def __call__(self, parallel_array_trials: types.Array) -> types.Array: + def __call__(self, parallel_array_trials: types.ModelInput) -> types.Array: """Evaluates the array of batched trials. Arguments: @@ -184,7 +236,7 @@ def __call__(self, parallel_array_trials: types.Array) -> types.Array: @struct.dataclass -class VectorizedOptimizer: +class VectorizedOptimizer(Generic[_S]): """Vectorized strategy optimizer. The optimizer is stateless and will create a new vectorized strategy at the @@ -208,14 +260,23 @@ class VectorizedOptimizer: `n_feature_dimensions_with_padding`). suggestion_batch_size: Number of suggested points returned at each call. max_evaluations: The maximum number of objective function evaluations. + dtype: Dtype of input data. use_fori: Whether to use JAX's fori_loop in the suggest-evalute-update loop. """ - strategy: VectorizedStrategy - n_feature_dimensions: int - n_feature_dimensions_with_padding: int = struct.field(pytree_node=False) + strategy: VectorizedStrategy[_S] + n_feature_dimensions: types.ContinuousAndCategorical[jax.Array] + n_feature_dimensions_with_padding: types.ContinuousAndCategorical[int] = ( + struct.field(pytree_node=False) + ) suggestion_batch_size: int = struct.field(pytree_node=False, default=25) max_evaluations: int = struct.field(pytree_node=False, default=75_000) + dtype: types.ContinuousAndCategorical[jnp.dtype] = struct.field( + pytree_node=False, + default=types.ContinuousAndCategorical[jnp.dtype]( + jnp.float64, types.INT_DTYPE + ), + ) use_fori: bool = struct.field(pytree_node=False, default=True) # TODO: Remove score_fn argument. @@ -226,7 +287,7 @@ def __call__( *, score_with_aux_fn: Optional[Callable] = None, count: int = 1, - prior_features: Optional[types.Array] = None, + prior_features: Optional[types.ModelInput] = None, n_parallel: Optional[int] = None, seed: Optional[int] = None, ) -> VectorizedStrategyResults: @@ -276,29 +337,38 @@ def __call__( """ seed = jax.random.PRNGKey(0) if seed is None else seed - dimension_is_missing = ( - jnp.arange(self.n_feature_dimensions_with_padding) - > self.n_feature_dimensions - ) - if n_parallel is None: # Squeeze out the singleton dimension of `features` before passing to a # non-parallel acquisition function to avoid batch shape collisions. - eval_score_fn = lambda x: score_fn(x[:, 0, :]) + eval_score_fn = lambda x: score_fn( # pylint: disable=g-long-lambda + _optimizer_to_model_input( + x, self.n_feature_dimensions, squeeze_middle_dim=True + ) + ) else: - eval_score_fn = score_fn + eval_score_fn = lambda x: score_fn( # pylint: disable=g-long-lambda + _optimizer_to_model_input(x, self.n_feature_dimensions) + ) # TODO: We should pass RNGKey to score_fn. prior_rewards = None parallel_dim = n_parallel or 1 if prior_features is not None: - num_prior_obs = prior_features.shape[0] - num_prior_batches = num_prior_obs // parallel_dim - prior_features = jnp.reshape( - prior_features[: num_prior_batches * parallel_dim], - [num_prior_batches, parallel_dim, -1], + continuous_prior, continuous_mask = _reshape_to_parallel_batches( + prior_features.continuous, parallel_dim + ) + categorical_prior, categorical_mask = _reshape_to_parallel_batches( + prior_features.categorical, parallel_dim + ) + prior_features = VectorizedOptimizerInput( + continuous=continuous_prior, categorical=categorical_prior ) prior_rewards = eval_score_fn(prior_features) + prior_rewards = jnp.where( + jnp.logical_and(continuous_mask, categorical_mask), + prior_rewards, + -jnp.inf * jnp.ones_like(prior_rewards), + ) def _optimization_one_step(_, args): state, best_results, seed = args @@ -306,11 +376,6 @@ def _optimization_one_step(_, args): new_features = self.strategy.suggest( suggest_seed, state=state, n_parallel=parallel_dim ) - # Ensure masking out padded dimensions in new features. - new_features = jnp.where( - dimension_is_missing, jnp.zeros_like(new_features), new_features - ) - # We assume `score_fn` is aware of padded dimensions. new_rewards = eval_score_fn(new_features) new_state = self.strategy.update( update_seed, state, new_features, new_rewards @@ -323,8 +388,23 @@ def _optimization_one_step(_, args): init_seed, loop_seed = jax.random.split(seed) init_best_results = VectorizedStrategyResults( rewards=-jnp.inf * jnp.ones([count]), - features=jnp.zeros( - [count, parallel_dim, self.n_feature_dimensions_with_padding] + features=VectorizedOptimizerInput( + continuous=jnp.zeros( + [ + count, + parallel_dim, + self.n_feature_dimensions_with_padding.continuous, + ], + dtype=self.dtype.continuous, + ), + categorical=jnp.zeros( + [ + count, + parallel_dim, + self.n_feature_dimensions_with_padding.categorical, + ], + dtype=self.dtype.categorical, + ), ), ) init_args = ( @@ -354,9 +434,19 @@ def _optimization_one_step(_, args): if score_with_aux_fn: if n_parallel is None: - aux = score_with_aux_fn(best_results.features[:, 0, :])[1] + aux = score_with_aux_fn( + _optimizer_to_model_input( + best_results.features, + self.n_feature_dimensions, + squeeze_middle_dim=True, + ) + )[1] else: - aux = score_with_aux_fn(best_results.features)[1] + aux = score_with_aux_fn( + _optimizer_to_model_input( + best_results.features, self.n_feature_dimensions + ) + )[1] return VectorizedStrategyResults( best_results.features, @@ -370,7 +460,7 @@ def _update_best_results( self, best_results: VectorizedStrategyResults, count: int, - batch_features: jax.Array, + batch_features: VectorizedOptimizerInput, batch_rewards: jax.Array, ) -> VectorizedStrategyResults: """Update the best results the optimizer seen thus far. @@ -392,44 +482,63 @@ def _update_best_results( trials: """ all_rewards = jnp.concatenate([batch_rewards, best_results.rewards], axis=0) - all_features = jnp.concatenate( - [batch_features, best_results.features], axis=0 + all_features = VectorizedOptimizerInput( + continuous=jnp.concatenate( + [batch_features.continuous, best_results.features.continuous], + axis=0, + ), + categorical=jnp.concatenate( + [batch_features.categorical, best_results.features.categorical], + axis=0, + ), ) top_indices = jnp.argpartition(-all_rewards, count - 1)[:count] return VectorizedStrategyResults( rewards=all_rewards[top_indices], - features=all_features[top_indices], + features=VectorizedOptimizerInput( + continuous=all_features.continuous[top_indices], + categorical=all_features.categorical[top_indices], + ), ) # TODO: Should return suggestions not trials. def best_candidates_to_trials( best_results: VectorizedStrategyResults, - converter: ArrayConverter, + converter: converters.TrialToModelInputConverter, ) -> list[vz.Trial]: """Returns the best candidate trials in the original search space.""" + best_features = best_results.features trials = [] sorted_ind = jnp.argsort(-best_results.rewards) - features = best_results.features - if isinstance(features, types.ContinuousAndCategorical): - features = feature_mapper.ContinuousCategoricalFeatureMapper( - converter - ).unmap(features) for i in range(len(best_results.rewards)): # Create trials and convert the strategy features back to parameters. ind = sorted_ind[i] - suggested_features = features[ind] + suggested_features = VectorizedOptimizerInput( + best_features.continuous[ind], best_features.categorical[ind] + ) reward = best_results.rewards[ind] # Loop over the number of candidates per batch (which will be one, unless a # parallel acquisition function is used). - for j in range(suggested_features.shape[0]): + for j in range(suggested_features.continuous.shape[0]): + features = VectorizedOptimizerInput( + continuous=jnp.expand_dims(suggested_features.continuous[j], axis=0), + categorical=jnp.expand_dims( + suggested_features.categorical[j], axis=0 + ), + ) trial = vz.Trial( parameters=converter.to_parameters( - jnp.expand_dims(suggested_features[j], axis=0) + _optimizer_to_model_input( + features, + n_features=types.ContinuousAndCategorical( + len(converter.output_specs.continuous), + len(converter.output_specs.categorical), + ), + ) )[0] ) - metadata = trial.metadata.ns('devinfo') metadata['acquisition_optimization'] = json.dumps( {'acquisition': best_results.rewards[ind]} @@ -447,18 +556,12 @@ def best_candidates_to_trials( # TODO: This function should return jax types. def trials_to_sorted_array( prior_trials: list[vz.Trial], - converter: ArrayConverter, -) -> Optional[types.Array]: + converter: converters.TrialToModelInputConverter, +) -> Optional[types.ModelInput]: """Sorts trials by the order they were created and converts to array.""" if prior_trials: prior_trials = sorted(prior_trials, key=lambda x: x.creation_time) prior_features = converter.to_features(prior_trials) - # TODO: Update this code to work more cleanly with - # PaddedArrays. - if isinstance(converter, converters.PaddedTrialToArrayConverter): - # We need to mask out the `NaN` padded trials with zeroes. - prior_features = np.array(prior_features.padded_array) - prior_features[len(prior_trials) :, ...] = 0.0 else: prior_features = None return prior_features @@ -475,16 +578,29 @@ class VectorizedOptimizerFactory: def __call__( self, - converter: ArrayConverter, + converter: converters.TrialToModelInputConverter, ) -> VectorizedOptimizer: """Generates a new VectorizedOptimizer object.""" strategy = self.strategy_factory( converter, suggestion_batch_size=self.suggestion_batch_size ) - n_feature_dimensions = sum( - spec.num_dimensions for spec in converter.output_specs + n_feature_dimensions = getattr( + strategy, + 'n_feature_dimensions', + types.ContinuousAndCategorical( + len(converter.output_specs.continuous), + len(converter.output_specs.categorical), + ), + ) + empty_features = converter.to_features([]) + n_feature_dimensions_with_padding = getattr( + strategy, + 'n_feature_dimensions_with_padding', + types.ContinuousAndCategorical[int]( + empty_features.continuous.shape[-1], + empty_features.categorical.shape[-1], + ), ) - n_feature_dimensions_with_padding = converter.to_features([]).shape[-1] return VectorizedOptimizer( strategy=strategy, n_feature_dimensions=n_feature_dimensions, @@ -492,4 +608,5 @@ def __call__( suggestion_batch_size=self.suggestion_batch_size, max_evaluations=self.max_evaluations, use_fori=self.use_fori, + dtype=converter._impl.dtype, ) diff --git a/vizier/_src/algorithms/optimizers/vectorized_base_test.py b/vizier/_src/algorithms/optimizers/vectorized_base_test.py index d46d533fc..efc872613 100644 --- a/vizier/_src/algorithms/optimizers/vectorized_base_test.py +++ b/vizier/_src/algorithms/optimizers/vectorized_base_test.py @@ -29,6 +29,8 @@ from absl.testing import absltest from absl.testing import parameterized +# pylint: disable=g-long-lambda + @chex.dataclass(frozen=True) class FakeIncrementVectorizedStrategyState: @@ -37,7 +39,9 @@ class FakeIncrementVectorizedStrategyState: iterations: int -class FakeIncrementVectorizedStrategy(vb.VectorizedStrategy): +class FakeIncrementVectorizedStrategy( + vb.VectorizedStrategy[FakeIncrementVectorizedStrategyState] +): """Fake vectorized strategy with incrementing suggestions.""" def __init__(self, *args, **kwargs): @@ -48,7 +52,7 @@ def suggest( seed: jax.random.KeyArray, state: FakeIncrementVectorizedStrategyState, n_parallel: int = 1, - ) -> jax.Array: + ) -> vb.VectorizedOptimizerInput: # The following structure allows to test the top K results. i = state.iterations suggestions = ( @@ -61,7 +65,10 @@ def suggest( ]) / 10 ) - return jnp.repeat(suggestions[:, jnp.newaxis, :], n_parallel, axis=1) + return vb.VectorizedOptimizerInput( + jnp.repeat(suggestions[:, jnp.newaxis, :], n_parallel, axis=1), + jnp.zeros([5, n_parallel, 0], dtype=types.INT_DTYPE), + ) @property def suggestion_batch_size(self) -> int: @@ -71,7 +78,7 @@ def update( self, seed: jax.random.KeyArray, state: FakeIncrementVectorizedStrategyState, - batch_features: types.Array, + batch_features: vb.VectorizedOptimizerInput, batch_rewards: types.Array, ) -> FakeIncrementVectorizedStrategyState: return FakeIncrementVectorizedStrategyState(iterations=state.iterations + 5) @@ -81,7 +88,7 @@ def init_state( seed: jax.random.KeyArray, n_parallel: int = 1, *, - prior_features: Optional[types.Array] = None, + prior_features: Optional[vb.VectorizedOptimizerInput] = None, prior_rewards: Optional[types.Array] = None, ) -> FakeIncrementVectorizedStrategyState: del seed @@ -90,7 +97,7 @@ def init_state( # pylint: disable=unused-argument def fake_increment_strategy_factory( - converter: converters.TrialToArrayConverter, + converter: converters.TrialToModelInputConverter, suggestion_batch_size: int, ) -> vb.VectorizedStrategy: return FakeIncrementVectorizedStrategy() @@ -100,11 +107,13 @@ def fake_increment_strategy_factory( class FakePriorTrialsStrategyState: """State for FakeIncrementVectorizedStrategy.""" - features: types.Array + features: vb.VectorizedOptimizerInput rewards: types.Array -class FakePriorTrialsVectorizedStrategy(vb.VectorizedStrategy): +class FakePriorTrialsVectorizedStrategy( + vb.VectorizedStrategy[FakePriorTrialsStrategyState] +): """Fake vectorized strategy to test prior trials.""" def init_state( @@ -112,7 +121,7 @@ def init_state( seed: jax.random.KeyArray, n_parallel: int = 1, *, - prior_features: Optional[types.Array] = None, + prior_features: Optional[vb.VectorizedOptimizerInput] = None, prior_rewards: Optional[types.Array] = None, ): if prior_rewards is not None and len(prior_rewards.shape) != 1: @@ -126,8 +135,13 @@ def suggest( seed: jax.random.KeyArray, state: FakePriorTrialsStrategyState, n_parallel: int = 1, - ) -> jax.Array: - return state.features[jnp.argmax(state.rewards, axis=-1)][jnp.newaxis, :] + ) -> vb.VectorizedOptimizerInput: + return vb.VectorizedOptimizerInput( + continuous=state.features.continuous[ + jnp.argmax(state.rewards, axis=-1) + ][jnp.newaxis, :], + categorical=jnp.zeros([1, n_parallel, 0], dtype=types.INT_DTYPE), + ) @property def suggestion_batch_size(self) -> int: @@ -137,7 +151,7 @@ def update( self, seed: jax.random.KeyArray, state: FakePriorTrialsStrategyState, - batch_features: types.Array, + batch_features: vb.VectorizedOptimizerInput, batch_rewards: types.Array, ) -> FakePriorTrialsStrategyState: return state @@ -145,7 +159,7 @@ def update( # pylint: disable=unused-argument def fake_prior_trials_strategy_factory( - converter: converters.TrialToArrayConverter, + converter: converters.TrialToModelInputConverter, suggestion_batch_size: int, ) -> vb.VectorizedStrategy: return FakePriorTrialsVectorizedStrategy() @@ -158,8 +172,8 @@ def test_optimize_candidates_len(self, count): problem = vz.ProblemStatement() problem.search_space.root.add_float_param('f1', 0.0, 10.0) problem.search_space.root.add_float_param('f2', 0.0, 10.0) - converter = converters.TrialToArrayConverter.from_study_config(problem) - score_fn = lambda x: jnp.sum(x, axis=-1) + converter = converters.TrialToModelInputConverter.from_problem(problem) + score_fn = lambda x: jnp.sum(x.continuous.padded_array, axis=-1) optimizer = vb.VectorizedOptimizerFactory( strategy_factory=fake_increment_strategy_factory, max_evaluations=100, @@ -176,8 +190,8 @@ def test_optimize_parallel_candidates_len(self, count, n_parallel): problem = vz.ProblemStatement() problem.search_space.root.add_float_param('f1', 0.0, 10.0) problem.search_space.root.add_float_param('f2', 0.0, 10.0) - converter = converters.TrialToArrayConverter.from_study_config(problem) - score_fn = lambda x: jnp.sum(x, axis=(-1, -2)) + converter = converters.TrialToModelInputConverter.from_problem(problem) + score_fn = lambda x: jnp.sum(x.continuous.padded_array, axis=(-1, -2)) optimizer = vb.VectorizedOptimizerFactory( strategy_factory=fake_increment_strategy_factory, max_evaluations=100, @@ -191,8 +205,10 @@ def test_best_candidates_count_is_1(self, use_fori): problem = vz.ProblemStatement() problem.search_space.root.add_float_param('f1', 0.0, 1.0) problem.search_space.root.add_float_param('f2', 0.0, 1.0) - converter = converters.TrialToArrayConverter.from_study_config(problem) - score_fn = lambda x: -jnp.max(jnp.square(x - 0.52), axis=-1) + converter = converters.TrialToModelInputConverter.from_problem(problem) + score_fn = lambda x: -jnp.max( + jnp.square(x.continuous.padded_array - 0.52), axis=-1 + ) strategy_factory = FakeIncrementVectorizedStrategy optimizer = vb.VectorizedOptimizerFactory( strategy_factory=strategy_factory, @@ -217,8 +233,10 @@ def test_best_candidates_count_is_3(self, use_fori): problem = vz.ProblemStatement() problem.search_space.root.add_float_param('f1', 0.0, 1.0) problem.search_space.root.add_float_param('f2', 0.0, 1.0) - converter = converters.TrialToArrayConverter.from_study_config(problem) - score_fn = lambda x: -jnp.max(jnp.square(x - 0.52), axis=-1) + converter = converters.TrialToModelInputConverter.from_problem(problem) + score_fn = lambda x: -jnp.max( + jnp.square(x.continuous.padded_array - 0.52), axis=-1 + ) optimizer = vb.VectorizedOptimizerFactory( strategy_factory=fake_increment_strategy_factory, suggestion_batch_size=5, @@ -254,7 +272,7 @@ def test_best_candidates_count_is_3(self, use_fori): def test_vectorized_optimizer_factory(self): problem = vz.ProblemStatement() problem.search_space.root.add_float_param('f1', 0.0, 1.0) - converter = converters.TrialToArrayConverter.from_study_config(problem) + converter = converters.TrialToModelInputConverter.from_problem(problem) optimizer_factory = vb.VectorizedOptimizerFactory( strategy_factory=fake_increment_strategy_factory, suggestion_batch_size=5, @@ -284,7 +302,7 @@ def test_prior_trials(self, use_fori): root = study_config.search_space.root root.add_float_param('x1', 0.0, 10.0) root.add_float_param('x2', 0.0, 10.0) - converter = converters.TrialToArrayConverter.from_study_config(study_config) + converter = converters.TrialToModelInputConverter.from_problem(study_config) optimizer = optimizer_factory(converter) trial1 = vz.Trial(parameters={'x1': 1, 'x2': 1}) @@ -299,7 +317,9 @@ def test_prior_trials(self, use_fori): [trial1, trial2, trial1], converter=converter ) best_trial_array = optimizer( - lambda x: -jnp.max(jnp.square(x - 0.52), axis=-1), + lambda x: -jnp.max( + jnp.square(x.continuous.padded_array - 0.52), axis=-1 + ), count=1, prior_features=prior_features, ) @@ -310,7 +330,9 @@ def test_prior_trials(self, use_fori): self.assertEqual(best_trial[0].parameters['x2'].value, 2) best_trial_array = optimizer( - lambda x: -jnp.max(jnp.square(x - 0.52), axis=-1), + lambda x: -jnp.max( + jnp.square(x.continuous.padded_array - 0.52), axis=-1 + ), count=1, prior_features=vb.trials_to_sorted_array([trial1], converter=converter), ) @@ -339,15 +361,24 @@ def test_prior_trials_parallel(self, use_fori): root.add_float_param('x1', 0.0, 10.0) root.add_float_param('x2', 0.0, 10.0) root.add_float_param('x3', 0.0, 10.0) - converter = converters.TrialToArrayConverter.from_study_config(study_config) + converter = converters.TrialToModelInputConverter.from_problem(study_config) optimizer = optimizer_factory(converter) - prior_features = jax.random.uniform(jax.random.PRNGKey(0), (14, 3)) + prior_features = types.ModelInput( + continuous=types.PaddedArray.as_padded( + jax.random.uniform(jax.random.PRNGKey(0), (14, 3)) + ), + categorical=types.PaddedArray.as_padded( + jnp.zeros([14, 0], dtype=types.INT_DTYPE) + ), + ) suggestions = optimizer( - lambda x: -jnp.max(jnp.square(x - 0.52), axis=(-1, -2)), + lambda x: -jnp.max( + jnp.square(x.continuous.padded_array - 0.52), axis=(-1, -2) + ), prior_features=prior_features, n_parallel=2, ) - self.assertSequenceEqual(suggestions.features.shape, (1, 2, 3)) + self.assertSequenceEqual(suggestions.features.continuous.shape, (1, 2, 3)) self.assertSequenceEqual(suggestions.rewards.shape, (1,)) diff --git a/vizier/_src/algorithms/testing/comparator_runner.py b/vizier/_src/algorithms/testing/comparator_runner.py index 5da46d2c6..70a15f145 100644 --- a/vizier/_src/algorithms/testing/comparator_runner.py +++ b/vizier/_src/algorithms/testing/comparator_runner.py @@ -144,7 +144,7 @@ class SimpleRegretComparisonTester: def assert_optimizer_better_simple_regret( self, - converter: converters.TrialToArrayConverter, + converter: converters.TrialToModelInputConverter, score_fn: vb.ArrayScoreFunction, baseline_strategy_factory: vb.VectorizedStrategyFactory, candidate_strategy_factory: vb.VectorizedStrategyFactory, diff --git a/vizier/_src/algorithms/testing/comparator_runner_test.py b/vizier/_src/algorithms/testing/comparator_runner_test.py index 1714403b1..dceb14d4c 100644 --- a/vizier/_src/algorithms/testing/comparator_runner_test.py +++ b/vizier/_src/algorithms/testing/comparator_runner_test.py @@ -39,7 +39,7 @@ class FakeVectorizedStrategy(vb.VectorizedStrategy): def __init__( self, - converter: converters.TrialToArrayConverter, + converter: converters.TrialToModelInputConverter, good_value: float = 1.0, bad_value: float = 0.0, num_trial_to_converge: int = 0, @@ -55,7 +55,7 @@ def init_state( seed: jax.random.KeyArray, n_parallel: int = 1, *, - prior_features: Optional[types.Array] = None, + prior_features: Optional[vb.VectorizedOptimizerInput] = None, prior_rewards: Optional[types.Array] = None, ) -> None: return @@ -65,15 +65,22 @@ def suggest( seed: jax.random.KeyArray, state: None, n_parallel: int = 1, - ) -> jax.Array: - output_len = sum( - [spec.num_dimensions for spec in self.converter.output_specs] + ) -> vb.VectorizedOptimizerInput: + continuous_output_len = sum( + [spec.num_dimensions for spec in self.converter.output_specs.continuous] ) - shape = (1, n_parallel, output_len) + categorical_output_len = len(self.converter.output_specs.categorical) + shape = (1, n_parallel, continuous_output_len) if self.num_trials_so_far < self.num_trial_to_converge: - return jnp.ones(shape) * self.bad_value + continuous = jnp.ones(shape) * self.bad_value else: - return jnp.ones(shape) * self.good_value + continuous = jnp.ones(shape) * self.good_value + return vb.VectorizedOptimizerInput( + continuous=continuous, + categorical=jnp.zeros( + (1, n_parallel, categorical_output_len), dtype=types.INT_DTYPE + ), + ) @property def suggestion_batch_size(self) -> int: @@ -83,7 +90,7 @@ def update( self, seed: jax.random.KeyArray, state: None, - batch_features: types.Array, + batch_features: vb.VectorizedOptimizerInput, batch_rewards: types.Array, ) -> None: pass @@ -192,7 +199,7 @@ class SimpleRegretConvergenceRunnerTest(parameterized.TestCase): def setUp(self): super(SimpleRegretConvergenceRunnerTest, self).setUp() self.experimenter = experimenters.BBOBExperimenterFactory('Sphere', 3)() - self.converter = converters.TrialToArrayConverter.from_study_config( + self.converter = converters.TrialToModelInputConverter.from_problem( self.experimenter.problem_statement() ) @@ -327,7 +334,7 @@ def _baseline_designer_factory(problem, seed): }, ) def test_optimizer_convergence(self, candidate_x_value, goal, should_pass): - score_fn = lambda x: np.sum(x, axis=-1) + score_fn = lambda x: jnp.sum(x.continuous.padded_array, axis=-1) simple_regret_test = comparator_runner.SimpleRegretComparisonTester( baseline_num_trials=100, candidate_num_trials=100, diff --git a/vizier/_src/jax/gp_bandit_utils.py b/vizier/_src/jax/gp_bandit_utils.py index af50cc517..38d879fdc 100644 --- a/vizier/_src/jax/gp_bandit_utils.py +++ b/vizier/_src/jax/gp_bandit_utils.py @@ -50,35 +50,3 @@ def stochastic_process_model_setup( ): """Setup function for a stochastic process model.""" return model.init(key, data.features)['params'] - - -# TODO: Remove this when Vectorized Optimizer works on CACV. -def make_one_hot_to_modelinput_fn(seed_features_unpad, mapper, cacpa): - """Temporary utility fn for converting one hot to ModelInput.""" - - def _one_hot_to_cacpa(x_): - if seed_features_unpad is not None: - x_unpad = x_[..., : seed_features_unpad.shape[1]] - else: - x_unpad = x_ - cacv = mapper.map(x_unpad) - return types.ModelInput( - continuous=types.PaddedArray.from_array( - cacv.continuous, - ( - cacv.continuous.shape[:-1] - + (cacpa.continuous.padded_array.shape[1],) - ), - fill_value=cacpa.continuous.fill_value, - ), - categorical=types.PaddedArray.from_array( - cacv.categorical, - ( - cacv.categorical.shape[:-1] - + (cacpa.categorical.padded_array.shape[1],) - ), - fill_value=cacpa.categorical.fill_value, - ), - ) - - return _one_hot_to_cacpa diff --git a/vizier/_src/jax/types.py b/vizier/_src/jax/types.py index acbb1eb1a..95d739f0a 100644 --- a/vizier/_src/jax/types.py +++ b/vizier/_src/jax/types.py @@ -143,9 +143,6 @@ def unpad(self) -> jt.Shaped[jax.Array, '...']: ) -# TODO: Remove this. -MaybePaddedArray = Union[Array, PaddedArray] - ArrayTree = Union[ArrayLike, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']] # An ArrayTree that allows None values. @@ -181,19 +178,6 @@ class ModelData(eqx.Module): # Tuple representing a box constraint of the form (lower, upper) bounds. Bounds = tuple[Optional[ArrayTreeOptional], Optional[ArrayTreeOptional]] -# TODO: Deprecate it. We will always use ModelInput type. -Features = TypeVar('Features', Array, ContinuousAndCategoricalArray) - - -# TODO: Deprecate it in favor of ModelData type. -class StochasticProcessModelData(Generic[Features], eqx.Module): - """Data that feed into GP.""" - - features: Features - labels: Array = eqx.field(converter=jnp.asarray) - label_is_missing: Optional[Array] = None - dimension_is_missing: Optional[Features] = None - # TODO: Deprecate it in favor of # PrecomputedPredictive for full predictive state including cholesky, and