Do not restrict Grid1D and Grid2D output to double Fixes #1158 Change-Id: Ibeb4188393023308231c32b944f2c9dfdd674a60
diff --git a/include/ceres/cubic_interpolation.h b/include/ceres/cubic_interpolation.h index 57fcdf4..1fed7b9 100644 --- a/include/ceres/cubic_interpolation.h +++ b/include/ceres/cubic_interpolation.h
@@ -1,5 +1,5 @@ // Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2023 Google Inc. All rights reserved. +// Copyright 2025 Google Inc. All rights reserved. // http://ceres-solver.org/ // // Redistribution and use in source and binary forms, with or without @@ -31,6 +31,8 @@ #ifndef CERES_PUBLIC_CUBIC_INTERPOLATION_H_ #define CERES_PUBLIC_CUBIC_INTERPOLATION_H_ +#include <type_traits> + #include "Eigen/Core" #include "absl/log/check.h" #include "ceres/internal/export.h" @@ -190,15 +192,18 @@ CHECK_LT(begin, end); } - EIGEN_STRONG_INLINE void GetValue(const int n, double* f) const { + template <typename U> + EIGEN_STRONG_INLINE void GetValue(const int n, U* f) const { + static_assert(std::is_convertible_v<T, U>, + "Grid1D::GetValue output type U must be convertible to T"); const int idx = (std::min)((std::max)(begin_, n), end_ - 1) - begin_; if (kInterleaved) { for (int i = 0; i < kDataDimension; ++i) { - f[i] = static_cast<double>(data_[kDataDimension * idx + i]); + f[i] = static_cast<U>(data_[kDataDimension * idx + i]); } } else { for (int i = 0; i < kDataDimension; ++i) { - f[i] = static_cast<double>(data_[i * num_values_ + idx]); + f[i] = static_cast<U>(data_[i * num_values_ + idx]); } } } @@ -400,7 +405,10 @@ CHECK_LT(col_begin, col_end); } - EIGEN_STRONG_INLINE void GetValue(const int r, const int c, double* f) const { + template <typename U> + EIGEN_STRONG_INLINE void GetValue(const int r, const int c, U* f) const { + static_assert(std::is_convertible_v<T, U>, + "Grid2D::GetValue output type U must be convertible to T"); const int row_idx = (std::min)((std::max)(row_begin_, r), row_end_ - 1) - row_begin_; const int col_idx = @@ -411,11 +419,11 @@ if (kInterleaved) { for (int i = 0; i < kDataDimension; ++i) { - f[i] = static_cast<double>(data_[kDataDimension * n + i]); + f[i] = static_cast<U>(data_[kDataDimension * n + i]); } } else { for (int i = 0; i < kDataDimension; ++i) { - f[i] = static_cast<double>(data_[i * num_values_ + n]); + f[i] = static_cast<U>(data_[i * num_values_ + n]); } } }
diff --git a/internal/ceres/cubic_interpolation_test.cc b/internal/ceres/cubic_interpolation_test.cc index ace9b2b..beaf4c6 100644 --- a/internal/ceres/cubic_interpolation_test.cc +++ b/internal/ceres/cubic_interpolation_test.cc
@@ -1,5 +1,5 @@ // Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2023 Google Inc. All rights reserved. +// Copyright 2025 Google Inc. All rights reserved. // http://ceres-solver.org/ // // Redistribution and use in source and binary forms, with or without @@ -94,6 +94,26 @@ } } +TEST(Grid1D, JetSupport) { + const ceres::Jet<double, 1> x[] = { + ceres::Jet<double, 1>{1}, + ceres::Jet<double, 1>{2}, + ceres::Jet<double, 1>{3}, + ceres::Jet<double, 1>{5}, + ceres::Jet<double, 1>{6}, + ceres::Jet<double, 1>{7}, + }; + + ceres::Grid1D<ceres::Jet<double, 1>, 2, false> grid(x, 0, 3); + + for (int i = 0; i < 3; ++i) { + ceres::Jet<double, 1> value[2]; + grid.GetValue(i, value); + EXPECT_EQ(value[0], static_cast<double>(i + 1)); + EXPECT_EQ(value[1], static_cast<double>(i + 5)); + } +} + TEST(Grid2D, OneDataDimensionRowMajor) { // clang-format off int x[] = {1, 2, 3, @@ -217,6 +237,33 @@ } } +TEST(Grid2D, JetSupport) { + ceres::Jet<double, 1> x[] = { + ceres::Jet<double, 1>{1}, + ceres::Jet<double, 1>{2}, + ceres::Jet<double, 1>{2}, + ceres::Jet<double, 1>{3}, + ceres::Jet<double, 1>{3}, + ceres::Jet<double, 1>{4}, + ceres::Jet<double, 1>{4}, + ceres::Jet<double, 1>{8}, + ceres::Jet<double, 1>{8}, + ceres::Jet<double, 1>{12}, + ceres::Jet<double, 1>{12}, + ceres::Jet<double, 1>{16}, + }; + + ceres::Grid2D<ceres::Jet<double, 1>, 2, false, false> grid(x, 0, 2, 0, 3); + for (int r = 0; r < 2; ++r) { + for (int c = 0; c < 3; ++c) { + ceres::Jet<double, 1> value[2]; + grid.GetValue(r, c, value); + EXPECT_EQ(value[0], static_cast<double>(r + c + 1)); + EXPECT_EQ(value[1], static_cast<double>(4 * (r + c + 1))); + } + } +} + class CubicInterpolatorTest : public ::testing::Test { public: template <int kDataDimension>