support 3-argument hypot jet

C++17 provides a three argument hypot(x, y, z) which can now be used
for jets if the standard is active.

Change-Id: Ide62e101f780fe738bb2d4f826b10daf94c585b3
diff --git a/include/ceres/jet.h b/include/ceres/jet.h
index 38a69f4..d20038b 100644
--- a/include/ceres/jet.h
+++ b/include/ceres/jet.h
@@ -630,6 +630,29 @@
   return Jet<T, N>(tmp, x.a / tmp * x.v + y.a / tmp * y.v);
 }
 
+#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
+// Like sqrt(x^2 + y^2 + z^2),
+// but acts to prevent underflow/overflow for small/large x/y/z.
+// Note that the function is non-smooth at x=y=z=0,
+// so the derivative is undefined there.
+template <typename T, int N>
+inline Jet<T, N> hypot(const Jet<T, N>& x,
+                       const Jet<T, N>& y,
+                       const Jet<T, N>& z) {
+  // d/da sqrt(a) = 0.5 / sqrt(a)
+  // d/dx x^2 + y^2 + z^2 = 2x
+  // So by the chain rule:
+  // d/dx sqrt(x^2 + y^2 + z^2)
+  //    = 0.5 / sqrt(x^2 + y^2 + z^2) * 2x
+  //    = x / sqrt(x^2 + y^2 + z^2)
+  // d/dy sqrt(x^2 + y^2 + z^2) = y / sqrt(x^2 + y^2 + z^2)
+  // d/dz sqrt(x^2 + y^2 + z^2) = z / sqrt(x^2 + y^2 + z^2)
+  const T tmp = hypot(x.a, y.a, z.a);
+  return Jet<T, N>(tmp, x.a / tmp * x.v + y.a / tmp * y.v + z.a / tmp * z.v);
+}
+#endif  // __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >=
+        // 201703L)
+
 template <typename T, int N>
 inline Jet<T, N> fmax(const Jet<T, N>& x, const Jet<T, N>& y) {
   using std::isgreater;
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc
index b7b6004..55a14b3 100644
--- a/internal/ceres/jet_test.cc
+++ b/internal/ceres/jet_test.cc
@@ -116,9 +116,11 @@
   // Pick arbitrary values for x and y.
   J x = MakeJet(2.3, -2.7, 1e-3);
   J y = MakeJet(1.7, 0.5, 1e+2);
+  J z = MakeJet(1e-6, 1e-4, 1e-2);
 
   VL << "x = " << x;
   VL << "y = " << y;
+  VL << "z = " << z;
 
   {  // Check that log(exp(x)) == x.
     J z = exp(x);
@@ -792,18 +794,108 @@
     ExpectJetsClose(h, huge);
   }
 
+  // Resolve the ambiguity between two and three argument hypot overloads
+  using Hypot2 = J(const J&, const J&);
+  Hypot2* const hypot2 = static_cast<Hypot2*>(&hypot<double, 2>);
+
   // clang-format off
-  NumericalTest2("hypot", hypot<double, 2>,  0.0,   1e-5);
-  NumericalTest2("hypot", hypot<double, 2>, -1e-5,  0.0);
-  NumericalTest2("hypot", hypot<double, 2>,  1e-5,  1e-5);
-  NumericalTest2("hypot", hypot<double, 2>,  0.0,   1.0);
-  NumericalTest2("hypot", hypot<double, 2>,  1e-3,  1.0);
-  NumericalTest2("hypot", hypot<double, 2>,  1e-3, -1.0);
-  NumericalTest2("hypot", hypot<double, 2>, -1e-3,  1.0);
-  NumericalTest2("hypot", hypot<double, 2>, -1e-3, -1.0);
-  NumericalTest2("hypot", hypot<double, 2>,  1.0,   2.0);
+  NumericalTest2("hypot2", hypot2,  0.0,   1e-5);
+  NumericalTest2("hypot2", hypot2, -1e-5,  0.0);
+  NumericalTest2("hypot2", hypot2,  1e-5,  1e-5);
+  NumericalTest2("hypot2", hypot2,  0.0,   1.0);
+  NumericalTest2("hypot2", hypot2,  1e-3,  1.0);
+  NumericalTest2("hypot2", hypot2,  1e-3, -1.0);
+  NumericalTest2("hypot2", hypot2, -1e-3,  1.0);
+  NumericalTest2("hypot2", hypot2, -1e-3, -1.0);
+  NumericalTest2("hypot2", hypot2,  1.0,   2.0);
   // clang-format on
 
+#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
+  {  // Check that hypot(x, y) == sqrt(x^2 + y^2)
+    J h = hypot(x, y, z);
+    J s = sqrt(x * x + y * y + z * z);
+    VL << "h = " << h;
+    VL << "s = " << s;
+    ExpectJetsClose(h, s);
+  }
+
+  {  // Check that hypot(x, x) == sqrt(3) * abs(x)
+    J h = hypot(x, x, x);
+    J s = sqrt(3.0) * abs(x);
+    VL << "h = " << h;
+    VL << "s = " << s;
+    ExpectJetsClose(h, s);
+  }
+
+  {  // Check that the derivative is zero tangentially to the circle:
+    J h = hypot(MakeJet(2.0, 1.0, 1.0),
+                MakeJet(2.0, 1.0, -1.0),
+                MakeJet(2.0, -1.0, 0.0));
+    VL << "h = " << h;
+    ExpectJetsClose(h, MakeJet(sqrt(12.0), 1.0 / std::sqrt(3.0), 0.0));
+  }
+
+  {  // Check that hypot(x, 0, 0) == x
+    J zero = MakeJet(0.0, 2.0, 3.14);
+    J h = hypot(x, zero, zero);
+    VL << "h = " << h;
+    ExpectJetsClose(x, h);
+  }
+
+  {  // Check that hypot(0, y, 0) == y
+    J zero = MakeJet(0.0, 2.0, 3.14);
+    J h = hypot(zero, y, zero);
+    VL << "h = " << h;
+    ExpectJetsClose(y, h);
+  }
+
+  {  // Check that hypot(0, 0, z) == z
+    J zero = MakeJet(0.0, 2.0, 3.14);
+    J h = hypot(zero, zero, z);
+    VL << "h = " << h;
+    ExpectJetsClose(z, h);
+  }
+
+  {  // Check that hypot(x, y, z) == hypot(hypot(x, y), z)
+    J v = hypot(x, y, z);
+    J w = hypot(hypot(x, y), z);
+    VL << "v = " << v;
+    VL << "w = " << w;
+    ExpectJetsClose(v, w);
+  }
+
+  {  // Check that hypot(x, y, z) == hypot(x, hypot(y, z))
+    J v = hypot(x, y, z);
+    J w = hypot(x, hypot(y, z));
+    VL << "v = " << v;
+    VL << "w = " << w;
+    ExpectJetsClose(v, w);
+  }
+
+  {  // Check that hypot(x, 0, 0) == sqrt(x * x) == x, even when x * x
+     // underflows:
+    EXPECT_EQ(
+        std::numeric_limits<double>::min() * std::numeric_limits<double>::min(),
+        0.0);  // Make sure it underflows
+    J huge = MakeJet(std::numeric_limits<double>::min(), 2.0, 3.14);
+    J h = hypot(huge, J(0.0), J(0.0));
+    VL << "h = " << h;
+    ExpectJetsClose(h, huge);
+  }
+
+  {  // Check that hypot(x, 0, 0) == sqrt(x * x) == x, even when x * x
+     // overflows:
+    EXPECT_EQ(
+        std::numeric_limits<double>::max() * std::numeric_limits<double>::max(),
+        std::numeric_limits<double>::infinity());
+    J huge = MakeJet(std::numeric_limits<double>::max(), 2.0, 3.14);
+    J h = hypot(huge, J(0.0), J(0.0));
+    VL << "h = " << h;
+    ExpectJetsClose(h, huge);
+  }
+#endif  // __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >=
+        // 201703L)
+
   {
     J z = fmax(x, y);
     VL << "z = " << z;