Avoid midpoint overflow in the differential
Change-Id: I8c75a819db33f4105e62942cddc83b2653318169
diff --git a/include/ceres/jet.h b/include/ceres/jet.h
index 6ebcf45..b3c65fa 100644
--- a/include/ceres/jet.h
+++ b/include/ceres/jet.h
@@ -881,7 +881,13 @@
// midpoint(a + da, b + db) ~= midpoint(a, b) + (da + db) / 2 .
template <typename T, int N>
inline Jet<T, N> midpoint(const Jet<T, N>& a, const Jet<T, N>& b) {
- return Jet<T, N>{midpoint(a.a, b.a), T(0.5) * (a.v + b.v)};
+ Jet<T, N> result{midpoint(a.a, b.a)};
+ // To avoid overflow in the differential, compute
+ // (da + db) / 2 using midpoint.
+ for (int i = 0; i < N; ++i) {
+ result.v[i] = midpoint(a.v[i], b.v[i]);
+ }
+ return result;
}
#endif // defined(CERES_HAS_CPP20)
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc
index 0a8f3e5..4680fd0 100644
--- a/internal/ceres/jet_test.cc
+++ b/internal/ceres/jet_test.cc
@@ -958,7 +958,7 @@
ExpectJetsClose(z, x);
}
- { // Check that midpoint(x, x) = x while avoiding overflow
+ { // Check that midpoint(x, y) = (x + y) / 2 while avoiding overflow
J x = MakeJet(std::numeric_limits<double>::min(), 1, 2);
J y = MakeJet(std::numeric_limits<double>::max(), 3, 4);
J z = midpoint(x, y);
@@ -967,6 +967,23 @@
VL << "v = " << v;
ExpectJetsClose(z, v);
}
+
+ { // Check that midpoint(x, x) = x while avoiding overflow
+ constexpr double a = std::numeric_limits<double>::max();
+ J x = MakeJet(a, a, a);
+ J z = midpoint(x, x);
+ VL << "z = " << z;
+ ExpectJetsClose(z, x);
+ }
+
+ { // Check that midpoint does not overflow for very large values
+ constexpr double a = 0.75 * std::numeric_limits<double>::max();
+ J x = MakeJet(a, a, -a);
+ J y = MakeJet(a, a, a);
+ J z = midpoint(x, y);
+ VL << "z = " << z;
+ ExpectJetsClose(z, MakeJet(a, a, 0));
+ }
#endif // defined(CERES_HAS_CPP20)
{