Support fma Jet
Change-Id: I9b1ad49611e3e6f117190d56512ef2ec4bdbb1c1
diff --git a/include/ceres/jet.h b/include/ceres/jet.h
index 9a66317..3048cdc 100644
--- a/include/ceres/jet.h
+++ b/include/ceres/jet.h
@@ -452,6 +452,7 @@
using std::exp2;
using std::expm1;
using std::floor;
+using std::fma;
using std::fmax;
using std::fmin;
using std::hypot;
@@ -713,6 +714,17 @@
}
#endif // defined(CERES_HAS_CPP17)
+// Like x * y + z but rounded only once.
+template <typename T, int N>
+inline Jet<T, N> fma(const Jet<T, N>& x,
+ const Jet<T, N>& y,
+ const Jet<T, N>& z) {
+ // d/dx fma(x, y, z) = y
+ // d/dy fma(x, y, z) = x
+ // d/dz fma(x, y, z) = 1
+ return Jet<T, N>(fma(x.a, y.a, z.a), y.a * x.v + x.a * y.v + z.v);
+}
+
template <typename T, int N>
inline Jet<T, N> fmax(const Jet<T, N>& x, const Jet<T, N>& y) {
using std::isgreater;
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc
index e613b3f..fb7ee81 100644
--- a/internal/ceres/jet_test.cc
+++ b/internal/ceres/jet_test.cc
@@ -645,6 +645,12 @@
#endif // defined(CERES_HAS_CPP20)
+TEST(Jet, Fma) {
+ J v = fma(x, y, z);
+ J w = x * y + z;
+ EXPECT_THAT(v, IsAlmostEqualTo(w));
+}
+
TEST(Jet, Fmax) {
EXPECT_THAT(fmax(x, y), IsAlmostEqualTo(x));
EXPECT_THAT(fmax(y, x), IsAlmostEqualTo(x));