Add support for removing parameter and residual blocks.
This adds support for removing parameter and residual blocks.
There are two modes of operation: in the first, removals of
paremeter blocks are expensive, since each remove requires
scanning all residual blocks to find ones that depend on the
removed parameter. In the other, extra memory is sacrificed to
maintain a list of the residuals a parameter block depends on,
removing the need to scan. In both cases, removing residual blocks
is fast.
As a caveat, any removals destroys the ordering of the parameters,
so the residuals or jacobian returned from Solver::Solve() is
meaningless. There is some debate on the best way to handle this;
the details remain for a future change.
This also adds some overhead, even in the case that fast removals
are not requested:
- 1 int32 to each residual, to track its position in the program.
- 1 pointer to each parameter, to store the dependent residuals.
Change-Id: I71dcac8656679329a15ee7fc12c0df07030c12af
diff --git a/internal/ceres/parameter_block.h b/internal/ceres/parameter_block.h
index f20805c..4fcafe0 100644
--- a/internal/ceres/parameter_block.h
+++ b/internal/ceres/parameter_block.h
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
+// Copyright 2010, 2011, 2012, 2013 Google Inc. All rights reserved.
// http://code.google.com/p/ceres-solver/
//
// Redistribution and use in source and binary forms, with or without
@@ -34,6 +34,7 @@
#include <cstdlib>
#include <string>
#include "ceres/array_utils.h"
+#include "ceres/collections_port.h"
#include "ceres/integral_types.h"
#include "ceres/internal/eigen.h"
#include "ceres/internal/port.h"
@@ -46,6 +47,7 @@
namespace internal {
class ProblemImpl;
+class ResidualBlock;
// The parameter block encodes the location of the user's original value, and
// also the "current state" of the parameter. The evaluator uses whatever is in
@@ -58,13 +60,28 @@
// responsible for the proper disposal of the local parameterization.
class ParameterBlock {
public:
- ParameterBlock(double* user_state, int size) {
- Init(user_state, size, NULL);
+ // TODO(keir): Decide what data structure is best here. Should this be a set?
+ // Probably not, because sets are memory inefficient. However, if it's a
+ // vector, you can get into pathological linear performance when removing a
+ // residual block from a problem where all the residual blocks depend on one
+ // parameter; for example, shared focal length in a bundle adjustment
+ // problem. It might be worth making a custom structure that is just an array
+ // when it is small, but transitions to a hash set when it has more elements.
+ //
+ // For now, use a hash set.
+ typedef HashSet<ResidualBlock*> ResidualBlockSet;
+
+ // Create a parameter block with the user state, size, and index specified.
+ // The size is the size of the parameter block and the index is the position
+ // if the parameter block inside a Program (if any).
+ ParameterBlock(double* user_state, int size, int index) {
+ Init(user_state, size, index, NULL);
}
ParameterBlock(double* user_state,
int size,
+ int index,
LocalParameterization* local_parameterization) {
- Init(user_state, size, local_parameterization);
+ Init(user_state, size, index, local_parameterization);
}
// The size of the parameter block.
@@ -187,12 +204,43 @@
delta_offset_);
}
+ void EnableResidualBlockDependencies() {
+ CHECK(residual_blocks_ == NULL)
+ << "Ceres bug: There is already a residual block collection "
+ << "for parameter block: " << ToString();
+ residual_blocks_ = new ResidualBlockSet;
+ }
+
+ void AddResidualBlock(ResidualBlock* residual_block) {
+ CHECK(residual_blocks_ != NULL)
+ << "Ceres bug: The residual block collection is null for parameter "
+ << "block: " << ToString();
+ residual_blocks_->insert(residual_block);
+ }
+
+ void RemoveResidualBlock(ResidualBlock* residual_block) {
+ CHECK(residual_blocks_ != NULL)
+ << "Ceres bug: The residual block collection is null for parameter "
+ << "block: " << ToString();
+ CHECK(residual_blocks_->find(residual_block) != residual_blocks_->end())
+ << "Ceres bug: Missing residual for parameter block: " << ToString();
+ residual_blocks_->erase(residual_block);
+ }
+
+ // This is only intended for iterating; perhaps this should only expose
+ // .begin() and .end().
+ ResidualBlockSet* mutable_residual_blocks() {
+ return residual_blocks_;
+ }
+
private:
void Init(double* user_state,
int size,
+ int index,
LocalParameterization* local_parameterization) {
user_state_ = user_state;
size_ = size;
+ index_ = index;
is_constant_ = false;
state_ = user_state_;
@@ -201,9 +249,10 @@
SetParameterization(local_parameterization);
}
- index_ = -1;
state_offset_ = -1;
delta_offset_ = -1;
+
+ residual_blocks_ = NULL;
}
bool UpdateLocalParameterizationJacobian() {
@@ -261,6 +310,9 @@
// The offset of this parameter block inside a larger delta vector.
int32 delta_offset_;
+ // If non-null, contains the residual blocks this parameter block is in.
+ ResidualBlockSet* residual_blocks_;
+
// Necessary so ProblemImpl can clean up the parameterizations.
friend class ProblemImpl;
};
diff --git a/internal/ceres/parameter_block_test.cc b/internal/ceres/parameter_block_test.cc
index 35998dc..09156f8 100644
--- a/internal/ceres/parameter_block_test.cc
+++ b/internal/ceres/parameter_block_test.cc
@@ -38,7 +38,7 @@
TEST(ParameterBlock, SetLocalParameterization) {
double x[3] = { 1.0, 2.0, 3.0 };
- ParameterBlock parameter_block(x, 3);
+ ParameterBlock parameter_block(x, 3, -1);
// The indices to set constant within the parameter block (used later).
vector<int> indices;
@@ -111,7 +111,7 @@
TEST(ParameterBlock, SetStateUpdatesLocalParameterizationJacobian) {
TestParameterization test_parameterization;
double x[1] = { 1.0 };
- ParameterBlock parameter_block(x, 1, &test_parameterization);
+ ParameterBlock parameter_block(x, 1, -1, &test_parameterization);
EXPECT_EQ(2.0, *parameter_block.LocalParameterizationJacobian());
@@ -122,7 +122,7 @@
TEST(ParameterBlock, PlusWithNoLocalParameterization) {
double x[2] = { 1.0, 2.0 };
- ParameterBlock parameter_block(x, 2);
+ ParameterBlock parameter_block(x, 2, -1);
double delta[2] = { 0.2, 0.3 };
double x_plus_delta[2];
@@ -164,7 +164,7 @@
TEST(ParameterBlock, DetectBadLocalParameterization) {
double x = 1;
BadLocalParameterization bad_parameterization;
- ParameterBlock parameter_block(&x, 1, &bad_parameterization);
+ ParameterBlock parameter_block(&x, 1, -1, &bad_parameterization);
double y = 2;
EXPECT_FALSE(parameter_block.SetState(&y));
}
diff --git a/internal/ceres/problem.cc b/internal/ceres/problem.cc
index 7ee5b5c..c8f4a21 100644
--- a/internal/ceres/problem.cc
+++ b/internal/ceres/problem.cc
@@ -36,8 +36,6 @@
namespace ceres {
-class ResidualBlock;
-
Problem::Problem() : problem_impl_(new internal::ProblemImpl) {}
Problem::Problem(const Problem::Options& options)
: problem_impl_(new internal::ProblemImpl(options)) {}
@@ -156,6 +154,14 @@
problem_impl_->AddParameterBlock(values, size, local_parameterization);
}
+void Problem::RemoveResidualBlock(ResidualBlockId residual_block) {
+ problem_impl_->RemoveResidualBlock(residual_block);
+}
+
+void Problem::RemoveParameterBlock(double* values) {
+ problem_impl_->RemoveParameterBlock(values);
+}
+
void Problem::SetParameterBlockConstant(double* values) {
problem_impl_->SetParameterBlockConstant(values);
}
diff --git a/internal/ceres/problem_impl.cc b/internal/ceres/problem_impl.cc
index e9d23ec..6154ddf 100644
--- a/internal/ceres/problem_impl.cc
+++ b/internal/ceres/problem_impl.cc
@@ -118,12 +118,58 @@
}
}
- ParameterBlock* new_parameter_block = new ParameterBlock(values, size);
+ // Pass the index of the new parameter block as well to keep the index in
+ // sync with the position of the parameter in the program's parameter vector.
+ ParameterBlock* new_parameter_block =
+ new ParameterBlock(values, size, program_->parameter_blocks_.size());
+
+ // For dynamic problems, add the list of dependent residual blocks, which is
+ // empty to start.
+ if (options_.enable_fast_parameter_block_removal) {
+ new_parameter_block->EnableResidualBlockDependencies();
+ }
parameter_block_map_[values] = new_parameter_block;
program_->parameter_blocks_.push_back(new_parameter_block);
return new_parameter_block;
}
+// Deletes the residual block in question, assuming there are no other
+// references to it inside the problem (e.g. by another parameter). Referenced
+// cost and loss functions are tucked away for future deletion, since it is not
+// possible to know whether other parts of the problem depend on them without
+// doing a full scan.
+void ProblemImpl::DeleteBlock(ResidualBlock* residual_block) {
+ // The const casts here are legit, since ResidualBlock holds these
+ // pointers as const pointers but we have ownership of them and
+ // have the right to destroy them when the destructor is called.
+ if (options_.cost_function_ownership == TAKE_OWNERSHIP &&
+ residual_block->cost_function() != NULL) {
+ cost_functions_to_delete_.push_back(
+ const_cast<CostFunction*>(residual_block->cost_function()));
+ }
+ if (options_.loss_function_ownership == TAKE_OWNERSHIP &&
+ residual_block->loss_function() != NULL) {
+ loss_functions_to_delete_.push_back(
+ const_cast<LossFunction*>(residual_block->loss_function()));
+ }
+ delete residual_block;
+}
+
+// Deletes the parameter block in question, assuming there are no other
+// references to it inside the problem (e.g. by any residual blocks).
+// Referenced parameterizations are tucked away for future deletion, since it
+// is not possible to know whether other parts of the problem depend on them
+// without doing a full scan.
+void ProblemImpl::DeleteBlock(ParameterBlock* parameter_block) {
+ if (options_.local_parameterization_ownership == TAKE_OWNERSHIP &&
+ parameter_block->local_parameterization() != NULL) {
+ local_parameterizations_to_delete_.push_back(
+ parameter_block->mutable_local_parameterization());
+ }
+ parameter_block_map_.erase(parameter_block->mutable_user_state());
+ delete parameter_block;
+}
+
ProblemImpl::ProblemImpl() : program_(new internal::Program) {}
ProblemImpl::ProblemImpl(const Problem::Options& options)
: options_(options),
@@ -132,54 +178,27 @@
ProblemImpl::~ProblemImpl() {
// Collect the unique cost/loss functions and delete the residuals.
const int num_residual_blocks = program_->residual_blocks_.size();
-
- vector<CostFunction*> cost_functions;
- cost_functions.reserve(num_residual_blocks);
-
- vector<LossFunction*> loss_functions;
- loss_functions.reserve(num_residual_blocks);
-
+ cost_functions_to_delete_.reserve(num_residual_blocks);
+ loss_functions_to_delete_.reserve(num_residual_blocks);
for (int i = 0; i < program_->residual_blocks_.size(); ++i) {
- ResidualBlock* residual_block = program_->residual_blocks_[i];
-
- // The const casts here are legit, since ResidualBlock holds these
- // pointers as const pointers but we have ownership of them and
- // have the right to destroy them when the destructor is called.
- if (options_.cost_function_ownership == TAKE_OWNERSHIP) {
- cost_functions.push_back(
- const_cast<CostFunction*>(residual_block->cost_function()));
- }
- if (options_.loss_function_ownership == TAKE_OWNERSHIP) {
- loss_functions.push_back(
- const_cast<LossFunction*>(residual_block->loss_function()));
- }
-
- delete residual_block;
+ DeleteBlock(program_->residual_blocks_[i]);
}
// Collect the unique parameterizations and delete the parameters.
- vector<LocalParameterization*> local_parameterizations;
for (int i = 0; i < program_->parameter_blocks_.size(); ++i) {
- ParameterBlock* parameter_block = program_->parameter_blocks_[i];
-
- if (options_.local_parameterization_ownership == TAKE_OWNERSHIP) {
- local_parameterizations.push_back(
- parameter_block->local_parameterization_);
- }
-
- delete parameter_block;
+ DeleteBlock(program_->parameter_blocks_[i]);
}
// Delete the owned cost/loss functions and parameterizations.
- STLDeleteUniqueContainerPointers(local_parameterizations.begin(),
- local_parameterizations.end());
- STLDeleteUniqueContainerPointers(cost_functions.begin(),
- cost_functions.end());
- STLDeleteUniqueContainerPointers(loss_functions.begin(),
- loss_functions.end());
+ STLDeleteUniqueContainerPointers(local_parameterizations_to_delete_.begin(),
+ local_parameterizations_to_delete_.end());
+ STLDeleteUniqueContainerPointers(cost_functions_to_delete_.begin(),
+ cost_functions_to_delete_.end());
+ STLDeleteUniqueContainerPointers(loss_functions_to_delete_.begin(),
+ loss_functions_to_delete_.end());
}
-const ResidualBlock* ProblemImpl::AddResidualBlock(
+ResidualBlock* ProblemImpl::AddResidualBlock(
CostFunction* cost_function,
LossFunction* loss_function,
const vector<double*>& parameter_blocks) {
@@ -238,14 +257,23 @@
ResidualBlock* new_residual_block =
new ResidualBlock(cost_function,
loss_function,
- parameter_block_ptrs);
+ parameter_block_ptrs,
+ program_->residual_blocks_.size());
+
+ // Add dependencies on the residual to the parameter blocks.
+ if (options_.enable_fast_parameter_block_removal) {
+ for (int i = 0; i < parameter_blocks.size(); ++i) {
+ parameter_block_ptrs[i]->AddResidualBlock(new_residual_block);
+ }
+ }
+
program_->residual_blocks_.push_back(new_residual_block);
return new_residual_block;
}
// Unfortunately, macros don't help much to reduce this code, and var args don't
// work because of the ambiguous case that there is no loss function.
-const ResidualBlock* ProblemImpl::AddResidualBlock(
+ResidualBlock* ProblemImpl::AddResidualBlock(
CostFunction* cost_function,
LossFunction* loss_function,
double* x0) {
@@ -254,7 +282,7 @@
return AddResidualBlock(cost_function, loss_function, residual_parameters);
}
-const ResidualBlock* ProblemImpl::AddResidualBlock(
+ResidualBlock* ProblemImpl::AddResidualBlock(
CostFunction* cost_function,
LossFunction* loss_function,
double* x0, double* x1) {
@@ -264,7 +292,7 @@
return AddResidualBlock(cost_function, loss_function, residual_parameters);
}
-const ResidualBlock* ProblemImpl::AddResidualBlock(
+ResidualBlock* ProblemImpl::AddResidualBlock(
CostFunction* cost_function,
LossFunction* loss_function,
double* x0, double* x1, double* x2) {
@@ -275,7 +303,7 @@
return AddResidualBlock(cost_function, loss_function, residual_parameters);
}
-const ResidualBlock* ProblemImpl::AddResidualBlock(
+ResidualBlock* ProblemImpl::AddResidualBlock(
CostFunction* cost_function,
LossFunction* loss_function,
double* x0, double* x1, double* x2, double* x3) {
@@ -287,7 +315,7 @@
return AddResidualBlock(cost_function, loss_function, residual_parameters);
}
-const ResidualBlock* ProblemImpl::AddResidualBlock(
+ResidualBlock* ProblemImpl::AddResidualBlock(
CostFunction* cost_function,
LossFunction* loss_function,
double* x0, double* x1, double* x2, double* x3, double* x4) {
@@ -300,7 +328,7 @@
return AddResidualBlock(cost_function, loss_function, residual_parameters);
}
-const ResidualBlock* ProblemImpl::AddResidualBlock(
+ResidualBlock* ProblemImpl::AddResidualBlock(
CostFunction* cost_function,
LossFunction* loss_function,
double* x0, double* x1, double* x2, double* x3, double* x4, double* x5) {
@@ -314,7 +342,7 @@
return AddResidualBlock(cost_function, loss_function, residual_parameters);
}
-const ResidualBlock* ProblemImpl::AddResidualBlock(
+ResidualBlock* ProblemImpl::AddResidualBlock(
CostFunction* cost_function,
LossFunction* loss_function,
double* x0, double* x1, double* x2, double* x3, double* x4, double* x5,
@@ -330,7 +358,7 @@
return AddResidualBlock(cost_function, loss_function, residual_parameters);
}
-const ResidualBlock* ProblemImpl::AddResidualBlock(
+ResidualBlock* ProblemImpl::AddResidualBlock(
CostFunction* cost_function,
LossFunction* loss_function,
double* x0, double* x1, double* x2, double* x3, double* x4, double* x5,
@@ -347,7 +375,7 @@
return AddResidualBlock(cost_function, loss_function, residual_parameters);
}
-const ResidualBlock* ProblemImpl::AddResidualBlock(
+ResidualBlock* ProblemImpl::AddResidualBlock(
CostFunction* cost_function,
LossFunction* loss_function,
double* x0, double* x1, double* x2, double* x3, double* x4, double* x5,
@@ -365,7 +393,7 @@
return AddResidualBlock(cost_function, loss_function, residual_parameters);
}
-const ResidualBlock* ProblemImpl::AddResidualBlock(
+ResidualBlock* ProblemImpl::AddResidualBlock(
CostFunction* cost_function,
LossFunction* loss_function,
double* x0, double* x1, double* x2, double* x3, double* x4, double* x5,
@@ -399,6 +427,77 @@
}
}
+// Delete a block from a vector of blocks, maintaining the indexing invariant.
+// This is done in constant time by moving an element from the end of the
+// vector over the element to remove, then popping the last element. It
+// destroys the ordering in the interest of speed.
+template<typename Block>
+void ProblemImpl::DeleteBlockInVector(vector<Block*>* mutable_blocks,
+ Block* block_to_remove) {
+ CHECK_EQ((*mutable_blocks)[block_to_remove->index()], block_to_remove)
+ << "You found a Ceres bug! Block: " << block_to_remove->ToString();
+
+ // Prepare the to-be-moved block for the new, lower-in-index position by
+ // setting the index to the blocks final location.
+ Block* tmp = mutable_blocks->back();
+ tmp->set_index(block_to_remove->index());
+
+ // Overwrite the to-be-deleted residual block with the one at the end.
+ (*mutable_blocks)[block_to_remove->index()] = tmp;
+
+ DeleteBlock(block_to_remove);
+
+ // The block is gone so shrink the vector of blocks accordingly.
+ mutable_blocks->pop_back();
+}
+
+void ProblemImpl::RemoveResidualBlock(ResidualBlock* residual_block) {
+ CHECK_NOTNULL(residual_block);
+
+ // If needed, remove the parameter dependencies on this residual block.
+ if (options_.enable_fast_parameter_block_removal) {
+ const int num_parameter_blocks_for_residual =
+ residual_block->NumParameterBlocks();
+ for (int i = 0; i < num_parameter_blocks_for_residual; ++i) {
+ residual_block->parameter_blocks()[i]
+ ->RemoveResidualBlock(residual_block);
+ }
+ }
+ DeleteBlockInVector(program_->mutable_residual_blocks(), residual_block);
+}
+
+void ProblemImpl::RemoveParameterBlock(double* values) {
+ ParameterBlock* parameter_block = FindOrDie(parameter_block_map_, values);
+
+ if (options_.enable_fast_parameter_block_removal) {
+ // Copy the dependent residuals from the parameter block because the set of
+ // dependents will change after each call to RemoveResidualBlock().
+ vector<ResidualBlock*> residual_blocks_to_remove(
+ parameter_block->mutable_residual_blocks()->begin(),
+ parameter_block->mutable_residual_blocks()->end());
+ for (int i = 0; i < residual_blocks_to_remove.size(); ++i) {
+ RemoveResidualBlock(residual_blocks_to_remove[i]);
+ }
+ } else {
+ // Scan all the residual blocks to remove ones that depend on the parameter
+ // block. Do the scan backwards since the vector changes while iterating.
+ const int num_residual_blocks = NumResidualBlocks();
+ for (int i = num_residual_blocks - 1; i >= 0; --i) {
+ ResidualBlock* residual_block =
+ (*(program_->mutable_residual_blocks()))[i];
+ const int num_parameter_blocks = residual_block->NumParameterBlocks();
+ for (int i = 0; i < num_parameter_blocks; ++i) {
+ if (residual_block->parameter_blocks()[i] == parameter_block) {
+ RemoveResidualBlock(residual_block);
+ // The parameter blocks are guaranteed unique.
+ break;
+ }
+ }
+ }
+ }
+ DeleteBlockInVector(program_->mutable_parameter_blocks(), parameter_block);
+}
+
void ProblemImpl::SetParameterBlockConstant(double* values) {
FindOrDie(parameter_block_map_, values)->SetConstant();
}
diff --git a/internal/ceres/problem_impl.h b/internal/ceres/problem_impl.h
index 82a1956..536e73a 100644
--- a/internal/ceres/problem_impl.h
+++ b/internal/ceres/problem_impl.h
@@ -118,6 +118,10 @@
void AddParameterBlock(double* values,
int size,
LocalParameterization* local_parameterization);
+
+ void RemoveResidualBlock(ResidualBlock* residual_block);
+ void RemoveParameterBlock(double* values);
+
void SetParameterBlockConstant(double* values);
void SetParameterBlockVariable(double* values);
void SetParameterization(double* values,
@@ -135,12 +139,33 @@
private:
ParameterBlock* InternalAddParameterBlock(double* values, int size);
+ // Delete the arguments in question. These differ from the Remove* functions
+ // in that they do not clean up references to the block to delete; they
+ // merely delete them.
+ template<typename Block>
+ void DeleteBlockInVector(vector<Block*>* mutable_blocks,
+ Block* block_to_remove);
+ void DeleteBlock(ResidualBlock* residual_block);
+ void DeleteBlock(ParameterBlock* parameter_block);
+
const Problem::Options options_;
// The mapping from user pointers to parameter blocks.
map<double*, ParameterBlock*> parameter_block_map_;
+ // The actual parameter and residual blocks.
internal::scoped_ptr<internal::Program> program_;
+
+ // When removing residual and parameter blocks, cost/loss functions and
+ // parameterizations have ambiguous ownership. Instead of scanning the entire
+ // problem to see if the cost/loss/parameterization is shared with other
+ // residual or parameter blocks, buffer them until destruction.
+ //
+ // TODO(keir): See if it makes sense to use sets instead.
+ vector<CostFunction*> cost_functions_to_delete_;
+ vector<LossFunction*> loss_functions_to_delete_;
+ vector<LocalParameterization*> local_parameterizations_to_delete_;
+
CERES_DISALLOW_COPY_AND_ASSIGN(ProblemImpl);
};
diff --git a/internal/ceres/problem_test.cc b/internal/ceres/problem_test.cc
index 4afe1b5..55f355b 100644
--- a/internal/ceres/problem_test.cc
+++ b/internal/ceres/problem_test.cc
@@ -30,10 +30,14 @@
// keir@google.com (Keir Mierle)
#include "ceres/problem.h"
+#include "ceres/problem_impl.h"
#include "gtest/gtest.h"
#include "ceres/cost_function.h"
#include "ceres/local_parameterization.h"
+#include "ceres/map_util.h"
+#include "ceres/parameter_block.h"
+#include "ceres/program.h"
#include "ceres/sized_cost_function.h"
#include "ceres/internal/scoped_ptr.h"
@@ -293,11 +297,11 @@
class DestructorCountingCostFunction : public SizedCostFunction<3, 4, 5> {
public:
- explicit DestructorCountingCostFunction(int *counter)
- : counter_(counter) {}
+ explicit DestructorCountingCostFunction(int *num_destructions)
+ : num_destructions_(num_destructions) {}
virtual ~DestructorCountingCostFunction() {
- *counter_ += 1;
+ *num_destructions_ += 1;
}
virtual bool Evaluate(double const* const* parameters,
@@ -307,12 +311,12 @@
}
private:
- int* counter_;
+ int* num_destructions_;
};
TEST(Problem, ReusedCostFunctionsAreOnlyDeletedOnce) {
double y[4], z[5];
- int counter = 0;
+ int num_destructions = 0;
// Add a cost function multiple times and check to make sure that
// the destructor on the cost function is only called once.
@@ -321,15 +325,375 @@
problem.AddParameterBlock(y, 4);
problem.AddParameterBlock(z, 5);
- CostFunction* cost = new DestructorCountingCostFunction(&counter);
+ CostFunction* cost = new DestructorCountingCostFunction(&num_destructions);
problem.AddResidualBlock(cost, NULL, y, z);
problem.AddResidualBlock(cost, NULL, y, z);
problem.AddResidualBlock(cost, NULL, y, z);
+ EXPECT_EQ(3, problem.NumResidualBlocks());
}
// Check that the destructor was called only once.
- CHECK_EQ(counter, 1);
+ CHECK_EQ(num_destructions, 1);
}
+TEST(Problem, CostFunctionsAreDeletedEvenWithRemovals) {
+ double y[4], z[5], w[4];
+ int num_destructions = 0;
+ {
+ Problem problem;
+ problem.AddParameterBlock(y, 4);
+ problem.AddParameterBlock(z, 5);
+
+ CostFunction* cost_yz =
+ new DestructorCountingCostFunction(&num_destructions);
+ CostFunction* cost_wz =
+ new DestructorCountingCostFunction(&num_destructions);
+ ResidualBlock* r_yz = problem.AddResidualBlock(cost_yz, NULL, y, z);
+ ResidualBlock* r_wz = problem.AddResidualBlock(cost_wz, NULL, w, z);
+ EXPECT_EQ(2, problem.NumResidualBlocks());
+
+ // In the current implementation, the destructor shouldn't get run yet.
+ problem.RemoveResidualBlock(r_yz);
+ CHECK_EQ(num_destructions, 0);
+ problem.RemoveResidualBlock(r_wz);
+ CHECK_EQ(num_destructions, 0);
+
+ EXPECT_EQ(0, problem.NumResidualBlocks());
+ }
+ CHECK_EQ(num_destructions, 2);
+}
+
+// Make the dynamic problem tests (e.g. for removing residual blocks)
+// parameterized on whether the low-latency mode is enabled or not.
+//
+// This tests against ProblemImpl instead of Problem in order to inspect the
+// state of the resulting Program; this is difficult with only the thin Problem
+// interface.
+struct DynamicProblem : public ::testing::TestWithParam<bool> {
+ DynamicProblem() {
+ Problem::Options options;
+ options.enable_fast_parameter_block_removal = GetParam();
+ problem.reset(new ProblemImpl(options));
+ }
+
+ ParameterBlock* GetParameterBlock(int block) {
+ return problem->program().parameter_blocks()[block];
+ }
+ ResidualBlock* GetResidualBlock(int block) {
+ return problem->program().residual_blocks()[block];
+ }
+
+ bool HasResidualBlock(ResidualBlock* residual_block) {
+ return find(problem->program().residual_blocks().begin(),
+ problem->program().residual_blocks().end(),
+ residual_block) != problem->program().residual_blocks().end();
+ }
+
+ // The next block of functions until the end are only for testing the
+ // residual block removals.
+ void ExpectParameterBlockContainsResidualBlock(
+ double* values,
+ ResidualBlock* residual_block) {
+ ParameterBlock* parameter_block =
+ FindOrDie(problem->parameter_map(), values);
+ EXPECT_TRUE(ContainsKey(*(parameter_block->mutable_residual_blocks()),
+ residual_block));
+ }
+
+ void ExpectSize(double* values, int size) {
+ ParameterBlock* parameter_block =
+ FindOrDie(problem->parameter_map(), values);
+ EXPECT_EQ(size, parameter_block->mutable_residual_blocks()->size());
+ }
+
+ // Degenerate case.
+ void ExpectParameterBlockContains(double* values) {
+ ExpectSize(values, 0);
+ }
+
+ void ExpectParameterBlockContains(double* values,
+ ResidualBlock* r1) {
+ ExpectSize(values, 1);
+ ExpectParameterBlockContainsResidualBlock(values, r1);
+ }
+
+ void ExpectParameterBlockContains(double* values,
+ ResidualBlock* r1,
+ ResidualBlock* r2) {
+ ExpectSize(values, 2);
+ ExpectParameterBlockContainsResidualBlock(values, r1);
+ ExpectParameterBlockContainsResidualBlock(values, r2);
+ }
+
+ void ExpectParameterBlockContains(double* values,
+ ResidualBlock* r1,
+ ResidualBlock* r2,
+ ResidualBlock* r3) {
+ ExpectSize(values, 3);
+ ExpectParameterBlockContainsResidualBlock(values, r1);
+ ExpectParameterBlockContainsResidualBlock(values, r2);
+ ExpectParameterBlockContainsResidualBlock(values, r3);
+ }
+
+ void ExpectParameterBlockContains(double* values,
+ ResidualBlock* r1,
+ ResidualBlock* r2,
+ ResidualBlock* r3,
+ ResidualBlock* r4) {
+ ExpectSize(values, 4);
+ ExpectParameterBlockContainsResidualBlock(values, r1);
+ ExpectParameterBlockContainsResidualBlock(values, r2);
+ ExpectParameterBlockContainsResidualBlock(values, r3);
+ ExpectParameterBlockContainsResidualBlock(values, r4);
+ }
+
+ scoped_ptr<ProblemImpl> problem;
+ double y[4], z[5], w[3];
+};
+
+TEST_P(DynamicProblem, RemoveParameterBlockWithNoResiduals) {
+ problem->AddParameterBlock(y, 4);
+ problem->AddParameterBlock(z, 5);
+ problem->AddParameterBlock(w, 3);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(y, GetParameterBlock(0)->user_state());
+ EXPECT_EQ(z, GetParameterBlock(1)->user_state());
+ EXPECT_EQ(w, GetParameterBlock(2)->user_state());
+
+ // w is at the end, which might break the swapping logic so try adding and
+ // removing it.
+ problem->RemoveParameterBlock(w);
+ ASSERT_EQ(2, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(y, GetParameterBlock(0)->user_state());
+ EXPECT_EQ(z, GetParameterBlock(1)->user_state());
+ problem->AddParameterBlock(w, 3);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(y, GetParameterBlock(0)->user_state());
+ EXPECT_EQ(z, GetParameterBlock(1)->user_state());
+ EXPECT_EQ(w, GetParameterBlock(2)->user_state());
+
+ // Now remove z, which is in the middle, and add it back.
+ problem->RemoveParameterBlock(z);
+ ASSERT_EQ(2, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(y, GetParameterBlock(0)->user_state());
+ EXPECT_EQ(w, GetParameterBlock(1)->user_state());
+ problem->AddParameterBlock(z, 5);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(y, GetParameterBlock(0)->user_state());
+ EXPECT_EQ(w, GetParameterBlock(1)->user_state());
+ EXPECT_EQ(z, GetParameterBlock(2)->user_state());
+
+ // Now remove everything.
+ // y
+ problem->RemoveParameterBlock(y);
+ ASSERT_EQ(2, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(z, GetParameterBlock(0)->user_state());
+ EXPECT_EQ(w, GetParameterBlock(1)->user_state());
+
+ // z
+ problem->RemoveParameterBlock(z);
+ ASSERT_EQ(1, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(w, GetParameterBlock(0)->user_state());
+
+ // w
+ problem->RemoveParameterBlock(w);
+ EXPECT_EQ(0, problem->NumParameterBlocks());
+ EXPECT_EQ(0, problem->NumResidualBlocks());
+}
+
+TEST_P(DynamicProblem, RemoveParameterBlockWithResiduals) {
+ problem->AddParameterBlock(y, 4);
+ problem->AddParameterBlock(z, 5);
+ problem->AddParameterBlock(w, 3);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ EXPECT_EQ(y, GetParameterBlock(0)->user_state());
+ EXPECT_EQ(z, GetParameterBlock(1)->user_state());
+ EXPECT_EQ(w, GetParameterBlock(2)->user_state());
+
+ // Add all combinations of cost functions.
+ CostFunction* cost_yzw = new TernaryCostFunction(1, 4, 5, 3);
+ CostFunction* cost_yz = new BinaryCostFunction (1, 4, 5);
+ CostFunction* cost_yw = new BinaryCostFunction (1, 4, 3);
+ CostFunction* cost_zw = new BinaryCostFunction (1, 5, 3);
+ CostFunction* cost_y = new UnaryCostFunction (1, 4);
+ CostFunction* cost_z = new UnaryCostFunction (1, 5);
+ CostFunction* cost_w = new UnaryCostFunction (1, 3);
+
+ ResidualBlock* r_yzw = problem->AddResidualBlock(cost_yzw, NULL, y, z, w);
+ ResidualBlock* r_yz = problem->AddResidualBlock(cost_yz, NULL, y, z);
+ ResidualBlock* r_yw = problem->AddResidualBlock(cost_yw, NULL, y, w);
+ ResidualBlock* r_zw = problem->AddResidualBlock(cost_zw, NULL, z, w);
+ ResidualBlock* r_y = problem->AddResidualBlock(cost_y, NULL, y);
+ ResidualBlock* r_z = problem->AddResidualBlock(cost_z, NULL, z);
+ ResidualBlock* r_w = problem->AddResidualBlock(cost_w, NULL, w);
+
+ EXPECT_EQ(3, problem->NumParameterBlocks());
+ EXPECT_EQ(7, problem->NumResidualBlocks());
+
+ // Remove w, which should remove r_yzw, r_yw, r_zw, r_w.
+ problem->RemoveParameterBlock(w);
+ ASSERT_EQ(2, problem->NumParameterBlocks());
+ ASSERT_EQ(3, problem->NumResidualBlocks());
+
+ ASSERT_FALSE(HasResidualBlock(r_yzw));
+ ASSERT_TRUE (HasResidualBlock(r_yz ));
+ ASSERT_FALSE(HasResidualBlock(r_yw ));
+ ASSERT_FALSE(HasResidualBlock(r_zw ));
+ ASSERT_TRUE (HasResidualBlock(r_y ));
+ ASSERT_TRUE (HasResidualBlock(r_z ));
+ ASSERT_FALSE(HasResidualBlock(r_w ));
+
+ // Remove z, which will remove almost everything else.
+ problem->RemoveParameterBlock(z);
+ ASSERT_EQ(1, problem->NumParameterBlocks());
+ ASSERT_EQ(1, problem->NumResidualBlocks());
+
+ ASSERT_FALSE(HasResidualBlock(r_yzw));
+ ASSERT_FALSE(HasResidualBlock(r_yz ));
+ ASSERT_FALSE(HasResidualBlock(r_yw ));
+ ASSERT_FALSE(HasResidualBlock(r_zw ));
+ ASSERT_TRUE (HasResidualBlock(r_y ));
+ ASSERT_FALSE(HasResidualBlock(r_z ));
+ ASSERT_FALSE(HasResidualBlock(r_w ));
+
+ // Remove y; all gone.
+ problem->RemoveParameterBlock(y);
+ EXPECT_EQ(0, problem->NumParameterBlocks());
+ EXPECT_EQ(0, problem->NumResidualBlocks());
+}
+
+TEST_P(DynamicProblem, RemoveResidualBlock) {
+ problem->AddParameterBlock(y, 4);
+ problem->AddParameterBlock(z, 5);
+ problem->AddParameterBlock(w, 3);
+
+ // Add all combinations of cost functions.
+ CostFunction* cost_yzw = new TernaryCostFunction(1, 4, 5, 3);
+ CostFunction* cost_yz = new BinaryCostFunction (1, 4, 5);
+ CostFunction* cost_yw = new BinaryCostFunction (1, 4, 3);
+ CostFunction* cost_zw = new BinaryCostFunction (1, 5, 3);
+ CostFunction* cost_y = new UnaryCostFunction (1, 4);
+ CostFunction* cost_z = new UnaryCostFunction (1, 5);
+ CostFunction* cost_w = new UnaryCostFunction (1, 3);
+
+ ResidualBlock* r_yzw = problem->AddResidualBlock(cost_yzw, NULL, y, z, w);
+ ResidualBlock* r_yz = problem->AddResidualBlock(cost_yz, NULL, y, z);
+ ResidualBlock* r_yw = problem->AddResidualBlock(cost_yw, NULL, y, w);
+ ResidualBlock* r_zw = problem->AddResidualBlock(cost_zw, NULL, z, w);
+ ResidualBlock* r_y = problem->AddResidualBlock(cost_y, NULL, y);
+ ResidualBlock* r_z = problem->AddResidualBlock(cost_z, NULL, z);
+ ResidualBlock* r_w = problem->AddResidualBlock(cost_w, NULL, w);
+
+ if (GetParam()) {
+ // In this test parameterization, there should be back-pointers from the
+ // parameter blocks to the residual blocks.
+ ExpectParameterBlockContains(y, r_yzw, r_yz, r_yw, r_y);
+ ExpectParameterBlockContains(z, r_yzw, r_yz, r_zw, r_z);
+ ExpectParameterBlockContains(w, r_yzw, r_yw, r_zw, r_w);
+ } else {
+ // Otherwise, nothing.
+ EXPECT_TRUE(GetParameterBlock(0)->mutable_residual_blocks() == NULL);
+ EXPECT_TRUE(GetParameterBlock(1)->mutable_residual_blocks() == NULL);
+ EXPECT_TRUE(GetParameterBlock(2)->mutable_residual_blocks() == NULL);
+ }
+ EXPECT_EQ(3, problem->NumParameterBlocks());
+ EXPECT_EQ(7, problem->NumResidualBlocks());
+
+ // Remove each residual and check the state after each removal.
+
+ // Remove r_yzw.
+ problem->RemoveResidualBlock(r_yzw);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(6, problem->NumResidualBlocks());
+ if (GetParam()) {
+ ExpectParameterBlockContains(y, r_yz, r_yw, r_y);
+ ExpectParameterBlockContains(z, r_yz, r_zw, r_z);
+ ExpectParameterBlockContains(w, r_yw, r_zw, r_w);
+ }
+ ASSERT_TRUE (HasResidualBlock(r_yz ));
+ ASSERT_TRUE (HasResidualBlock(r_yw ));
+ ASSERT_TRUE (HasResidualBlock(r_zw ));
+ ASSERT_TRUE (HasResidualBlock(r_y ));
+ ASSERT_TRUE (HasResidualBlock(r_z ));
+ ASSERT_TRUE (HasResidualBlock(r_w ));
+
+ // Remove r_yw.
+ problem->RemoveResidualBlock(r_yw);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(5, problem->NumResidualBlocks());
+ if (GetParam()) {
+ ExpectParameterBlockContains(y, r_yz, r_y);
+ ExpectParameterBlockContains(z, r_yz, r_zw, r_z);
+ ExpectParameterBlockContains(w, r_zw, r_w);
+ }
+ ASSERT_TRUE (HasResidualBlock(r_yz ));
+ ASSERT_TRUE (HasResidualBlock(r_zw ));
+ ASSERT_TRUE (HasResidualBlock(r_y ));
+ ASSERT_TRUE (HasResidualBlock(r_z ));
+ ASSERT_TRUE (HasResidualBlock(r_w ));
+
+ // Remove r_zw.
+ problem->RemoveResidualBlock(r_zw);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(4, problem->NumResidualBlocks());
+ if (GetParam()) {
+ ExpectParameterBlockContains(y, r_yz, r_y);
+ ExpectParameterBlockContains(z, r_yz, r_z);
+ ExpectParameterBlockContains(w, r_w);
+ }
+ ASSERT_TRUE (HasResidualBlock(r_yz ));
+ ASSERT_TRUE (HasResidualBlock(r_y ));
+ ASSERT_TRUE (HasResidualBlock(r_z ));
+ ASSERT_TRUE (HasResidualBlock(r_w ));
+
+ // Remove r_w.
+ problem->RemoveResidualBlock(r_w);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(3, problem->NumResidualBlocks());
+ if (GetParam()) {
+ ExpectParameterBlockContains(y, r_yz, r_y);
+ ExpectParameterBlockContains(z, r_yz, r_z);
+ ExpectParameterBlockContains(w);
+ }
+ ASSERT_TRUE (HasResidualBlock(r_yz ));
+ ASSERT_TRUE (HasResidualBlock(r_y ));
+ ASSERT_TRUE (HasResidualBlock(r_z ));
+
+ // Remove r_yz.
+ problem->RemoveResidualBlock(r_yz);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(2, problem->NumResidualBlocks());
+ if (GetParam()) {
+ ExpectParameterBlockContains(y, r_y);
+ ExpectParameterBlockContains(z, r_z);
+ ExpectParameterBlockContains(w);
+ }
+ ASSERT_TRUE (HasResidualBlock(r_y ));
+ ASSERT_TRUE (HasResidualBlock(r_z ));
+
+ // Remove the last two.
+ problem->RemoveResidualBlock(r_z);
+ problem->RemoveResidualBlock(r_y);
+ ASSERT_EQ(3, problem->NumParameterBlocks());
+ ASSERT_EQ(0, problem->NumResidualBlocks());
+ if (GetParam()) {
+ ExpectParameterBlockContains(y);
+ ExpectParameterBlockContains(z);
+ ExpectParameterBlockContains(w);
+ }
+}
+
+INSTANTIATE_TEST_CASE_P(OptionsInstantiation,
+ DynamicProblem,
+ ::testing::Values(true, false));
+
} // namespace internal
} // namespace ceres
diff --git a/internal/ceres/residual_block.cc b/internal/ceres/residual_block.cc
index bdb88b1..7f78960 100644
--- a/internal/ceres/residual_block.cc
+++ b/internal/ceres/residual_block.cc
@@ -49,12 +49,14 @@
ResidualBlock::ResidualBlock(const CostFunction* cost_function,
const LossFunction* loss_function,
- const vector<ParameterBlock*>& parameter_blocks)
+ const vector<ParameterBlock*>& parameter_blocks,
+ int index)
: cost_function_(cost_function),
loss_function_(loss_function),
parameter_blocks_(
new ParameterBlock* [
- cost_function->parameter_block_sizes().size()]) {
+ cost_function->parameter_block_sizes().size()]),
+ index_(index) {
std::copy(parameter_blocks.begin(),
parameter_blocks.end(),
parameter_blocks_.get());
diff --git a/internal/ceres/residual_block.h b/internal/ceres/residual_block.h
index e0a06e7..3921d1d 100644
--- a/internal/ceres/residual_block.h
+++ b/internal/ceres/residual_block.h
@@ -34,11 +34,13 @@
#ifndef CERES_INTERNAL_RESIDUAL_BLOCK_H_
#define CERES_INTERNAL_RESIDUAL_BLOCK_H_
+#include <string>
#include <vector>
#include "ceres/cost_function.h"
#include "ceres/internal/port.h"
#include "ceres/internal/scoped_ptr.h"
+#include "ceres/stringprintf.h"
#include "ceres/types.h"
namespace ceres {
@@ -64,9 +66,13 @@
// loss functions, and parameter blocks.
class ResidualBlock {
public:
+ // Construct the residual block with the given cost/loss functions. Loss may
+ // be null. The index is the index of the residual block in the Program's
+ // residual_blocks array.
ResidualBlock(const CostFunction* cost_function,
const LossFunction* loss_function,
- const vector<ParameterBlock*>& parameter_blocks);
+ const vector<ParameterBlock*>& parameter_blocks,
+ int index);
// Evaluates the residual term, storing the scalar cost in *cost, the residual
// components in *residuals, and the jacobians between the parameters and
@@ -112,10 +118,23 @@
// The minimum amount of scratch space needed to pass to Evaluate().
int NumScratchDoublesForEvaluate() const;
+ // This residual block's index in an array.
+ int index() const { return index_; }
+ void set_index(int index) { index_ = index; }
+
+ string ToString() {
+ return StringPrintf("{residual block; index=%d}", index_);
+ }
+
private:
const CostFunction* cost_function_;
const LossFunction* loss_function_;
scoped_array<ParameterBlock*> parameter_blocks_;
+
+ // The index of the residual, typically in a Program. This is only to permit
+ // switching from a ResidualBlock* to an index in the Program's array, needed
+ // to do efficient removals.
+ int32 index_;
};
} // namespace internal
diff --git a/internal/ceres/residual_block_test.cc b/internal/ceres/residual_block_test.cc
index 92b79f6..fddd44e 100644
--- a/internal/ceres/residual_block_test.cc
+++ b/internal/ceres/residual_block_test.cc
@@ -77,13 +77,13 @@
// Prepare the parameter blocks.
double values_x[2];
- ParameterBlock x(values_x, 2);
+ ParameterBlock x(values_x, 2, -1);
double values_y[3];
- ParameterBlock y(values_y, 3);
+ ParameterBlock y(values_y, 3, -1);
double values_z[4];
- ParameterBlock z(values_z, 4);
+ ParameterBlock z(values_z, 4, -1);
vector<ParameterBlock*> parameters;
parameters.push_back(&x);
@@ -93,7 +93,7 @@
TernaryCostFunction cost_function(3, 2, 3, 4);
// Create the object under tests.
- ResidualBlock residual_block(&cost_function, NULL, parameters);
+ ResidualBlock residual_block(&cost_function, NULL, parameters, -1);
// Verify getters.
EXPECT_EQ(&cost_function, residual_block.cost_function());
@@ -204,13 +204,13 @@
// Prepare the parameter blocks.
double values_x[2];
- ParameterBlock x(values_x, 2);
+ ParameterBlock x(values_x, 2, -1);
double values_y[3];
- ParameterBlock y(values_y, 3);
+ ParameterBlock y(values_y, 3, -1);
double values_z[4];
- ParameterBlock z(values_z, 4);
+ ParameterBlock z(values_z, 4, -1);
vector<ParameterBlock*> parameters;
parameters.push_back(&x);
@@ -232,7 +232,7 @@
LocallyParameterizedCostFunction cost_function;
// Create the object under tests.
- ResidualBlock residual_block(&cost_function, NULL, parameters);
+ ResidualBlock residual_block(&cost_function, NULL, parameters, -1);
// Verify getters.
EXPECT_EQ(&cost_function, residual_block.cost_function());
diff --git a/internal/ceres/residual_block_utils_test.cc b/internal/ceres/residual_block_utils_test.cc
index db9ad6d..24723b3 100644
--- a/internal/ceres/residual_block_utils_test.cc
+++ b/internal/ceres/residual_block_utils_test.cc
@@ -45,13 +45,14 @@
// with one residual succeeds with true or dies.
void CheckEvaluation(const CostFunction& cost_function, bool is_good) {
double x = 1.0;
- ParameterBlock parameter_block(&x, 1);
+ ParameterBlock parameter_block(&x, 1, -1);
vector<ParameterBlock*> parameter_blocks;
parameter_blocks.push_back(¶meter_block);
ResidualBlock residual_block(&cost_function,
NULL,
- parameter_blocks);
+ parameter_blocks,
+ -1);
scoped_array<double> scratch(
new double[residual_block.NumScratchDoublesForEvaluate()]);