support log1p and expm1 jet
Currently, it is not possible to accurately evaluate the derivative of
d/dx log(1 + x) under all circumstances and significant deviations from
the actual derivative d/dx log1p(x) can occur. This changeset introduces
the necessary Jet overload and its inverse, expm1.
Change-Id: Ifcf88f6d684f61ba86bbe49f0d551b703f34ad0d
diff --git a/include/ceres/jet.h b/include/ceres/jet.h
index eb3f8a7..4f3b144 100644
--- a/include/ceres/jet.h
+++ b/include/ceres/jet.h
@@ -394,6 +394,7 @@
using std::erfc;
using std::exp;
using std::exp2;
+using std::expm1;
using std::floor;
using std::fmax;
using std::fmin;
@@ -403,6 +404,7 @@
using std::isnan;
using std::isnormal;
using std::log;
+using std::log1p;
using std::log2;
using std::norm;
using std::pow;
@@ -471,6 +473,13 @@
return Jet<T, N>(log(f.a), f.v * a_inverse);
}
+// log1p(a + h) ~= log1p(a) + h / (1 + a)
+template <typename T, int N>
+inline Jet<T, N> log1p(const Jet<T, N>& f) {
+ const T a_inverse = T(1.0) / (T(1.0) + f.a);
+ return Jet<T, N>(log1p(f.a), f.v * a_inverse);
+}
+
// exp(a + h) ~= exp(a) + exp(a) h
template <typename T, int N>
inline Jet<T, N> exp(const Jet<T, N>& f) {
@@ -478,6 +487,12 @@
return Jet<T, N>(tmp, tmp * f.v);
}
+// expm1(a + h) ~= expm1(a) + exp(a) h
+template <typename T, int N>
+inline Jet<T, N> expm1(const Jet<T, N>& f) {
+ return Jet<T, N>(expm1(f.a), exp(f.a) * f.v);
+}
+
// sqrt(a + h) ~= sqrt(a) + h / (2 sqrt(a))
template <typename T, int N>
inline Jet<T, N> sqrt(const Jet<T, N>& f) {
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc
index 6c3fe44..6a7011d 100644
--- a/internal/ceres/jet_test.cc
+++ b/internal/ceres/jet_test.cc
@@ -128,6 +128,22 @@
ExpectJetsClose(w, x);
}
+ { // Check that expm1(log1p(x)) == x.
+ J z = expm1(x);
+ J w = log1p(z);
+ VL << "z = " << z;
+ VL << "w = " << w;
+ ExpectJetsClose(w, x);
+ }
+
+ { // Check that log1p(expm1(x)) == x.
+ J z = log1p(x);
+ J w = expm1(z);
+ VL << "z = " << z;
+ VL << "w = " << w;
+ ExpectJetsClose(w, x);
+ }
+
{ // Check that (x * y) / x == y.
J z = x * y;
J w = z / x;
@@ -631,6 +647,46 @@
NumericalTest("cbrt", cbrt<double, 2>, 1e-5);
NumericalTest("cbrt", cbrt<double, 2>, 1.0);
+ { // Check that log1p(x) == log(1 + x)
+ J z = log1p(x);
+ J w = log(J{1} + x);
+ VL << "z = " << z;
+ VL << "w = " << w;
+ ExpectJetsClose(z, w);
+ }
+
+ { // Check that log1p(x) does not loose precision for small x
+ J x = MakeJet(1e-16, 1e-8, 1e-4);
+ J z = log1p(x);
+ J w = MakeJet(9.9999999999999998e-17, 1e-8, 1e-4);
+ VL << "z = " << z;
+ VL << "w = " << w;
+ ExpectJetsClose(z, w);
+ // log(1 + x) collapes to 0
+ J v = log(J{1} + x);
+ EXPECT_TRUE(v.a == 0);
+ }
+
+ { // Check that expm1(x) == exp(x) - 1
+ J z = expm1(x);
+ J w = exp(x) - J{1};
+ VL << "z = " << z;
+ VL << "w = " << w;
+ ExpectJetsClose(z, w);
+ }
+
+ { // Check that expm1(x) does not loose precision for small x
+ J x = MakeJet(9.9999999999999998e-17, 1e-8, 1e-4);
+ J z = expm1(x);
+ J w = MakeJet(1e-16, 1e-8, 1e-4);
+ VL << "z = " << z;
+ VL << "w = " << w;
+ ExpectJetsClose(z, w);
+ // exp(x) - 1 collapes to 0
+ J v = exp(x) - J{1};
+ EXPECT_TRUE(v.a == 0);
+ }
+
{ // Check that exp2(x) == exp(x * log(2))
J z = exp2(x);
J w = exp(x * log(2.0));