Skip to content

Commit c0de4de

Browse files
committed
fix #41 (issue with ceres optimizer)
1 parent 6402e86 commit c0de4de

File tree

3 files changed

+54
-141
lines changed

3 files changed

+54
-141
lines changed

include/operon/optimizer/dynamic_cost_function.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
#define OPERON_OPTIMIZER_COST_FUNCTION_HPP
66

77
#ifdef HAVE_CERES
8+
#include <ceres/dynamic_cost_function.h>
9+
#include <ceres/ceres.h>
810
#ifndef CERES_EXPORT
911
#define CERES_EXPORT OPERON_EXPORT
1012
#endif
1113

12-
#include <ceres/ceres.h>
14+
1315
#include "operon/core/contracts.hpp"
1416

1517
namespace Operon {

include/operon/optimizer/optimizer.hpp

+51-30
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ struct LevenbergMarquardtOptimizer<DTable, OptimizerType::Eigen> final : public
151151
Operon::Interpreter<Operon::Scalar, DTable> interpreter{dtable, dataset, tree};
152152
Operon::LMCostFunction<Operon::Scalar> cf{interpreter, target, range};
153153
Eigen::LevenbergMarquardt<decltype(cf)> lm(cf);
154-
lm.setMaxfev(static_cast<int>(iterations+1));
154+
lm.setMaxfev(static_cast<int>(iterations));
155155

156156
auto x0 = tree.GetCoefficients();
157157
OptimizerSummary summary;
@@ -195,50 +195,71 @@ struct LevenbergMarquardtOptimizer<DTable, OptimizerType::Eigen> final : public
195195
};
196196

