Export the structure of a problem to the public API

This adds three new public methods to ceres::Problem:

  Problem::GetResidualBlocks()
  Problem::GetParameterBlocksForResidualBlock()
  Problem::GetResidualBlocksForParameterBlock()

These permit access to the underlying graph structure of the problem.

Change-Id: I55a4c7f0e5f325f140cb4830e7a7070554594650
diff --git a/include/ceres/problem.h b/include/ceres/problem.h
index 663616d..cd433f9 100644
--- a/include/ceres/problem.h
+++ b/include/ceres/problem.h
@@ -1,5 +1,5 @@
 // Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
+// Copyright 2013 Google Inc. All rights reserved.
 // http://code.google.com/p/ceres-solver/
 //
 // Redistribution and use in source and binary forms, with or without
@@ -341,6 +341,26 @@
   // parameter_block.size() == NumParameterBlocks.
   void GetParameterBlocks(vector<double*>* parameter_blocks) const;
 
+  // Fills the passed residual_blocks vector with pointers to the
+  // residual blocks currently in the problem. After this call,
+  // residual_blocks.size() == NumResidualBlocks.
+  void GetResidualBlocks(vector<ResidualBlockId>* residual_blocks) const;
+
+  // Get all the parameter blocks that depend on the given residual block.
+  void GetParameterBlocksForResidualBlock(
+      const ResidualBlockId residual_block,
+      vector<double*>* parameter_blocks) const;
+
+  // Get all the residual blocks that depend on the given parameter block.
+  //
+  // If Problem::Options::enable_fast_parameter_block_removal is true, then
+  // getting the residual blocks is fast and depends only on the number of
+  // residual blocks. Otherwise, getting the residual blocks for a parameter
+  // block will incur a scan of the entire Problem object.
+  void GetResidualBlocksForParameterBlock(
+      const double* values,
+      vector<ResidualBlockId>* residual_blocks) const;
+
   // Options struct to control Problem::Evaluate.
   struct EvaluateOptions {
     EvaluateOptions()
diff --git a/internal/ceres/problem.cc b/internal/ceres/problem.cc
index 403e96a..89821b9 100644
--- a/internal/ceres/problem.cc
+++ b/internal/ceres/problem.cc
@@ -1,5 +1,5 @@
 // Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
+// Copyright 2013 Google Inc. All rights reserved.
 // http://code.google.com/p/ceres-solver/
 //
 // Redistribution and use in source and binary forms, with or without
@@ -218,4 +218,23 @@
   problem_impl_->GetParameterBlocks(parameter_blocks);
 }
 
+void Problem::GetResidualBlocks(
+    vector<ResidualBlockId>* residual_blocks) const {
+  problem_impl_->GetResidualBlocks(residual_blocks);
+}
+
+void Problem::GetParameterBlocksForResidualBlock(
+    const ResidualBlockId residual_block,
+    vector<double*>* parameter_blocks) const {
+  problem_impl_->GetParameterBlocksForResidualBlock(residual_block,
+                                                    parameter_blocks);
+}
+
+void Problem::GetResidualBlocksForParameterBlock(
+    const double* values,
+    vector<ResidualBlockId>* residual_blocks) const {
+  problem_impl_->GetResidualBlocksForParameterBlock(values,
+                                                    residual_blocks);
+}
+
 }  // namespace ceres
diff --git a/internal/ceres/problem_impl.cc b/internal/ceres/problem_impl.cc
index 0846de8..4197d59 100644
--- a/internal/ceres/problem_impl.cc
+++ b/internal/ceres/problem_impl.cc
@@ -1,5 +1,5 @@
 // Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
+// Copyright 2013 Google Inc. All rights reserved.
 // http://code.google.com/p/ceres-solver/
 //
 // Redistribution and use in source and binary forms, with or without
@@ -27,7 +27,7 @@
 // POSSIBILITY OF SUCH DAMAGE.
 //
 // Author: sameeragarwal@google.com (Sameer Agarwal)
-//         keir@google.com (Keir Mierle)
+//         mierle@gmail.com (Keir Mierle)
 
 #include "ceres/problem_impl.h"
 
@@ -735,6 +735,56 @@
   }
 }
 
