Fix fmin/fmax() to use Jet averaging on equality

- Prior to 48cb54d1, Ceres' fmin/fmax() for Jets followed the convention
  of std::min/max(), and always returned the first argument on equality,
  irrespective of whether this argument was natively a scalar or a Jet.
- After 48cb54d1, Ceres' fmin/fmax() instead returned the second
  argument on equality, again irrespective of whether this argument was
  natively a scalar or a Jet.
- Now on equality we average the arguments as Jets, which ensures that
  a consistent answer is produced irrespective of the ordering or type
  (Jet or scalar) of the input arguments. This also ensures that we
  preserve a non-zero derivative where it exists, excluding the edge
  case of two Jet inputs with equal but oppositely signed infinitesimal
  components.
- We retain the behaviour introduced in 48cb54d1 whereby NaNs are
  treated as missing values, following the convention of
  std::fmin/fmax().
- Raised as issue #816.

Change-Id: I01217c0e32c1be83be440e4515b57c79dd290923
diff --git a/include/ceres/jet.h b/include/ceres/jet.h
index 2aefdc6..90f916c 100644
--- a/include/ceres/jet.h
+++ b/include/ceres/jet.h
@@ -755,28 +755,76 @@
   return Jet<T, N>(fma(x.a, y.a, z.a), y.a * x.v + x.a * y.v + z.v);
 }
 
-// Returns the larger of the two arguments. NaNs are treated as missing data.
+// Return value of fmax() and fmin() on equality
+// ---------------------------------------------
+//
+// There is arguably no good answer to what fmax() & fmin() should return on
+// equality, which for Jets by definition ONLY compares the scalar parts. We
+// choose what we think is the least worst option (averaging as Jets) which
+// minimises undesirable/unexpected behaviour as used, and also supports client
+// code written against Ceres versions prior to type promotion being supported
+// in Jet comparisons (< v2.1).
+//
+// The std::max() convention of returning the first argument on equality is
+// problematic, as it means that the derivative component may or may not be
+// preserved (when comparing a Jet with a scalar) depending upon the ordering.
+//
+// Always returning the Jet in {Jet, scalar} cases on equality is problematic
+// as it is inconsistent with the behaviour that would be obtained if the scalar
+// was first cast to Jet and the {Jet, Jet} case was used. Prior to type
+// promotion (Ceres v2.1) client code would typically cast constants to Jets
+// e.g: fmax(x, T(2.0)) which means the {Jet, Jet} case predominates, and we
+// still want the result to be order independent.
+//
+// Our intuition is that preserving a non-zero derivative is best, even if
+// its value does not match either of the inputs. Averaging achieves this
+// whilst ensuring argument ordering independence. This is also the approach
+// used by the Jax library, and TensorFlow's reduce_max().
+
+// Returns the larger of the two arguments, with Jet averaging on equality.
+// NaNs are treated as missing data.
 //
 // NOTE: This function is NOT subject to any of the error conditions specified
-// in `math_errhandling`.
+//       in `math_errhandling`.
 template <typename Lhs,
           typename Rhs,
           std::enable_if_t<CompatibleJetOperands_v<Lhs, Rhs>>* = nullptr>
