Add mixed precision support for CPU based DenseCholesky

On problem-744-543562-pre.txt

The time spent in linear solver on my M1 Pro is

eigen        81.550970
eigen+mixed  54.107383
LAPACK       47.078127
LAPACK+mixed 28.639868

Solution quality is unaffected.

The implementation of RefinedDenseCholesky and DenseIterativeRefiner
are straightforward ports of RefinedSparseCholesky and
SparseIterativeRefiner (formerly IterativeRefiner).

It maybe possible to refactor the SparseCholesky and DenseCholesky
interfaces so that this code duplication can be removed in the
future.

Change-Id: I921334224cb97629a60390f2add822de207f7923
diff --git a/internal/ceres/cuda_buffer.h b/internal/ceres/cuda_buffer.h
index 8868e1a..64774fa 100644
--- a/internal/ceres/cuda_buffer.h
+++ b/internal/ceres/cuda_buffer.h
@@ -84,9 +84,10 @@
   // Perform an asynchronous copy from GPU memory using the stream provided.
   void CopyFromGpuAsync(const T* data, const size_t size, cudaStream_t stream) {
     Reserve(size);
-    CHECK_EQ(cudaMemcpyAsync(
-        data_, data, size * sizeof(T), cudaMemcpyDeviceToDevice, stream),
-            cudaSuccess);
+    CHECK_EQ(
+        cudaMemcpyAsync(
+            data_, data, size * sizeof(T), cudaMemcpyDeviceToDevice, stream),
+        cudaSuccess);
   }
 
   // Copy data from the GPU to CPU memory. This is necessarily synchronous since
diff --git a/internal/ceres/cuda_dense_cholesky_test.cc b/internal/ceres/cuda_dense_cholesky_test.cc
index b9acc99..c7b11ce 100644
--- a/internal/ceres/cuda_dense_cholesky_test.cc
+++ b/internal/ceres/cuda_dense_cholesky_test.cc
@@ -258,8 +258,6 @@
   EXPECT_NEAR(x(3), 1.0000, kEpsilon);
 }
 
-
-
 TEST(CUDADenseCholeskyMixedPrecision, Randomized1600x1600Tests) {
   const int kNumCols = 1600;
   using LhsType = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>;
diff --git a/internal/ceres/dense_cholesky.cc b/internal/ceres/dense_cholesky.cc
index e2c5036..16d3e1a 100644
--- a/internal/ceres/dense_cholesky.cc
+++ b/internal/ceres/dense_cholesky.cc
@@ -36,6 +36,7 @@
 #include <vector>
 
 #include "ceres/internal/config.h"
+#include "ceres/iterative_refiner.h"
 
 #ifndef CERES_NO_CUDA
 #include "ceres/ceres_cuda_kernels.h"
@@ -58,6 +59,18 @@
                         double* b,
                         const int* ldb,
                         int* info);
+
+extern "C" void spotrf_(
+    const char* uplo, const int* n, float* a, const int* lda, int* info);
+
+extern "C" void spotrs_(const char* uplo,
+                        const int* n,
+                        const int* nrhs,
+                        const float* a,
+                        const int* lda,
+                        float* b,
+                        const int* ldb,
+                        int* info);
 #endif
 
 namespace ceres::internal {
@@ -69,17 +82,23 @@
   std::unique_ptr<DenseCholesky> dense_cholesky;
 
   switch (options.dense_linear_algebra_library_type) {
-    case EIGEN: {
+    case EIGEN:
       // Eigen mixed precision solver not yet implemented.
-      if (options.use_mixed_precision_solves) return nullptr;
-      dense_cholesky = std::make_unique<EigenDenseCholesky>();
-    } break;
+      if (options.use_mixed_precision_solves) {
+        dense_cholesky = std::make_unique<FloatEigenDenseCholesky>();
+      } else {
+        dense_cholesky = std::make_unique<EigenDenseCholesky>();
+      }
+      break;
 
     case LAPACK:
 #ifndef CERES_NO_LAPACK
       // LAPACK mixed precision solver not yet implemented.
-      if (options.use_mixed_precision_solves) return nullptr;
-      dense_cholesky = std::make_unique<LAPACKDenseCholesky>();
+      if (options.use_mixed_precision_solves) {
+        dense_cholesky = std::make_unique<FloatLAPACKDenseCholesky>();
+      } else {
+        dense_cholesky = std::make_unique<LAPACKDenseCholesky>();
+      }
       break;
 #else
       LOG(FATAL) << "Ceres was compiled without support for LAPACK.";
@@ -102,6 +121,14 @@
                  << DenseLinearAlgebraLibraryTypeToString(
                         options.dense_linear_algebra_library_type);
   }
+
+  if (options.max_num_refinement_iterations > 0) {
+    auto refiner = std::make_unique<DenseIterativeRefiner>(
+        options.max_num_refinement_iterations);
+    dense_cholesky = std::make_unique<RefinedDenseCholesky>(
+        std::move(dense_cholesky), std::move(refiner));
+  }
+
   return dense_cholesky;
 }
 
@@ -146,6 +173,34 @@
   return LinearSolverTerminationType::SUCCESS;
 }
 
+LinearSolverTerminationType FloatEigenDenseCholesky::Factorize(
+    int num_cols, double* lhs, std::string* message) {
+  // TODO(sameeragarwal): Check if this causes a double allocation.
+  lhs_ = Eigen::Map<Eigen::MatrixXd>(lhs, num_cols, num_cols).cast<float>();
+  llt_ = std::make_unique<LLTType>(lhs_);
+  if (llt_->info() != Eigen::Success) {
+    *message = "Eigen failure. Unable to perform dense Cholesky factorization.";
+    return LinearSolverTerminationType::FAILURE;
+  }
+
+  *message = "Success.";
+  return LinearSolverTerminationType::SUCCESS;
+}
+
+LinearSolverTerminationType FloatEigenDenseCholesky::Solve(
+    const double* rhs, double* solution, std::string* message) {
+  if (llt_->info() != Eigen::Success) {
+    *message = "Eigen failure. Unable to perform dense Cholesky factorization.";
+    return LinearSolverTerminationType::FAILURE;
+  }
+
+  rhs_ = ConstVectorRef(rhs, llt_->cols()).cast<float>();
+  solution_ = llt_->solve(rhs_);
+  VectorRef(solution, llt_->cols()) = solution_.cast<double>();
+  *message = "Success.";
+  return LinearSolverTerminationType::SUCCESS;
+}
+
 #ifndef CERES_NO_LAPACK
 LinearSolverTerminationType LAPACKDenseCholesky::Factorize(
     int num_cols, double* lhs, std::string* message) {
@@ -182,7 +237,7 @@
   const int nrhs = 1;
   int info = 0;
 
-  std::copy_n(rhs, num_cols_, solution);
+  VectorRef(solution, num_cols_) = ConstVectorRef(rhs, num_cols_);
   dpotrs_(
       &uplo, &num_cols_, &nrhs, lhs_, &num_cols_, solution, &num_cols_, &info);
 
@@ -200,8 +255,95 @@
   return termination_type_;
 }
 
