Refactor ConjugateGradientsSolver

1. Convert it from a class to a template function. Where the
   template parameter is "DenseVectorType". This allows us
   to have a single implementation of Conjugate Gradients
   without worrying about where the matrix and the vectors
   are stored or what their internal representation is.

   For the case of CPU based vectors, we abstract operations
   on Eigen vectors using eigen_vector_ops.
2. Introduce ConjugateGradientsLinearOperator which is
   templated on DenseVectorType. It is the matrix vector
   multiplication abstraction.
3. Port the tests and all usages of ConjugateGradientsSolver
   to this new implementation.
4. Introduce Eigen::Vector based RightMultiply and LeftMultiply
   methods into LinearOperator which by default delete to the
   bare pointer based interfaces.
5. Add an identity preconditioner.

These changes are being made in preparation for adding a CUDA
based CGNR solver.

Change-Id: I9da36dc6c131856dd1a4aa7e645aaf12d25dd79b
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt
index ea21df2..e4ef4c0 100644
--- a/internal/ceres/CMakeLists.txt
+++ b/internal/ceres/CMakeLists.txt
@@ -191,7 +191,6 @@
     compressed_row_jacobian_writer.cc
     compressed_row_sparse_matrix.cc
     conditioned_cost_function.cc
-    conjugate_gradients_solver.cc
     context.cc
     context_impl.cc
     coordinate_descent_minimizer.cc
diff --git a/internal/ceres/cgnr_linear_operator.h b/internal/ceres/cgnr_linear_operator.h
deleted file mode 100644
index 4e47aff..0000000
--- a/internal/ceres/cgnr_linear_operator.h
+++ /dev/null
@@ -1,121 +0,0 @@
-// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2015 Google Inc. All rights reserved.
-// http://ceres-solver.org/
-//
-// Redistribution and use in source and binary forms, with or without
-// modification, are permitted provided that the following conditions are met:
-//
-// * Redistributions of source code must retain the above copyright notice,
-//   this list of conditions and the following disclaimer.
-// * Redistributions in binary form must reproduce the above copyright notice,
-//   this list of conditions and the following disclaimer in the documentation
-//   and/or other materials provided with the distribution.
-// * Neither the name of Google Inc. nor the names of its contributors may be
-//   used to endorse or promote products derived from this software without
-//   specific prior written permission.
-//
-// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
-// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
-// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
-// POSSIBILITY OF SUCH DAMAGE.
-//
-// Author: keir@google.com (Keir Mierle)
-
-#ifndef CERES_INTERNAL_CGNR_LINEAR_OPERATOR_H_
-#define CERES_INTERNAL_CGNR_LINEAR_OPERATOR_H_
-
-#include <algorithm>
-#include <memory>
-
-#include "ceres/internal/disable_warnings.h"
-#include "ceres/internal/eigen.h"
-#include "ceres/internal/export.h"
-#include "ceres/linear_operator.h"
-
-namespace ceres::internal {
-
-class SparseMatrix;
-
-// A linear operator which takes a matrix A and a diagonal vector D and
-// performs products of the form
-//
-//   (A^T A + D^T D)x
-//
-// This is used to implement iterative general sparse linear solving with
-// conjugate gradients, where A is the Jacobian and D is a regularizing
-// parameter. A brief proof that D^T D is the correct regularizer:
-//
-// Given a regularized least squares problem:
-//
-//   min  ||Ax - b||^2 + ||Dx||^2
-//    x
-//
-// First expand into matrix notation:
-//
-//   (Ax - b)^T (Ax - b) + xD^TDx
-//
-// Then multiply out to get:
-//
-//   = xA^TAx - 2b^T Ax + b^Tb + xD^TDx
-//
-// Take the derivative:
-//
-//   0 = 2A^TAx - 2A^T b + 2 D^TDx
-//   0 = A^TAx - A^T b + D^TDx
-//   0 = (A^TA + D^TD)x - A^T b
-//
-// Thus, the symmetric system we need to solve for CGNR is
-//
-//   Sx = z
-//
-// with S = A^TA + D^TD
-//  and z = A^T b
-//
-// Note: This class is not thread safe, since it uses some temporary storage.
-class CERES_NO_EXPORT CgnrLinearOperator final : public LinearOperator {
- public:
-  CgnrLinearOperator(const LinearOperator& A, const double* D)
-      : A_(A), D_(D), z_(new double[A.num_rows()]) {}
-
-  void RightMultiply(const double* x, double* y) const final {
-    std::fill(z_.get(), z_.get() + A_.num_rows(), 0.0);
-
-    // z = Ax
-    A_.RightMultiply(x, z_.get());
-
-    // y = y + Atz
-    A_.LeftMultiply(z_.get(), y);
-
-    // y = y + DtDx
-    if (D_ != nullptr) {
-      int n = A_.num_cols();
-      VectorRef(y, n).array() +=
-          ConstVectorRef(D_, n).array().square() * ConstVectorRef(x, n).array();
-    }
-  }
-
-  void LeftMultiply(const double* x, double* y) const final {
-    RightMultiply(x, y);
-  }
-
-  int num_rows() const final { return A_.num_cols(); }
-  int num_cols() const final { return A_.num_cols(); }
-
- private:
-  const LinearOperator& A_;
-  const double* D_;
-  std::unique_ptr<double[]> z_;
-};
-
-}  // namespace ceres::internal
-
-#include "ceres/internal/reenable_warnings.h"
-
-#endif  // CERES_INTERNAL_CGNR_LINEAR_OPERATOR_H_
diff --git a/internal/ceres/cgnr_solver.cc b/internal/ceres/cgnr_solver.cc
index 12e2ef9..f79b897 100644
--- a/internal/ceres/cgnr_solver.cc
+++ b/internal/ceres/cgnr_solver.cc
@@ -34,7 +34,6 @@
 #include <utility>
 
 #include "ceres/block_jacobi_preconditioner.h"
-#include "ceres/cgnr_linear_operator.h"
 #include "ceres/conjugate_gradients_solver.h"
 #include "ceres/internal/eigen.h"
 #include "ceres/linear_solver.h"
