Add sinh, cosh, tanh and tan functions to automatic differentiation Change-Id: I6eb43fe9b340d4074ed3eed1461dda315f6e8ce8
diff --git a/include/ceres/jet.h b/include/ceres/jet.h index 000bd1c..7299a48 100644 --- a/include/ceres/jet.h +++ b/include/ceres/jet.h
@@ -405,7 +405,6 @@ // double-valued and Jet-valued functions, but we are not allowed to put // Jet-valued functions inside namespace std. // -// Missing: cosh, sinh, tanh, tan // TODO(keir): Switch to "using". inline double abs (double x) { return std::abs(x); } inline double log (double x) { return std::log(x); } @@ -415,6 +414,11 @@ inline double acos (double x) { return std::acos(x); } inline double sin (double x) { return std::sin(x); } inline double asin (double x) { return std::asin(x); } +inline double tan (double x) { return std::tan(x); } +inline double atan (double x) { return std::atan(x); } +inline double sinh (double x) { return std::sinh(x); } +inline double cosh (double x) { return std::cosh(x); } +inline double tanh (double x) { return std::tanh(x); } inline double pow (double x, double y) { return std::pow(x, y); } inline double atan2(double y, double x) { return std::atan2(y, x); } @@ -495,6 +499,58 @@ return g; } +// tan(a + h) ~= tan(a) + (1 + tan(a)^2) h +template <typename T, int N> inline +Jet<T, N> tan(const Jet<T, N>& f) { + Jet<T, N> g; + g.a = tan(f.a); + double tan_a = tan(f.a); + const T tmp = T(1.0) + tan_a * tan_a; + g.v = tmp * f.v; + return g; +} + +// atan(a + h) ~= atan(a) + 1 / (1 + a^2) h +template <typename T, int N> inline +Jet<T, N> atan(const Jet<T, N>& f) { + Jet<T, N> g; + g.a = atan(f.a); + const T tmp = T(1.0) / (T(1.0) + f.a * f.a); + g.v = tmp * f.v; + return g; +} + +// sinh(a + h) ~= sinh(a) + cosh(a) h +template <typename T, int N> inline +Jet<T, N> sinh(const Jet<T, N>& f) { + Jet<T, N> g; + g.a = sinh(f.a); + const T cosh_a = cosh(f.a); + g.v = cosh_a * f.v; + return g; +} + +// cosh(a + h) ~= cosh(a) + sinh(a) h +template <typename T, int N> inline +Jet<T, N> cosh(const Jet<T, N>& f) { + Jet<T, N> g; + g.a = cosh(f.a); + const T sinh_a = sinh(f.a); + g.v = sinh_a * f.v; + return g; +} + +// tanh(a + h) ~= tanh(a) + (1 - tanh(a)^2) h +template <typename T, int N> inline +Jet<T, N> tanh(const Jet<T, N>& f) { + Jet<T, N> g; + g.a = tanh(f.a); + double tanh_fa = tanh(f.a); + const T tmp = 1 - tanh_fa * tanh_fa; + g.v = tmp * f.v; + return g; +} + // Jet Classification. It is not clear what the appropriate semantics are for // these classifications. This picks that IsFinite and isnormal are "all" // operations, i.e. all elements of the jet must be finite for the jet itself @@ -645,6 +701,11 @@ template<typename T, int N> inline Jet<T, N> ei_log (const Jet<T, N>& x) { return log(x); } // NOLINT template<typename T, int N> inline Jet<T, N> ei_sin (const Jet<T, N>& x) { return sin(x); } // NOLINT template<typename T, int N> inline Jet<T, N> ei_cos (const Jet<T, N>& x) { return cos(x); } // NOLINT +template<typename T, int N> inline Jet<T, N> ei_tan (const Jet<T, N>& x) { return tan(x); } // NOLINT +template<typename T, int N> inline Jet<T, N> ei_atan(const Jet<T, N>& x) { return atan(x); } // NOLINT +template<typename T, int N> inline Jet<T, N> ei_sinh(const Jet<T, N>& x) { return sinh(x); } // NOLINT +template<typename T, int N> inline Jet<T, N> ei_cosh(const Jet<T, N>& x) { return cosh(x); } // NOLINT +template<typename T, int N> inline Jet<T, N> ei_tanh(const Jet<T, N>& x) { return tanh(x); } // NOLINT template<typename T, int N> inline Jet<T, N> ei_pow (const Jet<T, N>& x, Jet<T, N> y) { return pow(x, y); } // NOLINT // Note: This has to be in the ceres namespace for argument dependent lookup to
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc index 0dd4336..6b8cf17 100644 --- a/internal/ceres/jet_test.cc +++ b/internal/ceres/jet_test.cc
@@ -142,6 +142,38 @@ ExpectJetsClose(u, t); } + { // Check that tan(x) = sin(x) / cos(x). + J z = tan(x); + J w = sin(x) / cos(x); + VL << "z = " << z; + VL << "w = " << w; + ExpectJetsClose(z, w); + } + + { // Check that tan(atan(x)) = x. + J z = tan(atan(x)); + J w = x; + VL << "z = " << z; + VL << "w = " << w; + ExpectJetsClose(z, w); + } + + { // Check that cosh(x)*cosh(x) - sinh(x)*sinh(x) = 1 + J z = cosh(x) * cosh(x); + J w = sinh(x) * sinh(x); + VL << "z = " << z; + VL << "w = " << w; + ExpectJetsClose(z - w, J(1.0)); + } + + { // Check that tanh(x + y) = (tanh(x) + tanh(y)) / (1 + tanh(x) tanh(y)) + J z = tanh(x + y); + J w = (tanh(x) + tanh(y)) / (J(1.0) + tanh(x) * tanh(y)); + VL << "z = " << z; + VL << "w = " << w; + ExpectJetsClose(z, w); + } + { // Check that pow(x, 1) == x. VL << "x = " << x;