Enable usage of schur power series expansion preconditioner.

Add an option to use schur power series expansion for initialization
of pcg solution in ITERATIVE_SCHUR linear solver.

Change-Id: Ifb8bce02bc5f5ceebc74f961eefd3f6dd2ffab4a
diff --git a/include/ceres/solver.h b/include/ceres/solver.h
index 2a1dd18..27198ac 100644
--- a/include/ceres/solver.h
+++ b/include/ceres/solver.h
@@ -691,6 +691,23 @@
     // as its termination type.
     int max_linear_solver_iterations = 500;
 
+    // Maximum number of iterations performed by SCHUR_POWER_SERIES_EXPANSION.
+    // This value controls the maximum number of iterations whether it is used
+    // as a preconditioner or just to initialize the solution for
+    // ITERATIVE_SCHUR.
+    int max_num_spse_iterations = 5;
+
+    // Use SCHUR_POWER_SERIES_EXPANSION to initialize the solution for
+    // ITERATIVE_SCHUR. This option can be set true regardless of what
+    // preconditioner is being used.
+    bool use_power_series_expansion_initialization = false;
+
+    // When use_power_series_expansion_initialization is true, this parameter
+    // along with max_num_spse_iterations controls the number of
+    // SCHUR_POWER_SERIES_EXPANSION iterations performed for initialization. It
+    // is not used to control the preconditioner.
+    double spse_tolerance = 0.1;
+
     // Forcing sequence parameter. The truncated Newton solver uses
     // this number to control the relative accuracy with which the
     // Newton step is computed.
diff --git a/include/ceres/types.h b/include/ceres/types.h
index 52e1e7b..1b07ab3 100644
--- a/include/ceres/types.h
+++ b/include/ceres/types.h
@@ -107,9 +107,6 @@
 
   // Use power series expansion to approximate the inversion of Schur complement
   // as a preconditioner.
-  // WARNING! Application of this preconditioner currently is not integrated
-  // into linear solvers, so failure to use it via public API is expected
-  // behaviour.
   SCHUR_POWER_SERIES_EXPANSION,
 
   // Visibility clustering based preconditioners.
diff --git a/internal/ceres/implicit_schur_complement.cc b/internal/ceres/implicit_schur_complement.cc
index 92631de..8f7a2a0 100644
--- a/internal/ceres/implicit_schur_complement.cc
+++ b/internal/ceres/implicit_schur_complement.cc
@@ -57,6 +57,7 @@
   b_ = b;
 
   compute_ftf_inverse_ =
+      options_.use_power_series_expansion_initialization ||
       options_.preconditioner_type == JACOBI ||
       options_.preconditioner_type == SCHUR_POWER_SERIES_EXPANSION;
 
diff --git a/internal/ceres/iterative_schur_complement_solver.cc b/internal/ceres/iterative_schur_complement_solver.cc
index 5f56be2..8be33bc 100644
--- a/internal/ceres/iterative_schur_complement_solver.cc
+++ b/internal/ceres/iterative_schur_complement_solver.cc
@@ -43,6 +43,7 @@
 #include "ceres/implicit_schur_complement.h"
 #include "ceres/internal/eigen.h"
 #include "ceres/linear_solver.h"
+#include "ceres/power_series_expansion_preconditioner.h"
 #include "ceres/preconditioner.h"
 #include "ceres/schur_jacobi_preconditioner.h"
 #include "ceres/triplet_sparse_matrix.h"
@@ -90,9 +91,18 @@
     return summary;
   }
 
-  // Initialize the solution to the Schur complement system to zero.
+  // Initialize the solution to the Schur complement system.
   reduced_linear_system_solution_.resize(schur_complement_->num_rows());
   reduced_linear_system_solution_.setZero();
+  if (options_.use_power_series_expansion_initialization) {
+    PowerSeriesExpansionPreconditioner pse_solver(
+        schur_complement_.get(),
+        options_.max_num_spse_iterations,
+        options_.spse_tolerance);
+    pse_solver.RightMultiplyAndAccumulate(
+        schur_complement_->rhs().data(),
+        reduced_linear_system_solution_.data());
+  }
 
   CreatePreconditioner(A);
   if (preconditioner_ != nullptr) {
@@ -168,6 +178,12 @@
       preconditioner_ = std::make_unique<SparseMatrixPreconditionerWrapper>(
           schur_complement_->block_diagonal_FtF_inverse());
       break;
+    case SCHUR_POWER_SERIES_EXPANSION:
+      // Ignoring the value of spse_tolerance to ensure preconditioner stays
+      // fixed during the iterations of cg.
+      preconditioner_ = std::make_unique<PowerSeriesExpansionPreconditioner>(
+          schur_complement_.get(), options_.max_num_spse_iterations, 0);
+      break;
     case SCHUR_JACOBI:
       preconditioner_ = std::make_unique<SchurJacobiPreconditioner>(
           *A->block_structure(), preconditioner_options);
diff --git a/internal/ceres/iterative_schur_complement_solver_test.cc b/internal/ceres/iterative_schur_complement_solver_test.cc
index 80e388a..a98b295 100644
--- a/internal/ceres/iterative_schur_complement_solver_test.cc
+++ b/internal/ceres/iterative_schur_complement_solver_test.cc
@@ -74,7 +74,9 @@
     num_eliminate_blocks_ = problem->num_eliminate_blocks;
   }
 
