Implement AddResidualBlock using variadic templates
This CL changes the implementation of AddResidualBlock() in
ceres::Problem using variadic templates. Also one new overload for
AddResidualBlock() is added using a double** and the number of
parameter blocks.
Change-Id: I007a82a06897335a117213a0d12fedb4a77076a0
diff --git a/include/ceres/problem.h b/include/ceres/problem.h
index e220e66..503e0fe 100644
--- a/include/ceres/problem.h
+++ b/include/ceres/problem.h
@@ -34,6 +34,7 @@
#ifndef CERES_PUBLIC_PROBLEM_H_
#define CERES_PUBLIC_PROBLEM_H_
+#include <array>
#include <cstddef>
#include <map>
#include <memory>
@@ -212,56 +213,32 @@
// problem.AddResidualBlock(new MyUnaryCostFunction(...), NULL, x1);
// problem.AddResidualBlock(new MyBinaryCostFunction(...), NULL, x2, x1);
//
+ // Add a residual block by listing the parameter block pointers
+ // directly instead of wapping them in a container.
+ template <typename... Ts>
+ ResidualBlockId AddResidualBlock(CostFunction* cost_function,
+ LossFunction* loss_function,
+ double* x0,
+ Ts*... xs) {
+ const std::array<double*, sizeof...(Ts) + 1> parameter_blocks{{x0, xs...}};
+ return AddResidualBlock(cost_function, loss_function,
+ parameter_blocks.data(),
+ static_cast<int>(parameter_blocks.size()));
+ }
+
+ // Add a residual block by providing a vector of parameter blocks.
ResidualBlockId AddResidualBlock(
CostFunction* cost_function,
LossFunction* loss_function,
const std::vector<double*>& parameter_blocks);
- // Convenience methods for adding residuals with a small number of
- // parameters. This is the common case. Instead of specifying the
- // parameter block arguments as a vector, list them as pointers.
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2,
- double* x3);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2,
- double* x3, double* x4);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2,
- double* x3, double* x4, double* x5);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2,
- double* x3, double* x4, double* x5,
- double* x6);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2,
- double* x3, double* x4, double* x5,
- double* x6, double* x7);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2,
- double* x3, double* x4, double* x5,
- double* x6, double* x7, double* x8);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2,
- double* x3, double* x4, double* x5,
- double* x6, double* x7, double* x8,
- double* x9);
+ // Add a residual block by providing a pointer to the parameter block array
+ // and the number of parameter blocks.
+ ResidualBlockId AddResidualBlock(
+ CostFunction* cost_function,
+ LossFunction* loss_function,
+ double* const* const parameter_blocks,
+ int num_parameter_blocks);
// Add a parameter block with appropriate size to the problem.
// Repeated calls with the same arguments are ignored. Repeated
diff --git a/internal/ceres/gradient_checking_cost_function.cc b/internal/ceres/gradient_checking_cost_function.cc
index 9542af7..1afbec3 100644
--- a/internal/ceres/gradient_checking_cost_function.cc
+++ b/internal/ceres/gradient_checking_cost_function.cc
@@ -266,7 +266,8 @@
gradient_checking_problem_impl->AddResidualBlock(
gradient_checking_cost_function,
const_cast<LossFunction*>(residual_block->loss_function()),
- parameter_blocks);
+ parameter_blocks.data(),
+ static_cast<int>(parameter_blocks.size()));
}
// Normally, when a problem is given to the solver, we guarantee
diff --git a/internal/ceres/problem.cc b/internal/ceres/problem.cc
index 956afa1..6939b46 100644
--- a/internal/ceres/problem.cc
+++ b/internal/ceres/problem.cc
@@ -48,104 +48,22 @@
CostFunction* cost_function,
LossFunction* loss_function,
const vector<double*>& parameter_blocks) {
- return problem_impl_->AddResidualBlock(cost_function,
- loss_function,
- parameter_blocks);
-}
-
-ResidualBlockId Problem::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0) {
- return problem_impl_->AddResidualBlock(cost_function,
- loss_function,
- x0);
-}
-
-ResidualBlockId Problem::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1) {
- return problem_impl_->AddResidualBlock(cost_function,
- loss_function,
- x0, x1);
-}
-
-ResidualBlockId Problem::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2) {
- return problem_impl_->AddResidualBlock(cost_function,
- loss_function,
- x0, x1, x2);
-}
-
-ResidualBlockId Problem::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2, double* x3) {
- return problem_impl_->AddResidualBlock(cost_function,
- loss_function,
- x0, x1, x2, x3);
-}
-
-ResidualBlockId Problem::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2, double* x3, double* x4) {
- return problem_impl_->AddResidualBlock(cost_function,
- loss_function,
- x0, x1, x2, x3, x4);
-}
-
-ResidualBlockId Problem::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2, double* x3, double* x4, double* x5) {
- return problem_impl_->AddResidualBlock(cost_function,
- loss_function,
- x0, x1, x2, x3, x4, x5);
-}
-
-ResidualBlockId Problem::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2, double* x3, double* x4, double* x5,
- double* x6) {
- return problem_impl_->AddResidualBlock(cost_function,
- loss_function,
- x0, x1, x2, x3, x4, x5, x6);
-}
-
-ResidualBlockId Problem::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2, double* x3, double* x4, double* x5,
- double* x6, double* x7) {
- return problem_impl_->AddResidualBlock(cost_function,
- loss_function,
- x0, x1, x2, x3, x4, x5, x6, x7);
-}
-
-ResidualBlockId Problem::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2, double* x3, double* x4, double* x5,
- double* x6, double* x7, double* x8) {
- return problem_impl_->AddResidualBlock(cost_function,
- loss_function,
- x0, x1, x2, x3, x4, x5, x6, x7, x8);
-}
-
-ResidualBlockId Problem::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2, double* x3, double* x4, double* x5,
- double* x6, double* x7, double* x8, double* x9) {
return problem_impl_->AddResidualBlock(
cost_function,
loss_function,
- x0, x1, x2, x3, x4, x5, x6, x7, x8, x9);
+ parameter_blocks.data(),
+ static_cast<int>(parameter_blocks.size()));
+}
+
+ResidualBlockId Problem::AddResidualBlock(
+ CostFunction* cost_function,
+ LossFunction* loss_function,
+ double* const* const parameter_blocks,
+ int num_parameter_blocks) {
+ return problem_impl_->AddResidualBlock(cost_function,
+ loss_function,
+ parameter_blocks,
+ num_parameter_blocks);
}
void Problem::AddParameterBlock(double* values, int size) {
diff --git a/internal/ceres/problem_impl.cc b/internal/ceres/problem_impl.cc
index 00cd2fd..40d5aa2 100644
--- a/internal/ceres/problem_impl.cc
+++ b/internal/ceres/problem_impl.cc
@@ -247,14 +247,12 @@
ProblemImpl::ProblemImpl()
: options_(Problem::Options()),
program_(new internal::Program) {
- residual_parameters_.reserve(10);
InitializeContext(options_.context, &context_impl_, &context_impl_owned_);
}
ProblemImpl::ProblemImpl(const Problem::Options& options)
: options_(options),
program_(new internal::Program) {
- residual_parameters_.reserve(10);
InitializeContext(options_.context, &context_impl_, &context_impl_owned_);
}
@@ -286,12 +284,13 @@
}
}
-ResidualBlock* ProblemImpl::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- const vector<double*>& parameter_blocks) {
+ResidualBlockId ProblemImpl::AddResidualBlock(
+ CostFunction* cost_function,
+ LossFunction* loss_function,
+ double* const* const parameter_blocks,
+ int num_parameter_blocks) {
CHECK(cost_function != nullptr);
- CHECK_EQ(parameter_blocks.size(),
+ CHECK_EQ(num_parameter_blocks,
cost_function->parameter_block_sizes().size());
// Check the sizes match.
@@ -299,12 +298,13 @@
cost_function->parameter_block_sizes();
if (!options_.disable_all_safety_checks) {
- CHECK_EQ(parameter_block_sizes.size(), parameter_blocks.size())
+ CHECK_EQ(parameter_block_sizes.size(), num_parameter_blocks)
<< "Number of blocks input is different than the number of blocks "
<< "that the cost function expects.";
// Check for duplicate parameter blocks.
- vector<double*> sorted_parameter_blocks(parameter_blocks);
+ vector<double*> sorted_parameter_blocks(
+ parameter_blocks, parameter_blocks + num_parameter_blocks);
sort(sorted_parameter_blocks.begin(), sorted_parameter_blocks.end());
const bool has_duplicate_items =
(std::adjacent_find(sorted_parameter_blocks.begin(),
@@ -312,7 +312,7 @@
!= sorted_parameter_blocks.end());
if (has_duplicate_items) {
string blocks;
- for (int i = 0; i < parameter_blocks.size(); ++i) {
+ for (int i = 0; i < num_parameter_blocks; ++i) {
blocks += StringPrintf(" %p ", parameter_blocks[i]);
}
@@ -323,8 +323,8 @@
}
// Add parameter blocks and convert the double*'s to parameter blocks.
- vector<ParameterBlock*> parameter_block_ptrs(parameter_blocks.size());
- for (int i = 0; i < parameter_blocks.size(); ++i) {
+ vector<ParameterBlock*> parameter_block_ptrs(num_parameter_blocks);
+ for (int i = 0; i < num_parameter_blocks; ++i) {
parameter_block_ptrs[i] =
InternalAddParameterBlock(parameter_blocks[i],
parameter_block_sizes[i]);
@@ -351,7 +351,7 @@
// Add dependencies on the residual to the parameter blocks.
if (options_.enable_fast_removal) {
- for (int i = 0; i < parameter_blocks.size(); ++i) {
+ for (int i = 0; i < num_parameter_blocks; ++i) {
parameter_block_ptrs[i]->AddResidualBlock(new_residual_block);
}
}
@@ -377,147 +377,6 @@
return new_residual_block;
}
-// Unfortunately, macros don't help much to reduce this code, and var args don't
-// work because of the ambiguous case that there is no loss function.
-ResidualBlock* ProblemImpl::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0) {
- residual_parameters_.clear();
- residual_parameters_.push_back(x0);
- return AddResidualBlock(cost_function, loss_function, residual_parameters_);
-}
-
-ResidualBlock* ProblemImpl::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1) {
- residual_parameters_.clear();
- residual_parameters_.push_back(x0);
- residual_parameters_.push_back(x1);
- return AddResidualBlock(cost_function, loss_function, residual_parameters_);
-}
-
-ResidualBlock* ProblemImpl::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2) {
- residual_parameters_.clear();
- residual_parameters_.push_back(x0);
- residual_parameters_.push_back(x1);
- residual_parameters_.push_back(x2);
- return AddResidualBlock(cost_function, loss_function, residual_parameters_);
-}
-
-ResidualBlock* ProblemImpl::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2, double* x3) {
- residual_parameters_.clear();
- residual_parameters_.push_back(x0);
- residual_parameters_.push_back(x1);
- residual_parameters_.push_back(x2);
- residual_parameters_.push_back(x3);
- return AddResidualBlock(cost_function, loss_function, residual_parameters_);
-}
-
-ResidualBlock* ProblemImpl::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2, double* x3, double* x4) {
- residual_parameters_.clear();
- residual_parameters_.push_back(x0);
- residual_parameters_.push_back(x1);
- residual_parameters_.push_back(x2);
- residual_parameters_.push_back(x3);
- residual_parameters_.push_back(x4);
- return AddResidualBlock(cost_function, loss_function, residual_parameters_);
-}
-
-ResidualBlock* ProblemImpl::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2, double* x3, double* x4, double* x5) {
- residual_parameters_.clear();
- residual_parameters_.push_back(x0);
- residual_parameters_.push_back(x1);
- residual_parameters_.push_back(x2);
- residual_parameters_.push_back(x3);
- residual_parameters_.push_back(x4);
- residual_parameters_.push_back(x5);
- return AddResidualBlock(cost_function, loss_function, residual_parameters_);
-}
-
-ResidualBlock* ProblemImpl::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2, double* x3, double* x4, double* x5,
- double* x6) {
- residual_parameters_.clear();
- residual_parameters_.push_back(x0);
- residual_parameters_.push_back(x1);
- residual_parameters_.push_back(x2);
- residual_parameters_.push_back(x3);
- residual_parameters_.push_back(x4);
- residual_parameters_.push_back(x5);
- residual_parameters_.push_back(x6);
- return AddResidualBlock(cost_function, loss_function, residual_parameters_);
-}
-
-ResidualBlock* ProblemImpl::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2, double* x3, double* x4, double* x5,
- double* x6, double* x7) {
- residual_parameters_.clear();
- residual_parameters_.push_back(x0);
- residual_parameters_.push_back(x1);
- residual_parameters_.push_back(x2);
- residual_parameters_.push_back(x3);
- residual_parameters_.push_back(x4);
- residual_parameters_.push_back(x5);
- residual_parameters_.push_back(x6);
- residual_parameters_.push_back(x7);
- return AddResidualBlock(cost_function, loss_function, residual_parameters_);
-}
-
-ResidualBlock* ProblemImpl::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2, double* x3, double* x4, double* x5,
- double* x6, double* x7, double* x8) {
- residual_parameters_.clear();
- residual_parameters_.push_back(x0);
- residual_parameters_.push_back(x1);
- residual_parameters_.push_back(x2);
- residual_parameters_.push_back(x3);
- residual_parameters_.push_back(x4);
- residual_parameters_.push_back(x5);
- residual_parameters_.push_back(x6);
- residual_parameters_.push_back(x7);
- residual_parameters_.push_back(x8);
- return AddResidualBlock(cost_function, loss_function, residual_parameters_);
-}
-
-ResidualBlock* ProblemImpl::AddResidualBlock(
- CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2, double* x3, double* x4, double* x5,
- double* x6, double* x7, double* x8, double* x9) {
- residual_parameters_.clear();
- residual_parameters_.push_back(x0);
- residual_parameters_.push_back(x1);
- residual_parameters_.push_back(x2);
- residual_parameters_.push_back(x3);
- residual_parameters_.push_back(x4);
- residual_parameters_.push_back(x5);
- residual_parameters_.push_back(x6);
- residual_parameters_.push_back(x7);
- residual_parameters_.push_back(x8);
- residual_parameters_.push_back(x9);
- return AddResidualBlock(cost_function, loss_function, residual_parameters_);
-}
-
void ProblemImpl::AddParameterBlock(double* values, int size) {
InternalAddParameterBlock(values, size);
}
diff --git a/internal/ceres/problem_impl.h b/internal/ceres/problem_impl.h
index ff89e94..eabeaed 100644
--- a/internal/ceres/problem_impl.h
+++ b/internal/ceres/problem_impl.h
@@ -39,6 +39,7 @@
#ifndef CERES_PUBLIC_PROBLEM_IMPL_H_
#define CERES_PUBLIC_PROBLEM_IMPL_H_
+#include <array>
#include <map>
#include <memory>
#include <unordered_set>
@@ -79,49 +80,21 @@
ResidualBlockId AddResidualBlock(
CostFunction* cost_function,
LossFunction* loss_function,
- const std::vector<double*>& parameter_blocks);
+ double* const* const parameter_blocks,
+ int num_parameter_blocks);
+
+ template <typename... Ts>
ResidualBlockId AddResidualBlock(CostFunction* cost_function,
LossFunction* loss_function,
- double* x0);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2,
- double* x3);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2,
- double* x3, double* x4);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2,
- double* x3, double* x4, double* x5);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2,
- double* x3, double* x4, double* x5,
- double* x6);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2,
- double* x3, double* x4, double* x5,
- double* x6, double* x7);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2,
- double* x3, double* x4, double* x5,
- double* x6, double* x7, double* x8);
- ResidualBlockId AddResidualBlock(CostFunction* cost_function,
- LossFunction* loss_function,
- double* x0, double* x1, double* x2,
- double* x3, double* x4, double* x5,
- double* x6, double* x7, double* x8,
- double* x9);
+ double* x0,
+ Ts*... xs) {
+ const std::array<double*, sizeof...(Ts) + 1> parameter_blocks{{x0, xs...}};
+ return AddResidualBlock(cost_function,
+ loss_function,
+ parameter_blocks.data(),
+ static_cast<int>(parameter_blocks.size()));
+ }
+
void AddParameterBlock(double* values, int size);
void AddParameterBlock(double* values,
int size,
@@ -194,7 +167,7 @@
// Delete the arguments in question. These differ from the Remove* functions
// in that they do not clean up references to the block to delete; they
// merely delete them.
- template<typename Block>
+ template <typename Block>
void DeleteBlockInVector(std::vector<Block*>* mutable_blocks,
Block* block_to_remove);
void DeleteBlock(ResidualBlock* residual_block);
@@ -228,7 +201,6 @@
// destroyed.
CostFunctionRefCount cost_function_ref_count_;
LossFunctionRefCount loss_function_ref_count_;
- std::vector<double*> residual_parameters_;
};
} // namespace internal
diff --git a/internal/ceres/problem_test.cc b/internal/ceres/problem_test.cc
index 937f84e..3f9f804 100644
--- a/internal/ceres/problem_test.cc
+++ b/internal/ceres/problem_test.cc
@@ -142,7 +142,7 @@
// UnaryCostFunction takes only one parameter, but two are passed.
EXPECT_DEATH_IF_SUPPORTED(
problem.AddResidualBlock(new UnaryCostFunction(2, 3), NULL, x, y),
- "parameter_blocks.size");
+ "num_parameter_blocks");
}
TEST(Problem, AddResidualWithDifferentSizesOnTheSameVariableDies) {
diff --git a/internal/ceres/program_test.cc b/internal/ceres/program_test.cc
index 6cb316e..6cb8e9e 100644
--- a/internal/ceres/program_test.cc
+++ b/internal/ceres/program_test.cc
@@ -352,7 +352,8 @@
problem.AddResidualBlock(new NumParameterBlocksCostFunction<1, 20>(),
nullptr,
- parameter_blocks);
+ parameter_blocks.data(),
+ static_cast<int>(parameter_blocks.size()));
TripletSparseMatrix expected_block_sparse_jacobian(20, 1, 20);
{