Skip to content

Commit 892c37a

Browse files
authored
Check for observed variables in the trace (#7641)
1 parent 2012262 commit 892c37a

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

pymc/sampling/forward.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,13 @@ def draw(
345345
return [np.stack(v) for v in drawn_values]
346346

347347

348-
def observed_dependent_deterministics(model: Model):
348+
def observed_dependent_deterministics(model: Model, extra_observeds=None):
349349
"""Find deterministics that depend directly on observed variables."""
350+
if extra_observeds is None:
351+
extra_observeds = []
352+
350353
deterministics = model.deterministics
351-
observed_rvs = set(model.observed_RVs)
354+
observed_rvs = set(model.observed_RVs + extra_observeds)
352355
blockers = model.basic_RVs
353356
return [
354357
deterministic
@@ -767,13 +770,15 @@ def sample_posterior_predictive(
767770
if "coords" not in idata_kwargs:
768771
idata_kwargs["coords"] = {}
769772
idata: InferenceData | None = None
773+
observed_data = None
770774
stacked_dims = None
771775
if isinstance(trace, InferenceData):
772776
_constant_data = getattr(trace, "constant_data", None)
773777
if _constant_data is not None:
774778
trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()})
775779
constant_data.update({str(k): v.data for k, v in _constant_data.items()})
776780
idata = trace
781+
observed_data = trace.get("observed_data", None)
777782
trace = trace["posterior"]
778783
if isinstance(trace, xarray.Dataset):
779784
trace_coords.update({str(k): v.data for k, v in trace.coords.items()})
@@ -816,7 +821,12 @@ def sample_posterior_predictive(
816821
if var_names is not None:
817822
vars_ = [model[x] for x in var_names]
818823
else:
819-
vars_ = model.observed_RVs + observed_dependent_deterministics(model)
824+
observed_vars = model.observed_RVs
825+
if observed_data is not None:
826+
observed_vars += [
827+
model[x] for x in observed_data if x in model and x not in observed_vars
828+
]
829+
vars_ = observed_vars + observed_dependent_deterministics(model, observed_vars)
820830

821831
vars_to_sample = list(get_default_varnames(vars_, include_transformed=False))
822832

tests/sampling/test_forward.py

+18
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,24 @@ def test_normal_scalar_idata(self):
540540
ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False)
541541
assert ppc["a"].shape == (nchains, ndraws)
542542

543+
def test_external_trace_det(self):
544+
with pm.Model() as model:
545+
mu = pm.Normal("mu", 0.0, 1.0)
546+
a = pm.Normal("a", mu=mu, sigma=1, observed=0.0)
547+
b = pm.Deterministic("b", a + 1)
548+
trace = pm.sample(tune=50, draws=50, chains=1, compute_convergence_checks=False)
549+
550+
# test that trace is used in ppc
551+
with pm.Model() as model_ppc:
552+
mu = pm.Normal("mu", 0.0, 1.0)
553+
a = pm.Normal("a", mu=mu, sigma=1)
554+
c = pm.Deterministic("c", a + 1)
555+
556+
ppc = pm.sample_posterior_predictive(
557+
trace=trace, model=model_ppc, return_inferencedata=False
558+
)
559+
assert list(ppc.keys()) == ["a", "c"]
560+
543561
def test_normal_vector(self):
544562
with pm.Model() as model:
545563
mu = pm.Normal("mu", 0.0, 1.0)

0 commit comments

Comments
 (0)