Update Jet.h and rotation.h to use the new IF/ELSE macros
Also use branchless implementation for isfinite(Jet),
isinf(Jet), isnan(Jet), and isnormal(Jet).
Change-Id: Ia881df03ba873e0560d67e976ab1e99e199eb523
diff --git a/examples/autodiff_codegen.cc b/examples/autodiff_codegen.cc
index 788a7ce..7813ac4 100644
--- a/examples/autodiff_codegen.cc
+++ b/examples/autodiff_codegen.cc
@@ -39,6 +39,7 @@
template <typename T>
bool operator()(const T* x, T* residual) const {
residual[0] = x[0] * x[0];
+ isfinite(x[0]);
return true;
}
};
diff --git a/include/ceres/codegen/internal/expression_ref.h b/include/ceres/codegen/internal/expression_ref.h
index 6a04edb..1bdc3b5 100644
--- a/include/ceres/codegen/internal/expression_ref.h
+++ b/include/ceres/codegen/internal/expression_ref.h
@@ -208,6 +208,10 @@
const ComparisonExpressionRef& y);
ComparisonExpressionRef operator||(const ComparisonExpressionRef& x,
const ComparisonExpressionRef& y);
+ComparisonExpressionRef operator&(const ComparisonExpressionRef& x,
+ const ComparisonExpressionRef& y);
+ComparisonExpressionRef operator|(const ComparisonExpressionRef& x,
+ const ComparisonExpressionRef& y);
ComparisonExpressionRef operator!(const ComparisonExpressionRef& x);
#define CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(name) \
diff --git a/include/ceres/codegen/macros.h b/include/ceres/codegen/macros.h
index 0efc9e2..fbb2951 100644
--- a/include/ceres/codegen/macros.h
+++ b/include/ceres/codegen/macros.h
@@ -116,4 +116,11 @@
AddExpressionToGraph(ceres::internal::Expression::CreateEndIf());
#endif
+namespace ceres {
+// A function equivalent to the ternary ?-operator.
+// This function is required, because in the context of code generation a
+// comparison returns an expression type which is not convertible to bool.
+inline double Ternary(bool c, double a, double b) { return c ? a : b; }
+} // namespace ceres
+
#endif // CERES_PUBLIC_CODEGEN_MACROS_H_
diff --git a/include/ceres/jet.h b/include/ceres/jet.h
index fb7afce..00ecdca 100644
--- a/include/ceres/jet.h
+++ b/include/ceres/jet.h
@@ -378,14 +378,6 @@
CERES_DEFINE_JET_COMPARISON_OPERATOR(!=) // NOLINT
#undef CERES_DEFINE_JET_COMPARISON_OPERATOR
-// A function equivalent to the ternary ?-operator.
-// This function is required, because in the context of code generation a
-// comparison returns an expression type which is not convertible to bool.
-template <typename T>
-inline T Ternary(bool c, T a, T b) {
- return c ? a : b;
-}
-
template <typename T, int N>
inline Jet<T, N> Ternary(typename ComparisonReturnType<T>::type c,
const Jet<T, N>& f,
@@ -444,7 +436,7 @@
// abs(x + h) ~= x + h or -(x + h)
template <typename T, int N>
inline Jet<T, N> abs(const Jet<T, N>& f) {
- return f.a < T(0.0) ? -f : f;
+ return Ternary(f.a < T(0.0), -f, f);
}
// log(a + h) ~= log(a) + h / a
@@ -588,13 +580,13 @@
}
template <typename T, int N>
-inline const Jet<T, N>& fmax(const Jet<T, N>& x, const Jet<T, N>& y) {
- return x < y ? y : x;
+inline Jet<T, N> fmax(const Jet<T, N>& x, const Jet<T, N>& y) {
+ return Ternary(x < y, y, x);
}
template <typename T, int N>
-inline const Jet<T, N>& fmin(const Jet<T, N>& x, const Jet<T, N>& y) {
- return y < x ? y : x;
+inline Jet<T, N> fmin(const Jet<T, N>& x, const Jet<T, N>& y) {
+ return Ternary(y < x, y, x);
}
// Bessel functions of the first kind with integer order equal to 0, 1, n.
@@ -667,79 +659,65 @@
// The jet is finite if all parts of the jet are finite.
template <typename T, int N>
-inline bool isfinite(const Jet<T, N>& f) {
- if (!std::isfinite(f.a)) {
- return false;
- }
+inline typename ComparisonReturnType<T>::type isfinite(const Jet<T, N>& f) {
+ // Branchless implementation. This is more efficient for the false-case and
+ // works with the codegen system.
+ auto result = isfinite(f.a);
for (int i = 0; i < N; ++i) {
- if (!std::isfinite(f.v[i])) {
- return false;
- }
+ result = result & isfinite(f.v[i]);
}
- return true;
+ return result;
}
// The jet is infinite if any part of the Jet is infinite.
template <typename T, int N>
-inline bool isinf(const Jet<T, N>& f) {
- if (std::isinf(f.a)) {
- return true;
- }
+inline typename ComparisonReturnType<T>::type isinf(const Jet<T, N>& f) {
+ auto result = isinf(f.a);
for (int i = 0; i < N; ++i) {
- if (std::isinf(f.v[i])) {
- return true;
- }
+ result = result | isinf(f.v[i]);
}
- return false;
+ return result;
}
// The jet is NaN if any part of the jet is NaN.
template <typename T, int N>
-inline bool isnan(const Jet<T, N>& f) {
- if (std::isnan(f.a)) {
- return true;
- }
+inline typename ComparisonReturnType<T>::type isnan(const Jet<T, N>& f) {
+ auto result = isnan(f.a);
for (int i = 0; i < N; ++i) {
- if (std::isnan(f.v[i])) {
- return true;
- }
+ result = result | isnan(f.v[i]);
}
- return false;
+ return result;
}
// The jet is normal if all parts of the jet are normal.
template <typename T, int N>
-inline bool isnormal(const Jet<T, N>& f) {
- if (!std::isnormal(f.a)) {
- return false;
- }
+inline typename ComparisonReturnType<T>::type isnormal(const Jet<T, N>& f) {
+ auto result = isnormal(f.a);
for (int i = 0; i < N; ++i) {
- if (!std::isnormal(f.v[i])) {
- return false;
- }
+ result = result & isnormal(f.v[i]);
}
- return true;
+ return result;
}
// Legacy functions from the pre-C++11 days.
template <typename T, int N>
-inline bool IsFinite(const Jet<T, N>& f) {
+inline typename ComparisonReturnType<T>::type IsFinite(const Jet<T, N>& f) {
return isfinite(f);
}
template <typename T, int N>
-inline bool IsNaN(const Jet<T, N>& f) {
+inline typename ComparisonReturnType<T>::type IsNaN(const Jet<T, N>& f) {
return isnan(f);
}
template <typename T, int N>
-inline bool IsNormal(const Jet<T, N>& f) {
+inline typename ComparisonReturnType<T>::type IsNormal(const Jet<T, N>& f) {
return isnormal(f);
}
// The jet is infinite if any part of the jet is infinite.
template <typename T, int N>
-inline bool IsInfinite(const Jet<T, N>& f) {
+inline typename ComparisonReturnType<T>::type IsInfinite(const Jet<T, N>& f) {
return isinf(f);
}
@@ -778,25 +756,33 @@
// != 0, the derivatives are not defined and we return NaN.
template <typename T, int N>
-inline Jet<T, N> pow(double f, const Jet<T, N>& g) {
- if (f == 0 && g.a > 0) {
+inline Jet<T, N> pow(T f, const Jet<T, N>& g) {
+ Jet<T, N> result;
+
+ CERES_IF(f == T(0) && g.a > T(0)) {
// Handle case 2.
- return Jet<T, N>(T(0.0));
+ result = Jet<T, N>(T(0.0));
}
- if (f < 0 && g.a == floor(g.a)) {
- // Handle case 3.
- Jet<T, N> ret(pow(f, g.a));
- for (int i = 0; i < N; i++) {
- if (g.v[i] != T(0.0)) {
- // Return a NaN when g.v != 0.
- ret.v[i] = std::numeric_limits<T>::quiet_NaN();
+ CERES_ELSE {
+ CERES_IF(f < 0 && g.a == floor(g.a)) { // Handle case 3.
+ result = Jet<T, N>(pow(f, g.a));
+ for (int i = 0; i < N; i++) {
+ CERES_IF(g.v[i] != T(0.0)) {
+ // Return a NaN when g.v != 0.
+ result.v[i] = std::numeric_limits<T>::quiet_NaN();
+ }
+ CERES_ENDIF
}
}
- return ret;
+ CERES_ELSE {
+ // Handle case 1.
+ T const tmp = pow(f, g.a);
+ result = Jet<T, N>(tmp, log(f) * tmp * g.v);
+ }
+ CERES_ENDIF;
}
- // Handle case 1.
- T const tmp = pow(f, g.a);
- return Jet<T, N>(tmp, log(f) * tmp * g.v);
+ CERES_ENDIF
+ return result;
}
// pow -- both base and exponent are differentiable functions. This has a
@@ -837,32 +823,40 @@
template <typename T, int N>
inline Jet<T, N> pow(const Jet<T, N>& f, const Jet<T, N>& g) {
- if (f.a == 0 && g.a >= 1) {
+ Jet<T, N> result;
+
+ CERES_IF(f.a == T(0) && g.a >= T(1)) {
// Handle cases 2 and 3.
- if (g.a > 1) {
- return Jet<T, N>(T(0.0));
- }
- return f;
+ CERES_IF(g.a > T(1)) { result = Jet<T, N>(T(0.0)); }
+ CERES_ELSE { result = f; }
+ CERES_ENDIF;
}
- if (f.a < 0 && g.a == floor(g.a)) {
- // Handle cases 7 and 8.
- T const tmp = g.a * pow(f.a, g.a - T(1.0));
- Jet<T, N> ret(pow(f.a, g.a), tmp * f.v);
- for (int i = 0; i < N; i++) {
- if (g.v[i] != T(0.0)) {
- // Return a NaN when g.v != 0.
- ret.v[i] = std::numeric_limits<T>::quiet_NaN();
+ CERES_ELSE {
+ CERES_IF(f.a < T(0) && g.a == floor(g.a)) {
+ // Handle cases 7 and 8.
+ T const tmp = g.a * pow(f.a, g.a - T(1.0));
+ result = Jet<T, N>(pow(f.a, g.a), tmp * f.v);
+ for (int i = 0; i < N; i++) {
+ CERES_IF(g.v[i] != T(0.0)) {
+ // Return a NaN when g.v != 0.
+ result.v[i] = T(std::numeric_limits<double>::quiet_NaN());
+ }
+ CERES_ENDIF;
}
}
- return ret;
+ CERES_ELSE {
+ // Handle the remaining cases. For cases 4,5,6,9 we allow the log()
+ // function to generate -HUGE_VAL or NaN, since those cases result in a
+ // nonfinite derivative.
+ T const tmp1 = pow(f.a, g.a);
+ T const tmp2 = g.a * pow(f.a, g.a - T(1.0));
+ T const tmp3 = tmp1 * log(f.a);
+ result = Jet<T, N>(tmp1, tmp2 * f.v + tmp3 * g.v);
+ }
+ CERES_ENDIF;
}
- // Handle the remaining cases. For cases 4,5,6,9 we allow the log() function
- // to generate -HUGE_VAL or NaN, since those cases result in a nonfinite
- // derivative.
- T const tmp1 = pow(f.a, g.a);
- T const tmp2 = g.a * pow(f.a, g.a - T(1.0));
- T const tmp3 = tmp1 * log(f.a);
- return Jet<T, N>(tmp1, tmp2 * f.v + tmp3 * g.v);
+ CERES_ENDIF;
+ return result;
}
// Note: This has to be in the ceres namespace for argument dependent lookup to
diff --git a/include/ceres/rotation.h b/include/ceres/rotation.h
index 7d5c8ef..ce06a2c 100644
--- a/include/ceres/rotation.h
+++ b/include/ceres/rotation.h
@@ -48,7 +48,7 @@
#include <algorithm>
#include <cmath>
#include <limits>
-
+#include "codegen/macros.h"
#include "glog/logging.h"
namespace ceres {
@@ -253,7 +253,7 @@
const T theta_squared = a0 * a0 + a1 * a1 + a2 * a2;
// For points not at the origin, the full conversion is numerically stable.
- if (theta_squared > T(0.0)) {
+ CERES_IF(theta_squared > T(0.0)) {
const T theta = sqrt(theta_squared);
const T half_theta = theta * T(0.5);
const T k = sin(half_theta) / theta;
@@ -261,7 +261,8 @@
quaternion[1] = a0 * k;
quaternion[2] = a1 * k;
quaternion[3] = a2 * k;
- } else {
+ }
+ CERES_ELSE {
// At the origin, sqrt() will produce NaN in the derivative since
// the argument is zero. By approximating with a Taylor series,
// and truncating at one term, the value and first derivatives will be
@@ -272,6 +273,7 @@
quaternion[2] = a1 * k;
quaternion[3] = a2 * k;
}
+ CERES_ENDIF;
}
template <typename T>
@@ -283,7 +285,7 @@
// For quaternions representing non-zero rotation, the conversion
// is numerically stable.
- if (sin_squared_theta > T(0.0)) {
+ CERES_IF(sin_squared_theta > T(0.0)) {
const T sin_theta = sqrt(sin_squared_theta);
const T& cos_theta = quaternion[0];
@@ -300,14 +302,15 @@
// theta - pi = atan(sin(theta - pi), cos(theta - pi))
// = atan(-sin(theta), -cos(theta))
//
- const T two_theta =
- T(2.0) * ((cos_theta < T(0.0)) ? atan2(-sin_theta, -cos_theta)
- : atan2(sin_theta, cos_theta));
+ const T two_theta = T(2.0) * Ternary((cos_theta < T(0.0)),
+ atan2(-sin_theta, -cos_theta),
+ atan2(sin_theta, cos_theta));
const T k = two_theta / sin_theta;
angle_axis[0] = q1 * k;
angle_axis[1] = q2 * k;
angle_axis[2] = q3 * k;
- } else {
+ }
+ CERES_ELSE {
// For zero rotation, sqrt() will produce NaN in the derivative since
// the argument is zero. By approximating with a Taylor series,
// and truncating at one term, the value and first derivatives will be
@@ -317,6 +320,7 @@
angle_axis[1] = q2 * k;
angle_axis[2] = q3 * k;
}
+ CERES_ENDIF;
}
template <typename T>
@@ -387,7 +391,7 @@
const T* angle_axis, const MatrixAdapter<T, row_stride, col_stride>& R) {
static const T kOne = T(1.0);
const T theta2 = DotProduct(angle_axis, angle_axis);
- if (theta2 > T(std::numeric_limits<double>::epsilon())) {
+ CERES_IF(theta2 > T(std::numeric_limits<double>::epsilon())) {
// We want to be careful to only evaluate the square root if the
// norm of the angle_axis vector is greater than zero. Otherwise
// we get a division by zero.
@@ -410,7 +414,8 @@
R(1, 2) = -wx*sintheta + wy*wz*(kOne - costheta);
R(2, 2) = costheta + wz*wz*(kOne - costheta);
// clang-format on
- } else {
+ }
+ CERES_ELSE {
// Near zero, we switch to using the first order Taylor expansion.
R(0, 0) = kOne;
R(1, 0) = angle_axis[2];
@@ -422,6 +427,7 @@
R(1, 2) = -angle_axis[0];
R(2, 2) = kOne;
}
+ CERES_ENDIF;
}
template <typename T>
@@ -591,7 +597,7 @@
DCHECK_NE(pt, result) << "Inplace rotation is not supported.";
const T theta2 = DotProduct(angle_axis, angle_axis);
- if (theta2 > T(std::numeric_limits<double>::epsilon())) {
+ CERES_IF(theta2 > T(std::numeric_limits<double>::epsilon())) {
// Away from zero, use the rodriguez formula
//
// result = pt costheta +
@@ -622,7 +628,8 @@
result[0] = pt[0] * costheta + w_cross_pt[0] * sintheta + w[0] * tmp;
result[1] = pt[1] * costheta + w_cross_pt[1] * sintheta + w[1] * tmp;
result[2] = pt[2] * costheta + w_cross_pt[2] * sintheta + w[2] * tmp;
- } else {
+ }
+ CERES_ELSE {
// Near zero, the first order Taylor approximation of the rotation
// matrix R corresponding to a vector w and angle w is
//
@@ -648,6 +655,7 @@
result[1] = pt[1] + w_cross_pt[1];
result[2] = pt[2] + w_cross_pt[2];
}
+ CERES_ENDIF;
}
} // namespace ceres
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc
index 7c43595..52b7e0b 100644
--- a/internal/ceres/expression_ref.cc
+++ b/internal/ceres/expression_ref.cc
@@ -190,6 +190,8 @@
CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR(!=)
CERES_DEFINE_EXPRESSION_LOGICAL_OPERATOR(&&)
CERES_DEFINE_EXPRESSION_LOGICAL_OPERATOR(||)
+CERES_DEFINE_EXPRESSION_LOGICAL_OPERATOR(&)
+CERES_DEFINE_EXPRESSION_LOGICAL_OPERATOR(|)
#undef CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR
#undef CERES_DEFINE_EXPRESSION_LOGICAL_OPERATOR