Add MixedSparseCholesky.

A simple class that composes SparseCholesky with IterativeRefiner.

Change-Id: I4a67b8ca33a604aaa7b6a4bf511dad9501815f5b
diff --git a/internal/ceres/iterative_refiner.cc b/internal/ceres/iterative_refiner.cc
index 6a5a0c7..fff343d 100644
--- a/internal/ceres/iterative_refiner.cc
+++ b/internal/ceres/iterative_refiner.cc
@@ -46,6 +46,8 @@
       correction_(num_cols),
       lhs_x_solution_(num_cols) {}
 
+IterativeRefiner::~IterativeRefiner() {}
+
 IterativeRefiner::Summary IterativeRefiner::Refine(
     const SparseMatrix& lhs,
     const double* rhs_ptr,
diff --git a/internal/ceres/iterative_refiner.h b/internal/ceres/iterative_refiner.h
index 471116c..d1efb5c 100644
--- a/internal/ceres/iterative_refiner.h
+++ b/internal/ceres/iterative_refiner.h
@@ -84,6 +84,9 @@
   // to perform.
   IterativeRefiner(int num_cols, int max_num_iterations);
 
+  // Needed for mocking.
+  virtual ~IterativeRefiner();
+
   // Given an initial estimate of the solution of lhs * x = rhs, use
   // iterative refinement to improve it.
   //
@@ -92,10 +95,12 @@
   //
   // solution is expected to contain a approximation to the solution
   // to lhs * x = rhs. It can be zero.
-  Summary Refine(const SparseMatrix& lhs,
-                 const double* rhs,
-                 SparseCholesky* sparse_cholesky,
-                 double* solution);
+  //
+  // This method is virtual to facilitate mocking.
+  virtual Summary Refine(const SparseMatrix& lhs,
+                         const double* rhs,
+                         SparseCholesky* sparse_cholesky,
+                         double* solution);
 
  private:
   int num_cols_;
diff --git a/internal/ceres/sparse_cholesky.cc b/internal/ceres/sparse_cholesky.cc
index 1b5e638..3c4e97f 100644
--- a/internal/ceres/sparse_cholesky.cc
+++ b/internal/ceres/sparse_cholesky.cc
@@ -32,6 +32,7 @@
 
 #include "ceres/cxsparse.h"
 #include "ceres/eigensparse.h"
+#include "ceres/iterative_refiner.h"
 #include "ceres/suitesparse.h"
 
 namespace ceres {
@@ -96,5 +97,36 @@
   return CompressedRowSparseMatrix::LOWER_TRIANGULAR;
 }
 
+RefinedSparseCholesky::RefinedSparseCholesky(
+    std::unique_ptr<SparseCholesky> sparse_cholesky,
+    std::unique_ptr<IterativeRefiner> iterative_refiner)
+    : sparse_cholesky_(std::move(sparse_cholesky)),
+      iterative_refiner_(std::move(iterative_refiner)) {}
+
+RefinedSparseCholesky::~RefinedSparseCholesky() {}
+
+CompressedRowSparseMatrix::StorageType RefinedSparseCholesky::StorageType() const {
+  return sparse_cholesky_->StorageType();
+}
+
+LinearSolverTerminationType RefinedSparseCholesky::Factorize(
+    CompressedRowSparseMatrix* lhs, std::string* message) {
+  lhs_ = lhs;
+  return sparse_cholesky_->Factorize(lhs, message);
+}
+
+LinearSolverTerminationType RefinedSparseCholesky::Solve(const double* rhs,
+                                                         double* solution,
+                                                         std::string* message) {
+  CHECK(lhs_ != nullptr);
+  auto termination_type = sparse_cholesky_->Solve(rhs, solution, message);
+  if (termination_type != LINEAR_SOLVER_SUCCESS) {
+    return termination_type;
+  }
+
+  iterative_refiner_->Refine(*lhs_, rhs, sparse_cholesky_.get(), solution);
+  return LINEAR_SOLVER_SUCCESS;
+}
+
 }  // namespace internal
 }  // namespace ceres
diff --git a/internal/ceres/sparse_cholesky.h b/internal/ceres/sparse_cholesky.h
index 96d2bfa..c0e3e86 100644
--- a/internal/ceres/sparse_cholesky.h
+++ b/internal/ceres/sparse_cholesky.h
@@ -34,6 +34,7 @@
 // This include must come before any #ifndef check on Ceres compile options.
 #include "ceres/internal/port.h"
 
+#include <memory>
 #include "ceres/linear_solver.h"
 #include "glog/logging.h"
 
@@ -114,6 +115,29 @@
 
 };
 
+class IterativeRefiner;
+
+// Computes an initial solution using the given instance of
+// SparseCholesky, and then refines it using the IterativeRefiner.
+class RefinedSparseCholesky : public SparseCholesky {
+ public:
+  RefinedSparseCholesky(std::unique_ptr<SparseCholesky> sparse_cholesky,
+                        std::unique_ptr<IterativeRefiner> iterative_refiner);
+  virtual ~RefinedSparseCholesky();
+
+  virtual CompressedRowSparseMatrix::StorageType StorageType() const;
+  virtual LinearSolverTerminationType Factorize(
+      CompressedRowSparseMatrix* lhs, std::string* message);
+  virtual LinearSolverTerminationType Solve(const double* rhs,
+                                            double* solution,
+                                            std::string* message);
+
+ private:
+  std::unique_ptr<SparseCholesky> sparse_cholesky_;
+  std::unique_ptr<IterativeRefiner> iterative_refiner_;
+  CompressedRowSparseMatrix* lhs_ = nullptr;
+};
+
 }  // namespace internal
 }  // namespace ceres
 
