Support lerp Jet
Provide autodiff support for C++20 std::lerp if the standard is active.
Change-Id: I04ef6f17c707dd5f8ac426d9127b221e17aa08d6
diff --git a/include/ceres/internal/port.h b/include/ceres/internal/port.h
index 620f13a..7c72bb5 100644
--- a/include/ceres/internal/port.h
+++ b/include/ceres/internal/port.h
@@ -111,4 +111,20 @@
#define CERES_GET_FLAG(X) X
#endif
+// Indicates whether C++17 is currently active
+#ifndef CERES_HAS_CPP17
+#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
+#define CERES_HAS_CPP17
+#endif // __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >=
+ // 201703L)
+#endif // !defined(CERES_HAS_CPP17)
+
+// Indicates whether C++20 is currently active
+#ifndef CERES_HAS_CPP20
+#if __cplusplus >= 202002L || (defined(_MSVC_LANG) && _MSVC_LANG >= 202002L)
+#define CERES_HAS_CPP20
+#endif // __cplusplus >= 202002L || (defined(_MSVC_LANG) && _MSVC_LANG >=
+ // 202002L)
+#endif // !defined(CERES_HAS_CPP20)
+
#endif // CERES_PUBLIC_INTERNAL_PORT_H_
diff --git a/include/ceres/jet.h b/include/ceres/jet.h
index d20038b..f6789d8 100644
--- a/include/ceres/jet.h
+++ b/include/ceres/jet.h
@@ -162,6 +162,7 @@
#include <iosfwd>
#include <iostream> // NOLINT
#include <limits>
+#include <numeric>
#include <string>
#include "Eigen/Core"
@@ -415,6 +416,10 @@
using std::tan;
using std::tanh;
+#ifdef CERES_HAS_CPP20
+using std::lerp;
+#endif // defined(CERES_HAS_CPP20)
+
// Legacy names from pre-C++11 days.
// clang-format off
inline bool IsFinite(double x) { return std::isfinite(x); }
@@ -630,7 +635,7 @@
return Jet<T, N>(tmp, x.a / tmp * x.v + y.a / tmp * y.v);
}
-#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
+#ifdef CERES_HAS_CPP17
// Like sqrt(x^2 + y^2 + z^2),
// but acts to prevent underflow/overflow for small/large x/y/z.
// Note that the function is non-smooth at x=y=z=0,
@@ -650,8 +655,7 @@
const T tmp = hypot(x.a, y.a, z.a);
return Jet<T, N>(tmp, x.a / tmp * x.v + y.a / tmp * y.v + z.a / tmp * z.v);
}
-#endif // __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >=
- // 201703L)
+#endif // defined(CERES_HAS_CPP17)
template <typename T, int N>
inline Jet<T, N> fmax(const Jet<T, N>& x, const Jet<T, N>& y) {
@@ -841,6 +845,30 @@
return isinf(f);
}
+#ifdef CERES_HAS_CPP20
+// Computes the linear interpolation a + t(b - a) between a and b at the value
+// t. For arguments outside of the range 0 <= t <= 1, the values are
+// extrapolated.
+//
+// Differentiating lerp(a, b, t) with respect to a, b, and t gives:
+//
+// d/da lerp(a, b, t) = 1 - t
+// d/db lerp(a, b, t) = t
+// d/dt lerp(a, b, t) = b - a
+//
+// with the dual representation given by
+//
+// lerp(a + da, b + db, t + dt)
+// ~= lerp(a, b, t) + (1 - t) da + t db + (b - a) dt .
+template <typename T, int N>
+inline Jet<T, N> lerp(const Jet<T, N>& a,
+ const Jet<T, N>& b,
+ const Jet<T, N>& t) {
+ return Jet<T, N>{lerp(a.a, b.a, t.a),
+ (T(1) - t.a) * a.v + t.a * b.v + (b.a - a.a) * t.v};
+}
+#endif // defined(CERES_HAS_CPP20)
+
// atan2(b + db, a + da) ~= atan2(b, a) + (- b da + a db) / (a^2 + b^2)
//
// In words: the rate of change of theta is 1/r times the rate of
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc
index 55a14b3..743405b 100644
--- a/internal/ceres/jet_test.cc
+++ b/internal/ceres/jet_test.cc
@@ -810,7 +810,7 @@
NumericalTest2("hypot2", hypot2, 1.0, 2.0);
// clang-format on
-#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
+#ifdef CERES_HAS_CPP17
{ // Check that hypot(x, y) == sqrt(x^2 + y^2)
J h = hypot(x, y, z);
J s = sqrt(x * x + y * y + z * z);
@@ -893,8 +893,57 @@
VL << "h = " << h;
ExpectJetsClose(h, huge);
}
-#endif // __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >=
- // 201703L)
+#endif // defined(CERES_HAS_CPP17)
+
+#ifdef CERES_HAS_CPP20
+ { // Check lerp(x, y, 0) == x
+ J z = lerp(x, y, J{0});
+ VL << "z = " << z;
+ ExpectJetsClose(z, x);
+ }
+
+ { // Check lerp(x, y, 1) == y
+ J z = lerp(x, y, J{1});
+ VL << "z = " << z;
+ ExpectJetsClose(z, y);
+ }
+
+ { // Check lerp(x, x, 1) == x
+ J z = lerp(x, x, J{1});
+ VL << "z = " << z;
+ ExpectJetsClose(z, x);
+ }
+
+ { // Check lerp(y, y, 0) == y
+ J z = lerp(y, y, J{1});
+ VL << "z = " << z;
+ ExpectJetsClose(z, y);
+ }
+
+ { // Check lerp(x, y, 0.5) == (x + y) / 2
+ J z = lerp(x, y, J{0.5});
+ J v = (x + y) / J{2};
+ VL << "z = " << z;
+ VL << "v = " << v;
+ ExpectJetsClose(z, v);
+ }
+
+ { // Check lerp(x, y, 2) == 2y - x
+ J z = lerp(x, y, J{2});
+ J v = J{2} * y - x;
+ VL << "z = " << z;
+ VL << "v = " << v;
+ ExpectJetsClose(z, v);
+ }
+
+ { // Check lerp(x, y, -2) == 3x - 2y
+ J z = lerp(x, y, -J{2});
+ J v = J{3} * x - J{2} * y;
+ VL << "z = " << z;
+ VL << "v = " << v;
+ ExpectJetsClose(z, v);
+ }
+#endif // defined(CERES_HAS_CPP20)
{
J z = fmax(x, y);