support 3-argument hypot jet C++17 provides a three argument hypot(x, y, z) which can now be used for jets if the standard is active. Change-Id: Ide62e101f780fe738bb2d4f826b10daf94c585b3
diff --git a/include/ceres/jet.h b/include/ceres/jet.h index 38a69f4..d20038b 100644 --- a/include/ceres/jet.h +++ b/include/ceres/jet.h
@@ -630,6 +630,29 @@ return Jet<T, N>(tmp, x.a / tmp * x.v + y.a / tmp * y.v); } +#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) +// Like sqrt(x^2 + y^2 + z^2), +// but acts to prevent underflow/overflow for small/large x/y/z. +// Note that the function is non-smooth at x=y=z=0, +// so the derivative is undefined there. +template <typename T, int N> +inline Jet<T, N> hypot(const Jet<T, N>& x, + const Jet<T, N>& y, + const Jet<T, N>& z) { + // d/da sqrt(a) = 0.5 / sqrt(a) + // d/dx x^2 + y^2 + z^2 = 2x + // So by the chain rule: + // d/dx sqrt(x^2 + y^2 + z^2) + // = 0.5 / sqrt(x^2 + y^2 + z^2) * 2x + // = x / sqrt(x^2 + y^2 + z^2) + // d/dy sqrt(x^2 + y^2 + z^2) = y / sqrt(x^2 + y^2 + z^2) + // d/dz sqrt(x^2 + y^2 + z^2) = z / sqrt(x^2 + y^2 + z^2) + const T tmp = hypot(x.a, y.a, z.a); + return Jet<T, N>(tmp, x.a / tmp * x.v + y.a / tmp * y.v + z.a / tmp * z.v); +} +#endif // __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= + // 201703L) + template <typename T, int N> inline Jet<T, N> fmax(const Jet<T, N>& x, const Jet<T, N>& y) { using std::isgreater;
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc index b7b6004..55a14b3 100644 --- a/internal/ceres/jet_test.cc +++ b/internal/ceres/jet_test.cc
@@ -116,9 +116,11 @@ // Pick arbitrary values for x and y. J x = MakeJet(2.3, -2.7, 1e-3); J y = MakeJet(1.7, 0.5, 1e+2); + J z = MakeJet(1e-6, 1e-4, 1e-2); VL << "x = " << x; VL << "y = " << y; + VL << "z = " << z; { // Check that log(exp(x)) == x. J z = exp(x); @@ -792,18 +794,108 @@ ExpectJetsClose(h, huge); } + // Resolve the ambiguity between two and three argument hypot overloads + using Hypot2 = J(const J&, const J&); + Hypot2* const hypot2 = static_cast<Hypot2*>(&hypot<double, 2>); + // clang-format off - NumericalTest2("hypot", hypot<double, 2>, 0.0, 1e-5); - NumericalTest2("hypot", hypot<double, 2>, -1e-5, 0.0); - NumericalTest2("hypot", hypot<double, 2>, 1e-5, 1e-5); - NumericalTest2("hypot", hypot<double, 2>, 0.0, 1.0); - NumericalTest2("hypot", hypot<double, 2>, 1e-3, 1.0); - NumericalTest2("hypot", hypot<double, 2>, 1e-3, -1.0); - NumericalTest2("hypot", hypot<double, 2>, -1e-3, 1.0); - NumericalTest2("hypot", hypot<double, 2>, -1e-3, -1.0); - NumericalTest2("hypot", hypot<double, 2>, 1.0, 2.0); + NumericalTest2("hypot2", hypot2, 0.0, 1e-5); + NumericalTest2("hypot2", hypot2, -1e-5, 0.0); + NumericalTest2("hypot2", hypot2, 1e-5, 1e-5); + NumericalTest2("hypot2", hypot2, 0.0, 1.0); + NumericalTest2("hypot2", hypot2, 1e-3, 1.0); + NumericalTest2("hypot2", hypot2, 1e-3, -1.0); + NumericalTest2("hypot2", hypot2, -1e-3, 1.0); + NumericalTest2("hypot2", hypot2, -1e-3, -1.0); + NumericalTest2("hypot2", hypot2, 1.0, 2.0); // clang-format on +#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) + { // Check that hypot(x, y) == sqrt(x^2 + y^2) + J h = hypot(x, y, z); + J s = sqrt(x * x + y * y + z * z); + VL << "h = " << h; + VL << "s = " << s; + ExpectJetsClose(h, s); + } + + { // Check that hypot(x, x) == sqrt(3) * abs(x) + J h = hypot(x, x, x); + J s = sqrt(3.0) * abs(x); + VL << "h = " << h; + VL << "s = " << s; + ExpectJetsClose(h, s); + } + + { // Check that the derivative is zero tangentially to the circle: + J h = hypot(MakeJet(2.0, 1.0, 1.0), + MakeJet(2.0, 1.0, -1.0), + MakeJet(2.0, -1.0, 0.0)); + VL << "h = " << h; + ExpectJetsClose(h, MakeJet(sqrt(12.0), 1.0 / std::sqrt(3.0), 0.0)); + } + + { // Check that hypot(x, 0, 0) == x + J zero = MakeJet(0.0, 2.0, 3.14); + J h = hypot(x, zero, zero); + VL << "h = " << h; + ExpectJetsClose(x, h); + } + + { // Check that hypot(0, y, 0) == y + J zero = MakeJet(0.0, 2.0, 3.14); + J h = hypot(zero, y, zero); + VL << "h = " << h; + ExpectJetsClose(y, h); + } + + { // Check that hypot(0, 0, z) == z + J zero = MakeJet(0.0, 2.0, 3.14); + J h = hypot(zero, zero, z); + VL << "h = " << h; + ExpectJetsClose(z, h); + } + + { // Check that hypot(x, y, z) == hypot(hypot(x, y), z) + J v = hypot(x, y, z); + J w = hypot(hypot(x, y), z); + VL << "v = " << v; + VL << "w = " << w; + ExpectJetsClose(v, w); + } + + { // Check that hypot(x, y, z) == hypot(x, hypot(y, z)) + J v = hypot(x, y, z); + J w = hypot(x, hypot(y, z)); + VL << "v = " << v; + VL << "w = " << w; + ExpectJetsClose(v, w); + } + + { // Check that hypot(x, 0, 0) == sqrt(x * x) == x, even when x * x + // underflows: + EXPECT_EQ( + std::numeric_limits<double>::min() * std::numeric_limits<double>::min(), + 0.0); // Make sure it underflows + J huge = MakeJet(std::numeric_limits<double>::min(), 2.0, 3.14); + J h = hypot(huge, J(0.0), J(0.0)); + VL << "h = " << h; + ExpectJetsClose(h, huge); + } + + { // Check that hypot(x, 0, 0) == sqrt(x * x) == x, even when x * x + // overflows: + EXPECT_EQ( + std::numeric_limits<double>::max() * std::numeric_limits<double>::max(), + std::numeric_limits<double>::infinity()); + J huge = MakeJet(std::numeric_limits<double>::max(), 2.0, 3.14); + J h = hypot(huge, J(0.0), J(0.0)); + VL << "h = " << h; + ExpectJetsClose(h, huge); + } +#endif // __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= + // 201703L) + { J z = fmax(x, y); VL << "z = " << z;