Add more inspection methods to Problem.
Problem::GetCostFunctionForResidualBlock
Problem::GetLossFunctionForResidualBlock
are added, so that users do not have to maintain this mapping
outside the Problem.
Change-Id: I38356dfa094b2c7eec90651dafeaf3a33c5f5f56
diff --git a/include/ceres/problem.h b/include/ceres/problem.h
index b1cb99a..f75ede3 100644
--- a/include/ceres/problem.h
+++ b/include/ceres/problem.h
@@ -368,6 +368,15 @@
const ResidualBlockId residual_block,
vector<double*>* parameter_blocks) const;
+ // Get the CostFunction for the given residual block.
+ const CostFunction* GetCostFunctionForResidualBlock(
+ const ResidualBlockId residual_block) const;
+
+ // Get the LossFunction for the given residual block. Returns NULL
+ // if no loss function is associated with this residual block.
+ const LossFunction* GetLossFunctionForResidualBlock(
+ const ResidualBlockId residual_block) const;
+
// Get all the residual blocks that depend on the given parameter block.
//
// If Problem::Options::enable_fast_removal is true, then
diff --git a/internal/ceres/problem.cc b/internal/ceres/problem.cc
index 674694d..bbfaa98 100644
--- a/internal/ceres/problem.cc
+++ b/internal/ceres/problem.cc
@@ -251,6 +251,16 @@
parameter_blocks);
}
+const CostFunction* Problem::GetCostFunctionForResidualBlock(
+ const ResidualBlockId residual_block) const {
+ return problem_impl_->GetCostFunctionForResidualBlock(residual_block);
+}
+
+const LossFunction* Problem::GetLossFunctionForResidualBlock(
+ const ResidualBlockId residual_block) const {
+ return problem_impl_->GetLossFunctionForResidualBlock(residual_block);
+}
+
void Problem::GetResidualBlocksForParameterBlock(
const double* values,
vector<ResidualBlockId>* residual_blocks) const {
diff --git a/internal/ceres/problem_impl.cc b/internal/ceres/problem_impl.cc
index 7c86efb..67cac94 100644
--- a/internal/ceres/problem_impl.cc
+++ b/internal/ceres/problem_impl.cc
@@ -823,6 +823,16 @@
}
}
+const CostFunction* ProblemImpl::GetCostFunctionForResidualBlock(
+ const ResidualBlockId residual_block) const {
+ return residual_block->cost_function();
+}
+
+const LossFunction* ProblemImpl::GetLossFunctionForResidualBlock(
+ const ResidualBlockId residual_block) const {
+ return residual_block->loss_function();
+}
+
void ProblemImpl::GetResidualBlocksForParameterBlock(
const double* values,
vector<ResidualBlockId>* residual_blocks) const {
diff --git a/internal/ceres/problem_impl.h b/internal/ceres/problem_impl.h
index 7b5547b..3d84de8 100644
--- a/internal/ceres/problem_impl.h
+++ b/internal/ceres/problem_impl.h
@@ -157,6 +157,11 @@
const ResidualBlockId residual_block,
vector<double*>* parameter_blocks) const;
+ const CostFunction* GetCostFunctionForResidualBlock(
+ const ResidualBlockId residual_block) const;
+ const LossFunction* GetLossFunctionForResidualBlock(
+ const ResidualBlockId residual_block) const;
+
void GetResidualBlocksForParameterBlock(
const double* values,
vector<ResidualBlockId>* residual_blocks) const;
diff --git a/internal/ceres/problem_test.cc b/internal/ceres/problem_test.cc
index db082ec..36e4996 100644
--- a/internal/ceres/problem_test.cc
+++ b/internal/ceres/problem_test.cc
@@ -39,6 +39,7 @@
#include "ceres/internal/eigen.h"
#include "ceres/internal/scoped_ptr.h"
#include "ceres/local_parameterization.h"
+#include "ceres/loss_function.h"
#include "ceres/map_util.h"
#include "ceres/parameter_block.h"
#include "ceres/program.h"
@@ -342,6 +343,30 @@
CHECK_EQ(num_destructions, 1);
}
+TEST(Problem, GetCostFunctionForResidualBlock) {
+ double x[3];
+ Problem problem;
+ CostFunction* cost_function = new UnaryCostFunction(2, 3);
+ const ResidualBlockId residual_block =
+ problem.AddResidualBlock(cost_function, NULL, x);
+ EXPECT_EQ(problem.GetCostFunctionForResidualBlock(residual_block),
+ cost_function);
+ EXPECT_TRUE(problem.GetLossFunctionForResidualBlock(residual_block) == NULL);
+}
+
+TEST(Problem, GetLossFunctionForResidualBlock) {
+ double x[3];
+ Problem problem;
+ CostFunction* cost_function = new UnaryCostFunction(2, 3);
+ LossFunction* loss_function = new TrivialLoss();
+ const ResidualBlockId residual_block =
+ problem.AddResidualBlock(cost_function, loss_function, x);
+ EXPECT_EQ(problem.GetCostFunctionForResidualBlock(residual_block),
+ cost_function);
+ EXPECT_EQ(problem.GetLossFunctionForResidualBlock(residual_block),
+ loss_function);
+}
+
TEST(Problem, CostFunctionsAreDeletedEvenWithRemovals) {
double y[4], z[5], w[4];
int num_destructions = 0;