Do not restrict Grid1D and Grid2D output to double

Fixes #1158

Change-Id: Ibeb4188393023308231c32b944f2c9dfdd674a60
diff --git a/include/ceres/cubic_interpolation.h b/include/ceres/cubic_interpolation.h
index 57fcdf4..1fed7b9 100644
--- a/include/ceres/cubic_interpolation.h
+++ b/include/ceres/cubic_interpolation.h
@@ -1,5 +1,5 @@
 // Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2023 Google Inc. All rights reserved.
+// Copyright 2025 Google Inc. All rights reserved.
 // http://ceres-solver.org/
 //
 // Redistribution and use in source and binary forms, with or without
@@ -31,6 +31,8 @@
 #ifndef CERES_PUBLIC_CUBIC_INTERPOLATION_H_
 #define CERES_PUBLIC_CUBIC_INTERPOLATION_H_
 
+#include <type_traits>
+
 #include "Eigen/Core"
 #include "absl/log/check.h"
 #include "ceres/internal/export.h"
@@ -190,15 +192,18 @@
     CHECK_LT(begin, end);
   }
 
-  EIGEN_STRONG_INLINE void GetValue(const int n, double* f) const {
+  template <typename U>
+  EIGEN_STRONG_INLINE void GetValue(const int n, U* f) const {
+    static_assert(std::is_convertible_v<T, U>,
+                  "Grid1D::GetValue output type U must be convertible to T");
     const int idx = (std::min)((std::max)(begin_, n), end_ - 1) - begin_;
     if (kInterleaved) {
       for (int i = 0; i < kDataDimension; ++i) {
-        f[i] = static_cast<double>(data_[kDataDimension * idx + i]);
+        f[i] = static_cast<U>(data_[kDataDimension * idx + i]);
       }
     } else {
       for (int i = 0; i < kDataDimension; ++i) {
-        f[i] = static_cast<double>(data_[i * num_values_ + idx]);
+        f[i] = static_cast<U>(data_[i * num_values_ + idx]);
       }
     }
   }
@@ -400,7 +405,10 @@
     CHECK_LT(col_begin, col_end);
   }
 
-  EIGEN_STRONG_INLINE void GetValue(const int r, const int c, double* f) const {
+  template <typename U>
+  EIGEN_STRONG_INLINE void GetValue(const int r, const int c, U* f) const {
+    static_assert(std::is_convertible_v<T, U>,
+                  "Grid2D::GetValue output type U must be convertible to T");
     const int row_idx =
         (std::min)((std::max)(row_begin_, r), row_end_ - 1) - row_begin_;
     const int col_idx =
@@ -411,11 +419,11 @@
 
     if (kInterleaved) {
       for (int i = 0; i < kDataDimension; ++i) {
-        f[i] = static_cast<double>(data_[kDataDimension * n + i]);
+        f[i] = static_cast<U>(data_[kDataDimension * n + i]);
       }
     } else {
       for (int i = 0; i < kDataDimension; ++i) {
-        f[i] = static_cast<double>(data_[i * num_values_ + n]);
+        f[i] = static_cast<U>(data_[i * num_values_ + n]);
       }
     }
   }
diff --git a/internal/ceres/cubic_interpolation_test.cc b/internal/ceres/cubic_interpolation_test.cc
index ace9b2b..beaf4c6 100644
--- a/internal/ceres/cubic_interpolation_test.cc
+++ b/internal/ceres/cubic_interpolation_test.cc
@@ -1,5 +1,5 @@
 // Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2023 Google Inc. All rights reserved.
+// Copyright 2025 Google Inc. All rights reserved.
 // http://ceres-solver.org/
 //
 // Redistribution and use in source and binary forms, with or without
@@ -94,6 +94,26 @@
   }
 }
 
+TEST(Grid1D, JetSupport) {
+  const ceres::Jet<double, 1> x[] = {
+      ceres::Jet<double, 1>{1},
+      ceres::Jet<double, 1>{2},
+      ceres::Jet<double, 1>{3},
+      ceres::Jet<double, 1>{5},
+      ceres::Jet<double, 1>{6},
+      ceres::Jet<double, 1>{7},
+  };
+
+  ceres::Grid1D<ceres::Jet<double, 1>, 2, false> grid(x, 0, 3);
+
+  for (int i = 0; i < 3; ++i) {
+    ceres::Jet<double, 1> value[2];
+    grid.GetValue(i, value);
+    EXPECT_EQ(value[0], static_cast<double>(i + 1));
+    EXPECT_EQ(value[1], static_cast<double>(i + 5));
+  }
+}
+
 TEST(Grid2D, OneDataDimensionRowMajor) {
   // clang-format off
   int x[] = {1, 2, 3,
@@ -217,6 +237,33 @@
   }
 }
 
+TEST(Grid2D, JetSupport) {
+  ceres::Jet<double, 1> x[] = {
+      ceres::Jet<double, 1>{1},
+      ceres::Jet<double, 1>{2},
+      ceres::Jet<double, 1>{2},
+      ceres::Jet<double, 1>{3},
+      ceres::Jet<double, 1>{3},
+      ceres::Jet<double, 1>{4},
+      ceres::Jet<double, 1>{4},
+      ceres::Jet<double, 1>{8},
+      ceres::Jet<double, 1>{8},
+      ceres::Jet<double, 1>{12},
+      ceres::Jet<double, 1>{12},
+      ceres::Jet<double, 1>{16},
+  };
+
+  ceres::Grid2D<ceres::Jet<double, 1>, 2, false, false> grid(x, 0, 2, 0, 3);
+  for (int r = 0; r < 2; ++r) {
+    for (int c = 0; c < 3; ++c) {
+      ceres::Jet<double, 1> value[2];
+      grid.GetValue(r, c, value);
+      EXPECT_EQ(value[0], static_cast<double>(r + c + 1));
+      EXPECT_EQ(value[1], static_cast<double>(4 * (r + c + 1)));
+    }
+  }
+}
+
 class CubicInterpolatorTest : public ::testing::Test {
  public:
   template <int kDataDimension>