Add more inspection methods to Problem. Problem::GetCostFunctionForResidualBlock Problem::GetLossFunctionForResidualBlock are added, so that users do not have to maintain this mapping outside the Problem. Change-Id: I38356dfa094b2c7eec90651dafeaf3a33c5f5f56
diff --git a/include/ceres/problem.h b/include/ceres/problem.h index b1cb99a..f75ede3 100644 --- a/include/ceres/problem.h +++ b/include/ceres/problem.h
@@ -368,6 +368,15 @@ const ResidualBlockId residual_block, vector<double*>* parameter_blocks) const; + // Get the CostFunction for the given residual block. + const CostFunction* GetCostFunctionForResidualBlock( + const ResidualBlockId residual_block) const; + + // Get the LossFunction for the given residual block. Returns NULL + // if no loss function is associated with this residual block. + const LossFunction* GetLossFunctionForResidualBlock( + const ResidualBlockId residual_block) const; + // Get all the residual blocks that depend on the given parameter block. // // If Problem::Options::enable_fast_removal is true, then
diff --git a/internal/ceres/problem.cc b/internal/ceres/problem.cc index 674694d..bbfaa98 100644 --- a/internal/ceres/problem.cc +++ b/internal/ceres/problem.cc
@@ -251,6 +251,16 @@ parameter_blocks); } +const CostFunction* Problem::GetCostFunctionForResidualBlock( + const ResidualBlockId residual_block) const { + return problem_impl_->GetCostFunctionForResidualBlock(residual_block); +} + +const LossFunction* Problem::GetLossFunctionForResidualBlock( + const ResidualBlockId residual_block) const { + return problem_impl_->GetLossFunctionForResidualBlock(residual_block); +} + void Problem::GetResidualBlocksForParameterBlock( const double* values, vector<ResidualBlockId>* residual_blocks) const {
diff --git a/internal/ceres/problem_impl.cc b/internal/ceres/problem_impl.cc index 7c86efb..67cac94 100644 --- a/internal/ceres/problem_impl.cc +++ b/internal/ceres/problem_impl.cc
@@ -823,6 +823,16 @@ } } +const CostFunction* ProblemImpl::GetCostFunctionForResidualBlock( + const ResidualBlockId residual_block) const { + return residual_block->cost_function(); +} + +const LossFunction* ProblemImpl::GetLossFunctionForResidualBlock( + const ResidualBlockId residual_block) const { + return residual_block->loss_function(); +} + void ProblemImpl::GetResidualBlocksForParameterBlock( const double* values, vector<ResidualBlockId>* residual_blocks) const {
diff --git a/internal/ceres/problem_impl.h b/internal/ceres/problem_impl.h index 7b5547b..3d84de8 100644 --- a/internal/ceres/problem_impl.h +++ b/internal/ceres/problem_impl.h
@@ -157,6 +157,11 @@ const ResidualBlockId residual_block, vector<double*>* parameter_blocks) const; + const CostFunction* GetCostFunctionForResidualBlock( + const ResidualBlockId residual_block) const; + const LossFunction* GetLossFunctionForResidualBlock( + const ResidualBlockId residual_block) const; + void GetResidualBlocksForParameterBlock( const double* values, vector<ResidualBlockId>* residual_blocks) const;
diff --git a/internal/ceres/problem_test.cc b/internal/ceres/problem_test.cc index db082ec..36e4996 100644 --- a/internal/ceres/problem_test.cc +++ b/internal/ceres/problem_test.cc
@@ -39,6 +39,7 @@ #include "ceres/internal/eigen.h" #include "ceres/internal/scoped_ptr.h" #include "ceres/local_parameterization.h" +#include "ceres/loss_function.h" #include "ceres/map_util.h" #include "ceres/parameter_block.h" #include "ceres/program.h" @@ -342,6 +343,30 @@ CHECK_EQ(num_destructions, 1); } +TEST(Problem, GetCostFunctionForResidualBlock) { + double x[3]; + Problem problem; + CostFunction* cost_function = new UnaryCostFunction(2, 3); + const ResidualBlockId residual_block = + problem.AddResidualBlock(cost_function, NULL, x); + EXPECT_EQ(problem.GetCostFunctionForResidualBlock(residual_block), + cost_function); + EXPECT_TRUE(problem.GetLossFunctionForResidualBlock(residual_block) == NULL); +} + +TEST(Problem, GetLossFunctionForResidualBlock) { + double x[3]; + Problem problem; + CostFunction* cost_function = new UnaryCostFunction(2, 3); + LossFunction* loss_function = new TrivialLoss(); + const ResidualBlockId residual_block = + problem.AddResidualBlock(cost_function, loss_function, x); + EXPECT_EQ(problem.GetCostFunctionForResidualBlock(residual_block), + cost_function); + EXPECT_EQ(problem.GetLossFunctionForResidualBlock(residual_block), + loss_function); +} + TEST(Problem, CostFunctionsAreDeletedEvenWithRemovals) { double y[4], z[5], w[4]; int num_destructions = 0;