diff --git a/pymc/model/core.py b/pymc/model/core.py index b85cc802fc..4d38bd7d41 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -569,15 +569,16 @@ def logp_dlogp_function( for var in self.value_vars if var in input_vars and var not in grad_vars } - return ValueGradFunction( - costs, - grad_vars, - extra_vars_and_values, - model=self, - initial_point=initial_point, - ravel_inputs=ravel_inputs, - **kwargs, - ) + with self: + return ValueGradFunction( + costs, + grad_vars, + extra_vars_and_values, + model=self, + initial_point=initial_point, + ravel_inputs=ravel_inputs, + **kwargs, + ) def compile_logp( self, diff --git a/pymc/model/fgraph.py b/pymc/model/fgraph.py index 5dc47fe0ee..ab2b554bfd 100644 --- a/pymc/model/fgraph.py +++ b/pymc/model/fgraph.py @@ -223,6 +223,7 @@ def fgraph_from_model( copy_inputs=True, ) # Copy model meta-info to fgraph + fgraph.check_bounds = model.check_bounds fgraph._coords = model._coords.copy() fgraph._dim_lengths = {k: memo.get(v, v) for k, v in model._dim_lengths.items()} @@ -318,6 +319,7 @@ def first_non_model_var(var): # TODO: Consider representing/extracting them from the fgraph! _dim_lengths = {k: memo.get(v, v) for k, v in _dim_lengths.items()} + model.check_bounds = getattr(fgraph, "check_bounds", False) model._coords = _coords model._dim_lengths = _dim_lengths diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 82eca936b7..71c3ffe232 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -563,24 +563,24 @@ def join_nonshared_inputs( raise ValueError("Empty list of input variables.") raveled_inputs = pt.concatenate([var.ravel() for var in inputs]) + input_sizes = [point[var_name].size for var_name in point] + size = sum(input_sizes) if not make_inputs_shared: - tensor_type = raveled_inputs.type - joined_inputs = tensor_type("joined_inputs") + joined_inputs = pt.tensor("joined_inputs", shape=(size,), dtype=raveled_inputs.dtype) else: joined_values = np.concatenate([point[var.name].ravel() for var in inputs]) - joined_inputs = pytensor.shared(joined_values, "joined_inputs") + joined_inputs = pytensor.shared(joined_values, "joined_inputs", shape=(size,)) if pytensor.config.compute_test_value != "off": joined_inputs.tag.test_value = raveled_inputs.tag.test_value replace: dict[TensorVariable, TensorVariable] = {} - last_idx = 0 - for var in inputs: + for var, flat_var in zip( + inputs, pt.split(joined_inputs, input_sizes, len(inputs)), strict=True + ): shape = point[var.name].shape - arr_len = np.prod(shape, dtype=int) - replace[var] = joined_inputs[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype) - last_idx += arr_len + replace[var] = flat_var.reshape(shape).astype(var.dtype) if shared_inputs is not None: replace.update(shared_inputs) diff --git a/pymc/step_methods/hmc/integration.py b/pymc/step_methods/hmc/integration.py index 4eb7a15d8f..6d8f0d593d 100644 --- a/pymc/step_methods/hmc/integration.py +++ b/pymc/step_methods/hmc/integration.py @@ -42,6 +42,7 @@ def __init__(self, potential: QuadPotential, logp_dlogp_func): self._potential = potential # Sidestep logp_dlogp_function.__call__ pytensor_function = logp_dlogp_func._pytensor_function + pytensor_function.vm.allow_gc = False # Create some wrappers for backwards compatibility during transition # When raveled_inputs=False is forbidden, func = pytensor_function if logp_dlogp_func._raveled_inputs: diff --git a/tests/model/test_core.py b/tests/model/test_core.py index 4375a17ad2..97f4a3e5aa 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -443,6 +443,15 @@ def test_missing_data(self): # Assert that all the elements of res are equal assert res[1:] == res[:-1] + def test_check_bounds_out_of_model_context(self): + with pm.Model(check_bounds=False) as m: + x = pm.Normal("x") + y = pm.Normal("y", sigma=x) + fn = m.logp_dlogp_function(ravel_inputs=True) + fn.set_extra_values({}) + # When there are no bounds check logp turns into `nan` + assert np.isnan(fn(np.array([-1.0, -1.0]))[0]) + class TestPytensorRelatedLogpBugs: def test_pytensor_switch_broadcast_edge_cases_1(self): diff --git a/tests/model/test_fgraph.py b/tests/model/test_fgraph.py index 178eb39683..d36fd2bc13 100644 --- a/tests/model/test_fgraph.py +++ b/tests/model/test_fgraph.py @@ -397,3 +397,13 @@ def test_multivariate_transform(): new_ip = new_m.initial_point() np.testing.assert_allclose(ip["x_simplex__"], new_ip["x_simplex__"]) np.testing.assert_allclose(ip["y_cholesky-cov-packed__"], new_ip["y_cholesky-cov-packed__"]) + + +def test_check_bounds_preserved(): + with pm.Model(check_bounds=True) as m: + x = pm.HalfNormal("x") + + assert clone_model(m).check_bounds + + m.check_bounds = False + assert not clone_model(m).check_bounds