Add more inspection methods to Problem.

Problem::GetCostFunctionForResidualBlock
Problem::GetLossFunctionForResidualBlock

are added, so that users do not have to maintain this mapping
outside the Problem.

Change-Id: I38356dfa094b2c7eec90651dafeaf3a33c5f5f56
diff --git a/include/ceres/problem.h b/include/ceres/problem.h
index b1cb99a..f75ede3 100644
--- a/include/ceres/problem.h
+++ b/include/ceres/problem.h
@@ -368,6 +368,15 @@
       const ResidualBlockId residual_block,
       vector<double*>* parameter_blocks) const;
 
+  // Get the CostFunction for the given residual block.
+  const CostFunction* GetCostFunctionForResidualBlock(
+      const ResidualBlockId residual_block) const;
+
+  // Get the LossFunction for the given residual block. Returns NULL
+  // if no loss function is associated with this residual block.
+  const LossFunction* GetLossFunctionForResidualBlock(
+      const ResidualBlockId residual_block) const;
+
   // Get all the residual blocks that depend on the given parameter block.
   //
   // If Problem::Options::enable_fast_removal is true, then
diff --git a/internal/ceres/problem.cc b/internal/ceres/problem.cc
index 674694d..bbfaa98 100644
--- a/internal/ceres/problem.cc
+++ b/internal/ceres/problem.cc
@@ -251,6 +251,16 @@
                                                     parameter_blocks);
 }
 
+const CostFunction* Problem::GetCostFunctionForResidualBlock(
+    const ResidualBlockId residual_block) const {
+  return problem_impl_->GetCostFunctionForResidualBlock(residual_block);
+}
+
+const LossFunction* Problem::GetLossFunctionForResidualBlock(
+    const ResidualBlockId residual_block) const {
+  return problem_impl_->GetLossFunctionForResidualBlock(residual_block);
+}
+
 void Problem::GetResidualBlocksForParameterBlock(
     const double* values,
     vector<ResidualBlockId>* residual_blocks) const {
diff --git a/internal/ceres/problem_impl.cc b/internal/ceres/problem_impl.cc
index 7c86efb..67cac94 100644
--- a/internal/ceres/problem_impl.cc
+++ b/internal/ceres/problem_impl.cc
@@ -823,6 +823,16 @@
   }
 }
 
+const CostFunction* ProblemImpl::GetCostFunctionForResidualBlock(
+    const ResidualBlockId residual_block) const {
+  return residual_block->cost_function();
+}
+
+const LossFunction* ProblemImpl::GetLossFunctionForResidualBlock(
+    const ResidualBlockId residual_block) const {
+  return residual_block->loss_function();
+}
+
 void ProblemImpl::GetResidualBlocksForParameterBlock(
     const double* values,
     vector<ResidualBlockId>* residual_blocks) const {
diff --git a/internal/ceres/problem_impl.h b/internal/ceres/problem_impl.h
index 7b5547b..3d84de8 100644
--- a/internal/ceres/problem_impl.h
+++ b/internal/ceres/problem_impl.h
@@ -157,6 +157,11 @@
       const ResidualBlockId residual_block,
       vector<double*>* parameter_blocks) const;
 
+  const CostFunction* GetCostFunctionForResidualBlock(
+      const ResidualBlockId residual_block) const;
+  const LossFunction* GetLossFunctionForResidualBlock(
+      const ResidualBlockId residual_block) const;
+
   void GetResidualBlocksForParameterBlock(
       const double* values,
       vector<ResidualBlockId>* residual_blocks) const;
diff --git a/internal/ceres/problem_test.cc b/internal/ceres/problem_test.cc
index db082ec..36e4996 100644
--- a/internal/ceres/problem_test.cc
+++ b/internal/ceres/problem_test.cc
@@ -39,6 +39,7 @@
 #include "ceres/internal/eigen.h"
 #include "ceres/internal/scoped_ptr.h"
 #include "ceres/local_parameterization.h"
+#include "ceres/loss_function.h"
 #include "ceres/map_util.h"
 #include "ceres/parameter_block.h"
 #include "ceres/program.h"
@@ -342,6 +343,30 @@
   CHECK_EQ(num_destructions, 1);
 }
 
+TEST(Problem, GetCostFunctionForResidualBlock) {
+  double x[3];
+  Problem problem;
+  CostFunction* cost_function = new UnaryCostFunction(2, 3);
+  const ResidualBlockId residual_block =
+      problem.AddResidualBlock(cost_function, NULL, x);
+  EXPECT_EQ(problem.GetCostFunctionForResidualBlock(residual_block),
+            cost_function);
+  EXPECT_TRUE(problem.GetLossFunctionForResidualBlock(residual_block) == NULL);
+}
+
+TEST(Problem, GetLossFunctionForResidualBlock) {
+  double x[3];
+  Problem problem;
+  CostFunction* cost_function = new UnaryCostFunction(2, 3);
+  LossFunction* loss_function = new TrivialLoss();
+  const ResidualBlockId residual_block =
+      problem.AddResidualBlock(cost_function, loss_function, x);
+  EXPECT_EQ(problem.GetCostFunctionForResidualBlock(residual_block),
+            cost_function);
+  EXPECT_EQ(problem.GetLossFunctionForResidualBlock(residual_block),
+            loss_function);
+}
+
 TEST(Problem, CostFunctionsAreDeletedEvenWithRemovals) {
   double y[4], z[5], w[4];
   int num_destructions = 0;