Add support for up to 10 parameters in autodiff.
Supporting only 6 parameters in autodiff was enough for most
cases, but 6 was not always sufficient. This extends the
current implementation to work with up to 10 parameters.
This also increases the number of parameters supported in
SizedCostFunction to 10.
Change-Id: Ic783602f93e6ddf4af24fa34eff37c0a4b775dc1
diff --git a/include/ceres/autodiff_cost_function.h b/include/ceres/autodiff_cost_function.h
index da9ee2c..bb08d64 100644
--- a/include/ceres/autodiff_cost_function.h
+++ b/include/ceres/autodiff_cost_function.h
@@ -154,16 +154,20 @@
int N2 = 0, // Number of parameters in block 2.
int N3 = 0, // Number of parameters in block 3.
int N4 = 0, // Number of parameters in block 4.
- int N5 = 0> // Number of parameters in block 5.
+ int N5 = 0, // Number of parameters in block 5.
+ int N6 = 0, // Number of parameters in block 6.
+ int N7 = 0, // Number of parameters in block 7.
+ int N8 = 0, // Number of parameters in block 8.
+ int N9 = 0> // Number of parameters in block 9.
class AutoDiffCostFunction :
- public SizedCostFunction<M, N0, N1, N2, N3, N4, N5> {
+ public SizedCostFunction<M, N0, N1, N2, N3, N4, N5, N6, N7, N8, N9> {
public:
// Takes ownership of functor. Uses the template-provided value for the
// number of residuals ("M").
explicit AutoDiffCostFunction(CostFunctor* functor)
: functor_(functor) {
CHECK_NE(M, DYNAMIC) << "Can't run the fixed-size constructor if the "
- << "number of residuals is set to ceres::DYNAMIC.";
+ << "number of residuals is set to ceres::DYNAMIC.";
}
// Takes ownership of functor. Ignores the template-provided number of
@@ -174,8 +178,9 @@
AutoDiffCostFunction(CostFunctor* functor, int num_residuals)
: functor_(functor) {
CHECK_EQ(M, DYNAMIC) << "Can't run the dynamic-size constructor if the "
- << "number of residuals is not ceres::DYNAMIC.";
- SizedCostFunction<M, N0, N1, N2, N3, N4, N5>::set_num_residuals(num_residuals);
+ << "number of residuals is not ceres::DYNAMIC.";
+ SizedCostFunction<M, N0, N1, N2, N3, N4, N5, N6, N7, N8, N9>
+ ::set_num_residuals(num_residuals);
}
virtual ~AutoDiffCostFunction() {}
@@ -190,14 +195,15 @@
double** jacobians) const {
if (!jacobians) {
return internal::VariadicEvaluate<
- CostFunctor, double, N0, N1, N2, N3, N4, N5>
+ CostFunctor, double, N0, N1, N2, N3, N4, N5, N6, N7, N8, N9>
::Call(*functor_, parameters, residuals);
}
return internal::AutoDiff<CostFunctor, double,
- N0, N1, N2, N3, N4, N5>::Differentiate(
+ N0, N1, N2, N3, N4, N5, N6, N7, N8, N9>::Differentiate(
*functor_,
parameters,
- SizedCostFunction<M, N0, N1, N2, N3, N4, N5>::num_residuals(),
+ SizedCostFunction<M, N0, N1, N2, N3, N4, N5, N6, N7, N8, N9>
+ ::num_residuals(),
residuals,
jacobians);
}
diff --git a/include/ceres/internal/autodiff.h b/include/ceres/internal/autodiff.h
index 4f5081f..581e881 100644
--- a/include/ceres/internal/autodiff.h
+++ b/include/ceres/internal/autodiff.h
@@ -203,8 +203,8 @@
// Supporting variadic functions is the primary source of complexity in the
// autodiff implementation.
-template<typename Functor, typename T,
- int N0, int N1, int N2, int N3, int N4, int N5>
+template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4,
+ int N5, int N6, int N7, int N8, int N9>
struct VariadicEvaluate {
static bool Call(const Functor& functor, T const *const *input, T* output) {
return functor(input[0],
@@ -213,13 +213,78 @@
input[3],
input[4],
input[5],
+ input[6],
+ input[7],
+ input[8],
+ input[9],
+ output);
+ }
+};
+
+template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4,
+ int N5, int N6, int N7, int N8>
+struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, N5, N6, N7, N8, 0> {
+ static bool Call(const Functor& functor, T const *const *input, T* output) {
+ return functor(input[0],
+ input[1],
+ input[2],
+ input[3],
+ input[4],
+ input[5],
+ input[6],
+ input[7],
+ input[8],
+ output);
+ }
+};
+
+template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4,
+ int N5, int N6, int N7>
+struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, N5, N6, N7, 0, 0> {
+ static bool Call(const Functor& functor, T const *const *input, T* output) {
+ return functor(input[0],
+ input[1],
+ input[2],
+ input[3],
+ input[4],
+ input[5],
+ input[6],
+ input[7],
+ output);
+ }
+};
+
+template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4,
+ int N5, int N6>
+struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, N5, N6, 0, 0, 0> {
+ static bool Call(const Functor& functor, T const *const *input, T* output) {
+ return functor(input[0],
+ input[1],
+ input[2],
+ input[3],
+ input[4],
+ input[5],
+ input[6],
+ output);
+ }
+};
+
+template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4,
+ int N5>
+struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, N5, 0, 0, 0, 0> {
+ static bool Call(const Functor& functor, T const *const *input, T* output) {
+ return functor(input[0],
+ input[1],
+ input[2],
+ input[3],
+ input[4],
+ input[5],
output);
}
};
-template<typename Functor, typename T,
- int N0, int N1, int N2, int N3, int N4>
-struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, 0> {
+template<typename Functor, typename T, int N0, int N1, int N2, int N3, int N4>
+struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, N4, 0, 0, 0, 0, 0> {
static bool Call(const Functor& functor, T const *const *input, T* output) {
return functor(input[0],
input[1],
@@ -230,9 +295,8 @@
}
};
-template<typename Functor, typename T,
- int N0, int N1, int N2, int N3>
-struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, 0, 0> {
+template<typename Functor, typename T, int N0, int N1, int N2, int N3>
+struct VariadicEvaluate<Functor, T, N0, N1, N2, N3, 0, 0, 0, 0, 0, 0> {
static bool Call(const Functor& functor, T const *const *input, T* output) {
return functor(input[0],
input[1],
@@ -242,9 +306,8 @@
}
};
-template<typename Functor, typename T,
- int N0, int N1, int N2>
-struct VariadicEvaluate<Functor, T, N0, N1, N2, 0, 0, 0> {
+template<typename Functor, typename T, int N0, int N1, int N2>
+struct VariadicEvaluate<Functor, T, N0, N1, N2, 0, 0, 0, 0, 0, 0, 0> {
static bool Call(const Functor& functor, T const *const *input, T* output) {
return functor(input[0],
input[1],
@@ -253,9 +316,8 @@
}
};
-template<typename Functor, typename T,
- int N0, int N1>
-struct VariadicEvaluate<Functor, T, N0, N1, 0, 0, 0, 0> {
+template<typename Functor, typename T, int N0, int N1>
+struct VariadicEvaluate<Functor, T, N0, N1, 0, 0, 0, 0, 0, 0, 0, 0> {
static bool Call(const Functor& functor, T const *const *input, T* output) {
return functor(input[0],
input[1],
@@ -264,7 +326,7 @@
};
template<typename Functor, typename T, int N0>
-struct VariadicEvaluate<Functor, T, N0, 0, 0, 0, 0, 0> {
+struct VariadicEvaluate<Functor, T, N0, 0, 0, 0, 0, 0, 0, 0, 0, 0> {
static bool Call(const Functor& functor, T const *const *input, T* output) {
return functor(input[0],
output);
@@ -275,48 +337,58 @@
// supported in C++03 (though it is available in C++0x). N0 through N5 are the
// dimension of the input arguments to the user supplied functor.
template <typename Functor, typename T,
- int N0 = 0, int N1 = 0, int N2 = 0, int N3 = 0, int N4 = 0, int N5=0>
+ int N0 = 0, int N1 = 0, int N2 = 0, int N3 = 0, int N4 = 0,
+ int N5 = 0, int N6 = 0, int N7 = 0, int N8 = 0, int N9 = 0>
struct AutoDiff {
static bool Differentiate(const Functor& functor,
T const *const *parameters,
int num_outputs,
T *function_value,
T **jacobians) {
- typedef Jet<T, N0 + N1 + N2 + N3 + N4 + N5> JetT;
-
- DCHECK_GT(N0, 0)
- << "Cost functions must have at least one parameter block.";
- DCHECK((!N1 && !N2 && !N3 && !N4 && !N5) ||
- ((N1 > 0) && !N2 && !N3 && !N4 && !N5) ||
- ((N1 > 0) && (N2 > 0) && !N3 && !N4 && !N5) ||
- ((N1 > 0) && (N2 > 0) && (N3 > 0) && !N4 && !N5) ||
- ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && !N5) ||
- ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0)))
+ // This block breaks the 80 column rule to keep it somewhat readable.
+ DCHECK_GT(num_outputs, 0);
+ CHECK((!N1 && !N2 && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) ||
+ ((N1 > 0) && !N2 && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) ||
+ ((N1 > 0) && (N2 > 0) && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) ||
+ ((N1 > 0) && (N2 > 0) && (N3 > 0) && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) ||
+ ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && !N5 && !N6 && !N7 && !N8 && !N9) ||
+ ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && !N6 && !N7 && !N8 && !N9) ||
+ ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && !N7 && !N8 && !N9) ||
+ ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && !N8 && !N9) ||
+ ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && (N8 > 0) && !N9) ||
+ ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && (N8 > 0) && (N9 > 0)))
<< "Zero block cannot precede a non-zero block. Block sizes are "
<< "(ignore trailing 0s): " << N0 << ", " << N1 << ", " << N2 << ", "
- << N3 << ", " << N4 << ", " << N5;
+ << N3 << ", " << N4 << ", " << N5 << ", " << N6 << ", " << N7 << ", "
+ << N8 << ", " << N9;
- DCHECK_GT(num_outputs, 0);
-
+ typedef Jet<T, N0 + N1 + N2 + N3 + N4 + N5 + N6 + N7 + N8 + N9> JetT;
FixedArray<JetT, (256 * 7) / sizeof(JetT)> x(
- N0 + N1 + N2 + N3 + N4 + N5 + num_outputs);
+ N0 + N1 + N2 + N3 + N4 + N5 + N6 + N7 + N8 + N9 + num_outputs);
- // It's ugly, but it works.
- const int jet0 = 0;
- const int jet1 = N0;
- const int jet2 = N0 + N1;
- const int jet3 = N0 + N1 + N2;
- const int jet4 = N0 + N1 + N2 + N3;
- const int jet5 = N0 + N1 + N2 + N3 + N4;
- const int jet6 = N0 + N1 + N2 + N3 + N4 + N5;
+ // These are the positions of the respective jets in the fixed array x.
+ const int jet0 = 0;
+ const int jet1 = N0;
+ const int jet2 = N0 + N1;
+ const int jet3 = N0 + N1 + N2;
+ const int jet4 = N0 + N1 + N2 + N3;
+ const int jet5 = N0 + N1 + N2 + N3 + N4;
+ const int jet6 = N0 + N1 + N2 + N3 + N4 + N5;
+ const int jet7 = N0 + N1 + N2 + N3 + N4 + N5 + N6;
+ const int jet8 = N0 + N1 + N2 + N3 + N4 + N5 + N6 + N7;
+ const int jet9 = N0 + N1 + N2 + N3 + N4 + N5 + N6 + N7 + N8;
- const JetT *unpacked_parameters[6] = {
+ const JetT *unpacked_parameters[10] = {
x.get() + jet0,
x.get() + jet1,
x.get() + jet2,
x.get() + jet3,
x.get() + jet4,
x.get() + jet5,
+ x.get() + jet6,
+ x.get() + jet7,
+ x.get() + jet8,
+ x.get() + jet9,
};
JetT *output = x.get() + jet6;
@@ -333,10 +405,14 @@
CERES_MAKE_1ST_ORDER_PERTURBATION(3);
CERES_MAKE_1ST_ORDER_PERTURBATION(4);
CERES_MAKE_1ST_ORDER_PERTURBATION(5);
+ CERES_MAKE_1ST_ORDER_PERTURBATION(6);
+ CERES_MAKE_1ST_ORDER_PERTURBATION(7);
+ CERES_MAKE_1ST_ORDER_PERTURBATION(8);
+ CERES_MAKE_1ST_ORDER_PERTURBATION(9);
#undef CERES_MAKE_1ST_ORDER_PERTURBATION
if (!VariadicEvaluate<Functor, JetT,
- N0, N1, N2, N3, N4, N5>::Call(
+ N0, N1, N2, N3, N4, N5, N6, N7, N8, N9>::Call(
functor, unpacked_parameters, output)) {
return false;
}
@@ -359,6 +435,10 @@
CERES_TAKE_1ST_ORDER_PERTURBATION(3);
CERES_TAKE_1ST_ORDER_PERTURBATION(4);
CERES_TAKE_1ST_ORDER_PERTURBATION(5);
+ CERES_TAKE_1ST_ORDER_PERTURBATION(6);
+ CERES_TAKE_1ST_ORDER_PERTURBATION(7);
+ CERES_TAKE_1ST_ORDER_PERTURBATION(8);
+ CERES_TAKE_1ST_ORDER_PERTURBATION(9);
#undef CERES_TAKE_1ST_ORDER_PERTURBATION
return true;
}
diff --git a/include/ceres/sized_cost_function.h b/include/ceres/sized_cost_function.h
index 2894a9f..6bfc1af 100644
--- a/include/ceres/sized_cost_function.h
+++ b/include/ceres/sized_cost_function.h
@@ -45,25 +45,29 @@
namespace ceres {
template<int kNumResiduals,
- int N0 = 0, int N1 = 0, int N2 = 0, int N3 = 0, int N4 = 0, int N5 = 0>
+ int N0 = 0, int N1 = 0, int N2 = 0, int N3 = 0, int N4 = 0,
+ int N5 = 0, int N6 = 0, int N7 = 0, int N8 = 0, int N9 = 0>
class SizedCostFunction : public CostFunction {
public:
SizedCostFunction() {
- // Sanity checking.
CHECK(kNumResiduals > 0 || kNumResiduals == DYNAMIC)
<< "Cost functions must have at least one residual block.";
- CHECK_GT(N0, 0)
- << "Cost functions must have at least one parameter block.";
- CHECK((!N1 && !N2 && !N3 && !N4 && !N5) ||
- ((N1 > 0) && !N2 && !N3 && !N4 && !N5) ||
- ((N1 > 0) && (N2 > 0) && !N3 && !N4 && !N5) ||
- ((N1 > 0) && (N2 > 0) && (N3 > 0) && !N4 && !N5) ||
- ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && !N5) ||
- ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0)))
+ // This block breaks the 80 column rule to keep it somewhat readable.
+ CHECK((!N1 && !N2 && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) ||
+ ((N1 > 0) && !N2 && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) ||
+ ((N1 > 0) && (N2 > 0) && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) ||
+ ((N1 > 0) && (N2 > 0) && (N3 > 0) && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) ||
+ ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && !N5 && !N6 && !N7 && !N8 && !N9) ||
+ ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && !N6 && !N7 && !N8 && !N9) ||
+ ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && !N7 && !N8 && !N9) ||
+ ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && !N8 && !N9) ||
+ ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && (N8 > 0) && !N9) ||
+ ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && (N8 > 0) && (N9 > 0)))
<< "Zero block cannot precede a non-zero block. Block sizes are "
<< "(ignore trailing 0s): " << N0 << ", " << N1 << ", " << N2 << ", "
- << N3 << ", " << N4 << ", " << N5;
+ << N3 << ", " << N4 << ", " << N5 << ", " << N6 << ", " << N7 << ", "
+ << N8 << ", " << N9;
set_num_residuals(kNumResiduals);
@@ -75,6 +79,10 @@
ADD_PARAMETER_BLOCK(N3);
ADD_PARAMETER_BLOCK(N4);
ADD_PARAMETER_BLOCK(N5);
+ ADD_PARAMETER_BLOCK(N6);
+ ADD_PARAMETER_BLOCK(N7);
+ ADD_PARAMETER_BLOCK(N8);
+ ADD_PARAMETER_BLOCK(N9);
#undef ADD_PARAMETER_BLOCK
}
diff --git a/internal/ceres/autodiff_cost_function_test.cc b/internal/ceres/autodiff_cost_function_test.cc
index 33e576f..e98397a 100644
--- a/internal/ceres/autodiff_cost_function_test.cc
+++ b/internal/ceres/autodiff_cost_function_test.cc
@@ -51,7 +51,7 @@
double a_;
};
-TEST(AutoDiffResidualAndJacobian, BilinearDifferentiationTest) {
+TEST(AutodiffCostFunction, BilinearDifferentiationTest) {
CostFunction* cost_function =
new AutoDiffCostFunction<BinaryScalarCost, 1, 2, 2>(
new BinaryScalarCost(1.0));
@@ -73,20 +73,72 @@
double residuals = 0.0;
cost_function->Evaluate(parameters, &residuals, NULL);
- EXPECT_EQ(residuals, 10);
+ EXPECT_EQ(10.0, residuals);
cost_function->Evaluate(parameters, &residuals, jacobians);
- EXPECT_EQ(jacobians[0][0], 3);
- EXPECT_EQ(jacobians[0][1], 4);
- EXPECT_EQ(jacobians[1][0], 1);
- EXPECT_EQ(jacobians[1][1], 2);
+ EXPECT_EQ(3, jacobians[0][0]);
+ EXPECT_EQ(4, jacobians[0][1]);
+ EXPECT_EQ(1, jacobians[1][0]);
+ EXPECT_EQ(2, jacobians[1][1]);
- delete []jacobians[0];
- delete []jacobians[1];
- delete []parameters[0];
- delete []parameters[1];
- delete []jacobians;
- delete []parameters;
+ delete[] jacobians[0];
+ delete[] jacobians[1];
+ delete[] parameters[0];
+ delete[] parameters[1];
+ delete[] jacobians;
+ delete[] parameters;
+ delete cost_function;
+}
+
+struct TenParameterCost {
+ template <typename T>
+ bool operator()(const T* const x0,
+ const T* const x1,
+ const T* const x2,
+ const T* const x3,
+ const T* const x4,
+ const T* const x5,
+ const T* const x6,
+ const T* const x7,
+ const T* const x8,
+ const T* const x9,
+ T* cost) const {
+ cost[0] = *x0 + *x1 + *x2 + *x3 + *x4 + *x5 + *x6 + *x7 + *x8 + *x9;
+ return true;
+ }
+};
+
+TEST(AutodiffCostFunction, ManyParameterAutodiffInstantiates) {
+ CostFunction* cost_function =
+ new AutoDiffCostFunction<
+ TenParameterCost, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>(
+ new TenParameterCost);
+
+ double** parameters = new double*[10];
+ double** jacobians = new double*[10];
+ for (int i = 0; i < 10; ++i) {
+ parameters[i] = new double[1];
+ parameters[i][0] = i;
+ jacobians[i] = new double[1];
+ }
+
+ double residuals = 0.0;
+
+ cost_function->Evaluate(parameters, &residuals, NULL);
+ EXPECT_EQ(45.0, residuals);
+
+ cost_function->Evaluate(parameters, &residuals, jacobians);
+ EXPECT_EQ(residuals, 45.0);
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_EQ(1.0, jacobians[i][0]);
+ }
+
+ for (int i = 0; i < 10; ++i) {
+ delete[] jacobians[i];
+ delete[] parameters[i];
+ }
+ delete[] jacobians;
+ delete[] parameters;
delete cost_function;
}