-  AssertionResult TestSolver(double* D) {
+  AssertionResult TestSolver(double* D,
+                             PreconditionerType preconditioner_type,
+                             bool use_power_series_expansion_initialization) {
     TripletSparseMatrix triplet_A(
         A_->num_rows(), A_->num_cols(), A_->num_nonzeros());
     A_->ToTripletSparseMatrix(&triplet_A);
@@ -95,7 +97,10 @@
     options.elimination_groups.push_back(num_eliminate_blocks_);
     options.elimination_groups.push_back(0);
     options.max_num_iterations = num_cols_;
-    options.preconditioner_type = SCHUR_JACOBI;
+    options.max_num_spse_iterations = 1;
+    options.use_power_series_expansion_initialization =
+        use_power_series_expansion_initialization;
+    options.preconditioner_type = preconditioner_type;
     IterativeSchurComplementSolver isc(options);
 
     Vector isc_sol(num_cols_);
@@ -119,16 +124,30 @@
   std::unique_ptr<double[]> D_;
 };
 
-TEST_F(IterativeSchurComplementSolverTest, NormalProblem) {
+TEST_F(IterativeSchurComplementSolverTest, NormalProblemSchurJacobi) {
   SetUpProblem(2);
-  EXPECT_TRUE(TestSolver(nullptr));
-  EXPECT_TRUE(TestSolver(D_.get()));
+  EXPECT_TRUE(TestSolver(nullptr, SCHUR_JACOBI, false));
+  EXPECT_TRUE(TestSolver(D_.get(), SCHUR_JACOBI, false));
+}
+
+TEST_F(IterativeSchurComplementSolverTest,
+       NormalProblemSchurJacobiWithPowerSeriesExpansionInitialization) {
+  SetUpProblem(2);
+  EXPECT_TRUE(TestSolver(nullptr, SCHUR_JACOBI, true));
+  EXPECT_TRUE(TestSolver(D_.get(), SCHUR_JACOBI, true));
+}
+
+TEST_F(IterativeSchurComplementSolverTest,
+       NormalProblemPowerSeriesExpansionPreconditioner) {
+  SetUpProblem(5);
+  EXPECT_TRUE(TestSolver(nullptr, SCHUR_POWER_SERIES_EXPANSION, false));
+  EXPECT_TRUE(TestSolver(D_.get(), SCHUR_POWER_SERIES_EXPANSION, false));
 }
 
 TEST_F(IterativeSchurComplementSolverTest, ProblemWithNoFBlocks) {
   SetUpProblem(3);
-  EXPECT_TRUE(TestSolver(nullptr));
-  EXPECT_TRUE(TestSolver(D_.get()));
+  EXPECT_TRUE(TestSolver(nullptr, SCHUR_JACOBI, false));
+  EXPECT_TRUE(TestSolver(D_.get(), SCHUR_JACOBI, false));
 }
 
 }  // namespace internal
diff --git a/internal/ceres/linear_solver.h b/internal/ceres/linear_solver.h
index 4916e41..69e984f 100644
--- a/internal/ceres/linear_solver.h
+++ b/internal/ceres/linear_solver.h
@@ -165,6 +165,23 @@
     int min_num_iterations = 1;
     int max_num_iterations = 1;
 
+    // Maximum number of iterations performed by SCHUR_POWER_SERIES_EXPANSION.
+    // This value controls the maximum number of iterations whether it is used
+    // as a preconditioner or just to initialize the solution for
+    // ITERATIVE_SCHUR.
+    int max_num_spse_iterations = 5;
+
+    // Use SCHUR_POWER_SERIES_EXPANSION to initialize the solution for
+    // ITERATIVE_SCHUR. This option can be set true regardless of what
+    // preconditioner is being used.
+    bool use_power_series_expansion_initialization = false;
+
+    // When use_power_series_expansion_initialization is true, this parameter
+    // along with max_num_spse_iterations controls the number of
+    // SCHUR_POWER_SERIES_EXPANSION iterations performed for initialization. It
+    // is not used to control the preconditioner.
+    double spse_tolerance = 0.1;
+
     // If possible, how many threads can the solver use.
     int num_threads = 1;
 
diff --git a/internal/ceres/power_series_expansion_preconditioner.cc b/internal/ceres/power_series_expansion_preconditioner.cc
index 1b2dbb9..7a36d92 100644
--- a/internal/ceres/power_series_expansion_preconditioner.cc
+++ b/internal/ceres/power_series_expansion_preconditioner.cc
@@ -34,13 +34,11 @@
 
 PowerSeriesExpansionPreconditioner::PowerSeriesExpansionPreconditioner(
     const ImplicitSchurComplement* isc,
-    const double spse_tolerance,
-    const int min_num_iterations,
-    const int max_num_iterations)
+    const int max_num_spse_iterations,
+    const double spse_tolerance)
     : isc_(isc),
