Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions .github/workflows/build-special.yml
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,25 @@ jobs:
echo "GTSAM_BUILD_UNSTABLE=OFF" >> $GITHUB_ENV
echo "GTSAM 'unstable' will not be built."

- name: Set Swap Space
- name: Create swap (Linux only)
if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 8
shell: bash
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please document what's going on here? It's much more complicated than what we had before and is not quite obvious.

run: |
set -euo pipefail
SWAP=/mnt/swapfile
sudo swapoff $SWAP 2>/dev/null || true
sudo rm -f $SWAP

for SIZE in 8 4 2 1; do
if sudo fallocate -l ${SIZE}G $SWAP; then
sudo chmod 600 $SWAP
sudo mkswap $SWAP
sudo swapon $SWAP && break
fi
sudo rm -f $SWAP
done

swapon --show

- name: Build & Test
run: |
Expand Down
50 changes: 26 additions & 24 deletions gtsam/discrete/tests/testDiscreteBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/base/TestableAssertions.h>

#include <CppUnitLite/TestHarness.h>

Expand Down Expand Up @@ -290,30 +291,31 @@ TEST(DiscreteBayesTree, Dot) {
std::string actual = self.bayesTree->dot();
// print actual:
if (debug) std::cout << actual << std::endl;
EXPECT(actual ==
"digraph G{\n"
"0[label=\"13, 11, 6, 7\"];\n"
"0->1\n"
"1[label=\"14 : 11, 13\"];\n"
"1->2\n"
"2[label=\"9, 12 : 14\"];\n"
"2->3\n"
"3[label=\"3 : 9, 12\"];\n"
"2->4\n"
"4[label=\"2 : 9, 12\"];\n"
"2->5\n"
"5[label=\"8 : 12, 14\"];\n"
"5->6\n"
"6[label=\"1 : 8, 12\"];\n"
"5->7\n"
"7[label=\"0 : 8, 12\"];\n"
"1->8\n"
"8[label=\"10 : 13, 14\"];\n"
"8->9\n"
"9[label=\"5 : 10, 13\"];\n"
"8->10\n"
"10[label=\"4 : 10, 13\"];\n"
"}");
std::string expected =
R"(digraph G{
13[label="13, 11, 6, 7"];
13->14
14[label="14 : 11, 13"];
14->9
9[label="9, 12 : 14"];
9->3
3[label="3 : 9, 12"];
9->2
2[label="2 : 9, 12"];
9->8
8[label="8 : 12, 14"];
8->1
1[label="1 : 8, 12"];
8->0
0[label="0 : 8, 12"];
14->10
10[label="10 : 13, 14"];
10->5
5[label="5 : 10, 13"];
10->4
4[label="4 : 10, 13"];
})";
EXPECT(assert_equal(expected, actual));
}

/* ************************************************************************* */
Expand Down
14 changes: 7 additions & 7 deletions gtsam/hybrid/HybridConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,25 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
}

/* ************************************************************************ */
double HybridConditional::error(const HybridValues &values) const {
double HybridConditional::error(const HybridValues &hybridValues) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is a bit superficial since HybridConditional::error only ever accepts HybridValues.

if (auto gc = asGaussian()) {
return gc->error(values.continuous());
return gc->error(hybridValues.continuous());
} else if (auto gm = asHybrid()) {
return gm->error(values);
return gm->error(hybridValues);
} else if (auto dc = asDiscrete()) {
return dc->error(values.discrete());
return dc->error(hybridValues.discrete());
} else
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridConditional::errorTree(
const VectorValues &values) const {
const VectorValues &continuousValues) const {
if (auto gc = asGaussian()) {
return {gc->error(values)}; // NOTE: a "constant" tree
return {gc->error(continuousValues)}; // NOTE: a "constant" tree
} else if (auto gm = asHybrid()) {
return gm->errorTree(values);
return gm->errorTree(continuousValues);
} else if (auto dc = asDiscrete()) {
return dc->errorTree();
} else
Expand Down
4 changes: 2 additions & 2 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class GTSAM_EXPORT HybridConditional
std::shared_ptr<Factor> inner() const { return inner_; }

/// Return the error of the underlying conditional.
double error(const HybridValues& values) const override;
double error(const HybridValues& hybridValues) const override;

/**
* @brief Compute error of the HybridConditional as a tree.
Expand All @@ -192,7 +192,7 @@ class GTSAM_EXPORT HybridConditional
* as the conditionals involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues& values) const override;
const VectorValues& continuousValues) const override;

/// Return the log-probability (or density) of the underlying conditional.
double logProbability(const HybridValues& values) const override;
Expand Down
7 changes: 4 additions & 3 deletions gtsam/hybrid/HybridEliminationTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,22 @@ class GTSAM_EXPORT HybridEliminationTree
/// @{

/**
* Build the elimination tree of a factor graph using pre-computed column
* Construct the elimination tree of a factor graph using pre-computed column
* structure.
* @param factorGraph The factor graph for which to build the elimination tree
* @param structure The set of factors involving each variable. If this is
* not precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
* named constructor instead.
* @return The elimination tree
* @param order The ordering of the variables.
*/
HybridEliminationTree(const HybridGaussianFactorGraph& factorGraph,
const VariableIndex& structure, const Ordering& order);

/** Build the elimination tree of a factor graph. Note that this has to
/** Construct the elimination tree of a factor graph. Note that this has to
* compute the column structure as a VariableIndex, so if you already have
* this precomputed, use the other constructor instead.
* @param factorGraph The factor graph for which to build the elimination tree
* @param order The ordering of the variables.
*/
HybridEliminationTree(const HybridGaussianFactorGraph& factorGraph,
const Ordering& order);
Expand Down
5 changes: 4 additions & 1 deletion gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {

/// Compute tree of linear errors.
virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &values) const = 0;
const VectorValues &continuousValues) const = 0;

/// Restrict the factor to the given discrete values.
virtual std::shared_ptr<Factor> restrict(
Expand All @@ -162,4 +162,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
template <>
struct traits<HybridFactor> : public Testable<HybridFactor> {};

// For wrapper:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I haven't wrapped a lot of stuff in hybrid is primarily because of this class, and also because I want to add a new class called KeyDecisionTree which specializes the template AlgebraicDecisionTree.

This would have many benefits since we wouldn't have to always pass in keyFormatter arguments, and wrapping would be a cinch, but it is a big PR which is lower priority than some others (such as DCSAM).

using AlgebraicDecisionTreeKey = AlgebraicDecisionTree<Key>;

} // namespace gtsam
39 changes: 34 additions & 5 deletions gtsam/hybrid/HybridGaussianFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <gtsam/base/utilities.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
Expand Down Expand Up @@ -193,16 +194,44 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
}

