Compute summary->fixed_cost while the residual blocks are removed
Change-Id: Ib598316ad21606f1a06db48c341a5f1d69915b2b
diff --git a/internal/ceres/solver_impl_test.cc b/internal/ceres/solver_impl_test.cc
index 36dd959..a6d6aac 100644
--- a/internal/ceres/solver_impl_test.cc
+++ b/internal/ceres/solver_impl_test.cc
@@ -41,6 +41,20 @@
namespace ceres {
namespace internal {
+// A cost function that sipmply returns its argument.
+class UnaryIdentityCostFunction : public SizedCostFunction<1, 1> {
+ public:
+ virtual bool Evaluate(double const* const* parameters,
+ double* residuals,
+ double** jacobians) const {
+ residuals[0] = parameters[0][0];
+ if (jacobians != NULL && jacobians[0] != NULL) {
+ jacobians[0][0] = 1.0;
+ }
+ return true;
+ }
+};
+
// Templated base class for the CostFunction signatures.
template <int kNumResiduals, int N0, int N1, int N2>
class MockCostFunctionBase : public
@@ -77,6 +91,7 @@
Program program(*problem.mutable_program());
EXPECT_TRUE(SolverImpl::RemoveFixedBlocksFromProgram(&program,
&num_eliminate_blocks,
+ NULL,
&error));
EXPECT_EQ(program.NumParameterBlocks(), 3);
EXPECT_EQ(program.NumResidualBlocks(), 3);
@@ -90,6 +105,7 @@
Program program(problem.program());
EXPECT_TRUE(SolverImpl::RemoveFixedBlocksFromProgram(&program,
&num_eliminate_blocks,
+ NULL,
&error));
EXPECT_EQ(program.NumParameterBlocks(), 3);
EXPECT_EQ(program.NumResidualBlocks(), 3);
@@ -110,6 +126,7 @@
string error;
EXPECT_TRUE(SolverImpl::RemoveFixedBlocksFromProgram(&program,
&num_eliminate_blocks,
+ NULL,
&error));
EXPECT_EQ(program.NumParameterBlocks(), 0);
EXPECT_EQ(program.NumResidualBlocks(), 0);
@@ -131,6 +148,7 @@
string error;
EXPECT_TRUE(SolverImpl::RemoveFixedBlocksFromProgram(&program,
&num_eliminate_blocks,
+ NULL,
&error));
EXPECT_EQ(program.NumParameterBlocks(), 0);
EXPECT_EQ(program.NumResidualBlocks(), 0);
@@ -155,6 +173,7 @@
string error;
EXPECT_TRUE(SolverImpl::RemoveFixedBlocksFromProgram(&program,
&num_eliminate_blocks,
+ NULL,
&error));
EXPECT_EQ(program.NumParameterBlocks(), 1);
EXPECT_EQ(program.NumResidualBlocks(), 1);
@@ -180,12 +199,47 @@
string error;
EXPECT_TRUE(SolverImpl::RemoveFixedBlocksFromProgram(&program,
&num_eliminate_blocks,
+ NULL,
&error));
EXPECT_EQ(program.NumParameterBlocks(), 2);
EXPECT_EQ(program.NumResidualBlocks(), 2);
EXPECT_EQ(num_eliminate_blocks, 1);
}
+TEST(SolverImpl, RemoveFixedBlocksFixedCost) {
+ ProblemImpl problem;
+ double x = 1.23;
+ double y = 4.56;
+ double z = 7.89;
+
+ problem.AddParameterBlock(&x, 1);
+ problem.AddParameterBlock(&y, 1);
+ problem.AddParameterBlock(&z, 1);
+ problem.AddResidualBlock(new UnaryIdentityCostFunction(), NULL, &x);
+ problem.AddResidualBlock(new TernaryCostFunction(), NULL, &x, &y, &z);
+ problem.AddResidualBlock(new BinaryCostFunction(), NULL, &x, &y);
+ problem.SetParameterBlockConstant(&x);
+
+ int num_eliminate_blocks = 2;
+ double fixed_cost = 0.0;
+ Program program(problem.program());
+
+ double expected_fixed_cost;
+ ResidualBlock *expected_removed_block = program.residual_blocks()[0];
+ scoped_array<double> scratch(new double[expected_removed_block->NumScratchDoublesForEvaluate()]);
+ expected_removed_block->Evaluate(&expected_fixed_cost, NULL, NULL, scratch.get());
+
+ string error;
+ EXPECT_TRUE(SolverImpl::RemoveFixedBlocksFromProgram(&program,
+ &num_eliminate_blocks,
+ &fixed_cost,
+ &error));
+ EXPECT_EQ(program.NumParameterBlocks(), 2);
+ EXPECT_EQ(program.NumResidualBlocks(), 2);
+ EXPECT_EQ(num_eliminate_blocks, 1);
+ EXPECT_DOUBLE_EQ(fixed_cost, expected_fixed_cost);
+}
+
TEST(SolverImpl, ReorderResidualBlockNonSchurSolver) {
ProblemImpl problem;
double x;
@@ -322,7 +376,7 @@
// marking the index to -1 at the same time. x and y also get indices.
string error;
scoped_ptr<Program> reduced_program(
- SolverImpl::CreateReducedProgram(&options, &problem, &error));
+ SolverImpl::CreateReducedProgram(&options, &problem, NULL, &error));
const vector<ResidualBlock*>& residual_blocks =
problem.program().residual_blocks();