support norm jet

norm allows to compute the squared magnitude both of complex and real
numbers. While Jet does not support decaying of a std::complex to a
scalar performed by norm, the function is still useful when applied to
scalars alone for computing the square.

Change-Id: I27a3513f53f37fb3960411362e80d170d1ae6f74
diff --git a/include/ceres/jet.h b/include/ceres/jet.h
index c97e621..eb3f8a7 100644
--- a/include/ceres/jet.h
+++ b/include/ceres/jet.h
@@ -158,6 +158,7 @@
 #define CERES_PUBLIC_JET_H_
 
 #include <cmath>
+#include <complex>
 #include <iosfwd>
 #include <iostream>  // NOLINT
 #include <limits>
@@ -403,6 +404,7 @@
 using std::isnormal;
 using std::log;
 using std::log2;
+using std::norm;
 using std::pow;
 using std::sin;
 using std::sinh;
@@ -799,6 +801,22 @@
   return Jet<T, N>(atan2(g.a, f.a), tmp * (-g.a * f.v + f.a * g.v));
 }
 
+// Computes the square x^2 of a real number x (not the Euclidean L^2 norm as
+// the name might suggest).
+//
+// NOTE While std::norm is primarly intended for computing the squared magnitude
+// of a std::complex<> number, the current Jet implementation does not support
+// mixing a scalar T in its real part and std::complex<T> and in the
+// infinitesimal. Mixed Jet support is necessary for the type decay from
+// std::complex<T> to T (the squared magnitude of a complex number is always
+// real) performed by std::norm.
+//
+// norm(x + h) ~= norm(x) + 2x h
+template <typename T, int N>
+inline Jet<T, N> norm(const Jet<T, N>& f) {
+  return Jet<T, N>(norm(f.a), T(2) * f.a * f.v);
+}
+
 // pow -- base is a differentiable function, exponent is a constant.
 // (a+da)^p ~= a^p + p*a^(p-1) da
 template <typename T, int N>
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc
index 5631312..6c3fe44 100644
--- a/internal/ceres/jet_test.cc
+++ b/internal/ceres/jet_test.cc
@@ -657,6 +657,22 @@
   NumericalTest("log2", log2<double, 2>, 1.0);
   NumericalTest("log2", log2<double, 2>, 100.0);
 
+  {  // Check that norm(x) == x^2
+    J v = norm(x);
+    J w = x * x;
+    VL << "v = " << v;
+    VL << "w = " << w;
+    ExpectJetsClose(v, w);
+  }
+
+  {  // Check that norm(-x) == x^2
+    J v = norm(-x);
+    J w = x * x;
+    VL << "v = " << v;
+    VL << "w = " << w;
+    ExpectJetsClose(v, w);
+  }
+
   {  // Check that hypot(x, y) == sqrt(x^2 + y^2)
     J h = hypot(x, y);
     J s = sqrt(x * x + y * y);