Check validity of residual block before removal in RemoveResidualBlock.
- Breaking change: Problem::Options::enable_fast_parameter_block_removal
is now Problem::Options::enable_fast_removal, as it now controls
the behaviour for both parameter and residual blocks.
- Previously we did not check that the specified residual block to
remove in RemoveResidualBlock actually represented a valid residual
for the problem.
- This meant that Ceres would die unexpectedly if the user passed an
uninitialised residual_block, or more likely attempted to remove a
residual block that had already been removed automatically after
the user removed a parameter block upon on which it was dependent.
- RemoveResidualBlock now verifies the validity of the given
residual_block to remove. Either by checking against a hash set of
all residuals maintained in ProblemImpl iff enable_fast_removal
is enabled. Or by a full scan of the residual blocks if not.
Change-Id: I9ab178e2f68a74135f0a8e20905b16405c77a62b
diff --git a/internal/ceres/problem_impl.cc b/internal/ceres/problem_impl.cc
index 99f3f89..d9fb2af 100644
--- a/internal/ceres/problem_impl.cc
+++ b/internal/ceres/problem_impl.cc
@@ -142,7 +142,7 @@
// For dynamic problems, add the list of dependent residual blocks, which is
// empty to start.
- if (options_.enable_fast_parameter_block_removal) {
+ if (options_.enable_fast_removal) {
new_parameter_block->EnableResidualBlockDependencies();
}
parameter_block_map_[values] = new_parameter_block;
@@ -150,6 +150,26 @@
return new_parameter_block;
}
+void ProblemImpl::InternalRemoveResidualBlock(ResidualBlock* residual_block) {
+ CHECK_NOTNULL(residual_block);
+ // Perform no check on the validity of residual_block, that is handled in
+ // the public method: RemoveResidualBlock().
+
+ // If needed, remove the parameter dependencies on this residual block.
+ if (options_.enable_fast_removal) {
+ const int num_parameter_blocks_for_residual =
+ residual_block->NumParameterBlocks();
+ for (int i = 0; i < num_parameter_blocks_for_residual; ++i) {
+ residual_block->parameter_blocks()[i]
+ ->RemoveResidualBlock(residual_block);
+ }
+
+ ResidualBlockSet::iterator it = residual_block_set_.find(residual_block);
+ residual_block_set_.erase(it);
+ }
+ DeleteBlockInVector(program_->mutable_residual_blocks(), residual_block);
+}
+
// Deletes the residual block in question, assuming there are no other
// references to it inside the problem (e.g. by another parameter). Referenced
// cost and loss functions are tucked away for future deletion, since it is not
@@ -278,13 +298,18 @@
program_->residual_blocks_.size());
// Add dependencies on the residual to the parameter blocks.
- if (options_.enable_fast_parameter_block_removal) {
+ if (options_.enable_fast_removal) {
for (int i = 0; i < parameter_blocks.size(); ++i) {
parameter_block_ptrs[i]->AddResidualBlock(new_residual_block);
}
}
program_->residual_blocks_.push_back(new_residual_block);
+
+ if (options_.enable_fast_removal) {
+ residual_block_set_.insert(new_residual_block);
+ }
+
return new_residual_block;
}
@@ -475,30 +500,46 @@
void ProblemImpl::RemoveResidualBlock(ResidualBlock* residual_block) {
CHECK_NOTNULL(residual_block);
- // If needed, remove the parameter dependencies on this residual block.
- if (options_.enable_fast_parameter_block_removal) {
- const int num_parameter_blocks_for_residual =
- residual_block->NumParameterBlocks();
- for (int i = 0; i < num_parameter_blocks_for_residual; ++i) {
- residual_block->parameter_blocks()[i]
- ->RemoveResidualBlock(residual_block);
- }
+ // Verify that residual_block identifies a residual in the current problem.
+ const string residual_not_found_message =
+ StringPrintf("Residual block to remove: %p not found. This usually means "
+ "one of three things have happened:\n"
+ " 1) residual_block is uninitialised and points to a random "
+ "area in memory.\n"
+ " 2) residual_block represented a residual that was added to"
+ " the problem, but referred to a parameter block which has "
+ "since been removed, which removes all residuals which "
+ "depend on that parameter block, and was thus removed.\n"
+ " 3) residual_block referred to a residual that has already "
+ "been removed from the problem (by the user).",
+ residual_block);
+ if (options_.enable_fast_removal) {
+ CHECK(residual_block_set_.find(residual_block) !=
+ residual_block_set_.end())
+ << residual_not_found_message;
+ } else {
+ // Perform a full search over all current residuals.
+ CHECK(std::find(program_->residual_blocks().begin(),
+ program_->residual_blocks().end(),
+ residual_block) != program_->residual_blocks().end())
+ << residual_not_found_message;
}
- DeleteBlockInVector(program_->mutable_residual_blocks(), residual_block);
+
+ InternalRemoveResidualBlock(residual_block);
}
void ProblemImpl::RemoveParameterBlock(double* values) {
ParameterBlock* parameter_block =
FindParameterBlockOrDie(parameter_block_map_, values);
- if (options_.enable_fast_parameter_block_removal) {
+ if (options_.enable_fast_removal) {
// Copy the dependent residuals from the parameter block because the set of
// dependents will change after each call to RemoveResidualBlock().
vector<ResidualBlock*> residual_blocks_to_remove(
parameter_block->mutable_residual_blocks()->begin(),
parameter_block->mutable_residual_blocks()->end());
for (int i = 0; i < residual_blocks_to_remove.size(); ++i) {
- RemoveResidualBlock(residual_blocks_to_remove[i]);
+ InternalRemoveResidualBlock(residual_blocks_to_remove[i]);
}
} else {
// Scan all the residual blocks to remove ones that depend on the parameter
@@ -510,7 +551,7 @@
const int num_parameter_blocks = residual_block->NumParameterBlocks();
for (int j = 0; j < num_parameter_blocks; ++j) {
if (residual_block->parameter_blocks()[j] == parameter_block) {
- RemoveResidualBlock(residual_block);
+ InternalRemoveResidualBlock(residual_block);
// The parameter blocks are guaranteed unique.
break;
}
@@ -784,7 +825,7 @@
FindParameterBlockOrDie(parameter_block_map_,
const_cast<double*>(values));
- if (options_.enable_fast_parameter_block_removal) {
+ if (options_.enable_fast_removal) {
// In this case the residual blocks that depend on the parameter block are
// stored in the parameter block already, so just copy them out.
CHECK_NOTNULL(residual_blocks)->resize(
diff --git a/internal/ceres/problem_impl.h b/internal/ceres/problem_impl.h
index 75bdc2b..e846c03 100644
--- a/internal/ceres/problem_impl.h
+++ b/internal/ceres/problem_impl.h
@@ -45,6 +45,7 @@
#include "ceres/internal/macros.h"
#include "ceres/internal/port.h"
#include "ceres/internal/scoped_ptr.h"
+#include "ceres/collections_port.h"
#include "ceres/problem.h"
#include "ceres/types.h"
@@ -63,6 +64,7 @@
class ProblemImpl {
public:
typedef map<double*, ParameterBlock*> ParameterMap;
+ typedef HashSet<ResidualBlock*> ResidualBlockSet;
ProblemImpl();
explicit ProblemImpl(const Problem::Options& options);
@@ -160,9 +162,15 @@
Program* mutable_program() { return program_.get(); }
const ParameterMap& parameter_map() const { return parameter_block_map_; }
+ const ResidualBlockSet& residual_block_set() const {
+ CHECK(options_.enable_fast_removal)
+ << "Fast removal not enabled, residual_block_set is not maintained.";
+ return residual_block_set_;
+ }
private:
ParameterBlock* InternalAddParameterBlock(double* values, int size);
+ void InternalRemoveResidualBlock(ResidualBlock* residual_block);
bool InternalEvaluate(Program* program,
double* cost,
@@ -184,6 +192,9 @@
// The mapping from user pointers to parameter blocks.
map<double*, ParameterBlock*> parameter_block_map_;
+ // Iff enable_fast_removal is enabled, contains the current residual blocks.
+ ResidualBlockSet residual_block_set_;
+
// The actual parameter and residual blocks.
internal::scoped_ptr<internal::Program> program_;
diff --git a/internal/ceres/problem_test.cc b/internal/ceres/problem_test.cc
index 4507dfa..eb75e3a 100644
--- a/internal/ceres/problem_test.cc
+++ b/internal/ceres/problem_test.cc
@@ -378,7 +378,7 @@
struct DynamicProblem : public ::testing::TestWithParam<bool> {
DynamicProblem() {
Problem::Options options;
- options.enable_fast_parameter_block_removal = GetParam();
+ options.enable_fast_removal = GetParam();
problem.reset(new ProblemImpl(options));
}
@@ -390,9 +390,26 @@
}
bool HasResidualBlock(ResidualBlock* residual_block) {
- return find(problem->program().residual_blocks().begin(),
- problem->program().residual_blocks().end(),
- residual_block) != problem->program().residual_blocks().end();
+ bool have_residual_block = true;
+ if (GetParam()) {
+ have_residual_block &=
+ (problem->residual_block_set().find(residual_block) !=
+ problem->residual_block_set().end());
+ }
+ have_residual_block &=
+ find(problem->program().residual_blocks().begin(),
+ problem->program().residual_blocks().end(),
+ residual_block) != problem->program().residual_blocks().end();
+ return have_residual_block;
+ }
+
+ int NumResidualBlocks() {
+ // Verify that the hash set of residuals is maintained consistently.
+ if (GetParam()) {
+ EXPECT_EQ(problem->residual_block_set().size(),
+ problem->NumResidualBlocks());
+ }
+ return problem->NumResidualBlocks();
}
// The next block of functions until the end are only for testing the
@@ -550,7 +567,7 @@
problem->AddParameterBlock(z, 5);
problem->AddParameterBlock(w, 3);
ASSERT_EQ(3, problem->NumParameterBlocks());
- ASSERT_EQ(0, problem->NumResidualBlocks());
+ ASSERT_EQ(0, NumResidualBlocks());
EXPECT_EQ(y, GetParameterBlock(0)->user_state());
EXPECT_EQ(z, GetParameterBlock(1)->user_state());
EXPECT_EQ(w, GetParameterBlock(2)->user_state());
@@ -559,12 +576,12 @@
// removing it.
problem->RemoveParameterBlock(w);
ASSERT_EQ(2, problem->NumParameterBlocks());
- ASSERT_EQ(0, problem->NumResidualBlocks());
+ ASSERT_EQ(0, NumResidualBlocks());
EXPECT_EQ(y, GetParameterBlock(0)->user_state());
EXPECT_EQ(z, GetParameterBlock(1)->user_state());
problem->AddParameterBlock(w, 3);
ASSERT_EQ(3, problem->NumParameterBlocks());
- ASSERT_EQ(0, problem->NumResidualBlocks());
+ ASSERT_EQ(0, NumResidualBlocks());
EXPECT_EQ(y, GetParameterBlock(0)->user_state());
EXPECT_EQ(z, GetParameterBlock(1)->user_state());
EXPECT_EQ(w, GetParameterBlock(2)->user_state());
@@ -572,12 +589,12 @@
// Now remove z, which is in the middle, and add it back.
problem->RemoveParameterBlock(z);
ASSERT_EQ(2, problem->NumParameterBlocks());
- ASSERT_EQ(0, problem->NumResidualBlocks());
+ ASSERT_EQ(0, NumResidualBlocks());
EXPECT_EQ(y, GetParameterBlock(0)->user_state());
EXPECT_EQ(w, GetParameterBlock(1)->user_state());
problem->AddParameterBlock(z, 5);
ASSERT_EQ(3, problem->NumParameterBlocks());
- ASSERT_EQ(0, problem->NumResidualBlocks());
+ ASSERT_EQ(0, NumResidualBlocks());
EXPECT_EQ(y, GetParameterBlock(0)->user_state());
EXPECT_EQ(w, GetParameterBlock(1)->user_state());
EXPECT_EQ(z, GetParameterBlock(2)->user_state());
@@ -586,20 +603,20 @@
// y
problem->RemoveParameterBlock(y);
ASSERT_EQ(2, problem->NumParameterBlocks());
- ASSERT_EQ(0, problem->NumResidualBlocks());
+ ASSERT_EQ(0, NumResidualBlocks());
EXPECT_EQ(z, GetParameterBlock(0)->user_state());
EXPECT_EQ(w, GetParameterBlock(1)->user_state());
// z
problem->RemoveParameterBlock(z);
ASSERT_EQ(1, problem->NumParameterBlocks());
- ASSERT_EQ(0, problem->NumResidualBlocks());
+ ASSERT_EQ(0, NumResidualBlocks());
EXPECT_EQ(w, GetParameterBlock(0)->user_state());
// w
problem->RemoveParameterBlock(w);
EXPECT_EQ(0, problem->NumParameterBlocks());
- EXPECT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(0, NumResidualBlocks());
}
TEST_P(DynamicProblem, RemoveParameterBlockWithResiduals) {
@@ -607,7 +624,7 @@
problem->AddParameterBlock(z, 5);
problem->AddParameterBlock(w, 3);
ASSERT_EQ(3, problem->NumParameterBlocks());
- ASSERT_EQ(0, problem->NumResidualBlocks());
+ ASSERT_EQ(0, NumResidualBlocks());
EXPECT_EQ(y, GetParameterBlock(0)->user_state());
EXPECT_EQ(z, GetParameterBlock(1)->user_state());
EXPECT_EQ(w, GetParameterBlock(2)->user_state());
@@ -630,12 +647,12 @@
ResidualBlock* r_w = problem->AddResidualBlock(cost_w, NULL, w);
EXPECT_EQ(3, problem->NumParameterBlocks());
- EXPECT_EQ(7, problem->NumResidualBlocks());
+ EXPECT_EQ(7, NumResidualBlocks());
// Remove w, which should remove r_yzw, r_yw, r_zw, r_w.
problem->RemoveParameterBlock(w);
ASSERT_EQ(2, problem->NumParameterBlocks());
- ASSERT_EQ(3, problem->NumResidualBlocks());
+ ASSERT_EQ(3, NumResidualBlocks());
ASSERT_FALSE(HasResidualBlock(r_yzw));
ASSERT_TRUE (HasResidualBlock(r_yz ));
@@ -648,7 +665,7 @@
// Remove z, which will remove almost everything else.
problem->RemoveParameterBlock(z);
ASSERT_EQ(1, problem->NumParameterBlocks());
- ASSERT_EQ(1, problem->NumResidualBlocks());
+ ASSERT_EQ(1, NumResidualBlocks());
ASSERT_FALSE(HasResidualBlock(r_yzw));
ASSERT_FALSE(HasResidualBlock(r_yz ));
@@ -661,7 +678,7 @@
// Remove y; all gone.
problem->RemoveParameterBlock(y);
EXPECT_EQ(0, problem->NumParameterBlocks());
- EXPECT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(0, NumResidualBlocks());
}
TEST_P(DynamicProblem, RemoveResidualBlock) {
@@ -699,14 +716,14 @@
EXPECT_TRUE(GetParameterBlock(2)->mutable_residual_blocks() == NULL);
}
EXPECT_EQ(3, problem->NumParameterBlocks());
- EXPECT_EQ(7, problem->NumResidualBlocks());
+ EXPECT_EQ(7, NumResidualBlocks());
// Remove each residual and check the state after each removal.
// Remove r_yzw.
problem->RemoveResidualBlock(r_yzw);
ASSERT_EQ(3, problem->NumParameterBlocks());
- ASSERT_EQ(6, problem->NumResidualBlocks());
+ ASSERT_EQ(6, NumResidualBlocks());
if (GetParam()) {
ExpectParameterBlockContains(y, r_yz, r_yw, r_y);
ExpectParameterBlockContains(z, r_yz, r_zw, r_z);
@@ -722,7 +739,7 @@
// Remove r_yw.
problem->RemoveResidualBlock(r_yw);
ASSERT_EQ(3, problem->NumParameterBlocks());
- ASSERT_EQ(5, problem->NumResidualBlocks());
+ ASSERT_EQ(5, NumResidualBlocks());
if (GetParam()) {
ExpectParameterBlockContains(y, r_yz, r_y);
ExpectParameterBlockContains(z, r_yz, r_zw, r_z);
@@ -737,7 +754,7 @@
// Remove r_zw.
problem->RemoveResidualBlock(r_zw);
ASSERT_EQ(3, problem->NumParameterBlocks());
- ASSERT_EQ(4, problem->NumResidualBlocks());
+ ASSERT_EQ(4, NumResidualBlocks());
if (GetParam()) {
ExpectParameterBlockContains(y, r_yz, r_y);
ExpectParameterBlockContains(z, r_yz, r_z);
@@ -751,7 +768,7 @@
// Remove r_w.
problem->RemoveResidualBlock(r_w);
ASSERT_EQ(3, problem->NumParameterBlocks());
- ASSERT_EQ(3, problem->NumResidualBlocks());
+ ASSERT_EQ(3, NumResidualBlocks());
if (GetParam()) {
ExpectParameterBlockContains(y, r_yz, r_y);
ExpectParameterBlockContains(z, r_yz, r_z);
@@ -764,7 +781,7 @@
// Remove r_yz.
problem->RemoveResidualBlock(r_yz);
ASSERT_EQ(3, problem->NumParameterBlocks());
- ASSERT_EQ(2, problem->NumResidualBlocks());
+ ASSERT_EQ(2, NumResidualBlocks());
if (GetParam()) {
ExpectParameterBlockContains(y, r_y);
ExpectParameterBlockContains(z, r_z);
@@ -777,7 +794,7 @@
problem->RemoveResidualBlock(r_z);
problem->RemoveResidualBlock(r_y);
ASSERT_EQ(3, problem->NumParameterBlocks());
- ASSERT_EQ(0, problem->NumResidualBlocks());
+ ASSERT_EQ(0, NumResidualBlocks());
if (GetParam()) {
ExpectParameterBlockContains(y);
ExpectParameterBlockContains(z);
@@ -785,6 +802,56 @@
}
}
+TEST_P(DynamicProblem, RemoveInvalidResidualBlockDies) {
+ problem->AddParameterBlock(y, 4);
+ problem->AddParameterBlock(z, 5);
+ problem->AddParameterBlock(w, 3);
+
+ // Add all combinations of cost functions.
+ CostFunction* cost_yzw = new TernaryCostFunction(1, 4, 5, 3);
+ CostFunction* cost_yz = new BinaryCostFunction (1, 4, 5);
+ CostFunction* cost_yw = new BinaryCostFunction (1, 4, 3);
+ CostFunction* cost_zw = new BinaryCostFunction (1, 5, 3);
+ CostFunction* cost_y = new UnaryCostFunction (1, 4);
+ CostFunction* cost_z = new UnaryCostFunction (1, 5);
+ CostFunction* cost_w = new UnaryCostFunction (1, 3);
+
+ ResidualBlock* r_yzw = problem->AddResidualBlock(cost_yzw, NULL, y, z, w);
+ ResidualBlock* r_yz = problem->AddResidualBlock(cost_yz, NULL, y, z);
+ ResidualBlock* r_yw = problem->AddResidualBlock(cost_yw, NULL, y, w);
+ ResidualBlock* r_zw = problem->AddResidualBlock(cost_zw, NULL, z, w);
+ ResidualBlock* r_y = problem->AddResidualBlock(cost_y, NULL, y);
+ ResidualBlock* r_z = problem->AddResidualBlock(cost_z, NULL, z);
+ ResidualBlock* r_w = problem->AddResidualBlock(cost_w, NULL, w);
+
+ // Remove r_yzw.
+ problem->RemoveResidualBlock(r_yzw);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(6, NumResidualBlocks());
+ // Attempt to remove r_yzw again.
+ EXPECT_DEATH_IF_SUPPORTED(problem->RemoveResidualBlock(r_yzw), "not found");
+
+ // Attempt to remove a cast pointer never added as a residual.
+ int trash_memory = 1234;
+ ResidualBlock* invalid_residual =
+ reinterpret_cast<ResidualBlock*>(&trash_memory);
+ EXPECT_DEATH_IF_SUPPORTED(problem->RemoveResidualBlock(invalid_residual),
+ "not found");
+
+ // Remove a parameter block, which in turn removes the dependent residuals
+ // then attempt to remove them directly.
+ problem->RemoveParameterBlock(z);
+ ASSERT_EQ(2, problem->NumParameterBlocks());
+ ASSERT_EQ(3, NumResidualBlocks());
+ EXPECT_DEATH_IF_SUPPORTED(problem->RemoveResidualBlock(r_yz), "not found");
+ EXPECT_DEATH_IF_SUPPORTED(problem->RemoveResidualBlock(r_zw), "not found");
+ EXPECT_DEATH_IF_SUPPORTED(problem->RemoveResidualBlock(r_z), "not found");
+
+ problem->RemoveResidualBlock(r_yw);
+ problem->RemoveResidualBlock(r_w);
+ problem->RemoveResidualBlock(r_y);
+}
+
// Check that a null-terminated array, a, has the same elements as b.
template<typename T>
void ExpectVectorContainsUnordered(const T* a, const vector<T>& b) {