+void ProblemImpl::GetResidualBlocks(
+    vector<ResidualBlockId>* residual_blocks) const {
+  CHECK_NOTNULL(residual_blocks);
+  *residual_blocks = program().residual_blocks();
+}
+
+void ProblemImpl::GetParameterBlocksForResidualBlock(
+    const ResidualBlockId residual_block,
+    vector<double*>* parameter_blocks) const {
+  int num_parameter_blocks = residual_block->NumParameterBlocks();
+  CHECK_NOTNULL(parameter_blocks)->resize(num_parameter_blocks);
+  for (int i = 0; i < num_parameter_blocks; ++i) {
+    (*parameter_blocks)[i] =
+        residual_block->parameter_blocks()[i]->mutable_user_state();
+  }
+}
+
+void ProblemImpl::GetResidualBlocksForParameterBlock(
+    const double* values,
+    vector<ResidualBlockId>* residual_blocks) const {
+  ParameterBlock* parameter_block =
+      FindParameterBlockOrDie(parameter_block_map_,
+                              const_cast<double*>(values));
+
+  if (options_.enable_fast_parameter_block_removal) {
+    // In this case the residual blocks that depend on the parameter block are
+    // stored in the parameter block already, so just copy them out.
+    CHECK_NOTNULL(residual_blocks)->resize(parameter_block->mutable_residual_blocks()->size());
+    std::copy(parameter_block->mutable_residual_blocks()->begin(),
+              parameter_block->mutable_residual_blocks()->end(),
+              residual_blocks->begin());
+    return;
+  }
+
+  // Find residual blocks that depend on the parameter block.
+  CHECK_NOTNULL(residual_blocks)->clear();
+  const int num_residual_blocks = NumResidualBlocks();
+  for (int i = 0; i < num_residual_blocks; ++i) {
+    ResidualBlock* residual_block =
+        (*(program_->mutable_residual_blocks()))[i];
+    const int num_parameter_blocks = residual_block->NumParameterBlocks();
+    for (int j = 0; j < num_parameter_blocks; ++j) {
+      if (residual_block->parameter_blocks()[j] == parameter_block) {
+        residual_blocks->push_back(residual_block);
+        // The parameter blocks are guaranteed unique.
+        break;
+      }
+    }
+  }
+}
 
 }  // namespace internal
 }  // namespace ceres
diff --git a/internal/ceres/problem_impl.h b/internal/ceres/problem_impl.h
index ace27f5..35c16cd 100644
--- a/internal/ceres/problem_impl.h
+++ b/internal/ceres/problem_impl.h
@@ -142,6 +142,15 @@
   int ParameterBlockSize(const double* parameter_block) const;
   int ParameterBlockLocalSize(const double* parameter_block) const;
   void GetParameterBlocks(vector<double*>* parameter_blocks) const;
+  void GetResidualBlocks(vector<ResidualBlockId>* residual_blocks) const;
+
+  void GetParameterBlocksForResidualBlock(
+      const ResidualBlockId residual_block,
+      vector<double*>* parameter_blocks) const;
+
+  void GetResidualBlocksForParameterBlock(
+      const double* values,
+      vector<ResidualBlockId>* residual_blocks) const;
 
   const Program& program() const { return *program_; }
   Program* mutable_program() { return program_.get(); }
diff --git a/internal/ceres/problem_test.cc b/internal/ceres/problem_test.cc
index 0944d3f..a7f4f0b 100644
--- a/internal/ceres/problem_test.cc
+++ b/internal/ceres/problem_test.cc
@@ -1,5 +1,5 @@
 // Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
+// Copyright 2013 Google Inc. All rights reserved.
 // http://code.google.com/p/ceres-solver/
 //
 // Redistribution and use in source and binary forms, with or without
@@ -771,6 +771,141 @@
   }
 }
 
