From 82e3c56b59bfb19656a4908686af69175c3b40d0 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 27 May 2025 23:55:44 -0400 Subject: [PATCH 1/4] Wrap HybridBayesNet methods --- gtsam/discrete/discrete.i | 1 + gtsam/hybrid/HybridBayesNet.h | 2 +- gtsam/hybrid/HybridFactor.h | 3 ++ gtsam/hybrid/hybrid.i | 43 ++++++++++++++++++++++--- python/gtsam/specializations/discrete.h | 1 + 5 files changed, 45 insertions(+), 5 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 54d00f82ac..92897e1526 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -47,6 +47,7 @@ virtual class DiscreteFactor : gtsam::Factor { gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteFactor& lf, double tol = 1e-9) const; double operator()(const gtsam::DiscreteValues& values) const; + gtsam::AlgebraicDecisionTreeKey errorTree() const; }; #include diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 0058c406c9..83fc1a7256 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -248,7 +248,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @param discrete Optional DiscreteValues * @return double */ - double negLogConstant(const std::optional &discrete) const; + double negLogConstant(const std::optional& discrete = {}) const; /** * @brief Compute normalized posterior P(M|X=x) and return as a tree. diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index 4147420bd1..0028dc43dc 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -162,4 +162,7 @@ class GTSAM_EXPORT HybridFactor : public Factor { template <> struct traits : public Testable {}; +// For wrapper: +using AlgebraicDecisionTreeKey = AlgebraicDecisionTree; + } // namespace gtsam diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 72e55e7515..264a71bc34 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -114,6 +114,12 @@ class HybridValues { }; #include +class AlgebraicDecisionTreeKey { + const double& operator()(const gtsam::DiscreteValues& x) const; + void print(string s = "AlgebraicDecisionTreeKey\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; virtual class HybridFactor : gtsam::Factor { void print(string s = "HybridFactor\n", const gtsam::KeyFormatter& keyFormatter = @@ -121,17 +127,19 @@ virtual class HybridFactor : gtsam::Factor { bool equals(const gtsam::HybridFactor& lf, double tol = 1e-9) const; // Standard interface: - double error(const gtsam::HybridValues& values) const; bool isDiscrete() const; bool isContinuous() const; bool isHybrid() const; size_t nrContinuous() const; gtsam::DiscreteKeys discreteKeys() const; gtsam::KeyVector continuousKeys() const; + double error(const gtsam::HybridValues& hybridValues) const; + gtsam::AlgebraicDecisionTreeKey errorTree(const gtsam::VectorValues &continuousValues); + gtsam::Factor restrict(const gtsam::DiscreteValues& assignment) const; }; #include -virtual class HybridConditional { +virtual class HybridConditional : gtsam::HybridFactor { void print(string s = "Hybrid Conditional\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; @@ -148,7 +156,6 @@ virtual class HybridConditional { gtsam::GaussianConditional* asGaussian() const; gtsam::DiscreteConditional* asDiscrete() const; gtsam::Factor* inner(); - double error(const gtsam::HybridValues& values) const; }; #include @@ -167,19 +174,31 @@ class HybridGaussianFactor : gtsam::HybridFactor { }; #include -class HybridGaussianConditional : gtsam::HybridFactor { +class HybridGaussianConditional : gtsam::HybridGaussianFactor { HybridGaussianConditional( const gtsam::DiscreteKeys& discreteParents, const gtsam::HybridGaussianConditional::Conditionals& conditionals); HybridGaussianConditional( const gtsam::DiscreteKey& discreteParent, const std::vector& conditionals); + // Standard API + gtsam::GaussianConditional::shared_ptr choose( + const gtsam::DiscreteValues &discreteValues) const; +// gtsam::GaussianConditional::shared_ptr operator()( +// const gtsam::DiscreteValues &discreteValues) const; + size_t nrComponents() const; + gtsam::KeyVector continuousParents() const; + double negLogConstant() const; gtsam::HybridGaussianFactor* likelihood( const gtsam::VectorValues& frontals) const; double logProbability(const gtsam::HybridValues& values) const; double evaluate(const gtsam::HybridValues& values) const; + HybridGaussianConditional::shared_ptr prune( + const gtsam::DiscreteConditional &discreteProbs) const; + bool pruned() const; + void print(string s = "HybridGaussianConditional\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; @@ -217,6 +236,7 @@ class HybridBayesNet { void push_back(const gtsam::HybridGaussianConditional* s); void push_back(const gtsam::GaussianConditional* s); void push_back(const gtsam::DiscreteConditional* s); + void push_back(gtsam::HybridConditional::shared_ptr conditional); bool empty() const; size_t size() const; @@ -227,19 +247,34 @@ class HybridBayesNet { double logProbability(const gtsam::HybridValues& x) const; double evaluate(const gtsam::HybridValues& values) const; double error(const gtsam::HybridValues& values) const; + gtsam::AlgebraicDecisionTreeKey errorTree( + const gtsam::VectorValues& continuousValues) const; gtsam::HybridGaussianFactorGraph toFactorGraph( const gtsam::VectorValues& measurements) const; + double negLogConstant() const; + double negLogConstant(const gtsam::DiscreteValues &discrete) const; + gtsam::AlgebraicDecisionTreeKey discretePosterior( + const gtsam::VectorValues &continuousValues) const; gtsam::DiscreteBayesNet discreteMarginal() const; gtsam::GaussianBayesNet choose(const gtsam::DiscreteValues& assignment) const; + gtsam::DiscreteBayesNet discreteMarginal() const; + gtsam::DiscreteValues mpe() const; + gtsam::HybridValues optimize() const; gtsam::VectorValues optimize(const gtsam::DiscreteValues& assignment) const; gtsam::HybridValues sample(const gtsam::HybridValues& given, std::mt19937_64@ rng = nullptr) const; gtsam::HybridValues sample(std::mt19937_64@ rng = nullptr) const; + gtsam::HybridBayesNet prune(size_t maxNrLeaves) const; + gtsam::HybridBayesNet prune(size_t maxNrLeaves, double marginalThreshold) const; + // gtsam::HybridBayesNet prune(size_t maxNrLeaves, + // const std::optional &marginalThreshold = std::nullopt, + // gtsam::DiscreteValues *fixedValues = nullptr) const; + void print(string s = "HybridBayesNet\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/python/gtsam/specializations/discrete.h b/python/gtsam/specializations/discrete.h index 458a2ea4c0..9ef9de2ae1 100644 --- a/python/gtsam/specializations/discrete.h +++ b/python/gtsam/specializations/discrete.h @@ -14,4 +14,5 @@ // Seems this is not a good idea with inherited stl //py::bind_vector>(m_, "DiscreteKeys"); +py::bind_map>(m_, "AssignmentKey"); py::bind_map(m_, "DiscreteValues"); From 8e2cbb369eab89add39e10f57928ec7991025e71 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 27 May 2025 23:55:59 -0400 Subject: [PATCH 2/4] New comprehensive tests --- python/gtsam/tests/test_HybridBayesNet.py | 359 ++++++++++++++++++---- 1 file changed, 302 insertions(+), 57 deletions(-) diff --git a/python/gtsam/tests/test_HybridBayesNet.py b/python/gtsam/tests/test_HybridBayesNet.py index 57346d4d4b..a32001cd45 100644 --- a/python/gtsam/tests/test_HybridBayesNet.py +++ b/python/gtsam/tests/test_HybridBayesNet.py @@ -8,7 +8,8 @@ Unit tests for Hybrid Values. Author: Frank Dellaert """ -# pylint: disable=invalid-name, no-name-in-module, no-member + +# pylint: disable=invalid-name, no-name-in-module, no-member, E1101 import math import unittest @@ -17,80 +18,324 @@ from gtsam.symbol_shorthand import A, X from gtsam.utils.test_case import GtsamTestCase -from gtsam import (DiscreteConditional, DiscreteValues, - GaussianConditional, HybridBayesNet, - HybridGaussianConditional, HybridValues, VectorValues, - noiseModel) +from gtsam import ( + AlgebraicDecisionTreeKey, + AssignmentKey, + DecisionTreeFactor, + DiscreteBayesNet, + DiscreteConditional, + DiscreteValues, + GaussianBayesNet, + GaussianConditional, + HybridBayesNet, + HybridGaussianConditional, + HybridGaussianFactorGraph, + HybridValues, + KeySet, + VectorValues, + noiseModel, +) class TestHybridBayesNet(GtsamTestCase): - """Unit tests for HybridValues.""" + """Unit tests for HybridBayesNet.""" - def test_evaluate(self): - """Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).""" - asiaKey = A(0) - Asia = (asiaKey, 2) - - # Create the continuous conditional - I_1x1 = np.eye(1) - conditional = GaussianConditional.FromMeanAndStddev( - X(0), 2 * I_1x1, X(1), [-4], 5.0) - - # Create the noise models - model0 = noiseModel.Diagonal.Sigmas([2.0]) - model1 = noiseModel.Diagonal.Sigmas([3.0]) - - # Create the conditionals - conditional0 = GaussianConditional(X(1), [5], I_1x1, model0) - conditional1 = GaussianConditional(X(1), [2], I_1x1, model1) - - # Create hybrid Bayes net. - bayesNet = HybridBayesNet() - bayesNet.push_back(conditional) - bayesNet.push_back( - HybridGaussianConditional(Asia, [conditional0, conditional1])) - bayesNet.push_back(DiscreteConditional(Asia, "99/1")) - - # Create values at which to evaluate. - values = HybridValues() + def setUp(self): + """Set up a common HybridBayesNet for testing P(X0|X1) P(X1|Asia) P(Asia).""" + self.Asia = (A(0), 2) + + self.I_1x1 = np.eye(1) + self.conditional_X0_X1 = GaussianConditional.FromMeanAndStddev( + X(0), 2 * self.I_1x1, X(1), [-4], 5.0 + ) + + self.model0 = noiseModel.Diagonal.Sigmas([2.0]) + self.model1 = noiseModel.Diagonal.Sigmas([3.0]) + + self.conditional0_X1_A = GaussianConditional(X(1), [5], self.I_1x1, self.model0) + self.conditional1_X1_A1 = GaussianConditional( + X(1), [2], self.I_1x1, self.model1 + ) + + self.hybrid_conditional_X1_Asia = HybridGaussianConditional( + self.Asia, [self.conditional0_X1_A, self.conditional1_X1_A1] + ) + + self.conditional_Asia = DiscreteConditional(self.Asia, "99/1") + + self.bayesNet = HybridBayesNet() + self.bayesNet.push_back(self.conditional_X0_X1) + self.bayesNet.push_back(self.hybrid_conditional_X1_Asia) + self.bayesNet.push_back(self.conditional_Asia) + + self.values = HybridValues() continuous = VectorValues() continuous.insert(X(0), [-6]) continuous.insert(X(1), [1]) - values.insert(continuous) + self.values.insert(continuous) discrete = DiscreteValues() - discrete[asiaKey] = 0 - values.insert(discrete) + discrete[A(0)] = 0 + self.values.insert(discrete) + + def test_constructor_and_basic_props(self): + """Test constructor, empty(), size(), keys(), print(), dot().""" + bn_empty = HybridBayesNet() + self.assertTrue(bn_empty.empty()) + self.assertEqual(bn_empty.size(), 0) + + self.assertFalse(self.bayesNet.empty()) + self.assertEqual(self.bayesNet.size(), 3) + + keys = self.bayesNet.keys() + self.assertIsInstance(keys, KeySet) + self.assertTrue(X(0) in keys) + self.assertTrue(X(1) in keys) + self.assertTrue(A(0) in keys) + self.assertEqual(keys.size(), 3) + + # Test dot (returns a string) + self.assertIsInstance(self.bayesNet.dot(), str) + + def test_equals_method(self): + """Test the equals(HybridBayesNet) method.""" + bn_copy = HybridBayesNet() + bn_copy.push_back(self.conditional_X0_X1) + bn_copy.push_back(self.hybrid_conditional_X1_Asia) + bn_copy.push_back(self.conditional_Asia) + self.assertTrue(self.bayesNet.equals(bn_copy, 1e-9)) + + bn_different_order = HybridBayesNet() # Order matters for BayesNets + bn_different_order.push_back(self.conditional_Asia) # Different order + bn_different_order.push_back(self.conditional_X0_X1) + bn_different_order.push_back(self.hybrid_conditional_X1_Asia) + self.assertFalse(self.bayesNet.equals(bn_different_order, 1e-9)) + + bn_different_cond = HybridBayesNet() + bn_different_cond.push_back(self.conditional_X0_X1) + bn_different_cond.push_back(self.hybrid_conditional_X1_Asia) + # Different P(Asia) + bn_different_cond.push_back(DiscreteConditional(self.Asia, "5/5")) + self.assertFalse(self.bayesNet.equals(bn_different_cond, 1e-9)) + + def test_error_method(self): + """Test the error(HybridValues) method of HybridBayesNet.""" + # logProbability(x) = -(K + error(x)) with K = -log(k) + self.assertAlmostEqual( + self.bayesNet.negLogConstant() + self.bayesNet.error(self.values), + -self.bayesNet.logProbability(self.values), + places=5, + ) + + def test_evaluate(self): + """Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).""" + # Original test_evaluate content + conditionalProbability = self.conditional_X0_X1.evaluate( + self.values.continuous() + ) + # For HybridGaussianConditional, we need to evaluate its chosen component + # given the discrete assignment in self.values + chosen_gaussian_conditional_X1 = self.hybrid_conditional_X1_Asia.choose( + self.values.discrete() + ) + mixtureProbability = chosen_gaussian_conditional_X1.evaluate( + self.values.continuous() + ) + discreteProbability = self.conditional_Asia.evaluate(self.values.discrete()) - conditionalProbability = conditional.evaluate(values.continuous()) - mixtureProbability = conditional0.evaluate(values.continuous()) - self.assertAlmostEqual(conditionalProbability * mixtureProbability * - 0.99, - bayesNet.evaluate(values), - places=5) + self.assertAlmostEqual( + conditionalProbability * mixtureProbability * discreteProbability, + self.bayesNet.evaluate(self.values), + places=5, + ) # Check logProbability - self.assertAlmostEqual(bayesNet.logProbability(values), - math.log(bayesNet.evaluate(values))) + self.assertAlmostEqual( + self.bayesNet.logProbability(self.values), + math.log(self.bayesNet.evaluate(self.values)), + places=5, + ) # Check invariance for all conditionals: - self.check_invariance(bayesNet.at(0).asGaussian(), continuous) - self.check_invariance(bayesNet.at(0).asGaussian(), values) - self.check_invariance(bayesNet.at(0), values) + self.check_invariance( + self.bayesNet.at(0).asGaussian(), self.values.continuous() + ) + self.check_invariance(self.bayesNet.at(0).asGaussian(), self.values) + self.check_invariance(self.bayesNet.at(0), self.values) - self.check_invariance(bayesNet.at(1), values) + self.check_invariance(self.bayesNet.at(1), self.values) - self.check_invariance(bayesNet.at(2).asDiscrete(), discrete) - self.check_invariance(bayesNet.at(2).asDiscrete(), values) - self.check_invariance(bayesNet.at(2), values) + self.check_invariance(self.bayesNet.at(2).asDiscrete(), self.values.discrete()) + self.check_invariance(self.bayesNet.at(2).asDiscrete(), self.values) + self.check_invariance(self.bayesNet.at(2), self.values) - def check_invariance(self, conditional, values): + def check_invariance(self, conditional, values_obj): """Check invariance for given conditional.""" - probability = conditional.evaluate(values) + probability = conditional.evaluate(values_obj) self.assertTrue(probability >= 0.0) - logProb = conditional.logProbability(values) - self.assertAlmostEqual(probability, np.exp(logProb)) - expected = -(conditional.negLogConstant() + conditional.error(values)) - self.assertAlmostEqual(logProb, expected) + logProb = conditional.logProbability(values_obj) + if probability > 1e-9: # Avoid issues with log(0) + self.assertAlmostEqual( + probability, np.exp(logProb), delta=probability * 1e-6 + ) # Relative delta + + expected = -(conditional.negLogConstant() + conditional.error(values_obj)) + self.assertAlmostEqual(logProb, expected, places=5) + + def test_mpe_and_optimize_methods(self): + """Test mpe(), optimize(), and optimize(DiscreteValues).""" + # MPE: Most Probable Explanation for discrete variables + mpe_assignment = self.bayesNet.mpe() + self.assertIsInstance(mpe_assignment, DiscreteValues) + # Given P(Asia) = "99/1", Asia=0 is more probable + self.assertEqual(mpe_assignment[A(0)], 0) + + # Optimize(): MAP for discrete, then optimize continuous given MAP discrete + map_solution = self.bayesNet.optimize() + self.assertIsInstance(map_solution, HybridValues) + self.assertEqual(map_solution.atDiscrete(A(0)), 0) # Asia=0 + # If Asia=0, X1 from conditional0_X1_A: N(mean=[5], model0_sigmas=[2.0]) -> optimal X1=5 + self.assertAlmostEqual(map_solution.at(X(1))[0], 5.0, places=5) + # If X1=5, X0 from conditional_X0_X1: N(mean=2*5-4=6, stddev=5.0) -> optimal X0=6 + self.assertAlmostEqual(map_solution.at(X(0))[0], 6.0, places=5) + + # Optimize(DiscreteValues): optimize continuous given fixed discrete + discrete_choice = DiscreteValues() + discrete_choice[A(0)] = 1 # Fix Asia=1 (less likely choice) + optimized_continuous = self.bayesNet.optimize(discrete_choice) + self.assertIsInstance(optimized_continuous, VectorValues) + # If Asia=1, X1 from conditional1_X1_A1: N(mean=[2], model1_sigmas=[3.0]) -> optimal X1=2 + self.assertAlmostEqual(optimized_continuous.at(X(1))[0], 2.0, places=5) + # If X1=2, X0 from conditional_X0_X1: N(mean=2*2-4=0, stddev=5.0) -> optimal X0=0 + self.assertAlmostEqual(optimized_continuous.at(X(0))[0], 0.0, places=5) + + def test_sampling_methods(self): + """Test sample() and sample(HybridValues).""" + # sample() + full_sample = self.bayesNet.sample() + self.assertIsInstance(full_sample, HybridValues) + self.assertTrue(full_sample.existsDiscrete(A(0))) + self.assertTrue(full_sample.existsVector(X(0))) + self.assertTrue(full_sample.existsVector(X(1))) + self.assertIn(full_sample.atDiscrete(A(0)), [0, 1]) + + # sample(HybridValues) - conditional sampling + given_values = HybridValues() + discrete_given = DiscreteValues() + discrete_given[A(0)] = 0 # Condition on Asia=0 + given_values.insert(discrete_given) + + conditional_sample = self.bayesNet.sample(given_values) + self.assertIsInstance(conditional_sample, HybridValues) + self.assertEqual( + conditional_sample.atDiscrete(A(0)), 0 + ) # Should respect condition + self.assertTrue(conditional_sample.existsVector(X(0))) + self.assertTrue(conditional_sample.existsVector(X(1))) + + def test_marginals_and_choice_methods(self): + """Test discreteMarginal() and choose(DiscreteValues).""" + # discreteMarginal() -> DiscreteBayesNet P(M) + discrete_marginal_bn = self.bayesNet.discreteMarginal() + self.assertIsInstance(discrete_marginal_bn, DiscreteBayesNet) + # Our BN has only one discrete var Asia, so marginal is just P(Asia) + self.assertEqual(discrete_marginal_bn.size(), 1) + d_cond = discrete_marginal_bn.at(0) # Should be P(Asia) + self.gtsamAssertEquals(d_cond, self.conditional_Asia, 1e-9) + + # choose(DiscreteValues) -> GaussianBayesNet P(X | M=m) + fixed_discrete = DiscreteValues() + fixed_discrete[A(0)] = 0 # Fix Asia = 0 + gaussian_bn_given_asia0 = self.bayesNet.choose(fixed_discrete) + self.assertIsInstance(gaussian_bn_given_asia0, GaussianBayesNet) + # Should contain P(X0|X1) and P(X1|Asia=0) + # The order is important for BayesNets, parents are always *last* + self.assertEqual(gaussian_bn_given_asia0.size(), 2) + self.gtsamAssertEquals( + gaussian_bn_given_asia0.at(1), self.conditional0_X1_A, 1e-9 + ) # P(X1|A=0) + self.gtsamAssertEquals( + gaussian_bn_given_asia0.at(0), self.conditional_X0_X1, 1e-9 + ) # P(X0|X1) + + def test_errorTree(self): + """Test errorTree(VectorValues).""" + vector_values = self.values.continuous() + + # errorTree(VectorValues) -> unnormalized log P(M | X=continuous_values) + error_tree = self.bayesNet.errorTree(vector_values) + self.assertIsInstance(error_tree, AlgebraicDecisionTreeKey) + + # Get the errorTree for X1 given Asia=0 (key A(0)): + error_tree_X1_A = self.hybrid_conditional_X1_Asia.errorTree(vector_values) + + # For Asia=0 (key A(0)) from self.values: + assignment_Asia0 = AssignmentKey() + assignment_Asia0[A(0)] = 0 + # self.values has Asia=0 + error_Asia0 = self.conditional_Asia.error(self.values.discrete()) + expected_error_Asia0 = ( + self.conditional_X0_X1.error(vector_values) + + error_tree_X1_A(assignment_Asia0) + + error_Asia0 + ) + self.assertAlmostEqual( + error_tree(assignment_Asia0), expected_error_Asia0, places=4 + ) + + # For Asia=1: + # Need error(A=1) from P(Asia)="99/1" -> P(A=1)=0.01 + dv = DiscreteValues() + dv[A(0)] = 1 + error_Asia1 = self.conditional_Asia.error(dv) + assignment_Asia1 = AssignmentKey() + assignment_Asia1[A(0)] = 1 + expected_error_Asia1 = ( + self.conditional_X0_X1.error(vector_values) + + error_tree_X1_A(assignment_Asia1) + + error_Asia1 + ) + self.assertAlmostEqual( + error_tree(assignment_Asia1), expected_error_Asia1, places=4 + ) + + def test_errorTree_discretePosterior(self): + """Test discretePosterior(VectorValues).""" + vector_values = self.values.continuous() + posterior_tree = self.bayesNet.discretePosterior(vector_values) + self.assertIsInstance(posterior_tree, AlgebraicDecisionTreeKey) + + def test_pruning_methods(self): + """Test prune(maxNrLeaves) and prune(maxNrLeaves, marginalThreshold).""" + # Prune to max 1 leaf (most likely path) + pruned_bn_1leaf = self.bayesNet.prune(1) + self.assertIsInstance(pruned_bn_1leaf, HybridBayesNet) + # The discrete part should now be deterministic for Asia=0 + # TODO(frank): why does it not become a conditional? + mpe_pruned = pruned_bn_1leaf.mpe() + self.assertEqual(mpe_pruned[A(0)], 0) + # The discrete conditional for Asia should reflect P(Asia=0)=1.0 + # TODO(frank): why is there no dead-mode removal here? + actual = pruned_bn_1leaf.discreteMarginal().at(0).evaluate(mpe_pruned) + self.assertAlmostEqual(actual, 1.0, places=5) + + # Prune with marginalThreshold + # P(Asia=0)=0.99, P(Asia=1)=0.01 + # Threshold 0.5 should keep only Asia=0 branch + pruned_bn_thresh = self.bayesNet.prune( + 2, marginalThreshold=0.5 + ) # maxNrLeaves=2 allows both if above thresh + # TODO(Frank): I don't understand *how* discrete got cut here. + self.assertEqual(pruned_bn_thresh.size(), 2) # Should keep both conditionals + + def test_toFactorGraph_method(self): + """Test toFactorGraph(VectorValues) method.""" + # Create measurements for conditioning + measurements = VectorValues() + measurements.insert(X(0), [-5.0]) # Example measurement for X(0) + + hfg = self.bayesNet.toFactorGraph(measurements) + self.assertIsInstance(hfg, HybridGaussianFactorGraph) + self.assertEqual(hfg.size(), self.bayesNet.size()) if __name__ == "__main__": From 2aa6d9360ad9aaa1df3eacb726618e4ec06d070c Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 28 May 2025 09:28:58 -0400 Subject: [PATCH 3/4] Also add HybridConditional changes --- gtsam/hybrid/HybridConditional.cpp | 14 +- gtsam/hybrid/HybridConditional.h | 4 +- gtsam/hybrid/doc/HybridBayesNet.ipynb | 696 +++++++++++++++++++++++ gtsam/hybrid/doc/HybridConditional.ipynb | 414 ++++++++++++++ gtsam/hybrid/hybrid.i | 17 + 5 files changed, 1136 insertions(+), 9 deletions(-) create mode 100644 gtsam/hybrid/doc/HybridBayesNet.ipynb create mode 100644 gtsam/hybrid/doc/HybridConditional.ipynb diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 7fffb06d3b..d398b8b04c 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -112,13 +112,13 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const { } /* ************************************************************************ */ -double HybridConditional::error(const HybridValues &values) const { +double HybridConditional::error(const HybridValues &hybridValues) const { if (auto gc = asGaussian()) { - return gc->error(values.continuous()); + return gc->error(hybridValues.continuous()); } else if (auto gm = asHybrid()) { - return gm->error(values); + return gm->error(hybridValues); } else if (auto dc = asDiscrete()) { - return dc->error(values.discrete()); + return dc->error(hybridValues.discrete()); } else throw std::runtime_error( "HybridConditional::error: conditional type not handled"); @@ -126,11 +126,11 @@ double HybridConditional::error(const HybridValues &values) const { /* ************************************************************************ */ AlgebraicDecisionTree HybridConditional::errorTree( - const VectorValues &values) const { + const VectorValues &continuousValues) const { if (auto gc = asGaussian()) { - return {gc->error(values)}; // NOTE: a "constant" tree + return {gc->error(continuousValues)}; // NOTE: a "constant" tree } else if (auto gm = asHybrid()) { - return gm->errorTree(values); + return gm->errorTree(continuousValues); } else if (auto dc = asDiscrete()) { return dc->errorTree(); } else diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 45b00969b2..d3aa5ade43 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -182,7 +182,7 @@ class GTSAM_EXPORT HybridConditional std::shared_ptr inner() const { return inner_; } /// Return the error of the underlying conditional. - double error(const HybridValues& values) const override; + double error(const HybridValues& hybridValues) const override; /** * @brief Compute error of the HybridConditional as a tree. @@ -192,7 +192,7 @@ class GTSAM_EXPORT HybridConditional * as the conditionals involved, and leaf values as the error. */ AlgebraicDecisionTree errorTree( - const VectorValues& values) const override; + const VectorValues& continuousValues) const override; /// Return the log-probability (or density) of the underlying conditional. double logProbability(const HybridValues& values) const override; diff --git a/gtsam/hybrid/doc/HybridBayesNet.ipynb b/gtsam/hybrid/doc/HybridBayesNet.ipynb new file mode 100644 index 0000000000..03789ceae9 --- /dev/null +++ b/gtsam/hybrid/doc/HybridBayesNet.ipynb @@ -0,0 +1,696 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# HybridBayesNet" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "try:\n", + " import google.colab\n", + " %pip install --quiet gtsam-develop\n", + "except ImportError:\n", + " pass # Not running on Colab, do nothing" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A `HybridBayesNet` represents a directed graphical model (Bayes Net) specifically designed for hybrid systems. It is a collection of `gtsam.HybridConditional` objects, ordered according to an elimination sequence.\n", + "\n", + "It extends `gtsam.BayesNet` and allows representing the joint probability distribution $P(X, M)$ over continuous variables $X$ and discrete variables $M$ as a product of conditional probabilities:\n", + "$$\n", + "P(X, M) = \\prod_i P(\\text{Frontal}_i | \\text{Parents}_i)\n", + "$$\n", + "where each conditional $P(\\text{Frontal}_i | \\text{Parents}_i)$ is stored as a `HybridConditional`. This structure allows for representing complex dependencies, such as continuous variables conditioned on discrete modes ($P(X|M)$) alongside purely discrete ($P(M)$) or purely continuous ($P(X)$) relationships.\n", + "\n", + "`HybridBayesNet` objects are typically obtained by eliminating a `HybridGaussianFactorGraph`." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import gtsam\n", + "import numpy as np\n", + "import graphviz\n", + "\n", + "from gtsam import (\n", + " HybridConditional,\n", + " GaussianConditional,\n", + " DiscreteConditional,\n", + " HybridGaussianConditional,\n", + " HybridGaussianFactorGraph,\n", + " HybridGaussianFactor,\n", + " JacobianFactor,\n", + " DecisionTreeFactor,\n", + " Ordering,\n", + ")\n", + "from gtsam.symbol_shorthand import X, D" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a HybridBayesNet\n", + "\n", + "While they can be constructed manually by adding `HybridConditional`s, they are more commonly obtained via elimination of a `HybridGaussianFactorGraph`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Manually Constructed HybridBayesNet:\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "var7205759403792793600\n", + "\n", + "d0\n", + "\n", + "\n", + "\n", + "var8646911284551352320\n", + "\n", + "x0\n", + "\n", + "\n", + "\n", + "var7205759403792793600->var8646911284551352320\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "var8646911284551352321\n", + "\n", + "x1\n", + "\n", + "\n", + "\n", + "var8646911284551352320->var8646911284551352321\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# --- Method 1: Manual Construction ---\n", + "hbn_manual = gtsam.HybridBayesNet()\n", + "\n", + "# P(D0)\n", + "dk0 = (D(0), 2)\n", + "cond_d0 = DiscreteConditional(dk0, [], \"7/3\") # P(D0=0)=0.7\n", + "hbn_manual.push_back(HybridConditional(cond_d0))\n", + "\n", + "# P(X0 | D0)\n", + "dk0_parent = (D(0), 2)\n", + " # Mode 0: P(X0 | D0=0) = N(0, 1)\n", + "gc0 = GaussianConditional(X(0), np.zeros(1), np.eye(1), gtsam.noiseModel.Unit.Create(1))\n", + " # Mode 1: P(X0 | D0=1) = N(5, 4)\n", + "gc1 = GaussianConditional(X(0), np.array([2.5]), np.eye(1)*0.5, gtsam.noiseModel.Isotropic.Sigma(1,2.0))\n", + "cond_x0_d0 = HybridGaussianConditional(dk0_parent, [gc0, gc1])\n", + "hbn_manual.push_back(HybridConditional(cond_x0_d0))\n", + "\n", + "# P(X1 | X0)\n", + "cond_x1_x0 = GaussianConditional(X(1), np.array([0.0]), np.eye(1), # d, R=I\n", + " X(0), np.eye(1), # Parent X0, S=I\n", + " gtsam.noiseModel.Isotropic.Sigma(1, 1.0)) # N(X1; X0, I)\n", + "hbn_manual.push_back(HybridConditional(cond_x1_x0))\n", + "\n", + "print(\"Manually Constructed HybridBayesNet:\")\n", + "# hbn_manual.print()\n", + "graphviz.Source(hbn_manual.dot())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "f11254b8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Original HybridGaussianFactorGraph for Elimination:\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "var7205759403792793600\n", + "\n", + "d0\n", + "\n", + "\n", + "\n", + "factor0\n", + "\n", + "\n", + "\n", + "\n", + "var7205759403792793600--factor0\n", + "\n", + "\n", + "\n", + "\n", + "factor1\n", + "\n", + "\n", + "\n", + "\n", + "var7205759403792793600--factor1\n", + "\n", + "\n", + "\n", + "\n", + "var8646911284551352320\n", + "\n", + "x0\n", + "\n", + "\n", + "\n", + "var8646911284551352320--factor1\n", + "\n", + "\n", + "\n", + "\n", + "factor2\n", + "\n", + "\n", + "\n", + "\n", + "var8646911284551352320--factor2\n", + "\n", + "\n", + "\n", + "\n", + "var8646911284551352321\n", + "\n", + "x1\n", + "\n", + "\n", + "\n", + "var8646911284551352321--factor2\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# --- Method 2: From Elimination ---\n", + "hgfg = HybridGaussianFactorGraph()\n", + "# P(D0) = 70/30\n", + "hgfg.push_back(DecisionTreeFactor([dk0], \"0.7 0.3\"))\n", + "# P(X0|D0) = mixture N(0,1); N(5,4)\n", + "# Factor version: 0.5*|X0-0|^2/1 + C0 ; 0.5*|X0-5|^2/4 + C1\n", + "factor_gf0 = JacobianFactor(X(0), np.eye(1), np.zeros(1), gtsam.noiseModel.Isotropic.Sigma(1, 1.0))\n", + "factor_gf1 = JacobianFactor(X(0), np.eye(1), np.array([5.0]), gtsam.noiseModel.Isotropic.Sigma(1, 2.0))\n", + "# Store -log(prior) for D0 in the hybrid factor (optional, could keep separate)\n", + "logP_D0_0 = -np.log(0.7)\n", + "logP_D0_1 = -np.log(0.3)\n", + "hgfg.push_back(HybridGaussianFactor(dk0, [(factor_gf0, logP_D0_0), (factor_gf1, logP_D0_1)]))\n", + "# P(X1|X0) = N(X0, 1)\n", + "hgfg.push_back(JacobianFactor(X(0), -np.eye(1), X(1), np.eye(1), np.zeros(1), gtsam.noiseModel.Isotropic.Sigma(1, 1.0)))\n", + "\n", + "print(\"\\nOriginal HybridGaussianFactorGraph for Elimination:\")\n", + "# hgfg.print()\n", + "graphviz.Source(hgfg.dot())" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "07db8de1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "HybridBayesNet from Elimination:\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "var7205759403792793600\n", + "\n", + "d0\n", + "\n", + "\n", + "\n", + "var8646911284551352320\n", + "\n", + "x0\n", + "\n", + "\n", + "\n", + "var7205759403792793600->var8646911284551352320\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "var8646911284551352321\n", + "\n", + "x1\n", + "\n", + "\n", + "\n", + "var8646911284551352320->var8646911284551352321\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Note: Using HybridOrdering(hgfg) is generally recommended: \n", + "# it returns a Colamd constrained ordering where the discrete keys are\n", + "# eliminated after the continuous keys.\n", + "ordering = gtsam.HybridOrdering(hgfg)\n", + "\n", + "hbn_elim, _ = hgfg.eliminatePartialSequential(ordering)\n", + "print(\"\\nHybridBayesNet from Elimination:\")\n", + "# hbn_elim.print()\n", + "graphviz.Source(hbn_elim.dot())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Operations on HybridBayesNet\n", + "\n", + "`HybridBayesNet` allows evaluating the joint probability, sampling, optimizing (finding the MAP state), and extracting marginal or conditional distributions." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "LogProbability P(X0=0.1, X1=0.2, D0=0): -2.160749387689685\n", + "Probability P(X0=0.1, X1=0.2, D0=0): 0.11523873018620859\n", + "\n", + "Sampled HybridValues:\n", + "HybridValues: \n", + " Continuous: 2 elements\n", + " x0: 6.29382\n", + " x1: 6.6918\n", + " Discrete: (d0, 1)\n", + " Nonlinear\n", + "Values with 0 values:\n", + "\n", + "MAP Solution (Optimize):\n", + "HybridValues: \n", + " Continuous: 2 elements\n", + " x0: 0\n", + " x1: 0\n", + " Discrete: (d0, 0)\n", + " Nonlinear\n", + "Values with 0 values:\n", + "\n", + "MPE Discrete Assignment:\n", + "DiscreteValues{7205759403792793600: 0}\n" + ] + } + ], + "source": [ + "# Use the Bayes Net from elimination for consistency\n", + "hbn = hbn_elim\n", + "\n", + "# --- Evaluation ---\n", + "values = gtsam.HybridValues()\n", + "values.insert(D(0), 0)\n", + "values.insert(X(0), np.array([0.1]))\n", + "values.insert(X(1), np.array([0.2]))\n", + "\n", + "log_prob = hbn.logProbability(values)\n", + "prob = hbn.evaluate(values) # Same as exp(log_prob)\n", + "print(f\"\\nLogProbability P(X0=0.1, X1=0.2, D0=0): {log_prob}\")\n", + "print(f\"Probability P(X0=0.1, X1=0.2, D0=0): {prob}\")\n", + "\n", + "# --- Sampling ---\n", + "full_sample = hbn.sample()\n", + "print(\"\\nSampled HybridValues:\")\n", + "full_sample.print()\n", + "\n", + "# --- Optimization (Finding MAP state) ---\n", + "# Computes MPE for discrete, then optimizes continuous given MPE\n", + "map_solution = hbn.optimize()\n", + "print(\"\\nMAP Solution (Optimize):\")\n", + "map_solution.print()\n", + "\n", + "# --- MPE (Most Probable Explanation for Discrete Variables) ---\n", + "mpe_assignment = hbn.mpe()\n", + "print(\"\\nMPE Discrete Assignment:\")\n", + "print(mpe_assignment) # Should match discrete part of map_solution" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d8e3e0ee", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Optimized Continuous Solution for D0=1:\n", + "VectorValues: 2 elements\n", + " x0: 5\n", + " x1: 5\n" + ] + } + ], + "source": [ + "# --- Optimize Continuous given specific Discrete Assignment ---\n", + "dv = gtsam.DiscreteValues()\n", + "dv[D(0)] = 1\n", + "cont_solution_d0_eq_1 = hbn.optimize(dv)\n", + "print(\"\\nOptimized Continuous Solution for D0=1:\")\n", + "cont_solution_d0_eq_1.print()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "758c1790", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Discrete Marginal P(M):\n", + "DiscreteBayesNet\n", + " \n", + "size: 1\n", + "conditional 0: P( d0 ):\n", + " f[ (d0,2), ]\n", + "(d0, 0) | 0.731343 | 0\n", + "(d0, 1) | 0.268657 | 1\n", + "number of nnzs: 2\n", + "\n", + "\n", + "Gaussian Conditional P(X | D0=0):\n", + "\n", + "size: 2\n", + "conditional 0: p(x1 | x0)\n", + " R = [ 1 ]\n", + " S[x0] = [ -1 ]\n", + " d = [ 0 ]\n", + " logNormalizationConstant: -0.918939\n", + " No noise model\n", + "conditional 1: p(x0)\n", + " R = [ 1 ]\n", + " d = [ 0 ]\n", + " mean: 1 elements\n", + " x0: 0\n", + " logNormalizationConstant: -0.918939\n", + " No noise model\n" + ] + } + ], + "source": [ + "# --- Extract Marginal/Conditional Distributions ---\n", + "# Get P(M) = P(D0)\n", + "discrete_marginal_bn = hbn.discreteMarginal()\n", + "print(\"\\nDiscrete Marginal P(M):\")\n", + "discrete_marginal_bn.print()\n", + "\n", + "# Get P(X | M=m) = P(X0, X1 | D0=0)\n", + "dv[D(0)] = 0\n", + "gaussian_conditional_bn = hbn.choose(dv)\n", + "print(\"\\nGaussian Conditional P(X | D0=0):\")\n", + "gaussian_conditional_bn.print()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Advanced Operations (`errorTree`, `discretePosterior`, `prune`)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Error Tree (Unnormalized Log Posterior Log P'(M|x)) for x0=0.5, x1=1.0:\n", + "AlgebraicDecisionTreeKey\n", + " Choice(d0) \n", + "AlgebraicDecisionTreeKey\n", + " 0 Leaf 0.56287232\n", + "AlgebraicDecisionTreeKey\n", + " 1 Leaf 4.663718\n", + "\n", + "Discrete Posterior Tree P(M|x) for x0=0.5, x1=1.0:\n", + "AlgebraicDecisionTreeKey\n", + " Choice(d0) \n", + "AlgebraicDecisionTreeKey\n", + " 0 Leaf 0.98371106\n", + "AlgebraicDecisionTreeKey\n", + " 1 Leaf 0.016288942\n" + ] + } + ], + "source": [ + "# --- Error Tree (Log P'(M|x) = log P(x|M) + log P(M)) ---\n", + "# Evaluate unnormalized log posterior of discrete modes given continuous values\n", + "cont_values_for_error = gtsam.VectorValues()\n", + "cont_values_for_error.insert(X(0), np.array([0.5]))\n", + "cont_values_for_error.insert(X(1), np.array([1.0]))\n", + "\n", + "error_tree = hbn.errorTree(cont_values_for_error)\n", + "print(\"\\nError Tree (Unnormalized Log Posterior Log P'(M|x)) for x0=0.5, x1=1.0:\")\n", + "error_tree.print()\n", + "\n", + "# --- Discrete Posterior P(M|x) ---\n", + "# Normalized version of exp(-errorTree)\n", + "posterior_tree = hbn.discretePosterior(cont_values_for_error)\n", + "print(\"\\nDiscrete Posterior Tree P(M|x) for x0=0.5, x1=1.0:\")\n", + "posterior_tree.print()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9bec1c66", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Pruned HybridBayesNet (max_leaves=1):\n", + "HybridBayesNet\n", + " \n", + "size: 2\n", + "conditional 0: p(x1 | x0)\n", + " R = [ 1 ]\n", + " S[x0] = [ -1 ]\n", + " d = [ 0 ]\n", + " logNormalizationConstant: -0.918939\n", + " No noise model\n", + "conditional 1: p(x0)\n", + " R = [ 1 ]\n", + " d = [ 0 ]\n", + " mean: 1 elements\n", + " x0: 0\n", + " logNormalizationConstant: -0.918939\n", + " No noise model\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "var8646911284551352320\n", + "\n", + "x0\n", + "\n", + "\n", + "\n", + "var8646911284551352321\n", + "\n", + "x1\n", + "\n", + "\n", + "\n", + "var8646911284551352320->var8646911284551352321\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# --- Pruning ---\n", + "# Reduces complexity by removing low-probability discrete branches\n", + "max_leaves = 1 # Force pruning to the most likely mode\n", + "pruned_hbn = hbn.prune(max_leaves, marginalThreshold=0.8)\n", + "\n", + "print(f\"\\nPruned HybridBayesNet (max_leaves={max_leaves}):\")\n", + "pruned_hbn.print()\n", + "# Visualize the pruned Bayes Net\n", + "graphviz.Source(pruned_hbn.dot())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "py312", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/gtsam/hybrid/doc/HybridConditional.ipynb b/gtsam/hybrid/doc/HybridConditional.ipynb new file mode 100644 index 0000000000..c3235dd5e8 --- /dev/null +++ b/gtsam/hybrid/doc/HybridConditional.ipynb @@ -0,0 +1,414 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# HybridConditional" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9cb314a6", + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "try:\n", + " import google.colab\n", + " %pip install --quiet gtsam-develop\n", + "except ImportError:\n", + " pass # Not running on Colab, do nothing" + ] + }, + { + "cell_type": "markdown", + "id": "94f2d735", + "metadata": {}, + "source": [ + "`HybridConditional` acts as a **type-erased wrapper** for different kinds of conditional distributions that can appear in a `HybridBayesNet` or `HybridBayesTree`. It allows these containers to hold conditionals resulting from eliminating different types of variables (discrete, continuous, or mixtures) without needing to be templated on the specific conditional type.\n", + "\n", + "A `HybridConditional` object internally holds a shared pointer to one of the following concrete conditional types:\n", + "* `gtsam.GaussianConditional`\n", + "* `gtsam.DiscreteConditional`\n", + "* `gtsam.HybridGaussianConditional`\n", + "\n", + "It inherits from `HybridFactor` and `Conditional`, providing access to both factor-like properties (keys) and conditional-like properties (frontals, parents).\n", + "\n", + "```mermaid\n", + "graph TD\n", + " HybridConditional --> HybridFactor\n", + " HybridConditional --> Conditional\n", + " HybridConditional -- Holds shared pointer to --> GaussianConditional\n", + " HybridConditional -- Holds shared pointer to --> DiscreteConditional\n", + " HybridConditional -- Holds shared pointer to --> HybridGaussianConditional\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "6e6df071", + "metadata": {}, + "outputs": [], + "source": [ + "import gtsam\n", + "import numpy as np\n", + "\n", + "from gtsam import (\n", + " GaussianConditional,\n", + " DiscreteConditional,\n", + " HybridConditional,\n", + " HybridGaussianConditional,\n", + ")\n", + "from gtsam.symbol_shorthand import X, D" + ] + }, + { + "cell_type": "markdown", + "id": "23447d16", + "metadata": {}, + "source": [ + "## Initialization\n", + "\n", + "A `HybridConditional` is created by wrapping a shared pointer to one of the concrete conditional types. These concrete conditionals are usually obtained from factor graph elimination." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f5ff356f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HybridConditional from GaussianConditional:\n", + "Hybrid Conditional\n", + "p(x0 | x1)\n", + " R = [ 2 ]\n", + " S[x1] = [ 0.5 ]\n", + " d = [ 1 ]\n", + " logNormalizationConstant: 0.467356\n", + "isotropic dim=1 sigma=0.5\n", + "\n", + "HybridConditional from DiscreteConditional:\n", + "Hybrid Conditional\n", + " P( d0 | d1 ):\n", + " Choice(d1) \n", + " 0 Choice(d0) \n", + " 0 0 Leaf 0.8\n", + " 0 1 Leaf 0.2\n", + " 1 Choice(d0) \n", + " 1 0 Leaf 0.2\n", + " 1 1 Leaf 0.8\n", + "\n", + "\n", + "HybridConditional from HybridGaussianConditional:\n", + "Hybrid Conditional\n", + " P( x2 | d2)\n", + " Discrete Keys = (d2, 2), \n", + " logNormalizationConstant: 0.467356\n", + "\n", + " Choice(d2) \n", + " 0 Leaf p(x2)\n", + " R = [ 1 ]\n", + " d = [ 0 ]\n", + " mean: 1 elements\n", + " x2: 0\n", + " logNormalizationConstant: -0.918939\n", + " Noise model: unit (1) \n", + "\n", + " 1 Leaf p(x2)\n", + " R = [ 2 ]\n", + " d = [ 10 ]\n", + " mean: 1 elements\n", + " x2: 5\n", + " logNormalizationConstant: 0.467356\n", + "isotropic dim=1 sigma=0.5\n", + "\n" + ] + } + ], + "source": [ + "# --- Create concrete conditionals (examples) ---\n", + "# 1. GaussianConditional P(X0 | X1)\n", + "gc = GaussianConditional(X(0), np.array([1.0]), np.eye(1)*2.0, # d, R\n", + " X(1), np.array([[0.5]]), # Parent, S\n", + " gtsam.noiseModel.Diagonal.Sigmas([0.5])) # sigma=0.5 -> prec=4 -> R=2\n", + "\n", + "# 2. DiscreteConditional P(D0 | D1) (D0, D1 binary)\n", + "dk0 = (D(0), 2)\n", + "dk1 = (D(1), 2)\n", + "dc = DiscreteConditional(dk0, [dk1], \"4/1 1/4\") # P(D0|D1=0) = 80/20, P(D0|D1=1) = 20/80\n", + "\n", + "# 3. HybridGaussianConditional P(X2 | D2) (X2 1D, D2 binary)\n", + "dk2 = (D(2), 2)\n", + "# Mode 0: P(X2 | D2=0) = N(0, 1) -> R=1, d=0\n", + "hgc_gc0 = GaussianConditional(X(2), np.zeros(1), np.eye(1), gtsam.noiseModel.Unit.Create(1))\n", + "# Mode 1: P(X2 | D2=1) = N(5, 0.25) -> R=2, d=10\n", + "hgc_gc1 = GaussianConditional(X(2), np.array([10.0]), np.eye(1)*2.0, gtsam.noiseModel.Isotropic.Sigma(1,0.5))\n", + "# This constructor takes vector of conditionals directly if parents match\n", + "hgc = HybridGaussianConditional(dk2, [hgc_gc0, hgc_gc1])\n", + "\n", + "# --- Wrap them into HybridConditionals ---\n", + "hybrid_cond_g = HybridConditional(gc)\n", + "hybrid_cond_d = HybridConditional(dc)\n", + "hybrid_cond_h = HybridConditional(hgc)\n", + "\n", + "print(\"HybridConditional from GaussianConditional:\")\n", + "hybrid_cond_g.print()\n", + "print(\"\\nHybridConditional from DiscreteConditional:\")\n", + "hybrid_cond_d.print()\n", + "print(\"\\nHybridConditional from HybridGaussianConditional:\")\n", + "hybrid_cond_h.print()" + ] + }, + { + "cell_type": "markdown", + "id": "da85d70c", + "metadata": {}, + "source": [ + "## Accessing Information and Inner Type\n", + "\n", + "You can access keys, frontals, and parents like any conditional. You can also check the underlying type and attempt to cast back to the concrete type." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "885418ff", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "--- Inspecting HybridConditional from Gaussian ---\n", + "Keys: [8646911284551352320, 8646911284551352321]\n", + "Frontals: 1\n", + "Parents: 1\n", + "Is Continuous? True\n", + "Is Discrete? False\n", + "Is Hybrid? False\n", + "Successfully cast back to GaussianConditional:\n", + "GaussianConditional p(x0 | x1)\n", + " R = [ 2 ]\n", + " S[x1] = [ 0.5 ]\n", + " d = [ 1 ]\n", + " logNormalizationConstant: 0.467356\n", + "isotropic dim=1 sigma=0.5\n", + "Cast back to DiscreteConditional successful? False\n", + "\n", + "--- Inspecting HybridConditional from Hybrid ---\n", + "Keys: [8646911284551352322, 7205759403792793602]\n", + "Frontals: 1\n", + "Parents: 1\n", + "Continuous Keys: [8646911284551352322]\n", + "Discrete Keys: \n", + "d2 2\n", + "\n", + "Is Continuous? False\n", + "Is Discrete? False\n", + "Is Hybrid? True\n", + "Successfully cast back to HybridGaussianConditional.\n" + ] + } + ], + "source": [ + "print(\"\\n--- Inspecting HybridConditional from Gaussian ---\")\n", + "print(f\"Keys: {hybrid_cond_g.keys()}\")\n", + "print(f\"Frontals: {hybrid_cond_g.nrFrontals()}\")\n", + "print(f\"Parents: {hybrid_cond_g.nrParents()}\")\n", + "print(f\"Is Continuous? {hybrid_cond_g.isContinuous()}\") # True\n", + "print(f\"Is Discrete? {hybrid_cond_g.isDiscrete()}\") # False\n", + "print(f\"Is Hybrid? {hybrid_cond_g.isHybrid()}\") # False\n", + "\n", + "# Try casting back\n", + "inner_gaussian = hybrid_cond_g.asGaussian()\n", + "if inner_gaussian:\n", + " print(\"Successfully cast back to GaussianConditional:\")\n", + " inner_gaussian.print()\n", + "else:\n", + " print(\"Failed to cast back to GaussianConditional.\")\n", + "\n", + "inner_discrete = hybrid_cond_g.asDiscrete()\n", + "print(f\"Cast back to DiscreteConditional successful? {inner_discrete is not None}\")\n", + "\n", + "print(\"\\n--- Inspecting HybridConditional from Hybrid ---\")\n", + "print(f\"Keys: {hybrid_cond_h.keys()}\")\n", + "print(f\"Frontals: {hybrid_cond_h.nrFrontals()}\")\n", + "print(f\"Parents: {hybrid_cond_h.nrParents()}\") # Contains continuous AND discrete parents\n", + "print(f\"Continuous Keys: {hybrid_cond_h.continuousKeys()}\")\n", + "print(f\"Discrete Keys: {hybrid_cond_h.discreteKeys()}\")\n", + "print(f\"Is Continuous? {hybrid_cond_h.isContinuous()}\") # False\n", + "print(f\"Is Discrete? {hybrid_cond_h.isDiscrete()}\") # False\n", + "print(f\"Is Hybrid? {hybrid_cond_h.isHybrid()}\") # True\n", + "\n", + "# Try casting back\n", + "inner_hybrid = hybrid_cond_h.asHybrid()\n", + "if inner_hybrid:\n", + " print(\"Successfully cast back to HybridGaussianConditional.\")\n", + "else:\n", + " print(\"Failed to cast back to HybridGaussianConditional.\")" + ] + }, + { + "cell_type": "markdown", + "id": "23fc9fc6", + "metadata": {}, + "source": [ + "## Evaluation (`error`, `logProbability`, `evaluate`)\n", + "\n", + "These methods delegate to the underlying concrete conditional's implementation. They require a `HybridValues` object containing assignments for all involved variables (frontal and parents)." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ce6716ae", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Gaussian HybridConditional P(X0=2|X1=1):\n", + " Error: 24.5\n", + " LogProbability: -24.032644172084783\n", + " Probability: 3.6538881633458336e-11\n", + "\n", + "Discrete HybridConditional P(D0=1|D1=0):\n", + " Error: 1.6094379124341003\n", + " LogProbability: -1.6094379124341003\n", + " Probability: 0.2\n", + "\n", + "Hybrid Gaussian HybridConditional P(X2=4.5|D2=1):\n", + " Error: 2.0\n", + " LogProbability: -1.5326441720847823\n", + " Probability: 0.21596386605275217\n" + ] + } + ], + "source": [ + "# --- Evaluate the Gaussian Conditional P(X0 | X1) ---\n", + "vals_g = gtsam.HybridValues()\n", + "vals_g.insert(X(0), np.array([2.0])) # Frontal\n", + "vals_g.insert(X(1), np.array([1.0])) # Parent\n", + "\n", + "err_g = hybrid_cond_g.error(vals_g)\n", + "log_prob_g = hybrid_cond_g.logProbability(vals_g)\n", + "prob_g = hybrid_cond_g.evaluate(vals_g) # Equivalent to exp(logProbability)\n", + "\n", + "print(f\"\\nGaussian HybridConditional P(X0=2|X1=1):\")\n", + "print(f\" Error: {err_g}\")\n", + "print(f\" LogProbability: {log_prob_g}\")\n", + "print(f\" Probability: {prob_g}\")\n", + "\n", + "# --- Evaluate the Discrete Conditional P(D0 | D1) ---\n", + "vals_d = gtsam.HybridValues()\n", + "vals_d.insert(D(0), 1) # Frontal = 1\n", + "vals_d.insert(D(1), 0) # Parent = 0\n", + "\n", + "err_d = hybrid_cond_d.error(vals_d) # -log(P(D0=1|D1=0)) = -log(0.2)\n", + "log_prob_d = hybrid_cond_d.logProbability(vals_d) # log(0.2)\n", + "prob_d = hybrid_cond_d.evaluate(vals_d) # 0.2\n", + "\n", + "print(f\"\\nDiscrete HybridConditional P(D0=1|D1=0):\")\n", + "print(f\" Error: {err_d}\")\n", + "print(f\" LogProbability: {log_prob_d}\")\n", + "print(f\" Probability: {prob_d}\")\n", + "\n", + "# --- Evaluate the Hybrid Gaussian Conditional P(X2 | D2) ---\n", + "vals_h = gtsam.HybridValues()\n", + "vals_h.insert(X(2), np.array([4.5])) # Frontal\n", + "vals_h.insert(D(2), 1) # Parent (selects mode 1: N(5, 0.25))\n", + "\n", + "err_h = hybrid_cond_h.error(vals_h)\n", + "log_prob_h = hybrid_cond_h.logProbability(vals_h)\n", + "prob_h = hybrid_cond_h.evaluate(vals_h)\n", + "\n", + "print(f\"\\nHybrid Gaussian HybridConditional P(X2=4.5|D2=1):\")\n", + "print(f\" Error: {err_h}\")\n", + "print(f\" LogProbability: {log_prob_h}\")\n", + "print(f\" Probability: {prob_h}\")" + ] + }, + { + "cell_type": "markdown", + "id": "ac079490", + "metadata": {}, + "source": [ + "## Restriction (`restrict`)\n", + "\n", + "The `restrict` method allows fixing the discrete parent variables, potentially simplifying the conditional (e.g., a `HybridGaussianConditional` might become a `GaussianConditional`)." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6f6a5dab", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Restricted HybridConditional (D2=1):p(x2)\n", + " R = [ 2 ]\n", + " d = [ 10 ]\n", + " mean: 1 elements\n", + " x2: 5\n", + " logNormalizationConstant: 0.467356\n", + "isotropic dim=1 sigma=0.5\n" + ] + } + ], + "source": [ + "# Restrict the HybridGaussianConditional P(X2 | D2)\n", + "assignment = gtsam.DiscreteValues()\n", + "assignment[D(2)] = 1 # Fix D2 to mode 1\n", + "\n", + "restricted_factor = hybrid_cond_h.restrict(assignment)\n", + "\n", + "restricted_factor.print(\"\\nRestricted HybridConditional (D2=1):\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "py312", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 264a71bc34..1596aa63a3 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -140,6 +140,17 @@ virtual class HybridFactor : gtsam::Factor { #include virtual class HybridConditional : gtsam::HybridFactor { + HybridConditional(); + HybridConditional(const gtsam::KeyVector& continuousKeys, + const gtsam::DiscreteKeys& discreteKeys, size_t nFrontals); + HybridConditional(const gtsam::KeyVector& continuousFrontals, + const gtsam::DiscreteKeys& discreteFrontals, + const gtsam::KeyVector& continuousParents, + const gtsam::DiscreteKeys& discreteParents); + HybridConditional(const gtsam::GaussianConditional::shared_ptr& continuousConditional); + HybridConditional(const gtsam::DiscreteConditional::shared_ptr& discreteConditional); + HybridConditional(const gtsam::HybridGaussianConditional::shared_ptr& hybridGaussianCond); + void print(string s = "Hybrid Conditional\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; @@ -152,9 +163,14 @@ virtual class HybridConditional : gtsam::HybridFactor { double logProbability(const gtsam::HybridValues& values) const; double evaluate(const gtsam::HybridValues& values) const; double operator()(const gtsam::HybridValues& values) const; + + bool isDiscrete() const; + bool isContinuous() const; + bool isHybrid() const; gtsam::HybridGaussianConditional* asHybrid() const; gtsam::GaussianConditional* asGaussian() const; gtsam::DiscreteConditional* asDiscrete() const; + gtsam::Factor* inner(); }; @@ -338,6 +354,7 @@ class HybridGaussianFactorGraph { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, const gtsam::DotWriter& writer = gtsam::DotWriter()) const; }; +const gtsam::Ordering HybridOrdering(const gtsam::HybridGaussianFactorGraph& graph); #include class HybridNonlinearFactorGraph { From 2f9976e069026895faf125982d25f76b57ebd293 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 28 May 2025 23:06:20 -0400 Subject: [PATCH 4/4] DiscreteBN::prune test --- gtsam/discrete/DecisionTreeFactor.h | 3 ++ gtsam/discrete/tests/testDiscreteBayesNet.cpp | 30 +++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 63f0384aac..38880b0325 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -197,6 +197,9 @@ namespace gtsam { /// Convert into a decision tree DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } + /// Use sum() from AlgebraicDecisionTree + using ADT::sum; + /// Create new factor by summing all values with the same separator values DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { return combine(nrFrontals, Ring::add); diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 61e3a28206..111466ccac 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -122,6 +122,36 @@ TEST(DiscreteBayesNet, Sugar) { bn.add(C | S = "1/1/2 5/2/3"); } +/* ************************************************************************* */ +// Test that pruning a DiscreteBayesNet results in a conditional whose leaves sum to 1. +TEST(DiscreteBayesNet, PruneSumToOne) { + using namespace asia_example; + const DiscreteBayesNet asiaBn = createAsiaExample(); + + // Test with various numbers of max leaves. + // Asia network has 8 binary variables, so 2^8 = 256 possible assignments in the full joint. + std::vector maxLeavesToTest = { 1, 2, 5, 10, 50, 100, 256, 500 }; + + for (size_t maxLeaves : maxLeavesToTest) { + // We expect maxLeaves >= 1 for the sum-to-one property to hold meaningfully. + + DiscreteBayesNet prunedBn = asiaBn.prune(maxLeaves); + + // If the original BN was not empty and maxLeaves >= 1, + // the prunedBN should contain one conditional (the pruned joint). + EXPECT(prunedBn.size() > 0); + if (prunedBn.size() > 0) { + EXPECT_LONGS_EQUAL(1, prunedBn.size()); + + DiscreteConditional::shared_ptr prunedConditional = prunedBn.front(); + CHECK(prunedConditional); // Ensure it's not null + + double sumOfProbs = prunedConditional->sum(); + EXPECT_DOUBLES_EQUAL(1.0, sumOfProbs, 1e-8); + } + } +} + /* ************************************************************************* */ TEST(DiscreteBayesNet, Dot) { using namespace asia_example;