Support midpoint Jet Provide autodiff support for C++20 std::midpoint if the standard is active. Change-Id: I1308a1e514bef4c74f08f655cbf803fa43503ce5
diff --git a/include/ceres/jet.h b/include/ceres/jet.h index f6789d8..6ebcf45 100644 --- a/include/ceres/jet.h +++ b/include/ceres/jet.h
@@ -418,6 +418,7 @@ #ifdef CERES_HAS_CPP20 using std::lerp; +using std::midpoint; #endif // defined(CERES_HAS_CPP20) // Legacy names from pre-C++11 days. @@ -867,6 +868,21 @@ return Jet<T, N>{lerp(a.a, b.a, t.a), (T(1) - t.a) * a.v + t.a * b.v + (b.a - a.a) * t.v}; } + +// Computes the midpoint a + (b - a) / 2. +// +// Differentiating midpoint(a, b) with respect to a and b gives: +// +// d/da midpoint(a, b) = 1/2 +// d/db midpoint(a, b) = 1/2 +// +// with the dual representation given by +// +// midpoint(a + da, b + db) ~= midpoint(a, b) + (da + db) / 2 . +template <typename T, int N> +inline Jet<T, N> midpoint(const Jet<T, N>& a, const Jet<T, N>& b) { + return Jet<T, N>{midpoint(a.a, b.a), T(0.5) * (a.v + b.v)}; +} #endif // defined(CERES_HAS_CPP20) // atan2(b + db, a + da) ~= atan2(b, a) + (- b da + a db) / (a^2 + b^2)
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc index 743405b..0a8f3e5 100644 --- a/internal/ceres/jet_test.cc +++ b/internal/ceres/jet_test.cc
@@ -943,6 +943,30 @@ VL << "v = " << v; ExpectJetsClose(z, v); } + + { // Check that midpoint(x, y) = (x + y) / 2 + J z = midpoint(x, y); + J v = (x + y) / J{2}; + VL << "z = " << z; + VL << "v = " << v; + ExpectJetsClose(z, v); + } + + { // Check that midpoint(x, x) = x + J z = midpoint(x, x); + VL << "z = " << z; + ExpectJetsClose(z, x); + } + + { // Check that midpoint(x, x) = x while avoiding overflow + J x = MakeJet(std::numeric_limits<double>::min(), 1, 2); + J y = MakeJet(std::numeric_limits<double>::max(), 3, 4); + J z = midpoint(x, y); + J v = x + (y - x) / J{2}; + VL << "z = " << z; + VL << "v = " << v; + ExpectJetsClose(z, v); + } #endif // defined(CERES_HAS_CPP20) {