Fix user iteration callbacks.
User callbacks got broken at some point due to the extra
layer of copying from Solver::Options to Minimizer::Options.
This copies the user callbacks when initializing
Minimizer::Options from Solver::Options, and adds a test to
this effect.
This also fixes a bug where the state updating callback was
not called before the user callbacks. This also adds a test
to solver_impl_test to ensure the state updating callbacks
work as expected.
Thanks to Luis Alberto Zarrabeitia for the report.
Issue: 46
Change-Id: I2b36415c89dafaa5c84ecaa727a325df122e1092
diff --git a/internal/ceres/solver_impl_test.cc b/internal/ceres/solver_impl_test.cc
index 81775fb..ef4a6e0 100644
--- a/internal/ceres/solver_impl_test.cc
+++ b/internal/ceres/solver_impl_test.cc
@@ -29,6 +29,7 @@
// Author: sameeragarwal@google.com (Sameer Agarwal)
#include "gtest/gtest.h"
+#include "ceres/autodiff_cost_function.h"
#include "ceres/linear_solver.h"
#include "ceres/parameter_block.h"
#include "ceres/problem_impl.h"
@@ -560,5 +561,69 @@
EXPECT_TRUE(solver.get() != NULL);
}
+struct QuadraticCostFunction {
+ template <typename T> bool operator()(const T* const x,
+ T* residual) const {
+ residual[0] = T(5.0) - *x;
+ return true;
+ }
+};
+
+struct RememberingCallback : public IterationCallback {
+ RememberingCallback(double *x) : calls(0), x(x) {}
+ virtual ~RememberingCallback() {}
+ virtual CallbackReturnType operator()(const IterationSummary& summary) {
+ x_values.push_back(*x);
+ return SOLVER_CONTINUE;
+ }
+ int calls;
+ double *x;
+ vector<double> x_values;
+};
+
+TEST(SolverImpl, UpdateStateEveryIterationOption) {
+ double x = 50.0;
+ const double original_x = x;
+
+ scoped_ptr<CostFunction> cost_function(
+ new AutoDiffCostFunction<QuadraticCostFunction, 1, 1>(
+ new QuadraticCostFunction));
+
+ Problem::Options problem_options;
+ problem_options.cost_function_ownership = DO_NOT_TAKE_OWNERSHIP;
+ Problem problem(problem_options);
+ problem.AddResidualBlock(cost_function.get(), NULL, &x);
+
+ Solver::Options options;
+ options.linear_solver_type = DENSE_QR;
+
+ RememberingCallback callback(&x);
+ options.callbacks.push_back(&callback);
+
+ Solver::Summary summary;
+
+ int num_iterations;
+
+ // First try: no updating.
+ SolverImpl::Solve(options, &problem, &summary);
+ num_iterations = summary.num_successful_steps +
+ summary.num_unsuccessful_steps;
+ EXPECT_GT(num_iterations, 1);
+ for (int i = 0; i < callback.x_values.size(); ++i) {
+ EXPECT_EQ(50.0, callback.x_values[i]);
+ }
+
+ // Second try: with updating
+ x = 50.0;
+ options.update_state_every_iteration = true;
+ callback.x_values.clear();
+ SolverImpl::Solve(options, &problem, &summary);
+ num_iterations = summary.num_successful_steps +
+ summary.num_unsuccessful_steps;
+ EXPECT_GT(num_iterations, 1);
+ EXPECT_EQ(original_x, callback.x_values[0]);
+ EXPECT_NE(original_x, callback.x_values[1]);
+}
+
} // namespace internal
} // namespace ceres