/* *******************************************************************************/
double HybridGaussianFactor::error(const HybridValues& values) const {
double HybridGaussianFactor::error(const HybridValues& hybridValues) const {
// Directly index to get the component, no need to build the whole tree.
const GaussianFactorValuePair pair = factors_(values.discrete());
return PotentiallyPrunedComponentError(pair, values.continuous());
const GaussianFactorValuePair pair = factors_(hybridValues.discrete());
return PotentiallyPrunedComponentError(pair, hybridValues.continuous());
}

/* ************************************************************************ */
std::shared_ptr<Factor> HybridGaussianFactor::restrict(
const DiscreteValues& assignment) const {
throw std::runtime_error("HybridGaussianFactor::restrict not implemented");
const DiscreteValues& assignment) const {
FactorValuePairs restrictedTree = this->factors_; // Start with the original tree

const DiscreteKeys& currentFactorDiscreteKeys = this->discreteKeys();
DiscreteKeys newFactorDiscreteKeys; // For the new, restricted factor

// Iterate over the discrete keys of the current factor
for (const DiscreteKey& factor_dk_pair : currentFactorDiscreteKeys) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: The iterate can simply be called discreteKey.

const Key& key = factor_dk_pair.first;

// Check if this key is specified in the assignment
auto assignment_it = assignment.find(key);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be combined with the below line

if (assignment.find(key) != assignment.end()) {

and since assignment is a std::map, you can just do assignment.at(key) to get the assigned index. You shouldn't need the iterator assignment_it.


if (assignment_it != assignment.end()) {
// Key is in assignment: restrict the tree by choosing the branch
size_t assigned_value = assignment_it->second;
restrictedTree = restrictedTree.choose(key, assigned_value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason you do this and not just call restrictedTree.restrict(assignment)?
You can then use the loop only for finding the assigned keys, which can possibly be done with a set intersection so it is faster than looping.

// This key is now fixed, so it's not a discrete key for the new factor
}
else {
// Key is not in assignment: it remains a discrete key for the new factor
newFactorDiscreteKeys.push_back(factor_dk_pair);
}
}

// Create and return the new HybridGaussianFactor.
// Its constructor will derive continuous keys from the GaussianFactor
// shared_ptrs within the restrictedTree.
return std::make_shared<HybridGaussianFactor>(newFactorDiscreteKeys,
restrictedTree);
}

/* ************************************************************************ */
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridGaussianFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* @brief Compute the log-likelihood, including the log-normalizing constant.
* @return double
*/
double error(const HybridValues &values) const override;
double error(const HybridValues &hybridValues) const override;

/// Getter for GaussianFactor decision tree
const FactorValuePairs &factors() const { return factors_; }
Expand Down
10 changes: 1 addition & 9 deletions gtsam/hybrid/HybridJunctionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,7 @@ class GTSAM_EXPORT HybridJunctionTree
typedef HybridJunctionTree This; ///< This class
typedef std::shared_ptr<This> shared_ptr; ///< Shared pointer to this class

/**
* Build the elimination tree of a factor graph using precomputed column
* structure.
* @param factorGraph The factor graph for which to build the elimination tree
* @param structure The set of factors involving each variable. If this is
* not precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
* named constructor instead.
* @return The elimination tree
*/
/// Construct the junction tree from an elimination tree
HybridJunctionTree(const HybridEliminationTree& eliminationTree);
};

Expand Down
5 changes: 3 additions & 2 deletions gtsam/hybrid/HybridNonlinearFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ double HybridNonlinearFactor::error(
}

/* *******************************************************************************/
double HybridNonlinearFactor::error(const HybridValues& values) const {
return error(values.nonlinear(), values.discrete());
double HybridNonlinearFactor::error(const HybridValues& hybridValues) const {
return error(hybridValues.nonlinear(), hybridValues.discrete());
}

/* *******************************************************************************/
Expand All @@ -138,6 +138,7 @@ void HybridNonlinearFactor::print(const std::string& s,
auto [factor, val] = v;
if (factor) {
RedirectCout rd;
std::cout << "(" << val << ") ";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer this to be std::cout << "(val=" << val << ") ";.

factor->print("", keyFormatter);
return rd.str();
} else {
Expand Down
Loading
Loading