Use if constexpr and map SparseMatrix as const
Change-Id: Id72d77ef91d2bfc1055ab67b604f26ebc0d65769
diff --git a/internal/ceres/accelerate_sparse.cc b/internal/ceres/accelerate_sparse.cc
index 68a307b..b01414d 100644
--- a/internal/ceres/accelerate_sparse.cc
+++ b/internal/ceres/accelerate_sparse.cc
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2018 Google Inc. All rights reserved.
+// Copyright 2022 Google Inc. All rights reserved.
// http://ceres-solver.org/
//
// Redistribution and use in source and binary forms, with or without
@@ -127,8 +127,8 @@
At.structure.attributes._reserved = 0;
At.structure.attributes._allocatedBySparse = 0;
At.structure.blockSize = 1;
- if (std::is_same<Scalar, double>::value) {
- At.data = reinterpret_cast<Scalar*>(A->mutable_values());
+ if constexpr (std::is_same_v<Scalar, double>) {
+ At.data = A->mutable_values();
} else {
values_ =
ConstVectorRef(A->values(), A->num_nonzeros()).template cast<Scalar>();
@@ -262,8 +262,8 @@
typename SparseTypesTrait<Scalar>::DenseVector as_rhs_and_solution;
as_rhs_and_solution.count = num_cols;
- if (std::is_same<Scalar, double>::value) {
- as_rhs_and_solution.data = reinterpret_cast<Scalar*>(solution);
+ if constexpr (std::is_same_v<Scalar, double>) {
+ as_rhs_and_solution.data = solution;
std::copy_n(rhs, num_cols, solution);
} else {
scalar_rhs_and_solution_ =
diff --git a/internal/ceres/eigensparse.cc b/internal/ceres/eigensparse.cc
index ce01658..5f2c8ad 100644
--- a/internal/ceres/eigensparse.cc
+++ b/internal/ceres/eigensparse.cc
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2017 Google Inc. All rights reserved.
+// Copyright 2022 Google Inc. All rights reserved.
// http://ceres-solver.org/
//
// Redistribution and use in source and binary forms, with or without
@@ -49,8 +49,6 @@
namespace ceres::internal {
-// TODO(sameeragarwal): Use enable_if to clean up the implementations
-// for when Scalar == double.
template <typename Solver>
class EigenSparseCholeskyTemplate final : public SparseCholesky {
public:
@@ -92,17 +90,17 @@
std::string* message) override {
CHECK(analyzed_) << "Solve called without a call to Factorize first.";
- scalar_rhs_ = ConstVectorRef(rhs_ptr, solver_.cols())
- .template cast<typename Solver::Scalar>();
-
- // The two casts are needed if the Scalar in this class is not
- // double. For code simplicity we are going to assume that Eigen
- // is smart enough to figure out that casting a double Vector to a
- // double Vector is a straight copy. If this turns into a
- // performance bottleneck (unlikely), we can revisit this.
- scalar_solution_ = solver_.solve(scalar_rhs_);
- VectorRef(solution_ptr, solver_.cols()) =
- scalar_solution_.template cast<double>();
+ // Avoid copying when the scalar type is double
+ if constexpr (std::is_same_v<typename Solver::Scalar, double>) {
+ ConstVectorRef scalar_rhs(rhs_ptr, solver_.cols());
+ VectorRef(solution_ptr, solver_.cols()) = solver_.solve(scalar_rhs);
+ } else {
+ auto scalar_rhs = ConstVectorRef(rhs_ptr, solver_.cols())
+ .template cast<typename Solver::Scalar>();
+ auto scalar_solution = solver_.solve(scalar_rhs);
+ VectorRef(solution_ptr, solver_.cols()) =
+ scalar_solution.template cast<double>();
+ }
if (solver_.info() != Eigen::Success) {
*message = "Eigen failure. Unable to do triangular solve.";
@@ -116,9 +114,8 @@
CHECK_EQ(lhs->storage_type(), StorageType());
typename Solver::Scalar* values_ptr = nullptr;
- if (std::is_same<typename Solver::Scalar, double>::value) {
- values_ptr =
- reinterpret_cast<typename Solver::Scalar*>(lhs->mutable_values());
+ if constexpr (std::is_same_v<typename Solver::Scalar, double>) {
+ values_ptr = lhs->mutable_values();
} else {
// In the case where the scalar used in this class is not
// double. In that case, make a copy of the values array in the
@@ -128,19 +125,20 @@
values_ptr = values_.data();
}
- Eigen::Map<Eigen::SparseMatrix<typename Solver::Scalar, Eigen::ColMajor>>
+ Eigen::Map<
+ const Eigen::SparseMatrix<typename Solver::Scalar, Eigen::ColMajor>>
eigen_lhs(lhs->num_rows(),
lhs->num_rows(),
lhs->num_nonzeros(),
- lhs->mutable_rows(),
- lhs->mutable_cols(),
+ lhs->rows(),
+ lhs->cols(),
values_ptr);
return Factorize(eigen_lhs, message);
}
private:
- Eigen::Matrix<typename Solver::Scalar, Eigen::Dynamic, 1> values_,
- scalar_rhs_, scalar_solution_;
+ Eigen::Matrix<typename Solver::Scalar, Eigen::Dynamic, 1> values_;
+
bool analyzed_{false};
Solver solver_;
};