support 3-argument hypot jet
C++17 provides a three argument hypot(x, y, z) which can now be used
for jets if the standard is active.
Change-Id: Ide62e101f780fe738bb2d4f826b10daf94c585b3
diff --git a/include/ceres/jet.h b/include/ceres/jet.h
index 38a69f4..d20038b 100644
--- a/include/ceres/jet.h
+++ b/include/ceres/jet.h
@@ -630,6 +630,29 @@
return Jet<T, N>(tmp, x.a / tmp * x.v + y.a / tmp * y.v);
}
+#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
+// Like sqrt(x^2 + y^2 + z^2),
+// but acts to prevent underflow/overflow for small/large x/y/z.
+// Note that the function is non-smooth at x=y=z=0,
+// so the derivative is undefined there.
+template <typename T, int N>
+inline Jet<T, N> hypot(const Jet<T, N>& x,
+ const Jet<T, N>& y,
+ const Jet<T, N>& z) {
+ // d/da sqrt(a) = 0.5 / sqrt(a)
+ // d/dx x^2 + y^2 + z^2 = 2x
+ // So by the chain rule:
+ // d/dx sqrt(x^2 + y^2 + z^2)
+ // = 0.5 / sqrt(x^2 + y^2 + z^2) * 2x
+ // = x / sqrt(x^2 + y^2 + z^2)
+ // d/dy sqrt(x^2 + y^2 + z^2) = y / sqrt(x^2 + y^2 + z^2)
+ // d/dz sqrt(x^2 + y^2 + z^2) = z / sqrt(x^2 + y^2 + z^2)
+ const T tmp = hypot(x.a, y.a, z.a);
+ return Jet<T, N>(tmp, x.a / tmp * x.v + y.a / tmp * y.v + z.a / tmp * z.v);
+}
+#endif // __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >=
+ // 201703L)
+
template <typename T, int N>
inline Jet<T, N> fmax(const Jet<T, N>& x, const Jet<T, N>& y) {
using std::isgreater;
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc
index b7b6004..55a14b3 100644
--- a/internal/ceres/jet_test.cc
+++ b/internal/ceres/jet_test.cc
@@ -116,9 +116,11 @@
// Pick arbitrary values for x and y.
J x = MakeJet(2.3, -2.7, 1e-3);
J y = MakeJet(1.7, 0.5, 1e+2);
+ J z = MakeJet(1e-6, 1e-4, 1e-2);
VL << "x = " << x;
VL << "y = " << y;
+ VL << "z = " << z;
{ // Check that log(exp(x)) == x.
J z = exp(x);
@@ -792,18 +794,108 @@
ExpectJetsClose(h, huge);
}
+ // Resolve the ambiguity between two and three argument hypot overloads
+ using Hypot2 = J(const J&, const J&);
+ Hypot2* const hypot2 = static_cast<Hypot2*>(&hypot<double, 2>);
+
// clang-format off
- NumericalTest2("hypot", hypot<double, 2>, 0.0, 1e-5);
- NumericalTest2("hypot", hypot<double, 2>, -1e-5, 0.0);
- NumericalTest2("hypot", hypot<double, 2>, 1e-5, 1e-5);
- NumericalTest2("hypot", hypot<double, 2>, 0.0, 1.0);
- NumericalTest2("hypot", hypot<double, 2>, 1e-3, 1.0);
- NumericalTest2("hypot", hypot<double, 2>, 1e-3, -1.0);
- NumericalTest2("hypot", hypot<double, 2>, -1e-3, 1.0);
- NumericalTest2("hypot", hypot<double, 2>, -1e-3, -1.0);
- NumericalTest2("hypot", hypot<double, 2>, 1.0, 2.0);
+ NumericalTest2("hypot2", hypot2, 0.0, 1e-5);
+ NumericalTest2("hypot2", hypot2, -1e-5, 0.0);
+ NumericalTest2("hypot2", hypot2, 1e-5, 1e-5);
+ NumericalTest2("hypot2", hypot2, 0.0, 1.0);
+ NumericalTest2("hypot2", hypot2, 1e-3, 1.0);
+ NumericalTest2("hypot2", hypot2, 1e-3, -1.0);
+ NumericalTest2("hypot2", hypot2, -1e-3, 1.0);
+ NumericalTest2("hypot2", hypot2, -1e-3, -1.0);
+ NumericalTest2("hypot2", hypot2, 1.0, 2.0);
// clang-format on
+#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
+ { // Check that hypot(x, y) == sqrt(x^2 + y^2)
+ J h = hypot(x, y, z);
+ J s = sqrt(x * x + y * y + z * z);
+ VL << "h = " << h;
+ VL << "s = " << s;
+ ExpectJetsClose(h, s);
+ }
+
+ { // Check that hypot(x, x) == sqrt(3) * abs(x)
+ J h = hypot(x, x, x);
+ J s = sqrt(3.0) * abs(x);
+ VL << "h = " << h;
+ VL << "s = " << s;
+ ExpectJetsClose(h, s);
+ }
+
+ { // Check that the derivative is zero tangentially to the circle:
+ J h = hypot(MakeJet(2.0, 1.0, 1.0),
+ MakeJet(2.0, 1.0, -1.0),
+ MakeJet(2.0, -1.0, 0.0));
+ VL << "h = " << h;
+ ExpectJetsClose(h, MakeJet(sqrt(12.0), 1.0 / std::sqrt(3.0), 0.0));
+ }
+
+ { // Check that hypot(x, 0, 0) == x
+ J zero = MakeJet(0.0, 2.0, 3.14);
+ J h = hypot(x, zero, zero);
+ VL << "h = " << h;
+ ExpectJetsClose(x, h);
+ }
+
+ { // Check that hypot(0, y, 0) == y
+ J zero = MakeJet(0.0, 2.0, 3.14);
+ J h = hypot(zero, y, zero);
+ VL << "h = " << h;
+ ExpectJetsClose(y, h);
+ }
+
+ { // Check that hypot(0, 0, z) == z
+ J zero = MakeJet(0.0, 2.0, 3.14);
+ J h = hypot(zero, zero, z);
+ VL << "h = " << h;
+ ExpectJetsClose(z, h);
+ }
+
+ { // Check that hypot(x, y, z) == hypot(hypot(x, y), z)
+ J v = hypot(x, y, z);
+ J w = hypot(hypot(x, y), z);
+ VL << "v = " << v;
+ VL << "w = " << w;
+ ExpectJetsClose(v, w);
+ }
+
+ { // Check that hypot(x, y, z) == hypot(x, hypot(y, z))
+ J v = hypot(x, y, z);
+ J w = hypot(x, hypot(y, z));
+ VL << "v = " << v;
+ VL << "w = " << w;
+ ExpectJetsClose(v, w);
+ }
+
+ { // Check that hypot(x, 0, 0) == sqrt(x * x) == x, even when x * x
+ // underflows:
+ EXPECT_EQ(
+ std::numeric_limits<double>::min() * std::numeric_limits<double>::min(),
+ 0.0); // Make sure it underflows
+ J huge = MakeJet(std::numeric_limits<double>::min(), 2.0, 3.14);
+ J h = hypot(huge, J(0.0), J(0.0));
+ VL << "h = " << h;
+ ExpectJetsClose(h, huge);
+ }
+
+ { // Check that hypot(x, 0, 0) == sqrt(x * x) == x, even when x * x
+ // overflows:
+ EXPECT_EQ(
+ std::numeric_limits<double>::max() * std::numeric_limits<double>::max(),
+ std::numeric_limits<double>::infinity());
+ J huge = MakeJet(std::numeric_limits<double>::max(), 2.0, 3.14);
+ J h = hypot(huge, J(0.0), J(0.0));
+ VL << "h = " << h;
+ ExpectJetsClose(h, huge);
+ }
+#endif // __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >=
+ // 201703L)
+
{
J z = fmax(x, y);
VL << "z = " << z;