@@ -44,6 +43,68 @@
 
 namespace ceres::internal {
 
+// A linear operator which takes a matrix A and a diagonal vector D and
+// performs products of the form
+//
+//   (A^T A + D^T D)x
+//
+// This is used to implement iterative general sparse linear solving with
+// conjugate gradients, where A is the Jacobian and D is a regularizing
+// parameter. A brief proof that D^T D is the correct regularizer:
+//
+// Given a regularized least squares problem:
+//
+//   min  ||Ax - b||^2 + ||Dx||^2
+//    x
+//
+// First expand into matrix notation:
+//
+//   (Ax - b)^T (Ax - b) + xD^TDx
+//
+// Then multiply out to get:
+//
+//   = xA^TAx - 2b^T Ax + b^Tb + xD^TDx
+//
+// Take the derivative:
+//
+//   0 = 2A^TAx - 2A^T b + 2 D^TDx
+//   0 = A^TAx - A^T b + D^TDx
+//   0 = (A^TA + D^TD)x - A^T b
+//
+// Thus, the symmetric system we need to solve for CGNR is
+//
+//   Sx = z
+//
+// with S = A^TA + D^TD
+//  and z = A^T b
+//
+// Note: This class is not thread safe, since it uses some temporary storage.
+class CERES_NO_EXPORT CgnrLinearOperator final
+    : public ConjugateGradientsLinearOperator<Vector> {
+ public:
+  CgnrLinearOperator(const LinearOperator& A, const double* D)
+      : A_(A), D_(D), z_(Vector::Zero(A.num_rows())) {}
+
+  void RightMultiply(const Vector& x, Vector& y) final {
+    // z = Ax
+    // y = y + Atz
+    z_.setZero();
+    A_.RightMultiply(x, z_);
+    A_.LeftMultiply(z_, y);
+
+    // y = y + DtDx
+    if (D_ != nullptr) {
+      int n = A_.num_cols();
+      y.array() += ConstVectorRef(D_, n).array().square() * x.array();
+    }
+  }
+
+ private:
+  const LinearOperator& A_;
+  const double* D_;
+  Vector z_;
+};
+
 CgnrSolver::CgnrSolver(LinearSolver::Options options)
     : options_(std::move(options)) {
   if (options_.preconditioner_type != JACOBI &&
@@ -64,12 +125,6 @@
     const LinearSolver::PerSolveOptions& per_solve_options,
     double* x) {
   EventLogger event_logger("CgnrSolver::Solve");
-
-  // Form z = Atb.
-  Vector z(A->num_cols());
-  z.setZero();
-  A->LeftMultiply(b, z.data());
-
   if (!preconditioner_) {
     if (options_.preconditioner_type == JACOBI) {
       preconditioner_ = std::make_unique<BlockJacobiPreconditioner>(*A);
@@ -85,24 +140,37 @@
       preconditioner_options.context = options_.context;
       preconditioner_ =
           std::make_unique<SubsetPreconditioner>(preconditioner_options, *A);
+    } else {
+      preconditioner_ = std::make_unique<IdentityPreconditioner>(A->num_cols());
     }
   }
 
-  if (preconditioner_) {
-    preconditioner_->Update(*A, per_solve_options.D);
-  }
+  preconditioner_->Update(*A, per_solve_options.D);
 
-  LinearSolver::PerSolveOptions cg_per_solve_options = per_solve_options;
-  cg_per_solve_options.preconditioner = preconditioner_.get();
+  ConjugateGradientsSolverOptions cg_options;
+  cg_options.min_num_iterations = options_.min_num_iterations;
+  cg_options.max_num_iterations = options_.max_num_iterations;
+  cg_options.residual_reset_period = options_.residual_reset_period;
+  cg_options.q_tolerance = per_solve_options.q_tolerance;
+  cg_options.r_tolerance = per_solve_options.r_tolerance;
 
-  // Solve (AtA + DtD)x = z (= Atb).
-  VectorRef(x, A->num_cols()).setZero();
+  // lhs = AtA + DtD
   CgnrLinearOperator lhs(*A, per_solve_options.D);
+  // rhs = Atb.
+  Vector rhs(A->num_cols());
+  rhs.setZero();
+  A->LeftMultiply(b, rhs.data());
+
+  cg_solution_ = Vector::Zero(A->num_cols());
+  for (int i = 0; i < 4; ++i) {
+    scratch_[i] = Vector::Zero(A->num_cols());
+  }
   event_logger.AddEvent("Setup");
 
-  ConjugateGradientsSolver conjugate_gradient_solver(options_);
-  LinearSolver::Summary summary =
-      conjugate_gradient_solver.Solve(&lhs, z.data(), cg_per_solve_options, x);
+  LinearOperatorAdapter preconditioner(*preconditioner_);
+  auto summary = ConjugateGradientsSolver(
+      cg_options, lhs, rhs, preconditioner, scratch_, cg_solution_);
+  VectorRef(x, A->num_cols()) = cg_solution_;
   event_logger.AddEvent("Solve");
   return summary;
 }
diff --git a/internal/ceres/cgnr_solver.h b/internal/ceres/cgnr_solver.h
index 119f838..6982296 100644
--- a/internal/ceres/cgnr_solver.h
+++ b/internal/ceres/cgnr_solver.h
@@ -64,6 +64,8 @@
  private:
   const LinearSolver::Options options_;
   std::unique_ptr<Preconditioner> preconditioner_;
+  Vector cg_solution_;
+  Vector scratch_[4];
 };
 
 }  // namespace ceres::internal
