Add support for up to 10 parameters in autodiff. Supporting only 6 parameters in autodiff was enough for most cases, but 6 was not always sufficient. This extends the current implementation to work with up to 10 parameters. This also increases the number of parameters supported in SizedCostFunction to 10. Change-Id: Ic783602f93e6ddf4af24fa34eff37c0a4b775dc1
diff --git a/include/ceres/autodiff_cost_function.h b/include/ceres/autodiff_cost_function.h index da9ee2c..bb08d64 100644 --- a/include/ceres/autodiff_cost_function.h +++ b/include/ceres/autodiff_cost_function.h
@@ -154,16 +154,20 @@ int N2 = 0, // Number of parameters in block 2. int N3 = 0, // Number of parameters in block 3. int N4 = 0, // Number of parameters in block 4. - int N5 = 0> // Number of parameters in block 5. + int N5 = 0, // Number of parameters in block 5. + int N6 = 0, // Number of parameters in block 6. + int N7 = 0, // Number of parameters in block 7. + int N8 = 0, // Number of parameters in block 8. + int N9 = 0> // Number of parameters in block 9. class AutoDiffCostFunction : - public SizedCostFunction<M, N0, N1, N2, N3, N4, N5> { + public SizedCostFunction<M, N0, N1, N2, N3, N4, N5, N6, N7, N8, N9> { public: // Takes ownership of functor. Uses the template-provided value for the // number of residuals ("M"). explicit AutoDiffCostFunction(CostFunctor* functor) : functor_(functor) { CHECK_NE(M, DYNAMIC) << "Can't run the fixed-size constructor if the " - << "number of residuals is set to ceres::DYNAMIC."; + << "number of residuals is set to ceres::DYNAMIC."; } // Takes ownership of functor. Ignores the template-provided number of @@ -174,8 +178,9 @@ AutoDiffCostFunction(CostFunctor* functor, int num_residuals) : functor_(functor) { CHECK_EQ(M, DYNAMIC) << "Can't run the dynamic-size constructor if the " - << "number of residuals is not ceres::DYNAMIC."; - SizedCostFunction<M, N0, N1, N2, N3, N4, N5>::set_num_residuals(num_residuals); + << "number of residuals is not ceres::DYNAMIC."; + SizedCostFunction<M, N0, N1, N2, N3, N4, N5, N6, N7, N8, N9> + ::set_num_residuals(num_residuals); } virtual ~AutoDiffCostFunction() {} @@ -190,14 +195,15 @@ double** jacobians) const { if (!jacobians) { return internal::VariadicEvaluate< - CostFunctor, double, N0, N1, N2, N3, N4, N5> + CostFunctor, double, N0, N1, N2, N3, N4, N5, N6, N7, N8, N9> ::Call(*functor_, parameters, residuals); } return internal::AutoDiff<CostFunctor, double, - N0, N1, N2, N3, N4, N5>::Differentiate( + N0, N1, N2, N3, N4, N5, N6, N7, N8, N9>::Differentiate( *functor_, parameters, - SizedCostFunction<M, N0, N1, N2, N3, N4, N5>::num_residuals(), + SizedCostFunction<M, N0, N1, N2, N3, N4, N5, N6, N7, N8, N9> + ::num_residuals(), residuals, jacobians); }
diff --git a/include/ceres/internal/autodiff.h b/include/ceres/internal/autodiff.h index 4f5081f..581e881 100644 --- a/include/ceres/internal/autodiff.h +++ b/include/ceres/internal/autodiff.h
@@ -203,8 +203,8 @@ // Supporting variadic functions is the primary source of complexity in the // autodiff implementation. -template<typename Functor, typename T, - int N0, int N1, int N2, int N3, int N4, int N5> +template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4, + int N5, int N6, int N7, int N8, int N9> struct VariadicEvaluate { static bool Call(const Functor& functor, T const *const *input, T* output) { return functor(input[0], @@ -213,13 +213,78 @@ input[3], input[4], input[5], + input[6], + input[7], + input[8], + input[9], + output); + } +}; + +template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4, + int N5, int N6, int N7, int N8> +struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, N5, N6, N7, N8, 0> { + static bool Call(const Functor& functor, T const *const *input, T* output) { + return functor(input[0], + input[1], + input[2], + input[3], + input[4], + input[5], + input[6], + input[7], + input[8], + output); + } +}; + +template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4, + int N5, int N6, int N7> +struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, N5, N6, N7, 0, 0> { + static bool Call(const Functor& functor, T const *const *input, T* output) { + return functor(input[0], + input[1], + input[2], + input[3], + input[4], + input[5], + input[6], + input[7], + output); + } +}; + +template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4, + int N5, int N6> +struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, N5, N6, 0, 0, 0> { + static bool Call(const Functor& functor, T const *const *input, T* output) { + return functor(input[0], + input[1], + input[2], + input[3], + input[4], + input[5], + input[6], + output); + } +}; + +template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4, + int N5> +struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, N5, 0, 0, 0, 0> { + static bool Call(const Functor& functor, T const *const *input, T* output) { + return functor(input[0], + input[1], + input[2], + input[3], + input[4], + input[5], output); } }; -template<typename Functor, typename T, - int N0, int N1, int N2, int N3, int N4> -struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, 0> { +template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4> +struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, 0, 0, 0, 0, 0> { static bool Call(const Functor& functor, T const *const *input, T* output) { return functor(input[0], input[1], @@ -230,9 +295,8 @@ } }; -template<typename Functor, typename T, - int N0, int N1, int N2, int N3> -struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, 0, 0> { +template<typename Functor, typename T, int N0, int N1, int N2, int N3> +struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, 0, 0, 0, 0, 0, 0> { static bool Call(const Functor& functor, T const *const *input, T* output) { return functor(input[0], input[1], @@ -242,9 +306,8 @@ } }; -template<typename Functor, typename T, - int N0, int N1, int N2> -struct VariadicEvaluate<Functor, T, N0, N1, N2, 0, 0, 0> { +template<typename Functor, typename T, int N0, int N1, int N2> +struct VariadicEvaluate<Functor, T, N0, N1, N2, 0, 0, 0, 0, 0, 0, 0> { static bool Call(const Functor& functor, T const *const *input, T* output) { return functor(input[0], input[1], @@ -253,9 +316,8 @@ } }; -template<typename Functor, typename T, - int N0, int N1> -struct VariadicEvaluate<Functor, T, N0, N1, 0, 0, 0, 0> { +template<typename Functor, typename T, int N0, int N1> +struct VariadicEvaluate<Functor, T, N0, N1, 0, 0, 0, 0, 0, 0, 0, 0> { static bool Call(const Functor& functor, T const *const *input, T* output) { return functor(input[0], input[1], @@ -264,7 +326,7 @@ }; template<typename Functor, typename T, int N0> -struct VariadicEvaluate<Functor, T, N0, 0, 0, 0, 0, 0> { +struct VariadicEvaluate<Functor, T, N0, 0, 0, 0, 0, 0, 0, 0, 0, 0> { static bool Call(const Functor& functor, T const *const *input, T* output) { return functor(input[0], output); @@ -275,48 +337,58 @@ // supported in C++03 (though it is available in C++0x). N0 through N5 are the // dimension of the input arguments to the user supplied functor. template <typename Functor, typename T, - int N0 = 0, int N1 = 0, int N2 = 0, int N3 = 0, int N4 = 0, int N5=0> + int N0 = 0, int N1 = 0, int N2 = 0, int N3 = 0, int N4 = 0, + int N5 = 0, int N6 = 0, int N7 = 0, int N8 = 0, int N9 = 0> struct AutoDiff { static bool Differentiate(const Functor& functor, T const *const *parameters, int num_outputs, T *function_value, T **jacobians) { - typedef Jet<T, N0 + N1 + N2 + N3 + N4 + N5> JetT; - - DCHECK_GT(N0, 0) - << "Cost functions must have at least one parameter block."; - DCHECK((!N1 && !N2 && !N3 && !N4 && !N5) || - ((N1 > 0) && !N2 && !N3 && !N4 && !N5) || - ((N1 > 0) && (N2 > 0) && !N3 && !N4 && !N5) || - ((N1 > 0) && (N2 > 0) && (N3 > 0) && !N4 && !N5) || - ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && !N5) || - ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0))) + // This block breaks the 80 column rule to keep it somewhat readable. + DCHECK_GT(num_outputs, 0); + CHECK((!N1 && !N2 && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) || + ((N1 > 0) && !N2 && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) || + ((N1 > 0) && (N2 > 0) && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) || + ((N1 > 0) && (N2 > 0) && (N3 > 0) && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) || + ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && !N5 && !N6 && !N7 && !N8 && !N9) || + ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && !N6 && !N7 && !N8 && !N9) || + ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && !N7 && !N8 && !N9) || + ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && !N8 && !N9) || + ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && (N8 > 0) && !N9) || + ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && (N8 > 0) && (N9 > 0))) << "Zero block cannot precede a non-zero block. Block sizes are " << "(ignore trailing 0s): " << N0 << ", " << N1 << ", " << N2 << ", " - << N3 << ", " << N4 << ", " << N5; + << N3 << ", " << N4 << ", " << N5 << ", " << N6 << ", " << N7 << ", " + << N8 << ", " << N9; - DCHECK_GT(num_outputs, 0); - + typedef Jet<T, N0 + N1 + N2 + N3 + N4 + N5 + N6 + N7 + N8 + N9> JetT; FixedArray<JetT, (256 * 7) / sizeof(JetT)> x( - N0 + N1 + N2 + N3 + N4 + N5 + num_outputs); + N0 + N1 + N2 + N3 + N4 + N5 + N6 + N7 + N8 + N9 + num_outputs); - // It's ugly, but it works. - const int jet0 = 0; - const int jet1 = N0; - const int jet2 = N0 + N1; - const int jet3 = N0 + N1 + N2; - const int jet4 = N0 + N1 + N2 + N3; - const int jet5 = N0 + N1 + N2 + N3 + N4; - const int jet6 = N0 + N1 + N2 + N3 + N4 + N5; + // These are the positions of the respective jets in the fixed array x. + const int jet0 = 0; + const int jet1 = N0; + const int jet2 = N0 + N1; + const int jet3 = N0 + N1 + N2; + const int jet4 = N0 + N1 + N2 + N3; + const int jet5 = N0 + N1 + N2 + N3 + N4; + const int jet6 = N0 + N1 + N2 + N3 + N4 + N5; + const int jet7 = N0 + N1 + N2 + N3 + N4 + N5 + N6; + const int jet8 = N0 + N1 + N2 + N3 + N4 + N5 + N6 + N7; + const int jet9 = N0 + N1 + N2 + N3 + N4 + N5 + N6 + N7 + N8; - const JetT *unpacked_parameters[6] = { + const JetT *unpacked_parameters[10] = { x.get() + jet0, x.get() + jet1, x.get() + jet2, x.get() + jet3, x.get() + jet4, x.get() + jet5, + x.get() + jet6, + x.get() + jet7, + x.get() + jet8, + x.get() + jet9, }; JetT *output = x.get() + jet6; @@ -333,10 +405,14 @@ CERES_MAKE_1ST_ORDER_PERTURBATION(3); CERES_MAKE_1ST_ORDER_PERTURBATION(4); CERES_MAKE_1ST_ORDER_PERTURBATION(5); + CERES_MAKE_1ST_ORDER_PERTURBATION(6); + CERES_MAKE_1ST_ORDER_PERTURBATION(7); + CERES_MAKE_1ST_ORDER_PERTURBATION(8); + CERES_MAKE_1ST_ORDER_PERTURBATION(9); #undef CERES_MAKE_1ST_ORDER_PERTURBATION if (!VariadicEvaluate<Functor, JetT, - N0, N1, N2, N3, N4, N5>::Call( + N0, N1, N2, N3, N4, N5, N6, N7, N8, N9>::Call( functor, unpacked_parameters, output)) { return false; } @@ -359,6 +435,10 @@ CERES_TAKE_1ST_ORDER_PERTURBATION(3); CERES_TAKE_1ST_ORDER_PERTURBATION(4); CERES_TAKE_1ST_ORDER_PERTURBATION(5); + CERES_TAKE_1ST_ORDER_PERTURBATION(6); + CERES_TAKE_1ST_ORDER_PERTURBATION(7); + CERES_TAKE_1ST_ORDER_PERTURBATION(8); + CERES_TAKE_1ST_ORDER_PERTURBATION(9); #undef CERES_TAKE_1ST_ORDER_PERTURBATION return true; }
diff --git a/include/ceres/sized_cost_function.h b/include/ceres/sized_cost_function.h index 2894a9f..6bfc1af 100644 --- a/include/ceres/sized_cost_function.h +++ b/include/ceres/sized_cost_function.h
@@ -45,25 +45,29 @@ namespace ceres { template<int kNumResiduals, - int N0 = 0, int N1 = 0, int N2 = 0, int N3 = 0, int N4 = 0, int N5 = 0> + int N0 = 0, int N1 = 0, int N2 = 0, int N3 = 0, int N4 = 0, + int N5 = 0, int N6 = 0, int N7 = 0, int N8 = 0, int N9 = 0> class SizedCostFunction : public CostFunction { public: SizedCostFunction() { - // Sanity checking. CHECK(kNumResiduals > 0 || kNumResiduals == DYNAMIC) << "Cost functions must have at least one residual block."; - CHECK_GT(N0, 0) - << "Cost functions must have at least one parameter block."; - CHECK((!N1 && !N2 && !N3 && !N4 && !N5) || - ((N1 > 0) && !N2 && !N3 && !N4 && !N5) || - ((N1 > 0) && (N2 > 0) && !N3 && !N4 && !N5) || - ((N1 > 0) && (N2 > 0) && (N3 > 0) && !N4 && !N5) || - ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && !N5) || - ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0))) + // This block breaks the 80 column rule to keep it somewhat readable. + CHECK((!N1 && !N2 && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) || + ((N1 > 0) && !N2 && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) || + ((N1 > 0) && (N2 > 0) && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) || + ((N1 > 0) && (N2 > 0) && (N3 > 0) && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) || + ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && !N5 && !N6 && !N7 && !N8 && !N9) || + ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && !N6 && !N7 && !N8 && !N9) || + ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && !N7 && !N8 && !N9) || + ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && !N8 && !N9) || + ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && (N8 > 0) && !N9) || + ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && (N8 > 0) && (N9 > 0))) << "Zero block cannot precede a non-zero block. Block sizes are " << "(ignore trailing 0s): " << N0 << ", " << N1 << ", " << N2 << ", " - << N3 << ", " << N4 << ", " << N5; + << N3 << ", " << N4 << ", " << N5 << ", " << N6 << ", " << N7 << ", " + << N8 << ", " << N9; set_num_residuals(kNumResiduals); @@ -75,6 +79,10 @@ ADD_PARAMETER_BLOCK(N3); ADD_PARAMETER_BLOCK(N4); ADD_PARAMETER_BLOCK(N5); + ADD_PARAMETER_BLOCK(N6); + ADD_PARAMETER_BLOCK(N7); + ADD_PARAMETER_BLOCK(N8); + ADD_PARAMETER_BLOCK(N9); #undef ADD_PARAMETER_BLOCK }
diff --git a/internal/ceres/autodiff_cost_function_test.cc b/internal/ceres/autodiff_cost_function_test.cc index 33e576f..e98397a 100644 --- a/internal/ceres/autodiff_cost_function_test.cc +++ b/internal/ceres/autodiff_cost_function_test.cc
@@ -51,7 +51,7 @@ double a_; }; -TEST(AutoDiffResidualAndJacobian, BilinearDifferentiationTest) { +TEST(AutodiffCostFunction, BilinearDifferentiationTest) { CostFunction* cost_function = new AutoDiffCostFunction<BinaryScalarCost, 1, 2, 2>( new BinaryScalarCost(1.0)); @@ -73,20 +73,72 @@ double residuals = 0.0; cost_function->Evaluate(parameters, &residuals, NULL); - EXPECT_EQ(residuals, 10); + EXPECT_EQ(10.0, residuals); cost_function->Evaluate(parameters, &residuals, jacobians); - EXPECT_EQ(jacobians[0][0], 3); - EXPECT_EQ(jacobians[0][1], 4); - EXPECT_EQ(jacobians[1][0], 1); - EXPECT_EQ(jacobians[1][1], 2); + EXPECT_EQ(3, jacobians[0][0]); + EXPECT_EQ(4, jacobians[0][1]); + EXPECT_EQ(1, jacobians[1][0]); + EXPECT_EQ(2, jacobians[1][1]); - delete []jacobians[0]; - delete []jacobians[1]; - delete []parameters[0]; - delete []parameters[1]; - delete []jacobians; - delete []parameters; + delete[] jacobians[0]; + delete[] jacobians[1]; + delete[] parameters[0]; + delete[] parameters[1]; + delete[] jacobians; + delete[] parameters; + delete cost_function; +} + +struct TenParameterCost { + template <typename T> + bool operator()(const T* const x0, + const T* const x1, + const T* const x2, + const T* const x3, + const T* const x4, + const T* const x5, + const T* const x6, + const T* const x7, + const T* const x8, + const T* const x9, + T* cost) const { + cost[0] = *x0 + *x1 + *x2 + *x3 + *x4 + *x5 + *x6 + *x7 + *x8 + *x9; + return true; + } +}; + +TEST(AutodiffCostFunction, ManyParameterAutodiffInstantiates) { + CostFunction* cost_function = + new AutoDiffCostFunction< + TenParameterCost, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>( + new TenParameterCost); + + double** parameters = new double*[10]; + double** jacobians = new double*[10]; + for (int i = 0; i < 10; ++i) { + parameters[i] = new double[1]; + parameters[i][0] = i; + jacobians[i] = new double[1]; + } + + double residuals = 0.0; + + cost_function->Evaluate(parameters, &residuals, NULL); + EXPECT_EQ(45.0, residuals); + + cost_function->Evaluate(parameters, &residuals, jacobians); + EXPECT_EQ(residuals, 45.0); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(1.0, jacobians[i][0]); + } + + for (int i = 0; i < 10; ++i) { + delete[] jacobians[i]; + delete[] parameters[i]; + } + delete[] jacobians; + delete[] parameters; delete cost_function; }