Fix user iteration callbacks. User callbacks got broken at some point due to the extra layer of copying from Solver::Options to Minimizer::Options. This copies the user callbacks when initializing Minimizer::Options from Solver::Options, and adds a test to this effect. This also fixes a bug where the state updating callback was not called before the user callbacks. This also adds a test to solver_impl_test to ensure the state updating callbacks work as expected. Thanks to Luis Alberto Zarrabeitia for the report. Issue: 46 Change-Id: I2b36415c89dafaa5c84ecaa727a325df122e1092
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt index 076805f..d59469c 100644 --- a/internal/ceres/CMakeLists.txt +++ b/internal/ceres/CMakeLists.txt
@@ -182,6 +182,7 @@ CERES_TEST(levenberg_marquardt_strategy) CERES_TEST(local_parameterization) CERES_TEST(loss_function) + CERES_TEST(minimizer) CERES_TEST(normal_prior) CERES_TEST(numeric_diff_cost_function) CERES_TEST(parameter_block)
diff --git a/internal/ceres/minimizer.h b/internal/ceres/minimizer.h index e15b165..70b530f 100644 --- a/internal/ceres/minimizer.h +++ b/internal/ceres/minimizer.h
@@ -78,6 +78,7 @@ evaluator = NULL; trust_region_strategy = NULL; jacobian = NULL; + callbacks = options.callbacks; } int max_num_iterations;
diff --git a/internal/ceres/minimizer_test.cc b/internal/ceres/minimizer_test.cc new file mode 100644 index 0000000..1058036 --- /dev/null +++ b/internal/ceres/minimizer_test.cc
@@ -0,0 +1,63 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2012 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: keir@google.com (Keir Mierle) + +#include "gtest/gtest.h" +#include "ceres/iteration_callback.h" +#include "ceres/minimizer.h" +#include "ceres/solver.h" + +namespace ceres { +namespace internal { + +class FakeIterationCallback : public IterationCallback { + public: + virtual ~FakeIterationCallback() {} + virtual CallbackReturnType operator()(const IterationSummary& summary) { + return SOLVER_CONTINUE; + } +}; + +TEST(MinimizerTest, InitializationCopiesCallbacks) { + FakeIterationCallback callback0; + FakeIterationCallback callback1; + + Solver::Options solver_options; + solver_options.callbacks.push_back(&callback0); + solver_options.callbacks.push_back(&callback1); + + Minimizer::Options minimizer_options(solver_options); + ASSERT_EQ(2, minimizer_options.callbacks.size()); + + EXPECT_EQ(minimizer_options.callbacks[0], &callback0); + EXPECT_EQ(minimizer_options.callbacks[1], &callback1); +} + +} // namespace internal +} // namespace ceres
diff --git a/internal/ceres/solver_impl.cc b/internal/ceres/solver_impl.cc index 506ef2f..ab91681 100644 --- a/internal/ceres/solver_impl.cc +++ b/internal/ceres/solver_impl.cc
@@ -132,12 +132,16 @@ Minimizer::Options minimizer_options(options); LoggingCallback logging_callback(options.minimizer_progress_to_stdout); if (options.logging_type != SILENT) { - minimizer_options.callbacks.push_back(&logging_callback); + minimizer_options.callbacks.insert(minimizer_options.callbacks.begin(), + &logging_callback); } StateUpdatingCallback updating_callback(program, parameters); if (options.update_state_every_iteration) { - minimizer_options.callbacks.push_back(&updating_callback); + // This must get pushed to the front of the callbacks so that it is run + // before any of the user callbacks. + minimizer_options.callbacks.insert(minimizer_options.callbacks.begin(), + &updating_callback); } minimizer_options.evaluator = evaluator;
diff --git a/internal/ceres/solver_impl_test.cc b/internal/ceres/solver_impl_test.cc index 81775fb..ef4a6e0 100644 --- a/internal/ceres/solver_impl_test.cc +++ b/internal/ceres/solver_impl_test.cc
@@ -29,6 +29,7 @@ // Author: sameeragarwal@google.com (Sameer Agarwal) #include "gtest/gtest.h" +#include "ceres/autodiff_cost_function.h" #include "ceres/linear_solver.h" #include "ceres/parameter_block.h" #include "ceres/problem_impl.h" @@ -560,5 +561,69 @@ EXPECT_TRUE(solver.get() != NULL); } +struct QuadraticCostFunction { + template <typename T> bool operator()(const T* const x, + T* residual) const { + residual[0] = T(5.0) - *x; + return true; + } +}; + +struct RememberingCallback : public IterationCallback { + RememberingCallback(double *x) : calls(0), x(x) {} + virtual ~RememberingCallback() {} + virtual CallbackReturnType operator()(const IterationSummary& summary) { + x_values.push_back(*x); + return SOLVER_CONTINUE; + } + int calls; + double *x; + vector<double> x_values; +}; + +TEST(SolverImpl, UpdateStateEveryIterationOption) { + double x = 50.0; + const double original_x = x; + + scoped_ptr<CostFunction> cost_function( + new AutoDiffCostFunction<QuadraticCostFunction, 1, 1>( + new QuadraticCostFunction)); + + Problem::Options problem_options; + problem_options.cost_function_ownership = DO_NOT_TAKE_OWNERSHIP; + Problem problem(problem_options); + problem.AddResidualBlock(cost_function.get(), NULL, &x); + + Solver::Options options; + options.linear_solver_type = DENSE_QR; + + RememberingCallback callback(&x); + options.callbacks.push_back(&callback); + + Solver::Summary summary; + + int num_iterations; + + // First try: no updating. + SolverImpl::Solve(options, &problem, &summary); + num_iterations = summary.num_successful_steps + + summary.num_unsuccessful_steps; + EXPECT_GT(num_iterations, 1); + for (int i = 0; i < callback.x_values.size(); ++i) { + EXPECT_EQ(50.0, callback.x_values[i]); + } + + // Second try: with updating + x = 50.0; + options.update_state_every_iteration = true; + callback.x_values.clear(); + SolverImpl::Solve(options, &problem, &summary); + num_iterations = summary.num_successful_steps + + summary.num_unsuccessful_steps; + EXPECT_GT(num_iterations, 1); + EXPECT_EQ(original_x, callback.x_values[0]); + EXPECT_NE(original_x, callback.x_values[1]); +} + } // namespace internal } // namespace ceres
diff --git a/internal/ceres/system_test.cc b/internal/ceres/system_test.cc index 3eaebc3..405dc69 100644 --- a/internal/ceres/system_test.cc +++ b/internal/ceres/system_test.cc
@@ -500,8 +500,8 @@ #endif // CERES_NO_SUITESPARSE #ifndef CERES_NO_CXSPARSE - CONFIGURE(SPARSE_SCHUR, CX_SPARSE, USER, IDENTITY, 1); - CONFIGURE(SPARSE_SCHUR, CX_SPARSE, SCHUR, IDENTITY, 1); + CONFIGURE(SPARSE_SCHUR, CX_SPARSE, USER, IDENTITY, 1); + CONFIGURE(SPARSE_SCHUR, CX_SPARSE, SCHUR, IDENTITY, 1); #endif // CERES_NO_CXSPARSE CONFIGURE(DENSE_SCHUR, SUITE_SPARSE, USER, IDENTITY, 1);
diff --git a/internal/ceres/trust_region_minimizer.cc b/internal/ceres/trust_region_minimizer.cc index 4d0c91e..c475690 100644 --- a/internal/ceres/trust_region_minimizer.cc +++ b/internal/ceres/trust_region_minimizer.cc
@@ -59,7 +59,6 @@ // the callbacks does not return SOLVER_CONTINUE, then stop and return // its status. CallbackReturnType TrustRegionMinimizer::RunCallbacks( - const Minimizer::Options& options_, const IterationSummary& iteration_summary) { for (int i = 0; i < options_.callbacks.size(); ++i) { const CallbackReturnType status = @@ -219,7 +218,7 @@ summary->iterations.push_back(iteration_summary); // Call the various callbacks. - switch (RunCallbacks(options_, iteration_summary)) { + switch (RunCallbacks(iteration_summary)) { case SOLVER_TERMINATE_SUCCESSFULLY: summary->termination_type = USER_SUCCESS; VLOG(1) << "Terminating: User callback returned USER_SUCCESS."; @@ -441,7 +440,7 @@ summary->preprocessor_time_in_seconds; summary->iterations.push_back(iteration_summary); - switch (RunCallbacks(options_, iteration_summary)) { + switch (RunCallbacks(iteration_summary)) { case SOLVER_TERMINATE_SUCCESSFULLY: summary->termination_type = USER_SUCCESS; VLOG(1) << "Terminating: User callback returned USER_SUCCESS.";
diff --git a/internal/ceres/trust_region_minimizer.h b/internal/ceres/trust_region_minimizer.h index 4337b18..a4f5ba3 100644 --- a/internal/ceres/trust_region_minimizer.h +++ b/internal/ceres/trust_region_minimizer.h
@@ -53,8 +53,7 @@ private: void Init(const Minimizer::Options& options); void EstimateScale(const SparseMatrix& jacobian, double* scale) const; - CallbackReturnType RunCallbacks(const Minimizer::Options& options, - const IterationSummary& iteration_summary); + CallbackReturnType RunCallbacks(const IterationSummary& iteration_summary); bool MaybeDumpLinearLeastSquaresProblem( const int iteration, const SparseMatrix* jacobian, const double* residuals,