Simplify some template metaprograms using fold expressions. Change-Id: I865b670b99df30db39d33cbfe45170b70472e532
diff --git a/include/ceres/internal/integer_sequence_algorithm.h b/include/ceres/internal/integer_sequence_algorithm.h index 777c119..80c821c 100644 --- a/include/ceres/internal/integer_sequence_algorithm.h +++ b/include/ceres/internal/integer_sequence_algorithm.h
@@ -43,68 +43,6 @@ namespace ceres { namespace internal { -// Implementation of calculating the sum of an integer sequence. -// Recursively instantiate SumImpl and calculate the sum of the N first -// numbers. This reduces the number of instantiations and speeds up -// compilation. -// -// Examples: -// 1) integer_sequence<int, 5>: -// Value = 5 -// -// 2) integer_sequence<int, 4, 2>: -// Value = 4 + 2 + SumImpl<integer_sequence<int>>::Value -// Value = 4 + 2 + 0 -// -// 3) integer_sequence<int, 2, 1, 4>: -// Value = 2 + 1 + SumImpl<integer_sequence<int, 4>>::Value -// Value = 2 + 1 + 4 -template <typename Seq> -struct SumImpl; - -// Strip of and sum the first number. -template <typename T, T N, T... Ns> -struct SumImpl<std::integer_sequence<T, N, Ns...>> { - static constexpr T Value = - N + SumImpl<std::integer_sequence<T, Ns...>>::Value; -}; - -// Strip of and sum the first two numbers. -template <typename T, T N1, T N2, T... Ns> -struct SumImpl<std::integer_sequence<T, N1, N2, Ns...>> { - static constexpr T Value = - N1 + N2 + SumImpl<std::integer_sequence<T, Ns...>>::Value; -}; - -// Strip of and sum the first four numbers. -template <typename T, T N1, T N2, T N3, T N4, T... Ns> -struct SumImpl<std::integer_sequence<T, N1, N2, N3, N4, Ns...>> { - static constexpr T Value = - N1 + N2 + N3 + N4 + SumImpl<std::integer_sequence<T, Ns...>>::Value; -}; - -// Only one number is left. 'Value' is just that number ('recursion' ends). -template <typename T, T N> -struct SumImpl<std::integer_sequence<T, N>> { - static constexpr T Value = N; -}; - -// No number is left. 'Value' is the identity element (for sum this is zero). -template <typename T> -struct SumImpl<std::integer_sequence<T>> { - static constexpr T Value = T(0); -}; - -// Calculate the sum of an integer sequence. The resulting sum will be stored in -// 'Value'. -template <typename Seq> -class Sum { - using T = typename Seq::value_type; - - public: - static constexpr T Value = SumImpl<Seq>::Value; -}; - // Implementation of calculating an exclusive scan (exclusive prefix sum) of an // integer sequence. Exclusive means that the i-th input element is not included // in the i-th sum. Calculating the exclusive scan for an input array I results @@ -232,40 +170,11 @@ template <typename Sequence, typename Sequence::value_type ValueToRemove> using RemoveValue_t = typename RemoveValue<Sequence, ValueToRemove>::type; -// Determines whether the values of an integer sequence are all the same. +// Returns true if all elements of Values are equal to HeadValue. // -// The integer sequence must contain at least one value. The predicate is -// undefined for empty sequences. The evaluation result of the predicate for a -// sequence containing only one value is defined to be true. -template <typename... Sequence> -struct AreAllEqual; - -// The predicate result for a sequence containing one element is defined to be -// true. -template <typename T, T Value> -struct AreAllEqual<std::integer_sequence<T, Value>> : std::true_type {}; - -// Recursion end. -template <typename T, T Value1, T Value2> -struct AreAllEqual<std::integer_sequence<T, Value1, Value2>> - : std::integral_constant<bool, Value1 == Value2> {}; - -// Recursion for sequences containing at least two elements. -template <typename T, T Value1, T Value2, T... Values> -// clang-format off -struct AreAllEqual<std::integer_sequence<T, Value1, Value2, Values...> > - : std::integral_constant -< - bool, - AreAllEqual<std::integer_sequence<T, Value1, Value2> >::value && - AreAllEqual<std::integer_sequence<T, Value2, Values...> >::value -> -// clang-format on -{}; - -// Convenience variable template for AreAllEqual. -template <class Sequence> -constexpr bool AreAllEqual_v = AreAllEqual<Sequence>::value; +// Returns true if Values is empty. +template <typename T, T HeadValue, T... Values> +inline constexpr bool AreAllEqual_v = ((HeadValue == Values) && ...); // Predicate determining whether an integer sequence is either empty or all // values are equal. @@ -279,11 +188,12 @@ // General case for sequences containing at least one value. template <typename T, T HeadValue, T... Values> struct IsEmptyOrAreAllEqual<std::integer_sequence<T, HeadValue, Values...>> - : AreAllEqual<std::integer_sequence<T, HeadValue, Values...>> {}; + : std::integral_constant<bool, AreAllEqual_v<T, HeadValue, Values...>> {}; // Convenience variable template for IsEmptyOrAreAllEqual. template <class Sequence> -constexpr bool IsEmptyOrAreAllEqual_v = IsEmptyOrAreAllEqual<Sequence>::value; +inline constexpr bool IsEmptyOrAreAllEqual_v = + IsEmptyOrAreAllEqual<Sequence>::value; } // namespace internal } // namespace ceres
diff --git a/include/ceres/internal/jet_traits.h b/include/ceres/internal/jet_traits.h index 2a38c05..746638f 100644 --- a/include/ceres/internal/jet_traits.h +++ b/include/ceres/internal/jet_traits.h
@@ -42,17 +42,6 @@ namespace ceres { namespace internal { -// Predicate that determines whether T is a Jet. -template <typename T, typename E = void> -struct IsJet : std::false_type {}; - -template <typename T, int N> -struct IsJet<Jet<T, N>> : std::true_type {}; - -// Convenience variable template for IsJet. -template <typename T> -constexpr bool IsJet_v = IsJet<T>::value; - // Predicate that determines whether any of the Types is a Jet. template <typename... Types> struct AreAnyJet : std::false_type {}; @@ -65,7 +54,7 @@ // Convenience variable template for AreAnyJet. template <typename... Types> -constexpr bool AreAnyJet_v = AreAnyJet<Types...>::value; +inline constexpr bool AreAnyJet_v = AreAnyJet<Types...>::value; // Extracts the underlying floating-point from a type T. template <typename T, typename E = void> @@ -84,27 +73,8 @@ // // Specifically, the predicate applies std::is_same recursively to pairs of // Types in the pack. -// -// The predicate is defined only for template packs containing at least two -// types. -template <typename T1, typename T2, typename... Types> -// clang-format off -struct AreAllSame : std::integral_constant -< - bool, - AreAllSame<T1, T2>::value && - AreAllSame<T2, Types...>::value -> -// clang-format on -{}; - -// AreAllSame pairwise test. -template <typename T1, typename T2> -struct AreAllSame<T1, T2> : std::is_same<T1, T2> {}; - -// Convenience variable template for AreAllSame. -template <typename... Types> -constexpr bool AreAllSame_v = AreAllSame<Types...>::value; +template <typename T1, typename... Types> +inline constexpr bool AreAllSame_v = (std::is_same<T1, Types>::value && ...); // Determines the rank of a type. This allows to ensure that types passed as // arguments are compatible to each other. The rank of Jet is determined by the @@ -124,7 +94,7 @@ // Convenience variable template for Rank. template <typename T> -constexpr int Rank_v = Rank<T>::value; +inline constexpr int Rank_v = Rank<T>::value; // Constructs an integer sequence of ranks for each of the Types in the pack. template <typename... Types> @@ -186,7 +156,8 @@ // This trait is a candidate for a concept definition once C++20 features can // be used. template <typename... Types> -constexpr bool CompatibleJetOperands_v = CompatibleJetOperands<Types...>::value; +inline constexpr bool CompatibleJetOperands_v = + CompatibleJetOperands<Types...>::value; // Type trait ensuring at least one of the types is a Jet, // the underlying scalar types are compatible among each other and Jet @@ -216,7 +187,8 @@ // This trait is a candidate for a concept definition once C++20 features can // be used. template <typename... Types> -constexpr bool PromotableJetOperands_v = PromotableJetOperands<Types...>::value; +inline constexpr bool PromotableJetOperands_v = + PromotableJetOperands<Types...>::value; } // namespace ceres
diff --git a/include/ceres/internal/parameter_dims.h b/include/ceres/internal/parameter_dims.h index 2402106..efe2df4 100644 --- a/include/ceres/internal/parameter_dims.h +++ b/include/ceres/internal/parameter_dims.h
@@ -39,20 +39,6 @@ namespace ceres { namespace internal { -// Checks, whether the given parameter block sizes are valid. Valid means every -// dimension is bigger than zero. -constexpr bool IsValidParameterDimensionSequence(std::integer_sequence<int>) { - return true; -} - -template <int N, int... Ts> -constexpr bool IsValidParameterDimensionSequence( - std::integer_sequence<int, N, Ts...>) { - return (N <= 0) ? false - : IsValidParameterDimensionSequence( - std::integer_sequence<int, Ts...>()); -} - // Helper class that represents the parameter dimensions. The parameter // dimensions are either dynamic or the sizes are known at compile time. It is // used to pass parameter block dimensions around (e.g. between functions or @@ -70,8 +56,7 @@ // The parameter dimensions are only valid if all parameter block dimensions // are greater than zero. - static constexpr bool kIsValid = - IsValidParameterDimensionSequence(Parameters()); + static constexpr bool kIsValid = ((Ns > 0) && ...); static_assert(kIsValid, "Invalid parameter block dimension detected. Each parameter " "block dimension must be bigger than zero."); @@ -81,8 +66,7 @@ static_assert(kIsDynamic || kNumParameterBlocks > 0, "At least one parameter block must be specified."); - static constexpr int kNumParameters = - Sum<std::integer_sequence<int, Ns...>>::Value; + static constexpr int kNumParameters = (Ns + ... + 0); static constexpr int GetDim(int dim) { return params_[dim]; }
diff --git a/internal/ceres/integer_sequence_algorithm_test.cc b/internal/ceres/integer_sequence_algorithm_test.cc index 7e04148..4622fea 100644 --- a/internal/ceres/integer_sequence_algorithm_test.cc +++ b/internal/ceres/integer_sequence_algorithm_test.cc
@@ -39,20 +39,6 @@ namespace ceres { namespace internal { -// Unit tests for summation of integer sequence. -static_assert(Sum<std::integer_sequence<int>>::Value == 0, - "Unit test of summing up an integer sequence failed."); -static_assert(Sum<std::integer_sequence<int, 2>>::Value == 2, - "Unit test of summing up an integer sequence failed."); -static_assert(Sum<std::integer_sequence<int, 2, 3>>::Value == 5, - "Unit test of summing up an integer sequence failed."); -static_assert(Sum<std::integer_sequence<int, 2, 3, 10>>::Value == 15, - "Unit test of summing up an integer sequence failed."); -static_assert(Sum<std::integer_sequence<int, 2, 3, 10, 4>>::Value == 19, - "Unit test of summing up an integer sequence failed."); -static_assert(Sum<std::integer_sequence<int, 2, 3, 10, 4, 1>>::Value == 20, - "Unit test of summing up an integer sequence failed."); - // Unit tests for exclusive scan of integer sequence. static_assert(std::is_same<ExclusiveScan<std::integer_sequence<int>>, std::integer_sequence<int>>::value, @@ -129,15 +115,15 @@ static_assert(!AreAllSame_v<int, short, char>, "types must not be the same"); // Ensure all values in the integer sequence match -static_assert(AreAllEqual_v<std::integer_sequence<int, 1, 1>>, +static_assert(AreAllEqual_v<int, 1, 1>, "integer sequence must contain same values"); -static_assert(AreAllEqual_v<std::integer_sequence<long, 2>>, +static_assert(AreAllEqual_v<long, 2>, "integer sequence must contain one value"); -static_assert(!AreAllEqual_v<std::integer_sequence<short, 3, 4>>, +static_assert(!AreAllEqual_v<short, 3, 4>, "integer sequence must not contain the same values"); -static_assert(!AreAllEqual_v<std::integer_sequence<unsigned, 3, 4, 3>>, +static_assert(!AreAllEqual_v<unsigned, 3, 4, 3>, "integer sequence must not contain the same values"); -static_assert(!AreAllEqual_v<std::integer_sequence<int, 4, 4, 3>>, +static_assert(!AreAllEqual_v<int, 4, 4, 3>, "integer sequence must not contain the same values"); static_assert(IsEmptyOrAreAllEqual_v<std::integer_sequence<short>>,
diff --git a/internal/ceres/jet_traits_test.cc b/internal/ceres/jet_traits_test.cc index ee38f47..43afc3c 100644 --- a/internal/ceres/jet_traits_test.cc +++ b/internal/ceres/jet_traits_test.cc
@@ -44,20 +44,6 @@ using J0 = Jet<T, 0>; using J0d = J0<double>; -struct NotAJet {}; - -static_assert(IsJet_v<J0d>, "Jet is not identified as one"); -static_assert(IsJet_v<J0<NotAJet>>, "Jet is not identified as one"); -static_assert(IsJet_v<J0<J0d>>, "nested Jet is not identified as one"); -static_assert(IsJet_v<J0<J0<J0d>>>, "nested Jet is not identified as one"); - -static_assert(!IsJet_v<double>, "double must not be a Jet"); -static_assert(!IsJet_v<Eigen::VectorXd>, "Eigen::VectorXd must not be a Jet"); -static_assert(!IsJet_v<decltype(std::declval<Eigen::MatrixXd>() * - std::declval<Eigen::MatrixXd>())>, - "product of Eigen::MatrixXd must not be a Jet"); -static_assert(!IsJet_v<NotAJet>, "NotAJet must not be a Jet"); - // Extract the ranks of given types using Ranks001 = Ranks_t<Jet<double, 0>, double, Jet<double, 1>>; using Ranks1 = Ranks_t<Jet<double, 1>>;
diff --git a/internal/ceres/parameter_dims_test.cc b/internal/ceres/parameter_dims_test.cc index ee3be8f..58d2500 100644 --- a/internal/ceres/parameter_dims_test.cc +++ b/internal/ceres/parameter_dims_test.cc
@@ -32,20 +32,6 @@ namespace ceres { namespace internal { -// Is valid parameter dims unit test -static_assert(IsValidParameterDimensionSequence(std::integer_sequence<int>()) == - true, - "Unit test of is valid parameter dimension sequence failed."); -static_assert(IsValidParameterDimensionSequence( - std::integer_sequence<int, 2, 1>()) == true, - "Unit test of is valid parameter dimension sequence failed."); -static_assert(IsValidParameterDimensionSequence( - std::integer_sequence<int, 0, 1>()) == false, - "Unit test of is valid parameter dimension sequence failed."); -static_assert(IsValidParameterDimensionSequence( - std::integer_sequence<int, 3, 0>()) == false, - "Unit test of is valid parameter dimension sequence failed."); - // Static parameter dims unit test static_assert( std::is_same<StaticParameterDims<4, 2, 1>::Parameters,
diff --git a/internal/ceres/program_test.cc b/internal/ceres/program_test.cc index 8dc1377..300a3a5 100644 --- a/internal/ceres/program_test.cc +++ b/internal/ceres/program_test.cc
@@ -70,7 +70,7 @@ bool Evaluate(double const* const* parameters, double* residuals, double** jacobians) const final { - const int kNumParameters = Sum<std::integer_sequence<int, Ns...>>::Value; + constexpr int kNumParameters = (Ns + ... + 0); for (int i = 0; i < kNumResiduals; ++i) { residuals[i] = kNumResiduals + kNumParameters;