Update Jet.h and rotation.h to use the new IF/ELSE macros Also use branchless implementation for isfinite(Jet), isinf(Jet), isnan(Jet), and isnormal(Jet). Change-Id: Ia881df03ba873e0560d67e976ab1e99e199eb523
diff --git a/examples/autodiff_codegen.cc b/examples/autodiff_codegen.cc index 788a7ce..7813ac4 100644 --- a/examples/autodiff_codegen.cc +++ b/examples/autodiff_codegen.cc
@@ -39,6 +39,7 @@ template <typename T> bool operator()(const T* x, T* residual) const { residual[0] = x[0] * x[0]; + isfinite(x[0]); return true; } };
diff --git a/include/ceres/codegen/internal/expression_ref.h b/include/ceres/codegen/internal/expression_ref.h index 6a04edb..1bdc3b5 100644 --- a/include/ceres/codegen/internal/expression_ref.h +++ b/include/ceres/codegen/internal/expression_ref.h
@@ -208,6 +208,10 @@ const ComparisonExpressionRef& y); ComparisonExpressionRef operator||(const ComparisonExpressionRef& x, const ComparisonExpressionRef& y); +ComparisonExpressionRef operator&(const ComparisonExpressionRef& x, + const ComparisonExpressionRef& y); +ComparisonExpressionRef operator|(const ComparisonExpressionRef& x, + const ComparisonExpressionRef& y); ComparisonExpressionRef operator!(const ComparisonExpressionRef& x); #define CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(name) \
diff --git a/include/ceres/codegen/macros.h b/include/ceres/codegen/macros.h index 0efc9e2..fbb2951 100644 --- a/include/ceres/codegen/macros.h +++ b/include/ceres/codegen/macros.h
@@ -116,4 +116,11 @@ AddExpressionToGraph(ceres::internal::Expression::CreateEndIf()); #endif +namespace ceres { +// A function equivalent to the ternary ?-operator. +// This function is required, because in the context of code generation a +// comparison returns an expression type which is not convertible to bool. +inline double Ternary(bool c, double a, double b) { return c ? a : b; } +} // namespace ceres + #endif // CERES_PUBLIC_CODEGEN_MACROS_H_
diff --git a/include/ceres/jet.h b/include/ceres/jet.h index fb7afce..00ecdca 100644 --- a/include/ceres/jet.h +++ b/include/ceres/jet.h
@@ -378,14 +378,6 @@ CERES_DEFINE_JET_COMPARISON_OPERATOR(!=) // NOLINT #undef CERES_DEFINE_JET_COMPARISON_OPERATOR -// A function equivalent to the ternary ?-operator. -// This function is required, because in the context of code generation a -// comparison returns an expression type which is not convertible to bool. -template <typename T> -inline T Ternary(bool c, T a, T b) { - return c ? a : b; -} - template <typename T, int N> inline Jet<T, N> Ternary(typename ComparisonReturnType<T>::type c, const Jet<T, N>& f, @@ -444,7 +436,7 @@ // abs(x + h) ~= x + h or -(x + h) template <typename T, int N> inline Jet<T, N> abs(const Jet<T, N>& f) { - return f.a < T(0.0) ? -f : f; + return Ternary(f.a < T(0.0), -f, f); } // log(a + h) ~= log(a) + h / a @@ -588,13 +580,13 @@ } template <typename T, int N> -inline const Jet<T, N>& fmax(const Jet<T, N>& x, const Jet<T, N>& y) { - return x < y ? y : x; +inline Jet<T, N> fmax(const Jet<T, N>& x, const Jet<T, N>& y) { + return Ternary(x < y, y, x); } template <typename T, int N> -inline const Jet<T, N>& fmin(const Jet<T, N>& x, const Jet<T, N>& y) { - return y < x ? y : x; +inline Jet<T, N> fmin(const Jet<T, N>& x, const Jet<T, N>& y) { + return Ternary(y < x, y, x); } // Bessel functions of the first kind with integer order equal to 0, 1, n. @@ -667,79 +659,65 @@ // The jet is finite if all parts of the jet are finite. template <typename T, int N> -inline bool isfinite(const Jet<T, N>& f) { - if (!std::isfinite(f.a)) { - return false; - } +inline typename ComparisonReturnType<T>::type isfinite(const Jet<T, N>& f) { + // Branchless implementation. This is more efficient for the false-case and + // works with the codegen system. + auto result = isfinite(f.a); for (int i = 0; i < N; ++i) { - if (!std::isfinite(f.v[i])) { - return false; - } + result = result & isfinite(f.v[i]); } - return true; + return result; } // The jet is infinite if any part of the Jet is infinite. template <typename T, int N> -inline bool isinf(const Jet<T, N>& f) { - if (std::isinf(f.a)) { - return true; - } +inline typename ComparisonReturnType<T>::type isinf(const Jet<T, N>& f) { + auto result = isinf(f.a); for (int i = 0; i < N; ++i) { - if (std::isinf(f.v[i])) { - return true; - } + result = result | isinf(f.v[i]); } - return false; + return result; } // The jet is NaN if any part of the jet is NaN. template <typename T, int N> -inline bool isnan(const Jet<T, N>& f) { - if (std::isnan(f.a)) { - return true; - } +inline typename ComparisonReturnType<T>::type isnan(const Jet<T, N>& f) { + auto result = isnan(f.a); for (int i = 0; i < N; ++i) { - if (std::isnan(f.v[i])) { - return true; - } + result = result | isnan(f.v[i]); } - return false; + return result; } // The jet is normal if all parts of the jet are normal. template <typename T, int N> -inline bool isnormal(const Jet<T, N>& f) { - if (!std::isnormal(f.a)) { - return false; - } +inline typename ComparisonReturnType<T>::type isnormal(const Jet<T, N>& f) { + auto result = isnormal(f.a); for (int i = 0; i < N; ++i) { - if (!std::isnormal(f.v[i])) { - return false; - } + result = result & isnormal(f.v[i]); } - return true; + return result; } // Legacy functions from the pre-C++11 days. template <typename T, int N> -inline bool IsFinite(const Jet<T, N>& f) { +inline typename ComparisonReturnType<T>::type IsFinite(const Jet<T, N>& f) { return isfinite(f); } template <typename T, int N> -inline bool IsNaN(const Jet<T, N>& f) { +inline typename ComparisonReturnType<T>::type IsNaN(const Jet<T, N>& f) { return isnan(f); } template <typename T, int N> -inline bool IsNormal(const Jet<T, N>& f) { +inline typename ComparisonReturnType<T>::type IsNormal(const Jet<T, N>& f) { return isnormal(f); } // The jet is infinite if any part of the jet is infinite. template <typename T, int N> -inline bool IsInfinite(const Jet<T, N>& f) { +inline typename ComparisonReturnType<T>::type IsInfinite(const Jet<T, N>& f) { return isinf(f); } @@ -778,25 +756,33 @@ // != 0, the derivatives are not defined and we return NaN. template <typename T, int N> -inline Jet<T, N> pow(double f, const Jet<T, N>& g) { - if (f == 0 && g.a > 0) { +inline Jet<T, N> pow(T f, const Jet<T, N>& g) { + Jet<T, N> result; + + CERES_IF(f == T(0) && g.a > T(0)) { // Handle case 2. - return Jet<T, N>(T(0.0)); + result = Jet<T, N>(T(0.0)); } - if (f < 0 && g.a == floor(g.a)) { - // Handle case 3. - Jet<T, N> ret(pow(f, g.a)); - for (int i = 0; i < N; i++) { - if (g.v[i] != T(0.0)) { - // Return a NaN when g.v != 0. - ret.v[i] = std::numeric_limits<T>::quiet_NaN(); + CERES_ELSE { + CERES_IF(f < 0 && g.a == floor(g.a)) { // Handle case 3. + result = Jet<T, N>(pow(f, g.a)); + for (int i = 0; i < N; i++) { + CERES_IF(g.v[i] != T(0.0)) { + // Return a NaN when g.v != 0. + result.v[i] = std::numeric_limits<T>::quiet_NaN(); + } + CERES_ENDIF } } - return ret; + CERES_ELSE { + // Handle case 1. + T const tmp = pow(f, g.a); + result = Jet<T, N>(tmp, log(f) * tmp * g.v); + } + CERES_ENDIF; } - // Handle case 1. - T const tmp = pow(f, g.a); - return Jet<T, N>(tmp, log(f) * tmp * g.v); + CERES_ENDIF + return result; } // pow -- both base and exponent are differentiable functions. This has a @@ -837,32 +823,40 @@ template <typename T, int N> inline Jet<T, N> pow(const Jet<T, N>& f, const Jet<T, N>& g) { - if (f.a == 0 && g.a >= 1) { + Jet<T, N> result; + + CERES_IF(f.a == T(0) && g.a >= T(1)) { // Handle cases 2 and 3. - if (g.a > 1) { - return Jet<T, N>(T(0.0)); - } - return f; + CERES_IF(g.a > T(1)) { result = Jet<T, N>(T(0.0)); } + CERES_ELSE { result = f; } + CERES_ENDIF; } - if (f.a < 0 && g.a == floor(g.a)) { - // Handle cases 7 and 8. - T const tmp = g.a * pow(f.a, g.a - T(1.0)); - Jet<T, N> ret(pow(f.a, g.a), tmp * f.v); - for (int i = 0; i < N; i++) { - if (g.v[i] != T(0.0)) { - // Return a NaN when g.v != 0. - ret.v[i] = std::numeric_limits<T>::quiet_NaN(); + CERES_ELSE { + CERES_IF(f.a < T(0) && g.a == floor(g.a)) { + // Handle cases 7 and 8. + T const tmp = g.a * pow(f.a, g.a - T(1.0)); + result = Jet<T, N>(pow(f.a, g.a), tmp * f.v); + for (int i = 0; i < N; i++) { + CERES_IF(g.v[i] != T(0.0)) { + // Return a NaN when g.v != 0. + result.v[i] = T(std::numeric_limits<double>::quiet_NaN()); + } + CERES_ENDIF; } } - return ret; + CERES_ELSE { + // Handle the remaining cases. For cases 4,5,6,9 we allow the log() + // function to generate -HUGE_VAL or NaN, since those cases result in a + // nonfinite derivative. + T const tmp1 = pow(f.a, g.a); + T const tmp2 = g.a * pow(f.a, g.a - T(1.0)); + T const tmp3 = tmp1 * log(f.a); + result = Jet<T, N>(tmp1, tmp2 * f.v + tmp3 * g.v); + } + CERES_ENDIF; } - // Handle the remaining cases. For cases 4,5,6,9 we allow the log() function - // to generate -HUGE_VAL or NaN, since those cases result in a nonfinite - // derivative. - T const tmp1 = pow(f.a, g.a); - T const tmp2 = g.a * pow(f.a, g.a - T(1.0)); - T const tmp3 = tmp1 * log(f.a); - return Jet<T, N>(tmp1, tmp2 * f.v + tmp3 * g.v); + CERES_ENDIF; + return result; } // Note: This has to be in the ceres namespace for argument dependent lookup to
diff --git a/include/ceres/rotation.h b/include/ceres/rotation.h index 7d5c8ef..ce06a2c 100644 --- a/include/ceres/rotation.h +++ b/include/ceres/rotation.h
@@ -48,7 +48,7 @@ #include <algorithm> #include <cmath> #include <limits> - +#include "codegen/macros.h" #include "glog/logging.h" namespace ceres { @@ -253,7 +253,7 @@ const T theta_squared = a0 * a0 + a1 * a1 + a2 * a2; // For points not at the origin, the full conversion is numerically stable. - if (theta_squared > T(0.0)) { + CERES_IF(theta_squared > T(0.0)) { const T theta = sqrt(theta_squared); const T half_theta = theta * T(0.5); const T k = sin(half_theta) / theta; @@ -261,7 +261,8 @@ quaternion[1] = a0 * k; quaternion[2] = a1 * k; quaternion[3] = a2 * k; - } else { + } + CERES_ELSE { // At the origin, sqrt() will produce NaN in the derivative since // the argument is zero. By approximating with a Taylor series, // and truncating at one term, the value and first derivatives will be @@ -272,6 +273,7 @@ quaternion[2] = a1 * k; quaternion[3] = a2 * k; } + CERES_ENDIF; } template <typename T> @@ -283,7 +285,7 @@ // For quaternions representing non-zero rotation, the conversion // is numerically stable. - if (sin_squared_theta > T(0.0)) { + CERES_IF(sin_squared_theta > T(0.0)) { const T sin_theta = sqrt(sin_squared_theta); const T& cos_theta = quaternion[0]; @@ -300,14 +302,15 @@ // theta - pi = atan(sin(theta - pi), cos(theta - pi)) // = atan(-sin(theta), -cos(theta)) // - const T two_theta = - T(2.0) * ((cos_theta < T(0.0)) ? atan2(-sin_theta, -cos_theta) - : atan2(sin_theta, cos_theta)); + const T two_theta = T(2.0) * Ternary((cos_theta < T(0.0)), + atan2(-sin_theta, -cos_theta), + atan2(sin_theta, cos_theta)); const T k = two_theta / sin_theta; angle_axis[0] = q1 * k; angle_axis[1] = q2 * k; angle_axis[2] = q3 * k; - } else { + } + CERES_ELSE { // For zero rotation, sqrt() will produce NaN in the derivative since // the argument is zero. By approximating with a Taylor series, // and truncating at one term, the value and first derivatives will be @@ -317,6 +320,7 @@ angle_axis[1] = q2 * k; angle_axis[2] = q3 * k; } + CERES_ENDIF; } template <typename T> @@ -387,7 +391,7 @@ const T* angle_axis, const MatrixAdapter<T, row_stride, col_stride>& R) { static const T kOne = T(1.0); const T theta2 = DotProduct(angle_axis, angle_axis); - if (theta2 > T(std::numeric_limits<double>::epsilon())) { + CERES_IF(theta2 > T(std::numeric_limits<double>::epsilon())) { // We want to be careful to only evaluate the square root if the // norm of the angle_axis vector is greater than zero. Otherwise // we get a division by zero. @@ -410,7 +414,8 @@ R(1, 2) = -wx*sintheta + wy*wz*(kOne - costheta); R(2, 2) = costheta + wz*wz*(kOne - costheta); // clang-format on - } else { + } + CERES_ELSE { // Near zero, we switch to using the first order Taylor expansion. R(0, 0) = kOne; R(1, 0) = angle_axis[2]; @@ -422,6 +427,7 @@ R(1, 2) = -angle_axis[0]; R(2, 2) = kOne; } + CERES_ENDIF; } template <typename T> @@ -591,7 +597,7 @@ DCHECK_NE(pt, result) << "Inplace rotation is not supported."; const T theta2 = DotProduct(angle_axis, angle_axis); - if (theta2 > T(std::numeric_limits<double>::epsilon())) { + CERES_IF(theta2 > T(std::numeric_limits<double>::epsilon())) { // Away from zero, use the rodriguez formula // // result = pt costheta + @@ -622,7 +628,8 @@ result[0] = pt[0] * costheta + w_cross_pt[0] * sintheta + w[0] * tmp; result[1] = pt[1] * costheta + w_cross_pt[1] * sintheta + w[1] * tmp; result[2] = pt[2] * costheta + w_cross_pt[2] * sintheta + w[2] * tmp; - } else { + } + CERES_ELSE { // Near zero, the first order Taylor approximation of the rotation // matrix R corresponding to a vector w and angle w is // @@ -648,6 +655,7 @@ result[1] = pt[1] + w_cross_pt[1]; result[2] = pt[2] + w_cross_pt[2]; } + CERES_ENDIF; } } // namespace ceres
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc index 7c43595..52b7e0b 100644 --- a/internal/ceres/expression_ref.cc +++ b/internal/ceres/expression_ref.cc
@@ -190,6 +190,8 @@ CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR(!=) CERES_DEFINE_EXPRESSION_LOGICAL_OPERATOR(&&) CERES_DEFINE_EXPRESSION_LOGICAL_OPERATOR(||) +CERES_DEFINE_EXPRESSION_LOGICAL_OPERATOR(&) +CERES_DEFINE_EXPRESSION_LOGICAL_OPERATOR(|) #undef CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR #undef CERES_DEFINE_EXPRESSION_LOGICAL_OPERATOR