Compute summary->fixed_cost while the residual blocks are removed
Change-Id: Ib598316ad21606f1a06db48c341a5f1d69915b2b
diff --git a/internal/ceres/solver_impl.cc b/internal/ceres/solver_impl.cc
index 28627eb..493a890 100644
--- a/internal/ceres/solver_impl.cc
+++ b/internal/ceres/solver_impl.cc
@@ -236,7 +236,7 @@
// evaluator, and the linear solver.
scoped_ptr<Program> reduced_program(
- CreateReducedProgram(&options, problem_impl, &summary->error));
+ CreateReducedProgram(&options, problem_impl, &summary->fixed_cost, &summary->error));
if (reduced_program == NULL) {
return;
}
@@ -321,11 +321,19 @@
// num_eliminate_blocks.
bool SolverImpl::RemoveFixedBlocksFromProgram(Program* program,
int* num_eliminate_blocks,
+ double* fixed_cost,
string* error) {
int original_num_eliminate_blocks = *num_eliminate_blocks;
vector<ParameterBlock*>* parameter_blocks =
program->mutable_parameter_blocks();
+ scoped_array<double> residual_block_evaluate_scratch;
+ if (fixed_cost != NULL) {
+ residual_block_evalute_scratch.reset(
+ new double[program->MaxScratchDoublesNeededForEvaluate()]);
+ *fixed_cost = 0.0;
+ }
+
// Mark all the parameters as unused. Abuse the index member of the parameter
// blocks for the marking.
for (int i = 0; i < parameter_blocks->size(); ++i) {
@@ -355,6 +363,17 @@
if (!all_constant) {
(*residual_blocks)[j++] = (*residual_blocks)[i];
+ } else if (fixed_cost != NULL) {
+ // The residual is constant and will be removed, so its cost is
+ // added to the variable fixed_cost.
+ double cost = 0.0;
+ if (!residual_block->Evaluate(
+ &cost, NULL, NULL, residual_block_evaluate_scratch.get())) {
+ *error = StringPrintf("Evaluation of the residual %d failed during "
+ "removal of fixed residual blocks.", i);
+ return false;
+ }
+ *fixed_cost += cost;
}
}
residual_blocks->resize(j);
@@ -387,6 +406,7 @@
Program* SolverImpl::CreateReducedProgram(Solver::Options* options,
ProblemImpl* problem_impl,
+ double* fixed_cost,
string* error) {
Program* original_program = problem_impl->mutable_program();
scoped_ptr<Program> transformed_program(new Program(*original_program));
@@ -417,6 +437,7 @@
if (!RemoveFixedBlocksFromProgram(transformed_program.get(),
&num_eliminate_blocks,
+ fixed_cost,
error)) {
return NULL;
}