-inline decltype(auto) fmax(const Lhs& f, const Rhs& g) {
+inline decltype(auto) fmax(const Lhs& x, const Rhs& y) {
   using J = std::common_type_t<Lhs, Rhs>;
-  return (isnan(g) || isgreater(f, g)) ? J{f} : J{g};
+  // As x == y may set FP exceptions in the presence of NaNs when used with
+  // non-default compiler options so we avoid its use here.
+  if (isnan(x) || isnan(y) || islessgreater(x, y)) {
+    return isnan(x) || isless(x, y) ? J{y} : J{x};
+  }
+  // x == y (scalar parts) return the average of their Jet representations.
+#if defined(CERES_HAS_CPP20)
+  return midpoint(J{x}, J{y});
+#else
+  return (J{x} + J{y}) * 0.5;
+#endif  // defined(CERES_HAS_CPP20)
 }
 
-// Returns the smaller of the two arguments. NaNs are treated as missing data.
+// Returns the smaller of the two arguments, with Jet averaging on equality.
+// NaNs are treated as missing data.
 //
 // NOTE: This function is NOT subject to any of the error conditions specified
-// in `math_errhandling`.
+//       in `math_errhandling`.
 template <typename Lhs,
           typename Rhs,
           std::enable_if_t<CompatibleJetOperands_v<Lhs, Rhs>>* = nullptr>
-inline decltype(auto) fmin(const Lhs& f, const Rhs& g) {
+inline decltype(auto) fmin(const Lhs& x, const Rhs& y) {
   using J = std::common_type_t<Lhs, Rhs>;
-  return (isnan(f) || isless(g, f)) ? J{g} : J{f};
+  // As x == y may set FP exceptions in the presence of NaNs when used with
+  // non-default compiler options so we avoid its use here.
+  if (isnan(x) || isnan(y) || islessgreater(x, y)) {
+    return isnan(x) || isgreater(x, y) ? J{y} : J{x};
+  }
+  // x == y (scalar parts) return the average of their Jet representations.
+#if defined(CERES_HAS_CPP20)
+  return midpoint(J{x}, J{y});
+#else
+  return (J{x} + J{y}) * 0.5;
+#endif  // defined(CERES_HAS_CPP20)
 }
 
 // Returns the positive difference (f - g) of two arguments and zero if f <= g.
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc
index 1c85d01..72dff01 100644
--- a/internal/ceres/jet_test.cc
+++ b/internal/ceres/jet_test.cc
@@ -682,7 +682,7 @@
   EXPECT_THAT(v, IsAlmostEqualTo(w));
 }
 
-TEST(Jet, Fmax) {
+TEST(Jet, FmaxJetWithJet) {
   Fenv env;
   // Clear all exceptions to ensure none are set by the following function
   // calls.
@@ -690,21 +690,54 @@
 
   EXPECT_THAT(fmax(x, y), IsAlmostEqualTo(x));
   EXPECT_THAT(fmax(y, x), IsAlmostEqualTo(x));
-  EXPECT_THAT(fmax(x, y.a), IsAlmostEqualTo(x));
-  EXPECT_THAT(fmax(y.a, x), IsAlmostEqualTo(x));
-  EXPECT_THAT(fmax(y, x.a), IsAlmostEqualTo(J{x.a}));
-  EXPECT_THAT(fmax(x.a, y), IsAlmostEqualTo(J{x.a}));
-  EXPECT_THAT(fmax(x, std::numeric_limits<double>::quiet_NaN()),
-              IsAlmostEqualTo(x));
-  EXPECT_THAT(fmax(std::numeric_limits<double>::quiet_NaN(), x),
-              IsAlmostEqualTo(x));
+
+  // Average the Jets on equality (of scalar parts).
+  const J scalar_part_only_equal_to_x = J(x.a, 2 * x.v);
+  const J average = (x + scalar_part_only_equal_to_x) * 0.5;
+  EXPECT_THAT(fmax(x, scalar_part_only_equal_to_x), IsAlmostEqualTo(average));
+  EXPECT_THAT(fmax(scalar_part_only_equal_to_x, x), IsAlmostEqualTo(average));
+
+  // Follow convention of fmax(): treat NANs as missing values.
+  const J nan_scalar_part(std::numeric_limits<double>::quiet_NaN(), 2 * x.v);
+  EXPECT_THAT(fmax(x, nan_scalar_part), IsAlmostEqualTo(x));
+  EXPECT_THAT(fmax(nan_scalar_part, x), IsAlmostEqualTo(x));
 
 #ifndef CERES_NO_FENV_ACCESS
   EXPECT_EQ(std::fetestexcept(FE_ALL_EXCEPT & ~FE_INEXACT), 0);
 #endif
 }
 
-TEST(Jet, Fmin) {
+TEST(Jet, FmaxJetWithScalar) {
+  Fenv env;
+  // Clear all exceptions to ensure none are set by the following function
+  // calls.
+  std::feclearexcept(FE_ALL_EXCEPT);
+
+  EXPECT_THAT(fmax(x, y.a), IsAlmostEqualTo(x));
+  EXPECT_THAT(fmax(y.a, x), IsAlmostEqualTo(x));
+  EXPECT_THAT(fmax(y, x.a), IsAlmostEqualTo(J{x.a}));
+  EXPECT_THAT(fmax(x.a, y), IsAlmostEqualTo(J{x.a}));
+
+  // Average the Jet and scalar cast to a Jet on equality (of scalar parts).
+  const J average = (x + J{x.a}) * 0.5;
+  EXPECT_THAT(fmax(x, x.a), IsAlmostEqualTo(average));
+  EXPECT_THAT(fmax(x.a, x), IsAlmostEqualTo(average));
+
+  // Follow convention of fmax(): treat NANs as missing values.
+  EXPECT_THAT(fmax(x, std::numeric_limits<double>::quiet_NaN()),
+              IsAlmostEqualTo(x));
+  EXPECT_THAT(fmax(std::numeric_limits<double>::quiet_NaN(), x),
+              IsAlmostEqualTo(x));
+  const J nan_scalar_part(std::numeric_limits<double>::quiet_NaN(), 2 * x.v);
+  EXPECT_THAT(fmax(nan_scalar_part, x.a), IsAlmostEqualTo(J{x.a}));
+  EXPECT_THAT(fmax(x.a, nan_scalar_part), IsAlmostEqualTo(J{x.a}));
+
+#ifndef CERES_NO_FENV_ACCESS
+  EXPECT_EQ(std::fetestexcept(FE_ALL_EXCEPT & ~FE_INEXACT), 0);
+#endif
+}
+
+TEST(Jet, FminJetWithJet) {
   Fenv env;
   // Clear all exceptions to ensure none are set by the following function
   // calls.
@@ -712,14 +745,47 @@
 
   EXPECT_THAT(fmin(x, y), IsAlmostEqualTo(y));
   EXPECT_THAT(fmin(y, x), IsAlmostEqualTo(y));
+
+  // Average the Jets on equality (of scalar parts).
+  const J scalar_part_only_equal_to_x = J(x.a, 2 * x.v);
+  const J average = (x + scalar_part_only_equal_to_x) * 0.5;
+  EXPECT_THAT(fmin(x, scalar_part_only_equal_to_x), IsAlmostEqualTo(average));
+  EXPECT_THAT(fmin(scalar_part_only_equal_to_x, x), IsAlmostEqualTo(average));
+
+  // Follow convention of fmin(): treat NANs as missing values.
+  const J nan_scalar_part(std::numeric_limits<double>::quiet_NaN(), 2 * x.v);
+  EXPECT_THAT(fmin(x, nan_scalar_part), IsAlmostEqualTo(x));
+  EXPECT_THAT(fmin(nan_scalar_part, x), IsAlmostEqualTo(x));
+
+#ifndef CERES_NO_FENV_ACCESS
+  EXPECT_EQ(std::fetestexcept(FE_ALL_EXCEPT & ~FE_INEXACT), 0);
+#endif
+}
+
+TEST(Jet, FminJetWithScalar) {
+  Fenv env;
+  // Clear all exceptions to ensure none are set by the following function
+  // calls.
+  std::feclearexcept(FE_ALL_EXCEPT);
+
   EXPECT_THAT(fmin(x, y.a), IsAlmostEqualTo(J{y.a}));
   EXPECT_THAT(fmin(y.a, x), IsAlmostEqualTo(J{y.a}));
   EXPECT_THAT(fmin(y, x.a), IsAlmostEqualTo(y));
   EXPECT_THAT(fmin(x.a, y), IsAlmostEqualTo(y));
+
+  // Average the Jet and scalar cast to a Jet on equality (of scalar parts).
+  const J average = (x + J{x.a}) * 0.5;
+  EXPECT_THAT(fmin(x, x.a), IsAlmostEqualTo(average));
+  EXPECT_THAT(fmin(x.a, x), IsAlmostEqualTo(average));
+
+  // Follow convention of fmin(): treat NANs as missing values.
   EXPECT_THAT(fmin(x, std::numeric_limits<double>::quiet_NaN()),
               IsAlmostEqualTo(x));
   EXPECT_THAT(fmin(std::numeric_limits<double>::quiet_NaN(), x),
               IsAlmostEqualTo(x));
+  const J nan_scalar_part(std::numeric_limits<double>::quiet_NaN(), 2 * x.v);
+  EXPECT_THAT(fmin(nan_scalar_part, x.a), IsAlmostEqualTo(J{x.a}));
+  EXPECT_THAT(fmin(x.a, nan_scalar_part), IsAlmostEqualTo(J{x.a}));
 
 #ifndef CERES_NO_FENV_ACCESS
   EXPECT_EQ(std::fetestexcept(FE_ALL_EXCEPT & ~FE_INEXACT), 0);