Add support for removing parameter and residual blocks.
This adds support for removing parameter and residual blocks.
There are two modes of operation: in the first, removals of
paremeter blocks are expensive, since each remove requires
scanning all residual blocks to find ones that depend on the
removed parameter. In the other, extra memory is sacrificed to
maintain a list of the residuals a parameter block depends on,
removing the need to scan. In both cases, removing residual blocks
is fast.
As a caveat, any removals destroys the ordering of the parameters,
so the residuals or jacobian returned from Solver::Solve() is
meaningless. There is some debate on the best way to handle this;
the details remain for a future change.
This also adds some overhead, even in the case that fast removals
are not requested:
- 1 int32 to each residual, to track its position in the program.
- 1 pointer to each parameter, to store the dependent residuals.
Change-Id: I71dcac8656679329a15ee7fc12c0df07030c12af
diff --git a/internal/ceres/problem_test.cc b/internal/ceres/problem_test.cc
index 4afe1b5..55f355b 100644
--- a/internal/ceres/problem_test.cc
+++ b/internal/ceres/problem_test.cc
@@ -30,10 +30,14 @@
// keir@google.com (Keir Mierle)
#include "ceres/problem.h"
+#include "ceres/problem_impl.h"
#include "gtest/gtest.h"
#include "ceres/cost_function.h"
#include "ceres/local_parameterization.h"
+#include "ceres/map_util.h"
+#include "ceres/parameter_block.h"
+#include "ceres/program.h"
#include "ceres/sized_cost_function.h"
#include "ceres/internal/scoped_ptr.h"
@@ -293,11 +297,11 @@
class DestructorCountingCostFunction : public SizedCostFunction<3, 4, 5> {
public:
- explicit DestructorCountingCostFunction(int *counter)
- : counter_(counter) {}
+ explicit DestructorCountingCostFunction(int *num_destructions)
+ : num_destructions_(num_destructions) {}
virtual ~DestructorCountingCostFunction() {
- *counter_ += 1;
+ *num_destructions_ += 1;
}
virtual bool Evaluate(double const* const* parameters,
@@ -307,12 +311,12 @@
}
private:
- int* counter_;
+ int* num_destructions_;
};
TEST(Problem, ReusedCostFunctionsAreOnlyDeletedOnce) {
double y[4], z[5];
- int counter = 0;
+ int num_destructions = 0;
// Add a cost function multiple times and check to make sure that
// the destructor on the cost function is only called once.
@@ -321,15 +325,375 @@
problem.AddParameterBlock(y, 4);
problem.AddParameterBlock(z, 5);
- CostFunction* cost = new DestructorCountingCostFunction(&counter);
+ CostFunction* cost = new DestructorCountingCostFunction(&num_destructions);
problem.AddResidualBlock(cost, NULL, y, z);
problem.AddResidualBlock(cost, NULL, y, z);
problem.AddResidualBlock(cost, NULL, y, z);
+ EXPECT_EQ(3, problem.NumResidualBlocks());
}
// Check that the destructor was called only once.
- CHECK_EQ(counter, 1);
+ CHECK_EQ(num_destructions, 1);
}
+TEST(Problem, CostFunctionsAreDeletedEvenWithRemovals) {
+ double y[4], z[5], w[4];
+ int num_destructions = 0;
+ {
+ Problem problem;
+ problem.AddParameterBlock(y, 4);
+ problem.AddParameterBlock(z, 5);
+
+ CostFunction* cost_yz =
+ new DestructorCountingCostFunction(&num_destructions);
+ CostFunction* cost_wz =
+ new DestructorCountingCostFunction(&num_destructions);
+ ResidualBlock* r_yz = problem.AddResidualBlock(cost_yz, NULL, y, z);
+ ResidualBlock* r_wz = problem.AddResidualBlock(cost_wz, NULL, w, z);
+ EXPECT_EQ(2, problem.NumResidualBlocks());
+
+ // In the current implementation, the destructor shouldn't get run yet.
+ problem.RemoveResidualBlock(r_yz);
+ CHECK_EQ(num_destructions, 0);
+ problem.RemoveResidualBlock(r_wz);
+ CHECK_EQ(num_destructions, 0);
+
+ EXPECT_EQ(0, problem.NumResidualBlocks());
+ }
+ CHECK_EQ(num_destructions, 2);
+}
+
+// Make the dynamic problem tests (e.g. for removing residual blocks)
+// parameterized on whether the low-latency mode is enabled or not.
+//
+// This tests against ProblemImpl instead of Problem in order to inspect the
+// state of the resulting Program; this is difficult with only the thin Problem
+// interface.
+struct DynamicProblem : public ::testing::TestWithParam<bool> {
+ DynamicProblem() {
+ Problem::Options options;
+ options.enable_fast_parameter_block_removal = GetParam();
+ problem.reset(new ProblemImpl(options));
+ }
+
+ ParameterBlock* GetParameterBlock(int block) {
+ return problem->program().parameter_blocks()[block];
+ }
+ ResidualBlock* GetResidualBlock(int block) {
+ return problem->program().residual_blocks()[block];
+ }
+
+ bool HasResidualBlock(ResidualBlock* residual_block) {
+ return find(problem->program().residual_blocks().begin(),
+ problem->program().residual_blocks().end(),
+ residual_block) != problem->program().residual_blocks().end();
+ }
+
+ // The next block of functions until the end are only for testing the
+ // residual block removals.
+ void ExpectParameterBlockContainsResidualBlock(
+ double* values,
+ ResidualBlock* residual_block) {
+ ParameterBlock* parameter_block =
+ FindOrDie(problem->parameter_map(), values);
+ EXPECT_TRUE(ContainsKey(*(parameter_block->mutable_residual_blocks()),
+ residual_block));
+ }
+
+ void ExpectSize(double* values, int size) {
+ ParameterBlock* parameter_block =
+ FindOrDie(problem->parameter_map(), values);
+ EXPECT_EQ(size, parameter_block->mutable_residual_blocks()->size());
+ }
+
+ // Degenerate case.
+ void ExpectParameterBlockContains(double* values) {
+ ExpectSize(values, 0);
+ }
+
+ void ExpectParameterBlockContains(double* values,
+ ResidualBlock* r1) {
+ ExpectSize(values, 1);
+ ExpectParameterBlockContainsResidualBlock(values, r1);
+ }
+
+ void ExpectParameterBlockContains(double* values,
+ ResidualBlock* r1,
+ ResidualBlock* r2) {
+ ExpectSize(values, 2);
+ ExpectParameterBlockContainsResidualBlock(values, r1);
+ ExpectParameterBlockContainsResidualBlock(values, r2);
+ }
+
+ void ExpectParameterBlockContains(double* values,
+ ResidualBlock* r1,
+ ResidualBlock* r2,
+ ResidualBlock* r3) {
+ ExpectSize(values, 3);
+ ExpectParameterBlockContainsResidualBlock(values, r1);
+ ExpectParameterBlockContainsResidualBlock(values, r2);
+ ExpectParameterBlockContainsResidualBlock(values, r3);
+ }
+
+ void ExpectParameterBlockContains(double* values,
+ ResidualBlock* r1,
+ ResidualBlock* r2,
+ ResidualBlock* r3,
+ ResidualBlock* r4) {
+ ExpectSize(values, 4);
+ ExpectParameterBlockContainsResidualBlock(values, r1);
+ ExpectParameterBlockContainsResidualBlock(values, r2);
+ ExpectParameterBlockContainsResidualBlock(values, r3);
+ ExpectParameterBlockContainsResidualBlock(values, r4);
+ }
+
+ scoped_ptr<ProblemImpl> problem;
+ double y[4], z[5], w[3];
+};
+
+TEST_P(DynamicProblem, RemoveParameterBlockWithNoResiduals) {
+ problem->AddParameterBlock(y, 4);
+ problem->AddParameterBlock(z, 5);
+ problem->AddParameterBlock(w, 3);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(y, GetParameterBlock(0)->user_state());
+ EXPECT_EQ(z, GetParameterBlock(1)->user_state());
+ EXPECT_EQ(w, GetParameterBlock(2)->user_state());
+
+ // w is at the end, which might break the swapping logic so try adding and
+ // removing it.
+ problem->RemoveParameterBlock(w);
+ ASSERT_EQ(2, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(y, GetParameterBlock(0)->user_state());
+ EXPECT_EQ(z, GetParameterBlock(1)->user_state());
+ problem->AddParameterBlock(w, 3);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(y, GetParameterBlock(0)->user_state());
+ EXPECT_EQ(z, GetParameterBlock(1)->user_state());
+ EXPECT_EQ(w, GetParameterBlock(2)->user_state());
+
+ // Now remove z, which is in the middle, and add it back.
+ problem->RemoveParameterBlock(z);
+ ASSERT_EQ(2, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(y, GetParameterBlock(0)->user_state());
+ EXPECT_EQ(w, GetParameterBlock(1)->user_state());
+ problem->AddParameterBlock(z, 5);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(y, GetParameterBlock(0)->user_state());
+ EXPECT_EQ(w, GetParameterBlock(1)->user_state());
+ EXPECT_EQ(z, GetParameterBlock(2)->user_state());
+
+ // Now remove everything.
+ // y
+ problem->RemoveParameterBlock(y);
+ ASSERT_EQ(2, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(z, GetParameterBlock(0)->user_state());
+ EXPECT_EQ(w, GetParameterBlock(1)->user_state());
+
+ // z
+ problem->RemoveParameterBlock(z);
+ ASSERT_EQ(1, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(w, GetParameterBlock(0)->user_state());
+
+ // w
+ problem->RemoveParameterBlock(w);
+ EXPECT_EQ(0, problem->NumParameterBlocks());
+ EXPECT_EQ(0, problem->NumResidualBlocks());
+}
+
+TEST_P(DynamicProblem, RemoveParameterBlockWithResiduals) {
+ problem->AddParameterBlock(y, 4);
+ problem->AddParameterBlock(z, 5);
+ problem->AddParameterBlock(w, 3);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(y, GetParameterBlock(0)->user_state());
+ EXPECT_EQ(z, GetParameterBlock(1)->user_state());
+ EXPECT_EQ(w, GetParameterBlock(2)->user_state());
+
+ // Add all combinations of cost functions.
+ CostFunction* cost_yzw = new TernaryCostFunction(1, 4, 5, 3);
+ CostFunction* cost_yz = new BinaryCostFunction (1, 4, 5);
+ CostFunction* cost_yw = new BinaryCostFunction (1, 4, 3);
+ CostFunction* cost_zw = new BinaryCostFunction (1, 5, 3);
+ CostFunction* cost_y = new UnaryCostFunction (1, 4);
+ CostFunction* cost_z = new UnaryCostFunction (1, 5);
+ CostFunction* cost_w = new UnaryCostFunction (1, 3);
+
+ ResidualBlock* r_yzw = problem->AddResidualBlock(cost_yzw, NULL, y, z, w);
+ ResidualBlock* r_yz = problem->AddResidualBlock(cost_yz, NULL, y, z);
+ ResidualBlock* r_yw = problem->AddResidualBlock(cost_yw, NULL, y, w);
+ ResidualBlock* r_zw = problem->AddResidualBlock(cost_zw, NULL, z, w);
+ ResidualBlock* r_y = problem->AddResidualBlock(cost_y, NULL, y);
+ ResidualBlock* r_z = problem->AddResidualBlock(cost_z, NULL, z);
+ ResidualBlock* r_w = problem->AddResidualBlock(cost_w, NULL, w);
+
+ EXPECT_EQ(3, problem->NumParameterBlocks());
+ EXPECT_EQ(7, problem->NumResidualBlocks());
+
+ // Remove w, which should remove r_yzw, r_yw, r_zw, r_w.
+ problem->RemoveParameterBlock(w);
+ ASSERT_EQ(2, problem->NumParameterBlocks());
+ ASSERT_EQ(3, problem->NumResidualBlocks());
+
+ ASSERT_FALSE(HasResidualBlock(r_yzw));
+ ASSERT_TRUE (HasResidualBlock(r_yz ));
+ ASSERT_FALSE(HasResidualBlock(r_yw ));
+ ASSERT_FALSE(HasResidualBlock(r_zw ));
+ ASSERT_TRUE (HasResidualBlock(r_y ));
+ ASSERT_TRUE (HasResidualBlock(r_z ));
+ ASSERT_FALSE(HasResidualBlock(r_w ));
+
+ // Remove z, which will remove almost everything else.
+ problem->RemoveParameterBlock(z);
+ ASSERT_EQ(1, problem->NumParameterBlocks());
+ ASSERT_EQ(1, problem->NumResidualBlocks());
+
+ ASSERT_FALSE(HasResidualBlock(r_yzw));
+ ASSERT_FALSE(HasResidualBlock(r_yz ));
+ ASSERT_FALSE(HasResidualBlock(r_yw ));
+ ASSERT_FALSE(HasResidualBlock(r_zw ));
+ ASSERT_TRUE (HasResidualBlock(r_y ));
+ ASSERT_FALSE(HasResidualBlock(r_z ));
+ ASSERT_FALSE(HasResidualBlock(r_w ));
+
+ // Remove y; all gone.
+ problem->RemoveParameterBlock(y);
+ EXPECT_EQ(0, problem->NumParameterBlocks());
+ EXPECT_EQ(0, problem->NumResidualBlocks());
+}
+
+TEST_P(DynamicProblem, RemoveResidualBlock) {
+ problem->AddParameterBlock(y, 4);
+ problem->AddParameterBlock(z, 5);
+ problem->AddParameterBlock(w, 3);
+
+ // Add all combinations of cost functions.
+ CostFunction* cost_yzw = new TernaryCostFunction(1, 4, 5, 3);
+ CostFunction* cost_yz = new BinaryCostFunction (1, 4, 5);
+ CostFunction* cost_yw = new BinaryCostFunction (1, 4, 3);
+ CostFunction* cost_zw = new BinaryCostFunction (1, 5, 3);
+ CostFunction* cost_y = new UnaryCostFunction (1, 4);
+ CostFunction* cost_z = new UnaryCostFunction (1, 5);
+ CostFunction* cost_w = new UnaryCostFunction (1, 3);
+
+ ResidualBlock* r_yzw = problem->AddResidualBlock(cost_yzw, NULL, y, z, w);
+ ResidualBlock* r_yz = problem->AddResidualBlock(cost_yz, NULL, y, z);
+ ResidualBlock* r_yw = problem->AddResidualBlock(cost_yw, NULL, y, w);
+ ResidualBlock* r_zw = problem->AddResidualBlock(cost_zw, NULL, z, w);
+ ResidualBlock* r_y = problem->AddResidualBlock(cost_y, NULL, y);
+ ResidualBlock* r_z = problem->AddResidualBlock(cost_z, NULL, z);
+ ResidualBlock* r_w = problem->AddResidualBlock(cost_w, NULL, w);
+
+ if (GetParam()) {
+ // In this test parameterization, there should be back-pointers from the
+ // parameter blocks to the residual blocks.
+ ExpectParameterBlockContains(y, r_yzw, r_yz, r_yw, r_y);
+ ExpectParameterBlockContains(z, r_yzw, r_yz, r_zw, r_z);
+ ExpectParameterBlockContains(w, r_yzw, r_yw, r_zw, r_w);
+ } else {
+ // Otherwise, nothing.
+ EXPECT_TRUE(GetParameterBlock(0)->mutable_residual_blocks() == NULL);
+ EXPECT_TRUE(GetParameterBlock(1)->mutable_residual_blocks() == NULL);
+ EXPECT_TRUE(GetParameterBlock(2)->mutable_residual_blocks() == NULL);
+ }
+ EXPECT_EQ(3, problem->NumParameterBlocks());
+ EXPECT_EQ(7, problem->NumResidualBlocks());
+
+ // Remove each residual and check the state after each removal.
+
+ // Remove r_yzw.
+ problem->RemoveResidualBlock(r_yzw);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(6, problem->NumResidualBlocks());
+ if (GetParam()) {
+ ExpectParameterBlockContains(y, r_yz, r_yw, r_y);
+ ExpectParameterBlockContains(z, r_yz, r_zw, r_z);
+ ExpectParameterBlockContains(w, r_yw, r_zw, r_w);
+ }
+ ASSERT_TRUE (HasResidualBlock(r_yz ));
+ ASSERT_TRUE (HasResidualBlock(r_yw ));
+ ASSERT_TRUE (HasResidualBlock(r_zw ));
+ ASSERT_TRUE (HasResidualBlock(r_y ));
+ ASSERT_TRUE (HasResidualBlock(r_z ));
+ ASSERT_TRUE (HasResidualBlock(r_w ));
+
+ // Remove r_yw.
+ problem->RemoveResidualBlock(r_yw);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(5, problem->NumResidualBlocks());
+ if (GetParam()) {
+ ExpectParameterBlockContains(y, r_yz, r_y);
+ ExpectParameterBlockContains(z, r_yz, r_zw, r_z);
+ ExpectParameterBlockContains(w, r_zw, r_w);
+ }
+ ASSERT_TRUE (HasResidualBlock(r_yz ));
+ ASSERT_TRUE (HasResidualBlock(r_zw ));
+ ASSERT_TRUE (HasResidualBlock(r_y ));
+ ASSERT_TRUE (HasResidualBlock(r_z ));
+ ASSERT_TRUE (HasResidualBlock(r_w ));
+
+ // Remove r_zw.
+ problem->RemoveResidualBlock(r_zw);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(4, problem->NumResidualBlocks());
+ if (GetParam()) {
+ ExpectParameterBlockContains(y, r_yz, r_y);
+ ExpectParameterBlockContains(z, r_yz, r_z);
+ ExpectParameterBlockContains(w, r_w);
+ }
+ ASSERT_TRUE (HasResidualBlock(r_yz ));
+ ASSERT_TRUE (HasResidualBlock(r_y ));
+ ASSERT_TRUE (HasResidualBlock(r_z ));
+ ASSERT_TRUE (HasResidualBlock(r_w ));
+
+ // Remove r_w.
+ problem->RemoveResidualBlock(r_w);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(3, problem->NumResidualBlocks());
+ if (GetParam()) {
+ ExpectParameterBlockContains(y, r_yz, r_y);
+ ExpectParameterBlockContains(z, r_yz, r_z);
+ ExpectParameterBlockContains(w);
+ }
+ ASSERT_TRUE (HasResidualBlock(r_yz ));
+ ASSERT_TRUE (HasResidualBlock(r_y ));
+ ASSERT_TRUE (HasResidualBlock(r_z ));
+
+ // Remove r_yz.
+ problem->RemoveResidualBlock(r_yz);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(2, problem->NumResidualBlocks());
+ if (GetParam()) {
+ ExpectParameterBlockContains(y, r_y);
+ ExpectParameterBlockContains(z, r_z);
+ ExpectParameterBlockContains(w);
+ }
+ ASSERT_TRUE (HasResidualBlock(r_y ));
+ ASSERT_TRUE (HasResidualBlock(r_z ));
+
+ // Remove the last two.
+ problem->RemoveResidualBlock(r_z);
+ problem->RemoveResidualBlock(r_y);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ if (GetParam()) {
+ ExpectParameterBlockContains(y);
+ ExpectParameterBlockContains(z);
+ ExpectParameterBlockContains(w);
+ }
+}
+
+INSTANTIATE_TEST_CASE_P(OptionsInstantiation,
+ DynamicProblem,
+ ::testing::Values(true, false));
+
} // namespace internal
} // namespace ceres