+LinearSolverTerminationType FloatLAPACKDenseCholesky::Factorize(
+    int num_cols, double* lhs, std::string* message) {
+  num_cols_ = num_cols;
+  lhs_ = Eigen::Map<Eigen::MatrixXd>(lhs, num_cols, num_cols).cast<float>();
+
+  const char uplo = 'L';
+  int info = 0;
+  spotrf_(&uplo, &num_cols_, lhs_.data(), &num_cols_, &info);
+
+  if (info < 0) {
+    termination_type_ = LinearSolverTerminationType::FATAL_ERROR;
+    LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
+               << "Please report it. "
+               << "LAPACK::spotrf fatal error. "
+               << "Argument: " << -info << " is invalid.";
+  } else if (info > 0) {
+    termination_type_ = LinearSolverTerminationType::FAILURE;
+    *message = StringPrintf(
+        "LAPACK::spotrf numerical failure. "
+        "The leading minor of order %d is not positive definite.",
+        info);
+  } else {
+    termination_type_ = LinearSolverTerminationType::SUCCESS;
+    *message = "Success.";
+  }
+  return termination_type_;
+}
+
+LinearSolverTerminationType FloatLAPACKDenseCholesky::Solve(
+    const double* rhs, double* solution, std::string* message) {
+  const char uplo = 'L';
+  const int nrhs = 1;
+  int info = 0;
+  rhs_and_solution_ = ConstVectorRef(rhs, num_cols_).cast<float>();
+  spotrs_(&uplo,
+          &num_cols_,
+          &nrhs,
+          lhs_.data(),
+          &num_cols_,
+          rhs_and_solution_.data(),
+          &num_cols_,
+          &info);
+
+  if (info < 0) {
+    termination_type_ = LinearSolverTerminationType::FATAL_ERROR;
+    LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
+               << "Please report it. "
+               << "LAPACK::dpotrs fatal error. "
+               << "Argument: " << -info << " is invalid.";
+  }
+
+  *message = "Success";
+  termination_type_ = LinearSolverTerminationType::SUCCESS;
+  VectorRef(solution, num_cols_) =
+      rhs_and_solution_.head(num_cols_).cast<double>();
+  return termination_type_;
+}
+
 #endif  // CERES_NO_LAPACK
 
+RefinedDenseCholesky::RefinedDenseCholesky(
+    std::unique_ptr<DenseCholesky> dense_cholesky,
+    std::unique_ptr<DenseIterativeRefiner> iterative_refiner)
+    : dense_cholesky_(std::move(dense_cholesky)),
+      iterative_refiner_(std::move(iterative_refiner)) {}
+
+RefinedDenseCholesky::~RefinedDenseCholesky() = default;
+
+LinearSolverTerminationType RefinedDenseCholesky::Factorize(
+    const int num_cols, double* lhs, std::string* message) {
+  lhs_ = lhs;
+  num_cols_ = num_cols;
+  return dense_cholesky_->Factorize(num_cols, lhs, message);
+}
+
+LinearSolverTerminationType RefinedDenseCholesky::Solve(const double* rhs,
+                                                        double* solution,
+                                                        std::string* message) {
+  CHECK(lhs_ != nullptr);
+  auto termination_type = dense_cholesky_->Solve(rhs, solution, message);
+  if (termination_type != LinearSolverTerminationType::SUCCESS) {
+    return termination_type;
+  }
+
+  iterative_refiner_->Refine(
+      num_cols_, lhs_, rhs, dense_cholesky_.get(), solution);
+  return LinearSolverTerminationType::SUCCESS;
+}
+
 #ifndef CERES_NO_CUDA
 
 bool CUDADenseCholesky::Init(ContextImpl* context, std::string* message) {
@@ -323,15 +465,14 @@
   if (cuda_dense_cholesky->Init(options.context, &cuda_error)) {
     return cuda_dense_cholesky;
   }
-  // Initialization failed, destroy the object (done automatically) and return a
-  // nullptr.
+  // Initialization failed, destroy the object (done automatically) and return
+  // a nullptr.
   LOG(ERROR) << "CUDADenseCholesky::Init failed: " << cuda_error;
   return nullptr;
 }
 
 std::unique_ptr<CUDADenseCholeskyMixedPrecision>