+// Check that a null-terminated array, a, has the same elements as b.
+template<typename T>
+void ExpectVectorContainsUnordered(const T* a, const vector<T>& b) {
+  // Compute the size of a.
+  int size = 0;
+  while (a[size]) {
+    ++size;
+  }
+  ASSERT_EQ(size, b.size());
+
+  // Sort a.
+  vector<T> a_sorted(size);
+  copy(a, a + size, a_sorted.begin());
+  sort(a_sorted.begin(), a_sorted.end());
+
+  // Sort b.
+  vector<T> b_sorted(b);
+  sort(b_sorted.begin(), b_sorted.end());
+
+  // Compare.
+  for (int i = 0; i < size; ++i) {
+    EXPECT_EQ(a_sorted[i], b_sorted[i]);
+  }
+}
+
+void ExpectProblemHasResidualBlocks(
+    const ProblemImpl &problem,
+    const ResidualBlockId *expected_residual_blocks) {
+  vector<ResidualBlockId> residual_blocks;
+  problem.GetResidualBlocks(&residual_blocks);
+  ExpectVectorContainsUnordered(expected_residual_blocks, residual_blocks);
+}
+
+TEST_P(DynamicProblem, GetXXXBlocksForYYYBlock) {
+  problem->AddParameterBlock(y, 4);
+  problem->AddParameterBlock(z, 5);
+  problem->AddParameterBlock(w, 3);
+
+  // Add all combinations of cost functions.
+  CostFunction* cost_yzw = new TernaryCostFunction(1, 4, 5, 3);
+  CostFunction* cost_yz  = new BinaryCostFunction (1, 4, 5);
+  CostFunction* cost_yw  = new BinaryCostFunction (1, 4, 3);
+  CostFunction* cost_zw  = new BinaryCostFunction (1, 5, 3);
+  CostFunction* cost_y   = new UnaryCostFunction  (1, 4);
+  CostFunction* cost_z   = new UnaryCostFunction  (1, 5);
+  CostFunction* cost_w   = new UnaryCostFunction  (1, 3);
+
+  ResidualBlock* r_yzw = problem->AddResidualBlock(cost_yzw, NULL, y, z, w);
+  {
+    ResidualBlockId expected_residuals[] = {r_yzw, 0};
+    ExpectProblemHasResidualBlocks(*problem, expected_residuals);
+  }
+  ResidualBlock* r_yz  = problem->AddResidualBlock(cost_yz,  NULL, y, z);
+  {
+    ResidualBlockId expected_residuals[] = {r_yzw, r_yz, 0};
+    ExpectProblemHasResidualBlocks(*problem, expected_residuals);
+  }
+  ResidualBlock* r_yw  = problem->AddResidualBlock(cost_yw,  NULL, y, w);
+  {
+    ResidualBlock *expected_residuals[] = {r_yzw, r_yz, r_yw, 0};
+    ExpectProblemHasResidualBlocks(*problem, expected_residuals);
+  }
+  ResidualBlock* r_zw  = problem->AddResidualBlock(cost_zw,  NULL, z, w);
+  {
+    ResidualBlock *expected_residuals[] = {r_yzw, r_yz, r_yw, r_zw, 0};
+    ExpectProblemHasResidualBlocks(*problem, expected_residuals);
+  }
+  ResidualBlock* r_y   = problem->AddResidualBlock(cost_y,   NULL, y);
+  {
+    ResidualBlock *expected_residuals[] = {r_yzw, r_yz, r_yw, r_zw, r_y, 0};
+    ExpectProblemHasResidualBlocks(*problem, expected_residuals);
+  }
+  ResidualBlock* r_z   = problem->AddResidualBlock(cost_z,   NULL, z);
+  {
+    ResidualBlock *expected_residuals[] = {
+      r_yzw, r_yz, r_yw, r_zw, r_y, r_z, 0
+    };
+    ExpectProblemHasResidualBlocks(*problem, expected_residuals);
+  }
+  ResidualBlock* r_w   = problem->AddResidualBlock(cost_w,   NULL, w);
+  {
+    ResidualBlock *expected_residuals[] = {
+      r_yzw, r_yz, r_yw, r_zw, r_y, r_z, r_w, 0
+    };
+    ExpectProblemHasResidualBlocks(*problem, expected_residuals);
+  }
+
+  vector<double*> parameter_blocks;
+  vector<ResidualBlockId> residual_blocks;
+
+  // Check GetResidualBlocksForParameterBlock() for all parameter blocks.
+  struct GetResidualBlocksForParameterBlockTestCase {
+    double* parameter_block;
+    ResidualBlockId expected_residual_blocks[10];
+  };
+  GetResidualBlocksForParameterBlockTestCase get_residual_blocks_cases[] = {
+    { y, { r_yzw, r_yz, r_yw, r_y, NULL} },
+    { z, { r_yzw, r_yz, r_zw, r_z, NULL} },
+    { w, { r_yzw, r_yw, r_zw, r_w, NULL} },
+    { NULL }
+  };
+  for (int i = 0; get_residual_blocks_cases[i].parameter_block; ++i) {
+    problem->GetResidualBlocksForParameterBlock(
+        get_residual_blocks_cases[i].parameter_block,
+        &residual_blocks);
+    ExpectVectorContainsUnordered(
+        get_residual_blocks_cases[i].expected_residual_blocks,
+        residual_blocks);
+  }
+
+  // Check GetParameterBlocksForResidualBlock() for all residual blocks.
+  struct GetParameterBlocksForResidualBlockTestCase {
+    ResidualBlockId residual_block;
+    double* expected_parameter_blocks[10];
+  };
+  GetParameterBlocksForResidualBlockTestCase get_parameter_blocks_cases[] = {
+    { r_yzw, { y, z, w, NULL } },
+    { r_yz , { y, z, NULL } },
+    { r_yw , { y, w, NULL } },
+    { r_zw , { z, w, NULL } },
+    { r_y  , { y, NULL } },
+    { r_z  , { z, NULL } },
+    { r_w  , { w, NULL } },
+    { NULL }
+  };
+  for (int i = 0; get_parameter_blocks_cases[i].residual_block; ++i) {
+    problem->GetParameterBlocksForResidualBlock(
+        get_parameter_blocks_cases[i].residual_block,
+        &parameter_blocks);
+    ExpectVectorContainsUnordered(
+        get_parameter_blocks_cases[i].expected_parameter_blocks,
+        parameter_blocks);
+  }
+}
+
 INSTANTIATE_TEST_CASE_P(OptionsInstantiation,
                         DynamicProblem,
                         ::testing::Values(true, false));