Support lerp Jet Provide autodiff support for C++20 std::lerp if the standard is active. Change-Id: I04ef6f17c707dd5f8ac426d9127b221e17aa08d6
diff --git a/include/ceres/internal/port.h b/include/ceres/internal/port.h index 620f13a..7c72bb5 100644 --- a/include/ceres/internal/port.h +++ b/include/ceres/internal/port.h
@@ -111,4 +111,20 @@ #define CERES_GET_FLAG(X) X #endif +// Indicates whether C++17 is currently active +#ifndef CERES_HAS_CPP17 +#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) +#define CERES_HAS_CPP17 +#endif // __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= + // 201703L) +#endif // !defined(CERES_HAS_CPP17) + +// Indicates whether C++20 is currently active +#ifndef CERES_HAS_CPP20 +#if __cplusplus >= 202002L || (defined(_MSVC_LANG) && _MSVC_LANG >= 202002L) +#define CERES_HAS_CPP20 +#endif // __cplusplus >= 202002L || (defined(_MSVC_LANG) && _MSVC_LANG >= + // 202002L) +#endif // !defined(CERES_HAS_CPP20) + #endif // CERES_PUBLIC_INTERNAL_PORT_H_
diff --git a/include/ceres/jet.h b/include/ceres/jet.h index d20038b..f6789d8 100644 --- a/include/ceres/jet.h +++ b/include/ceres/jet.h
@@ -162,6 +162,7 @@ #include <iosfwd> #include <iostream> // NOLINT #include <limits> +#include <numeric> #include <string> #include "Eigen/Core" @@ -415,6 +416,10 @@ using std::tan; using std::tanh; +#ifdef CERES_HAS_CPP20 +using std::lerp; +#endif // defined(CERES_HAS_CPP20) + // Legacy names from pre-C++11 days. // clang-format off inline bool IsFinite(double x) { return std::isfinite(x); } @@ -630,7 +635,7 @@ return Jet<T, N>(tmp, x.a / tmp * x.v + y.a / tmp * y.v); } -#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) +#ifdef CERES_HAS_CPP17 // Like sqrt(x^2 + y^2 + z^2), // but acts to prevent underflow/overflow for small/large x/y/z. // Note that the function is non-smooth at x=y=z=0, @@ -650,8 +655,7 @@ const T tmp = hypot(x.a, y.a, z.a); return Jet<T, N>(tmp, x.a / tmp * x.v + y.a / tmp * y.v + z.a / tmp * z.v); } -#endif // __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= - // 201703L) +#endif // defined(CERES_HAS_CPP17) template <typename T, int N> inline Jet<T, N> fmax(const Jet<T, N>& x, const Jet<T, N>& y) { @@ -841,6 +845,30 @@ return isinf(f); } +#ifdef CERES_HAS_CPP20 +// Computes the linear interpolation a + t(b - a) between a and b at the value +// t. For arguments outside of the range 0 <= t <= 1, the values are +// extrapolated. +// +// Differentiating lerp(a, b, t) with respect to a, b, and t gives: +// +// d/da lerp(a, b, t) = 1 - t +// d/db lerp(a, b, t) = t +// d/dt lerp(a, b, t) = b - a +// +// with the dual representation given by +// +// lerp(a + da, b + db, t + dt) +// ~= lerp(a, b, t) + (1 - t) da + t db + (b - a) dt . +template <typename T, int N> +inline Jet<T, N> lerp(const Jet<T, N>& a, + const Jet<T, N>& b, + const Jet<T, N>& t) { + return Jet<T, N>{lerp(a.a, b.a, t.a), + (T(1) - t.a) * a.v + t.a * b.v + (b.a - a.a) * t.v}; +} +#endif // defined(CERES_HAS_CPP20) + // atan2(b + db, a + da) ~= atan2(b, a) + (- b da + a db) / (a^2 + b^2) // // In words: the rate of change of theta is 1/r times the rate of
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc index 55a14b3..743405b 100644 --- a/internal/ceres/jet_test.cc +++ b/internal/ceres/jet_test.cc
@@ -810,7 +810,7 @@ NumericalTest2("hypot2", hypot2, 1.0, 2.0); // clang-format on -#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) +#ifdef CERES_HAS_CPP17 { // Check that hypot(x, y) == sqrt(x^2 + y^2) J h = hypot(x, y, z); J s = sqrt(x * x + y * y + z * z); @@ -893,8 +893,57 @@ VL << "h = " << h; ExpectJetsClose(h, huge); } -#endif // __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= - // 201703L) +#endif // defined(CERES_HAS_CPP17) + +#ifdef CERES_HAS_CPP20 + { // Check lerp(x, y, 0) == x + J z = lerp(x, y, J{0}); + VL << "z = " << z; + ExpectJetsClose(z, x); + } + + { // Check lerp(x, y, 1) == y + J z = lerp(x, y, J{1}); + VL << "z = " << z; + ExpectJetsClose(z, y); + } + + { // Check lerp(x, x, 1) == x + J z = lerp(x, x, J{1}); + VL << "z = " << z; + ExpectJetsClose(z, x); + } + + { // Check lerp(y, y, 0) == y + J z = lerp(y, y, J{1}); + VL << "z = " << z; + ExpectJetsClose(z, y); + } + + { // Check lerp(x, y, 0.5) == (x + y) / 2 + J z = lerp(x, y, J{0.5}); + J v = (x + y) / J{2}; + VL << "z = " << z; + VL << "v = " << v; + ExpectJetsClose(z, v); + } + + { // Check lerp(x, y, 2) == 2y - x + J z = lerp(x, y, J{2}); + J v = J{2} * y - x; + VL << "z = " << z; + VL << "v = " << v; + ExpectJetsClose(z, v); + } + + { // Check lerp(x, y, -2) == 3x - 2y + J z = lerp(x, y, -J{2}); + J v = J{3} * x - J{2} * y; + VL << "z = " << z; + VL << "v = " << v; + ExpectJetsClose(z, v); + } +#endif // defined(CERES_HAS_CPP20) { J z = fmax(x, y);