Do not restrict Grid1D and Grid2D output to double
Fixes #1158
Change-Id: Ibeb4188393023308231c32b944f2c9dfdd674a60
diff --git a/include/ceres/cubic_interpolation.h b/include/ceres/cubic_interpolation.h
index 57fcdf4..1fed7b9 100644
--- a/include/ceres/cubic_interpolation.h
+++ b/include/ceres/cubic_interpolation.h
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2023 Google Inc. All rights reserved.
+// Copyright 2025 Google Inc. All rights reserved.
// http://ceres-solver.org/
//
// Redistribution and use in source and binary forms, with or without
@@ -31,6 +31,8 @@
#ifndef CERES_PUBLIC_CUBIC_INTERPOLATION_H_
#define CERES_PUBLIC_CUBIC_INTERPOLATION_H_
+#include <type_traits>
+
#include "Eigen/Core"
#include "absl/log/check.h"
#include "ceres/internal/export.h"
@@ -190,15 +192,18 @@
CHECK_LT(begin, end);
}
- EIGEN_STRONG_INLINE void GetValue(const int n, double* f) const {
+ template <typename U>
+ EIGEN_STRONG_INLINE void GetValue(const int n, U* f) const {
+ static_assert(std::is_convertible_v<T, U>,
+ "Grid1D::GetValue output type U must be convertible to T");
const int idx = (std::min)((std::max)(begin_, n), end_ - 1) - begin_;
if (kInterleaved) {
for (int i = 0; i < kDataDimension; ++i) {
- f[i] = static_cast<double>(data_[kDataDimension * idx + i]);
+ f[i] = static_cast<U>(data_[kDataDimension * idx + i]);
}
} else {
for (int i = 0; i < kDataDimension; ++i) {
- f[i] = static_cast<double>(data_[i * num_values_ + idx]);
+ f[i] = static_cast<U>(data_[i * num_values_ + idx]);
}
}
}
@@ -400,7 +405,10 @@
CHECK_LT(col_begin, col_end);
}
- EIGEN_STRONG_INLINE void GetValue(const int r, const int c, double* f) const {
+ template <typename U>
+ EIGEN_STRONG_INLINE void GetValue(const int r, const int c, U* f) const {
+ static_assert(std::is_convertible_v<T, U>,
+ "Grid2D::GetValue output type U must be convertible to T");
const int row_idx =
(std::min)((std::max)(row_begin_, r), row_end_ - 1) - row_begin_;
const int col_idx =
@@ -411,11 +419,11 @@
if (kInterleaved) {
for (int i = 0; i < kDataDimension; ++i) {
- f[i] = static_cast<double>(data_[kDataDimension * n + i]);
+ f[i] = static_cast<U>(data_[kDataDimension * n + i]);
}
} else {
for (int i = 0; i < kDataDimension; ++i) {
- f[i] = static_cast<double>(data_[i * num_values_ + n]);
+ f[i] = static_cast<U>(data_[i * num_values_ + n]);
}
}
}
diff --git a/internal/ceres/cubic_interpolation_test.cc b/internal/ceres/cubic_interpolation_test.cc
index ace9b2b..beaf4c6 100644
--- a/internal/ceres/cubic_interpolation_test.cc
+++ b/internal/ceres/cubic_interpolation_test.cc
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2023 Google Inc. All rights reserved.
+// Copyright 2025 Google Inc. All rights reserved.
// http://ceres-solver.org/
//
// Redistribution and use in source and binary forms, with or without
@@ -94,6 +94,26 @@
}
}
+TEST(Grid1D, JetSupport) {
+ const ceres::Jet<double, 1> x[] = {
+ ceres::Jet<double, 1>{1},
+ ceres::Jet<double, 1>{2},
+ ceres::Jet<double, 1>{3},
+ ceres::Jet<double, 1>{5},
+ ceres::Jet<double, 1>{6},
+ ceres::Jet<double, 1>{7},
+ };
+
+ ceres::Grid1D<ceres::Jet<double, 1>, 2, false> grid(x, 0, 3);
+
+ for (int i = 0; i < 3; ++i) {
+ ceres::Jet<double, 1> value[2];
+ grid.GetValue(i, value);
+ EXPECT_EQ(value[0], static_cast<double>(i + 1));
+ EXPECT_EQ(value[1], static_cast<double>(i + 5));
+ }
+}
+
TEST(Grid2D, OneDataDimensionRowMajor) {
// clang-format off
int x[] = {1, 2, 3,
@@ -217,6 +237,33 @@
}
}
+TEST(Grid2D, JetSupport) {
+ ceres::Jet<double, 1> x[] = {
+ ceres::Jet<double, 1>{1},
+ ceres::Jet<double, 1>{2},
+ ceres::Jet<double, 1>{2},
+ ceres::Jet<double, 1>{3},
+ ceres::Jet<double, 1>{3},
+ ceres::Jet<double, 1>{4},
+ ceres::Jet<double, 1>{4},
+ ceres::Jet<double, 1>{8},
+ ceres::Jet<double, 1>{8},
+ ceres::Jet<double, 1>{12},
+ ceres::Jet<double, 1>{12},
+ ceres::Jet<double, 1>{16},
+ };
+
+ ceres::Grid2D<ceres::Jet<double, 1>, 2, false, false> grid(x, 0, 2, 0, 3);
+ for (int r = 0; r < 2; ++r) {
+ for (int c = 0; c < 3; ++c) {
+ ceres::Jet<double, 1> value[2];
+ grid.GetValue(r, c, value);
+ EXPECT_EQ(value[0], static_cast<double>(r + c + 1));
+ EXPECT_EQ(value[1], static_cast<double>(4 * (r + c + 1)));
+ }
+ }
+}
+
class CubicInterpolatorTest : public ::testing::Test {
public:
template <int kDataDimension>