diff --git a/internal/ceres/sparse_cholesky_test.cc b/internal/ceres/sparse_cholesky_test.cc
index b75e3aa..f0cd729 100644
--- a/internal/ceres/sparse_cholesky_test.cc
+++ b/internal/ceres/sparse_cholesky_test.cc
@@ -40,8 +40,10 @@
 #include "ceres/compressed_row_sparse_matrix.h"
 #include "ceres/inner_product_computer.h"
 #include "ceres/internal/eigen.h"
+#include "ceres/iterative_refiner.h"
 #include "ceres/random.h"
 #include "glog/logging.h"
+#include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
 namespace ceres {
@@ -214,5 +216,111 @@
                         ParamInfoToString);
 #endif
 
+class MockSparseCholesky : public SparseCholesky {
+ public:
+  MOCK_CONST_METHOD0(StorageType, CompressedRowSparseMatrix::StorageType());
+  MOCK_METHOD2(Factorize,
+               LinearSolverTerminationType(CompressedRowSparseMatrix* lhs,
+                                           std::string* message));
+  MOCK_METHOD3(Solve,
+               LinearSolverTerminationType(const double* rhs,
+                                           double* solution,
+                                           std::string* message));
+};
+
+class MockIterativeRefiner : public IterativeRefiner {
+ public:
+  MockIterativeRefiner() : IterativeRefiner(1, 1) {}
+  MOCK_METHOD4(Refine,
+               Summary(const SparseMatrix& lhs,
+                       const double* rhs,
+                       SparseCholesky* sparse_cholesky,
+                       double* solution));
+};
+
+
+using testing::_;
+using testing::Return;
+
+TEST(RefinedSparseCholesky, StorageType) {
+  MockSparseCholesky* mock_sparse_cholesky = new MockSparseCholesky;
+  MockIterativeRefiner* mock_iterative_refiner = new MockIterativeRefiner;
+  EXPECT_CALL(*mock_sparse_cholesky, StorageType())
+      .Times(1)
+      .WillRepeatedly(Return(CompressedRowSparseMatrix::UPPER_TRIANGULAR));
+  EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _))
+      .Times(0);
+  std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky);
+  std::unique_ptr<IterativeRefiner> iterative_refiner(mock_iterative_refiner);
+  RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky),
+                                                std::move(iterative_refiner));
+  EXPECT_EQ(refined_sparse_cholesky.StorageType(),
+            CompressedRowSparseMatrix::UPPER_TRIANGULAR);
+};
+
+TEST(RefinedSparseCholesky, Factorize) {
+  MockSparseCholesky* mock_sparse_cholesky = new MockSparseCholesky;
+  MockIterativeRefiner* mock_iterative_refiner = new MockIterativeRefiner;
+  EXPECT_CALL(*mock_sparse_cholesky, Factorize(_, _))
+      .Times(1)
+      .WillRepeatedly(Return(LINEAR_SOLVER_SUCCESS));
+  EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _))
+      .Times(0);
+  std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky);
+  std::unique_ptr<IterativeRefiner> iterative_refiner(mock_iterative_refiner);
+  RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky),
+                                                std::move(iterative_refiner));
+  CompressedRowSparseMatrix m(1, 1, 1);
+  std::string message;
+  EXPECT_EQ(refined_sparse_cholesky.Factorize(&m, &message),
+            LINEAR_SOLVER_SUCCESS);
+};
+
+TEST(RefinedSparseCholesky, FactorAndSolveWithUnsuccessfulFactorization) {
+  MockSparseCholesky* mock_sparse_cholesky = new MockSparseCholesky;
+  MockIterativeRefiner* mock_iterative_refiner = new MockIterativeRefiner;
+  EXPECT_CALL(*mock_sparse_cholesky, Factorize(_, _))
+      .Times(1)
+      .WillRepeatedly(Return(LINEAR_SOLVER_FAILURE));
+  EXPECT_CALL(*mock_sparse_cholesky, Solve(_, _, _))
+      .Times(0);
+  EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _))
+      .Times(0);
+  std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky);
+  std::unique_ptr<IterativeRefiner> iterative_refiner(mock_iterative_refiner);
+  RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky),
+                                                std::move(iterative_refiner));
+  CompressedRowSparseMatrix m(1, 1, 1);
+  std::string message;
+  double rhs;
+  double solution;
+  EXPECT_EQ(refined_sparse_cholesky.FactorAndSolve(&m, &rhs, &solution, &message),
+            LINEAR_SOLVER_FAILURE);
+};
+
+TEST(RefinedSparseCholesky, FactorAndSolveWithSuccess) {
+  MockSparseCholesky* mock_sparse_cholesky = new MockSparseCholesky;
+  std::unique_ptr<MockIterativeRefiner> mock_iterative_refiner(new MockIterativeRefiner);
+  EXPECT_CALL(*mock_sparse_cholesky, Factorize(_, _))
+      .Times(1)
+      .WillRepeatedly(Return(LINEAR_SOLVER_SUCCESS));
+  EXPECT_CALL(*mock_sparse_cholesky, Solve(_, _, _))
+      .Times(1)
+      .WillRepeatedly(Return(LINEAR_SOLVER_SUCCESS));
+  EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _))
+      .Times(1);
+
+  std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky);
+  std::unique_ptr<IterativeRefiner> iterative_refiner(std::move(mock_iterative_refiner));
+  RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky),
+                                                std::move(iterative_refiner));
+  CompressedRowSparseMatrix m(1, 1, 1);
+  std::string message;
+  double rhs;
+  double solution;
+  EXPECT_EQ(refined_sparse_cholesky.FactorAndSolve(&m, &rhs, &solution, &message),
+            LINEAR_SOLVER_SUCCESS);
+};
+
 }  // namespace internal
 }  // namespace ceres