-      spse_tolerance_(spse_tolerance),
-      min_num_iterations_(min_num_iterations),
-      max_num_iterations_(max_num_iterations) {}
+      max_num_spse_iterations_(max_num_spse_iterations),
+      spse_tolerance_(spse_tolerance) {}
 
 PowerSeriesExpansionPreconditioner::~PowerSeriesExpansionPreconditioner() =
     default;
@@ -66,8 +64,7 @@
     isc_->InversePowerSeriesOperatorRightMultiplyAccumulate(
         previous_series_term.data(), series_term.data());
     yref += series_term;
-    if (i >= min_num_iterations_ &&
-        (i >= max_num_iterations_ || series_term.norm() < norm_threshold)) {
+    if (i >= max_num_spse_iterations_ || series_term.norm() < norm_threshold) {
       break;
     }
     std::swap(previous_series_term, series_term);
diff --git a/internal/ceres/power_series_expansion_preconditioner.h b/internal/ceres/power_series_expansion_preconditioner.h
index c5cf32d..a8bb9a6 100644
--- a/internal/ceres/power_series_expansion_preconditioner.h
+++ b/internal/ceres/power_series_expansion_preconditioner.h
@@ -46,9 +46,8 @@
     : public Preconditioner {
  public:
   PowerSeriesExpansionPreconditioner(const ImplicitSchurComplement* isc,
-                                     const double spse_tolerance_,
-                                     const int min_num_iterations_,
-                                     const int max_num_iterations_);
+                                     const int max_num_spse_iterations,
+                                     const double spse_tolerance);
   PowerSeriesExpansionPreconditioner(
       const PowerSeriesExpansionPreconditioner&) = delete;
   void operator=(const PowerSeriesExpansionPreconditioner&) = delete;
@@ -60,9 +59,8 @@
 
  private:
   const ImplicitSchurComplement* isc_;
+  const int max_num_spse_iterations_;
   const double spse_tolerance_;
-  const int min_num_iterations_;
-  const int max_num_iterations_;
 };
 
 }  // namespace ceres::internal
