Validate ParameterBlocks before solve. Ensure that all parameter blocks have finite values before the minimizer is called. Change-Id: I15fd9c487247989626f799496bb8f5ea8728d6f0
diff --git a/internal/ceres/array_utils.cc b/internal/ceres/array_utils.cc index 673baa4..3eea042 100644 --- a/internal/ceres/array_utils.cc +++ b/internal/ceres/array_utils.cc
@@ -32,7 +32,10 @@ #include <cmath> #include <cstddef> +#include <string> + #include "ceres/fpclassify.h" +#include "ceres/stringprintf.h" namespace ceres { namespace internal { @@ -55,6 +58,20 @@ return true; } +int FindInvalidValue(const int size, const double* x) { + if (x == NULL) { + return size; + } + + for (int i = 0; i < size; ++i) { + if (!IsFinite(x[i]) || (x[i] == kImpossibleValue)) { + return i; + } + } + + return size; +}; + void InvalidateArray(const int size, double* x) { if (x != NULL) { for (int i = 0; i < size; ++i) { @@ -63,5 +80,19 @@ } } +void AppendArrayToString(const int size, const double* x, string* result) { + for (int i = 0; i < size; ++i) { + if (x == NULL) { + StringAppendF(result, "Not Computed "); + } else { + if (x[i] == kImpossibleValue) { + StringAppendF(result, "Uninitialized "); + } else { + StringAppendF(result, "%12g ", x[i]); + } + } + } +} + } // namespace internal } // namespace ceres
diff --git a/internal/ceres/array_utils.h b/internal/ceres/array_utils.h index 742f439..34fda6f 100644 --- a/internal/ceres/array_utils.h +++ b/internal/ceres/array_utils.h
@@ -57,6 +57,14 @@ // equal to the "impossible" value used by InvalidateArray. bool IsArrayValid(int size, const double* x); +// If the array contains an invalid value, return the index for it, +// otherwise return size. +int FindInvalidValue(const int size, const double* x); + +// Utility routine to print an array of doubles to a string. If the +// array pointer is NULL, it is treated as an array of zeros. +void AppendArrayToString(const int size, const double* x, string* result); + extern const double kImpossibleValue; } // namespace internal
diff --git a/internal/ceres/array_utils_test.cc b/internal/ceres/array_utils_test.cc index c19a44a..96e625d 100644 --- a/internal/ceres/array_utils_test.cc +++ b/internal/ceres/array_utils_test.cc
@@ -54,5 +54,22 @@ EXPECT_FALSE(IsArrayValid(3, x)); } +TEST(ArrayUtils, FindInvalidIndex) { + double x[3]; + x[0] = 0.0; + x[1] = 1.0; + x[2] = 2.0; + EXPECT_EQ(FindInvalidValue(3, x), 3); + x[1] = std::numeric_limits<double>::infinity(); + EXPECT_EQ(FindInvalidValue(3, x), 1); + x[1] = std::numeric_limits<double>::quiet_NaN(); + EXPECT_EQ(FindInvalidValue(3, x), 1); + x[1] = std::numeric_limits<double>::signaling_NaN(); + EXPECT_EQ(FindInvalidValue(3, x), 1); + EXPECT_EQ(FindInvalidValue(1, NULL), 1); + InvalidateArray(3, x); + EXPECT_EQ(FindInvalidValue(3, x), 0); +} + } // namespace internal } // namespace ceres
diff --git a/internal/ceres/residual_block_utils.cc b/internal/ceres/residual_block_utils.cc index 4d88a9f..d2564a7 100644 --- a/internal/ceres/residual_block_utils.cc +++ b/internal/ceres/residual_block_utils.cc
@@ -61,24 +61,6 @@ } } -// Utility routine to print an array of doubles to a string. If the -// array pointer is NULL, it is treated as an array of zeros. -namespace { -void AppendArrayToString(const int size, const double* x, string* result) { - for (int i = 0; i < size; ++i) { - if (x == NULL) { - StringAppendF(result, "Not Computed "); - } else { - if (x[i] == kImpossibleValue) { - StringAppendF(result, "Uninitialized "); - } else { - StringAppendF(result, "%12g ", x[i]); - } - } - } -} -} // namespace - string EvaluationToString(const ResidualBlock& block, double const* const* parameters, double* cost,
diff --git a/internal/ceres/solver_impl.cc b/internal/ceres/solver_impl.cc index 1001a55..0d55c05 100644 --- a/internal/ceres/solver_impl.cc +++ b/internal/ceres/solver_impl.cc
@@ -224,6 +224,28 @@ summary->num_residuals_reduced = program.NumResiduals(); } +bool ParameterBlocksAreFinite(const ProblemImpl* problem, + string* message) { + CHECK_NOTNULL(message); + const Program& program = problem->program(); + const vector<ParameterBlock*>& parameter_blocks = program.parameter_blocks(); + for (int i = 0; i < parameter_blocks.size(); ++i) { + const double* array = parameter_blocks[i]->user_state(); + const int size = parameter_blocks[i]->Size(); + const int invalid_index = FindInvalidValue(size, array); + if (invalid_index != size) { + *message = StringPrintf( + "ParameterBlock: %p with size %d has at least one invalid value.\n" + "First invalid value is at index: %d.\n" + "Parameter block values: ", + array, size, invalid_index); + AppendArrayToString(size, array, message); + return false; + } + } + return true; +} + bool LineSearchOptionsAreValid(const Solver::Options& options, string* message) { // Validate values for configuration parameters supplied by user. @@ -419,7 +441,7 @@ << " residual blocks, " << problem_impl->NumResiduals() << " residuals."; - + *CHECK_NOTNULL(summary) = Solver::Summary(); if (options.minimizer_type == TRUST_REGION) { TrustRegionSolve(options, problem_impl, summary); } else { @@ -440,9 +462,6 @@ Program* original_program = original_problem_impl->mutable_program(); ProblemImpl* problem_impl = original_problem_impl; - // Reset the summary object to its default values. - *CHECK_NOTNULL(summary) = Solver::Summary(); - summary->minimizer_type = TRUST_REGION; SummarizeGivenProgram(*original_program, summary); @@ -484,6 +503,11 @@ return; } + if (!ParameterBlocksAreFinite(problem_impl, &summary->message)) { + LOG(ERROR) << "Terminating: " << summary->message; + return; + } + event_logger.AddEvent("Init"); original_program->SetParameterBlockStatePtrsToUserStatePtrs(); @@ -704,9 +728,6 @@ Program* original_program = original_problem_impl->mutable_program(); ProblemImpl* problem_impl = original_problem_impl; - // Reset the summary object to its default values. - *CHECK_NOTNULL(summary) = Solver::Summary(); - SummarizeGivenProgram(*original_program, summary); summary->minimizer_type = LINE_SEARCH; summary->line_search_direction_type = @@ -746,6 +767,11 @@ summary->num_threads_given = original_options.num_threads; summary->num_threads_used = options.num_threads; + if (!ParameterBlocksAreFinite(problem_impl, &summary->message)) { + LOG(ERROR) << "Terminating: " << summary->message; + return; + } + if (original_options.linear_solver_ordering != NULL) { if (!IsOrderingValid(original_options, problem_impl, &summary->message)) { LOG(ERROR) << summary->message;
diff --git a/internal/ceres/solver_impl_test.cc b/internal/ceres/solver_impl_test.cc index d6faaff..1a810ed 100644 --- a/internal/ceres/solver_impl_test.cc +++ b/internal/ceres/solver_impl_test.cc
@@ -1085,5 +1085,18 @@ EXPECT_EQ(array, expected); } +TEST(SolverImpl, ProblemHasNanParameterBlocks) { + Problem problem; + double x[2]; + x[0] = 1.0; + x[1] = std::numeric_limits<double>::quiet_NaN(); + problem.AddResidualBlock(new MockCostFunctionBase<1, 2, 0, 0>(), NULL, x); + Solver::Options options; + Solver::Summary summary; + Solve(options, &problem, &summary); + EXPECT_EQ(summary.termination_type, FAILURE); + EXPECT_NE(summary.message.find("has at least one invalid value"), string::npos); +} + } // namespace internal } // namespace ceres