Speed up Jets.
Change-Id: I101bac1b1a1cf72ca49ffcf843b73c0ef5a6dfcb
diff --git a/include/ceres/jet.h b/include/ceres/jet.h
index 96e2256..1238123 100644
--- a/include/ceres/jet.h
+++ b/include/ceres/jet.h
@@ -349,7 +349,11 @@
//
// which holds because v*v = 0.
h.a = f.a / g.a;
- h.v = (f.v - f.a / g.a * g.v) / g.a;
+ const T g_a_inverse = 1.0 / g.a;
+ const T f_a_by_g_a = f.a * g_a_inverse;
+ for (int i = 0; i < N; ++i) {
+ h.v[i] = (f.v[i] - f_a_by_g_a * g.v[i]) * g_a_inverse;
+ }
return h;
}
@@ -358,7 +362,8 @@
Jet<T, N> operator/(T s, const Jet<T, N>& g) {
Jet<T, N> h;
h.a = s / g.a;
- h.v = - s * g.v / (g.a * g.a);
+ const T minus_s_g_a_inverse2 = -s / (g.a * g.a);
+ h.v = g.v * minus_s_g_a_inverse2;
return h;
}
@@ -366,8 +371,9 @@
template<typename T, int N> inline
Jet<T, N> operator/(const Jet<T, N>& f, T s) {
Jet<T, N> h;
- h.a = f.a / s;
- h.v = f.v / s;
+ const T s_inverse = 1.0 / s;
+ h.a = f.a * s_inverse;
+ h.v = f.v * s_inverse;
return h;
}
@@ -425,7 +431,8 @@
Jet<T, N> log(const Jet<T, N>& f) {
Jet<T, N> g;
g.a = log(f.a);
- g.v = f.v / f.a;
+ const T a_inverse = T(1.0) / f.a;
+ g.v = f.v * a_inverse;
return g;
}
@@ -443,7 +450,8 @@
Jet<T, N> sqrt(const Jet<T, N>& f) {
Jet<T, N> g;
g.a = sqrt(f.a);
- g.v = f.v / (T(2.0) * g.a);
+ const T two_a_inverse = 1.0 / (T(2.0) * g.a);
+ g.v = f.v * two_a_inverse;
return g;
}
@@ -452,7 +460,7 @@
Jet<T, N> cos(const Jet<T, N>& f) {
Jet<T, N> g;
g.a = cos(f.a);
- T sin_a = sin(f.a);
+ const T sin_a = sin(f.a);
g.v = - sin_a * f.v;
return g;
}
@@ -462,7 +470,8 @@
Jet<T, N> acos(const Jet<T, N>& f) {
Jet<T, N> g;
g.a = acos(f.a);
- g.v = - T(1.0) / sqrt(T(1.0) - f.a * f.a) * f.v;
+ const T tmp = - T(1.0) / sqrt(T(1.0) - f.a * f.a);
+ g.v = tmp * f.v;
return g;
}
@@ -471,7 +480,7 @@
Jet<T, N> sin(const Jet<T, N>& f) {
Jet<T, N> g;
g.a = sin(f.a);
- T cos_a = cos(f.a);
+ const T cos_a = cos(f.a);
g.v = cos_a * f.v;
return g;
}
@@ -481,7 +490,8 @@
Jet<T, N> asin(const Jet<T, N>& f) {
Jet<T, N> g;
g.a = asin(f.a);
- g.v = T(1.0) / sqrt(T(1.0) - f.a * f.a) * f.v;
+ const T tmp = T(1.0) / sqrt(T(1.0) - f.a * f.a);
+ g.v = tmp * f.v;
return g;
}