Fix broken constant parameter blocks
This fixes the bug introduced in a previous commit,
and adds a test to check that constant parameter
blocks work as expected.
This also refactors the Solver/SolverImpl split so
that SolverImpl is no longer a friend of Problem;
instead, Solver is. This makes it possible to
verify the invariant on parameter block states in
the unit test, and is a more symmetric design
anyway.
Bug: 51
Change-Id: Id503f5b526cfb8bc24aae3aaad2e414b14063d78
diff --git a/include/ceres/problem.h b/include/ceres/problem.h
index 9710e46..2b08c67 100644
--- a/include/ceres/problem.h
+++ b/include/ceres/problem.h
@@ -50,13 +50,13 @@
class CostFunction;
class LossFunction;
class LocalParameterization;
+class Solver;
namespace internal {
class Preprocessor;
class ProblemImpl;
class ParameterBlock;
class ResidualBlock;
-class SolverImpl;
} // namespace internal
// A ResidualBlockId is a handle clients can use to delete residual
@@ -255,7 +255,7 @@
int NumResiduals() const;
private:
- friend class internal::SolverImpl;
+ friend class Solver;
internal::scoped_ptr<internal::ProblemImpl> problem_impl_;
CERES_DISALLOW_COPY_AND_ASSIGN(Problem);
};
diff --git a/internal/ceres/program.cc b/internal/ceres/program.cc
index 529e4a3..3d62272 100644
--- a/internal/ceres/program.cc
+++ b/internal/ceres/program.cc
@@ -90,7 +90,7 @@
}
}
-bool Program::CopyUserStateToParameterBlocks() {
+bool Program::SetParameterBlockStatePtrsToUserStatePtrs() {
for (int i = 0; i < parameter_blocks_.size(); ++i) {
if (!parameter_blocks_[i]->SetState(parameter_blocks_[i]->user_state())) {
return false;
diff --git a/internal/ceres/program.h b/internal/ceres/program.h
index 27b58e1..1386d3c 100644
--- a/internal/ceres/program.h
+++ b/internal/ceres/program.h
@@ -71,9 +71,13 @@
bool StateVectorToParameterBlocks(const double *state);
void ParameterBlocksToStateVector(double *state) const;
- // Copy internal state to and from the user's parameters.
+ // Copy internal state to the user's parameters.
void CopyParameterBlockStateToUserState();
- bool CopyUserStateToParameterBlocks();
+
+ // Set the parameter block pointers to the user pointers. Since this
+ // runs parameter block set state internally, which may call local
+ // parameterizations, this can fail. False is returned on failure.
+ bool SetParameterBlockStatePtrsToUserStatePtrs();
// Update a state vector for the program given a delta.
bool Plus(const double* state,
diff --git a/internal/ceres/solver.cc b/internal/ceres/solver.cc
index c61383c..b8122cd 100644
--- a/internal/ceres/solver.cc
+++ b/internal/ceres/solver.cc
@@ -32,31 +32,32 @@
#include "ceres/solver.h"
#include <vector>
+#include "ceres/problem.h"
+#include "ceres/problem_impl.h"
#include "ceres/program.h"
#include "ceres/solver_impl.h"
#include "ceres/stringprintf.h"
-#include "ceres/problem.h"
namespace ceres {
Solver::~Solver() {}
-// TODO(sameeragarwal): The timing code here should use a sub-second
-// timer.
+// TODO(sameeragarwal): Use subsecond timers.
void Solver::Solve(const Solver::Options& options,
Problem* problem,
Solver::Summary* summary) {
time_t start_time_seconds = time(NULL);
- internal::SolverImpl::Solve(options, problem, summary);
+ internal::ProblemImpl* problem_impl =
+ CHECK_NOTNULL(problem)->problem_impl_.get();
+ internal::SolverImpl::Solve(options, problem_impl, summary);
summary->total_time_in_seconds = time(NULL) - start_time_seconds;
}
void Solve(const Solver::Options& options,
Problem* problem,
Solver::Summary* summary) {
- time_t start_time_seconds = time(NULL);
- internal::SolverImpl::Solve(options, problem, summary);
- summary->total_time_in_seconds = time(NULL) - start_time_seconds;
+ Solver solver;
+ solver.Solve(options, problem, summary);
}
Solver::Summary::Summary()
diff --git a/internal/ceres/solver_impl.cc b/internal/ceres/solver_impl.cc
index 982fb40..46c8b1a 100644
--- a/internal/ceres/solver_impl.cc
+++ b/internal/ceres/solver_impl.cc
@@ -168,7 +168,7 @@
}
void SolverImpl::Solve(const Solver::Options& original_options,
- Problem* problem,
+ ProblemImpl* problem_impl,
Solver::Summary* summary) {
time_t solver_start_time = time(NULL);
Solver::Options options(original_options);
@@ -199,8 +199,6 @@
original_options.num_linear_solver_threads;
summary->ordering_type = original_options.ordering_type;
- ProblemImpl* problem_impl = CHECK_NOTNULL(problem)->problem_impl_.get();
-
summary->num_parameter_blocks = problem_impl->NumParameterBlocks();
summary->num_parameters = problem_impl->NumParameters();
summary->num_residual_blocks = problem_impl->NumResidualBlocks();
@@ -211,22 +209,18 @@
options.sparse_linear_algebra_library;
summary->trust_region_strategy_type = options.trust_region_strategy_type;
- // Ensure the program state is set to the user parameters.
- Program* program = CHECK_NOTNULL(problem_impl)->mutable_program();
- if (!program->CopyUserStateToParameterBlocks()) {
- // This can only happen if there was a numerical problem updating the local
- // jacobians. Indicate as such and fail out.
- summary->termination_type = NUMERICAL_FAILURE;
- summary->error = "Local parameterization failure.";
- return;
- }
-
// Evaluate the initial cost and residual vector (if needed). The
// initial cost needs to be computed on the original unpreprocessed
// problem, as it is used to determine the value of the "fixed" part
// of the objective function after the problem has undergone
// reduction. Also the initial residuals are in the order in which
// the user added the ResidualBlocks to the optimization problem.
+ //
+ // Note: This assumes the parameter block states are pointing to the
+ // user state at start of Solve(), instead of some other pointer.
+ // The invariant is ensured by the ParameterBlock constructor and by
+ // the call to SetParameterBlockStatePtrsToUserStatePtrs() at the
+ // bottom of this function.
EvaluateCostAndResiduals(problem_impl,
&summary->initial_cost,
options.return_initial_residuals
@@ -238,6 +232,9 @@
// GradientCheckingCostFunction and replacing problem_impl with
// gradient_checking_problem_impl.
scoped_ptr<ProblemImpl> gradient_checking_problem_impl;
+ // Save the original problem impl so we don't use the gradient
+ // checking one when computing the residuals.
+ ProblemImpl* original_problem_impl = problem_impl;
if (options.check_gradients) {
VLOG(1) << "Checking Gradients";
gradient_checking_problem_impl.reset(
@@ -318,8 +315,11 @@
reduced_program->StateVectorToParameterBlocks(parameters.data());
reduced_program->CopyParameterBlockStateToUserState();
+ // Ensure the program state is set to the user parameters on the way out.
+ reduced_program->SetParameterBlockStatePtrsToUserStatePtrs();
+
// Return the final cost and residuals for the original problem.
- EvaluateCostAndResiduals(problem->problem_impl_.get(),
+ EvaluateCostAndResiduals(original_problem_impl,
&summary->final_cost,
options.return_final_residuals
? &summary->final_residuals
diff --git a/internal/ceres/solver_impl.h b/internal/ceres/solver_impl.h
index 7dee03c..6b0340c 100644
--- a/internal/ceres/solver_impl.h
+++ b/internal/ceres/solver_impl.h
@@ -46,7 +46,7 @@
// Mirrors the interface in solver.h, but exposes implementation
// details for testing internally.
static void Solve(const Solver::Options& options,
- Problem* problem,
+ ProblemImpl* problem_impl,
Solver::Summary* summary);
// Create the transformed Program, which has all the fixed blocks
diff --git a/internal/ceres/solver_impl_test.cc b/internal/ceres/solver_impl_test.cc
index ef4a6e0..c30abbc 100644
--- a/internal/ceres/solver_impl_test.cc
+++ b/internal/ceres/solver_impl_test.cc
@@ -591,7 +591,7 @@
Problem::Options problem_options;
problem_options.cost_function_ownership = DO_NOT_TAKE_OWNERSHIP;
- Problem problem(problem_options);
+ ProblemImpl problem(problem_options);
problem.AddResidualBlock(cost_function.get(), NULL, &x);
Solver::Options options;
@@ -625,5 +625,64 @@
EXPECT_NE(original_x, callback.x_values[1]);
}
+// The parameters must be in separate blocks so that they can be individually
+// set constant or not.
+struct Quadratic4DCostFunction {
+ template <typename T> bool operator()(const T* const x,
+ const T* const y,
+ const T* const z,
+ const T* const w,
+ T* residual) const {
+ // A 4-dimension axis-aligned quadratic.
+ residual[0] = T(10.0) - *x +
+ T(20.0) - *y +
+ T(30.0) - *z +
+ T(40.0) - *w;
+ return true;
+ }
+};
+
+TEST(SolverImpl, ConstantParameterBlocksDoNotChangeAndStateInvariantKept) {
+ double x = 50.0;
+ double y = 50.0;
+ double z = 50.0;
+ double w = 50.0;
+ const double original_x = 50.0;
+ const double original_y = 50.0;
+ const double original_z = 50.0;
+ const double original_w = 50.0;
+
+ scoped_ptr<CostFunction> cost_function(
+ new AutoDiffCostFunction<Quadratic4DCostFunction, 1, 1, 1, 1, 1>(
+ new Quadratic4DCostFunction));
+
+ Problem::Options problem_options;
+ problem_options.cost_function_ownership = DO_NOT_TAKE_OWNERSHIP;
+
+ ProblemImpl problem(problem_options);
+ problem.AddResidualBlock(cost_function.get(), NULL, &x, &y, &z, &w);
+ problem.SetParameterBlockConstant(&x);
+ problem.SetParameterBlockConstant(&w);
+
+ Solver::Options options;
+ options.linear_solver_type = DENSE_QR;
+
+ Solver::Summary summary;
+ SolverImpl::Solve(options, &problem, &summary);
+
+ // Verify only the non-constant parameters were mutated.
+ EXPECT_EQ(original_x, x);
+ EXPECT_NE(original_y, y);
+ EXPECT_NE(original_z, z);
+ EXPECT_EQ(original_w, w);
+
+ // Check that the parameter block state pointers are pointing back at the
+ // user state, instead of inside a random temporary vector made by Solve().
+ EXPECT_EQ(&x, problem.program().parameter_blocks()[0]->state());
+ EXPECT_EQ(&y, problem.program().parameter_blocks()[1]->state());
+ EXPECT_EQ(&z, problem.program().parameter_blocks()[2]->state());
+ EXPECT_EQ(&w, problem.program().parameter_blocks()[3]->state());
+}
+
} // namespace internal
} // namespace ceres