Integrate the SchurEliminatorForOneFBlock for the case <2,3,6> Also run clang-format on schur_complement_solver.cc. A more elaborate integration is not done, because it will cause binary bloat. Change-Id: Idf4e9e1794f7401a39f05fb145ec06459209a1e1
diff --git a/internal/ceres/schur_complement_solver.cc b/internal/ceres/schur_complement_solver.cc index a80bd23..27a44a7 100644 --- a/internal/ceres/schur_complement_solver.cc +++ b/internal/ceres/schur_complement_solver.cc
@@ -67,8 +67,7 @@ public: explicit BlockRandomAccessSparseMatrixAdapter( const BlockRandomAccessSparseMatrix& m) - : m_(m) { - } + : m_(m) {} virtual ~BlockRandomAccessSparseMatrixAdapter() {} @@ -93,8 +92,7 @@ public: explicit BlockRandomAccessDiagonalMatrixAdapter( const BlockRandomAccessDiagonalMatrix& m) - : m_(m) { - } + : m_(m) {} virtual ~BlockRandomAccessDiagonalMatrixAdapter() {} @@ -115,7 +113,7 @@ const BlockRandomAccessDiagonalMatrix& m_; }; -} // namespace +} // namespace LinearSolver::Summary SchurComplementSolver::SolveImpl( BlockSparseMatrix* A, @@ -124,19 +122,36 @@ double* x) { EventLogger event_logger("SchurComplementSolver::Solve"); + const CompressedRowBlockStructure* bs = A->block_structure(); if (eliminator_.get() == NULL) { - InitStorage(A->block_structure()); - DetectStructure(*A->block_structure(), - options_.elimination_groups[0], + const int num_eliminate_blocks = options_.elimination_groups[0]; + const int num_f_blocks = bs->cols.size() - num_eliminate_blocks; + + InitStorage(bs); + DetectStructure(*bs, + num_eliminate_blocks, &options_.row_block_size, &options_.e_block_size, &options_.f_block_size); - eliminator_.reset(SchurEliminatorBase::Create(options_)); - CHECK(eliminator_ != nullptr); + + // For the special case of the static structure <2,3,6> with + // exactly one f block use the SchurEliminatorForOneFBlock. + // + // TODO(sameeragarwal): A more scalable template specialization + // mechanism that does not cause binary bloat. + if (options_.row_block_size == 2 && + options_.e_block_size == 3 && + options_.f_block_size == 6 && + num_f_blocks == 1) { + eliminator_.reset(new SchurEliminatorForOneFBlock<2, 3, 6>); + } else { + eliminator_.reset(SchurEliminatorBase::Create(options_)); + } + + CHECK(eliminator_); const bool kFullRankETE = true; - eliminator_->Init( - options_.elimination_groups[0], kFullRankETE, A->block_structure()); - }; + eliminator_->Init(num_eliminate_blocks, kFullRankETE, bs); + } std::fill(x, x + A->num_cols(), 0.0); event_logger.AddEvent("Setup"); @@ -165,9 +180,7 @@ const int num_col_blocks = bs->cols.size(); vector<int> blocks(num_col_blocks - num_eliminate_blocks, 0); - for (int i = num_eliminate_blocks, j = 0; - i < num_col_blocks; - ++i, ++j) { + for (int i = num_eliminate_blocks, j = 0; i < num_col_blocks; ++i, ++j) { blocks[j] = bs->cols[i].size; } @@ -178,10 +191,8 @@ // Solve the system Sx = r, assuming that the matrix S is stored in a // BlockRandomAccessDenseMatrix. The linear system is solved using // Eigen's Cholesky factorization. -LinearSolver::Summary -DenseSchurComplementSolver::SolveReducedLinearSystem( - const LinearSolver::PerSolveOptions& per_solve_options, - double* solution) { +LinearSolver::Summary DenseSchurComplementSolver::SolveReducedLinearSystem( + const LinearSolver::PerSolveOptions& per_solve_options, double* solution) { LinearSolver::Summary summary; summary.num_iterations = 0; summary.termination_type = LINEAR_SOLVER_SUCCESS; @@ -202,8 +213,8 @@ if (options().dense_linear_algebra_library_type == EIGEN) { Eigen::LLT<Matrix, Eigen::Upper> llt = ConstMatrixRef(m->values(), num_rows, num_rows) - .selfadjointView<Eigen::Upper>() - .llt(); + .selfadjointView<Eigen::Upper>() + .llt(); if (llt.info() != Eigen::Success) { summary.termination_type = LINEAR_SOLVER_FAILURE; summary.message = @@ -214,11 +225,8 @@ VectorRef(solution, num_rows) = llt.solve(ConstVectorRef(rhs(), num_rows)); } else { VectorRef(solution, num_rows) = ConstVectorRef(rhs(), num_rows); - summary.termination_type = - LAPACK::SolveInPlaceUsingCholesky(num_rows, - m->values(), - solution, - &summary.message); + summary.termination_type = LAPACK::SolveInPlaceUsingCholesky( + num_rows, m->values(), solution, &summary.message); } return summary; @@ -232,8 +240,7 @@ } } -SparseSchurComplementSolver::~SparseSchurComplementSolver() { -} +SparseSchurComplementSolver::~SparseSchurComplementSolver() {} // Determine the non-zero blocks in the Schur Complement matrix, and // initialize a BlockRandomAccessSparseMatrix object. @@ -347,8 +354,7 @@ LinearSolver::Summary SparseSchurComplementSolver::SolveReducedLinearSystemUsingConjugateGradients( - const LinearSolver::PerSolveOptions& per_solve_options, - double* solution) { + const LinearSolver::PerSolveOptions& per_solve_options, double* solution) { CHECK(options().use_explicit_schur_complement); const int num_rows = lhs()->num_rows(); // The case where there are no f blocks, and the system is block @@ -368,13 +374,12 @@ preconditioner_.reset(new BlockRandomAccessDiagonalMatrix(blocks_)); } - BlockRandomAccessSparseMatrix* sc = - down_cast<BlockRandomAccessSparseMatrix*>( - const_cast<BlockRandomAccessMatrix*>(lhs())); + BlockRandomAccessSparseMatrix* sc = down_cast<BlockRandomAccessSparseMatrix*>( + const_cast<BlockRandomAccessMatrix*>(lhs())); // Extract block diagonal from the Schur complement to construct the // schur_jacobi preconditioner. - for (int i = 0; i < blocks_.size(); ++i) { + for (int i = 0; i < blocks_.size(); ++i) { const int block_size = blocks_[i]; int sc_r, sc_c, sc_row_stride, sc_col_stride; @@ -401,7 +406,6 @@ std::unique_ptr<LinearOperator> preconditioner_adapter( new BlockRandomAccessDiagonalMatrixAdapter(*preconditioner_)); - LinearSolver::Options cg_options; cg_options.min_num_iterations = options().min_num_iterations; cg_options.max_num_iterations = options().max_num_iterations; @@ -412,10 +416,8 @@ cg_per_solve_options.q_tolerance = per_solve_options.q_tolerance; cg_per_solve_options.preconditioner = preconditioner_adapter.get(); - return cg_solver.Solve(lhs_adapter.get(), - rhs(), - cg_per_solve_options, - solution); + return cg_solver.Solve( + lhs_adapter.get(), rhs(), cg_per_solve_options, solution); } } // namespace internal