|
21 | 21 | import logging
|
22 | 22 | from abc import ABC, abstractmethod
|
23 | 23 | from functools import partial
|
24 |
| -from warnings import warn |
25 | 24 |
|
26 | 25 | logging.basicConfig()
|
27 | 26 |
|
@@ -1214,3 +1213,157 @@ def _get_local_global(self, input_dict, **kwargs):
|
1214 | 1213 | local_summaries = input_dict.get("direct_local_conditions")
|
1215 | 1214 | global_summaries = input_dict.get("direct_global_conditions")
|
1216 | 1215 | return local_summaries, global_summaries
|
| 1216 | + |
| 1217 | + |
| 1218 | +class AmortizedPointEstimator(tf.keras.Model): |
| 1219 | + """An interface to connect a neural point estimator for Bayesian estimation with an optional summary network [1]. |
| 1220 | +
|
| 1221 | + [1] Sainsbury-Dale, M., Zammit-Mangion, A., & Huser, R. (2024). |
| 1222 | + Likelihood-free parameter estimation with neural Bayes estimators. |
| 1223 | + The American Statistician, 78(1), 1-14. |
| 1224 | + """ |
| 1225 | + |
| 1226 | + def __init__(self, inference_net, summary_net=None, norm_ord=2, loss_fun=None): |
| 1227 | + """Initializes a composite neural architecture for amortized bayesian model comparison. |
| 1228 | +
|
| 1229 | + Parameters |
| 1230 | + ---------- |
| 1231 | + inference_net : tf.keras.Model |
| 1232 | + A neural network whose final output dimension equals that of the target quantities. |
| 1233 | + summary_net : tf.keras.Model or None, optional, default: None |
| 1234 | + An optional summary network |
| 1235 | + norm_ord : int or np.inf, optional, default: 2 |
| 1236 | + The order of the norm used as a loss function for the point estimator. Should be in ``[1, 2, np.inf]``. |
| 1237 | + loss_fun : callable or None, optional, default: None |
| 1238 | + If not None, it overrides the norm keyword argument. |
| 1239 | +
|
| 1240 | + Important |
| 1241 | + ---------- |
| 1242 | + - If no ``summary_net`` is provided, then the output dictionary of your generative model should not contain |
| 1243 | + any `sumamry_conditions`, i.e., ``summary_conditions`` should be set to None, otherwise these will be ignored. |
| 1244 | +
|
| 1245 | + - If no custom ``loss_fun`` is provided, the loss function will be the log loss for the means of a Dirichlet |
| 1246 | + distribution or softmax outputs. |
| 1247 | + """ |
| 1248 | + |
| 1249 | + super().__init__() |
| 1250 | + |
| 1251 | + self.inference_net = inference_net |
| 1252 | + self.summary_net = summary_net |
| 1253 | + self.loss_fn = self._determine_loss(loss_fun, norm_ord) |
| 1254 | + |
| 1255 | + def call(self, input_dict, return_summary=False, **kwargs): |
| 1256 | + """Performs a forward pass through the summary and inference network given an input dictionary. |
| 1257 | +
|
| 1258 | + Parameters |
| 1259 | + ---------- |
| 1260 | + input_dict : dict |
| 1261 | + Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged: |
| 1262 | + ``parameters`` - the latent model parameters over which a condition density is learned |
| 1263 | + ``summary_conditions`` - the conditioning variables (including data) that are first passed through a summary network |
| 1264 | + ``direct_conditions`` - the conditioning variables that the directly passed to the inference network |
| 1265 | + return_summary : bool, optional, default: False |
| 1266 | + A flag which determines whether the learnable data summaries (representations) are returned or not. |
| 1267 | + **kwargs : dict, optional, default: {} |
| 1268 | + Additional keyword arguments passed to the networks |
| 1269 | + For instance, ``kwargs={'training': True}`` is passed automatically during training. |
| 1270 | +
|
| 1271 | + Returns |
| 1272 | + ------- |
| 1273 | + net_out or (net_out, summary_out) : tuple of tf.Tensor |
| 1274 | + The outputs of ``inference_net(summary_net(x, c_s), c_d)``, usually a batch of point estimates, |
| 1275 | + that is, a tensor ``estimates`` or ``(sum_outputs, estimates)`` if ``return_summary`` is set |
| 1276 | + to True and a summary network is defined. |
| 1277 | + """ |
| 1278 | + |
| 1279 | + # Concatenate conditions, if given |
| 1280 | + summary_out, full_cond = self._compute_summary_condition( |
| 1281 | + input_dict.get(DEFAULT_KEYS["summary_conditions"]), |
| 1282 | + input_dict.get(DEFAULT_KEYS["direct_conditions"]), |
| 1283 | + **kwargs, |
| 1284 | + ) |
| 1285 | + |
| 1286 | + # Compute output of inference net |
| 1287 | + net_out = self.inference_net(full_cond, **kwargs) |
| 1288 | + |
| 1289 | + # Return summary outputs or not, depending on parameter |
| 1290 | + if return_summary: |
| 1291 | + return net_out, summary_out |
| 1292 | + return net_out |
| 1293 | + |
| 1294 | + def estimate(self, input_dict, to_numpy=True, **kwargs): |
| 1295 | + """Obtains Bayesian point estimates given the data in input_dict. |
| 1296 | +
|
| 1297 | + Parameters |
| 1298 | + ---------- |
| 1299 | + input_dict : dict |
| 1300 | + Input dictionary containing at least one of the following mandatory keys, if ``DEFAULT_KEYS`` unchanged: |
| 1301 | + ``summary_conditions`` : the conditioning variables (including data) that are first passed through a summary network |
| 1302 | + ``direct_conditions`` : the conditioning variables that the directly passed to the inference network |
| 1303 | + to_numpy : bool, optional, default: True |
| 1304 | + Flag indicating whether to return the samples as a ``np.ndarray`` or a ``tf.Tensor``. |
| 1305 | + **kwargs : dict, optional, default: {} |
| 1306 | + Additional keyword arguments passed to the networks. |
| 1307 | +
|
| 1308 | + Returns |
| 1309 | + ------- |
| 1310 | + estimates : tf.Tensor or np.ndarray of shape (num_data_sets, num_params) |
| 1311 | + The point estimates of the parameters for each data set. |
| 1312 | + """ |
| 1313 | + |
| 1314 | + estimates = self(input_dict, **kwargs) |
| 1315 | + if to_numpy: |
| 1316 | + return estimates.numpy() |
| 1317 | + return estimates |
| 1318 | + |
| 1319 | + def compute_loss(self, input_dict, **kwargs): |
| 1320 | + """Computes the loss of the posterior amortizer given an input dictionary, which will |
| 1321 | + typically be the output of a Bayesian ``GenerativeModel`` instance. |
| 1322 | +
|
| 1323 | + Parameters |
| 1324 | + ---------- |
| 1325 | + input_dict : dict |
| 1326 | + Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged: |
| 1327 | + ``parameters`` - the latent model parameters over which a condition density is learned |
| 1328 | + ``summary_conditions`` - the conditioning variables that are first passed through a summary network |
| 1329 | + ``direct_conditions`` - the conditioning variables that the directly passed to the inference network |
| 1330 | + **kwargs : dict, optional, default: {} |
| 1331 | + Additional keyword arguments passed to the networks |
| 1332 | + For instance, ``kwargs={'training': True}`` is passed automatically during training. |
| 1333 | +
|
| 1334 | + Returns |
| 1335 | + ------- |
| 1336 | + total_loss : tf.Tensor of shape (1,) - the total computed loss given input variables |
| 1337 | + """ |
| 1338 | + |
| 1339 | + net_out = self(input_dict, **kwargs) |
| 1340 | + loss = tf.reduce_mean(self.loss_fn(net_out - input_dict[DEFAULT_KEYS["parameters"]])) |
| 1341 | + return loss |
| 1342 | + |
| 1343 | + def _compute_summary_condition(self, summary_conditions, direct_conditions, **kwargs): |
| 1344 | + """Determines how to concatenate the provided conditions.""" |
| 1345 | + |
| 1346 | + # Compute learnable summaries, if given |
| 1347 | + if self.summary_net is not None: |
| 1348 | + sum_condition = self.summary_net(summary_conditions, **kwargs) |
| 1349 | + else: |
| 1350 | + sum_condition = None |
| 1351 | + |
| 1352 | + # Concatenate learnable summaries with fixed summaries |
| 1353 | + if sum_condition is not None and direct_conditions is not None: |
| 1354 | + full_cond = tf.concat([sum_condition, direct_conditions], axis=-1) |
| 1355 | + elif sum_condition is not None: |
| 1356 | + full_cond = sum_condition |
| 1357 | + elif direct_conditions is not None: |
| 1358 | + full_cond = direct_conditions |
| 1359 | + else: |
| 1360 | + raise SummaryStatsError("Could not concatenarte or determine conditioning inputs...") |
| 1361 | + return sum_condition, full_cond |
| 1362 | + |
| 1363 | + def _determine_loss(self, loss_fun, norm_ord): |
| 1364 | + """Determines which loss function to use and defaults to the norm_ord=2 as specified by the ``__init__`` method.""" |
| 1365 | + |
| 1366 | + # In case of user-provided loss, override norm order |
| 1367 | + if loss_fun is not None: |
| 1368 | + return loss_fun |
| 1369 | + return partial(tf.norm, ord=norm_ord, axis=-1) |
0 commit comments