Fix QuaternionToAngleAxis to ensure rotations are between -pi and pi.

Thanks to Guoxuan Zhang for reporting this.

Change-Id: I2831ca3a04d5dc6467849c290461adbe23faaea3
diff --git a/include/ceres/rotation.h b/include/ceres/rotation.h
index 7f05187..0d8a390 100644
--- a/include/ceres/rotation.h
+++ b/include/ceres/rotation.h
@@ -148,9 +148,9 @@
 
 template<typename T>
 inline void AngleAxisToQuaternion(const T* angle_axis, T* quaternion) {
-  const T &a0 = angle_axis[0];
-  const T &a1 = angle_axis[1];
-  const T &a2 = angle_axis[2];
+  const T& a0 = angle_axis[0];
+  const T& a1 = angle_axis[1];
+  const T& a2 = angle_axis[2];
   const T theta_squared = a0 * a0 + a1 * a1 + a2 * a2;
 
   // For points not at the origin, the full conversion is numerically stable.
@@ -177,16 +177,35 @@
 
 template<typename T>
 inline void QuaternionToAngleAxis(const T* quaternion, T* angle_axis) {
-  const T &q1 = quaternion[1];
-  const T &q2 = quaternion[2];
-  const T &q3 = quaternion[3];
-  const T sin_squared = q1 * q1 + q2 * q2 + q3 * q3;
+  const T& q1 = quaternion[1];
+  const T& q2 = quaternion[2];
+  const T& q3 = quaternion[3];
+  const T sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3;
 
   // For quaternions representing non-zero rotation, the conversion
   // is numerically stable.
-  if (sin_squared > T(0.0)) {
-    const T sin_theta = sqrt(sin_squared);
-    const T k = T(2.0) * atan2(sin_theta, quaternion[0]) / sin_theta;
+  if (sin_squared_theta > T(0.0)) {
+    const T sin_theta = sqrt(sin_squared_theta);
+    const T& cos_theta = quaternion[0];
+
+    // If cos_theta is negative, theta is greater than pi/2, which
+    // means that angle for the angle_axis vector which is 2 * theta
+    // would be greater than pi.
+    //
+    // While this will result in the correct rotation, it does not
+    // result in a normalized angle-axis vector.
+    //
+    // In that case we observe that 2 * theta ~ 2 * theta - 2 * pi,
+    // which is equivalent saying
+    //
+    //   theta - pi = atan(sin(theta - pi), cos(theta - pi))
+    //              = atan(-sin(theta), -cos(theta))
+    //
+    const T two_theta =
+        T(2.0) * ((cos_theta < 0.0)
+                  ? atan2(-sin_theta, -cos_theta)
+                  : atan2(sin_theta, cos_theta));
+    const T k = two_theta / sin_theta;
     angle_axis[0] = q1 * k;
     angle_axis[1] = q2 * k;
     angle_axis[2] = q3 * k;
diff --git a/internal/ceres/rotation_test.cc b/internal/ceres/rotation_test.cc
index 65421fa..8e40507 100644
--- a/internal/ceres/rotation_test.cc
+++ b/internal/ceres/rotation_test.cc
@@ -87,14 +87,43 @@
     return false;
   }
 
+  // Quaternions are equivalent upto a sign change. So we will compare
+  // both signs before declaring failure.
+  bool near = true;
   for (int i = 0; i < 4; i++) {
     if (fabs(arg[i] - expected[i]) > kTolerance) {
-      *result_listener << "component " << i << " should be " << expected[i];
-      return false;
+      near = false;
+      break;
     }
   }
 
-  return true;
+  if (near) {
+    return true;
+  }
+
+  near = true;
+  for (int i = 0; i < 4; i++) {
+    if (fabs(arg[i] + expected[i]) > kTolerance) {
+      near = false;
+      break;
+    }
+  }
+
+  if (near) {
+    return true;
+  }
+
+  *result_listener << "expected : "
+                   << expected[0] << " "
+                   << expected[1] << " "
+                   << expected[2] << " "
+                   << expected[3] << " "
+                   << "actual : "
+                   << arg[0] << " "
+                   << arg[1] << " "
+                   << arg[2] << " "
+                   << arg[3];
+  return false;
 }
 
 // Use as:
@@ -259,6 +288,23 @@
   EXPECT_THAT(axis_angle, IsNearAngleAxis(expected));
 }
 
+TEST(Rotation, QuaternionToAngleAxisAngleIsLessThanPi) {
+  double quaternion[4];
+  double angle_axis[3];
+
+  const double half_theta = 0.75 * kPi;
+
+  quaternion[0] = cos(half_theta);
+  quaternion[1] = 1.0 * sin(half_theta);
+  quaternion[2] = 0.0;
+  quaternion[3] = 0.0;
+  QuaternionToAngleAxis(quaternion, angle_axis);
+  const double angle = sqrt(angle_axis[0] * angle_axis[0] +
+                            angle_axis[1] * angle_axis[1] +
+                            angle_axis[2] * angle_axis[2]);
+  EXPECT_LE(angle, kPi);
+}
+
 static const int kNumTrials = 10000;
 
 // Takes a bunch of random axis/angle values, converts them to quaternions,