-
Notifications
You must be signed in to change notification settings - Fork 870
New docs: hybrid #2120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
New docs: hybrid #2120
Changes from 20 commits
9df0c31
fb4c000
34f43af
d30ec2a
20c773f
8de40a0
42e56a2
0be29f3
e11952f
8a9b02c
f1789bf
f2456f1
64ad4ae
2d956dd
bc9d65e
7507819
938b1b3
a20044c
1d6a3f4
cb98390
661d7e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -112,25 +112,25 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const { | |
| } | ||
|
|
||
| /* ************************************************************************ */ | ||
| double HybridConditional::error(const HybridValues &values) const { | ||
| double HybridConditional::error(const HybridValues &hybridValues) const { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is a bit superficial since |
||
| if (auto gc = asGaussian()) { | ||
| return gc->error(values.continuous()); | ||
| return gc->error(hybridValues.continuous()); | ||
| } else if (auto gm = asHybrid()) { | ||
| return gm->error(values); | ||
| return gm->error(hybridValues); | ||
| } else if (auto dc = asDiscrete()) { | ||
| return dc->error(values.discrete()); | ||
| return dc->error(hybridValues.discrete()); | ||
| } else | ||
| throw std::runtime_error( | ||
| "HybridConditional::error: conditional type not handled"); | ||
| } | ||
|
|
||
| /* ************************************************************************ */ | ||
| AlgebraicDecisionTree<Key> HybridConditional::errorTree( | ||
| const VectorValues &values) const { | ||
| const VectorValues &continuousValues) const { | ||
| if (auto gc = asGaussian()) { | ||
| return {gc->error(values)}; // NOTE: a "constant" tree | ||
| return {gc->error(continuousValues)}; // NOTE: a "constant" tree | ||
| } else if (auto gm = asHybrid()) { | ||
| return gm->errorTree(values); | ||
| return gm->errorTree(continuousValues); | ||
| } else if (auto dc = asDiscrete()) { | ||
| return dc->errorTree(); | ||
| } else | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -135,7 +135,7 @@ class GTSAM_EXPORT HybridFactor : public Factor { | |
|
|
||
| /// Compute tree of linear errors. | ||
| virtual AlgebraicDecisionTree<Key> errorTree( | ||
| const VectorValues &values) const = 0; | ||
| const VectorValues &continuousValues) const = 0; | ||
|
|
||
| /// Restrict the factor to the given discrete values. | ||
| virtual std::shared_ptr<Factor> restrict( | ||
|
|
@@ -162,4 +162,7 @@ class GTSAM_EXPORT HybridFactor : public Factor { | |
| template <> | ||
| struct traits<HybridFactor> : public Testable<HybridFactor> {}; | ||
|
|
||
| // For wrapper: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason I haven't wrapped a lot of stuff in hybrid is primarily because of this class, and also because I want to add a new class called This would have many benefits since we wouldn't have to always pass in |
||
| using AlgebraicDecisionTreeKey = AlgebraicDecisionTree<Key>; | ||
|
|
||
| } // namespace gtsam | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| #include <gtsam/base/utilities.h> | ||
| #include <gtsam/discrete/DecisionTree-inl.h> | ||
| #include <gtsam/discrete/DecisionTree.h> | ||
| #include <gtsam/discrete/DiscreteValues.h> | ||
| #include <gtsam/hybrid/HybridFactor.h> | ||
| #include <gtsam/hybrid/HybridGaussianFactor.h> | ||
| #include <gtsam/hybrid/HybridGaussianProductFactor.h> | ||
|
|
@@ -193,16 +194,44 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree( | |
| } | ||
|
|
||
| /* *******************************************************************************/ | ||
| double HybridGaussianFactor::error(const HybridValues& values) const { | ||
| double HybridGaussianFactor::error(const HybridValues& hybridValues) const { | ||
| // Directly index to get the component, no need to build the whole tree. | ||
| const GaussianFactorValuePair pair = factors_(values.discrete()); | ||
| return PotentiallyPrunedComponentError(pair, values.continuous()); | ||
| const GaussianFactorValuePair pair = factors_(hybridValues.discrete()); | ||
| return PotentiallyPrunedComponentError(pair, hybridValues.continuous()); | ||
| } | ||
|
|
||
| /* ************************************************************************ */ | ||
| std::shared_ptr<Factor> HybridGaussianFactor::restrict( | ||
| const DiscreteValues& assignment) const { | ||
| throw std::runtime_error("HybridGaussianFactor::restrict not implemented"); | ||
| const DiscreteValues& assignment) const { | ||
| FactorValuePairs restrictedTree = this->factors_; // Start with the original tree | ||
|
|
||
| const DiscreteKeys& currentFactorDiscreteKeys = this->discreteKeys(); | ||
| DiscreteKeys newFactorDiscreteKeys; // For the new, restricted factor | ||
|
|
||
| // Iterate over the discrete keys of the current factor | ||
| for (const DiscreteKey& factor_dk_pair : currentFactorDiscreteKeys) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: The iterate can simply be called |
||
| const Key& key = factor_dk_pair.first; | ||
|
|
||
| // Check if this key is specified in the assignment | ||
| auto assignment_it = assignment.find(key); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be combined with the below line if (assignment.find(key) != assignment.end()) {and since |
||
|
|
||
| if (assignment_it != assignment.end()) { | ||
| // Key is in assignment: restrict the tree by choosing the branch | ||
| size_t assigned_value = assignment_it->second; | ||
| restrictedTree = restrictedTree.choose(key, assigned_value); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason you do this and not just call |
||
| // This key is now fixed, so it's not a discrete key for the new factor | ||
| } | ||
| else { | ||
| // Key is not in assignment: it remains a discrete key for the new factor | ||
| newFactorDiscreteKeys.push_back(factor_dk_pair); | ||
| } | ||
| } | ||
|
|
||
| // Create and return the new HybridGaussianFactor. | ||
| // Its constructor will derive continuous keys from the GaussianFactor | ||
| // shared_ptrs within the restrictedTree. | ||
| return std::make_shared<HybridGaussianFactor>(newFactorDiscreteKeys, | ||
| restrictedTree); | ||
| } | ||
|
|
||
| /* ************************************************************************ */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -117,8 +117,8 @@ double HybridNonlinearFactor::error( | |
| } | ||
|
|
||
| /* *******************************************************************************/ | ||
| double HybridNonlinearFactor::error(const HybridValues& values) const { | ||
| return error(values.nonlinear(), values.discrete()); | ||
| double HybridNonlinearFactor::error(const HybridValues& hybridValues) const { | ||
| return error(hybridValues.nonlinear(), hybridValues.discrete()); | ||
| } | ||
|
|
||
| /* *******************************************************************************/ | ||
|
|
@@ -138,6 +138,7 @@ void HybridNonlinearFactor::print(const std::string& s, | |
| auto [factor, val] = v; | ||
| if (factor) { | ||
| RedirectCout rd; | ||
| std::cout << "(" << val << ") "; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer this to be |
||
| factor->print("", keyFormatter); | ||
| return rd.str(); | ||
| } else { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please document what's going on here? It's much more complicated than what we had before and is not quite obvious.