Support fma Jet Change-Id: I9b1ad49611e3e6f117190d56512ef2ec4bdbb1c1
diff --git a/include/ceres/jet.h b/include/ceres/jet.h index 9a66317..3048cdc 100644 --- a/include/ceres/jet.h +++ b/include/ceres/jet.h
@@ -452,6 +452,7 @@ using std::exp2; using std::expm1; using std::floor; +using std::fma; using std::fmax; using std::fmin; using std::hypot; @@ -713,6 +714,17 @@ } #endif // defined(CERES_HAS_CPP17) +// Like x * y + z but rounded only once. +template <typename T, int N> +inline Jet<T, N> fma(const Jet<T, N>& x, + const Jet<T, N>& y, + const Jet<T, N>& z) { + // d/dx fma(x, y, z) = y + // d/dy fma(x, y, z) = x + // d/dz fma(x, y, z) = 1 + return Jet<T, N>(fma(x.a, y.a, z.a), y.a * x.v + x.a * y.v + z.v); +} + template <typename T, int N> inline Jet<T, N> fmax(const Jet<T, N>& x, const Jet<T, N>& y) { using std::isgreater;
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc index e613b3f..fb7ee81 100644 --- a/internal/ceres/jet_test.cc +++ b/internal/ceres/jet_test.cc
@@ -645,6 +645,12 @@ #endif // defined(CERES_HAS_CPP20) +TEST(Jet, Fma) { + J v = fma(x, y, z); + J w = x * y + z; + EXPECT_THAT(v, IsAlmostEqualTo(w)); +} + TEST(Jet, Fmax) { EXPECT_THAT(fmax(x, y), IsAlmostEqualTo(x)); EXPECT_THAT(fmax(y, x), IsAlmostEqualTo(x));