@@ -345,10 +345,13 @@ def draw(
345
345
return [np .stack (v ) for v in drawn_values ]
346
346
347
347
348
- def observed_dependent_deterministics (model : Model ):
348
+ def observed_dependent_deterministics (model : Model , extra_observeds = None ):
349
349
"""Find deterministics that depend directly on observed variables."""
350
+ if extra_observeds is None :
351
+ extra_observeds = []
352
+
350
353
deterministics = model .deterministics
351
- observed_rvs = set (model .observed_RVs )
354
+ observed_rvs = set (model .observed_RVs + extra_observeds )
352
355
blockers = model .basic_RVs
353
356
return [
354
357
deterministic
@@ -767,13 +770,15 @@ def sample_posterior_predictive(
767
770
if "coords" not in idata_kwargs :
768
771
idata_kwargs ["coords" ] = {}
769
772
idata : InferenceData | None = None
773
+ observed_data = None
770
774
stacked_dims = None
771
775
if isinstance (trace , InferenceData ):
772
776
_constant_data = getattr (trace , "constant_data" , None )
773
777
if _constant_data is not None :
774
778
trace_coords .update ({str (k ): v .data for k , v in _constant_data .coords .items ()})
775
779
constant_data .update ({str (k ): v .data for k , v in _constant_data .items ()})
776
780
idata = trace
781
+ observed_data = trace .get ("observed_data" , None )
777
782
trace = trace ["posterior" ]
778
783
if isinstance (trace , xarray .Dataset ):
779
784
trace_coords .update ({str (k ): v .data for k , v in trace .coords .items ()})
@@ -816,7 +821,12 @@ def sample_posterior_predictive(
816
821
if var_names is not None :
817
822
vars_ = [model [x ] for x in var_names ]
818
823
else :
819
- vars_ = model .observed_RVs + observed_dependent_deterministics (model )
824
+ observed_vars = model .observed_RVs
825
+ if observed_data is not None :
826
+ observed_vars += [
827
+ model [x ] for x in observed_data if x in model and x not in observed_vars
828
+ ]
829
+ vars_ = observed_vars + observed_dependent_deterministics (model , observed_vars )
820
830
821
831
vars_to_sample = list (get_default_varnames (vars_ , include_transformed = False ))
822
832
0 commit comments