Check validity of residual block before removal in RemoveResidualBlock. - Breaking change: Problem::Options::enable_fast_parameter_block_removal is now Problem::Options::enable_fast_removal, as it now controls the behaviour for both parameter and residual blocks. - Previously we did not check that the specified residual block to remove in RemoveResidualBlock actually represented a valid residual for the problem. - This meant that Ceres would die unexpectedly if the user passed an uninitialised residual_block, or more likely attempted to remove a residual block that had already been removed automatically after the user removed a parameter block upon on which it was dependent. - RemoveResidualBlock now verifies the validity of the given residual_block to remove. Either by checking against a hash set of all residuals maintained in ProblemImpl iff enable_fast_removal is enabled. Or by a full scan of the residual blocks if not. Change-Id: I9ab178e2f68a74135f0a8e20905b16405c77a62b
diff --git a/docs/source/modeling.rst b/docs/source/modeling.rst index 33a4098..f06ef0d 100644 --- a/docs/source/modeling.rst +++ b/docs/source/modeling.rst
@@ -1363,7 +1363,10 @@ Remove a residual block from the problem. Any parameters that the residual block depends on are not removed. The cost and loss functions for the residual block will not get deleted immediately; won't happen until the - problem itself is deleted. + problem itself is deleted. If Problem::Options::enable_fast_removal is + true, then the removal is fast (almost constant time). Otherwise, removing a + residual block will incur a scan of the entire Problem object to verify that + the residual_block represents a valid residual in the problem. **WARNING:** Removing a residual or parameter block will destroy the implicit ordering, rendering the jacobian or residuals returned @@ -1378,7 +1381,7 @@ of the problem (similar to cost/loss functions in residual block removal). Any residual blocks that depend on the parameter are also removed, as described above in RemoveResidualBlock(). If - Problem::Options::enable_fast_parameter_block_removal is true, then + Problem::Options::enable_fast_removal is true, then the removal is fast (almost constant time). Otherwise, removing a parameter block will incur a scan of the entire Problem object. @@ -1456,7 +1459,7 @@ Get all the residual blocks that depend on the given parameter block. - If `Problem::Options::enable_fast_parameter_block_removal` is + If `Problem::Options::enable_fast_removal` is `true`, then getting the residual blocks is fast and depends only on the number of residual blocks. Otherwise, getting the residual blocks for a parameter block will incur a scan of the entire
diff --git a/include/ceres/problem.h b/include/ceres/problem.h index a7d7815..bed792a 100644 --- a/include/ceres/problem.h +++ b/include/ceres/problem.h
@@ -124,7 +124,7 @@ : cost_function_ownership(TAKE_OWNERSHIP), loss_function_ownership(TAKE_OWNERSHIP), local_parameterization_ownership(TAKE_OWNERSHIP), - enable_fast_parameter_block_removal(false), + enable_fast_removal(false), disable_all_safety_checks(false) {} // These flags control whether the Problem object owns the cost @@ -138,17 +138,21 @@ Ownership loss_function_ownership; Ownership local_parameterization_ownership; - // If true, trades memory for a faster RemoveParameterBlock() operation. + // If true, trades memory for faster RemoveResidualBlock() and + // RemoveParameterBlock() operations. // - // RemoveParameterBlock() takes time proportional to the size of the entire - // Problem. If you only remove parameter blocks from the Problem - // occassionaly, this may be acceptable. However, if you are modifying the - // Problem frequently, and have memory to spare, then flip this switch to + // By default, RemoveParameterBlock() and RemoveResidualBlock() take time + // proportional to the size of the entire problem. If you only ever remove + // parameters or residuals from the problem occassionally, this might be + // acceptable. However, if you have memory to spare, enable this option to // make RemoveParameterBlock() take time proportional to the number of - // residual blocks that depend on it. The increase in memory usage is an - // additonal hash set per parameter block containing all the residuals that - // depend on the parameter block. - bool enable_fast_parameter_block_removal; + // residual blocks that depend on it, and RemoveResidualBlock() take (on + // average) constant time. + // + // The increase in memory usage is twofold: an additonal hash set per + // parameter block containing all the residuals that depend on the parameter + // block; and a hash set in the problem containing all residuals. + bool enable_fast_removal; // By default, Ceres performs a variety of safety checks when constructing // the problem. There is a small but measurable performance penalty to @@ -276,7 +280,7 @@ // residual blocks that depend on the parameter are also removed, as // described above in RemoveResidualBlock(). // - // If Problem::Options::enable_fast_parameter_block_removal is true, then the + // If Problem::Options::enable_fast_removal is true, then the // removal is fast (almost constant time). Otherwise, removing a parameter // block will incur a scan of the entire Problem object. // @@ -362,7 +366,7 @@ // Get all the residual blocks that depend on the given parameter block. // - // If Problem::Options::enable_fast_parameter_block_removal is true, then + // If Problem::Options::enable_fast_removal is true, then // getting the residual blocks is fast and depends only on the number of // residual blocks. Otherwise, getting the residual blocks for a parameter // block will incur a scan of the entire Problem object.
diff --git a/internal/ceres/problem_impl.cc b/internal/ceres/problem_impl.cc index 99f3f89..d9fb2af 100644 --- a/internal/ceres/problem_impl.cc +++ b/internal/ceres/problem_impl.cc
@@ -142,7 +142,7 @@ // For dynamic problems, add the list of dependent residual blocks, which is // empty to start. - if (options_.enable_fast_parameter_block_removal) { + if (options_.enable_fast_removal) { new_parameter_block->EnableResidualBlockDependencies(); } parameter_block_map_[values] = new_parameter_block; @@ -150,6 +150,26 @@ return new_parameter_block; } +void ProblemImpl::InternalRemoveResidualBlock(ResidualBlock* residual_block) { + CHECK_NOTNULL(residual_block); + // Perform no check on the validity of residual_block, that is handled in + // the public method: RemoveResidualBlock(). + + // If needed, remove the parameter dependencies on this residual block. + if (options_.enable_fast_removal) { + const int num_parameter_blocks_for_residual = + residual_block->NumParameterBlocks(); + for (int i = 0; i < num_parameter_blocks_for_residual; ++i) { + residual_block->parameter_blocks()[i] + ->RemoveResidualBlock(residual_block); + } + + ResidualBlockSet::iterator it = residual_block_set_.find(residual_block); + residual_block_set_.erase(it); + } + DeleteBlockInVector(program_->mutable_residual_blocks(), residual_block); +} + // Deletes the residual block in question, assuming there are no other // references to it inside the problem (e.g. by another parameter). Referenced // cost and loss functions are tucked away for future deletion, since it is not @@ -278,13 +298,18 @@ program_->residual_blocks_.size()); // Add dependencies on the residual to the parameter blocks. - if (options_.enable_fast_parameter_block_removal) { + if (options_.enable_fast_removal) { for (int i = 0; i < parameter_blocks.size(); ++i) { parameter_block_ptrs[i]->AddResidualBlock(new_residual_block); } } program_->residual_blocks_.push_back(new_residual_block); + + if (options_.enable_fast_removal) { + residual_block_set_.insert(new_residual_block); + } + return new_residual_block; } @@ -475,30 +500,46 @@ void ProblemImpl::RemoveResidualBlock(ResidualBlock* residual_block) { CHECK_NOTNULL(residual_block); - // If needed, remove the parameter dependencies on this residual block. - if (options_.enable_fast_parameter_block_removal) { - const int num_parameter_blocks_for_residual = - residual_block->NumParameterBlocks(); - for (int i = 0; i < num_parameter_blocks_for_residual; ++i) { - residual_block->parameter_blocks()[i] - ->RemoveResidualBlock(residual_block); - } + // Verify that residual_block identifies a residual in the current problem. + const string residual_not_found_message = + StringPrintf("Residual block to remove: %p not found. This usually means " + "one of three things have happened:\n" + " 1) residual_block is uninitialised and points to a random " + "area in memory.\n" + " 2) residual_block represented a residual that was added to" + " the problem, but referred to a parameter block which has " + "since been removed, which removes all residuals which " + "depend on that parameter block, and was thus removed.\n" + " 3) residual_block referred to a residual that has already " + "been removed from the problem (by the user).", + residual_block); + if (options_.enable_fast_removal) { + CHECK(residual_block_set_.find(residual_block) != + residual_block_set_.end()) + << residual_not_found_message; + } else { + // Perform a full search over all current residuals. + CHECK(std::find(program_->residual_blocks().begin(), + program_->residual_blocks().end(), + residual_block) != program_->residual_blocks().end()) + << residual_not_found_message; } - DeleteBlockInVector(program_->mutable_residual_blocks(), residual_block); + + InternalRemoveResidualBlock(residual_block); } void ProblemImpl::RemoveParameterBlock(double* values) { ParameterBlock* parameter_block = FindParameterBlockOrDie(parameter_block_map_, values); - if (options_.enable_fast_parameter_block_removal) { + if (options_.enable_fast_removal) { // Copy the dependent residuals from the parameter block because the set of // dependents will change after each call to RemoveResidualBlock(). vector<ResidualBlock*> residual_blocks_to_remove( parameter_block->mutable_residual_blocks()->begin(), parameter_block->mutable_residual_blocks()->end()); for (int i = 0; i < residual_blocks_to_remove.size(); ++i) { - RemoveResidualBlock(residual_blocks_to_remove[i]); + InternalRemoveResidualBlock(residual_blocks_to_remove[i]); } } else { // Scan all the residual blocks to remove ones that depend on the parameter @@ -510,7 +551,7 @@ const int num_parameter_blocks = residual_block->NumParameterBlocks(); for (int j = 0; j < num_parameter_blocks; ++j) { if (residual_block->parameter_blocks()[j] == parameter_block) { - RemoveResidualBlock(residual_block); + InternalRemoveResidualBlock(residual_block); // The parameter blocks are guaranteed unique. break; } @@ -784,7 +825,7 @@ FindParameterBlockOrDie(parameter_block_map_, const_cast<double*>(values)); - if (options_.enable_fast_parameter_block_removal) { + if (options_.enable_fast_removal) { // In this case the residual blocks that depend on the parameter block are // stored in the parameter block already, so just copy them out. CHECK_NOTNULL(residual_blocks)->resize(
diff --git a/internal/ceres/problem_impl.h b/internal/ceres/problem_impl.h index 75bdc2b..e846c03 100644 --- a/internal/ceres/problem_impl.h +++ b/internal/ceres/problem_impl.h
@@ -45,6 +45,7 @@ #include "ceres/internal/macros.h" #include "ceres/internal/port.h" #include "ceres/internal/scoped_ptr.h" +#include "ceres/collections_port.h" #include "ceres/problem.h" #include "ceres/types.h" @@ -63,6 +64,7 @@ class ProblemImpl { public: typedef map<double*, ParameterBlock*> ParameterMap; + typedef HashSet<ResidualBlock*> ResidualBlockSet; ProblemImpl(); explicit ProblemImpl(const Problem::Options& options); @@ -160,9 +162,15 @@ Program* mutable_program() { return program_.get(); } const ParameterMap& parameter_map() const { return parameter_block_map_; } + const ResidualBlockSet& residual_block_set() const { + CHECK(options_.enable_fast_removal) + << "Fast removal not enabled, residual_block_set is not maintained."; + return residual_block_set_; + } private: ParameterBlock* InternalAddParameterBlock(double* values, int size); + void InternalRemoveResidualBlock(ResidualBlock* residual_block); bool InternalEvaluate(Program* program, double* cost, @@ -184,6 +192,9 @@ // The mapping from user pointers to parameter blocks. map<double*, ParameterBlock*> parameter_block_map_; + // Iff enable_fast_removal is enabled, contains the current residual blocks. + ResidualBlockSet residual_block_set_; + // The actual parameter and residual blocks. internal::scoped_ptr<internal::Program> program_;
diff --git a/internal/ceres/problem_test.cc b/internal/ceres/problem_test.cc index 4507dfa..eb75e3a 100644 --- a/internal/ceres/problem_test.cc +++ b/internal/ceres/problem_test.cc
@@ -378,7 +378,7 @@ struct DynamicProblem : public ::testing::TestWithParam<bool> { DynamicProblem() { Problem::Options options; - options.enable_fast_parameter_block_removal = GetParam(); + options.enable_fast_removal = GetParam(); problem.reset(new ProblemImpl(options)); } @@ -390,9 +390,26 @@ } bool HasResidualBlock(ResidualBlock* residual_block) { - return find(problem->program().residual_blocks().begin(), - problem->program().residual_blocks().end(), - residual_block) != problem->program().residual_blocks().end(); + bool have_residual_block = true; + if (GetParam()) { + have_residual_block &= + (problem->residual_block_set().find(residual_block) != + problem->residual_block_set().end()); + } + have_residual_block &= + find(problem->program().residual_blocks().begin(), + problem->program().residual_blocks().end(), + residual_block) != problem->program().residual_blocks().end(); + return have_residual_block; + } + + int NumResidualBlocks() { + // Verify that the hash set of residuals is maintained consistently. + if (GetParam()) { + EXPECT_EQ(problem->residual_block_set().size(), + problem->NumResidualBlocks()); + } + return problem->NumResidualBlocks(); } // The next block of functions until the end are only for testing the @@ -550,7 +567,7 @@ problem->AddParameterBlock(z, 5); problem->AddParameterBlock(w, 3); ASSERT_EQ(3, problem->NumParameterBlocks()); - ASSERT_EQ(0, problem->NumResidualBlocks()); + ASSERT_EQ(0, NumResidualBlocks()); EXPECT_EQ(y, GetParameterBlock(0)->user_state()); EXPECT_EQ(z, GetParameterBlock(1)->user_state()); EXPECT_EQ(w, GetParameterBlock(2)->user_state()); @@ -559,12 +576,12 @@ // removing it. problem->RemoveParameterBlock(w); ASSERT_EQ(2, problem->NumParameterBlocks()); - ASSERT_EQ(0, problem->NumResidualBlocks()); + ASSERT_EQ(0, NumResidualBlocks()); EXPECT_EQ(y, GetParameterBlock(0)->user_state()); EXPECT_EQ(z, GetParameterBlock(1)->user_state()); problem->AddParameterBlock(w, 3); ASSERT_EQ(3, problem->NumParameterBlocks()); - ASSERT_EQ(0, problem->NumResidualBlocks()); + ASSERT_EQ(0, NumResidualBlocks()); EXPECT_EQ(y, GetParameterBlock(0)->user_state()); EXPECT_EQ(z, GetParameterBlock(1)->user_state()); EXPECT_EQ(w, GetParameterBlock(2)->user_state()); @@ -572,12 +589,12 @@ // Now remove z, which is in the middle, and add it back. problem->RemoveParameterBlock(z); ASSERT_EQ(2, problem->NumParameterBlocks()); - ASSERT_EQ(0, problem->NumResidualBlocks()); + ASSERT_EQ(0, NumResidualBlocks()); EXPECT_EQ(y, GetParameterBlock(0)->user_state()); EXPECT_EQ(w, GetParameterBlock(1)->user_state()); problem->AddParameterBlock(z, 5); ASSERT_EQ(3, problem->NumParameterBlocks()); - ASSERT_EQ(0, problem->NumResidualBlocks()); + ASSERT_EQ(0, NumResidualBlocks()); EXPECT_EQ(y, GetParameterBlock(0)->user_state()); EXPECT_EQ(w, GetParameterBlock(1)->user_state()); EXPECT_EQ(z, GetParameterBlock(2)->user_state()); @@ -586,20 +603,20 @@ // y problem->RemoveParameterBlock(y); ASSERT_EQ(2, problem->NumParameterBlocks()); - ASSERT_EQ(0, problem->NumResidualBlocks()); + ASSERT_EQ(0, NumResidualBlocks()); EXPECT_EQ(z, GetParameterBlock(0)->user_state()); EXPECT_EQ(w, GetParameterBlock(1)->user_state()); // z problem->RemoveParameterBlock(z); ASSERT_EQ(1, problem->NumParameterBlocks()); - ASSERT_EQ(0, problem->NumResidualBlocks()); + ASSERT_EQ(0, NumResidualBlocks()); EXPECT_EQ(w, GetParameterBlock(0)->user_state()); // w problem->RemoveParameterBlock(w); EXPECT_EQ(0, problem->NumParameterBlocks()); - EXPECT_EQ(0, problem->NumResidualBlocks()); + EXPECT_EQ(0, NumResidualBlocks()); } TEST_P(DynamicProblem, RemoveParameterBlockWithResiduals) { @@ -607,7 +624,7 @@ problem->AddParameterBlock(z, 5); problem->AddParameterBlock(w, 3); ASSERT_EQ(3, problem->NumParameterBlocks()); - ASSERT_EQ(0, problem->NumResidualBlocks()); + ASSERT_EQ(0, NumResidualBlocks()); EXPECT_EQ(y, GetParameterBlock(0)->user_state()); EXPECT_EQ(z, GetParameterBlock(1)->user_state()); EXPECT_EQ(w, GetParameterBlock(2)->user_state()); @@ -630,12 +647,12 @@ ResidualBlock* r_w = problem->AddResidualBlock(cost_w, NULL, w); EXPECT_EQ(3, problem->NumParameterBlocks()); - EXPECT_EQ(7, problem->NumResidualBlocks()); + EXPECT_EQ(7, NumResidualBlocks()); // Remove w, which should remove r_yzw, r_yw, r_zw, r_w. problem->RemoveParameterBlock(w); ASSERT_EQ(2, problem->NumParameterBlocks()); - ASSERT_EQ(3, problem->NumResidualBlocks()); + ASSERT_EQ(3, NumResidualBlocks()); ASSERT_FALSE(HasResidualBlock(r_yzw)); ASSERT_TRUE (HasResidualBlock(r_yz )); @@ -648,7 +665,7 @@ // Remove z, which will remove almost everything else. problem->RemoveParameterBlock(z); ASSERT_EQ(1, problem->NumParameterBlocks()); - ASSERT_EQ(1, problem->NumResidualBlocks()); + ASSERT_EQ(1, NumResidualBlocks()); ASSERT_FALSE(HasResidualBlock(r_yzw)); ASSERT_FALSE(HasResidualBlock(r_yz )); @@ -661,7 +678,7 @@ // Remove y; all gone. problem->RemoveParameterBlock(y); EXPECT_EQ(0, problem->NumParameterBlocks()); - EXPECT_EQ(0, problem->NumResidualBlocks()); + EXPECT_EQ(0, NumResidualBlocks()); } TEST_P(DynamicProblem, RemoveResidualBlock) { @@ -699,14 +716,14 @@ EXPECT_TRUE(GetParameterBlock(2)->mutable_residual_blocks() == NULL); } EXPECT_EQ(3, problem->NumParameterBlocks()); - EXPECT_EQ(7, problem->NumResidualBlocks()); + EXPECT_EQ(7, NumResidualBlocks()); // Remove each residual and check the state after each removal. // Remove r_yzw. problem->RemoveResidualBlock(r_yzw); ASSERT_EQ(3, problem->NumParameterBlocks()); - ASSERT_EQ(6, problem->NumResidualBlocks()); + ASSERT_EQ(6, NumResidualBlocks()); if (GetParam()) { ExpectParameterBlockContains(y, r_yz, r_yw, r_y); ExpectParameterBlockContains(z, r_yz, r_zw, r_z); @@ -722,7 +739,7 @@ // Remove r_yw. problem->RemoveResidualBlock(r_yw); ASSERT_EQ(3, problem->NumParameterBlocks()); - ASSERT_EQ(5, problem->NumResidualBlocks()); + ASSERT_EQ(5, NumResidualBlocks()); if (GetParam()) { ExpectParameterBlockContains(y, r_yz, r_y); ExpectParameterBlockContains(z, r_yz, r_zw, r_z); @@ -737,7 +754,7 @@ // Remove r_zw. problem->RemoveResidualBlock(r_zw); ASSERT_EQ(3, problem->NumParameterBlocks()); - ASSERT_EQ(4, problem->NumResidualBlocks()); + ASSERT_EQ(4, NumResidualBlocks()); if (GetParam()) { ExpectParameterBlockContains(y, r_yz, r_y); ExpectParameterBlockContains(z, r_yz, r_z); @@ -751,7 +768,7 @@ // Remove r_w. problem->RemoveResidualBlock(r_w); ASSERT_EQ(3, problem->NumParameterBlocks()); - ASSERT_EQ(3, problem->NumResidualBlocks()); + ASSERT_EQ(3, NumResidualBlocks()); if (GetParam()) { ExpectParameterBlockContains(y, r_yz, r_y); ExpectParameterBlockContains(z, r_yz, r_z); @@ -764,7 +781,7 @@ // Remove r_yz. problem->RemoveResidualBlock(r_yz); ASSERT_EQ(3, problem->NumParameterBlocks()); - ASSERT_EQ(2, problem->NumResidualBlocks()); + ASSERT_EQ(2, NumResidualBlocks()); if (GetParam()) { ExpectParameterBlockContains(y, r_y); ExpectParameterBlockContains(z, r_z); @@ -777,7 +794,7 @@ problem->RemoveResidualBlock(r_z); problem->RemoveResidualBlock(r_y); ASSERT_EQ(3, problem->NumParameterBlocks()); - ASSERT_EQ(0, problem->NumResidualBlocks()); + ASSERT_EQ(0, NumResidualBlocks()); if (GetParam()) { ExpectParameterBlockContains(y); ExpectParameterBlockContains(z); @@ -785,6 +802,56 @@ } } +TEST_P(DynamicProblem, RemoveInvalidResidualBlockDies) { + problem->AddParameterBlock(y, 4); + problem->AddParameterBlock(z, 5); + problem->AddParameterBlock(w, 3); + + // Add all combinations of cost functions. + CostFunction* cost_yzw = new TernaryCostFunction(1, 4, 5, 3); + CostFunction* cost_yz = new BinaryCostFunction (1, 4, 5); + CostFunction* cost_yw = new BinaryCostFunction (1, 4, 3); + CostFunction* cost_zw = new BinaryCostFunction (1, 5, 3); + CostFunction* cost_y = new UnaryCostFunction (1, 4); + CostFunction* cost_z = new UnaryCostFunction (1, 5); + CostFunction* cost_w = new UnaryCostFunction (1, 3); + + ResidualBlock* r_yzw = problem->AddResidualBlock(cost_yzw, NULL, y, z, w); + ResidualBlock* r_yz = problem->AddResidualBlock(cost_yz, NULL, y, z); + ResidualBlock* r_yw = problem->AddResidualBlock(cost_yw, NULL, y, w); + ResidualBlock* r_zw = problem->AddResidualBlock(cost_zw, NULL, z, w); + ResidualBlock* r_y = problem->AddResidualBlock(cost_y, NULL, y); + ResidualBlock* r_z = problem->AddResidualBlock(cost_z, NULL, z); + ResidualBlock* r_w = problem->AddResidualBlock(cost_w, NULL, w); + + // Remove r_yzw. + problem->RemoveResidualBlock(r_yzw); + ASSERT_EQ(3, problem->NumParameterBlocks()); + ASSERT_EQ(6, NumResidualBlocks()); + // Attempt to remove r_yzw again. + EXPECT_DEATH_IF_SUPPORTED(problem->RemoveResidualBlock(r_yzw), "not found"); + + // Attempt to remove a cast pointer never added as a residual. + int trash_memory = 1234; + ResidualBlock* invalid_residual = + reinterpret_cast<ResidualBlock*>(&trash_memory); + EXPECT_DEATH_IF_SUPPORTED(problem->RemoveResidualBlock(invalid_residual), + "not found"); + + // Remove a parameter block, which in turn removes the dependent residuals + // then attempt to remove them directly. + problem->RemoveParameterBlock(z); + ASSERT_EQ(2, problem->NumParameterBlocks()); + ASSERT_EQ(3, NumResidualBlocks()); + EXPECT_DEATH_IF_SUPPORTED(problem->RemoveResidualBlock(r_yz), "not found"); + EXPECT_DEATH_IF_SUPPORTED(problem->RemoveResidualBlock(r_zw), "not found"); + EXPECT_DEATH_IF_SUPPORTED(problem->RemoveResidualBlock(r_z), "not found"); + + problem->RemoveResidualBlock(r_yw); + problem->RemoveResidualBlock(r_w); + problem->RemoveResidualBlock(r_y); +} + // Check that a null-terminated array, a, has the same elements as b. template<typename T> void ExpectVectorContainsUnordered(const T* a, const vector<T>& b) {