Add the ability to query the Problem about parameter blocks.
Change-Id: Ieda1aefa28e7a1d18fe6c8d1665882e4d9c274f2
diff --git a/include/ceres/problem.h b/include/ceres/problem.h
index 0a449cb..707a8eb 100644
--- a/include/ceres/problem.h
+++ b/include/ceres/problem.h
@@ -328,6 +328,19 @@
// sizes of all of the residual blocks.
int NumResiduals() const;
+ // The size of the parameter block.
+ int ParameterBlockSize(double* values) const;
+
+ // The size of local parameterization for the parameter block. If
+ // there is no local parameterization associated with this parameter
+ // block, then ParmeterBlockLocalSize = ParameterBlockSize.
+ int ParameterBlockLocalSize(double* values) const;
+
+ // Fills the passed parameter_blocks vector with pointers to the
+ // parameter blocks currently in the problem. After this call,
+ // parameter_block.size() == NumParameterBlocks.
+ void GetParameterBlocks(vector<double*>* parameter_blocks) const;
+
// Options struct to control Problem::Evaluate.
struct EvaluateOptions {
EvaluateOptions()
diff --git a/internal/ceres/problem.cc b/internal/ceres/problem.cc
index 43e7883..b483932 100644
--- a/internal/ceres/problem.cc
+++ b/internal/ceres/problem.cc
@@ -206,4 +206,16 @@
return problem_impl_->NumResiduals();
}
+int Problem::ParameterBlockSize(double* parameter_block) const {
+ return problem_impl_->ParameterBlockSize(parameter_block);
+};
+
+int Problem::ParameterBlockLocalSize(double* parameter_block) const {
+ return problem_impl_->ParameterBlockLocalSize(parameter_block);
+};
+
+void Problem::GetParameterBlocks(vector<double*>* parameter_blocks) const {
+ problem_impl_->GetParameterBlocks(parameter_blocks);
+}
+
} // namespace ceres
diff --git a/internal/ceres/problem_impl.cc b/internal/ceres/problem_impl.cc
index f4615f9..34c3785 100644
--- a/internal/ceres/problem_impl.cc
+++ b/internal/ceres/problem_impl.cc
@@ -711,5 +711,25 @@
return program_->NumResiduals();
}
+int ProblemImpl::ParameterBlockSize(double* parameter_block) const {
+ return FindParameterBlockOrDie(parameter_block_map_, parameter_block)->Size();
+};
+
+int ProblemImpl::ParameterBlockLocalSize(double* parameter_block) const {
+ return FindParameterBlockOrDie(parameter_block_map_,
+ parameter_block)->LocalSize();
+};
+
+void ProblemImpl::GetParameterBlocks(vector<double*>* parameter_blocks) const {
+ CHECK_NOTNULL(parameter_blocks);
+ parameter_blocks->resize(0);
+ for (ParameterMap::const_iterator it = parameter_block_map_.begin();
+ it != parameter_block_map_.end();
+ ++it) {
+ parameter_blocks->push_back(it->first);
+ }
+}
+
+
} // namespace internal
} // namespace ceres
diff --git a/internal/ceres/problem_impl.h b/internal/ceres/problem_impl.h
index ccc315d..2609389 100644
--- a/internal/ceres/problem_impl.h
+++ b/internal/ceres/problem_impl.h
@@ -139,6 +139,10 @@
int NumResidualBlocks() const;
int NumResiduals() const;
+ int ParameterBlockSize(double* parameter_block) const;
+ int ParameterBlockLocalSize(double* parameter_block) const;
+ void GetParameterBlocks(vector<double*>* parameter_blocks) const;
+
const Program& program() const { return *program_; }
Program* mutable_program() { return program_.get(); }
diff --git a/internal/ceres/problem_test.cc b/internal/ceres/problem_test.cc
index ab40e05..0944d3f 100644
--- a/internal/ceres/problem_test.cc
+++ b/internal/ceres/problem_test.cc
@@ -502,6 +502,35 @@
problem.RemoveParameterBlock(y), "Parameter block not found:");
}
+TEST(Problem, ParameterBlockQueryTest) {
+ double x[3];
+ double y[4];
+ Problem problem;
+ problem.AddParameterBlock(x, 3);
+ problem.AddParameterBlock(y, 4);
+
+ vector<int> constant_parameters;
+ constant_parameters.push_back(0);
+ problem.SetParameterization(
+ x,
+ new SubsetParameterization(3, constant_parameters));
+ EXPECT_EQ(problem.ParameterBlockSize(x), 3);
+ EXPECT_EQ(problem.ParameterBlockLocalSize(x), 2);
+ EXPECT_EQ(problem.ParameterBlockLocalSize(y), 4);
+
+ vector<double*> parameter_blocks;
+ problem.GetParameterBlocks(¶meter_blocks);
+ EXPECT_EQ(parameter_blocks.size(), 2);
+ EXPECT_NE(parameter_blocks[0], parameter_blocks[1]);
+ EXPECT_TRUE(parameter_blocks[0] == x || parameter_blocks[0] == y);
+ EXPECT_TRUE(parameter_blocks[1] == x || parameter_blocks[1] == y);
+
+ problem.RemoveParameterBlock(x);
+ problem.GetParameterBlocks(¶meter_blocks);
+ EXPECT_EQ(parameter_blocks.size(), 1);
+ EXPECT_TRUE(parameter_blocks[0] == y);
+}
+
TEST_P(DynamicProblem, RemoveParameterBlockWithNoResiduals) {
problem->AddParameterBlock(y, 4);
problem->AddParameterBlock(z, 5);