-    CUDADenseCholeskyMixedPrecision::Create(
-    const LinearSolver::Options& options) {
+CUDADenseCholeskyMixedPrecision::Create(const LinearSolver::Options& options) {
   if (options.dense_linear_algebra_library_type != CUDA ||
       !options.use_mixed_precision_solves) {
     // The user called the wrong factory method.
@@ -347,8 +488,8 @@
   return nullptr;
 }
 
-bool CUDADenseCholeskyMixedPrecision::Init(
-    const LinearSolver::Options& options, std::string* message) {
+bool CUDADenseCholeskyMixedPrecision::Init(const LinearSolver::Options& options,
+                                           std::string* message) {
   if (!options.context->InitCUDA(message)) {
     return false;
   }
@@ -452,9 +593,7 @@
 }
 
 LinearSolverTerminationType CUDADenseCholeskyMixedPrecision::Factorize(
-    int num_cols,
-    double* lhs,
-    std::string* message) {
+    int num_cols, double* lhs, std::string* message) {
   num_cols_ = num_cols;
 
   // Copy fp64 version of lhs to GPU.
@@ -472,9 +611,7 @@
 }
 
 LinearSolverTerminationType CUDADenseCholeskyMixedPrecision::Solve(
-    const double* rhs,
-    double* solution,
-    std::string* message) {
+    const double* rhs, double* solution, std::string* message) {
   // If factorization failed, return failure.
   if (factorize_result_ != LinearSolverTerminationType::SUCCESS) {
     *message = "Factorize did not complete successfully previously.";
@@ -505,8 +642,7 @@
       return result;
     }
     // [fp64] x += c.
-    CudaDsxpy(
-        x_fp64_.data(), correction_fp32_.data(), num_cols_, stream_);
+    CudaDsxpy(x_fp64_.data(), correction_fp32_.data(), num_cols_, stream_);
     if (i < max_num_refinement_iterations_) {
       // [fp64] residual = rhs - lhs * x
       // This is done in two steps:
diff --git a/internal/ceres/dense_cholesky.h b/internal/ceres/dense_cholesky.h
index 0593875..1c561c1 100644
--- a/internal/ceres/dense_cholesky.h
+++ b/internal/ceres/dense_cholesky.h
@@ -114,6 +114,23 @@
   std::unique_ptr<LLTType> llt_;
 };
 
+class CERES_NO_EXPORT FloatEigenDenseCholesky final : public DenseCholesky {
+ public:
+  LinearSolverTerminationType Factorize(int num_cols,
+                                        double* lhs,
+                                        std::string* message) override;
+  LinearSolverTerminationType Solve(const double* rhs,
+                                    double* solution,
+                                    std::string* message) override;
+
+ private:
+  Eigen::MatrixXf lhs_;
+  Eigen::VectorXf rhs_;
+  Eigen::VectorXf solution_;
+  using LLTType = Eigen::LLT<Eigen::MatrixXf, Eigen::Lower>;
+  std::unique_ptr<LLTType> llt_;
+};
+
 #ifndef CERES_NO_LAPACK
 class CERES_NO_EXPORT LAPACKDenseCholesky final : public DenseCholesky {
  public:
@@ -130,8 +147,50 @@
   LinearSolverTerminationType termination_type_ =
       LinearSolverTerminationType::FATAL_ERROR;
 };
+
+class CERES_NO_EXPORT FloatLAPACKDenseCholesky final : public DenseCholesky {
+ public:
+  LinearSolverTerminationType Factorize(int num_cols,
+                                        double* lhs,
+                                        std::string* message) override;
+  LinearSolverTerminationType Solve(const double* rhs,
+                                    double* solution,
+                                    std::string* message) override;
+
+ private:
+  Eigen::MatrixXf lhs_;
+  Eigen::VectorXf rhs_and_solution_;
+  int num_cols_ = -1;
+  LinearSolverTerminationType termination_type_ =
+      LinearSolverTerminationType::FATAL_ERROR;
+};
 #endif  // CERES_NO_LAPACK
 
+class DenseIterativeRefiner;
+
+// Computes an initial solution using the given instance of
+// DenseCholesky, and then refines it using the DenseIterativeRefiner.
+class CERES_NO_EXPORT RefinedDenseCholesky final : public DenseCholesky {
+ public:
+  RefinedDenseCholesky(
+      std::unique_ptr<DenseCholesky> dense_cholesky,
+      std::unique_ptr<DenseIterativeRefiner> iterative_refiner);
+  ~RefinedDenseCholesky() override;
+
+  LinearSolverTerminationType Factorize(int num_cols,
+                                        double* lhs,
+                                        std::string* message) override;
+  LinearSolverTerminationType Solve(const double* rhs,
+                                    double* solution,
+                                    std::string* message) override;
+
+ private:
+  std::unique_ptr<DenseCholesky> dense_cholesky_;
+  std::unique_ptr<DenseIterativeRefiner> iterative_refiner_;
+  double* lhs_ = nullptr;
+  int num_cols_;
+};
+
 #ifndef CERES_NO_CUDA
 // CUDA implementation of DenseCholesky using the cuSolverDN library using the
 // 32-bit legacy interface for maximum compatibility.
@@ -190,13 +249,13 @@
 //    symmetric positive definite.
 // 2. During the solution update, the up-cast and accumulation is performed in
 //    one step with a custom kernel.
-class CERES_NO_EXPORT CUDADenseCholeskyMixedPrecision final :
-    public DenseCholesky {
+class CERES_NO_EXPORT CUDADenseCholeskyMixedPrecision final
+    : public DenseCholesky {
  public:
   static std::unique_ptr<CUDADenseCholeskyMixedPrecision> Create(
       const LinearSolver::Options& options);
-  CUDADenseCholeskyMixedPrecision(
-      const CUDADenseCholeskyMixedPrecision&) = delete;
+  CUDADenseCholeskyMixedPrecision(const CUDADenseCholeskyMixedPrecision&) =
+      delete;
   CUDADenseCholeskyMixedPrecision& operator=(
       const CUDADenseCholeskyMixedPrecision&) = delete;
   LinearSolverTerminationType Factorize(int num_cols,
diff --git a/internal/ceres/dense_cholesky_test.cc b/internal/ceres/dense_cholesky_test.cc
index 5f96939..f8e6567 100644
--- a/internal/ceres/dense_cholesky_test.cc
+++ b/internal/ceres/dense_cholesky_test.cc
@@ -39,6 +39,7 @@
 #include "Eigen/Dense"
 #include "ceres/internal/config.h"
 #include "ceres/internal/eigen.h"
+#include "ceres/iterative_refiner.h"
 #include "ceres/linear_solver.h"
 #include "glog/logging.h"
 #include "gmock/gmock.h"
@@ -80,8 +81,7 @@
   if (options.use_mixed_precision_solves) {
     options.max_num_refinement_iterations = kNumRefinementSteps;
   }
-  std::unique_ptr<DenseCholesky> dense_cholesky =
-      DenseCholesky::Create(options);
+  auto dense_cholesky = DenseCholesky::Create(options);
 
   const int kNumTrials = 10;
   const int kMinNumCols = 1;
@@ -111,13 +111,15 @@
 INSTANTIATE_TEST_SUITE_P(EigenCholesky,
                          DenseCholeskyTest,
                          ::testing::Combine(::testing::Values(EIGEN),
-                                            ::testing::Values(kFullPrecision)),
+                                            ::testing::Values(kMixedPrecision,
+                                                              kFullPrecision)),
                          ParamInfoToString);
 #ifndef CERES_NO_LAPACK
 INSTANTIATE_TEST_SUITE_P(LapackCholesky,
                          DenseCholeskyTest,
                          ::testing::Combine(::testing::Values(LAPACK),
-                                            ::testing::Values(kFullPrecision)),
+                                            ::testing::Values(kMixedPrecision,
+                                                              kFullPrecision)),
                          ParamInfoToString);
 #endif
 #ifndef CERES_NO_CUDA
@@ -129,44 +131,86 @@
                          ParamInfoToString);
 #endif
 
-#ifndef CERES_NO_CUDA
-TEST(DenseCholesky, ValidMixedPrecisionOptions) {
-  // Dense Cholesky with CUDA: okay, supported.
-  ContextImpl context;
-  LinearSolver::Options options;
-  options.dense_linear_algebra_library_type = CUDA;
-  options.use_mixed_precision_solves = true;
-  options.context = &context;
-  std::unique_ptr<DenseCholesky> dense_cholesky =
-      DenseCholesky::Create(options);
-  EXPECT_NE(dense_cholesky, nullptr);
-}
-#endif
+class MockDenseCholesky : public DenseCholesky {
+ public:
+  MOCK_METHOD3(Factorize,
+               LinearSolverTerminationType(int num_cols,
+                                           double* lhs,
+                                           std::string* message));
+  MOCK_METHOD3(Solve,
+               LinearSolverTerminationType(const double* rhs,
+                                           double* solution,
+                                           std::string* message));
+};
 
-TEST(DenseCholesky, InvalidMixedPrecisionOptionsEigen) {
-  // Dense Cholesky with Eigen: not supported
-  ContextImpl context;
-  LinearSolver::Options options;
-  options.dense_linear_algebra_library_type = EIGEN;
-  options.use_mixed_precision_solves = true;
-  options.context = &context;
-  std::unique_ptr<DenseCholesky> dense_cholesky =
-      DenseCholesky::Create(options);
-  EXPECT_EQ(dense_cholesky, nullptr);
-}
+class MockDenseIterativeRefiner : public DenseIterativeRefiner {
+ public:
+  MockDenseIterativeRefiner() : DenseIterativeRefiner(1) {}
+  MOCK_METHOD5(Refine,
+               void(int num_cols,
+                    const double* lhs,
+                    const double* rhs,
+                    DenseCholesky* dense_cholesky,
+                    double* solution));
+};
 
-#ifndef CERES_NO_LAPACK
-TEST(DenseCholesky, InvalidMixedPrecisionOptionsLAPACK) {
-  // Dense Cholesky with Lapack: not supported
-  ContextImpl context;
-  LinearSolver::Options options;
-  options.dense_linear_algebra_library_type = LAPACK;
-  options.use_mixed_precision_solves = true;
-  options.context = &context;
-  std::unique_ptr<DenseCholesky> dense_cholesky =
-      DenseCholesky::Create(options);
-  EXPECT_EQ(dense_cholesky, nullptr);
-}
-#endif
+using testing::_;
+using testing::Return;
+
+TEST(RefinedDenseCholesky, Factorize) {
+  auto dense_cholesky = std::make_unique<MockDenseCholesky>();
+  auto iterative_refiner = std::make_unique<MockDenseIterativeRefiner>();
+  EXPECT_CALL(*dense_cholesky, Factorize(_, _, _))
+      .Times(1)
+      .WillRepeatedly(Return(LinearSolverTerminationType::SUCCESS));
+  EXPECT_CALL(*iterative_refiner, Refine(_, _, _, _, _)).Times(0);
+  RefinedDenseCholesky refined_dense_cholesky(std::move(dense_cholesky),
+                                              std::move(iterative_refiner));
+  double lhs;
+  std::string message;
+  EXPECT_EQ(refined_dense_cholesky.Factorize(1, &lhs, &message),
+            LinearSolverTerminationType::SUCCESS);
+};
+
+TEST(RefinedDenseCholesky, FactorAndSolveWithUnsuccessfulFactorization) {
+  auto dense_cholesky = std::make_unique<MockDenseCholesky>();
+  auto iterative_refiner = std::make_unique<MockDenseIterativeRefiner>();
+  EXPECT_CALL(*dense_cholesky, Factorize(_, _, _))
+      .Times(1)
+      .WillRepeatedly(Return(LinearSolverTerminationType::FAILURE));
+  EXPECT_CALL(*dense_cholesky, Solve(_, _, _)).Times(0);
+  EXPECT_CALL(*iterative_refiner, Refine(_, _, _, _, _)).Times(0);
+  RefinedDenseCholesky refined_dense_cholesky(std::move(dense_cholesky),
+                                              std::move(iterative_refiner));
+  double lhs;
+  std::string message;
+  double rhs;
+  double solution;
+  EXPECT_EQ(
+      refined_dense_cholesky.FactorAndSolve(1, &lhs, &rhs, &solution, &message),
+      LinearSolverTerminationType::FAILURE);
+};
+
+TEST(RefinedDenseCholesky, FactorAndSolveWithSuccess) {
+  auto dense_cholesky = std::make_unique<MockDenseCholesky>();
+  auto iterative_refiner = std::make_unique<MockDenseIterativeRefiner>();
+  EXPECT_CALL(*dense_cholesky, Factorize(_, _, _))
+      .Times(1)
+      .WillRepeatedly(Return(LinearSolverTerminationType::SUCCESS));
+  EXPECT_CALL(*dense_cholesky, Solve(_, _, _))
+      .Times(1)
+      .WillRepeatedly(Return(LinearSolverTerminationType::SUCCESS));
+  EXPECT_CALL(*iterative_refiner, Refine(_, _, _, _, _)).Times(1);
+
+  RefinedDenseCholesky refined_dense_cholesky(std::move(dense_cholesky),
+                                              std::move(iterative_refiner));
+  double lhs;
+  std::string message;
+  double rhs;
+  double solution;
+  EXPECT_EQ(
+      refined_dense_cholesky.FactorAndSolve(1, &lhs, &rhs, &solution, &message),
+      LinearSolverTerminationType::SUCCESS);
+};
 
 }  // namespace ceres::internal
diff --git a/internal/ceres/eigensparse.cc b/internal/ceres/eigensparse.cc
index a2407c0..ce01658 100644
--- a/internal/ceres/eigensparse.cc
+++ b/internal/ceres/eigensparse.cc
@@ -195,7 +195,7 @@
 #else
     LOG(FATAL)
         << "Congratulations you have found a bug in Ceres Solver. Please "
-        "report it to the Ceres Solver developers.";
+           "report it to the Ceres Solver developers.";
     return nullptr;
 #endif  // CERES_NO_EIGEN_METIS
   }
diff --git a/internal/ceres/iterative_refiner.cc b/internal/ceres/iterative_refiner.cc
index b9fa88a..90ff511 100644
--- a/internal/ceres/iterative_refiner.cc
+++ b/internal/ceres/iterative_refiner.cc
@@ -33,39 +33,67 @@
 #include <string>
 
 #include "Eigen/Core"
+#include "ceres/dense_cholesky.h"
 #include "ceres/sparse_cholesky.h"
 #include "ceres/sparse_matrix.h"
 
 namespace ceres::internal {
 
-IterativeRefiner::IterativeRefiner(const int max_num_iterations)
+SparseIterativeRefiner::SparseIterativeRefiner(const int max_num_iterations)
     : max_num_iterations_(max_num_iterations) {}
 
-IterativeRefiner::~IterativeRefiner() = default;
+SparseIterativeRefiner::~SparseIterativeRefiner() = default;
 
-void IterativeRefiner::Allocate(int num_cols) {
+void SparseIterativeRefiner::Allocate(int num_cols) {
   residual_.resize(num_cols);
   correction_.resize(num_cols);
   lhs_x_solution_.resize(num_cols);
 }
 
-void IterativeRefiner::Refine(const SparseMatrix& lhs,
-                              const double* rhs_ptr,
-                              SparseCholesky* sparse_cholesky,
-                              double* solution_ptr) {
+void SparseIterativeRefiner::Refine(const SparseMatrix& lhs,
+                                    const double* rhs_ptr,
+                                    SparseCholesky* cholesky,
+                                    double* solution_ptr) {
   const int num_cols = lhs.num_cols();
   Allocate(num_cols);
   ConstVectorRef rhs(rhs_ptr, num_cols);
   VectorRef solution(solution_ptr, num_cols);
+  std::string ignored_message;
   for (int i = 0; i < max_num_iterations_; ++i) {
     // residual = rhs - lhs * solution
     lhs_x_solution_.setZero();
     lhs.RightMultiply(solution_ptr, lhs_x_solution_.data());
     residual_ = rhs - lhs_x_solution_;
     // solution += lhs^-1 residual
-    std::string ignored_message;
-    sparse_cholesky->Solve(
-        residual_.data(), correction_.data(), &ignored_message);
+    cholesky->Solve(residual_.data(), correction_.data(), &ignored_message);
+    solution += correction_;
+  }
+};
+
+DenseIterativeRefiner::DenseIterativeRefiner(const int max_num_iterations)
+    : max_num_iterations_(max_num_iterations) {}
+
+DenseIterativeRefiner::~DenseIterativeRefiner() = default;
+
+void DenseIterativeRefiner::Allocate(int num_cols) {
+  residual_.resize(num_cols);
+  correction_.resize(num_cols);
+}
+
+void DenseIterativeRefiner::Refine(const int num_cols,
+                                   const double* lhs_ptr,
+                                   const double* rhs_ptr,
+                                   DenseCholesky* cholesky,
+                                   double* solution_ptr) {
+  Allocate(num_cols);
+  ConstMatrixRef lhs(lhs_ptr, num_cols, num_cols);
+  ConstVectorRef rhs(rhs_ptr, num_cols);
+  VectorRef solution(solution_ptr, num_cols);
+  std::string ignored_message;
+  for (int i = 0; i < max_num_iterations_; ++i) {
+    residual_ = rhs - lhs * solution;
+    // solution += lhs^-1 residual
+    cholesky->Solve(residual_.data(), correction_.data(), &ignored_message);
     solution += correction_;
   }
 };
diff --git a/internal/ceres/iterative_refiner.h b/internal/ceres/iterative_refiner.h
index d500212..8333124 100644
--- a/internal/ceres/iterative_refiner.h
+++ b/internal/ceres/iterative_refiner.h
@@ -41,6 +41,7 @@
 
 namespace ceres::internal {
 
+class DenseCholesky;
 class SparseCholesky;
 class SparseMatrix;
 
@@ -57,20 +58,20 @@
 // Definite linear systems.
 //
 // The above iterative loop is run until max_num_iterations is reached.
-class CERES_NO_EXPORT IterativeRefiner {
+class CERES_NO_EXPORT SparseIterativeRefiner {
  public:
   // max_num_iterations is the number of refinement iterations to
   // perform.
-  explicit IterativeRefiner(int max_num_iterations);
+  explicit SparseIterativeRefiner(int max_num_iterations);
 
   // Needed for mocking.
-  virtual ~IterativeRefiner();
+  virtual ~SparseIterativeRefiner();
 
   // Given an initial estimate of the solution of lhs * x = rhs, use
   // max_num_iterations rounds of iterative refinement to improve it.
   //
-  // sparse_cholesky is assumed to contain an already computed
-  // factorization (or approximation thereof) of lhs.
+  // cholesky is assumed to contain an already computed factorization (or
+  // an approximation thereof) of lhs.
   //
   // solution is expected to contain a approximation to the solution
   // to lhs * x = rhs. It can be zero.
@@ -78,7 +79,7 @@
   // This method is virtual to facilitate mocking.
   virtual void Refine(const SparseMatrix& lhs,
                       const double* rhs,
-                      SparseCholesky* sparse_cholesky,
+                      SparseCholesky* cholesky,
                       double* solution);
 
  private:
@@ -90,6 +91,39 @@
   Vector lhs_x_solution_;
 };
 
+class CERES_NO_EXPORT DenseIterativeRefiner {
+ public:
+  // max_num_iterations is the number of refinement iterations to
+  // perform.
+  explicit DenseIterativeRefiner(int max_num_iterations);
+
+  // Needed for mocking.
+  virtual ~DenseIterativeRefiner();
+
+  // Given an initial estimate of the solution of lhs * x = rhs, use
+  // max_num_iterations rounds of iterative refinement to improve it.
+  //
+  // cholesky is assumed to contain an already computed factorization (or
+  // an approximation thereof) of lhs.
+  //
+  // solution is expected to contain a approximation to the solution
+  // to lhs * x = rhs. It can be zero.
+  //
+  // This method is virtual to facilitate mocking.
+  virtual void Refine(int num_cols,
+                      const double* lhs,
+                      const double* rhs,
+                      DenseCholesky* cholesky,
+                      double* solution);
+
+ private:
+  void Allocate(int num_cols);
+
+  int max_num_iterations_;
+  Vector residual_;
+  Vector correction_;
+};
+
 }  // namespace ceres::internal
 
 #endif  // CERES_INTERNAL_ITERATIVE_REFINER_H_
diff --git a/internal/ceres/iterative_refiner_test.cc b/internal/ceres/iterative_refiner_test.cc
index 5718f14..0b09247 100644
--- a/internal/ceres/iterative_refiner_test.cc
+++ b/internal/ceres/iterative_refiner_test.cc
@@ -33,6 +33,7 @@
 #include <utility>
 
 #include "Eigen/Dense"
+#include "ceres/dense_cholesky.h"
 #include "ceres/internal/eigen.h"
 #include "ceres/sparse_cholesky.h"
 #include "ceres/sparse_matrix.h"
@@ -97,7 +98,9 @@
     const int num_cols = lhs_.cols();
     VectorRef solution(solution_ptr, num_cols);
     ConstVectorRef rhs(rhs_ptr, num_cols);
-    solution = lhs_.llt().solve(rhs.cast<Scalar>()).template cast<double>();
+    auto llt = lhs_.llt();
+    CHECK_NE(llt.info(), Eigen::Success);
+    solution = llt.solve(rhs.cast<Scalar>()).template cast<double>();
     return LinearSolverTerminationType::SUCCESS;
   }
 
@@ -113,10 +116,37 @@
   Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> lhs_;
 };
 
+// A fake DenseCholesky which uses Eigen's Cholesky factorization to
+// do the real work. The template parameter allows us to work in
+// doubles or floats, even though the source matrix is double.
+template <typename Scalar>
+class FakeDenseCholesky : public DenseCholesky {
+ public:
+  explicit FakeDenseCholesky(const Matrix& lhs) { lhs_ = lhs.cast<Scalar>(); }
+
+  LinearSolverTerminationType Solve(const double* rhs_ptr,
+                                    double* solution_ptr,
+                                    std::string* message) final {
+    const int num_cols = lhs_.cols();
+    VectorRef solution(solution_ptr, num_cols);
+    ConstVectorRef rhs(rhs_ptr, num_cols);
+    solution = lhs_.llt().solve(rhs.cast<Scalar>()).template cast<double>();
+    return LinearSolverTerminationType::SUCCESS;
+  }
+
+  LinearSolverTerminationType Factorize(int num_cols,
+                                        double* lhs,
+                                        std::string* message) final
+      DO_NOT_CALL_WITH_RETURN(LinearSolverTerminationType::FAILURE);
+
+ private:
+  Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> lhs_;
+};
+
 #undef DO_NOT_CALL
 #undef DO_NOT_CALL_WITH_RETURN
 
-class IterativeRefinerTest : public ::testing::Test {
+class SparseIterativeRefinerTest : public ::testing::Test {
  public:
   void SetUp() override {
     num_cols_ = 5;
@@ -136,10 +166,11 @@
   Vector rhs_, solution_;
 };
 
-TEST_F(IterativeRefinerTest, RandomSolutionWithExactFactorizationConverges) {
+TEST_F(SparseIterativeRefinerTest,
+       RandomSolutionWithExactFactorizationConverges) {
   FakeSparseMatrix lhs(lhs_);
   FakeSparseCholesky<double> sparse_cholesky(lhs_);
-  IterativeRefiner refiner(max_num_iterations_);
+  SparseIterativeRefiner refiner(max_num_iterations_);
   Vector refined_solution(num_cols_);
   refined_solution.setRandom();
   refiner.Refine(lhs, rhs_.data(), &sparse_cholesky, refined_solution.data());
@@ -148,13 +179,13 @@
               std::numeric_limits<double>::epsilon() * 10);
 }
 
-TEST_F(IterativeRefinerTest,
+TEST_F(SparseIterativeRefinerTest,
        RandomSolutionWithApproximationFactorizationConverges) {
   FakeSparseMatrix lhs(lhs_);
   // Use a single precision Cholesky factorization of the double
   // precision matrix. This will give us an approximate factorization.
   FakeSparseCholesky<float> sparse_cholesky(lhs_);
-  IterativeRefiner refiner(max_num_iterations_);
+  SparseIterativeRefiner refiner(max_num_iterations_);
   Vector refined_solution(num_cols_);
   refined_solution.setRandom();
   refiner.Refine(lhs, rhs_.data(), &sparse_cholesky, refined_solution.data());
@@ -163,4 +194,60 @@
               std::numeric_limits<double>::epsilon() * 10);
 }
 
+class DenseIterativeRefinerTest : public ::testing::Test {
+ public:
+  void SetUp() override {
+    num_cols_ = 5;
+    max_num_iterations_ = 30;
+    Matrix m(num_cols_, num_cols_);
+    m.setRandom();
+    lhs_ = m * m.transpose();
+    solution_.resize(num_cols_);
+    solution_.setRandom();
+    rhs_ = lhs_ * solution_;
+  };
+
+ protected:
+  int num_cols_;
+  int max_num_iterations_;
+  Matrix lhs_;
+  Vector rhs_, solution_;
+};
+
+TEST_F(DenseIterativeRefinerTest,
+       RandomSolutionWithExactFactorizationConverges) {
+  Matrix lhs = lhs_;
+  FakeDenseCholesky<double> dense_cholesky(lhs);
+  DenseIterativeRefiner refiner(max_num_iterations_);
+  Vector refined_solution(num_cols_);
+  refined_solution.setRandom();
+  refiner.Refine(lhs.cols(),
+                 lhs.data(),
+                 rhs_.data(),
+                 &dense_cholesky,
+                 refined_solution.data());
+  EXPECT_NEAR((lhs_ * refined_solution - rhs_).norm(),
+              0.0,
+              std::numeric_limits<double>::epsilon() * 10);
+}
+
+TEST_F(DenseIterativeRefinerTest,
+       RandomSolutionWithApproximationFactorizationConverges) {
+  Matrix lhs = lhs_;
+  // Use a single precision Cholesky factorization of the double
+  // precision matrix. This will give us an approximate factorization.
+  FakeDenseCholesky<float> dense_cholesky(lhs_);
+  DenseIterativeRefiner refiner(max_num_iterations_);
+  Vector refined_solution(num_cols_);
+  refined_solution.setRandom();
+  refiner.Refine(lhs.cols(),
+                 lhs.data(),
+                 rhs_.data(),
+                 &dense_cholesky,
+                 refined_solution.data());
+  EXPECT_NEAR((lhs_ * refined_solution - rhs_).norm(),
+              0.0,
+              std::numeric_limits<double>::epsilon() * 10);
+}
+
 }  // namespace ceres::internal
diff --git a/internal/ceres/solver.cc b/internal/ceres/solver.cc
index 93551f7..966a5dd 100644
--- a/internal/ceres/solver.cc
+++ b/internal/ceres/solver.cc
@@ -120,44 +120,36 @@
 
 bool MixedPrecisionOptionIsValid(const Solver::Options& options,
                                  string* error) {
-  if (options.use_mixed_precision_solves) {
-    if ((options.linear_solver_type == DENSE_NORMAL_CHOLESKY ||
-        options.linear_solver_type == DENSE_SCHUR) &&
-        options.dense_linear_algebra_library_type == CUDA) {
-      // Mixed precision with CUDA and dense Cholesky variant: okay.
-      return true;
-    }
-    if ((options.linear_solver_type == SPARSE_NORMAL_CHOLESKY ||
-        options.linear_solver_type == SPARSE_SCHUR) &&
-        (options.sparse_linear_algebra_library_type == EIGEN_SPARSE ||
-        options.sparse_linear_algebra_library_type == ACCELERATE_SPARSE)) {
+  if (!options.use_mixed_precision_solves) {
+    return true;
+  }
+
+  // All dense linear algebra backends support mixed precision solves now with
+  // Cholesky factorization.
+  if ((options.linear_solver_type == DENSE_NORMAL_CHOLESKY ||
+       options.linear_solver_type == DENSE_SCHUR)) {
+    return true;
+  }
+
+  if ((options.linear_solver_type == SPARSE_NORMAL_CHOLESKY ||
+       options.linear_solver_type == SPARSE_SCHUR)) {
+    if (options.sparse_linear_algebra_library_type == EIGEN_SPARSE ||
+        options.sparse_linear_algebra_library_type == ACCELERATE_SPARSE) {
       // Mixed precision with any Eigen or Accelerate Cholesky variant: okay.
       return true;
     }
-    // No other mixed precision variants are supported.
-    if (options.linear_solver_type == DENSE_NORMAL_CHOLESKY ||
-        options.linear_solver_type == DENSE_SCHUR) {
+    if (options.sparse_linear_algebra_library_type == SUITE_SPARSE) {
       *error = StringPrintf(
-          "use_mixed_precision_solves with %s is only supported with "
-          "CUDA as the dense_linear_algebra_library_type.",
-          LinearSolverTypeToString(options.linear_solver_type));
-      return false;
-    }
-    if ((options.linear_solver_type == SPARSE_NORMAL_CHOLESKY ||
-        options.linear_solver_type == SPARSE_SCHUR) &&
-        options.sparse_linear_algebra_library_type == SUITE_SPARSE) {
-      *error =  StringPrintf(
           "use_mixed_precision_solves with %s is not supported with "
           "SUITE_SPARSE as the sparse_linear_algebra_library_type.",
           LinearSolverTypeToString(options.linear_solver_type));
       return false;
     }
-    *error = StringPrintf(
-          "use_mixed_precision_solves with %s is not supported.",
-          LinearSolverTypeToString(options.linear_solver_type));
-    return false;
   }
-  return true;
+
+  *error = StringPrintf("use_mixed_precision_solves with %s is not supported.",
+                        LinearSolverTypeToString(options.linear_solver_type));
+  return false;
 }
 
 bool TrustRegionOptionsAreValid(const Solver::Options& options, string* error) {
@@ -300,7 +292,8 @@
     }
     if (options.sparse_linear_algebra_library_type == ACCELERATE_SPARSE) {
       *error =
-          "ACCELERATE_SPARSE is not currently supported with dynamic sparsity.";
+          "ACCELERATE_SPARSE is not currently supported with dynamic "
+          "sparsity.";
       return false;
     }
   }
@@ -310,7 +303,8 @@
       options.residual_blocks_for_subset_preconditioner.empty()) {
     *error =
         "When using SUBSET preconditioner, "
-        "Solver::Options::residual_blocks_for_subset_preconditioner cannot be "
+        "Solver::Options::residual_blocks_for_subset_preconditioner cannot "
+        "be "
         "empty";
     return false;
   }
@@ -348,8 +342,8 @@
 
   // Warn user if they have requested BISECTION interpolation, but constraints
   // on max/min step size change during line search prevent bisection scaling
-  // from occurring. Warn only, as this is likely a user mistake, but one which
-  // does not prevent us from continuing.
+  // from occurring. Warn only, as this is likely a user mistake, but one
+  // which does not prevent us from continuing.
   if (options.line_search_interpolation_type == ceres::BISECTION &&
       (options.max_line_search_step_contraction > 0.5 ||
        options.min_line_search_step_contraction < 0.5)) {
diff --git a/internal/ceres/sparse_cholesky.cc b/internal/ceres/sparse_cholesky.cc
index fe7412c..22df3c9 100644
--- a/internal/ceres/sparse_cholesky.cc
+++ b/internal/ceres/sparse_cholesky.cc
@@ -94,10 +94,10 @@
   }
 
   if (options.max_num_refinement_iterations > 0) {
-    std::unique_ptr<IterativeRefiner> refiner(
-        new IterativeRefiner(options.max_num_refinement_iterations));
-    sparse_cholesky = std::unique_ptr<SparseCholesky>(new RefinedSparseCholesky(
-        std::move(sparse_cholesky), std::move(refiner)));
+    auto refiner = std::make_unique<SparseIterativeRefiner>(
+        options.max_num_refinement_iterations);
+    sparse_cholesky = std::make_unique<RefinedSparseCholesky>(
+        std::move(sparse_cholesky), std::move(refiner));
   }
   return sparse_cholesky;
 }
@@ -118,7 +118,7 @@
 
 RefinedSparseCholesky::RefinedSparseCholesky(
     std::unique_ptr<SparseCholesky> sparse_cholesky,
-    std::unique_ptr<IterativeRefiner> iterative_refiner)
+    std::unique_ptr<SparseIterativeRefiner> iterative_refiner)
     : sparse_cholesky_(std::move(sparse_cholesky)),
       iterative_refiner_(std::move(iterative_refiner)) {}
 
diff --git a/internal/ceres/sparse_cholesky.h b/internal/ceres/sparse_cholesky.h
index feea7aa..9907d07 100644
--- a/internal/ceres/sparse_cholesky.h
+++ b/internal/ceres/sparse_cholesky.h
@@ -112,14 +112,15 @@
                                              std::string* message);
 };
 
-class IterativeRefiner;
+class SparseIterativeRefiner;
 
 // Computes an initial solution using the given instance of
-// SparseCholesky, and then refines it using the IterativeRefiner.
+// SparseCholesky, and then refines it using the SparseIterativeRefiner.
 class CERES_NO_EXPORT RefinedSparseCholesky final : public SparseCholesky {
  public:
-  RefinedSparseCholesky(std::unique_ptr<SparseCholesky> sparse_cholesky,
-                        std::unique_ptr<IterativeRefiner> iterative_refiner);
+  RefinedSparseCholesky(
+      std::unique_ptr<SparseCholesky> sparse_cholesky,
+      std::unique_ptr<SparseIterativeRefiner> iterative_refiner);
   ~RefinedSparseCholesky() override;
 
   CompressedRowSparseMatrix::StorageType StorageType() const override;
@@ -131,7 +132,7 @@
 
  private:
   std::unique_ptr<SparseCholesky> sparse_cholesky_;
-  std::unique_ptr<IterativeRefiner> iterative_refiner_;
+  std::unique_ptr<SparseIterativeRefiner> iterative_refiner_;
   CompressedRowSparseMatrix* lhs_ = nullptr;
 };
 
diff --git a/internal/ceres/sparse_cholesky_test.cc b/internal/ceres/sparse_cholesky_test.cc
index 31dbb02..0522232 100644
--- a/internal/ceres/sparse_cholesky_test.cc
+++ b/internal/ceres/sparse_cholesky_test.cc
@@ -209,7 +209,8 @@
                        ::testing::Values(OrderingType::NESDIS),
                        ::testing::Values(true, false)),
     ParamInfoToString);
