Skip to content

Commit 0b866a9

Browse files
committed
Add AmortizedPointEstimator
1 parent c4ccad8 commit 0b866a9

File tree

1 file changed

+154
-1
lines changed

1 file changed

+154
-1
lines changed

bayesflow/amortizers.py

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import logging
2222
from abc import ABC, abstractmethod
2323
from functools import partial
24-
from warnings import warn
2524

2625
logging.basicConfig()
2726

@@ -1214,3 +1213,157 @@ def _get_local_global(self, input_dict, **kwargs):
12141213
local_summaries = input_dict.get("direct_local_conditions")
12151214
global_summaries = input_dict.get("direct_global_conditions")
12161215
return local_summaries, global_summaries
1216+
1217+
1218+
class AmortizedPointEstimator(tf.keras.Model):
1219+
"""An interface to connect a neural point estimator for Bayesian estimation with an optional summary network [1].
1220+
1221+
[1] Sainsbury-Dale, M., Zammit-Mangion, A., & Huser, R. (2024).
1222+
Likelihood-free parameter estimation with neural Bayes estimators.
1223+
The American Statistician, 78(1), 1-14.
1224+
"""
1225+
1226+
def __init__(self, inference_net, summary_net=None, norm_ord=2, loss_fun=None):
1227+
"""Initializes a composite neural architecture for amortized bayesian model comparison.
1228+
1229+
Parameters
1230+
----------
1231+
inference_net : tf.keras.Model
1232+
A neural network whose final output dimension equals that of the target quantities.
1233+
summary_net : tf.keras.Model or None, optional, default: None
1234+
An optional summary network
1235+
norm_ord : int or np.inf, optional, default: 2
1236+
The order of the norm used as a loss function for the point estimator. Should be in ``[1, 2, np.inf]``.
1237+
loss_fun : callable or None, optional, default: None
1238+
If not None, it overrides the norm keyword argument.
1239+
1240+
Important
1241+
----------
1242+
- If no ``summary_net`` is provided, then the output dictionary of your generative model should not contain
1243+
any `sumamry_conditions`, i.e., ``summary_conditions`` should be set to None, otherwise these will be ignored.
1244+
1245+
- If no custom ``loss_fun`` is provided, the loss function will be the log loss for the means of a Dirichlet
1246+
distribution or softmax outputs.
1247+
"""
1248+
1249+
super().__init__()
1250+
1251+
self.inference_net = inference_net
1252+
self.summary_net = summary_net
1253+
self.loss_fn = self._determine_loss(loss_fun, norm_ord)
1254+
1255+
def call(self, input_dict, return_summary=False, **kwargs):
1256+
"""Performs a forward pass through the summary and inference network given an input dictionary.
1257+
1258+
Parameters
1259+
----------
1260+
input_dict : dict
1261+
Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
1262+
``parameters`` - the latent model parameters over which a condition density is learned
1263+
``summary_conditions`` - the conditioning variables (including data) that are first passed through a summary network
1264+
``direct_conditions`` - the conditioning variables that the directly passed to the inference network
1265+
return_summary : bool, optional, default: False
1266+
A flag which determines whether the learnable data summaries (representations) are returned or not.
1267+
**kwargs : dict, optional, default: {}
1268+
Additional keyword arguments passed to the networks
1269+
For instance, ``kwargs={'training': True}`` is passed automatically during training.
1270+
1271+
Returns
1272+
-------
1273+
net_out or (net_out, summary_out) : tuple of tf.Tensor
1274+
The outputs of ``inference_net(summary_net(x, c_s), c_d)``, usually a batch of point estimates,
1275+
that is, a tensor ``estimates`` or ``(sum_outputs, estimates)`` if ``return_summary`` is set
1276+
to True and a summary network is defined.
1277+
"""
1278+
1279+
# Concatenate conditions, if given
1280+
summary_out, full_cond = self._compute_summary_condition(
1281+
input_dict.get(DEFAULT_KEYS["summary_conditions"]),
1282+
input_dict.get(DEFAULT_KEYS["direct_conditions"]),
1283+
**kwargs,
1284+
)
1285+
1286+
# Compute output of inference net
1287+
net_out = self.inference_net(full_cond, **kwargs)
1288+
1289+
# Return summary outputs or not, depending on parameter
1290+
if return_summary:
1291+
return net_out, summary_out
1292+
return net_out
1293+
1294+
def estimate(self, input_dict, to_numpy=True, **kwargs):
1295+
"""Obtains Bayesian point estimates given the data in input_dict.
1296+
1297+
Parameters
1298+
----------
1299+
input_dict : dict
1300+
Input dictionary containing at least one of the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
1301+
``summary_conditions`` : the conditioning variables (including data) that are first passed through a summary network
1302+
``direct_conditions`` : the conditioning variables that the directly passed to the inference network
1303+
to_numpy : bool, optional, default: True
1304+
Flag indicating whether to return the samples as a ``np.ndarray`` or a ``tf.Tensor``.
1305+
**kwargs : dict, optional, default: {}
1306+
Additional keyword arguments passed to the networks.
1307+
1308+
Returns
1309+
-------
1310+
estimates : tf.Tensor or np.ndarray of shape (num_data_sets, num_params)
1311+
The point estimates of the parameters for each data set.
1312+
"""
1313+
1314+
estimates = self(input_dict, **kwargs)
1315+
if to_numpy:
1316+
return estimates.numpy()
1317+
return estimates
1318+
1319+
def compute_loss(self, input_dict, **kwargs):
1320+
"""Computes the loss of the posterior amortizer given an input dictionary, which will
1321+
typically be the output of a Bayesian ``GenerativeModel`` instance.
1322+
1323+
Parameters
1324+
----------
1325+
input_dict : dict
1326+
Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
1327+
``parameters`` - the latent model parameters over which a condition density is learned
1328+
``summary_conditions`` - the conditioning variables that are first passed through a summary network
1329+
``direct_conditions`` - the conditioning variables that the directly passed to the inference network
1330+
**kwargs : dict, optional, default: {}
1331+
Additional keyword arguments passed to the networks
1332+
For instance, ``kwargs={'training': True}`` is passed automatically during training.
1333+
1334+
Returns
1335+
-------
1336+
total_loss : tf.Tensor of shape (1,) - the total computed loss given input variables
1337+
"""
1338+
1339+
net_out = self(input_dict, **kwargs)
1340+
loss = tf.reduce_mean(self.loss_fn(net_out - input_dict[DEFAULT_KEYS["parameters"]]))
1341+
return loss
1342+
1343+
def _compute_summary_condition(self, summary_conditions, direct_conditions, **kwargs):
1344+
"""Determines how to concatenate the provided conditions."""
1345+
1346+
# Compute learnable summaries, if given
1347+
if self.summary_net is not None:
1348+
sum_condition = self.summary_net(summary_conditions, **kwargs)
1349+
else:
1350+
sum_condition = None
1351+
1352+
# Concatenate learnable summaries with fixed summaries
1353+
if sum_condition is not None and direct_conditions is not None:
1354+
full_cond = tf.concat([sum_condition, direct_conditions], axis=-1)
1355+
elif sum_condition is not None:
1356+
full_cond = sum_condition
1357+
elif direct_conditions is not None:
1358+
full_cond = direct_conditions
1359+
else:
1360+
raise SummaryStatsError("Could not concatenarte or determine conditioning inputs...")
1361+
return sum_condition, full_cond
1362+
1363+
def _determine_loss(self, loss_fun, norm_ord):
1364+
"""Determines which loss function to use and defaults to the norm_ord=2 as specified by the ``__init__`` method."""
1365+
1366+
# In case of user-provided loss, override norm order
1367+
if loss_fun is not None:
1368+
return loss_fun
1369+
return partial(tf.norm, ord=norm_ord, axis=-1)

0 commit comments

Comments
 (0)