Fix broken constant parameter blocks This fixes the bug introduced in a previous commit, and adds a test to check that constant parameter blocks work as expected. This also refactors the Solver/SolverImpl split so that SolverImpl is no longer a friend of Problem; instead, Solver is. This makes it possible to verify the invariant on parameter block states in the unit test, and is a more symmetric design anyway. Bug: 51 Change-Id: Id503f5b526cfb8bc24aae3aaad2e414b14063d78
diff --git a/include/ceres/problem.h b/include/ceres/problem.h index 9710e46..2b08c67 100644 --- a/include/ceres/problem.h +++ b/include/ceres/problem.h
@@ -50,13 +50,13 @@ class CostFunction; class LossFunction; class LocalParameterization; +class Solver; namespace internal { class Preprocessor; class ProblemImpl; class ParameterBlock; class ResidualBlock; -class SolverImpl; } // namespace internal // A ResidualBlockId is a handle clients can use to delete residual @@ -255,7 +255,7 @@ int NumResiduals() const; private: - friend class internal::SolverImpl; + friend class Solver; internal::scoped_ptr<internal::ProblemImpl> problem_impl_; CERES_DISALLOW_COPY_AND_ASSIGN(Problem); };
diff --git a/internal/ceres/program.cc b/internal/ceres/program.cc index 529e4a3..3d62272 100644 --- a/internal/ceres/program.cc +++ b/internal/ceres/program.cc
@@ -90,7 +90,7 @@ } } -bool Program::CopyUserStateToParameterBlocks() { +bool Program::SetParameterBlockStatePtrsToUserStatePtrs() { for (int i = 0; i < parameter_blocks_.size(); ++i) { if (!parameter_blocks_[i]->SetState(parameter_blocks_[i]->user_state())) { return false;
diff --git a/internal/ceres/program.h b/internal/ceres/program.h index 27b58e1..1386d3c 100644 --- a/internal/ceres/program.h +++ b/internal/ceres/program.h
@@ -71,9 +71,13 @@ bool StateVectorToParameterBlocks(const double *state); void ParameterBlocksToStateVector(double *state) const; - // Copy internal state to and from the user's parameters. + // Copy internal state to the user's parameters. void CopyParameterBlockStateToUserState(); - bool CopyUserStateToParameterBlocks(); + + // Set the parameter block pointers to the user pointers. Since this + // runs parameter block set state internally, which may call local + // parameterizations, this can fail. False is returned on failure. + bool SetParameterBlockStatePtrsToUserStatePtrs(); // Update a state vector for the program given a delta. bool Plus(const double* state,
diff --git a/internal/ceres/solver.cc b/internal/ceres/solver.cc index c61383c..b8122cd 100644 --- a/internal/ceres/solver.cc +++ b/internal/ceres/solver.cc
@@ -32,31 +32,32 @@ #include "ceres/solver.h" #include <vector> +#include "ceres/problem.h" +#include "ceres/problem_impl.h" #include "ceres/program.h" #include "ceres/solver_impl.h" #include "ceres/stringprintf.h" -#include "ceres/problem.h" namespace ceres { Solver::~Solver() {} -// TODO(sameeragarwal): The timing code here should use a sub-second -// timer. +// TODO(sameeragarwal): Use subsecond timers. void Solver::Solve(const Solver::Options& options, Problem* problem, Solver::Summary* summary) { time_t start_time_seconds = time(NULL); - internal::SolverImpl::Solve(options, problem, summary); + internal::ProblemImpl* problem_impl = + CHECK_NOTNULL(problem)->problem_impl_.get(); + internal::SolverImpl::Solve(options, problem_impl, summary); summary->total_time_in_seconds = time(NULL) - start_time_seconds; } void Solve(const Solver::Options& options, Problem* problem, Solver::Summary* summary) { - time_t start_time_seconds = time(NULL); - internal::SolverImpl::Solve(options, problem, summary); - summary->total_time_in_seconds = time(NULL) - start_time_seconds; + Solver solver; + solver.Solve(options, problem, summary); } Solver::Summary::Summary()
diff --git a/internal/ceres/solver_impl.cc b/internal/ceres/solver_impl.cc index 982fb40..46c8b1a 100644 --- a/internal/ceres/solver_impl.cc +++ b/internal/ceres/solver_impl.cc
@@ -168,7 +168,7 @@ } void SolverImpl::Solve(const Solver::Options& original_options, - Problem* problem, + ProblemImpl* problem_impl, Solver::Summary* summary) { time_t solver_start_time = time(NULL); Solver::Options options(original_options); @@ -199,8 +199,6 @@ original_options.num_linear_solver_threads; summary->ordering_type = original_options.ordering_type; - ProblemImpl* problem_impl = CHECK_NOTNULL(problem)->problem_impl_.get(); - summary->num_parameter_blocks = problem_impl->NumParameterBlocks(); summary->num_parameters = problem_impl->NumParameters(); summary->num_residual_blocks = problem_impl->NumResidualBlocks(); @@ -211,22 +209,18 @@ options.sparse_linear_algebra_library; summary->trust_region_strategy_type = options.trust_region_strategy_type; - // Ensure the program state is set to the user parameters. - Program* program = CHECK_NOTNULL(problem_impl)->mutable_program(); - if (!program->CopyUserStateToParameterBlocks()) { - // This can only happen if there was a numerical problem updating the local - // jacobians. Indicate as such and fail out. - summary->termination_type = NUMERICAL_FAILURE; - summary->error = "Local parameterization failure."; - return; - } - // Evaluate the initial cost and residual vector (if needed). The // initial cost needs to be computed on the original unpreprocessed // problem, as it is used to determine the value of the "fixed" part // of the objective function after the problem has undergone // reduction. Also the initial residuals are in the order in which // the user added the ResidualBlocks to the optimization problem. + // + // Note: This assumes the parameter block states are pointing to the + // user state at start of Solve(), instead of some other pointer. + // The invariant is ensured by the ParameterBlock constructor and by + // the call to SetParameterBlockStatePtrsToUserStatePtrs() at the + // bottom of this function. EvaluateCostAndResiduals(problem_impl, &summary->initial_cost, options.return_initial_residuals @@ -238,6 +232,9 @@ // GradientCheckingCostFunction and replacing problem_impl with // gradient_checking_problem_impl. scoped_ptr<ProblemImpl> gradient_checking_problem_impl; + // Save the original problem impl so we don't use the gradient + // checking one when computing the residuals. + ProblemImpl* original_problem_impl = problem_impl; if (options.check_gradients) { VLOG(1) << "Checking Gradients"; gradient_checking_problem_impl.reset( @@ -318,8 +315,11 @@ reduced_program->StateVectorToParameterBlocks(parameters.data()); reduced_program->CopyParameterBlockStateToUserState(); + // Ensure the program state is set to the user parameters on the way out. + reduced_program->SetParameterBlockStatePtrsToUserStatePtrs(); + // Return the final cost and residuals for the original problem. - EvaluateCostAndResiduals(problem->problem_impl_.get(), + EvaluateCostAndResiduals(original_problem_impl, &summary->final_cost, options.return_final_residuals ? &summary->final_residuals
diff --git a/internal/ceres/solver_impl.h b/internal/ceres/solver_impl.h index 7dee03c..6b0340c 100644 --- a/internal/ceres/solver_impl.h +++ b/internal/ceres/solver_impl.h
@@ -46,7 +46,7 @@ // Mirrors the interface in solver.h, but exposes implementation // details for testing internally. static void Solve(const Solver::Options& options, - Problem* problem, + ProblemImpl* problem_impl, Solver::Summary* summary); // Create the transformed Program, which has all the fixed blocks
diff --git a/internal/ceres/solver_impl_test.cc b/internal/ceres/solver_impl_test.cc index ef4a6e0..c30abbc 100644 --- a/internal/ceres/solver_impl_test.cc +++ b/internal/ceres/solver_impl_test.cc
@@ -591,7 +591,7 @@ Problem::Options problem_options; problem_options.cost_function_ownership = DO_NOT_TAKE_OWNERSHIP; - Problem problem(problem_options); + ProblemImpl problem(problem_options); problem.AddResidualBlock(cost_function.get(), NULL, &x); Solver::Options options; @@ -625,5 +625,64 @@ EXPECT_NE(original_x, callback.x_values[1]); } +// The parameters must be in separate blocks so that they can be individually +// set constant or not. +struct Quadratic4DCostFunction { + template <typename T> bool operator()(const T* const x, + const T* const y, + const T* const z, + const T* const w, + T* residual) const { + // A 4-dimension axis-aligned quadratic. + residual[0] = T(10.0) - *x + + T(20.0) - *y + + T(30.0) - *z + + T(40.0) - *w; + return true; + } +}; + +TEST(SolverImpl, ConstantParameterBlocksDoNotChangeAndStateInvariantKept) { + double x = 50.0; + double y = 50.0; + double z = 50.0; + double w = 50.0; + const double original_x = 50.0; + const double original_y = 50.0; + const double original_z = 50.0; + const double original_w = 50.0; + + scoped_ptr<CostFunction> cost_function( + new AutoDiffCostFunction<Quadratic4DCostFunction, 1, 1, 1, 1, 1>( + new Quadratic4DCostFunction)); + + Problem::Options problem_options; + problem_options.cost_function_ownership = DO_NOT_TAKE_OWNERSHIP; + + ProblemImpl problem(problem_options); + problem.AddResidualBlock(cost_function.get(), NULL, &x, &y, &z, &w); + problem.SetParameterBlockConstant(&x); + problem.SetParameterBlockConstant(&w); + + Solver::Options options; + options.linear_solver_type = DENSE_QR; + + Solver::Summary summary; + SolverImpl::Solve(options, &problem, &summary); + + // Verify only the non-constant parameters were mutated. + EXPECT_EQ(original_x, x); + EXPECT_NE(original_y, y); + EXPECT_NE(original_z, z); + EXPECT_EQ(original_w, w); + + // Check that the parameter block state pointers are pointing back at the + // user state, instead of inside a random temporary vector made by Solve(). + EXPECT_EQ(&x, problem.program().parameter_blocks()[0]->state()); + EXPECT_EQ(&y, problem.program().parameter_blocks()[1]->state()); + EXPECT_EQ(&z, problem.program().parameter_blocks()[2]->state()); + EXPECT_EQ(&w, problem.program().parameter_blocks()[3]->state()); +} + } // namespace internal } // namespace ceres