-#endif  // !defined(CERES_NO_SUITESPARSE) && !defined(CERES_NO_CHOLMOD_PARTITION)
+#endif  // !defined(CERES_NO_SUITESPARSE) &&
+        // !defined(CERES_NO_CHOLMOD_PARTITION)
 
 #ifndef CERES_NO_ACCELERATE_SPARSE
 INSTANTIATE_TEST_SUITE_P(
@@ -283,9 +284,9 @@
                                            std::string* message));
 };
 
-class MockIterativeRefiner : public IterativeRefiner {
+class MockSparseIterativeRefiner : public SparseIterativeRefiner {
  public:
-  MockIterativeRefiner() : IterativeRefiner(1) {}
+  MockSparseIterativeRefiner() : SparseIterativeRefiner(1) {}
   MOCK_METHOD4(Refine,
                void(const SparseMatrix& lhs,
                     const double* rhs,
@@ -298,14 +299,15 @@
 
 TEST(RefinedSparseCholesky, StorageType) {
   auto* mock_sparse_cholesky = new MockSparseCholesky;
-  auto* mock_iterative_refiner = new MockIterativeRefiner;
+  auto* mock_iterative_refiner = new MockSparseIterativeRefiner;
   EXPECT_CALL(*mock_sparse_cholesky, StorageType())
       .Times(1)
       .WillRepeatedly(
           Return(CompressedRowSparseMatrix::StorageType::UPPER_TRIANGULAR));
   EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _)).Times(0);
   std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky);
-  std::unique_ptr<IterativeRefiner> iterative_refiner(mock_iterative_refiner);
+  std::unique_ptr<SparseIterativeRefiner> iterative_refiner(
+      mock_iterative_refiner);
   RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky),
                                                 std::move(iterative_refiner));
   EXPECT_EQ(refined_sparse_cholesky.StorageType(),
@@ -314,13 +316,14 @@
 
 TEST(RefinedSparseCholesky, Factorize) {
   auto* mock_sparse_cholesky = new MockSparseCholesky;
-  auto* mock_iterative_refiner = new MockIterativeRefiner;
+  auto* mock_iterative_refiner = new MockSparseIterativeRefiner;
   EXPECT_CALL(*mock_sparse_cholesky, Factorize(_, _))
       .Times(1)
       .WillRepeatedly(Return(LinearSolverTerminationType::SUCCESS));
   EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _)).Times(0);
   std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky);
