Add Problem::HasParameterization This will help the transition from LocalParameterization to Manifolds, since most uses of Problem::GetParameterization is to just check whether a parameter block has a local parameterization associated with it or not. Change-Id: Ib3539f377eaed853d7542c9844ec1487aa0fb4d6
diff --git a/include/ceres/problem.h b/include/ceres/problem.h index add12ea..734a968 100644 --- a/include/ceres/problem.h +++ b/include/ceres/problem.h
@@ -329,6 +329,10 @@ // associated then nullptr is returned. const LocalParameterization* GetParameterization(const double* values) const; + // Returns true if a parameterization is associated with this parameter block, + // false otherwise. + bool HasParameterization(const double* values) const; + // Set the lower/upper bound for the parameter at position "index". void SetParameterLowerBound(double* values, int index, double lower_bound); void SetParameterUpperBound(double* values, int index, double upper_bound);
diff --git a/internal/ceres/problem.cc b/internal/ceres/problem.cc index f3ffd54..de4a2b3 100644 --- a/internal/ceres/problem.cc +++ b/internal/ceres/problem.cc
@@ -106,6 +106,10 @@ return impl_->GetParameterization(values); } +bool Problem::HasParameterization(const double* values) const { + return impl_->HasParameterization(values); +} + void Problem::SetParameterLowerBound(double* values, int index, double lower_bound) {
diff --git a/internal/ceres/problem_impl.cc b/internal/ceres/problem_impl.cc index 3155bc3..492cabc 100644 --- a/internal/ceres/problem_impl.cc +++ b/internal/ceres/problem_impl.cc
@@ -543,6 +543,18 @@ return parameter_block->local_parameterization(); } +bool ProblemImpl::HasParameterization(const double* values) const { + ParameterBlock* parameter_block = FindWithDefault( + parameter_block_map_, const_cast<double*>(values), nullptr); + if (parameter_block == nullptr) { + LOG(FATAL) << "Parameter block not found: " << values + << ". You must add the parameter block to the problem before " + << "you can get its local parameterization."; + } + + return (parameter_block->local_parameterization() != nullptr); +} + void ProblemImpl::SetParameterLowerBound(double* values, int index, double lower_bound) {
diff --git a/internal/ceres/problem_impl.h b/internal/ceres/problem_impl.h index 9abff3f..5051c57 100644 --- a/internal/ceres/problem_impl.h +++ b/internal/ceres/problem_impl.h
@@ -110,6 +110,7 @@ void SetParameterization(double* values, LocalParameterization* local_parameterization); const LocalParameterization* GetParameterization(const double* values) const; + bool HasParameterization(const double* values) const; void SetParameterLowerBound(double* values, int index, double lower_bound); void SetParameterUpperBound(double* values, int index, double upper_bound);
diff --git a/internal/ceres/problem_test.cc b/internal/ceres/problem_test.cc index 9f129df..67faa7e 100644 --- a/internal/ceres/problem_test.cc +++ b/internal/ceres/problem_test.cc
@@ -612,6 +612,20 @@ EXPECT_TRUE(problem.GetParameterization(y) == NULL); } +TEST(Problem, HasParameterization) { + double x[3]; + double y[2]; + + Problem problem; + problem.AddParameterBlock(x, 3); + problem.AddParameterBlock(y, 2); + + LocalParameterization* parameterization = new IdentityParameterization(3); + problem.SetParameterization(x, parameterization); + EXPECT_TRUE(problem.HasParameterization(x)); + EXPECT_FALSE(problem.HasParameterization(y)); +} + TEST(Problem, ParameterBlockQueryTest) { double x[3]; double y[4];