Skip to content

Commit c09af15

Browse files
sdaultonfacebook-github-bot
authored andcommitted
update optimize with NSGA-II (#2937)
Summary: Pull Request resolved: #2937 Updates: - Use HV-maximizing greedy subset selection - Pass MultiOutputAcquisitionFunction instead of Model as an argument - Add support for fixed features - Move optimize_with_nsgaii to a new file to avoid circular imports Reviewed By: bletham Differential Revision: D77696697 fbshipit-source-id: 89ffa62a6431ce70d29b295dd80a1943935f5241
1 parent adb4cfb commit c09af15

File tree

5 files changed

+419
-282
lines changed

5 files changed

+419
-282
lines changed
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
import warnings
10+
11+
from typing import Callable
12+
13+
import numpy as np
14+
import torch
15+
from botorch.acquisition.multi_objective.objective import (
16+
IdentityMCMultiOutputObjective,
17+
MCMultiOutputObjective,
18+
)
19+
from botorch.acquisition.multioutput_acquisition import MultiOutputAcquisitionFunction
20+
from botorch.exceptions import BotorchWarning
21+
from botorch.utils.multi_objective.hypervolume import get_hypervolume_maximizing_subset
22+
from botorch.utils.multi_objective.pareto import is_non_dominated
23+
from torch import Tensor
24+
25+
try:
26+
from pymoo.algorithms.moo.nsga2 import NSGA2
27+
from pymoo.core.problem import Problem
28+
from pymoo.optimize import minimize
29+
from pymoo.termination.max_gen import MaximumGenerationTermination
30+
31+
class BotorchPymooProblem(Problem):
32+
def __init__(
33+
self,
34+
n_var: int,
35+
n_obj: int,
36+
xl: np.ndarray,
37+
xu: np.ndarray,
38+
acqf: MultiOutputAcquisitionFunction,
39+
dtype: torch.dtype,
40+
device: torch.device,
41+
ref_point: Tensor | None = None,
42+
objective: MCMultiOutputObjective | None = None,
43+
constraints: list[Callable[[Tensor], Tensor]] | None = None,
44+
) -> None:
45+
"""PyMOO problem for optimizing the model posterior mean using NSGA-II.
46+
47+
This is instantiated and used within `optimize_with_nsgaii` to define
48+
the optimization problem to interface with pymoo.
49+
50+
This assumes maximization of all objectives.
51+
52+
Args:
53+
n_var: The number of tunable parameters (`d`).
54+
n_obj: The number of objectives.
55+
xl: A `d`-dim np.ndarray of lower bounds for each tunable parameter.
56+
xu: A `d`-dim np.ndarray of upper bounds for each tunable parameter.
57+
acqf: A MultiOutputAcquisitionFunction.
58+
dtype: The torch dtype.
59+
device: The torch device.
60+
acqf: The acquisition function to optimize.
61+
ref_point: A list or tensor with `m` elements representing the reference
62+
point (in the outcome space), which is treated as a lower bound
63+
on the objectives, after applying `objective` to the samples.
64+
objective: The MCMultiOutputObjective under which the samples are
65+
evaluated. Defaults to `IdentityMultiOutputObjective()`.
66+
This can be used to determine which outputs of the
67+
MultiOutputAcquisitionFunction should be used as
68+
objectives/constraints in NSGA-II.
69+
constraints: A list of callables, each mapping a Tensor of dimension
70+
`sample_shape x batch-shape x q x m` to a Tensor of dimension
71+
`sample_shape x batch-shape x q`, where negative values imply
72+
feasibility.
73+
"""
74+
num_constraints = 0 if constraints is None else len(constraints)
75+
if ref_point is not None:
76+
num_constraints += ref_point.shape[0]
77+
super().__init__(
78+
n_var=n_var,
79+
n_obj=n_obj,
80+
n_ieq_constr=num_constraints,
81+
xl=xl,
82+
xu=xu,
83+
type_var=np.double,
84+
)
85+
self.botorch_acqf = acqf
86+
self.botorch_ref_point = ref_point
87+
self.botorch_objective = (
88+
IdentityMCMultiOutputObjective() if objective is None else objective
89+
)
90+
self.botorch_constraints = constraints
91+
self.torch_dtype = dtype
92+
self.torch_device = device
93+
94+
def _evaluate(self, x: np.ndarray, out: dict[str, np.ndarray]) -> None:
95+
"""Evaluate x with respect to the objective/constraints."""
96+
X = torch.from_numpy(x).to(dtype=self.torch_dtype, device=self.torch_device)
97+
with torch.no_grad():
98+
# eval in batch mode, since all we need is the mean and this helps
99+
# avoid ill-conditioning
100+
y = self.botorch_acqf(X=X.unsqueeze(-2))
101+
obj = self.botorch_objective(y)
102+
# negate the objectives, since we want to maximize this function
103+
out["F"] = -obj.cpu().numpy()
104+
constraint_vals = None
105+
if self.botorch_constraints is not None:
106+
constraint_vals = torch.stack(
107+
[c(y) for c in self.botorch_constraints], dim=-1
108+
)
109+
if self.botorch_ref_point is not None:
110+
# add constraints for the ref point
111+
ref_constraints = self.botorch_ref_point - obj
112+
if constraint_vals is not None:
113+
constraint_vals = torch.cat(
114+
[constraint_vals, ref_constraints], dim=-1
115+
)
116+
else:
117+
constraint_vals = ref_constraints
118+
if constraint_vals is not None:
119+
out["G"] = constraint_vals.cpu().numpy()
120+
121+
def optimize_with_nsgaii(
122+
acq_function: MultiOutputAcquisitionFunction,
123+
bounds: Tensor,
124+
num_objectives: int,
125+
q: int | None = None,
126+
ref_point: list[float] | Tensor | None = None,
127+
objective: MCMultiOutputObjective | None = None,
128+
constraints: list[Callable[[Tensor], Tensor]] | None = None,
129+
population_size: int = 250,
130+
max_gen: int | None = None,
131+
seed: int | None = None,
132+
fixed_features: dict[int, float] | None = None,
133+
) -> tuple[Tensor, Tensor]:
134+
"""Optimize the posterior mean via NSGA-II, returning the Pareto set and front.
135+
136+
This assumes maximization of all objectives.
137+
138+
TODO: Add support for discrete parameters.
139+
140+
Args:
141+
acq_function: The MultiOutputAcquisitionFunction to optimize.
142+
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
143+
q: The number of candidates. If None, return the full population.
144+
num_objectives: The number of objectives.
145+
ref_point: A list or tensor with `m` elements representing the reference
146+
point (in the outcome space), which is treated as a lower bound
147+
on the objectives, after applying `objective` to the samples.
148+
objective: The MCMultiOutputObjective under which the samples are
149+
evaluated. Defaults to `IdentityMultiOutputObjective()`.
150+
This can be used to determine which outputs of the
151+
MultiOutputAcquisitionFunction should be used as
152+
objectives/constraints in NSGA-II.
153+
constraints: A list of callables, each mapping a Tensor of dimension
154+
`sample_shape x batch-shape x q x m` to a Tensor of dimension
155+
`sample_shape x batch-shape x q`, where negative values imply
156+
feasibility.
157+
population_size: the population size for NSGA-II.
158+
max_gen: The number of iterations for NSGA-II. If None, this uses the
159+
default termination condition in pymoo for NSGA-II.
160+
seed: The random seed for NSGA-II.
161+
fixed_features: A map `{feature_index: value}` for features that
162+
should be fixed to a particular value during generation. All indices
163+
should be non-negative.
164+
165+
Returns:
166+
A two-element tuple containing the pareto set X and pareto frontier Y.
167+
"""
168+
tkwargs = {"dtype": bounds.dtype, "device": bounds.device}
169+
if ref_point is not None:
170+
ref_point = torch.as_tensor(ref_point, **tkwargs)
171+
if fixed_features is not None:
172+
bounds = bounds.clone()
173+
# set lower and upper bounds to the fixed value
174+
for i, val in fixed_features.items():
175+
bounds[:, i] = val
176+
with warnings.catch_warnings():
177+
warnings.simplefilter("ignore", category=DeprecationWarning)
178+
pymoo_problem = BotorchPymooProblem(
179+
n_var=bounds.shape[-1],
180+
n_obj=num_objectives,
181+
xl=bounds[0].cpu().numpy(),
182+
xu=bounds[1].cpu().numpy(),
183+
acqf=acq_function,
184+
ref_point=ref_point,
185+
objective=objective,
186+
constraints=constraints,
187+
**tkwargs,
188+
)
189+
if q is not None:
190+
population_size = max(population_size, q)
191+
algorithm = NSGA2(pop_size=population_size, eliminate_duplicates=True)
192+
res = minimize(
193+
problem=pymoo_problem,
194+
algorithm=algorithm,
195+
termination=(
196+
None
197+
if max_gen is None
198+
else MaximumGenerationTermination(n_max_gen=max_gen)
199+
),
200+
seed=seed,
201+
verbose=False,
202+
)
203+
X = torch.tensor(res.X, **tkwargs)
204+
# multiply by negative one to return the correct sign for maximization
205+
Y = -torch.tensor(res.F, **tkwargs)
206+
pareto_mask = is_non_dominated(Y, deduplicate=True)
207+
X_pareto = X[pareto_mask]
208+
Y_pareto = Y[pareto_mask]
209+
if q is not None:
210+
if Y_pareto.shape[0] > q:
211+
Y_pareto, indices = get_hypervolume_maximizing_subset(
212+
# use nadir as reference point since we likely don't care about the
213+
# extrema as much as the interior
214+
n=q,
215+
Y=Y_pareto,
216+
ref_point=Y_pareto.min(dim=0).values,
217+
)
218+
X_pareto = X_pareto[indices]
219+
elif Y_pareto.shape[0] < q:
220+
n_missing = q - Y_pareto.shape[0]
221+
if Y.shape[0] >= q:
222+
# select some dominated solutions
223+
rand_idcs = np.random.choice(
224+
(~pareto_mask).nonzero().view(-1).cpu().numpy(),
225+
n_missing,
226+
replace=False,
227+
)
228+
rand_idcs = torch.from_numpy(rand_idcs).to(
229+
device=pareto_mask.device
230+
)
231+
pareto_mask[rand_idcs] = 1
232+
X_pareto = X[pareto_mask]
233+
Y_pareto = Y[pareto_mask]
234+
else:
235+
warnings.warn(
236+
f"NSGA-II only returned {Y.shape[0]} points.",
237+
BotorchWarning,
238+
stacklevel=3,
239+
)
240+
return X, Y
241+
return X_pareto, Y_pareto
242+
243+
except ImportError: # pragma: no cover
244+
pass

0 commit comments

Comments
 (0)