-  std::unique_ptr<IterativeRefiner> iterative_refiner(mock_iterative_refiner);
+  std::unique_ptr<SparseIterativeRefiner> iterative_refiner(
+      mock_iterative_refiner);
   RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky),
                                                 std::move(iterative_refiner));
   CompressedRowSparseMatrix m(1, 1, 1);
@@ -331,14 +334,15 @@
 
 TEST(RefinedSparseCholesky, FactorAndSolveWithUnsuccessfulFactorization) {
   auto* mock_sparse_cholesky = new MockSparseCholesky;
-  auto* mock_iterative_refiner = new MockIterativeRefiner;
+  auto* mock_iterative_refiner = new MockSparseIterativeRefiner;
   EXPECT_CALL(*mock_sparse_cholesky, Factorize(_, _))
       .Times(1)
       .WillRepeatedly(Return(LinearSolverTerminationType::FAILURE));
   EXPECT_CALL(*mock_sparse_cholesky, Solve(_, _, _)).Times(0);
   EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _)).Times(0);
   std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky);
-  std::unique_ptr<IterativeRefiner> iterative_refiner(mock_iterative_refiner);
+  std::unique_ptr<SparseIterativeRefiner> iterative_refiner(
+      mock_iterative_refiner);
   RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky),
                                                 std::move(iterative_refiner));
   CompressedRowSparseMatrix m(1, 1, 1);
@@ -352,8 +356,8 @@
 
 TEST(RefinedSparseCholesky, FactorAndSolveWithSuccess) {
   auto* mock_sparse_cholesky = new MockSparseCholesky;
-  std::unique_ptr<MockIterativeRefiner> mock_iterative_refiner(
-      new MockIterativeRefiner);
+  std::unique_ptr<MockSparseIterativeRefiner> mock_iterative_refiner(
+      new MockSparseIterativeRefiner);
   EXPECT_CALL(*mock_sparse_cholesky, Factorize(_, _))
       .Times(1)
       .WillRepeatedly(Return(LinearSolverTerminationType::SUCCESS));
@@ -363,7 +367,7 @@
   EXPECT_CALL(*mock_iterative_refiner, Refine(_, _, _, _)).Times(1);
 
   std::unique_ptr<SparseCholesky> sparse_cholesky(mock_sparse_cholesky);
-  std::unique_ptr<IterativeRefiner> iterative_refiner(
+  std::unique_ptr<SparseIterativeRefiner> iterative_refiner(
       std::move(mock_iterative_refiner));
   RefinedSparseCholesky refined_sparse_cholesky(std::move(sparse_cholesky),
                                                 std::move(iterative_refiner));