Fix fmin/fmax() to use Jet averaging on equality - Prior to 48cb54d1, Ceres' fmin/fmax() for Jets followed the convention of std::min/max(), and always returned the first argument on equality, irrespective of whether this argument was natively a scalar or a Jet. - After 48cb54d1, Ceres' fmin/fmax() instead returned the second argument on equality, again irrespective of whether this argument was natively a scalar or a Jet. - Now on equality we average the arguments as Jets, which ensures that a consistent answer is produced irrespective of the ordering or type (Jet or scalar) of the input arguments. This also ensures that we preserve a non-zero derivative where it exists, excluding the edge case of two Jet inputs with equal but oppositely signed infinitesimal components. - We retain the behaviour introduced in 48cb54d1 whereby NaNs are treated as missing values, following the convention of std::fmin/fmax(). - Raised as issue #816. Change-Id: I01217c0e32c1be83be440e4515b57c79dd290923
diff --git a/include/ceres/jet.h b/include/ceres/jet.h index 2aefdc6..90f916c 100644 --- a/include/ceres/jet.h +++ b/include/ceres/jet.h
@@ -755,28 +755,76 @@ return Jet<T, N>(fma(x.a, y.a, z.a), y.a * x.v + x.a * y.v + z.v); } -// Returns the larger of the two arguments. NaNs are treated as missing data. +// Return value of fmax() and fmin() on equality +// --------------------------------------------- +// +// There is arguably no good answer to what fmax() & fmin() should return on +// equality, which for Jets by definition ONLY compares the scalar parts. We +// choose what we think is the least worst option (averaging as Jets) which +// minimises undesirable/unexpected behaviour as used, and also supports client +// code written against Ceres versions prior to type promotion being supported +// in Jet comparisons (< v2.1). +// +// The std::max() convention of returning the first argument on equality is +// problematic, as it means that the derivative component may or may not be +// preserved (when comparing a Jet with a scalar) depending upon the ordering. +// +// Always returning the Jet in {Jet, scalar} cases on equality is problematic +// as it is inconsistent with the behaviour that would be obtained if the scalar +// was first cast to Jet and the {Jet, Jet} case was used. Prior to type +// promotion (Ceres v2.1) client code would typically cast constants to Jets +// e.g: fmax(x, T(2.0)) which means the {Jet, Jet} case predominates, and we +// still want the result to be order independent. +// +// Our intuition is that preserving a non-zero derivative is best, even if +// its value does not match either of the inputs. Averaging achieves this +// whilst ensuring argument ordering independence. This is also the approach +// used by the Jax library, and TensorFlow's reduce_max(). + +// Returns the larger of the two arguments, with Jet averaging on equality. +// NaNs are treated as missing data. // // NOTE: This function is NOT subject to any of the error conditions specified -// in `math_errhandling`. +// in `math_errhandling`. template <typename Lhs, typename Rhs, std::enable_if_t<CompatibleJetOperands_v<Lhs, Rhs>>* = nullptr> -inline decltype(auto) fmax(const Lhs& f, const Rhs& g) { +inline decltype(auto) fmax(const Lhs& x, const Rhs& y) { using J = std::common_type_t<Lhs, Rhs>; - return (isnan(g) || isgreater(f, g)) ? J{f} : J{g}; + // As x == y may set FP exceptions in the presence of NaNs when used with + // non-default compiler options so we avoid its use here. + if (isnan(x) || isnan(y) || islessgreater(x, y)) { + return isnan(x) || isless(x, y) ? J{y} : J{x}; + } + // x == y (scalar parts) return the average of their Jet representations. +#if defined(CERES_HAS_CPP20) + return midpoint(J{x}, J{y}); +#else + return (J{x} + J{y}) * 0.5; +#endif // defined(CERES_HAS_CPP20) } -// Returns the smaller of the two arguments. NaNs are treated as missing data. +// Returns the smaller of the two arguments, with Jet averaging on equality. +// NaNs are treated as missing data. // // NOTE: This function is NOT subject to any of the error conditions specified -// in `math_errhandling`. +// in `math_errhandling`. template <typename Lhs, typename Rhs, std::enable_if_t<CompatibleJetOperands_v<Lhs, Rhs>>* = nullptr> -inline decltype(auto) fmin(const Lhs& f, const Rhs& g) { +inline decltype(auto) fmin(const Lhs& x, const Rhs& y) { using J = std::common_type_t<Lhs, Rhs>; - return (isnan(f) || isless(g, f)) ? J{g} : J{f}; + // As x == y may set FP exceptions in the presence of NaNs when used with + // non-default compiler options so we avoid its use here. + if (isnan(x) || isnan(y) || islessgreater(x, y)) { + return isnan(x) || isgreater(x, y) ? J{y} : J{x}; + } + // x == y (scalar parts) return the average of their Jet representations. +#if defined(CERES_HAS_CPP20) + return midpoint(J{x}, J{y}); +#else + return (J{x} + J{y}) * 0.5; +#endif // defined(CERES_HAS_CPP20) } // Returns the positive difference (f - g) of two arguments and zero if f <= g.
diff --git a/internal/ceres/jet_test.cc b/internal/ceres/jet_test.cc index 1c85d01..72dff01 100644 --- a/internal/ceres/jet_test.cc +++ b/internal/ceres/jet_test.cc
@@ -682,7 +682,7 @@ EXPECT_THAT(v, IsAlmostEqualTo(w)); } -TEST(Jet, Fmax) { +TEST(Jet, FmaxJetWithJet) { Fenv env; // Clear all exceptions to ensure none are set by the following function // calls. @@ -690,21 +690,54 @@ EXPECT_THAT(fmax(x, y), IsAlmostEqualTo(x)); EXPECT_THAT(fmax(y, x), IsAlmostEqualTo(x)); - EXPECT_THAT(fmax(x, y.a), IsAlmostEqualTo(x)); - EXPECT_THAT(fmax(y.a, x), IsAlmostEqualTo(x)); - EXPECT_THAT(fmax(y, x.a), IsAlmostEqualTo(J{x.a})); - EXPECT_THAT(fmax(x.a, y), IsAlmostEqualTo(J{x.a})); - EXPECT_THAT(fmax(x, std::numeric_limits<double>::quiet_NaN()), - IsAlmostEqualTo(x)); - EXPECT_THAT(fmax(std::numeric_limits<double>::quiet_NaN(), x), - IsAlmostEqualTo(x)); + + // Average the Jets on equality (of scalar parts). + const J scalar_part_only_equal_to_x = J(x.a, 2 * x.v); + const J average = (x + scalar_part_only_equal_to_x) * 0.5; + EXPECT_THAT(fmax(x, scalar_part_only_equal_to_x), IsAlmostEqualTo(average)); + EXPECT_THAT(fmax(scalar_part_only_equal_to_x, x), IsAlmostEqualTo(average)); + + // Follow convention of fmax(): treat NANs as missing values. + const J nan_scalar_part(std::numeric_limits<double>::quiet_NaN(), 2 * x.v); + EXPECT_THAT(fmax(x, nan_scalar_part), IsAlmostEqualTo(x)); + EXPECT_THAT(fmax(nan_scalar_part, x), IsAlmostEqualTo(x)); #ifndef CERES_NO_FENV_ACCESS EXPECT_EQ(std::fetestexcept(FE_ALL_EXCEPT & ~FE_INEXACT), 0); #endif } -TEST(Jet, Fmin) { +TEST(Jet, FmaxJetWithScalar) { + Fenv env; + // Clear all exceptions to ensure none are set by the following function + // calls. + std::feclearexcept(FE_ALL_EXCEPT); + + EXPECT_THAT(fmax(x, y.a), IsAlmostEqualTo(x)); + EXPECT_THAT(fmax(y.a, x), IsAlmostEqualTo(x)); + EXPECT_THAT(fmax(y, x.a), IsAlmostEqualTo(J{x.a})); + EXPECT_THAT(fmax(x.a, y), IsAlmostEqualTo(J{x.a})); + + // Average the Jet and scalar cast to a Jet on equality (of scalar parts). + const J average = (x + J{x.a}) * 0.5; + EXPECT_THAT(fmax(x, x.a), IsAlmostEqualTo(average)); + EXPECT_THAT(fmax(x.a, x), IsAlmostEqualTo(average)); + + // Follow convention of fmax(): treat NANs as missing values. + EXPECT_THAT(fmax(x, std::numeric_limits<double>::quiet_NaN()), + IsAlmostEqualTo(x)); + EXPECT_THAT(fmax(std::numeric_limits<double>::quiet_NaN(), x), + IsAlmostEqualTo(x)); + const J nan_scalar_part(std::numeric_limits<double>::quiet_NaN(), 2 * x.v); + EXPECT_THAT(fmax(nan_scalar_part, x.a), IsAlmostEqualTo(J{x.a})); + EXPECT_THAT(fmax(x.a, nan_scalar_part), IsAlmostEqualTo(J{x.a})); + +#ifndef CERES_NO_FENV_ACCESS + EXPECT_EQ(std::fetestexcept(FE_ALL_EXCEPT & ~FE_INEXACT), 0); +#endif +} + +TEST(Jet, FminJetWithJet) { Fenv env; // Clear all exceptions to ensure none are set by the following function // calls. @@ -712,14 +745,47 @@ EXPECT_THAT(fmin(x, y), IsAlmostEqualTo(y)); EXPECT_THAT(fmin(y, x), IsAlmostEqualTo(y)); + + // Average the Jets on equality (of scalar parts). + const J scalar_part_only_equal_to_x = J(x.a, 2 * x.v); + const J average = (x + scalar_part_only_equal_to_x) * 0.5; + EXPECT_THAT(fmin(x, scalar_part_only_equal_to_x), IsAlmostEqualTo(average)); + EXPECT_THAT(fmin(scalar_part_only_equal_to_x, x), IsAlmostEqualTo(average)); + + // Follow convention of fmin(): treat NANs as missing values. + const J nan_scalar_part(std::numeric_limits<double>::quiet_NaN(), 2 * x.v); + EXPECT_THAT(fmin(x, nan_scalar_part), IsAlmostEqualTo(x)); + EXPECT_THAT(fmin(nan_scalar_part, x), IsAlmostEqualTo(x)); + +#ifndef CERES_NO_FENV_ACCESS + EXPECT_EQ(std::fetestexcept(FE_ALL_EXCEPT & ~FE_INEXACT), 0); +#endif +} + +TEST(Jet, FminJetWithScalar) { + Fenv env; + // Clear all exceptions to ensure none are set by the following function + // calls. + std::feclearexcept(FE_ALL_EXCEPT); + EXPECT_THAT(fmin(x, y.a), IsAlmostEqualTo(J{y.a})); EXPECT_THAT(fmin(y.a, x), IsAlmostEqualTo(J{y.a})); EXPECT_THAT(fmin(y, x.a), IsAlmostEqualTo(y)); EXPECT_THAT(fmin(x.a, y), IsAlmostEqualTo(y)); + + // Average the Jet and scalar cast to a Jet on equality (of scalar parts). + const J average = (x + J{x.a}) * 0.5; + EXPECT_THAT(fmin(x, x.a), IsAlmostEqualTo(average)); + EXPECT_THAT(fmin(x.a, x), IsAlmostEqualTo(average)); + + // Follow convention of fmin(): treat NANs as missing values. EXPECT_THAT(fmin(x, std::numeric_limits<double>::quiet_NaN()), IsAlmostEqualTo(x)); EXPECT_THAT(fmin(std::numeric_limits<double>::quiet_NaN(), x), IsAlmostEqualTo(x)); + const J nan_scalar_part(std::numeric_limits<double>::quiet_NaN(), 2 * x.v); + EXPECT_THAT(fmin(nan_scalar_part, x.a), IsAlmostEqualTo(J{x.a})); + EXPECT_THAT(fmin(x.a, nan_scalar_part), IsAlmostEqualTo(J{x.a})); #ifndef CERES_NO_FENV_ACCESS EXPECT_EQ(std::fetestexcept(FE_ALL_EXCEPT & ~FE_INEXACT), 0);