Export the structure of a problem to the public API
This adds three new public methods to ceres::Problem:
Problem::GetResidualBlocks()
Problem::GetParameterBlocksForResidualBlock()
Problem::GetResidualBlocksForParameterBlock()
These permit access to the underlying graph structure of the problem.
Change-Id: I55a4c7f0e5f325f140cb4830e7a7070554594650
diff --git a/include/ceres/problem.h b/include/ceres/problem.h
index 663616d..cd433f9 100644
--- a/include/ceres/problem.h
+++ b/include/ceres/problem.h
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
+// Copyright 2013 Google Inc. All rights reserved.
// http://code.google.com/p/ceres-solver/
//
// Redistribution and use in source and binary forms, with or without
@@ -341,6 +341,26 @@
// parameter_block.size() == NumParameterBlocks.
void GetParameterBlocks(vector<double*>* parameter_blocks) const;
+ // Fills the passed residual_blocks vector with pointers to the
+ // residual blocks currently in the problem. After this call,
+ // residual_blocks.size() == NumResidualBlocks.
+ void GetResidualBlocks(vector<ResidualBlockId>* residual_blocks) const;
+
+ // Get all the parameter blocks that depend on the given residual block.
+ void GetParameterBlocksForResidualBlock(
+ const ResidualBlockId residual_block,
+ vector<double*>* parameter_blocks) const;
+
+ // Get all the residual blocks that depend on the given parameter block.
+ //
+ // If Problem::Options::enable_fast_parameter_block_removal is true, then
+ // getting the residual blocks is fast and depends only on the number of
+ // residual blocks. Otherwise, getting the residual blocks for a parameter
+ // block will incur a scan of the entire Problem object.
+ void GetResidualBlocksForParameterBlock(
+ const double* values,
+ vector<ResidualBlockId>* residual_blocks) const;
+
// Options struct to control Problem::Evaluate.
struct EvaluateOptions {
EvaluateOptions()
diff --git a/internal/ceres/problem.cc b/internal/ceres/problem.cc
index 403e96a..89821b9 100644
--- a/internal/ceres/problem.cc
+++ b/internal/ceres/problem.cc
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
+// Copyright 2013 Google Inc. All rights reserved.
// http://code.google.com/p/ceres-solver/
//
// Redistribution and use in source and binary forms, with or without
@@ -218,4 +218,23 @@
problem_impl_->GetParameterBlocks(parameter_blocks);
}
+void Problem::GetResidualBlocks(
+ vector<ResidualBlockId>* residual_blocks) const {
+ problem_impl_->GetResidualBlocks(residual_blocks);
+}
+
+void Problem::GetParameterBlocksForResidualBlock(
+ const ResidualBlockId residual_block,
+ vector<double*>* parameter_blocks) const {
+ problem_impl_->GetParameterBlocksForResidualBlock(residual_block,
+ parameter_blocks);
+}
+
+void Problem::GetResidualBlocksForParameterBlock(
+ const double* values,
+ vector<ResidualBlockId>* residual_blocks) const {
+ problem_impl_->GetResidualBlocksForParameterBlock(values,
+ residual_blocks);
+}
+
} // namespace ceres
diff --git a/internal/ceres/problem_impl.cc b/internal/ceres/problem_impl.cc
index 0846de8..4197d59 100644
--- a/internal/ceres/problem_impl.cc
+++ b/internal/ceres/problem_impl.cc
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
+// Copyright 2013 Google Inc. All rights reserved.
// http://code.google.com/p/ceres-solver/
//
// Redistribution and use in source and binary forms, with or without
@@ -27,7 +27,7 @@
// POSSIBILITY OF SUCH DAMAGE.
//
// Author: sameeragarwal@google.com (Sameer Agarwal)
-// keir@google.com (Keir Mierle)
+// mierle@gmail.com (Keir Mierle)
#include "ceres/problem_impl.h"
@@ -735,6 +735,56 @@
}
}
+void ProblemImpl::GetResidualBlocks(
+ vector<ResidualBlockId>* residual_blocks) const {
+ CHECK_NOTNULL(residual_blocks);
+ *residual_blocks = program().residual_blocks();
+}
+
+void ProblemImpl::GetParameterBlocksForResidualBlock(
+ const ResidualBlockId residual_block,
+ vector<double*>* parameter_blocks) const {
+ int num_parameter_blocks = residual_block->NumParameterBlocks();
+ CHECK_NOTNULL(parameter_blocks)->resize(num_parameter_blocks);
+ for (int i = 0; i < num_parameter_blocks; ++i) {
+ (*parameter_blocks)[i] =
+ residual_block->parameter_blocks()[i]->mutable_user_state();
+ }
+}
+
+void ProblemImpl::GetResidualBlocksForParameterBlock(
+ const double* values,
+ vector<ResidualBlockId>* residual_blocks) const {
+ ParameterBlock* parameter_block =
+ FindParameterBlockOrDie(parameter_block_map_,
+ const_cast<double*>(values));
+
+ if (options_.enable_fast_parameter_block_removal) {
+ // In this case the residual blocks that depend on the parameter block are
+ // stored in the parameter block already, so just copy them out.
+ CHECK_NOTNULL(residual_blocks)->resize(parameter_block->mutable_residual_blocks()->size());
+ std::copy(parameter_block->mutable_residual_blocks()->begin(),
+ parameter_block->mutable_residual_blocks()->end(),
+ residual_blocks->begin());
+ return;
+ }
+
+ // Find residual blocks that depend on the parameter block.
+ CHECK_NOTNULL(residual_blocks)->clear();
+ const int num_residual_blocks = NumResidualBlocks();
+ for (int i = 0; i < num_residual_blocks; ++i) {
+ ResidualBlock* residual_block =
+ (*(program_->mutable_residual_blocks()))[i];
+ const int num_parameter_blocks = residual_block->NumParameterBlocks();
+ for (int j = 0; j < num_parameter_blocks; ++j) {
+ if (residual_block->parameter_blocks()[j] == parameter_block) {
+ residual_blocks->push_back(residual_block);
+ // The parameter blocks are guaranteed unique.
+ break;
+ }
+ }
+ }
+}
} // namespace internal
} // namespace ceres
diff --git a/internal/ceres/problem_impl.h b/internal/ceres/problem_impl.h
index ace27f5..35c16cd 100644
--- a/internal/ceres/problem_impl.h
+++ b/internal/ceres/problem_impl.h
@@ -142,6 +142,15 @@
int ParameterBlockSize(const double* parameter_block) const;
int ParameterBlockLocalSize(const double* parameter_block) const;
void GetParameterBlocks(vector<double*>* parameter_blocks) const;
+ void GetResidualBlocks(vector<ResidualBlockId>* residual_blocks) const;
+
+ void GetParameterBlocksForResidualBlock(
+ const ResidualBlockId residual_block,
+ vector<double*>* parameter_blocks) const;
+
+ void GetResidualBlocksForParameterBlock(
+ const double* values,
+ vector<ResidualBlockId>* residual_blocks) const;
const Program& program() const { return *program_; }
Program* mutable_program() { return program_.get(); }
diff --git a/internal/ceres/problem_test.cc b/internal/ceres/problem_test.cc
index 0944d3f..a7f4f0b 100644
--- a/internal/ceres/problem_test.cc
+++ b/internal/ceres/problem_test.cc
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
+// Copyright 2013 Google Inc. All rights reserved.
// http://code.google.com/p/ceres-solver/
//
// Redistribution and use in source and binary forms, with or without
@@ -771,6 +771,141 @@
}
}
+// Check that a null-terminated array, a, has the same elements as b.
+template<typename T>
+void ExpectVectorContainsUnordered(const T* a, const vector<T>& b) {
+ // Compute the size of a.
+ int size = 0;
+ while (a[size]) {
+ ++size;
+ }
+ ASSERT_EQ(size, b.size());
+
+ // Sort a.
+ vector<T> a_sorted(size);
+ copy(a, a + size, a_sorted.begin());
+ sort(a_sorted.begin(), a_sorted.end());
+
+ // Sort b.
+ vector<T> b_sorted(b);
+ sort(b_sorted.begin(), b_sorted.end());
+
+ // Compare.
+ for (int i = 0; i < size; ++i) {
+ EXPECT_EQ(a_sorted[i], b_sorted[i]);
+ }
+}
+
+void ExpectProblemHasResidualBlocks(
+ const ProblemImpl &problem,
+ const ResidualBlockId *expected_residual_blocks) {
+ vector<ResidualBlockId> residual_blocks;
+ problem.GetResidualBlocks(&residual_blocks);
+ ExpectVectorContainsUnordered(expected_residual_blocks, residual_blocks);
+}
+
+TEST_P(DynamicProblem, GetXXXBlocksForYYYBlock) {
+ problem->AddParameterBlock(y, 4);
+ problem->AddParameterBlock(z, 5);
+ problem->AddParameterBlock(w, 3);
+
+ // Add all combinations of cost functions.
+ CostFunction* cost_yzw = new TernaryCostFunction(1, 4, 5, 3);
+ CostFunction* cost_yz = new BinaryCostFunction (1, 4, 5);
+ CostFunction* cost_yw = new BinaryCostFunction (1, 4, 3);
+ CostFunction* cost_zw = new BinaryCostFunction (1, 5, 3);
+ CostFunction* cost_y = new UnaryCostFunction (1, 4);
+ CostFunction* cost_z = new UnaryCostFunction (1, 5);
+ CostFunction* cost_w = new UnaryCostFunction (1, 3);
+
+ ResidualBlock* r_yzw = problem->AddResidualBlock(cost_yzw, NULL, y, z, w);
+ {
+ ResidualBlockId expected_residuals[] = {r_yzw, 0};
+ ExpectProblemHasResidualBlocks(*problem, expected_residuals);
+ }
+ ResidualBlock* r_yz = problem->AddResidualBlock(cost_yz, NULL, y, z);
+ {
+ ResidualBlockId expected_residuals[] = {r_yzw, r_yz, 0};
+ ExpectProblemHasResidualBlocks(*problem, expected_residuals);
+ }
+ ResidualBlock* r_yw = problem->AddResidualBlock(cost_yw, NULL, y, w);
+ {
+ ResidualBlock *expected_residuals[] = {r_yzw, r_yz, r_yw, 0};
+ ExpectProblemHasResidualBlocks(*problem, expected_residuals);
+ }
+ ResidualBlock* r_zw = problem->AddResidualBlock(cost_zw, NULL, z, w);
+ {
+ ResidualBlock *expected_residuals[] = {r_yzw, r_yz, r_yw, r_zw, 0};
+ ExpectProblemHasResidualBlocks(*problem, expected_residuals);
+ }
+ ResidualBlock* r_y = problem->AddResidualBlock(cost_y, NULL, y);
+ {
+ ResidualBlock *expected_residuals[] = {r_yzw, r_yz, r_yw, r_zw, r_y, 0};
+ ExpectProblemHasResidualBlocks(*problem, expected_residuals);
+ }
+ ResidualBlock* r_z = problem->AddResidualBlock(cost_z, NULL, z);
+ {
+ ResidualBlock *expected_residuals[] = {
+ r_yzw, r_yz, r_yw, r_zw, r_y, r_z, 0
+ };
+ ExpectProblemHasResidualBlocks(*problem, expected_residuals);
+ }
+ ResidualBlock* r_w = problem->AddResidualBlock(cost_w, NULL, w);
+ {
+ ResidualBlock *expected_residuals[] = {
+ r_yzw, r_yz, r_yw, r_zw, r_y, r_z, r_w, 0
+ };
+ ExpectProblemHasResidualBlocks(*problem, expected_residuals);
+ }
+
+ vector<double*> parameter_blocks;
+ vector<ResidualBlockId> residual_blocks;
+
+ // Check GetResidualBlocksForParameterBlock() for all parameter blocks.
+ struct GetResidualBlocksForParameterBlockTestCase {
+ double* parameter_block;
+ ResidualBlockId expected_residual_blocks[10];
+ };
+ GetResidualBlocksForParameterBlockTestCase get_residual_blocks_cases[] = {
+ { y, { r_yzw, r_yz, r_yw, r_y, NULL} },
+ { z, { r_yzw, r_yz, r_zw, r_z, NULL} },
+ { w, { r_yzw, r_yw, r_zw, r_w, NULL} },
+ { NULL }
+ };
+ for (int i = 0; get_residual_blocks_cases[i].parameter_block; ++i) {
+ problem->GetResidualBlocksForParameterBlock(
+ get_residual_blocks_cases[i].parameter_block,
+ &residual_blocks);
+ ExpectVectorContainsUnordered(
+ get_residual_blocks_cases[i].expected_residual_blocks,
+ residual_blocks);
+ }
+
+ // Check GetParameterBlocksForResidualBlock() for all residual blocks.
+ struct GetParameterBlocksForResidualBlockTestCase {
+ ResidualBlockId residual_block;
+ double* expected_parameter_blocks[10];
+ };
+ GetParameterBlocksForResidualBlockTestCase get_parameter_blocks_cases[] = {
+ { r_yzw, { y, z, w, NULL } },
+ { r_yz , { y, z, NULL } },
+ { r_yw , { y, w, NULL } },
+ { r_zw , { z, w, NULL } },
+ { r_y , { y, NULL } },
+ { r_z , { z, NULL } },
+ { r_w , { w, NULL } },
+ { NULL }
+ };
+ for (int i = 0; get_parameter_blocks_cases[i].residual_block; ++i) {
+ problem->GetParameterBlocksForResidualBlock(
+ get_parameter_blocks_cases[i].residual_block,
+ ¶meter_blocks);
+ ExpectVectorContainsUnordered(
+ get_parameter_blocks_cases[i].expected_parameter_blocks,
+ parameter_blocks);
+ }
+}
+
INSTANTIATE_TEST_CASE_P(OptionsInstantiation,
DynamicProblem,
::testing::Values(true, false));