diff --git a/internal/ceres/conjugate_gradients_solver.cc b/internal/ceres/conjugate_gradients_solver.cc
deleted file mode 100644
index 2a0c3ab..0000000
--- a/internal/ceres/conjugate_gradients_solver.cc
+++ /dev/null
@@ -1,251 +0,0 @@
-// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2015 Google Inc. All rights reserved.
-// http://ceres-solver.org/
-//
-// Redistribution and use in source and binary forms, with or without
-// modification, are permitted provided that the following conditions are met:
-//
-// * Redistributions of source code must retain the above copyright notice,
-//   this list of conditions and the following disclaimer.
-// * Redistributions in binary form must reproduce the above copyright notice,
-//   this list of conditions and the following disclaimer in the documentation
-//   and/or other materials provided with the distribution.
-// * Neither the name of Google Inc. nor the names of its contributors may be
-//   used to endorse or promote products derived from this software without
-//   specific prior written permission.
-//
-// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
-// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
-// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
-// POSSIBILITY OF SUCH DAMAGE.
-//
-// Author: sameeragarwal@google.com (Sameer Agarwal)
-//
-// A preconditioned conjugate gradients solver
-// (ConjugateGradientsSolver) for positive semidefinite linear
-// systems.
-//
-// We have also augmented the termination criterion used by this
-// solver to support not just residual based termination but also
-// termination based on decrease in the value of the quadratic model
-// that CG optimizes.
-
-#include "ceres/conjugate_gradients_solver.h"
-
-#include <cmath>
-#include <cstddef>
-#include <utility>
-
-#include "ceres/internal/eigen.h"
-#include "ceres/linear_operator.h"
-#include "ceres/stringprintf.h"
-#include "ceres/types.h"
-#include "glog/logging.h"
-
-namespace ceres::internal {
-namespace {
-
-bool IsZeroOrInfinity(double x) { return ((x == 0.0) || std::isinf(x)); }
-
-}  // namespace
-
-ConjugateGradientsSolver::ConjugateGradientsSolver(
-    LinearSolver::Options options)
-    : options_(std::move(options)) {}
-
-LinearSolver::Summary ConjugateGradientsSolver::Solve(
-    LinearOperator* A,
-    const double* b,
-    const LinearSolver::PerSolveOptions& per_solve_options,
-    double* x) {
-  CHECK(A != nullptr);
-  CHECK(x != nullptr);
-  CHECK(b != nullptr);
-  CHECK_EQ(A->num_rows(), A->num_cols());
-
-  LinearSolver::Summary summary;
-  summary.termination_type = LinearSolverTerminationType::NO_CONVERGENCE;
-  summary.message = "Maximum number of iterations reached.";
-  summary.num_iterations = 0;
-
-  const int num_cols = A->num_cols();
-  VectorRef xref(x, num_cols);
-  ConstVectorRef bref(b, num_cols);
-
-  const double norm_b = bref.norm();
-  if (norm_b == 0.0) {
-    xref.setZero();
-    summary.termination_type = LinearSolverTerminationType::SUCCESS;
-    summary.message = "Convergence. |b| = 0.";
-    return summary;
-  }
-
-  Vector r(num_cols);
-  Vector p(num_cols);
-  Vector z(num_cols);
-  Vector tmp(num_cols);
-
-  const double tol_r = per_solve_options.r_tolerance * norm_b;
-
-  tmp.setZero();
-  A->RightMultiply(x, tmp.data());
-  r = bref - tmp;
-  double norm_r = r.norm();
-  if (options_.min_num_iterations == 0 && norm_r <= tol_r) {
-    summary.termination_type = LinearSolverTerminationType::SUCCESS;
-    summary.message =
-        StringPrintf("Convergence. |r| = %e <= %e.", norm_r, tol_r);
-    return summary;
-  }
-
-  double rho = 1.0;
-
-  // Initial value of the quadratic model Q = x'Ax - 2 * b'x.
-  double Q0 = -1.0 * xref.dot(bref + r);
-
-  for (summary.num_iterations = 1;; ++summary.num_iterations) {
-    // Apply preconditioner
-    if (per_solve_options.preconditioner != nullptr) {
-      z.setZero();
-      per_solve_options.preconditioner->RightMultiply(r.data(), z.data());
-    } else {
-      z = r;
-    }
-
-    double last_rho = rho;
-    rho = r.dot(z);
-    if (IsZeroOrInfinity(rho)) {
-      summary.termination_type = LinearSolverTerminationType::FAILURE;
-      summary.message = StringPrintf("Numerical failure. rho = r'z = %e.", rho);
-      break;
-    }
-
-    if (summary.num_iterations == 1) {
-      p = z;
-    } else {
-      double beta = rho / last_rho;
-      if (IsZeroOrInfinity(beta)) {
-        summary.termination_type = LinearSolverTerminationType::FAILURE;
-        summary.message = StringPrintf(
-            "Numerical failure. beta = rho_n / rho_{n-1} = %e, "
-            "rho_n = %e, rho_{n-1} = %e",
-            beta,
-            rho,
-            last_rho);
-        break;
-      }
-      p = z + beta * p;
-    }
-
-    Vector& q = z;
-    q.setZero();
-    A->RightMultiply(p.data(), q.data());
-    const double pq = p.dot(q);
-    if ((pq <= 0) || std::isinf(pq)) {
-      summary.termination_type = LinearSolverTerminationType::NO_CONVERGENCE;
-      summary.message = StringPrintf(
-          "Matrix is indefinite, no more progress can be made. "
-          "p'q = %e. |p| = %e, |q| = %e",
-          pq,
-          p.norm(),
-          q.norm());
-      break;
-    }
-
-    const double alpha = rho / pq;
-    if (std::isinf(alpha)) {
-      summary.termination_type = LinearSolverTerminationType::FAILURE;
-      summary.message = StringPrintf(
-          "Numerical failure. alpha = rho / pq = %e, rho = %e, pq = %e.",
-          alpha,
-          rho,
-          pq);
-      break;
-    }
-
-    xref = xref + alpha * p;
-
-    // Ideally we would just use the update r = r - alpha*q to keep
-    // track of the residual vector. However this estimate tends to
-    // drift over time due to round off errors. Thus every
-    // residual_reset_period iterations, we calculate the residual as
-    // r = b - Ax. We do not do this every iteration because this
-    // requires an additional matrix vector multiply which would
-    // double the complexity of the CG algorithm.
-    if (summary.num_iterations % options_.residual_reset_period == 0) {
-      tmp.setZero();
-      A->RightMultiply(x, tmp.data());
-      r = bref - tmp;
-    } else {
-      r = r - alpha * q;
-    }
-
-    // Quadratic model based termination.
-    //   Q1 = x'Ax - 2 * b' x.
-    const double Q1 = -1.0 * xref.dot(bref + r);
-
-    // For PSD matrices A, let
-    //
-    //   Q(x) = x'Ax - 2b'x
-    //
-    // be the cost of the quadratic function defined by A and b. Then,
-    // the solver terminates at iteration i if
-    //
-    //   i * (Q(x_i) - Q(x_i-1)) / Q(x_i) < q_tolerance.
-    //
-    // This termination criterion is more useful when using CG to
-    // solve the Newton step. This particular convergence test comes
-    // from Stephen Nash's work on truncated Newton
-    // methods. References:
-    //
-    //   1. Stephen G. Nash & Ariela Sofer, Assessing A Search
-    //   Direction Within A Truncated Newton Method, Operation
-    //   Research Letters 9(1990) 219-221.
-    //
-    //   2. Stephen G. Nash, A Survey of Truncated Newton Methods,
-    //   Journal of Computational and Applied Mathematics,
-    //   124(1-2), 45-59, 2000.
-    //
-    const double zeta = summary.num_iterations * (Q1 - Q0) / Q1;
-    if (zeta < per_solve_options.q_tolerance &&
-        summary.num_iterations >= options_.min_num_iterations) {
-      summary.termination_type = LinearSolverTerminationType::SUCCESS;
-      summary.message =
-          StringPrintf("Iteration: %d Convergence: zeta = %e < %e. |r| = %e",
-                       summary.num_iterations,
-                       zeta,
-                       per_solve_options.q_tolerance,
-                       r.norm());
-      break;
-    }
-    Q0 = Q1;
-
-    // Residual based termination.
-    norm_r = r.norm();
-    if (norm_r <= tol_r &&
-        summary.num_iterations >= options_.min_num_iterations) {
-      summary.termination_type = LinearSolverTerminationType::SUCCESS;
-      summary.message =
-          StringPrintf("Iteration: %d Convergence. |r| = %e <= %e.",
-                       summary.num_iterations,
-                       norm_r,
-                       tol_r);
-      break;
-    }
-
-    if (summary.num_iterations >= options_.max_num_iterations) {
-      break;
-    }
-  }
-
-  return summary;
-}
-
-}  // namespace ceres::internal
diff --git a/internal/ceres/conjugate_gradients_solver.h b/internal/ceres/conjugate_gradients_solver.h
index c5efac7..6254d2c 100644
--- a/internal/ceres/conjugate_gradients_solver.h
+++ b/internal/ceres/conjugate_gradients_solver.h
@@ -1,5 +1,5 @@
 // Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2015 Google Inc. All rights reserved.