diff --git a/internal/ceres/power_series_expansion_preconditioner_test.cc b/internal/ceres/power_series_expansion_preconditioner_test.cc
index a101831..a291405 100644
--- a/internal/ceres/power_series_expansion_preconditioner_test.cc
+++ b/internal/ceres/power_series_expansion_preconditioner_test.cc
@@ -81,10 +81,9 @@
 TEST_F(PowerSeriesExpansionPreconditionerTest,
        InverseValidPreconditionerToleranceReached) {
   const double spse_tolerance = kEpsilon;
-  const int min_iterations = 1;
-  const int max_iterations = 50;
+  const int max_num_iterations = 50;
   PowerSeriesExpansionPreconditioner preconditioner(
-      isc_.get(), spse_tolerance, min_iterations, max_iterations);
+      isc_.get(), max_num_iterations, spse_tolerance);
 
   Vector x(num_f_cols_), y(num_f_cols_);
   for (int i = 0; i < num_f_cols_; i++) {
@@ -104,10 +103,10 @@
 
 TEST_F(PowerSeriesExpansionPreconditionerTest,
        InverseValidPreconditionerMaxIterations) {
-  const double spse_tolerance = 1 / kEpsilon;
-  const int num_iterations = 50;
+  const double spse_tolerance = 0;
+  const int max_num_iterations = 50;
   PowerSeriesExpansionPreconditioner preconditioner_fixed_n_iterations(
-      isc_.get(), spse_tolerance, num_iterations, num_iterations);
+      isc_.get(), max_num_iterations, spse_tolerance);
 
   Vector x(num_f_cols_), y(num_f_cols_);
   for (int i = 0; i < num_f_cols_; i++) {
@@ -129,10 +128,9 @@
 TEST_F(PowerSeriesExpansionPreconditionerTest,
        InverseInvalidBadPreconditionerTolerance) {
   const double spse_tolerance = 1 / kEpsilon;
-  const int min_iterations = 1;
-  const int max_iterations = 50;
+  const int max_num_iterations = 50;
   PowerSeriesExpansionPreconditioner preconditioner_bad_tolerance(
-      isc_.get(), spse_tolerance, min_iterations, max_iterations);
+      isc_.get(), max_num_iterations, spse_tolerance);
 
   Vector x(num_f_cols_), y(num_f_cols_);
   for (int i = 0; i < num_f_cols_; i++) {
@@ -151,9 +149,9 @@
 TEST_F(PowerSeriesExpansionPreconditionerTest,
        InverseInvalidBadPreconditionerMaxIterations) {
   const double spse_tolerance = kEpsilon;
-  const int num_iterations = 1;
+  const int max_num_iterations = 1;
   PowerSeriesExpansionPreconditioner preconditioner_bad_iterations_limit(
-      isc_.get(), spse_tolerance, num_iterations, num_iterations);
+      isc_.get(), max_num_iterations, spse_tolerance);
 
   Vector x(num_f_cols_), y(num_f_cols_);
   for (int i = 0; i < num_f_cols_; i++) {
diff --git a/internal/ceres/solver.cc b/internal/ceres/solver.cc
index b07d4f0..71b4537 100644
--- a/internal/ceres/solver.cc
+++ b/internal/ceres/solver.cc
@@ -259,12 +259,25 @@
     return false;
   }
 
-  if (options.use_explicit_schur_complement &&
-      options.preconditioner_type != SCHUR_JACOBI) {
-    *error =
-        "use_explicit_schur_complement only supports "
-        "SCHUR_JACOBI as the preconditioner.";
-    return false;
+  if (options.use_explicit_schur_complement) {
+    if (options.preconditioner_type != SCHUR_JACOBI) {
+      *error =
+          "use_explicit_schur_complement only supports "
+          "SCHUR_JACOBI as the preconditioner.";
+      return false;
+    }
+    if (options.use_power_series_expansion_initialization) {
+      *error =
+          "use_explicit_schur_complement does not support "
+          "use_power_series_expansion_initialization.";
+      return false;
+    }
+  }
+
+  if (options.use_power_series_expansion_initialization ||
+      options.preconditioner_type == SCHUR_POWER_SERIES_EXPANSION) {
+    OPTION_GE(max_num_spse_iterations, 1)
+    OPTION_GE(spse_tolerance, 0.0)
   }
 
   if (options.use_mixed_precision_solves) {
diff --git a/internal/ceres/trust_region_preprocessor.cc b/internal/ceres/trust_region_preprocessor.cc
index a5a986f..e9406d1 100644
--- a/internal/ceres/trust_region_preprocessor.cc
+++ b/internal/ceres/trust_region_preprocessor.cc
@@ -199,6 +199,11 @@
       options.max_linear_solver_iterations;
   pp->linear_solver_options.type = options.linear_solver_type;
   pp->linear_solver_options.preconditioner_type = options.preconditioner_type;
+  pp->linear_solver_options.use_power_series_expansion_initialization =
+      options.use_power_series_expansion_initialization;
+  pp->linear_solver_options.spse_tolerance = options.spse_tolerance;
+  pp->linear_solver_options.max_num_spse_iterations =
+      options.max_num_spse_iterations;
   pp->linear_solver_options.visibility_clustering_type =
       options.visibility_clustering_type;
   pp->linear_solver_options.sparse_linear_algebra_library_type =
diff --git a/internal/ceres/types.cc b/internal/ceres/types.cc
index c0e3355..1d514de 100644
--- a/internal/ceres/types.cc
+++ b/internal/ceres/types.cc
@@ -81,6 +81,7 @@
     CASESTR(IDENTITY);
     CASESTR(JACOBI);
     CASESTR(SCHUR_JACOBI);
+    CASESTR(SCHUR_POWER_SERIES_EXPANSION);
     CASESTR(CLUSTER_JACOBI);
     CASESTR(CLUSTER_TRIDIAGONAL);
     CASESTR(SUBSET);
@@ -94,6 +95,7 @@
   STRENUM(IDENTITY);
   STRENUM(JACOBI);
   STRENUM(SCHUR_JACOBI);
+  STRENUM(SCHUR_POWER_SERIES_EXPANSION);
   STRENUM(CLUSTER_JACOBI);
   STRENUM(CLUSTER_TRIDIAGONAL);
   STRENUM(SUBSET);