Modernize ProductParameterization. This CL modernizes ProductParameterization in the following ways: - It uses std::unique_ptr for memory handling instead of using raw pointers and handmade memory management. - Replaces the constructors with a variadic template. Change-Id: I5c9fe42ac935b6c26e867dbd3369a4c766623047
diff --git a/include/ceres/local_parameterization.h b/include/ceres/local_parameterization.h index 5eed035..338ab54 100644 --- a/include/ceres/local_parameterization.h +++ b/include/ceres/local_parameterization.h
@@ -32,9 +32,11 @@ #ifndef CERES_PUBLIC_LOCAL_PARAMETERIZATION_H_ #define CERES_PUBLIC_LOCAL_PARAMETERIZATION_H_ +#include <array> +#include <memory> #include <vector> -#include "ceres/internal/port.h" #include "ceres/internal/disable_warnings.h" +#include "ceres/internal/port.h" namespace ceres { @@ -269,9 +271,6 @@ // manifolds. For example the parameters of a camera consist of a // rotation and a translation, i.e., SO(3) x R^3. // -// Currently this class supports taking the cartesian product of up to -// four local parameterizations. -// // Example usage: // // ProductParameterization product_param(new QuaterionionParameterization(), @@ -282,22 +281,37 @@ class CERES_EXPORT ProductParameterization : public LocalParameterization { public: // - // NOTE: All the constructors take ownership of the input local + // NOTE: The constructor takes ownership of the input local // parameterizations. // - ProductParameterization(LocalParameterization* local_param1, - LocalParameterization* local_param2); + template <typename... LocalParams> + ProductParameterization(LocalParams*... local_params) + : local_params_(sizeof...(LocalParams)), + local_size_{0}, + global_size_{0}, + buffer_size_{0} { + constexpr int kNumLocalParams = sizeof...(LocalParams); + static_assert(kNumLocalParams >= 2, + "At least two local parameterizations must be specified."); - ProductParameterization(LocalParameterization* local_param1, - LocalParameterization* local_param2, - LocalParameterization* local_param3); + using LocalParameterizationPtr = std::unique_ptr<LocalParameterization>; - ProductParameterization(LocalParameterization* local_param1, - LocalParameterization* local_param2, - LocalParameterization* local_param3, - LocalParameterization* local_param4); + // Wrap all raw pointers into std::unique_ptr for exception safety. + std::array<LocalParameterizationPtr, kNumLocalParams> local_params_array{ + LocalParameterizationPtr(local_params)...}; - virtual ~ProductParameterization(); + // Initialize internal state. + for (int i = 0; i < kNumLocalParams; ++i) { + LocalParameterizationPtr& param = local_params_[i]; + param = std::move(local_params_array[i]); + + buffer_size_ = + std::max(buffer_size_, param->LocalSize() * param->GlobalSize()); + global_size_ += param->GlobalSize(); + local_size_ += param->LocalSize(); + } + } + virtual bool Plus(const double* x, const double* delta, double* x_plus_delta) const; @@ -307,9 +321,7 @@ virtual int LocalSize() const { return local_size_; } private: - void Init(); - - std::vector<LocalParameterization*> local_params_; + std::vector<std::unique_ptr<LocalParameterization>> local_params_; int local_size_; int global_size_; int buffer_size_;
diff --git a/internal/ceres/local_parameterization.cc b/internal/ceres/local_parameterization.cc index 02ed4c9..4d63594 100644 --- a/internal/ceres/local_parameterization.cc +++ b/internal/ceres/local_parameterization.cc
@@ -287,62 +287,12 @@ return true; } -ProductParameterization::ProductParameterization( - LocalParameterization* local_param1, - LocalParameterization* local_param2) { - local_params_.push_back(local_param1); - local_params_.push_back(local_param2); - Init(); -} - -ProductParameterization::ProductParameterization( - LocalParameterization* local_param1, - LocalParameterization* local_param2, - LocalParameterization* local_param3) { - local_params_.push_back(local_param1); - local_params_.push_back(local_param2); - local_params_.push_back(local_param3); - Init(); -} - -ProductParameterization::ProductParameterization( - LocalParameterization* local_param1, - LocalParameterization* local_param2, - LocalParameterization* local_param3, - LocalParameterization* local_param4) { - local_params_.push_back(local_param1); - local_params_.push_back(local_param2); - local_params_.push_back(local_param3); - local_params_.push_back(local_param4); - Init(); -} - -ProductParameterization::~ProductParameterization() { - for (int i = 0; i < local_params_.size(); ++i) { - delete local_params_[i]; - } -} - -void ProductParameterization::Init() { - global_size_ = 0; - local_size_ = 0; - buffer_size_ = 0; - for (int i = 0; i < local_params_.size(); ++i) { - const LocalParameterization* param = local_params_[i]; - buffer_size_ = std::max(buffer_size_, - param->LocalSize() * param->GlobalSize()); - global_size_ += param->GlobalSize(); - local_size_ += param->LocalSize(); - } -} - bool ProductParameterization::Plus(const double* x, const double* delta, double* x_plus_delta) const { int x_cursor = 0; int delta_cursor = 0; - for (int i = 0; i < local_params_.size(); ++i) { - const LocalParameterization* param = local_params_[i]; + for (const auto& param : local_params_) { if (!param->Plus(x + x_cursor, delta + delta_cursor, x_plus_delta + x_cursor)) { @@ -363,8 +313,7 @@ int x_cursor = 0; int delta_cursor = 0; - for (int i = 0; i < local_params_.size(); ++i) { - const LocalParameterization* param = local_params_[i]; + for (const auto& param : local_params_) { const int local_size = param->LocalSize(); const int global_size = param->GlobalSize();