+// Copyright 2022 Google Inc. All rights reserved.
 // http://ceres-solver.org/
 //
 // Redistribution and use in source and binary forms, with or without
@@ -34,39 +34,267 @@
 #ifndef CERES_INTERNAL_CONJUGATE_GRADIENTS_SOLVER_H_
 #define CERES_INTERNAL_CONJUGATE_GRADIENTS_SOLVER_H_
 
+#include <cmath>
+#include <cstddef>
+#include <utility>
+
+#include "ceres/eigen_vector_ops.h"
 #include "ceres/internal/disable_warnings.h"
+#include "ceres/internal/eigen.h"
 #include "ceres/internal/export.h"
+#include "ceres/linear_operator.h"
 #include "ceres/linear_solver.h"
+#include "ceres/stringprintf.h"
+#include "ceres/types.h"
+#include "glog/logging.h"
 
 namespace ceres::internal {
 
-class LinearOperator;
-
-// This class implements the now classical Conjugate Gradients
-// algorithm of Hestenes & Stiefel for solving positive semidefinite
-// linear systems. Optionally it can use a preconditioner also to
-// reduce the condition number of the linear system and improve the
-// convergence rate. Modern references for Conjugate Gradients are the
-// books by Yousef Saad and Trefethen & Bau. This implementation of CG
-// has been augmented with additional termination tests that are
-// needed for forcing early termination when used as part of an
-// inexact Newton solver.
-//
-// For more details see the documentation for
-// LinearSolver::PerSolveOptions::r_tolerance and
-// LinearSolver::PerSolveOptions::q_tolerance in linear_solver.h.
-class CERES_NO_EXPORT ConjugateGradientsSolver final : public LinearSolver {
+// Interface for the linear operator used by ConjugateGradientsSolver.
+template <typename DenseVectorType>
+class ConjugateGradientsLinearOperator {
  public:
-  explicit ConjugateGradientsSolver(LinearSolver::Options options);
-  Summary Solve(LinearOperator* A,
-                const double* b,
-                const LinearSolver::PerSolveOptions& per_solve_options,
-                double* x) final;
+  ~ConjugateGradientsLinearOperator() = default;
+  virtual void RightMultiply(const DenseVectorType& x, DenseVectorType& y) = 0;
+};
+
+// Adapter class that makes LinearOperator appear like an instance of
+// ConjugateGradientsLinearOperator.
+class LinearOperatorAdapter : public ConjugateGradientsLinearOperator<Vector> {
+ public:
+  LinearOperatorAdapter(LinearOperator& linear_operator)
+      : linear_operator_(linear_operator) {}
+
+  void RightMultiply(const Vector& x, Vector& y) final {
+    linear_operator_.RightMultiply(x, y);
+  }
 
  private:
-  const LinearSolver::Options options_;
+  LinearOperator& linear_operator_;
 };
 
+// Options to control the ConjugateGradientsSolver. For detailed documentation
+// for each of these options see linear_solver.h
+struct ConjugateGradientsSolverOptions {
+  int min_num_iterations = 1;
+  int max_num_iterations = 1;
+  int residual_reset_period = 10;
+  double r_tolerance = 0.0;
+  double q_tolerance = 0.0;
+};
+
+// This function implements the now classical Conjugate Gradients algorithm of
+// Hestenes & Stiefel for solving positive semidefinite linear systems.
+// Optionally it can use a preconditioner also to reduce the condition number of
+// the linear system and improve the convergence rate. Modern references for
+// Conjugate Gradients are the books by Yousef Saad and Trefethen & Bau. This
+// implementation of CG has been augmented with additional termination tests
+// that are needed for forcing early termination when used as part of an inexact
+// Newton solver.
+//
+// This implementation is templated over DenseVectorType and then in turn on
+// ConjugateGradientsLinearOperator, which allows us to write an abstract
+// implementaion of the Conjugate Gradients algorithm without worrying about how
+// these objects are implemented or where they are stored. In particular it
+// allows us to have a single implementation that works on CPU and GPU based
+// matrices and vectors.
+//
+// scratch must contain four DenseVector objects of the same size as rhs and
+// solution. By asking the user for scratch space, we guarantee that we will not
+// perform any allocations inside this function.
+template <typename DenseVectorType>
+LinearSolver::Summary ConjugateGradientsSolver(
+    const ConjugateGradientsSolverOptions options,
+    ConjugateGradientsLinearOperator<DenseVectorType>& lhs,
+    const DenseVectorType& rhs,
+    ConjugateGradientsLinearOperator<DenseVectorType>& preconditioner,
+    DenseVectorType scratch[4],
+    DenseVectorType& solution) {
+  auto IsZeroOrInfinity = [](double x) {
+    return ((x == 0.0) || std::isinf(x));
+  };
+
+  DenseVectorType& p = scratch[0];
+  DenseVectorType& r = scratch[1];
+  DenseVectorType& z = scratch[2];
+  DenseVectorType& tmp = scratch[3];
+
+  LinearSolver::Summary summary;
+  summary.termination_type = LinearSolverTerminationType::NO_CONVERGENCE;
+  summary.message = "Maximum number of iterations reached.";
+  summary.num_iterations = 0;
+
+  const double norm_b = Norm(rhs);
+  if (norm_b == 0.0) {
+    SetZero(solution);
+    summary.termination_type = LinearSolverTerminationType::SUCCESS;
+    summary.message = "Convergence. |b| = 0.";
+    return summary;
+  }
+
+  const double tol_r = options.r_tolerance * norm_b;
+
+  SetZero(tmp);
+  lhs.RightMultiply(solution, tmp);
+
+  // r = rhs - tmp
+  Axpby(1.0, rhs, -1.0, tmp, r);
+
+  double norm_r = Norm(r);
+  if (options.min_num_iterations == 0 && norm_r <= tol_r) {
+    summary.termination_type = LinearSolverTerminationType::SUCCESS;
+    summary.message =
+        StringPrintf("Convergence. |r| = %e <= %e.", norm_r, tol_r);
+    return summary;
+  }
+
+  double rho = 1.0;
+
+  // Initial value of the quadratic model Q = x'Ax - 2 * b'x.
+  // double Q0 = -1.0 * solution.dot(rhs + r);
+  Axpby(1.0, rhs, 1.0, r, tmp);
+  double Q0 = -Dot(solution, tmp);
+
+  for (summary.num_iterations = 1;; ++summary.num_iterations) {
+    SetZero(z);
+    preconditioner.RightMultiply(r, z);
+
+    double last_rho = rho;
+    // rho = r.dot(z);
+    rho = Dot(r, z);
+    if (IsZeroOrInfinity(rho)) {
+      summary.termination_type = LinearSolverTerminationType::FAILURE;
+      summary.message = StringPrintf("Numerical failure. rho = r'z = %e.", rho);
+      break;
+    }
+
+    if (summary.num_iterations == 1) {
+      Copy(z, p);
+    } else {
+      double beta = rho / last_rho;
+      if (IsZeroOrInfinity(beta)) {
+        summary.termination_type = LinearSolverTerminationType::FAILURE;
+        summary.message = StringPrintf(
+            "Numerical failure. beta = rho_n / rho_{n-1} = %e, "
+            "rho_n = %e, rho_{n-1} = %e",
+            beta,
+            rho,
+            last_rho);
+        break;
+      }
+      // p = z + beta * p;
+      Axpby(1.0, z, beta, p, p);
+    }
+
+    DenseVectorType& q = z;
+    SetZero(q);
+    lhs.RightMultiply(p, q);
+    const double pq = Dot(p, q);
+    if ((pq <= 0) || std::isinf(pq)) {
+      summary.termination_type = LinearSolverTerminationType::NO_CONVERGENCE;
+      summary.message = StringPrintf(
+          "Matrix is indefinite, no more progress can be made. "
+          "p'q = %e. |p| = %e, |q| = %e",
+          pq,
+          Norm(p),
+          Norm(q));
+      break;
+    }
+
+    const double alpha = rho / pq;
+    if (std::isinf(alpha)) {
+      summary.termination_type = LinearSolverTerminationType::FAILURE;
+      summary.message = StringPrintf(
+          "Numerical failure. alpha = rho / pq = %e, rho = %e, pq = %e.",
+          alpha,
+          rho,
+          pq);
+      break;
+    }
+
+    // solution = solution + alpha * p;
+    Axpby(1.0, solution, alpha, p, solution);
+
+    // Ideally we would just use the update r = r - alpha*q to keep
+    // track of the residual vector. However this estimate tends to
+    // drift over time due to round off errors. Thus every
+    // residual_reset_period iterations, we calculate the residual as
+    // r = b - Ax. We do not do this every iteration because this
+    // requires an additional matrix vector multiply which would
+    // double the complexity of the CG algorithm.
+    if (summary.num_iterations % options.residual_reset_period == 0) {
+      SetZero(tmp);
+      lhs.RightMultiply(solution, tmp);
+      Axpby(1.0, rhs, -1.0, tmp, r);
+      // r = rhs - tmp;
+    } else {
+      Axpby(1.0, r, -alpha, q, r);
+      // r = r - alpha * q;
+    }
+
+    // Quadratic model based termination.
+    //   Q1 = x'Ax - 2 * b' x.
+    // const double Q1 = -1.0 * solution.dot(rhs + r);
+    Axpby(1.0, rhs, 1.0, r, tmp);
+    const double Q1 = -Dot(solution, tmp);
+
+    // For PSD matrices A, let
+    //
+    //   Q(x) = x'Ax - 2b'x
+    //
+    // be the cost of the quadratic function defined by A and b. Then,
+    // the solver terminates at iteration i if
+    //
+    //   i * (Q(x_i) - Q(x_i-1)) / Q(x_i) < q_tolerance.
+    //
+    // This termination criterion is more useful when using CG to
+    // solve the Newton step. This particular convergence test comes
+    // from Stephen Nash's work on truncated Newton
+    // methods. References:
+    //
+    //   1. Stephen G. Nash & Ariela Sofer, Assessing A Search
+    //   Direction Within A Truncated Newton Method, Operation
+    //   Research Letters 9(1990) 219-221.
+    //
+    //   2. Stephen G. Nash, A Survey of Truncated Newton Methods,
+    //   Journal of Computational and Applied Mathematics,
+    //   124(1-2), 45-59, 2000.
+    //
+    const double zeta = summary.num_iterations * (Q1 - Q0) / Q1;
+    if (zeta < options.q_tolerance &&
+        summary.num_iterations >= options.min_num_iterations) {
+      summary.termination_type = LinearSolverTerminationType::SUCCESS;
+      summary.message =
+          StringPrintf("Iteration: %d Convergence: zeta = %e < %e. |r| = %e",
+                       summary.num_iterations,
+                       zeta,
+                       options.q_tolerance,
+                       Norm(r));
+      break;
+    }
+    Q0 = Q1;
+
+    // Residual based termination.
+    norm_r = Norm(r);
+    if (norm_r <= tol_r &&
+        summary.num_iterations >= options.min_num_iterations) {
+      summary.termination_type = LinearSolverTerminationType::SUCCESS;
+      summary.message =
+          StringPrintf("Iteration: %d Convergence. |r| = %e <= %e.",
+                       summary.num_iterations,
+                       norm_r,
+                       tol_r);
+      break;
+    }
+
+    if (summary.num_iterations >= options.max_num_iterations) {
+      break;
+    }
+  }
+
+  return summary;
+}
+
 }  // namespace ceres::internal
 
 #include "ceres/internal/reenable_warnings.h"
diff --git a/internal/ceres/conjugate_gradients_solver_test.cc b/internal/ceres/conjugate_gradients_solver_test.cc
index b27fee0..a01dfc9 100644
--- a/internal/ceres/conjugate_gradients_solver_test.cc
+++ b/internal/ceres/conjugate_gradients_solver_test.cc
@@ -37,6 +37,7 @@
 
 #include "ceres/internal/eigen.h"
 #include "ceres/linear_solver.h"
+#include "ceres/preconditioner.h"
 #include "ceres/triplet_sparse_matrix.h"
 #include "ceres/types.h"
 #include "gtest/gtest.h"
@@ -58,15 +59,24 @@
   x(1) = 1;
   x(2) = 1;
 
-  LinearSolver::Options options;
-  options.max_num_iterations = 10;
+  ConjugateGradientsSolverOptions cg_options;
+  cg_options.min_num_iterations = 1;
+  cg_options.max_num_iterations = 10;
+  cg_options.residual_reset_period = 20;
+  cg_options.q_tolerance = 0.0;
+  cg_options.r_tolerance = 1e-9;
 
-  LinearSolver::PerSolveOptions per_solve_options;
-  per_solve_options.r_tolerance = 1e-9;
+  Vector scratch[4];
+  for (int i = 0; i < 4; ++i) {
+    scratch[i] = Vector::Zero(A->num_cols());
+  }
 
-  ConjugateGradientsSolver solver(options);
-  LinearSolver::Summary summary =
-      solver.Solve(A.get(), b.data(), per_solve_options, x.data());
+  IdentityPreconditioner identity(A->num_cols());
+  LinearOperatorAdapter lhs(*A);
+  LinearOperatorAdapter preconditioner(identity);
+
+  auto summary =
+      ConjugateGradientsSolver(cg_options, lhs, b, preconditioner, scratch, x);
 
   EXPECT_EQ(summary.termination_type, LinearSolverTerminationType::SUCCESS);
   ASSERT_EQ(summary.num_iterations, 1);
@@ -114,15 +124,24 @@
   x(1) = 1;
   x(2) = 1;
 
-  LinearSolver::Options options;
-  options.max_num_iterations = 10;
+  ConjugateGradientsSolverOptions cg_options;
+  cg_options.min_num_iterations = 1;
+  cg_options.max_num_iterations = 10;
+  cg_options.residual_reset_period = 20;
+  cg_options.q_tolerance = 0.0;
+  cg_options.r_tolerance = 1e-9;
 
-  LinearSolver::PerSolveOptions per_solve_options;
-  per_solve_options.r_tolerance = 1e-9;
+  Vector scratch[4];
+  for (int i = 0; i < 4; ++i) {
+    scratch[i] = Vector::Zero(A->num_cols());
+  }
 
-  ConjugateGradientsSolver solver(options);
-  LinearSolver::Summary summary =
-      solver.Solve(A.get(), b.data(), per_solve_options, x.data());
+  IdentityPreconditioner identity(A->num_cols());
+  LinearOperatorAdapter lhs(*A);
+  LinearOperatorAdapter preconditioner(identity);
+
+  auto summary =
+      ConjugateGradientsSolver(cg_options, lhs, b, preconditioner, scratch, x);
 
   EXPECT_EQ(summary.termination_type, LinearSolverTerminationType::SUCCESS);
 
diff --git a/internal/ceres/eigen_vector_ops.h b/internal/ceres/eigen_vector_ops.h
new file mode 100644
index 0000000..5bcf49d
--- /dev/null
+++ b/internal/ceres/eigen_vector_ops.h
@@ -0,0 +1,52 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2022 Google Inc. All rights reserved.
+// http://ceres-solver.org/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+//   this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+//   this list of conditions and the following disclaimer in the documentation
+//   and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+//   used to endorse or promote products derived from this software without
+//   specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: sameeragarwal@google.com (Sameer Agarwal)
+
+#ifndef CERES_INTERNAL_EIGEN_VECTOR_OPS_H_
+#define CERES_INTERNAL_EIGEN_VECTOR_OPS_H_
+
+#include "ceres/internal/eigen.h"
+
+namespace ceres::internal {
+
+// Blas1 operations on Eigen vectors. These functions are needed as an
+// abstraction layer so that we can use different versions of a vector style
+// object in the conjugate gradients linear solver.
+inline double Norm(const Vector& x) { return x.norm(); }
+inline void SetZero(Vector& x) { x.setZero(); }
+inline void Axpby(
+    double a, const Vector& x, double b, const Vector& y, Vector& z) {
+  z = a * x + b * y;
+}
+inline double Dot(const Vector& x, const Vector& y) { return x.dot(y); }
+inline void Copy(const Vector& from, Vector& to) { to = from; }
+
+}  // namespace ceres::internal
+
+#endif  // CERES_INTERNAL_EIGEN_VECTOR_OPS_H_
diff --git a/internal/ceres/iterative_schur_complement_solver.cc b/internal/ceres/iterative_schur_complement_solver.cc
index 9cacf0b..0d0daaa 100644
--- a/internal/ceres/iterative_schur_complement_solver.cc
+++ b/internal/ceres/iterative_schur_complement_solver.cc
@@ -94,15 +94,6 @@
   reduced_linear_system_solution_.resize(schur_complement_->num_rows());
   reduced_linear_system_solution_.setZero();
 
-  LinearSolver::Options cg_options;
-  cg_options.min_num_iterations = options_.min_num_iterations;
-  cg_options.max_num_iterations = options_.max_num_iterations;
-  ConjugateGradientsSolver cg_solver(cg_options);
-
-  LinearSolver::PerSolveOptions cg_per_solve_options;
-  cg_per_solve_options.r_tolerance = per_solve_options.r_tolerance;
-  cg_per_solve_options.q_tolerance = per_solve_options.q_tolerance;
-
   CreatePreconditioner(A);
   if (preconditioner_.get() != nullptr) {
     if (!preconditioner_->Update(*A, per_solve_options.D)) {
@@ -112,16 +103,33 @@
       summary.message = "Preconditioner update failed.";
       return summary;
     }
+  }
 
-    cg_per_solve_options.preconditioner = preconditioner_.get();
+  ConjugateGradientsSolverOptions cg_options;
+  cg_options.min_num_iterations = options_.min_num_iterations;
+  cg_options.max_num_iterations = options_.max_num_iterations;
+  cg_options.residual_reset_period = options_.residual_reset_period;
+  cg_options.q_tolerance = per_solve_options.q_tolerance;
+  cg_options.r_tolerance = per_solve_options.r_tolerance;
+
+  LinearOperatorAdapter lhs(*schur_complement_);
+  LinearOperatorAdapter preconditioner(*preconditioner_);
+
+  Vector scratch[4];
+  for (int i = 0; i < 4; ++i) {
+    scratch[i] = Vector::Zero(schur_complement_->num_cols());
   }
 
   event_logger.AddEvent("Setup");
+
   LinearSolver::Summary summary =
-      cg_solver.Solve(schur_complement_.get(),
-                      schur_complement_->rhs().data(),
-                      cg_per_solve_options,
-                      reduced_linear_system_solution_.data());
+      ConjugateGradientsSolver(cg_options,
+                               lhs,
+                               schur_complement_->rhs(),
+                               preconditioner,
+                               scratch,
+                               reduced_linear_system_solution_);
+
   if (summary.termination_type != LinearSolverTerminationType::FAILURE &&
       summary.termination_type != LinearSolverTerminationType::FATAL_ERROR) {
     schur_complement_->BackSubstitute(reduced_linear_system_solution_.data(),
@@ -133,8 +141,7 @@
 
 void IterativeSchurComplementSolver::CreatePreconditioner(
     BlockSparseMatrix* A) {
-  if (options_.preconditioner_type == IDENTITY ||
-      preconditioner_.get() != nullptr) {
+  if (preconditioner_.get() != nullptr) {
     return;
   }
 
@@ -153,6 +160,10 @@
   preconditioner_options.context = options_.context;
 
   switch (options_.preconditioner_type) {
+    case IDENTITY:
+      preconditioner_ = std::make_unique<IdentityPreconditioner>(
+          schur_complement_->num_cols());
+      break;
     case JACOBI:
       preconditioner_ = std::make_unique<SparseMatrixPreconditionerWrapper>(
           schur_complement_->block_diagonal_FtF_inverse());
diff --git a/internal/ceres/linear_operator.h b/internal/ceres/linear_operator.h
index f8f2208..cab87e7 100644
--- a/internal/ceres/linear_operator.h
+++ b/internal/ceres/linear_operator.h
@@ -33,6 +33,7 @@
 #ifndef CERES_INTERNAL_LINEAR_OPERATOR_H_
 #define CERES_INTERNAL_LINEAR_OPERATOR_H_
 
+#include "ceres/internal/eigen.h"
 #include "ceres/internal/export.h"
 #include "ceres/types.h"
 
@@ -49,6 +50,14 @@
   // y = y + A'x;
   virtual void LeftMultiply(const double* x, double* y) const = 0;
 
+  virtual void RightMultiply(const Vector& x, Vector& y) const {
+    RightMultiply(x.data(), y.data());
+  }
+
+  virtual void LeftMultiply(const Vector& x, Vector& y) const {
+    LeftMultiply(x.data(), y.data());
+  }
+
   virtual int num_rows() const = 0;
   virtual int num_cols() const = 0;
 };
diff --git a/internal/ceres/preconditioner.h b/internal/ceres/preconditioner.h
index 68b575f..75613fb 100644
--- a/internal/ceres/preconditioner.h
+++ b/internal/ceres/preconditioner.h
@@ -142,6 +142,22 @@
   int num_cols() const override { return num_rows(); }
 };
 
+class CERES_NO_EXPORT IdentityPreconditioner : public Preconditioner {
+ public:
+  IdentityPreconditioner(int num_rows) : num_rows_(num_rows) {}
+
+  bool Update(const LinearOperator& A, const double* D) final { return true; }
+
+  void RightMultiply(const double* x, double* y) const final {
+    VectorRef(y, num_rows_) += ConstVectorRef(x, num_rows_);
+  }
+
+  int num_rows() const final { return num_rows_; }
+
+ private:
+  int num_rows_ = -1;
+};
+
 // This templated subclass of Preconditioner serves as a base class for
 // other preconditioners that depend on the particular matrix layout of
 // the underlying linear operator.
diff --git a/internal/ceres/schur_complement_solver.cc b/internal/ceres/schur_complement_solver.cc
index 28e6a5d..da52b78 100644
--- a/internal/ceres/schur_complement_solver.cc
+++ b/internal/ceres/schur_complement_solver.cc
@@ -61,48 +61,33 @@
 
 namespace {
 
-class BlockRandomAccessSparseMatrixAdapter final : public LinearOperator {
+class BlockRandomAccessSparseMatrixAdapter
+    : public ConjugateGradientsLinearOperator<Vector> {
  public:
   explicit BlockRandomAccessSparseMatrixAdapter(
       const BlockRandomAccessSparseMatrix& m)
       : m_(m) {}
 
-  // y = y + Ax;
-  void RightMultiply(const double* x, double* y) const final {
-    m_.SymmetricRightMultiply(x, y);
+  void RightMultiply(const Vector& x, Vector& y) final {
+    m_.SymmetricRightMultiply(x.data(), y.data());
   }
 
-  // y = y + A'x;
-  void LeftMultiply(const double* x, double* y) const final {
-    m_.SymmetricRightMultiply(x, y);
-  }
-
-  int num_rows() const final { return m_.num_rows(); }
-  int num_cols() const final { return m_.num_rows(); }
-
  private:
   const BlockRandomAccessSparseMatrix& m_;
 };
 
-class BlockRandomAccessDiagonalMatrixAdapter final : public LinearOperator {
+class BlockRandomAccessDiagonalMatrixAdapter final
+    : public ConjugateGradientsLinearOperator<Vector> {
  public:
   explicit BlockRandomAccessDiagonalMatrixAdapter(
       const BlockRandomAccessDiagonalMatrix& m)
       : m_(m) {}
 
   // y = y + Ax;
-  void RightMultiply(const double* x, double* y) const final {
-    m_.RightMultiply(x, y);
+  void RightMultiply(const Vector& x, Vector& y) final {
+    m_.RightMultiply(x.data(), y.data());
   }
 
-  // y = y + A'x;
-  void LeftMultiply(const double* x, double* y) const final {
-    m_.RightMultiply(x, y);
-  }
-
-  int num_rows() const final { return m_.num_rows(); }
-  int num_cols() const final { return m_.num_rows(); }
-
  private:
   const BlockRandomAccessDiagonalMatrix& m_;
 };
@@ -160,7 +145,7 @@
                          b,
                          per_solve_options.D,
                          lhs_.get(),
-                         rhs_.get());
+                         rhs_.data());
   event_logger.AddEvent("Eliminate");
 
   double* reduced_solution = x + A->num_cols() - lhs_->num_cols();
@@ -196,7 +181,7 @@
   }
 
   set_lhs(std::make_unique<BlockRandomAccessDenseMatrix>(blocks));
-  set_rhs(std::make_unique<double[]>(lhs()->num_rows()));
+  ResizeRhs(lhs()->num_rows());
 }
 
 // Solve the system Sx = r, assuming that the matrix S is stored in a
@@ -220,7 +205,7 @@
 
   summary.num_iterations = 1;
   summary.termination_type = cholesky_->FactorAndSolve(
-      num_rows, m->mutable_values(), rhs(), solution, &summary.message);
+      num_rows, m->mutable_values(), rhs().data(), solution, &summary.message);
   return summary;
 }
 
@@ -303,7 +288,7 @@
 
   set_lhs(
       std::make_unique<BlockRandomAccessSparseMatrix>(blocks_, block_pairs));
-  set_rhs(std::make_unique<double[]>(lhs()->num_rows()));
+  ResizeRhs(lhs()->num_rows());
 }
 
 LinearSolver::Summary SparseSchurComplementSolver::SolveReducedLinearSystem(
@@ -343,7 +328,7 @@
 
   summary.num_iterations = 1;
   summary.termination_type = sparse_cholesky_->FactorAndSolve(
-      lhs.get(), rhs(), solution, &summary.message);
+      lhs.get(), rhs().data(), solution, &summary.message);
   return summary;
 }
 
@@ -396,24 +381,28 @@
 
   VectorRef(solution, num_rows).setZero();
 
-  std::unique_ptr<LinearOperator> lhs_adapter =
-      std::make_unique<BlockRandomAccessSparseMatrixAdapter>(*sc);
-  std::unique_ptr<LinearOperator> preconditioner_adapter =
+  auto lhs = std::make_unique<BlockRandomAccessSparseMatrixAdapter>(*sc);
+  auto preconditioner =
       std::make_unique<BlockRandomAccessDiagonalMatrixAdapter>(
           *preconditioner_);
 
-  LinearSolver::Options cg_options;
+  ConjugateGradientsSolverOptions cg_options;
   cg_options.min_num_iterations = options().min_num_iterations;
   cg_options.max_num_iterations = options().max_num_iterations;
-  ConjugateGradientsSolver cg_solver(cg_options);
+  cg_options.residual_reset_period = options().residual_reset_period;
+  cg_options.q_tolerance = per_solve_options.q_tolerance;
+  cg_options.r_tolerance = per_solve_options.r_tolerance;
 
-  LinearSolver::PerSolveOptions cg_per_solve_options;
-  cg_per_solve_options.r_tolerance = per_solve_options.r_tolerance;
-  cg_per_solve_options.q_tolerance = per_solve_options.q_tolerance;
-  cg_per_solve_options.preconditioner = preconditioner_adapter.get();
+  cg_solution_ = Vector::Zero(sc->num_rows());
+  Vector scratch[4];
+  for (int i = 0; i < 4; ++i) {
+    scratch_[i] = Vector::Zero(sc->num_rows());
+  }
 
-  return cg_solver.Solve(
-      lhs_adapter.get(), rhs(), cg_per_solve_options, solution);
+  auto summary = ConjugateGradientsSolver<Vector>(
+      cg_options, *lhs, rhs(), *preconditioner, scratch_, cg_solution_);
+  VectorRef(solution, sc->num_rows()) = cg_solution_;
+  return summary;
 }
 
 }  // namespace ceres::internal
