Add the ability to query the Problem about parameter blocks. Change-Id: Ieda1aefa28e7a1d18fe6c8d1665882e4d9c274f2
diff --git a/include/ceres/problem.h b/include/ceres/problem.h index 0a449cb..707a8eb 100644 --- a/include/ceres/problem.h +++ b/include/ceres/problem.h
@@ -328,6 +328,19 @@ // sizes of all of the residual blocks. int NumResiduals() const; + // The size of the parameter block. + int ParameterBlockSize(double* values) const; + + // The size of local parameterization for the parameter block. If + // there is no local parameterization associated with this parameter + // block, then ParmeterBlockLocalSize = ParameterBlockSize. + int ParameterBlockLocalSize(double* values) const; + + // Fills the passed parameter_blocks vector with pointers to the + // parameter blocks currently in the problem. After this call, + // parameter_block.size() == NumParameterBlocks. + void GetParameterBlocks(vector<double*>* parameter_blocks) const; + // Options struct to control Problem::Evaluate. struct EvaluateOptions { EvaluateOptions()
diff --git a/internal/ceres/problem.cc b/internal/ceres/problem.cc index 43e7883..b483932 100644 --- a/internal/ceres/problem.cc +++ b/internal/ceres/problem.cc
@@ -206,4 +206,16 @@ return problem_impl_->NumResiduals(); } +int Problem::ParameterBlockSize(double* parameter_block) const { + return problem_impl_->ParameterBlockSize(parameter_block); +}; + +int Problem::ParameterBlockLocalSize(double* parameter_block) const { + return problem_impl_->ParameterBlockLocalSize(parameter_block); +}; + +void Problem::GetParameterBlocks(vector<double*>* parameter_blocks) const { + problem_impl_->GetParameterBlocks(parameter_blocks); +} + } // namespace ceres
diff --git a/internal/ceres/problem_impl.cc b/internal/ceres/problem_impl.cc index f4615f9..34c3785 100644 --- a/internal/ceres/problem_impl.cc +++ b/internal/ceres/problem_impl.cc
@@ -711,5 +711,25 @@ return program_->NumResiduals(); } +int ProblemImpl::ParameterBlockSize(double* parameter_block) const { + return FindParameterBlockOrDie(parameter_block_map_, parameter_block)->Size(); +}; + +int ProblemImpl::ParameterBlockLocalSize(double* parameter_block) const { + return FindParameterBlockOrDie(parameter_block_map_, + parameter_block)->LocalSize(); +}; + +void ProblemImpl::GetParameterBlocks(vector<double*>* parameter_blocks) const { + CHECK_NOTNULL(parameter_blocks); + parameter_blocks->resize(0); + for (ParameterMap::const_iterator it = parameter_block_map_.begin(); + it != parameter_block_map_.end(); + ++it) { + parameter_blocks->push_back(it->first); + } +} + + } // namespace internal } // namespace ceres
diff --git a/internal/ceres/problem_impl.h b/internal/ceres/problem_impl.h index ccc315d..2609389 100644 --- a/internal/ceres/problem_impl.h +++ b/internal/ceres/problem_impl.h
@@ -139,6 +139,10 @@ int NumResidualBlocks() const; int NumResiduals() const; + int ParameterBlockSize(double* parameter_block) const; + int ParameterBlockLocalSize(double* parameter_block) const; + void GetParameterBlocks(vector<double*>* parameter_blocks) const; + const Program& program() const { return *program_; } Program* mutable_program() { return program_.get(); }
diff --git a/internal/ceres/problem_test.cc b/internal/ceres/problem_test.cc index ab40e05..0944d3f 100644 --- a/internal/ceres/problem_test.cc +++ b/internal/ceres/problem_test.cc
@@ -502,6 +502,35 @@ problem.RemoveParameterBlock(y), "Parameter block not found:"); } +TEST(Problem, ParameterBlockQueryTest) { + double x[3]; + double y[4]; + Problem problem; + problem.AddParameterBlock(x, 3); + problem.AddParameterBlock(y, 4); + + vector<int> constant_parameters; + constant_parameters.push_back(0); + problem.SetParameterization( + x, + new SubsetParameterization(3, constant_parameters)); + EXPECT_EQ(problem.ParameterBlockSize(x), 3); + EXPECT_EQ(problem.ParameterBlockLocalSize(x), 2); + EXPECT_EQ(problem.ParameterBlockLocalSize(y), 4); + + vector<double*> parameter_blocks; + problem.GetParameterBlocks(¶meter_blocks); + EXPECT_EQ(parameter_blocks.size(), 2); + EXPECT_NE(parameter_blocks[0], parameter_blocks[1]); + EXPECT_TRUE(parameter_blocks[0] == x || parameter_blocks[0] == y); + EXPECT_TRUE(parameter_blocks[1] == x || parameter_blocks[1] == y); + + problem.RemoveParameterBlock(x); + problem.GetParameterBlocks(¶meter_blocks); + EXPECT_EQ(parameter_blocks.size(), 1); + EXPECT_TRUE(parameter_blocks[0] == y); +} + TEST_P(DynamicProblem, RemoveParameterBlockWithNoResiduals) { problem->AddParameterBlock(y, 4); problem->AddParameterBlock(z, 5);