Export the structure of a problem to the public API This adds three new public methods to ceres::Problem: Problem::GetResidualBlocks() Problem::GetParameterBlocksForResidualBlock() Problem::GetResidualBlocksForParameterBlock() These permit access to the underlying graph structure of the problem. Change-Id: I55a4c7f0e5f325f140cb4830e7a7070554594650
diff --git a/include/ceres/problem.h b/include/ceres/problem.h index 663616d..cd433f9 100644 --- a/include/ceres/problem.h +++ b/include/ceres/problem.h
@@ -1,5 +1,5 @@ // Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2010, 2011, 2012 Google Inc. All rights reserved. +// Copyright 2013 Google Inc. All rights reserved. // http://code.google.com/p/ceres-solver/ // // Redistribution and use in source and binary forms, with or without @@ -341,6 +341,26 @@ // parameter_block.size() == NumParameterBlocks. void GetParameterBlocks(vector<double*>* parameter_blocks) const; + // Fills the passed residual_blocks vector with pointers to the + // residual blocks currently in the problem. After this call, + // residual_blocks.size() == NumResidualBlocks. + void GetResidualBlocks(vector<ResidualBlockId>* residual_blocks) const; + + // Get all the parameter blocks that depend on the given residual block. + void GetParameterBlocksForResidualBlock( + const ResidualBlockId residual_block, + vector<double*>* parameter_blocks) const; + + // Get all the residual blocks that depend on the given parameter block. + // + // If Problem::Options::enable_fast_parameter_block_removal is true, then + // getting the residual blocks is fast and depends only on the number of + // residual blocks. Otherwise, getting the residual blocks for a parameter + // block will incur a scan of the entire Problem object. + void GetResidualBlocksForParameterBlock( + const double* values, + vector<ResidualBlockId>* residual_blocks) const; + // Options struct to control Problem::Evaluate. struct EvaluateOptions { EvaluateOptions()
diff --git a/internal/ceres/problem.cc b/internal/ceres/problem.cc index 403e96a..89821b9 100644 --- a/internal/ceres/problem.cc +++ b/internal/ceres/problem.cc
@@ -1,5 +1,5 @@ // Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2010, 2011, 2012 Google Inc. All rights reserved. +// Copyright 2013 Google Inc. All rights reserved. // http://code.google.com/p/ceres-solver/ // // Redistribution and use in source and binary forms, with or without @@ -218,4 +218,23 @@ problem_impl_->GetParameterBlocks(parameter_blocks); } +void Problem::GetResidualBlocks( + vector<ResidualBlockId>* residual_blocks) const { + problem_impl_->GetResidualBlocks(residual_blocks); +} + +void Problem::GetParameterBlocksForResidualBlock( + const ResidualBlockId residual_block, + vector<double*>* parameter_blocks) const { + problem_impl_->GetParameterBlocksForResidualBlock(residual_block, + parameter_blocks); +} + +void Problem::GetResidualBlocksForParameterBlock( + const double* values, + vector<ResidualBlockId>* residual_blocks) const { + problem_impl_->GetResidualBlocksForParameterBlock(values, + residual_blocks); +} + } // namespace ceres
diff --git a/internal/ceres/problem_impl.cc b/internal/ceres/problem_impl.cc index 0846de8..4197d59 100644 --- a/internal/ceres/problem_impl.cc +++ b/internal/ceres/problem_impl.cc
@@ -1,5 +1,5 @@ // Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2010, 2011, 2012 Google Inc. All rights reserved. +// Copyright 2013 Google Inc. All rights reserved. // http://code.google.com/p/ceres-solver/ // // Redistribution and use in source and binary forms, with or without @@ -27,7 +27,7 @@ // POSSIBILITY OF SUCH DAMAGE. // // Author: sameeragarwal@google.com (Sameer Agarwal) -// keir@google.com (Keir Mierle) +// mierle@gmail.com (Keir Mierle) #include "ceres/problem_impl.h" @@ -735,6 +735,56 @@ } } +void ProblemImpl::GetResidualBlocks( + vector<ResidualBlockId>* residual_blocks) const { + CHECK_NOTNULL(residual_blocks); + *residual_blocks = program().residual_blocks(); +} + +void ProblemImpl::GetParameterBlocksForResidualBlock( + const ResidualBlockId residual_block, + vector<double*>* parameter_blocks) const { + int num_parameter_blocks = residual_block->NumParameterBlocks(); + CHECK_NOTNULL(parameter_blocks)->resize(num_parameter_blocks); + for (int i = 0; i < num_parameter_blocks; ++i) { + (*parameter_blocks)[i] = + residual_block->parameter_blocks()[i]->mutable_user_state(); + } +} + +void ProblemImpl::GetResidualBlocksForParameterBlock( + const double* values, + vector<ResidualBlockId>* residual_blocks) const { + ParameterBlock* parameter_block = + FindParameterBlockOrDie(parameter_block_map_, + const_cast<double*>(values)); + + if (options_.enable_fast_parameter_block_removal) { + // In this case the residual blocks that depend on the parameter block are + // stored in the parameter block already, so just copy them out. + CHECK_NOTNULL(residual_blocks)->resize(parameter_block->mutable_residual_blocks()->size()); + std::copy(parameter_block->mutable_residual_blocks()->begin(), + parameter_block->mutable_residual_blocks()->end(), + residual_blocks->begin()); + return; + } + + // Find residual blocks that depend on the parameter block. + CHECK_NOTNULL(residual_blocks)->clear(); + const int num_residual_blocks = NumResidualBlocks(); + for (int i = 0; i < num_residual_blocks; ++i) { + ResidualBlock* residual_block = + (*(program_->mutable_residual_blocks()))[i]; + const int num_parameter_blocks = residual_block->NumParameterBlocks(); + for (int j = 0; j < num_parameter_blocks; ++j) { + if (residual_block->parameter_blocks()[j] == parameter_block) { + residual_blocks->push_back(residual_block); + // The parameter blocks are guaranteed unique. + break; + } + } + } +} } // namespace internal } // namespace ceres
diff --git a/internal/ceres/problem_impl.h b/internal/ceres/problem_impl.h index ace27f5..35c16cd 100644 --- a/internal/ceres/problem_impl.h +++ b/internal/ceres/problem_impl.h
@@ -142,6 +142,15 @@ int ParameterBlockSize(const double* parameter_block) const; int ParameterBlockLocalSize(const double* parameter_block) const; void GetParameterBlocks(vector<double*>* parameter_blocks) const; + void GetResidualBlocks(vector<ResidualBlockId>* residual_blocks) const; + + void GetParameterBlocksForResidualBlock( + const ResidualBlockId residual_block, + vector<double*>* parameter_blocks) const; + + void GetResidualBlocksForParameterBlock( + const double* values, + vector<ResidualBlockId>* residual_blocks) const; const Program& program() const { return *program_; } Program* mutable_program() { return program_.get(); }
diff --git a/internal/ceres/problem_test.cc b/internal/ceres/problem_test.cc index 0944d3f..a7f4f0b 100644 --- a/internal/ceres/problem_test.cc +++ b/internal/ceres/problem_test.cc
@@ -1,5 +1,5 @@ // Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2010, 2011, 2012 Google Inc. All rights reserved. +// Copyright 2013 Google Inc. All rights reserved. // http://code.google.com/p/ceres-solver/ // // Redistribution and use in source and binary forms, with or without @@ -771,6 +771,141 @@ } } +// Check that a null-terminated array, a, has the same elements as b. +template<typename T> +void ExpectVectorContainsUnordered(const T* a, const vector<T>& b) { + // Compute the size of a. + int size = 0; + while (a[size]) { + ++size; + } + ASSERT_EQ(size, b.size()); + + // Sort a. + vector<T> a_sorted(size); + copy(a, a + size, a_sorted.begin()); + sort(a_sorted.begin(), a_sorted.end()); + + // Sort b. + vector<T> b_sorted(b); + sort(b_sorted.begin(), b_sorted.end()); + + // Compare. + for (int i = 0; i < size; ++i) { + EXPECT_EQ(a_sorted[i], b_sorted[i]); + } +} + +void ExpectProblemHasResidualBlocks( + const ProblemImpl &problem, + const ResidualBlockId *expected_residual_blocks) { + vector<ResidualBlockId> residual_blocks; + problem.GetResidualBlocks(&residual_blocks); + ExpectVectorContainsUnordered(expected_residual_blocks, residual_blocks); +} + +TEST_P(DynamicProblem, GetXXXBlocksForYYYBlock) { + problem->AddParameterBlock(y, 4); + problem->AddParameterBlock(z, 5); + problem->AddParameterBlock(w, 3); + + // Add all combinations of cost functions. + CostFunction* cost_yzw = new TernaryCostFunction(1, 4, 5, 3); + CostFunction* cost_yz = new BinaryCostFunction (1, 4, 5); + CostFunction* cost_yw = new BinaryCostFunction (1, 4, 3); + CostFunction* cost_zw = new BinaryCostFunction (1, 5, 3); + CostFunction* cost_y = new UnaryCostFunction (1, 4); + CostFunction* cost_z = new UnaryCostFunction (1, 5); + CostFunction* cost_w = new UnaryCostFunction (1, 3); + + ResidualBlock* r_yzw = problem->AddResidualBlock(cost_yzw, NULL, y, z, w); + { + ResidualBlockId expected_residuals[] = {r_yzw, 0}; + ExpectProblemHasResidualBlocks(*problem, expected_residuals); + } + ResidualBlock* r_yz = problem->AddResidualBlock(cost_yz, NULL, y, z); + { + ResidualBlockId expected_residuals[] = {r_yzw, r_yz, 0}; + ExpectProblemHasResidualBlocks(*problem, expected_residuals); + } + ResidualBlock* r_yw = problem->AddResidualBlock(cost_yw, NULL, y, w); + { + ResidualBlock *expected_residuals[] = {r_yzw, r_yz, r_yw, 0}; + ExpectProblemHasResidualBlocks(*problem, expected_residuals); + } + ResidualBlock* r_zw = problem->AddResidualBlock(cost_zw, NULL, z, w); + { + ResidualBlock *expected_residuals[] = {r_yzw, r_yz, r_yw, r_zw, 0}; + ExpectProblemHasResidualBlocks(*problem, expected_residuals); + } + ResidualBlock* r_y = problem->AddResidualBlock(cost_y, NULL, y); + { + ResidualBlock *expected_residuals[] = {r_yzw, r_yz, r_yw, r_zw, r_y, 0}; + ExpectProblemHasResidualBlocks(*problem, expected_residuals); + } + ResidualBlock* r_z = problem->AddResidualBlock(cost_z, NULL, z); + { + ResidualBlock *expected_residuals[] = { + r_yzw, r_yz, r_yw, r_zw, r_y, r_z, 0 + }; + ExpectProblemHasResidualBlocks(*problem, expected_residuals); + } + ResidualBlock* r_w = problem->AddResidualBlock(cost_w, NULL, w); + { + ResidualBlock *expected_residuals[] = { + r_yzw, r_yz, r_yw, r_zw, r_y, r_z, r_w, 0 + }; + ExpectProblemHasResidualBlocks(*problem, expected_residuals); + } + + vector<double*> parameter_blocks; + vector<ResidualBlockId> residual_blocks; + + // Check GetResidualBlocksForParameterBlock() for all parameter blocks. + struct GetResidualBlocksForParameterBlockTestCase { + double* parameter_block; + ResidualBlockId expected_residual_blocks[10]; + }; + GetResidualBlocksForParameterBlockTestCase get_residual_blocks_cases[] = { + { y, { r_yzw, r_yz, r_yw, r_y, NULL} }, + { z, { r_yzw, r_yz, r_zw, r_z, NULL} }, + { w, { r_yzw, r_yw, r_zw, r_w, NULL} }, + { NULL } + }; + for (int i = 0; get_residual_blocks_cases[i].parameter_block; ++i) { + problem->GetResidualBlocksForParameterBlock( + get_residual_blocks_cases[i].parameter_block, + &residual_blocks); + ExpectVectorContainsUnordered( + get_residual_blocks_cases[i].expected_residual_blocks, + residual_blocks); + } + + // Check GetParameterBlocksForResidualBlock() for all residual blocks. + struct GetParameterBlocksForResidualBlockTestCase { + ResidualBlockId residual_block; + double* expected_parameter_blocks[10]; + }; + GetParameterBlocksForResidualBlockTestCase get_parameter_blocks_cases[] = { + { r_yzw, { y, z, w, NULL } }, + { r_yz , { y, z, NULL } }, + { r_yw , { y, w, NULL } }, + { r_zw , { z, w, NULL } }, + { r_y , { y, NULL } }, + { r_z , { z, NULL } }, + { r_w , { w, NULL } }, + { NULL } + }; + for (int i = 0; get_parameter_blocks_cases[i].residual_block; ++i) { + problem->GetParameterBlocksForResidualBlock( + get_parameter_blocks_cases[i].residual_block, + ¶meter_blocks); + ExpectVectorContainsUnordered( + get_parameter_blocks_cases[i].expected_parameter_blocks, + parameter_blocks); + } +} + INSTANTIATE_TEST_CASE_P(OptionsInstantiation, DynamicProblem, ::testing::Values(true, false));