Add MixedSparseCholesky.
A simple class that composes SparseCholesky with IterativeRefiner.
Change-Id: I4a67b8ca33a604aaa7b6a4bf511dad9501815f5b
diff --git a/internal/ceres/iterative_refiner.cc b/internal/ceres/iterative_refiner.cc
index 6a5a0c7..fff343d 100644
--- a/internal/ceres/iterative_refiner.cc
+++ b/internal/ceres/iterative_refiner.cc
@@ -46,6 +46,8 @@
correction_(num_cols),
lhs_x_solution_(num_cols) {}
+IterativeRefiner::~IterativeRefiner() {}
+
IterativeRefiner::Summary IterativeRefiner::Refine(
const SparseMatrix& lhs,
const double* rhs_ptr,
diff --git a/internal/ceres/iterative_refiner.h b/internal/ceres/iterative_refiner.h
index 471116c..d1efb5c 100644
--- a/internal/ceres/iterative_refiner.h
+++ b/internal/ceres/iterative_refiner.h
@@ -84,6 +84,9 @@
// to perform.
IterativeRefiner(int num_cols, int max_num_iterations);
+ // Needed for mocking.
+ virtual ~IterativeRefiner();
+
// Given an initial estimate of the solution of lhs * x = rhs, use
// iterative refinement to improve it.
//
@@ -92,10 +95,12 @@
//
// solution is expected to contain a approximation to the solution
// to lhs * x = rhs. It can be zero.
- Summary Refine(const SparseMatrix& lhs,
- const double* rhs,
- SparseCholesky* sparse_cholesky,
- double* solution);
+ //
+ // This method is virtual to facilitate mocking.
+ virtual Summary Refine(const SparseMatrix& lhs,
+ const double* rhs,
+ SparseCholesky* sparse_cholesky,
+ double* solution);
private:
int num_cols_;
diff --git a/internal/ceres/sparse_cholesky.cc b/internal/ceres/sparse_cholesky.cc
index 1b5e638..3c4e97f 100644
--- a/internal/ceres/sparse_cholesky.cc
+++ b/internal/ceres/sparse_cholesky.cc
@@ -32,6 +32,7 @@
#include "ceres/cxsparse.h"
#include "ceres/eigensparse.h"
+#include "ceres/iterative_refiner.h"
#include "ceres/suitesparse.h"
namespace ceres {
@@ -96,5 +97,36 @@
return CompressedRowSparseMatrix::LOWER_TRIANGULAR;
}
+RefinedSparseCholesky::RefinedSparseCholesky(
+ std::unique_ptr<SparseCholesky> sparse_cholesky,
+ std::unique_ptr<IterativeRefiner> iterative_refiner)
+ : sparse_cholesky_(std::move(sparse_cholesky)),
+ iterative_refiner_(std::move(iterative_refiner)) {}
+
+RefinedSparseCholesky::~RefinedSparseCholesky() {}
+
+CompressedRowSparseMatrix::StorageType RefinedSparseCholesky::StorageType() const {
+ return sparse_cholesky_->StorageType();
+}
+
+LinearSolverTerminationType RefinedSparseCholesky::Factorize(
+ CompressedRowSparseMatrix* lhs, std::string* message) {
+ lhs_ = lhs;
+ return sparse_cholesky_->Factorize(lhs, message);
+}
+
+LinearSolverTerminationType RefinedSparseCholesky::Solve(const double* rhs,
+ double* solution,
+ std::string* message) {
+ CHECK(lhs_ != nullptr);
+ auto termination_type = sparse_cholesky_->Solve(rhs, solution, message);
+ if (termination_type != LINEAR_SOLVER_SUCCESS) {
+ return termination_type;
+ }
+
+ iterative_refiner_->Refine(*lhs_, rhs, sparse_cholesky_.get(), solution);
+ return LINEAR_SOLVER_SUCCESS;
+}
+
} // namespace internal
} // namespace ceres
diff --git a/internal/ceres/sparse_cholesky.h b/internal/ceres/sparse_cholesky.h
index 96d2bfa..c0e3e86 100644
--- a/internal/ceres/sparse_cholesky.h
+++ b/internal/ceres/sparse_cholesky.h
@@ -34,6 +34,7 @@
// This include must come before any #ifndef check on Ceres compile options.
#include "ceres/internal/port.h"
+#include <memory>
#include "ceres/linear_solver.h"
#include "glog/logging.h"
@@ -114,6 +115,29 @@
};
+class IterativeRefiner;
+
+// Computes an initial solution using the given instance of
+// SparseCholesky, and then refines it using the IterativeRefiner.
+class RefinedSparseCholesky : public SparseCholesky {
+ public:
+ RefinedSparseCholesky(std::unique_ptr<SparseCholesky> sparse_cholesky,
+ std::unique_ptr<IterativeRefiner> iterative_refiner);
+ virtual ~RefinedSparseCholesky();
+
+ virtual CompressedRowSparseMatrix::StorageType StorageType() const;
+ virtual LinearSolverTerminationType Factorize(
+ CompressedRowSparseMatrix* lhs, std::string* message);
+ virtual LinearSolverTerminationType Solve(const double* rhs,
+ double* solution,
+ std::string* message);
+
+ private:
+ std::unique_ptr<SparseCholesky> sparse_cholesky_;
+ std::unique_ptr<IterativeRefiner> iterative_refiner_;
+ CompressedRowSparseMatrix* lhs_ = nullptr;
+};
+
} // namespace internal
} // namespace ceres
diff --git a/internal/ceres/sparse_cholesky_test.cc b/internal/ceres/sparse_cholesky_test.cc
index b75e3aa..f0cd729 100644
--- a/internal/ceres/sparse_cholesky_test.cc
+++ b/internal/ceres/sparse_cholesky_test.cc
@@ -40,8 +40,10 @@
#include "ceres/compressed_row_sparse_matrix.h"
#include "ceres/inner_product_computer.h"
#include "ceres/internal/eigen.h"
+#include "ceres/iterative_refiner.h"
#include "ceres/random.h"
#include "glog/logging.h"
+#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace ceres {
@@ -214,5 +216,111 @@
ParamInfoToString);
#endif
+class MockSparseCholesky : public SparseCholesky {
+ public:
+ MOCK_CONST_METHOD0(StorageType, CompressedRowSparseMatrix::StorageType());
+ MOCK_METHOD2(Factorize,
+ LinearSolverTerminationType(CompressedRowSparseMatrix* lhs,
+ std::string* message));
+ MOCK_METHOD3(Solve,
+ LinearSolverTerminationType(const double* rhs,
+ double* solution,
+ std::string* message));
+};
+
+class MockIterativeRefiner : public IterativeRefiner {
+ public:
+ MockIterativeRefiner() : IterativeRefiner(1, 1) {}
+ MOCK_METHOD4(Refine,
+ Summary(const SparseMatrix& lhs,
+ const double* rhs,
+ SparseCholesky* sparse_cholesky,
+ double* solution));
+};
+
+
+using testing::_;
+using testing::Return;
+
+TEST(RefinedSparseCholesky, StorageType) {
+ MockSparseCholesky* mock_sparse_cholesky = new MockSparseCholesky;
+ MockIterativeRefiner* mock_iterative_refiner = new MockIterativeRefiner;
+ EXPECT_CALL(*mock_sparse_cholesky, StorageType())
+ .Times(1)
+ .WillRepeatedly(Return(CompressedRowSparseMatrix::UPPER_TRIANGULAR));
+ EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _))
+ .Times(0);
+ std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky);
+ std::unique_ptr<IterativeRefiner> iterative_refiner(mock_iterative_refiner);
+ RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky),
+ std::move(iterative_refiner));
+ EXPECT_EQ(refined_sparse_cholesky.StorageType(),
+ CompressedRowSparseMatrix::UPPER_TRIANGULAR);
+};
+
+TEST(RefinedSparseCholesky, Factorize) {
+ MockSparseCholesky* mock_sparse_cholesky = new MockSparseCholesky;
+ MockIterativeRefiner* mock_iterative_refiner = new MockIterativeRefiner;
+ EXPECT_CALL(*mock_sparse_cholesky, Factorize(_, _))
+ .Times(1)
+ .WillRepeatedly(Return(LINEAR_SOLVER_SUCCESS));
+ EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _))
+ .Times(0);
+ std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky);
+ std::unique_ptr<IterativeRefiner> iterative_refiner(mock_iterative_refiner);
+ RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky),
+ std::move(iterative_refiner));
+ CompressedRowSparseMatrix m(1, 1, 1);
+ std::string message;
+ EXPECT_EQ(refined_sparse_cholesky.Factorize(&m, &message),
+ LINEAR_SOLVER_SUCCESS);
+};
+
+TEST(RefinedSparseCholesky, FactorAndSolveWithUnsuccessfulFactorization) {
+ MockSparseCholesky* mock_sparse_cholesky = new MockSparseCholesky;
+ MockIterativeRefiner* mock_iterative_refiner = new MockIterativeRefiner;
+ EXPECT_CALL(*mock_sparse_cholesky, Factorize(_, _))
+ .Times(1)
+ .WillRepeatedly(Return(LINEAR_SOLVER_FAILURE));
+ EXPECT_CALL(*mock_sparse_cholesky, Solve(_, _, _))
+ .Times(0);
+ EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _))
+ .Times(0);
+ std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky);
+ std::unique_ptr<IterativeRefiner> iterative_refiner(mock_iterative_refiner);
+ RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky),
+ std::move(iterative_refiner));
+ CompressedRowSparseMatrix m(1, 1, 1);
+ std::string message;
+ double rhs;
+ double solution;
+ EXPECT_EQ(refined_sparse_cholesky.FactorAndSolve(&m, &rhs, &solution, &message),
+ LINEAR_SOLVER_FAILURE);
+};
+
+TEST(RefinedSparseCholesky, FactorAndSolveWithSuccess) {
+ MockSparseCholesky* mock_sparse_cholesky = new MockSparseCholesky;
+ std::unique_ptr<MockIterativeRefiner> mock_iterative_refiner(new MockIterativeRefiner);
+ EXPECT_CALL(*mock_sparse_cholesky, Factorize(_, _))
+ .Times(1)
+ .WillRepeatedly(Return(LINEAR_SOLVER_SUCCESS));
+ EXPECT_CALL(*mock_sparse_cholesky, Solve(_, _, _))
+ .Times(1)
+ .WillRepeatedly(Return(LINEAR_SOLVER_SUCCESS));
+ EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _))
+ .Times(1);
+
+ std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky);
+ std::unique_ptr<IterativeRefiner> iterative_refiner(std::move(mock_iterative_refiner));
+ RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky),
+ std::move(iterative_refiner));
+ CompressedRowSparseMatrix m(1, 1, 1);
+ std::string message;
+ double rhs;
+ double solution;
+ EXPECT_EQ(refined_sparse_cholesky.FactorAndSolve(&m, &rhs, &solution, &message),
+ LINEAR_SOLVER_SUCCESS);
+};
+
} // namespace internal
} // namespace ceres