Fix broken constant parameter blocks

This fixes the bug introduced in a previous commit,
and adds a test to check that constant parameter
blocks work as expected.

This also refactors the Solver/SolverImpl split so
that SolverImpl is no longer a friend of Problem;
instead, Solver is. This makes it possible to
verify the invariant on parameter block states in
the unit test, and is a more symmetric design
anyway.

Bug: 51
Change-Id: Id503f5b526cfb8bc24aae3aaad2e414b14063d78
diff --git a/include/ceres/problem.h b/include/ceres/problem.h
index 9710e46..2b08c67 100644
--- a/include/ceres/problem.h
+++ b/include/ceres/problem.h
@@ -50,13 +50,13 @@
 class CostFunction;
 class LossFunction;
 class LocalParameterization;
+class Solver;
 
 namespace internal {
 class Preprocessor;
 class ProblemImpl;
 class ParameterBlock;
 class ResidualBlock;
-class SolverImpl;
 }  // namespace internal
 
 // A ResidualBlockId is a handle clients can use to delete residual
@@ -255,7 +255,7 @@
   int NumResiduals() const;
 
  private:
-  friend class internal::SolverImpl;
+  friend class Solver;
   internal::scoped_ptr<internal::ProblemImpl> problem_impl_;
   CERES_DISALLOW_COPY_AND_ASSIGN(Problem);
 };
diff --git a/internal/ceres/program.cc b/internal/ceres/program.cc
index 529e4a3..3d62272 100644
--- a/internal/ceres/program.cc
+++ b/internal/ceres/program.cc
@@ -90,7 +90,7 @@
   }
 }
 
