Avoid midpoint overflow in the differential

Change-Id: I8c75a819db33f4105e62942cddc83b2653318169
diff --git a/include/ceres/jet.h b/include/ceres/jet.h
index 6ebcf45..b3c65fa 100644
--- a/include/ceres/jet.h
+++ b/include/ceres/jet.h
@@ -881,7 +881,13 @@
 //   midpoint(a + da, b + db) ~= midpoint(a, b) + (da + db) / 2 .
 template <typename T, int N>
 inline Jet<T, N> midpoint(const Jet<T, N>& a, const Jet<T, N>& b) {
-  return Jet<T, N>{midpoint(a.a, b.a), T(0.5) * (a.v + b.v)};
+  Jet<T, N> result{midpoint(a.a, b.a)};
+  // To avoid overflow in the differential, compute
+  // (da + db) / 2 using midpoint.
+  for (int i = 0; i < N; ++i) {
+    result.v[i] = midpoint(a.v[i], b.v[i]);
+  }
+  return result;
 }
 #endif  // defined(CERES_HAS_CPP20)
 
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc
index 0a8f3e5..4680fd0 100644
--- a/internal/ceres/jet_test.cc
+++ b/internal/ceres/jet_test.cc
@@ -958,7 +958,7 @@
     ExpectJetsClose(z, x);
   }
 
-  {  // Check that midpoint(x, x) = x while avoiding overflow
+  {  // Check that midpoint(x, y) = (x + y) / 2 while avoiding overflow
     J x = MakeJet(std::numeric_limits<double>::min(), 1, 2);
     J y = MakeJet(std::numeric_limits<double>::max(), 3, 4);
     J z = midpoint(x, y);
@@ -967,6 +967,23 @@
     VL << "v = " << v;
     ExpectJetsClose(z, v);
   }
+
+  {  // Check that midpoint(x, x) = x while avoiding overflow
+    constexpr double a = std::numeric_limits<double>::max();
+    J x = MakeJet(a, a, a);
+    J z = midpoint(x, x);
+    VL << "z = " << z;
+    ExpectJetsClose(z, x);
+  }
+
+  {  // Check that midpoint does not overflow for very large values
+    constexpr double a = 0.75 * std::numeric_limits<double>::max();
+    J x = MakeJet(a, a, -a);
+    J y = MakeJet(a, a, a);
+    J z = midpoint(x, y);
+    VL << "z = " << z;
+    ExpectJetsClose(z, MakeJet(a, a, 0));
+  }
 #endif  // defined(CERES_HAS_CPP20)
 
   {