Add MixedSparseCholesky. A simple class that composes SparseCholesky with IterativeRefiner. Change-Id: I4a67b8ca33a604aaa7b6a4bf511dad9501815f5b
diff --git a/internal/ceres/iterative_refiner.cc b/internal/ceres/iterative_refiner.cc index 6a5a0c7..fff343d 100644 --- a/internal/ceres/iterative_refiner.cc +++ b/internal/ceres/iterative_refiner.cc
@@ -46,6 +46,8 @@ correction_(num_cols), lhs_x_solution_(num_cols) {} +IterativeRefiner::~IterativeRefiner() {} + IterativeRefiner::Summary IterativeRefiner::Refine( const SparseMatrix& lhs, const double* rhs_ptr,
diff --git a/internal/ceres/iterative_refiner.h b/internal/ceres/iterative_refiner.h index 471116c..d1efb5c 100644 --- a/internal/ceres/iterative_refiner.h +++ b/internal/ceres/iterative_refiner.h
@@ -84,6 +84,9 @@ // to perform. IterativeRefiner(int num_cols, int max_num_iterations); + // Needed for mocking. + virtual ~IterativeRefiner(); + // Given an initial estimate of the solution of lhs * x = rhs, use // iterative refinement to improve it. // @@ -92,10 +95,12 @@ // // solution is expected to contain a approximation to the solution // to lhs * x = rhs. It can be zero. - Summary Refine(const SparseMatrix& lhs, - const double* rhs, - SparseCholesky* sparse_cholesky, - double* solution); + // + // This method is virtual to facilitate mocking. + virtual Summary Refine(const SparseMatrix& lhs, + const double* rhs, + SparseCholesky* sparse_cholesky, + double* solution); private: int num_cols_;
diff --git a/internal/ceres/sparse_cholesky.cc b/internal/ceres/sparse_cholesky.cc index 1b5e638..3c4e97f 100644 --- a/internal/ceres/sparse_cholesky.cc +++ b/internal/ceres/sparse_cholesky.cc
@@ -32,6 +32,7 @@ #include "ceres/cxsparse.h" #include "ceres/eigensparse.h" +#include "ceres/iterative_refiner.h" #include "ceres/suitesparse.h" namespace ceres { @@ -96,5 +97,36 @@ return CompressedRowSparseMatrix::LOWER_TRIANGULAR; } +RefinedSparseCholesky::RefinedSparseCholesky( + std::unique_ptr<SparseCholesky> sparse_cholesky, + std::unique_ptr<IterativeRefiner> iterative_refiner) + : sparse_cholesky_(std::move(sparse_cholesky)), + iterative_refiner_(std::move(iterative_refiner)) {} + +RefinedSparseCholesky::~RefinedSparseCholesky() {} + +CompressedRowSparseMatrix::StorageType RefinedSparseCholesky::StorageType() const { + return sparse_cholesky_->StorageType(); +} + +LinearSolverTerminationType RefinedSparseCholesky::Factorize( + CompressedRowSparseMatrix* lhs, std::string* message) { + lhs_ = lhs; + return sparse_cholesky_->Factorize(lhs, message); +} + +LinearSolverTerminationType RefinedSparseCholesky::Solve(const double* rhs, + double* solution, + std::string* message) { + CHECK(lhs_ != nullptr); + auto termination_type = sparse_cholesky_->Solve(rhs, solution, message); + if (termination_type != LINEAR_SOLVER_SUCCESS) { + return termination_type; + } + + iterative_refiner_->Refine(*lhs_, rhs, sparse_cholesky_.get(), solution); + return LINEAR_SOLVER_SUCCESS; +} + } // namespace internal } // namespace ceres
diff --git a/internal/ceres/sparse_cholesky.h b/internal/ceres/sparse_cholesky.h index 96d2bfa..c0e3e86 100644 --- a/internal/ceres/sparse_cholesky.h +++ b/internal/ceres/sparse_cholesky.h
@@ -34,6 +34,7 @@ // This include must come before any #ifndef check on Ceres compile options. #include "ceres/internal/port.h" +#include <memory> #include "ceres/linear_solver.h" #include "glog/logging.h" @@ -114,6 +115,29 @@ }; +class IterativeRefiner; + +// Computes an initial solution using the given instance of +// SparseCholesky, and then refines it using the IterativeRefiner. +class RefinedSparseCholesky : public SparseCholesky { + public: + RefinedSparseCholesky(std::unique_ptr<SparseCholesky> sparse_cholesky, + std::unique_ptr<IterativeRefiner> iterative_refiner); + virtual ~RefinedSparseCholesky(); + + virtual CompressedRowSparseMatrix::StorageType StorageType() const; + virtual LinearSolverTerminationType Factorize( + CompressedRowSparseMatrix* lhs, std::string* message); + virtual LinearSolverTerminationType Solve(const double* rhs, + double* solution, + std::string* message); + + private: + std::unique_ptr<SparseCholesky> sparse_cholesky_; + std::unique_ptr<IterativeRefiner> iterative_refiner_; + CompressedRowSparseMatrix* lhs_ = nullptr; +}; + } // namespace internal } // namespace ceres
diff --git a/internal/ceres/sparse_cholesky_test.cc b/internal/ceres/sparse_cholesky_test.cc index b75e3aa..f0cd729 100644 --- a/internal/ceres/sparse_cholesky_test.cc +++ b/internal/ceres/sparse_cholesky_test.cc
@@ -40,8 +40,10 @@ #include "ceres/compressed_row_sparse_matrix.h" #include "ceres/inner_product_computer.h" #include "ceres/internal/eigen.h" +#include "ceres/iterative_refiner.h" #include "ceres/random.h" #include "glog/logging.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" namespace ceres { @@ -214,5 +216,111 @@ ParamInfoToString); #endif +class MockSparseCholesky : public SparseCholesky { + public: + MOCK_CONST_METHOD0(StorageType, CompressedRowSparseMatrix::StorageType()); + MOCK_METHOD2(Factorize, + LinearSolverTerminationType(CompressedRowSparseMatrix* lhs, + std::string* message)); + MOCK_METHOD3(Solve, + LinearSolverTerminationType(const double* rhs, + double* solution, + std::string* message)); +}; + +class MockIterativeRefiner : public IterativeRefiner { + public: + MockIterativeRefiner() : IterativeRefiner(1, 1) {} + MOCK_METHOD4(Refine, + Summary(const SparseMatrix& lhs, + const double* rhs, + SparseCholesky* sparse_cholesky, + double* solution)); +}; + + +using testing::_; +using testing::Return; + +TEST(RefinedSparseCholesky, StorageType) { + MockSparseCholesky* mock_sparse_cholesky = new MockSparseCholesky; + MockIterativeRefiner* mock_iterative_refiner = new MockIterativeRefiner; + EXPECT_CALL(*mock_sparse_cholesky, StorageType()) + .Times(1) + .WillRepeatedly(Return(CompressedRowSparseMatrix::UPPER_TRIANGULAR)); + EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _)) + .Times(0); + std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky); + std::unique_ptr<IterativeRefiner> iterative_refiner(mock_iterative_refiner); + RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky), + std::move(iterative_refiner)); + EXPECT_EQ(refined_sparse_cholesky.StorageType(), + CompressedRowSparseMatrix::UPPER_TRIANGULAR); +}; + +TEST(RefinedSparseCholesky, Factorize) { + MockSparseCholesky* mock_sparse_cholesky = new MockSparseCholesky; + MockIterativeRefiner* mock_iterative_refiner = new MockIterativeRefiner; + EXPECT_CALL(*mock_sparse_cholesky, Factorize(_, _)) + .Times(1) + .WillRepeatedly(Return(LINEAR_SOLVER_SUCCESS)); + EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _)) + .Times(0); + std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky); + std::unique_ptr<IterativeRefiner> iterative_refiner(mock_iterative_refiner); + RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky), + std::move(iterative_refiner)); + CompressedRowSparseMatrix m(1, 1, 1); + std::string message; + EXPECT_EQ(refined_sparse_cholesky.Factorize(&m, &message), + LINEAR_SOLVER_SUCCESS); +}; + +TEST(RefinedSparseCholesky, FactorAndSolveWithUnsuccessfulFactorization) { + MockSparseCholesky* mock_sparse_cholesky = new MockSparseCholesky; + MockIterativeRefiner* mock_iterative_refiner = new MockIterativeRefiner; + EXPECT_CALL(*mock_sparse_cholesky, Factorize(_, _)) + .Times(1) + .WillRepeatedly(Return(LINEAR_SOLVER_FAILURE)); + EXPECT_CALL(*mock_sparse_cholesky, Solve(_, _, _)) + .Times(0); + EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _)) + .Times(0); + std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky); + std::unique_ptr<IterativeRefiner> iterative_refiner(mock_iterative_refiner); + RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky), + std::move(iterative_refiner)); + CompressedRowSparseMatrix m(1, 1, 1); + std::string message; + double rhs; + double solution; + EXPECT_EQ(refined_sparse_cholesky.FactorAndSolve(&m, &rhs, &solution, &message), + LINEAR_SOLVER_FAILURE); +}; + +TEST(RefinedSparseCholesky, FactorAndSolveWithSuccess) { + MockSparseCholesky* mock_sparse_cholesky = new MockSparseCholesky; + std::unique_ptr<MockIterativeRefiner> mock_iterative_refiner(new MockIterativeRefiner); + EXPECT_CALL(*mock_sparse_cholesky, Factorize(_, _)) + .Times(1) + .WillRepeatedly(Return(LINEAR_SOLVER_SUCCESS)); + EXPECT_CALL(*mock_sparse_cholesky, Solve(_, _, _)) + .Times(1) + .WillRepeatedly(Return(LINEAR_SOLVER_SUCCESS)); + EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _)) + .Times(1); + + std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky); + std::unique_ptr<IterativeRefiner> iterative_refiner(std::move(mock_iterative_refiner)); + RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky), + std::move(iterative_refiner)); + CompressedRowSparseMatrix m(1, 1, 1); + std::string message; + double rhs; + double solution; + EXPECT_EQ(refined_sparse_cholesky.FactorAndSolve(&m, &rhs, &solution, &message), + LINEAR_SOLVER_SUCCESS); +}; + } // namespace internal } // namespace ceres