Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ namespace gtsam {
/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }

/// Use sum() from AlgebraicDecisionTree
using ADT::sum;

/// Create new factor by summing all values with the same separator values
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add);
Expand Down
1 change: 1 addition & 0 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ virtual class DiscreteFactor : gtsam::Factor {
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteFactor& lf, double tol = 1e-9) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::AlgebraicDecisionTreeKey errorTree() const;
};

#include <gtsam/discrete/DecisionTreeFactor.h>
Expand Down
30 changes: 30 additions & 0 deletions gtsam/discrete/tests/testDiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,36 @@ TEST(DiscreteBayesNet, Sugar) {
bn.add(C | S = "1/1/2 5/2/3");
}

/* ************************************************************************* */
// Test that pruning a DiscreteBayesNet results in a conditional whose leaves sum to 1.
TEST(DiscreteBayesNet, PruneSumToOne) {
using namespace asia_example;
const DiscreteBayesNet asiaBn = createAsiaExample();

// Test with various numbers of max leaves.
// Asia network has 8 binary variables, so 2^8 = 256 possible assignments in the full joint.
std::vector<size_t> maxLeavesToTest = { 1, 2, 5, 10, 50, 100, 256, 500 };

for (size_t maxLeaves : maxLeavesToTest) {
// We expect maxLeaves >= 1 for the sum-to-one property to hold meaningfully.

DiscreteBayesNet prunedBn = asiaBn.prune(maxLeaves);

// If the original BN was not empty and maxLeaves >= 1,
// the prunedBN should contain one conditional (the pruned joint).
EXPECT(prunedBn.size() > 0);
if (prunedBn.size() > 0) {
EXPECT_LONGS_EQUAL(1, prunedBn.size());

DiscreteConditional::shared_ptr prunedConditional = prunedBn.front();
CHECK(prunedConditional); // Ensure it's not null

double sumOfProbs = prunedConditional->sum();
EXPECT_DOUBLES_EQUAL(1.0, sumOfProbs, 1e-8);
}
}
}

/* ************************************************************************* */
TEST(DiscreteBayesNet, Dot) {
using namespace asia_example;
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param discrete Optional DiscreteValues
* @return double
*/
double negLogConstant(const std::optional<DiscreteValues> &discrete) const;
double negLogConstant(const std::optional<DiscreteValues>& discrete = {}) const;

/**
* @brief Compute normalized posterior P(M|X=x) and return as a tree.
Expand Down
14 changes: 7 additions & 7 deletions gtsam/hybrid/HybridConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,25 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
}

/* ************************************************************************ */
double HybridConditional::error(const HybridValues &values) const {
double HybridConditional::error(const HybridValues &hybridValues) const {
if (auto gc = asGaussian()) {
return gc->error(values.continuous());
return gc->error(hybridValues.continuous());
} else if (auto gm = asHybrid()) {
return gm->error(values);
return gm->error(hybridValues);
} else if (auto dc = asDiscrete()) {
return dc->error(values.discrete());
return dc->error(hybridValues.discrete());
} else
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridConditional::errorTree(
const VectorValues &values) const {
const VectorValues &continuousValues) const {
if (auto gc = asGaussian()) {
return {gc->error(values)}; // NOTE: a "constant" tree
return {gc->error(continuousValues)}; // NOTE: a "constant" tree
} else if (auto gm = asHybrid()) {
return gm->errorTree(values);
return gm->errorTree(continuousValues);
} else if (auto dc = asDiscrete()) {
return dc->errorTree();
} else
Expand Down
4 changes: 2 additions & 2 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class GTSAM_EXPORT HybridConditional
std::shared_ptr<Factor> inner() const { return inner_; }

/// Return the error of the underlying conditional.
double error(const HybridValues& values) const override;
double error(const HybridValues& hybridValues) const override;

/**
* @brief Compute error of the HybridConditional as a tree.
Expand All @@ -192,7 +192,7 @@ class GTSAM_EXPORT HybridConditional
* as the conditionals involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues& values) const override;
const VectorValues& continuousValues) const override;

/// Return the log-probability (or density) of the underlying conditional.
double logProbability(const HybridValues& values) const override;
Expand Down
3 changes: 3 additions & 0 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
template <>
struct traits<HybridFactor> : public Testable<HybridFactor> {};

// For wrapper:
using AlgebraicDecisionTreeKey = AlgebraicDecisionTree<Key>;

} // namespace gtsam
Loading
Loading