Implement AddResidualBlock using variadic templates This CL changes the implementation of AddResidualBlock() in ceres::Problem using variadic templates. Also one new overload for AddResidualBlock() is added using a double** and the number of parameter blocks. Change-Id: I007a82a06897335a117213a0d12fedb4a77076a0
diff --git a/include/ceres/problem.h b/include/ceres/problem.h index e220e66..503e0fe 100644 --- a/include/ceres/problem.h +++ b/include/ceres/problem.h
@@ -34,6 +34,7 @@ #ifndef CERES_PUBLIC_PROBLEM_H_ #define CERES_PUBLIC_PROBLEM_H_ +#include <array> #include <cstddef> #include <map> #include <memory> @@ -212,56 +213,32 @@ // problem.AddResidualBlock(new MyUnaryCostFunction(...), NULL, x1); // problem.AddResidualBlock(new MyBinaryCostFunction(...), NULL, x2, x1); // + // Add a residual block by listing the parameter block pointers + // directly instead of wapping them in a container. + template <typename... Ts> + ResidualBlockId AddResidualBlock(CostFunction* cost_function, + LossFunction* loss_function, + double* x0, + Ts*... xs) { + const std::array<double*, sizeof...(Ts) + 1> parameter_blocks{{x0, xs...}}; + return AddResidualBlock(cost_function, loss_function, + parameter_blocks.data(), + static_cast<int>(parameter_blocks.size())); + } + + // Add a residual block by providing a vector of parameter blocks. ResidualBlockId AddResidualBlock( CostFunction* cost_function, LossFunction* loss_function, const std::vector<double*>& parameter_blocks); - // Convenience methods for adding residuals with a small number of - // parameters. This is the common case. Instead of specifying the - // parameter block arguments as a vector, list them as pointers. - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, - double* x3); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, - double* x3, double* x4); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, - double* x3, double* x4, double* x5); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, - double* x3, double* x4, double* x5, - double* x6); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, - double* x3, double* x4, double* x5, - double* x6, double* x7); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, - double* x3, double* x4, double* x5, - double* x6, double* x7, double* x8); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, - double* x3, double* x4, double* x5, - double* x6, double* x7, double* x8, - double* x9); + // Add a residual block by providing a pointer to the parameter block array + // and the number of parameter blocks. + ResidualBlockId AddResidualBlock( + CostFunction* cost_function, + LossFunction* loss_function, + double* const* const parameter_blocks, + int num_parameter_blocks); // Add a parameter block with appropriate size to the problem. // Repeated calls with the same arguments are ignored. Repeated
diff --git a/internal/ceres/gradient_checking_cost_function.cc b/internal/ceres/gradient_checking_cost_function.cc index 9542af7..1afbec3 100644 --- a/internal/ceres/gradient_checking_cost_function.cc +++ b/internal/ceres/gradient_checking_cost_function.cc
@@ -266,7 +266,8 @@ gradient_checking_problem_impl->AddResidualBlock( gradient_checking_cost_function, const_cast<LossFunction*>(residual_block->loss_function()), - parameter_blocks); + parameter_blocks.data(), + static_cast<int>(parameter_blocks.size())); } // Normally, when a problem is given to the solver, we guarantee
diff --git a/internal/ceres/problem.cc b/internal/ceres/problem.cc index 956afa1..6939b46 100644 --- a/internal/ceres/problem.cc +++ b/internal/ceres/problem.cc
@@ -48,104 +48,22 @@ CostFunction* cost_function, LossFunction* loss_function, const vector<double*>& parameter_blocks) { - return problem_impl_->AddResidualBlock(cost_function, - loss_function, - parameter_blocks); -} - -ResidualBlockId Problem::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0) { - return problem_impl_->AddResidualBlock(cost_function, - loss_function, - x0); -} - -ResidualBlockId Problem::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1) { - return problem_impl_->AddResidualBlock(cost_function, - loss_function, - x0, x1); -} - -ResidualBlockId Problem::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2) { - return problem_impl_->AddResidualBlock(cost_function, - loss_function, - x0, x1, x2); -} - -ResidualBlockId Problem::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, double* x3) { - return problem_impl_->AddResidualBlock(cost_function, - loss_function, - x0, x1, x2, x3); -} - -ResidualBlockId Problem::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, double* x3, double* x4) { - return problem_impl_->AddResidualBlock(cost_function, - loss_function, - x0, x1, x2, x3, x4); -} - -ResidualBlockId Problem::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, double* x3, double* x4, double* x5) { - return problem_impl_->AddResidualBlock(cost_function, - loss_function, - x0, x1, x2, x3, x4, x5); -} - -ResidualBlockId Problem::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, double* x3, double* x4, double* x5, - double* x6) { - return problem_impl_->AddResidualBlock(cost_function, - loss_function, - x0, x1, x2, x3, x4, x5, x6); -} - -ResidualBlockId Problem::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, double* x3, double* x4, double* x5, - double* x6, double* x7) { - return problem_impl_->AddResidualBlock(cost_function, - loss_function, - x0, x1, x2, x3, x4, x5, x6, x7); -} - -ResidualBlockId Problem::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, double* x3, double* x4, double* x5, - double* x6, double* x7, double* x8) { - return problem_impl_->AddResidualBlock(cost_function, - loss_function, - x0, x1, x2, x3, x4, x5, x6, x7, x8); -} - -ResidualBlockId Problem::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, double* x3, double* x4, double* x5, - double* x6, double* x7, double* x8, double* x9) { return problem_impl_->AddResidualBlock( cost_function, loss_function, - x0, x1, x2, x3, x4, x5, x6, x7, x8, x9); + parameter_blocks.data(), + static_cast<int>(parameter_blocks.size())); +} + +ResidualBlockId Problem::AddResidualBlock( + CostFunction* cost_function, + LossFunction* loss_function, + double* const* const parameter_blocks, + int num_parameter_blocks) { + return problem_impl_->AddResidualBlock(cost_function, + loss_function, + parameter_blocks, + num_parameter_blocks); } void Problem::AddParameterBlock(double* values, int size) {
diff --git a/internal/ceres/problem_impl.cc b/internal/ceres/problem_impl.cc index 00cd2fd..40d5aa2 100644 --- a/internal/ceres/problem_impl.cc +++ b/internal/ceres/problem_impl.cc
@@ -247,14 +247,12 @@ ProblemImpl::ProblemImpl() : options_(Problem::Options()), program_(new internal::Program) { - residual_parameters_.reserve(10); InitializeContext(options_.context, &context_impl_, &context_impl_owned_); } ProblemImpl::ProblemImpl(const Problem::Options& options) : options_(options), program_(new internal::Program) { - residual_parameters_.reserve(10); InitializeContext(options_.context, &context_impl_, &context_impl_owned_); } @@ -286,12 +284,13 @@ } } -ResidualBlock* ProblemImpl::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - const vector<double*>& parameter_blocks) { +ResidualBlockId ProblemImpl::AddResidualBlock( + CostFunction* cost_function, + LossFunction* loss_function, + double* const* const parameter_blocks, + int num_parameter_blocks) { CHECK(cost_function != nullptr); - CHECK_EQ(parameter_blocks.size(), + CHECK_EQ(num_parameter_blocks, cost_function->parameter_block_sizes().size()); // Check the sizes match. @@ -299,12 +298,13 @@ cost_function->parameter_block_sizes(); if (!options_.disable_all_safety_checks) { - CHECK_EQ(parameter_block_sizes.size(), parameter_blocks.size()) + CHECK_EQ(parameter_block_sizes.size(), num_parameter_blocks) << "Number of blocks input is different than the number of blocks " << "that the cost function expects."; // Check for duplicate parameter blocks. - vector<double*> sorted_parameter_blocks(parameter_blocks); + vector<double*> sorted_parameter_blocks( + parameter_blocks, parameter_blocks + num_parameter_blocks); sort(sorted_parameter_blocks.begin(), sorted_parameter_blocks.end()); const bool has_duplicate_items = (std::adjacent_find(sorted_parameter_blocks.begin(), @@ -312,7 +312,7 @@ != sorted_parameter_blocks.end()); if (has_duplicate_items) { string blocks; - for (int i = 0; i < parameter_blocks.size(); ++i) { + for (int i = 0; i < num_parameter_blocks; ++i) { blocks += StringPrintf(" %p ", parameter_blocks[i]); } @@ -323,8 +323,8 @@ } // Add parameter blocks and convert the double*'s to parameter blocks. - vector<ParameterBlock*> parameter_block_ptrs(parameter_blocks.size()); - for (int i = 0; i < parameter_blocks.size(); ++i) { + vector<ParameterBlock*> parameter_block_ptrs(num_parameter_blocks); + for (int i = 0; i < num_parameter_blocks; ++i) { parameter_block_ptrs[i] = InternalAddParameterBlock(parameter_blocks[i], parameter_block_sizes[i]); @@ -351,7 +351,7 @@ // Add dependencies on the residual to the parameter blocks. if (options_.enable_fast_removal) { - for (int i = 0; i < parameter_blocks.size(); ++i) { + for (int i = 0; i < num_parameter_blocks; ++i) { parameter_block_ptrs[i]->AddResidualBlock(new_residual_block); } } @@ -377,147 +377,6 @@ return new_residual_block; } -// Unfortunately, macros don't help much to reduce this code, and var args don't -// work because of the ambiguous case that there is no loss function. -ResidualBlock* ProblemImpl::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0) { - residual_parameters_.clear(); - residual_parameters_.push_back(x0); - return AddResidualBlock(cost_function, loss_function, residual_parameters_); -} - -ResidualBlock* ProblemImpl::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1) { - residual_parameters_.clear(); - residual_parameters_.push_back(x0); - residual_parameters_.push_back(x1); - return AddResidualBlock(cost_function, loss_function, residual_parameters_); -} - -ResidualBlock* ProblemImpl::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2) { - residual_parameters_.clear(); - residual_parameters_.push_back(x0); - residual_parameters_.push_back(x1); - residual_parameters_.push_back(x2); - return AddResidualBlock(cost_function, loss_function, residual_parameters_); -} - -ResidualBlock* ProblemImpl::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, double* x3) { - residual_parameters_.clear(); - residual_parameters_.push_back(x0); - residual_parameters_.push_back(x1); - residual_parameters_.push_back(x2); - residual_parameters_.push_back(x3); - return AddResidualBlock(cost_function, loss_function, residual_parameters_); -} - -ResidualBlock* ProblemImpl::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, double* x3, double* x4) { - residual_parameters_.clear(); - residual_parameters_.push_back(x0); - residual_parameters_.push_back(x1); - residual_parameters_.push_back(x2); - residual_parameters_.push_back(x3); - residual_parameters_.push_back(x4); - return AddResidualBlock(cost_function, loss_function, residual_parameters_); -} - -ResidualBlock* ProblemImpl::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, double* x3, double* x4, double* x5) { - residual_parameters_.clear(); - residual_parameters_.push_back(x0); - residual_parameters_.push_back(x1); - residual_parameters_.push_back(x2); - residual_parameters_.push_back(x3); - residual_parameters_.push_back(x4); - residual_parameters_.push_back(x5); - return AddResidualBlock(cost_function, loss_function, residual_parameters_); -} - -ResidualBlock* ProblemImpl::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, double* x3, double* x4, double* x5, - double* x6) { - residual_parameters_.clear(); - residual_parameters_.push_back(x0); - residual_parameters_.push_back(x1); - residual_parameters_.push_back(x2); - residual_parameters_.push_back(x3); - residual_parameters_.push_back(x4); - residual_parameters_.push_back(x5); - residual_parameters_.push_back(x6); - return AddResidualBlock(cost_function, loss_function, residual_parameters_); -} - -ResidualBlock* ProblemImpl::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, double* x3, double* x4, double* x5, - double* x6, double* x7) { - residual_parameters_.clear(); - residual_parameters_.push_back(x0); - residual_parameters_.push_back(x1); - residual_parameters_.push_back(x2); - residual_parameters_.push_back(x3); - residual_parameters_.push_back(x4); - residual_parameters_.push_back(x5); - residual_parameters_.push_back(x6); - residual_parameters_.push_back(x7); - return AddResidualBlock(cost_function, loss_function, residual_parameters_); -} - -ResidualBlock* ProblemImpl::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, double* x3, double* x4, double* x5, - double* x6, double* x7, double* x8) { - residual_parameters_.clear(); - residual_parameters_.push_back(x0); - residual_parameters_.push_back(x1); - residual_parameters_.push_back(x2); - residual_parameters_.push_back(x3); - residual_parameters_.push_back(x4); - residual_parameters_.push_back(x5); - residual_parameters_.push_back(x6); - residual_parameters_.push_back(x7); - residual_parameters_.push_back(x8); - return AddResidualBlock(cost_function, loss_function, residual_parameters_); -} - -ResidualBlock* ProblemImpl::AddResidualBlock( - CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, double* x3, double* x4, double* x5, - double* x6, double* x7, double* x8, double* x9) { - residual_parameters_.clear(); - residual_parameters_.push_back(x0); - residual_parameters_.push_back(x1); - residual_parameters_.push_back(x2); - residual_parameters_.push_back(x3); - residual_parameters_.push_back(x4); - residual_parameters_.push_back(x5); - residual_parameters_.push_back(x6); - residual_parameters_.push_back(x7); - residual_parameters_.push_back(x8); - residual_parameters_.push_back(x9); - return AddResidualBlock(cost_function, loss_function, residual_parameters_); -} - void ProblemImpl::AddParameterBlock(double* values, int size) { InternalAddParameterBlock(values, size); }
diff --git a/internal/ceres/problem_impl.h b/internal/ceres/problem_impl.h index ff89e94..eabeaed 100644 --- a/internal/ceres/problem_impl.h +++ b/internal/ceres/problem_impl.h
@@ -39,6 +39,7 @@ #ifndef CERES_PUBLIC_PROBLEM_IMPL_H_ #define CERES_PUBLIC_PROBLEM_IMPL_H_ +#include <array> #include <map> #include <memory> #include <unordered_set> @@ -79,49 +80,21 @@ ResidualBlockId AddResidualBlock( CostFunction* cost_function, LossFunction* loss_function, - const std::vector<double*>& parameter_blocks); + double* const* const parameter_blocks, + int num_parameter_blocks); + + template <typename... Ts> ResidualBlockId AddResidualBlock(CostFunction* cost_function, LossFunction* loss_function, - double* x0); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, - double* x3); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, - double* x3, double* x4); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, - double* x3, double* x4, double* x5); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, - double* x3, double* x4, double* x5, - double* x6); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, - double* x3, double* x4, double* x5, - double* x6, double* x7); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, - double* x3, double* x4, double* x5, - double* x6, double* x7, double* x8); - ResidualBlockId AddResidualBlock(CostFunction* cost_function, - LossFunction* loss_function, - double* x0, double* x1, double* x2, - double* x3, double* x4, double* x5, - double* x6, double* x7, double* x8, - double* x9); + double* x0, + Ts*... xs) { + const std::array<double*, sizeof...(Ts) + 1> parameter_blocks{{x0, xs...}}; + return AddResidualBlock(cost_function, + loss_function, + parameter_blocks.data(), + static_cast<int>(parameter_blocks.size())); + } + void AddParameterBlock(double* values, int size); void AddParameterBlock(double* values, int size, @@ -194,7 +167,7 @@ // Delete the arguments in question. These differ from the Remove* functions // in that they do not clean up references to the block to delete; they // merely delete them. - template<typename Block> + template <typename Block> void DeleteBlockInVector(std::vector<Block*>* mutable_blocks, Block* block_to_remove); void DeleteBlock(ResidualBlock* residual_block); @@ -228,7 +201,6 @@ // destroyed. CostFunctionRefCount cost_function_ref_count_; LossFunctionRefCount loss_function_ref_count_; - std::vector<double*> residual_parameters_; }; } // namespace internal
diff --git a/internal/ceres/problem_test.cc b/internal/ceres/problem_test.cc index 937f84e..3f9f804 100644 --- a/internal/ceres/problem_test.cc +++ b/internal/ceres/problem_test.cc
@@ -142,7 +142,7 @@ // UnaryCostFunction takes only one parameter, but two are passed. EXPECT_DEATH_IF_SUPPORTED( problem.AddResidualBlock(new UnaryCostFunction(2, 3), NULL, x, y), - "parameter_blocks.size"); + "num_parameter_blocks"); } TEST(Problem, AddResidualWithDifferentSizesOnTheSameVariableDies) {
diff --git a/internal/ceres/program_test.cc b/internal/ceres/program_test.cc index 6cb316e..6cb8e9e 100644 --- a/internal/ceres/program_test.cc +++ b/internal/ceres/program_test.cc
@@ -352,7 +352,8 @@ problem.AddResidualBlock(new NumParameterBlocksCostFunction<1, 20>(), nullptr, - parameter_blocks); + parameter_blocks.data(), + static_cast<int>(parameter_blocks.size())); TripletSparseMatrix expected_block_sparse_jacobian(20, 1, 20); {