Enable usage of schur power series expansion preconditioner.
Add an option to use schur power series expansion for initialization
of pcg solution in ITERATIVE_SCHUR linear solver.
Change-Id: Ifb8bce02bc5f5ceebc74f961eefd3f6dd2ffab4a
diff --git a/include/ceres/solver.h b/include/ceres/solver.h
index 2a1dd18..27198ac 100644
--- a/include/ceres/solver.h
+++ b/include/ceres/solver.h
@@ -691,6 +691,23 @@
// as its termination type.
int max_linear_solver_iterations = 500;
+ // Maximum number of iterations performed by SCHUR_POWER_SERIES_EXPANSION.
+ // This value controls the maximum number of iterations whether it is used
+ // as a preconditioner or just to initialize the solution for
+ // ITERATIVE_SCHUR.
+ int max_num_spse_iterations = 5;
+
+ // Use SCHUR_POWER_SERIES_EXPANSION to initialize the solution for
+ // ITERATIVE_SCHUR. This option can be set true regardless of what
+ // preconditioner is being used.
+ bool use_power_series_expansion_initialization = false;
+
+ // When use_power_series_expansion_initialization is true, this parameter
+ // along with max_num_spse_iterations controls the number of
+ // SCHUR_POWER_SERIES_EXPANSION iterations performed for initialization. It
+ // is not used to control the preconditioner.
+ double spse_tolerance = 0.1;
+
// Forcing sequence parameter. The truncated Newton solver uses
// this number to control the relative accuracy with which the
// Newton step is computed.
diff --git a/include/ceres/types.h b/include/ceres/types.h
index 52e1e7b..1b07ab3 100644
--- a/include/ceres/types.h
+++ b/include/ceres/types.h
@@ -107,9 +107,6 @@
// Use power series expansion to approximate the inversion of Schur complement
// as a preconditioner.
- // WARNING! Application of this preconditioner currently is not integrated
- // into linear solvers, so failure to use it via public API is expected
- // behaviour.
SCHUR_POWER_SERIES_EXPANSION,
// Visibility clustering based preconditioners.
diff --git a/internal/ceres/implicit_schur_complement.cc b/internal/ceres/implicit_schur_complement.cc
index 92631de..8f7a2a0 100644
--- a/internal/ceres/implicit_schur_complement.cc
+++ b/internal/ceres/implicit_schur_complement.cc
@@ -57,6 +57,7 @@
b_ = b;
compute_ftf_inverse_ =
+ options_.use_power_series_expansion_initialization ||
options_.preconditioner_type == JACOBI ||
options_.preconditioner_type == SCHUR_POWER_SERIES_EXPANSION;
diff --git a/internal/ceres/iterative_schur_complement_solver.cc b/internal/ceres/iterative_schur_complement_solver.cc
index 5f56be2..8be33bc 100644
--- a/internal/ceres/iterative_schur_complement_solver.cc
+++ b/internal/ceres/iterative_schur_complement_solver.cc
@@ -43,6 +43,7 @@
#include "ceres/implicit_schur_complement.h"
#include "ceres/internal/eigen.h"
#include "ceres/linear_solver.h"
+#include "ceres/power_series_expansion_preconditioner.h"
#include "ceres/preconditioner.h"
#include "ceres/schur_jacobi_preconditioner.h"
#include "ceres/triplet_sparse_matrix.h"
@@ -90,9 +91,18 @@
return summary;
}
- // Initialize the solution to the Schur complement system to zero.
+ // Initialize the solution to the Schur complement system.
reduced_linear_system_solution_.resize(schur_complement_->num_rows());
reduced_linear_system_solution_.setZero();
+ if (options_.use_power_series_expansion_initialization) {
+ PowerSeriesExpansionPreconditioner pse_solver(
+ schur_complement_.get(),
+ options_.max_num_spse_iterations,
+ options_.spse_tolerance);
+ pse_solver.RightMultiplyAndAccumulate(
+ schur_complement_->rhs().data(),
+ reduced_linear_system_solution_.data());
+ }
CreatePreconditioner(A);
if (preconditioner_ != nullptr) {
@@ -168,6 +178,12 @@
preconditioner_ = std::make_unique<SparseMatrixPreconditionerWrapper>(
schur_complement_->block_diagonal_FtF_inverse());
break;
+ case SCHUR_POWER_SERIES_EXPANSION:
+ // Ignoring the value of spse_tolerance to ensure preconditioner stays
+ // fixed during the iterations of cg.
+ preconditioner_ = std::make_unique<PowerSeriesExpansionPreconditioner>(
+ schur_complement_.get(), options_.max_num_spse_iterations, 0);
+ break;
case SCHUR_JACOBI:
preconditioner_ = std::make_unique<SchurJacobiPreconditioner>(
*A->block_structure(), preconditioner_options);
diff --git a/internal/ceres/iterative_schur_complement_solver_test.cc b/internal/ceres/iterative_schur_complement_solver_test.cc
index 80e388a..a98b295 100644
--- a/internal/ceres/iterative_schur_complement_solver_test.cc
+++ b/internal/ceres/iterative_schur_complement_solver_test.cc
@@ -74,7 +74,9 @@
num_eliminate_blocks_ = problem->num_eliminate_blocks;
}
- AssertionResult TestSolver(double* D) {
+ AssertionResult TestSolver(double* D,
+ PreconditionerType preconditioner_type,
+ bool use_power_series_expansion_initialization) {
TripletSparseMatrix triplet_A(
A_->num_rows(), A_->num_cols(), A_->num_nonzeros());
A_->ToTripletSparseMatrix(&triplet_A);
@@ -95,7 +97,10 @@
options.elimination_groups.push_back(num_eliminate_blocks_);
options.elimination_groups.push_back(0);
options.max_num_iterations = num_cols_;
- options.preconditioner_type = SCHUR_JACOBI;
+ options.max_num_spse_iterations = 1;
+ options.use_power_series_expansion_initialization =
+ use_power_series_expansion_initialization;
+ options.preconditioner_type = preconditioner_type;
IterativeSchurComplementSolver isc(options);
Vector isc_sol(num_cols_);
@@ -119,16 +124,30 @@
std::unique_ptr<double[]> D_;
};
-TEST_F(IterativeSchurComplementSolverTest, NormalProblem) {
+TEST_F(IterativeSchurComplementSolverTest, NormalProblemSchurJacobi) {
SetUpProblem(2);
- EXPECT_TRUE(TestSolver(nullptr));
- EXPECT_TRUE(TestSolver(D_.get()));
+ EXPECT_TRUE(TestSolver(nullptr, SCHUR_JACOBI, false));
+ EXPECT_TRUE(TestSolver(D_.get(), SCHUR_JACOBI, false));
+}
+
+TEST_F(IterativeSchurComplementSolverTest,
+ NormalProblemSchurJacobiWithPowerSeriesExpansionInitialization) {
+ SetUpProblem(2);
+ EXPECT_TRUE(TestSolver(nullptr, SCHUR_JACOBI, true));
+ EXPECT_TRUE(TestSolver(D_.get(), SCHUR_JACOBI, true));
+}
+
+TEST_F(IterativeSchurComplementSolverTest,
+ NormalProblemPowerSeriesExpansionPreconditioner) {
+ SetUpProblem(5);
+ EXPECT_TRUE(TestSolver(nullptr, SCHUR_POWER_SERIES_EXPANSION, false));
+ EXPECT_TRUE(TestSolver(D_.get(), SCHUR_POWER_SERIES_EXPANSION, false));
}
TEST_F(IterativeSchurComplementSolverTest, ProblemWithNoFBlocks) {
SetUpProblem(3);
- EXPECT_TRUE(TestSolver(nullptr));
- EXPECT_TRUE(TestSolver(D_.get()));
+ EXPECT_TRUE(TestSolver(nullptr, SCHUR_JACOBI, false));
+ EXPECT_TRUE(TestSolver(D_.get(), SCHUR_JACOBI, false));
}
} // namespace internal
diff --git a/internal/ceres/linear_solver.h b/internal/ceres/linear_solver.h
index 4916e41..69e984f 100644
--- a/internal/ceres/linear_solver.h
+++ b/internal/ceres/linear_solver.h
@@ -165,6 +165,23 @@
int min_num_iterations = 1;
int max_num_iterations = 1;
+ // Maximum number of iterations performed by SCHUR_POWER_SERIES_EXPANSION.
+ // This value controls the maximum number of iterations whether it is used
+ // as a preconditioner or just to initialize the solution for
+ // ITERATIVE_SCHUR.
+ int max_num_spse_iterations = 5;
+
+ // Use SCHUR_POWER_SERIES_EXPANSION to initialize the solution for
+ // ITERATIVE_SCHUR. This option can be set true regardless of what
+ // preconditioner is being used.
+ bool use_power_series_expansion_initialization = false;
+
+ // When use_power_series_expansion_initialization is true, this parameter
+ // along with max_num_spse_iterations controls the number of
+ // SCHUR_POWER_SERIES_EXPANSION iterations performed for initialization. It
+ // is not used to control the preconditioner.
+ double spse_tolerance = 0.1;
+
// If possible, how many threads can the solver use.
int num_threads = 1;
diff --git a/internal/ceres/power_series_expansion_preconditioner.cc b/internal/ceres/power_series_expansion_preconditioner.cc
index 1b2dbb9..7a36d92 100644
--- a/internal/ceres/power_series_expansion_preconditioner.cc
+++ b/internal/ceres/power_series_expansion_preconditioner.cc
@@ -34,13 +34,11 @@
PowerSeriesExpansionPreconditioner::PowerSeriesExpansionPreconditioner(
const ImplicitSchurComplement* isc,
- const double spse_tolerance,
- const int min_num_iterations,
- const int max_num_iterations)
+ const int max_num_spse_iterations,
+ const double spse_tolerance)
: isc_(isc),
- spse_tolerance_(spse_tolerance),
- min_num_iterations_(min_num_iterations),
- max_num_iterations_(max_num_iterations) {}
+ max_num_spse_iterations_(max_num_spse_iterations),
+ spse_tolerance_(spse_tolerance) {}
PowerSeriesExpansionPreconditioner::~PowerSeriesExpansionPreconditioner() =
default;
@@ -66,8 +64,7 @@
isc_->InversePowerSeriesOperatorRightMultiplyAccumulate(
previous_series_term.data(), series_term.data());
yref += series_term;
- if (i >= min_num_iterations_ &&
- (i >= max_num_iterations_ || series_term.norm() < norm_threshold)) {
+ if (i >= max_num_spse_iterations_ || series_term.norm() < norm_threshold) {
break;
}
std::swap(previous_series_term, series_term);
diff --git a/internal/ceres/power_series_expansion_preconditioner.h b/internal/ceres/power_series_expansion_preconditioner.h
index c5cf32d..a8bb9a6 100644
--- a/internal/ceres/power_series_expansion_preconditioner.h
+++ b/internal/ceres/power_series_expansion_preconditioner.h
@@ -46,9 +46,8 @@
: public Preconditioner {
public:
PowerSeriesExpansionPreconditioner(const ImplicitSchurComplement* isc,
- const double spse_tolerance_,
- const int min_num_iterations_,
- const int max_num_iterations_);
+ const int max_num_spse_iterations,
+ const double spse_tolerance);
PowerSeriesExpansionPreconditioner(
const PowerSeriesExpansionPreconditioner&) = delete;
void operator=(const PowerSeriesExpansionPreconditioner&) = delete;
@@ -60,9 +59,8 @@
private:
const ImplicitSchurComplement* isc_;
+ const int max_num_spse_iterations_;
const double spse_tolerance_;
- const int min_num_iterations_;
- const int max_num_iterations_;
};
} // namespace ceres::internal
diff --git a/internal/ceres/power_series_expansion_preconditioner_test.cc b/internal/ceres/power_series_expansion_preconditioner_test.cc
index a101831..a291405 100644
--- a/internal/ceres/power_series_expansion_preconditioner_test.cc
+++ b/internal/ceres/power_series_expansion_preconditioner_test.cc
@@ -81,10 +81,9 @@
TEST_F(PowerSeriesExpansionPreconditionerTest,
InverseValidPreconditionerToleranceReached) {
const double spse_tolerance = kEpsilon;
- const int min_iterations = 1;
- const int max_iterations = 50;
+ const int max_num_iterations = 50;
PowerSeriesExpansionPreconditioner preconditioner(
- isc_.get(), spse_tolerance, min_iterations, max_iterations);
+ isc_.get(), max_num_iterations, spse_tolerance);
Vector x(num_f_cols_), y(num_f_cols_);
for (int i = 0; i < num_f_cols_; i++) {
@@ -104,10 +103,10 @@
TEST_F(PowerSeriesExpansionPreconditionerTest,
InverseValidPreconditionerMaxIterations) {
- const double spse_tolerance = 1 / kEpsilon;
- const int num_iterations = 50;
+ const double spse_tolerance = 0;
+ const int max_num_iterations = 50;
PowerSeriesExpansionPreconditioner preconditioner_fixed_n_iterations(
- isc_.get(), spse_tolerance, num_iterations, num_iterations);
+ isc_.get(), max_num_iterations, spse_tolerance);
Vector x(num_f_cols_), y(num_f_cols_);
for (int i = 0; i < num_f_cols_; i++) {
@@ -129,10 +128,9 @@
TEST_F(PowerSeriesExpansionPreconditionerTest,
InverseInvalidBadPreconditionerTolerance) {
const double spse_tolerance = 1 / kEpsilon;
- const int min_iterations = 1;
- const int max_iterations = 50;
+ const int max_num_iterations = 50;
PowerSeriesExpansionPreconditioner preconditioner_bad_tolerance(
- isc_.get(), spse_tolerance, min_iterations, max_iterations);
+ isc_.get(), max_num_iterations, spse_tolerance);
Vector x(num_f_cols_), y(num_f_cols_);
for (int i = 0; i < num_f_cols_; i++) {
@@ -151,9 +149,9 @@
TEST_F(PowerSeriesExpansionPreconditionerTest,
InverseInvalidBadPreconditionerMaxIterations) {
const double spse_tolerance = kEpsilon;
- const int num_iterations = 1;
+ const int max_num_iterations = 1;
PowerSeriesExpansionPreconditioner preconditioner_bad_iterations_limit(
- isc_.get(), spse_tolerance, num_iterations, num_iterations);
+ isc_.get(), max_num_iterations, spse_tolerance);
Vector x(num_f_cols_), y(num_f_cols_);
for (int i = 0; i < num_f_cols_; i++) {
diff --git a/internal/ceres/solver.cc b/internal/ceres/solver.cc
index b07d4f0..71b4537 100644
--- a/internal/ceres/solver.cc
+++ b/internal/ceres/solver.cc
@@ -259,12 +259,25 @@
return false;
}
- if (options.use_explicit_schur_complement &&
- options.preconditioner_type != SCHUR_JACOBI) {
- *error =
- "use_explicit_schur_complement only supports "
- "SCHUR_JACOBI as the preconditioner.";
- return false;
+ if (options.use_explicit_schur_complement) {
+ if (options.preconditioner_type != SCHUR_JACOBI) {
+ *error =
+ "use_explicit_schur_complement only supports "
+ "SCHUR_JACOBI as the preconditioner.";
+ return false;
+ }
+ if (options.use_power_series_expansion_initialization) {
+ *error =
+ "use_explicit_schur_complement does not support "
+ "use_power_series_expansion_initialization.";
+ return false;
+ }
+ }
+
+ if (options.use_power_series_expansion_initialization ||
+ options.preconditioner_type == SCHUR_POWER_SERIES_EXPANSION) {
+ OPTION_GE(max_num_spse_iterations, 1)
+ OPTION_GE(spse_tolerance, 0.0)
}
if (options.use_mixed_precision_solves) {
diff --git a/internal/ceres/trust_region_preprocessor.cc b/internal/ceres/trust_region_preprocessor.cc
index a5a986f..e9406d1 100644
--- a/internal/ceres/trust_region_preprocessor.cc
+++ b/internal/ceres/trust_region_preprocessor.cc
@@ -199,6 +199,11 @@
options.max_linear_solver_iterations;
pp->linear_solver_options.type = options.linear_solver_type;
pp->linear_solver_options.preconditioner_type = options.preconditioner_type;
+ pp->linear_solver_options.use_power_series_expansion_initialization =
+ options.use_power_series_expansion_initialization;
+ pp->linear_solver_options.spse_tolerance = options.spse_tolerance;
+ pp->linear_solver_options.max_num_spse_iterations =
+ options.max_num_spse_iterations;
pp->linear_solver_options.visibility_clustering_type =
options.visibility_clustering_type;
pp->linear_solver_options.sparse_linear_algebra_library_type =
diff --git a/internal/ceres/types.cc b/internal/ceres/types.cc
index c0e3355..1d514de 100644
--- a/internal/ceres/types.cc
+++ b/internal/ceres/types.cc
@@ -81,6 +81,7 @@
CASESTR(IDENTITY);
CASESTR(JACOBI);
CASESTR(SCHUR_JACOBI);
+ CASESTR(SCHUR_POWER_SERIES_EXPANSION);
CASESTR(CLUSTER_JACOBI);
CASESTR(CLUSTER_TRIDIAGONAL);
CASESTR(SUBSET);
@@ -94,6 +95,7 @@
STRENUM(IDENTITY);
STRENUM(JACOBI);
STRENUM(SCHUR_JACOBI);
+ STRENUM(SCHUR_POWER_SERIES_EXPANSION);
STRENUM(CLUSTER_JACOBI);
STRENUM(CLUSTER_TRIDIAGONAL);
STRENUM(SUBSET);