diff --git a/internal/ceres/schur_complement_solver.h b/internal/ceres/schur_complement_solver.h
index cbccc75..4fde69c 100644
--- a/internal/ceres/schur_complement_solver.h
+++ b/internal/ceres/schur_complement_solver.h
@@ -130,9 +130,8 @@
   }
   const BlockRandomAccessMatrix* lhs() const { return lhs_.get(); }
   BlockRandomAccessMatrix* mutable_lhs() { return lhs_.get(); }
-
-  void set_rhs(std::unique_ptr<double[]> rhs) { rhs_ = std::move(rhs); }
-  const double* rhs() const { return rhs_.get(); }
+  void ResizeRhs(int n) { rhs_.resize(n); }
+  const Vector& rhs() const { return rhs_; }
 
  private:
   virtual void InitStorage(const CompressedRowBlockStructure* bs) = 0;
@@ -144,7 +143,7 @@
 
   std::unique_ptr<SchurEliminatorBase> eliminator_;
   std::unique_ptr<BlockRandomAccessMatrix> lhs_;
-  std::unique_ptr<double[]> rhs_;
+  Vector rhs_;
 };
 
 // Dense Cholesky factorization based solver.
@@ -188,6 +187,8 @@
   std::vector<int> blocks_;
   std::unique_ptr<SparseCholesky> sparse_cholesky_;
   std::unique_ptr<BlockRandomAccessDiagonalMatrix> preconditioner_;
+  Vector cg_solution_;
+  Vector scratch_[4];
 };
 
 }  // namespace ceres::internal