-bool Program::CopyUserStateToParameterBlocks() {
+bool Program::SetParameterBlockStatePtrsToUserStatePtrs() {
   for (int i = 0; i < parameter_blocks_.size(); ++i) {
     if (!parameter_blocks_[i]->SetState(parameter_blocks_[i]->user_state())) {
       return false;
diff --git a/internal/ceres/program.h b/internal/ceres/program.h
index 27b58e1..1386d3c 100644
--- a/internal/ceres/program.h
+++ b/internal/ceres/program.h
@@ -71,9 +71,13 @@
   bool StateVectorToParameterBlocks(const double *state);
   void ParameterBlocksToStateVector(double *state) const;
 
-  // Copy internal state to and from the user's parameters.
+  // Copy internal state to the user's parameters.
   void CopyParameterBlockStateToUserState();
-  bool CopyUserStateToParameterBlocks();
+
+  // Set the parameter block pointers to the user pointers. Since this
+  // runs parameter block set state internally, which may call local
+  // parameterizations, this can fail. False is returned on failure.
+  bool SetParameterBlockStatePtrsToUserStatePtrs();
 
   // Update a state vector for the program given a delta.
   bool Plus(const double* state,
diff --git a/internal/ceres/solver.cc b/internal/ceres/solver.cc
index c61383c..b8122cd 100644
--- a/internal/ceres/solver.cc
+++ b/internal/ceres/solver.cc
@@ -32,31 +32,32 @@
 #include "ceres/solver.h"
 
 #include <vector>
+#include "ceres/problem.h"
+#include "ceres/problem_impl.h"
 #include "ceres/program.h"
 #include "ceres/solver_impl.h"
 #include "ceres/stringprintf.h"
-#include "ceres/problem.h"
 
 namespace ceres {
 
 Solver::~Solver() {}
 
-// TODO(sameeragarwal): The timing code here should use a sub-second
-// timer.
+// TODO(sameeragarwal): Use subsecond timers.
 void Solver::Solve(const Solver::Options& options,
                    Problem* problem,
                    Solver::Summary* summary) {
   time_t start_time_seconds = time(NULL);
-  internal::SolverImpl::Solve(options, problem, summary);
+  internal::ProblemImpl* problem_impl =
+      CHECK_NOTNULL(problem)->problem_impl_.get();
+  internal::SolverImpl::Solve(options, problem_impl, summary);
   summary->total_time_in_seconds =  time(NULL) - start_time_seconds;
 }
 
 void Solve(const Solver::Options& options,
            Problem* problem,
            Solver::Summary* summary) {
-  time_t start_time_seconds = time(NULL);
-  internal::SolverImpl::Solve(options, problem, summary);
-  summary->total_time_in_seconds =  time(NULL) - start_time_seconds;
+  Solver solver;
+  solver.Solve(options, problem, summary);
 }
 
 Solver::Summary::Summary()
diff --git a/internal/ceres/solver_impl.cc b/internal/ceres/solver_impl.cc
index 982fb40..46c8b1a 100644
--- a/internal/ceres/solver_impl.cc
+++ b/internal/ceres/solver_impl.cc
@@ -168,7 +168,7 @@
 }
 
 void SolverImpl::Solve(const Solver::Options& original_options,
-                       Problem* problem,
+                       ProblemImpl* problem_impl,
                        Solver::Summary* summary) {
   time_t solver_start_time = time(NULL);
   Solver::Options options(original_options);
@@ -199,8 +199,6 @@
       original_options.num_linear_solver_threads;
   summary->ordering_type = original_options.ordering_type;
 
-  ProblemImpl* problem_impl = CHECK_NOTNULL(problem)->problem_impl_.get();
-
   summary->num_parameter_blocks = problem_impl->NumParameterBlocks();
   summary->num_parameters = problem_impl->NumParameters();
   summary->num_residual_blocks = problem_impl->NumResidualBlocks();
@@ -211,22 +209,18 @@
       options.sparse_linear_algebra_library;
   summary->trust_region_strategy_type = options.trust_region_strategy_type;
 
-  // Ensure the program state is set to the user parameters.
-  Program* program = CHECK_NOTNULL(problem_impl)->mutable_program();
-  if (!program->CopyUserStateToParameterBlocks())  {
-    // This can only happen if there was a numerical problem updating the local
-    // jacobians. Indicate as such and fail out.
-    summary->termination_type = NUMERICAL_FAILURE;
-    summary->error = "Local parameterization failure.";
-    return;
-  }
-
   // Evaluate the initial cost and residual vector (if needed). The
   // initial cost needs to be computed on the original unpreprocessed
   // problem, as it is used to determine the value of the "fixed" part
   // of the objective function after the problem has undergone
   // reduction. Also the initial residuals are in the order in which
   // the user added the ResidualBlocks to the optimization problem.
+  //
+  // Note: This assumes the parameter block states are pointing to the
+  // user state at start of Solve(), instead of some other pointer.
+  // The invariant is ensured by the ParameterBlock constructor and by
+  // the call to SetParameterBlockStatePtrsToUserStatePtrs() at the
+  // bottom of this function.
   EvaluateCostAndResiduals(problem_impl,
                            &summary->initial_cost,
                            options.return_initial_residuals
@@ -238,6 +232,9 @@
   // GradientCheckingCostFunction and replacing problem_impl with
   // gradient_checking_problem_impl.
   scoped_ptr<ProblemImpl> gradient_checking_problem_impl;
+  // Save the original problem impl so we don't use the gradient
+  // checking one when computing the residuals.
+  ProblemImpl* original_problem_impl = problem_impl;
   if (options.check_gradients) {
     VLOG(1) << "Checking Gradients";
     gradient_checking_problem_impl.reset(
@@ -318,8 +315,11 @@
   reduced_program->StateVectorToParameterBlocks(parameters.data());
   reduced_program->CopyParameterBlockStateToUserState();
 
+  // Ensure the program state is set to the user parameters on the way out.
+  reduced_program->SetParameterBlockStatePtrsToUserStatePtrs();
+
   // Return the final cost and residuals for the original problem.
-  EvaluateCostAndResiduals(problem->problem_impl_.get(),
+  EvaluateCostAndResiduals(original_problem_impl,
                            &summary->final_cost,
                            options.return_final_residuals
                            ? &summary->final_residuals
diff --git a/internal/ceres/solver_impl.h b/internal/ceres/solver_impl.h
index 7dee03c..6b0340c 100644
--- a/internal/ceres/solver_impl.h
+++ b/internal/ceres/solver_impl.h
@@ -46,7 +46,7 @@
   // Mirrors the interface in solver.h, but exposes implementation
   // details for testing internally.
   static void Solve(const Solver::Options& options,
-                    Problem* problem,
+                    ProblemImpl* problem_impl,
                     Solver::Summary* summary);
 
   // Create the transformed Program, which has all the fixed blocks
diff --git a/internal/ceres/solver_impl_test.cc b/internal/ceres/solver_impl_test.cc
index ef4a6e0..c30abbc 100644
--- a/internal/ceres/solver_impl_test.cc
+++ b/internal/ceres/solver_impl_test.cc
@@ -591,7 +591,7 @@
 
   Problem::Options problem_options;
   problem_options.cost_function_ownership = DO_NOT_TAKE_OWNERSHIP;
-  Problem problem(problem_options);
+  ProblemImpl problem(problem_options);
   problem.AddResidualBlock(cost_function.get(), NULL, &x);
 
   Solver::Options options;
@@ -625,5 +625,64 @@
   EXPECT_NE(original_x, callback.x_values[1]);
 }
 
+// The parameters must be in separate blocks so that they can be individually
+// set constant or not.
+struct Quadratic4DCostFunction {
+  template <typename T> bool operator()(const T* const x,
+                                        const T* const y,
+                                        const T* const z,
+                                        const T* const w,
+                                        T* residual) const {
+    // A 4-dimension axis-aligned quadratic.
+    residual[0] = T(10.0) - *x +
+                  T(20.0) - *y +
+                  T(30.0) - *z +
+                  T(40.0) - *w;
+    return true;
+  }
+};
+
+TEST(SolverImpl, ConstantParameterBlocksDoNotChangeAndStateInvariantKept) {
+  double x = 50.0;
+  double y = 50.0;
+  double z = 50.0;
+  double w = 50.0;
+  const double original_x = 50.0;
+  const double original_y = 50.0;
+  const double original_z = 50.0;
+  const double original_w = 50.0;
+
+  scoped_ptr<CostFunction> cost_function(
+      new AutoDiffCostFunction<Quadratic4DCostFunction, 1, 1, 1, 1, 1>(
+          new Quadratic4DCostFunction));
+
+  Problem::Options problem_options;
+  problem_options.cost_function_ownership = DO_NOT_TAKE_OWNERSHIP;
+
+  ProblemImpl problem(problem_options);
+  problem.AddResidualBlock(cost_function.get(), NULL, &x, &y, &z, &w);
+  problem.SetParameterBlockConstant(&x);
+  problem.SetParameterBlockConstant(&w);
+
+  Solver::Options options;
+  options.linear_solver_type = DENSE_QR;
+
+  Solver::Summary summary;
+  SolverImpl::Solve(options, &problem, &summary);
+
+  // Verify only the non-constant parameters were mutated.
+  EXPECT_EQ(original_x, x);
+  EXPECT_NE(original_y, y);
+  EXPECT_NE(original_z, z);
+  EXPECT_EQ(original_w, w);
+
+  // Check that the parameter block state pointers are pointing back at the
+  // user state, instead of inside a random temporary vector made by Solve().
+  EXPECT_EQ(&x, problem.program().parameter_blocks()[0]->state());
+  EXPECT_EQ(&y, problem.program().parameter_blocks()[1]->state());
+  EXPECT_EQ(&z, problem.program().parameter_blocks()[2]->state());
+  EXPECT_EQ(&w, problem.program().parameter_blocks()[3]->state());
+}
+
 }  // namespace internal
 }  // namespace ceres