Use if constexpr and map SparseMatrix as const Change-Id: Id72d77ef91d2bfc1055ab67b604f26ebc0d65769
diff --git a/internal/ceres/accelerate_sparse.cc b/internal/ceres/accelerate_sparse.cc index 68a307b..b01414d 100644 --- a/internal/ceres/accelerate_sparse.cc +++ b/internal/ceres/accelerate_sparse.cc
@@ -1,5 +1,5 @@ // Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2018 Google Inc. All rights reserved. +// Copyright 2022 Google Inc. All rights reserved. // http://ceres-solver.org/ // // Redistribution and use in source and binary forms, with or without @@ -127,8 +127,8 @@ At.structure.attributes._reserved = 0; At.structure.attributes._allocatedBySparse = 0; At.structure.blockSize = 1; - if (std::is_same<Scalar, double>::value) { - At.data = reinterpret_cast<Scalar*>(A->mutable_values()); + if constexpr (std::is_same_v<Scalar, double>) { + At.data = A->mutable_values(); } else { values_ = ConstVectorRef(A->values(), A->num_nonzeros()).template cast<Scalar>(); @@ -262,8 +262,8 @@ typename SparseTypesTrait<Scalar>::DenseVector as_rhs_and_solution; as_rhs_and_solution.count = num_cols; - if (std::is_same<Scalar, double>::value) { - as_rhs_and_solution.data = reinterpret_cast<Scalar*>(solution); + if constexpr (std::is_same_v<Scalar, double>) { + as_rhs_and_solution.data = solution; std::copy_n(rhs, num_cols, solution); } else { scalar_rhs_and_solution_ =
diff --git a/internal/ceres/eigensparse.cc b/internal/ceres/eigensparse.cc index ce01658..5f2c8ad 100644 --- a/internal/ceres/eigensparse.cc +++ b/internal/ceres/eigensparse.cc
@@ -1,5 +1,5 @@ // Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2017 Google Inc. All rights reserved. +// Copyright 2022 Google Inc. All rights reserved. // http://ceres-solver.org/ // // Redistribution and use in source and binary forms, with or without @@ -49,8 +49,6 @@ namespace ceres::internal { -// TODO(sameeragarwal): Use enable_if to clean up the implementations -// for when Scalar == double. template <typename Solver> class EigenSparseCholeskyTemplate final : public SparseCholesky { public: @@ -92,17 +90,17 @@ std::string* message) override { CHECK(analyzed_) << "Solve called without a call to Factorize first."; - scalar_rhs_ = ConstVectorRef(rhs_ptr, solver_.cols()) - .template cast<typename Solver::Scalar>(); - - // The two casts are needed if the Scalar in this class is not - // double. For code simplicity we are going to assume that Eigen - // is smart enough to figure out that casting a double Vector to a - // double Vector is a straight copy. If this turns into a - // performance bottleneck (unlikely), we can revisit this. - scalar_solution_ = solver_.solve(scalar_rhs_); - VectorRef(solution_ptr, solver_.cols()) = - scalar_solution_.template cast<double>(); + // Avoid copying when the scalar type is double + if constexpr (std::is_same_v<typename Solver::Scalar, double>) { + ConstVectorRef scalar_rhs(rhs_ptr, solver_.cols()); + VectorRef(solution_ptr, solver_.cols()) = solver_.solve(scalar_rhs); + } else { + auto scalar_rhs = ConstVectorRef(rhs_ptr, solver_.cols()) + .template cast<typename Solver::Scalar>(); + auto scalar_solution = solver_.solve(scalar_rhs); + VectorRef(solution_ptr, solver_.cols()) = + scalar_solution.template cast<double>(); + } if (solver_.info() != Eigen::Success) { *message = "Eigen failure. Unable to do triangular solve."; @@ -116,9 +114,8 @@ CHECK_EQ(lhs->storage_type(), StorageType()); typename Solver::Scalar* values_ptr = nullptr; - if (std::is_same<typename Solver::Scalar, double>::value) { - values_ptr = - reinterpret_cast<typename Solver::Scalar*>(lhs->mutable_values()); + if constexpr (std::is_same_v<typename Solver::Scalar, double>) { + values_ptr = lhs->mutable_values(); } else { // In the case where the scalar used in this class is not // double. In that case, make a copy of the values array in the @@ -128,19 +125,20 @@ values_ptr = values_.data(); } - Eigen::Map<Eigen::SparseMatrix<typename Solver::Scalar, Eigen::ColMajor>> + Eigen::Map< + const Eigen::SparseMatrix<typename Solver::Scalar, Eigen::ColMajor>> eigen_lhs(lhs->num_rows(), lhs->num_rows(), lhs->num_nonzeros(), - lhs->mutable_rows(), - lhs->mutable_cols(), + lhs->rows(), + lhs->cols(), values_ptr); return Factorize(eigen_lhs, message); } private: - Eigen::Matrix<typename Solver::Scalar, Eigen::Dynamic, 1> values_, - scalar_rhs_, scalar_solution_; + Eigen::Matrix<typename Solver::Scalar, Eigen::Dynamic, 1> values_; + bool analyzed_{false}; Solver solver_; };