197197
#if defined(HAVE_CERES)
198-
template <typename T = Operon::Scalar>
199-
struct NonlinearLeastSquaresOptimizer<T, OptimizerType::Ceres> : public OptimizerBase<T> {
200-
explicit NonlinearLeastSquaresOptimizer(InterpreterBase<T>& interpreter)
201-
: OptimizerBase<T>{interpreter}
198+
template <typename DTable>
199+
struct LevenbergMarquardtOptimizer<DTable, OptimizerType::Ceres> final : public OptimizerBase {
200+
explicit LevenbergMarquardtOptimizer(DTable const& dtable, Problem const& problem)
201+
: OptimizerBase{problem}, dtable_{dtable}
202202
{
203203
}
204204

205-
auto Optimize(Operon::Span<Operon::Scalar const> target, Range range, size_t iterations, OptimizerSummary& summary) -> std::vector<Operon::Scalar> final
205+
[[nodiscard]] auto Optimize(Operon::RandomGenerator& /*unused*/, Operon::Tree const& tree) const -> OptimizerSummary final
206206
{
207-
auto const& tree = this->GetTree();
208-
auto const& ds = this->GetDataset();
209-
auto const& dt = this->GetDispatchTable();
210-
211-
auto x0 = tree.GetCoefficients();
207+
auto const& dtable = this->GetDispatchTable();
208+
auto const& problem = this->GetProblem();
209+
auto const& dataset = problem.GetDataset();
210+
auto range = problem.TrainingRange();
211+
auto target = problem.TargetValues(range);
212+
auto iterations = this->Iterations();;
212213

213-
Operon::CostFunction<DTable, Eigen::RowMajor> cf(tree, ds, target, range, dt);
214-
auto costFunction = new Operon::DynamicCostFunction(cf); // NOLINT
214+
auto initialParameters = tree.GetCoefficients();
215+
auto finalParameters = initialParameters;
215216

217+
Operon::Interpreter<Operon::Scalar, DTable> interpreter{dtable, dataset, tree};
218+
Operon::LMCostFunction<Operon::Scalar, Eigen::RowMajor> cf{interpreter, target, range};
219+
auto* dynamicCostFunction = new Operon::DynamicCostFunction{cf};
216220
ceres::Solver::Summary s;
217-
if (!x0.empty()) {
218-
Eigen::Map<Eigen::Matrix<Operon::Scalar, -1, 1>> m0(x0.data(), std::ssize(x0));
219-
auto sz = static_cast<Eigen::Index>(x0.size());
220-
Eigen::VectorXd params = Eigen::Map<Eigen::Matrix<Operon::Scalar, -1, 1>>(x0.data(), sz).template cast<double>();
221+
if (!initialParameters.empty()) {
222+
auto sz = std::ssize(finalParameters);
223+
Eigen::Map<Eigen::Matrix<Operon::Scalar, -1, 1>> m0(finalParameters.data(), sz);
224+
Eigen::VectorXd params = m0.template cast<double>();
221225
ceres::Problem problem;
222-
problem.AddResidualBlock(costFunction, nullptr, params.data());
226+
problem.AddResidualBlock(dynamicCostFunction, nullptr, params.data());
223227
ceres::Solver::Options options;
224228
options.linear_solver_type = ceres::DENSE_QR;
225229
options.logging_type = ceres::LoggingType::SILENT;
226-
options.max_num_iterations = static_cast<int>(iterations - 1); // workaround since for some reason ceres sometimes does 1 more iteration
230+
options.max_num_iterations = static_cast<int>(iterations);
227231
options.minimizer_progress_to_stdout = false;
228232
options.num_threads = 1;
229233
options.trust_region_strategy_type = ceres::LEVENBERG_MARQUARDT;
230234
options.use_inner_iterations = false;
231235
Solve(options, &problem, &s);
232236
m0 = params.cast<Operon::Scalar>();
233237
}
234-
summary.InitialCost = s.initial_cost;
235-
summary.FinalCost = s.final_cost;
236-
summary.Iterations = static_cast<int>(s.iterations.size());
237-
summary.FunctionEvaluations = s.num_residual_evaluations;
238-
summary.JacobianEvaluations = s.num_jacobian_evaluations;
239-
summary.Success = detail::CheckSuccess(summary.InitialCost, summary.FinalCost);
240-
return x0;
238+
return Operon::OptimizerSummary {
239+
.InitialParameters = initialParameters,
240+
.FinalParameters = finalParameters,
241+
.InitialCost = static_cast<Operon::Scalar>(s.initial_cost),
242+
.FinalCost = static_cast<Operon::Scalar>(s.final_cost),
243+
.Iterations = static_cast<int>(s.iterations.size()),
244+
.FunctionEvaluations = s.num_residual_evaluations,
245+
.JacobianEvaluations = s.num_jacobian_evaluations,
246+
.Success = detail::CheckSuccess(s.initial_cost, s.final_cost)
247+
};
241248
}
249+
250+
auto GetDispatchTable() const -> DTable const& { return dtable_.get(); }
251+
252+
[[nodiscard]] auto ComputeLikelihood(Operon::Span<Operon::Scalar const> x, Operon::Span<Operon::Scalar const> y, Operon::Span<Operon::Scalar const> w) const -> Operon::Scalar final
253+
{
254+
return GaussianLikelihood<Operon::Scalar>::ComputeLikelihood(x, y, w);
255+
}
256+
257+
[[nodiscard]] auto ComputeFisherMatrix(Operon::Span<Operon::Scalar const> pred, Operon::Span<Operon::Scalar const> jac, Operon::Span<Operon::Scalar const> sigma) const -> Eigen::Matrix<Operon::Scalar, -1, -1> final {
258+
return GaussianLikelihood<Operon::Scalar>::ComputeFisherMatrix(pred, jac, sigma);
259+
}
260+
261+
private:
262+
std::reference_wrapper<DTable const> dtable_;
242263
};
243264
#endif
244265

@@ -298,12 +319,12 @@ struct LBFGSOptimizer final : public OptimizerBase {
298319

299320
auto GetDispatchTable() const -> DTable const& { return dtable_.get(); }
300321

301-
[[nodiscard]] virtual auto ComputeLikelihood(Operon::Span<Operon::Scalar const> x, Operon::Span<Operon::Scalar const> y, Operon::Span<Operon::Scalar const> w) const -> Operon::Scalar
322+
[[nodiscard]] auto ComputeLikelihood(Operon::Span<Operon::Scalar const> x, Operon::Span<Operon::Scalar const> y, Operon::Span<Operon::Scalar const> w) const -> Operon::Scalar override
302323
{
303324
return LossFunction::ComputeLikelihood(x, y, w);
304325
}
305326

306-
[[nodiscard]] virtual auto ComputeFisherMatrix(Operon::Span<Operon::Scalar const> pred, Operon::Span<Operon::Scalar const> jac, Operon::Span<Operon::Scalar const> sigma) const -> Eigen::Matrix<Operon::Scalar, -1, -1> final {
327+
[[nodiscard]] auto ComputeFisherMatrix(Operon::Span<Operon::Scalar const> pred, Operon::Span<Operon::Scalar const> jac, Operon::Span<Operon::Scalar const> sigma) const -> Eigen::Matrix<Operon::Scalar, -1, -1> final {
307328
return LossFunction::ComputeFisherMatrix(pred, jac, sigma);
308329
}
309330

@@ -371,12 +392,12 @@ struct SGDOptimizer final : public OptimizerBase {
371392
return summary;
372393
}
373394

374-
[[nodiscard]] virtual auto ComputeLikelihood(Operon::Span<Operon::Scalar const> x, Operon::Span<Operon::Scalar const> y, Operon::Span<Operon::Scalar const> w) const -> Operon::Scalar
395+
[[nodiscard]] auto ComputeLikelihood(Operon::Span<Operon::Scalar const> x, Operon::Span<Operon::Scalar const> y, Operon::Span<Operon::Scalar const> w) const -> Operon::Scalar override
375396
{
376397
return LossFunction::ComputeLikelihood(x, y, w);
377398
}
378399

379-
[[nodiscard]] virtual auto ComputeFisherMatrix(Operon::Span<Operon::Scalar const> pred, Operon::Span<Operon::Scalar const> jac, Operon::Span<Operon::Scalar const> sigma) const -> Eigen::Matrix<Operon::Scalar, -1, -1> final {
400+
[[nodiscard]] auto ComputeFisherMatrix(Operon::Span<Operon::Scalar const> pred, Operon::Span<Operon::Scalar const> jac, Operon::Span<Operon::Scalar const> sigma) const -> Eigen::Matrix<Operon::Scalar, -1, -1> final {
380401
return LossFunction::ComputeFisherMatrix(pred, jac, sigma);
381402
}
382403

include/operon/optimizer/tiny_cost_function.hpp

-110
This file was deleted.

0 commit comments

Comments
 (0)