Add a templated TypedPreconditioner class.
This sets the stage of preconditioners that can utilize
different kinds of matrix layouts, just like the LinearSolver
class hierarchy.
Change-Id: I3579cf344bcd2eeeecb1ae621cab02a3c9a0f920
diff --git a/internal/ceres/block_jacobi_preconditioner.cc b/internal/ceres/block_jacobi_preconditioner.cc
index 5525d4c..749e0b6 100644
--- a/internal/ceres/block_jacobi_preconditioner.cc
+++ b/internal/ceres/block_jacobi_preconditioner.cc
@@ -66,8 +66,8 @@
BlockJacobiPreconditioner::~BlockJacobiPreconditioner() {}
-bool BlockJacobiPreconditioner::Update(const BlockSparseMatrix& A,
- const double* D) {
+bool BlockJacobiPreconditioner::UpdateImpl(const BlockSparseMatrix& A,
+ const double* D) {
const CompressedRowBlockStructure* bs = A.block_structure();
// Compute the diagonal blocks by block inner products.
diff --git a/internal/ceres/block_jacobi_preconditioner.h b/internal/ceres/block_jacobi_preconditioner.h
index dc291bf..3505a01 100644
--- a/internal/ceres/block_jacobi_preconditioner.h
+++ b/internal/ceres/block_jacobi_preconditioner.h
@@ -51,20 +51,21 @@
// update the matrix by running Update(A, D). The values of the matrix A are
// inspected to construct the preconditioner. The vector D is applied as the
// D^TD diagonal term.
-class BlockJacobiPreconditioner : public Preconditioner {
+class BlockJacobiPreconditioner : public BlockSparseMatrixPreconditioner {
public:
// A must remain valid while the BlockJacobiPreconditioner is.
explicit BlockJacobiPreconditioner(const BlockSparseMatrix& A);
virtual ~BlockJacobiPreconditioner();
// Preconditioner interface
- virtual bool Update(const BlockSparseMatrix& A, const double* D);
virtual void RightMultiply(const double* x, double* y) const;
virtual void LeftMultiply(const double* x, double* y) const;
virtual int num_rows() const { return num_rows_; }
virtual int num_cols() const { return num_rows_; }
private:
+ virtual bool UpdateImpl(const BlockSparseMatrix& A, const double* D);
+
std::vector<double*> blocks_;
std::vector<double> block_storage_;
int num_rows_;
diff --git a/internal/ceres/preconditioner.cc b/internal/ceres/preconditioner.cc
index 19e58fc..505a47d 100644
--- a/internal/ceres/preconditioner.cc
+++ b/internal/ceres/preconditioner.cc
@@ -45,8 +45,8 @@
SparseMatrixPreconditionerWrapper::~SparseMatrixPreconditionerWrapper() {
}
-bool SparseMatrixPreconditionerWrapper::Update(const BlockSparseMatrix& A,
- const double* D) {
+bool SparseMatrixPreconditionerWrapper::UpdateImpl(const SparseMatrix& A,
+ const double* D) {
return true;
}
diff --git a/internal/ceres/preconditioner.h b/internal/ceres/preconditioner.h
index 7206536..cb0a381 100644
--- a/internal/ceres/preconditioner.h
+++ b/internal/ceres/preconditioner.h
@@ -32,6 +32,8 @@
#define CERES_INTERNAL_PRECONDITIONER_H_
#include <vector>
+#include "ceres/casts.h"
+#include "ceres/compressed_row_sparse_matrix.h"
#include "ceres/linear_operator.h"
#include "ceres/sparse_matrix.h"
@@ -105,7 +107,7 @@
//
// D can be NULL, in which case its interpreted as a diagonal matrix
// of size zero.
- virtual bool Update(const BlockSparseMatrix& A, const double* D) = 0;
+ virtual bool Update(const LinearOperator& A, const double* D) = 0;
// LinearOperator interface. Since the operator is symmetric,
// LeftMultiply and num_cols are just calls to RightMultiply and
@@ -122,19 +124,40 @@
}
};
+// This templated subclass of Preconditioner serves as a base class for
+// other preconditioners that depend on the particular matrix layout of
+// the underlying linear operator.
+template <typename MatrixType>
+class TypedPreconditioner : public Preconditioner {
+ public:
+ virtual ~TypedPreconditioner() {}
+ virtual bool Update(const LinearOperator& A, const double* D) {
+ return UpdateImpl(*down_cast<const MatrixType*>(&A), D);
+ }
+
+ private:
+ virtual bool UpdateImpl(const MatrixType& A, const double* D) = 0;
+};
+
+// Preconditioners that depend on acccess to the low level structure
+// of a SparseMatrix.
+typedef TypedPreconditioner<SparseMatrix> SparseMatrixPreconditioner; // NOLINT
+typedef TypedPreconditioner<BlockSparseMatrix> BlockSparseMatrixPreconditioner; // NOLINT
+typedef TypedPreconditioner<CompressedRowSparseMatrix> CompressedRowSparseMatrixPreconditioner; // NOLINT
+
// Wrap a SparseMatrix object as a preconditioner.
-class SparseMatrixPreconditionerWrapper : public Preconditioner {
+class SparseMatrixPreconditionerWrapper : public SparseMatrixPreconditioner {
public:
// Wrapper does NOT take ownership of the matrix pointer.
explicit SparseMatrixPreconditionerWrapper(const SparseMatrix* matrix);
virtual ~SparseMatrixPreconditionerWrapper();
// Preconditioner interface
- virtual bool Update(const BlockSparseMatrix& A, const double* D);
virtual void RightMultiply(const double* x, double* y) const;
virtual int num_rows() const;
private:
+ virtual bool UpdateImpl(const SparseMatrix& A, const double* D);
const SparseMatrix* matrix_;
};
diff --git a/internal/ceres/schur_jacobi_preconditioner.cc b/internal/ceres/schur_jacobi_preconditioner.cc
index 780795b..aa840c5 100644
--- a/internal/ceres/schur_jacobi_preconditioner.cc
+++ b/internal/ceres/schur_jacobi_preconditioner.cc
@@ -91,8 +91,8 @@
}
// Update the values of the preconditioner matrix and factorize it.
-bool SchurJacobiPreconditioner::Update(const BlockSparseMatrix& A,
- const double* D) {
+bool SchurJacobiPreconditioner::UpdateImpl(const BlockSparseMatrix& A,
+ const double* D) {
const int num_rows = m_->num_rows();
CHECK_GT(num_rows, 0);
diff --git a/internal/ceres/schur_jacobi_preconditioner.h b/internal/ceres/schur_jacobi_preconditioner.h
index b80a249..f6e7b0d 100644
--- a/internal/ceres/schur_jacobi_preconditioner.h
+++ b/internal/ceres/schur_jacobi_preconditioner.h
@@ -73,7 +73,7 @@
// preconditioner.Update(A, NULL);
// preconditioner.RightMultiply(x, y);
//
-class SchurJacobiPreconditioner : public Preconditioner {
+class SchurJacobiPreconditioner : public BlockSparseMatrixPreconditioner {
public:
// Initialize the symbolic structure of the preconditioner. bs is
// the block structure of the linear system to be solved. It is used
@@ -86,12 +86,12 @@
virtual ~SchurJacobiPreconditioner();
// Preconditioner interface.
- virtual bool Update(const BlockSparseMatrix& A, const double* D);
virtual void RightMultiply(const double* x, double* y) const;
virtual int num_rows() const;
private:
void InitEliminator(const CompressedRowBlockStructure& bs);
+ virtual bool UpdateImpl(const BlockSparseMatrix& A, const double* D);
Preconditioner::Options options_;
diff --git a/internal/ceres/visibility_based_preconditioner.cc b/internal/ceres/visibility_based_preconditioner.cc
index 94266e5..7af1339 100644
--- a/internal/ceres/visibility_based_preconditioner.cc
+++ b/internal/ceres/visibility_based_preconditioner.cc
@@ -324,8 +324,8 @@
}
// Update the values of the preconditioner matrix and factorize it.
-bool VisibilityBasedPreconditioner::Update(const BlockSparseMatrix& A,
- const double* D) {
+bool VisibilityBasedPreconditioner::UpdateImpl(const BlockSparseMatrix& A,
+ const double* D) {
const time_t start_time = time(NULL);
const int num_rows = m_->num_rows();
CHECK_GT(num_rows, 0);
diff --git a/internal/ceres/visibility_based_preconditioner.h b/internal/ceres/visibility_based_preconditioner.h
index 54a03e6..c58b1a7 100644
--- a/internal/ceres/visibility_based_preconditioner.h
+++ b/internal/ceres/visibility_based_preconditioner.h
@@ -123,7 +123,7 @@
// preconditioner.RightMultiply(x, y);
//
#ifndef CERES_NO_SUITESPARSE
-class VisibilityBasedPreconditioner : public Preconditioner {
+class VisibilityBasedPreconditioner : public BlockSparseMatrixPreconditioner {
public:
// Initialize the symbolic structure of the preconditioner. bs is
// the block structure of the linear system to be solved. It is used
@@ -136,12 +136,13 @@
virtual ~VisibilityBasedPreconditioner();
// Preconditioner interface
- virtual bool Update(const BlockSparseMatrix& A, const double* D);
virtual void RightMultiply(const double* x, double* y) const;
virtual int num_rows() const;
friend class VisibilityBasedPreconditionerTest;
+
private:
+ virtual bool UpdateImpl(const BlockSparseMatrix& A, const double* D);
void ComputeClusterJacobiSparsity(const CompressedRowBlockStructure& bs);
void ComputeClusterTridiagonalSparsity(const CompressedRowBlockStructure& bs);
void InitStorage(const CompressedRowBlockStructure& bs);
@@ -203,7 +204,7 @@
#else // SuiteSparse
// If SuiteSparse is not compiled in, the preconditioner is not
// available.
-class VisibilityBasedPreconditioner : public Preconditioner {
+class VisibilityBasedPreconditioner : public BlockSparseMatrixPreconditioner {
public:
VisibilityBasedPreconditioner(const CompressedRowBlockStructure& bs,
const Preconditioner::Options& options) {
@@ -215,7 +216,9 @@
virtual void LeftMultiply(const double* x, double* y) const {}
virtual int num_rows() const { return -1; }
virtual int num_cols() const { return -1; }
- bool Update(const BlockSparseMatrix& A, const double* D) {
+
+ private:
+ bool UpdateImpl(const BlockSparseMatrix& A, const double* D) {
return false;
}
};