Avoid midpoint overflow in the differential Change-Id: I8c75a819db33f4105e62942cddc83b2653318169
diff --git a/include/ceres/jet.h b/include/ceres/jet.h index 6ebcf45..b3c65fa 100644 --- a/include/ceres/jet.h +++ b/include/ceres/jet.h
@@ -881,7 +881,13 @@ // midpoint(a + da, b + db) ~= midpoint(a, b) + (da + db) / 2 . template <typename T, int N> inline Jet<T, N> midpoint(const Jet<T, N>& a, const Jet<T, N>& b) { - return Jet<T, N>{midpoint(a.a, b.a), T(0.5) * (a.v + b.v)}; + Jet<T, N> result{midpoint(a.a, b.a)}; + // To avoid overflow in the differential, compute + // (da + db) / 2 using midpoint. + for (int i = 0; i < N; ++i) { + result.v[i] = midpoint(a.v[i], b.v[i]); + } + return result; } #endif // defined(CERES_HAS_CPP20)
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc index 0a8f3e5..4680fd0 100644 --- a/internal/ceres/jet_test.cc +++ b/internal/ceres/jet_test.cc
@@ -958,7 +958,7 @@ ExpectJetsClose(z, x); } - { // Check that midpoint(x, x) = x while avoiding overflow + { // Check that midpoint(x, y) = (x + y) / 2 while avoiding overflow J x = MakeJet(std::numeric_limits<double>::min(), 1, 2); J y = MakeJet(std::numeric_limits<double>::max(), 3, 4); J z = midpoint(x, y); @@ -967,6 +967,23 @@ VL << "v = " << v; ExpectJetsClose(z, v); } + + { // Check that midpoint(x, x) = x while avoiding overflow + constexpr double a = std::numeric_limits<double>::max(); + J x = MakeJet(a, a, a); + J z = midpoint(x, x); + VL << "z = " << z; + ExpectJetsClose(z, x); + } + + { // Check that midpoint does not overflow for very large values + constexpr double a = 0.75 * std::numeric_limits<double>::max(); + J x = MakeJet(a, a, -a); + J y = MakeJet(a, a, a); + J z = midpoint(x, y); + VL << "z = " << z; + ExpectJetsClose(z, MakeJet(a, a, 0)); + } #endif // defined(CERES_HAS_CPP20) {