Simplify some template metaprograms using fold expressions.
Change-Id: I865b670b99df30db39d33cbfe45170b70472e532
diff --git a/include/ceres/internal/integer_sequence_algorithm.h b/include/ceres/internal/integer_sequence_algorithm.h
index 777c119..80c821c 100644
--- a/include/ceres/internal/integer_sequence_algorithm.h
+++ b/include/ceres/internal/integer_sequence_algorithm.h
@@ -43,68 +43,6 @@
namespace ceres {
namespace internal {
-// Implementation of calculating the sum of an integer sequence.
-// Recursively instantiate SumImpl and calculate the sum of the N first
-// numbers. This reduces the number of instantiations and speeds up
-// compilation.
-//
-// Examples:
-// 1) integer_sequence<int, 5>:
-// Value = 5
-//
-// 2) integer_sequence<int, 4, 2>:
-// Value = 4 + 2 + SumImpl<integer_sequence<int>>::Value
-// Value = 4 + 2 + 0
-//
-// 3) integer_sequence<int, 2, 1, 4>:
-// Value = 2 + 1 + SumImpl<integer_sequence<int, 4>>::Value
-// Value = 2 + 1 + 4
-template <typename Seq>
-struct SumImpl;
-
-// Strip of and sum the first number.
-template <typename T, T N, T... Ns>
-struct SumImpl<std::integer_sequence<T, N, Ns...>> {
- static constexpr T Value =
- N + SumImpl<std::integer_sequence<T, Ns...>>::Value;
-};
-
-// Strip of and sum the first two numbers.
-template <typename T, T N1, T N2, T... Ns>
-struct SumImpl<std::integer_sequence<T, N1, N2, Ns...>> {
- static constexpr T Value =
- N1 + N2 + SumImpl<std::integer_sequence<T, Ns...>>::Value;
-};
-
-// Strip of and sum the first four numbers.
-template <typename T, T N1, T N2, T N3, T N4, T... Ns>
-struct SumImpl<std::integer_sequence<T, N1, N2, N3, N4, Ns...>> {
- static constexpr T Value =
- N1 + N2 + N3 + N4 + SumImpl<std::integer_sequence<T, Ns...>>::Value;
-};
-
-// Only one number is left. 'Value' is just that number ('recursion' ends).
-template <typename T, T N>
-struct SumImpl<std::integer_sequence<T, N>> {
- static constexpr T Value = N;
-};
-
-// No number is left. 'Value' is the identity element (for sum this is zero).
-template <typename T>
-struct SumImpl<std::integer_sequence<T>> {
- static constexpr T Value = T(0);
-};
-
-// Calculate the sum of an integer sequence. The resulting sum will be stored in
-// 'Value'.
-template <typename Seq>
-class Sum {
- using T = typename Seq::value_type;
-
- public:
- static constexpr T Value = SumImpl<Seq>::Value;
-};
-
// Implementation of calculating an exclusive scan (exclusive prefix sum) of an
// integer sequence. Exclusive means that the i-th input element is not included
// in the i-th sum. Calculating the exclusive scan for an input array I results
@@ -232,40 +170,11 @@
template <typename Sequence, typename Sequence::value_type ValueToRemove>
using RemoveValue_t = typename RemoveValue<Sequence, ValueToRemove>::type;
-// Determines whether the values of an integer sequence are all the same.
+// Returns true if all elements of Values are equal to HeadValue.
//
-// The integer sequence must contain at least one value. The predicate is
-// undefined for empty sequences. The evaluation result of the predicate for a
-// sequence containing only one value is defined to be true.
-template <typename... Sequence>
-struct AreAllEqual;
-
-// The predicate result for a sequence containing one element is defined to be
-// true.
-template <typename T, T Value>
-struct AreAllEqual<std::integer_sequence<T, Value>> : std::true_type {};
-
-// Recursion end.
-template <typename T, T Value1, T Value2>
-struct AreAllEqual<std::integer_sequence<T, Value1, Value2>>
- : std::integral_constant<bool, Value1 == Value2> {};
-
-// Recursion for sequences containing at least two elements.
-template <typename T, T Value1, T Value2, T... Values>
-// clang-format off
-struct AreAllEqual<std::integer_sequence<T, Value1, Value2, Values...> >
- : std::integral_constant
-<
- bool,
- AreAllEqual<std::integer_sequence<T, Value1, Value2> >::value &&
- AreAllEqual<std::integer_sequence<T, Value2, Values...> >::value
->
-// clang-format on
-{};
-
-// Convenience variable template for AreAllEqual.
-template <class Sequence>
-constexpr bool AreAllEqual_v = AreAllEqual<Sequence>::value;
+// Returns true if Values is empty.
+template <typename T, T HeadValue, T... Values>
+inline constexpr bool AreAllEqual_v = ((HeadValue == Values) && ...);
// Predicate determining whether an integer sequence is either empty or all
// values are equal.
@@ -279,11 +188,12 @@
// General case for sequences containing at least one value.
template <typename T, T HeadValue, T... Values>
struct IsEmptyOrAreAllEqual<std::integer_sequence<T, HeadValue, Values...>>
- : AreAllEqual<std::integer_sequence<T, HeadValue, Values...>> {};
+ : std::integral_constant<bool, AreAllEqual_v<T, HeadValue, Values...>> {};
// Convenience variable template for IsEmptyOrAreAllEqual.
template <class Sequence>
-constexpr bool IsEmptyOrAreAllEqual_v = IsEmptyOrAreAllEqual<Sequence>::value;
+inline constexpr bool IsEmptyOrAreAllEqual_v =
+ IsEmptyOrAreAllEqual<Sequence>::value;
} // namespace internal
} // namespace ceres
diff --git a/include/ceres/internal/jet_traits.h b/include/ceres/internal/jet_traits.h
index 2a38c05..746638f 100644
--- a/include/ceres/internal/jet_traits.h
+++ b/include/ceres/internal/jet_traits.h
@@ -42,17 +42,6 @@
namespace ceres {
namespace internal {
-// Predicate that determines whether T is a Jet.
-template <typename T, typename E = void>
-struct IsJet : std::false_type {};
-
-template <typename T, int N>
-struct IsJet<Jet<T, N>> : std::true_type {};
-
-// Convenience variable template for IsJet.
-template <typename T>
-constexpr bool IsJet_v = IsJet<T>::value;
-
// Predicate that determines whether any of the Types is a Jet.
template <typename... Types>
struct AreAnyJet : std::false_type {};
@@ -65,7 +54,7 @@
// Convenience variable template for AreAnyJet.
template <typename... Types>
-constexpr bool AreAnyJet_v = AreAnyJet<Types...>::value;
+inline constexpr bool AreAnyJet_v = AreAnyJet<Types...>::value;
// Extracts the underlying floating-point from a type T.
template <typename T, typename E = void>
@@ -84,27 +73,8 @@
//
// Specifically, the predicate applies std::is_same recursively to pairs of
// Types in the pack.
-//
-// The predicate is defined only for template packs containing at least two
-// types.
-template <typename T1, typename T2, typename... Types>
-// clang-format off
-struct AreAllSame : std::integral_constant
-<
- bool,
- AreAllSame<T1, T2>::value &&
- AreAllSame<T2, Types...>::value
->
-// clang-format on
-{};
-
-// AreAllSame pairwise test.
-template <typename T1, typename T2>
-struct AreAllSame<T1, T2> : std::is_same<T1, T2> {};
-
-// Convenience variable template for AreAllSame.
-template <typename... Types>
-constexpr bool AreAllSame_v = AreAllSame<Types...>::value;
+template <typename T1, typename... Types>
+inline constexpr bool AreAllSame_v = (std::is_same<T1, Types>::value && ...);
// Determines the rank of a type. This allows to ensure that types passed as
// arguments are compatible to each other. The rank of Jet is determined by the
@@ -124,7 +94,7 @@
// Convenience variable template for Rank.
template <typename T>
-constexpr int Rank_v = Rank<T>::value;
+inline constexpr int Rank_v = Rank<T>::value;
// Constructs an integer sequence of ranks for each of the Types in the pack.
template <typename... Types>
@@ -186,7 +156,8 @@
// This trait is a candidate for a concept definition once C++20 features can
// be used.
template <typename... Types>
-constexpr bool CompatibleJetOperands_v = CompatibleJetOperands<Types...>::value;
+inline constexpr bool CompatibleJetOperands_v =
+ CompatibleJetOperands<Types...>::value;
// Type trait ensuring at least one of the types is a Jet,
// the underlying scalar types are compatible among each other and Jet
@@ -216,7 +187,8 @@
// This trait is a candidate for a concept definition once C++20 features can
// be used.
template <typename... Types>
-constexpr bool PromotableJetOperands_v = PromotableJetOperands<Types...>::value;
+inline constexpr bool PromotableJetOperands_v =
+ PromotableJetOperands<Types...>::value;
} // namespace ceres
diff --git a/include/ceres/internal/parameter_dims.h b/include/ceres/internal/parameter_dims.h
index 2402106..efe2df4 100644
--- a/include/ceres/internal/parameter_dims.h
+++ b/include/ceres/internal/parameter_dims.h
@@ -39,20 +39,6 @@
namespace ceres {
namespace internal {
-// Checks, whether the given parameter block sizes are valid. Valid means every
-// dimension is bigger than zero.
-constexpr bool IsValidParameterDimensionSequence(std::integer_sequence<int>) {
- return true;
-}
-
-template <int N, int... Ts>
-constexpr bool IsValidParameterDimensionSequence(
- std::integer_sequence<int, N, Ts...>) {
- return (N <= 0) ? false
- : IsValidParameterDimensionSequence(
- std::integer_sequence<int, Ts...>());
-}
-
// Helper class that represents the parameter dimensions. The parameter
// dimensions are either dynamic or the sizes are known at compile time. It is
// used to pass parameter block dimensions around (e.g. between functions or
@@ -70,8 +56,7 @@
// The parameter dimensions are only valid if all parameter block dimensions
// are greater than zero.
- static constexpr bool kIsValid =
- IsValidParameterDimensionSequence(Parameters());
+ static constexpr bool kIsValid = ((Ns > 0) && ...);
static_assert(kIsValid,
"Invalid parameter block dimension detected. Each parameter "
"block dimension must be bigger than zero.");
@@ -81,8 +66,7 @@
static_assert(kIsDynamic || kNumParameterBlocks > 0,
"At least one parameter block must be specified.");
- static constexpr int kNumParameters =
- Sum<std::integer_sequence<int, Ns...>>::Value;
+ static constexpr int kNumParameters = (Ns + ... + 0);
static constexpr int GetDim(int dim) { return params_[dim]; }
diff --git a/internal/ceres/integer_sequence_algorithm_test.cc b/internal/ceres/integer_sequence_algorithm_test.cc
index 7e04148..4622fea 100644
--- a/internal/ceres/integer_sequence_algorithm_test.cc
+++ b/internal/ceres/integer_sequence_algorithm_test.cc
@@ -39,20 +39,6 @@
namespace ceres {
namespace internal {
-// Unit tests for summation of integer sequence.
-static_assert(Sum<std::integer_sequence<int>>::Value == 0,
- "Unit test of summing up an integer sequence failed.");
-static_assert(Sum<std::integer_sequence<int, 2>>::Value == 2,
- "Unit test of summing up an integer sequence failed.");
-static_assert(Sum<std::integer_sequence<int, 2, 3>>::Value == 5,
- "Unit test of summing up an integer sequence failed.");
-static_assert(Sum<std::integer_sequence<int, 2, 3, 10>>::Value == 15,
- "Unit test of summing up an integer sequence failed.");
-static_assert(Sum<std::integer_sequence<int, 2, 3, 10, 4>>::Value == 19,
- "Unit test of summing up an integer sequence failed.");
-static_assert(Sum<std::integer_sequence<int, 2, 3, 10, 4, 1>>::Value == 20,
- "Unit test of summing up an integer sequence failed.");
-
// Unit tests for exclusive scan of integer sequence.
static_assert(std::is_same<ExclusiveScan<std::integer_sequence<int>>,
std::integer_sequence<int>>::value,
@@ -129,15 +115,15 @@
static_assert(!AreAllSame_v<int, short, char>, "types must not be the same");
// Ensure all values in the integer sequence match
-static_assert(AreAllEqual_v<std::integer_sequence<int, 1, 1>>,
+static_assert(AreAllEqual_v<int, 1, 1>,
"integer sequence must contain same values");
-static_assert(AreAllEqual_v<std::integer_sequence<long, 2>>,
+static_assert(AreAllEqual_v<long, 2>,
"integer sequence must contain one value");
-static_assert(!AreAllEqual_v<std::integer_sequence<short, 3, 4>>,
+static_assert(!AreAllEqual_v<short, 3, 4>,
"integer sequence must not contain the same values");
-static_assert(!AreAllEqual_v<std::integer_sequence<unsigned, 3, 4, 3>>,
+static_assert(!AreAllEqual_v<unsigned, 3, 4, 3>,
"integer sequence must not contain the same values");
-static_assert(!AreAllEqual_v<std::integer_sequence<int, 4, 4, 3>>,
+static_assert(!AreAllEqual_v<int, 4, 4, 3>,
"integer sequence must not contain the same values");
static_assert(IsEmptyOrAreAllEqual_v<std::integer_sequence<short>>,
diff --git a/internal/ceres/jet_traits_test.cc b/internal/ceres/jet_traits_test.cc
index ee38f47..43afc3c 100644
--- a/internal/ceres/jet_traits_test.cc
+++ b/internal/ceres/jet_traits_test.cc
@@ -44,20 +44,6 @@
using J0 = Jet<T, 0>;
using J0d = J0<double>;
-struct NotAJet {};
-
-static_assert(IsJet_v<J0d>, "Jet is not identified as one");
-static_assert(IsJet_v<J0<NotAJet>>, "Jet is not identified as one");
-static_assert(IsJet_v<J0<J0d>>, "nested Jet is not identified as one");
-static_assert(IsJet_v<J0<J0<J0d>>>, "nested Jet is not identified as one");
-
-static_assert(!IsJet_v<double>, "double must not be a Jet");
-static_assert(!IsJet_v<Eigen::VectorXd>, "Eigen::VectorXd must not be a Jet");
-static_assert(!IsJet_v<decltype(std::declval<Eigen::MatrixXd>() *
- std::declval<Eigen::MatrixXd>())>,
- "product of Eigen::MatrixXd must not be a Jet");
-static_assert(!IsJet_v<NotAJet>, "NotAJet must not be a Jet");
-
// Extract the ranks of given types
using Ranks001 = Ranks_t<Jet<double, 0>, double, Jet<double, 1>>;
using Ranks1 = Ranks_t<Jet<double, 1>>;
diff --git a/internal/ceres/parameter_dims_test.cc b/internal/ceres/parameter_dims_test.cc
index ee3be8f..58d2500 100644
--- a/internal/ceres/parameter_dims_test.cc
+++ b/internal/ceres/parameter_dims_test.cc
@@ -32,20 +32,6 @@
namespace ceres {
namespace internal {
-// Is valid parameter dims unit test
-static_assert(IsValidParameterDimensionSequence(std::integer_sequence<int>()) ==
- true,
- "Unit test of is valid parameter dimension sequence failed.");
-static_assert(IsValidParameterDimensionSequence(
- std::integer_sequence<int, 2, 1>()) == true,
- "Unit test of is valid parameter dimension sequence failed.");
-static_assert(IsValidParameterDimensionSequence(
- std::integer_sequence<int, 0, 1>()) == false,
- "Unit test of is valid parameter dimension sequence failed.");
-static_assert(IsValidParameterDimensionSequence(
- std::integer_sequence<int, 3, 0>()) == false,
- "Unit test of is valid parameter dimension sequence failed.");
-
// Static parameter dims unit test
static_assert(
std::is_same<StaticParameterDims<4, 2, 1>::Parameters,
diff --git a/internal/ceres/program_test.cc b/internal/ceres/program_test.cc
index 8dc1377..300a3a5 100644
--- a/internal/ceres/program_test.cc
+++ b/internal/ceres/program_test.cc
@@ -70,7 +70,7 @@
bool Evaluate(double const* const* parameters,
double* residuals,
double** jacobians) const final {
- const int kNumParameters = Sum<std::integer_sequence<int, Ns...>>::Value;
+ constexpr int kNumParameters = (Ns + ... + 0);
for (int i = 0; i < kNumResiduals; ++i) {
residuals[i] = kNumResiduals + kNumParameters;