Validate ParameterBlocks before solve.
Ensure that all parameter blocks have finite values
before the minimizer is called.
Change-Id: I15fd9c487247989626f799496bb8f5ea8728d6f0
diff --git a/internal/ceres/array_utils.cc b/internal/ceres/array_utils.cc
index 673baa4..3eea042 100644
--- a/internal/ceres/array_utils.cc
+++ b/internal/ceres/array_utils.cc
@@ -32,7 +32,10 @@
#include <cmath>
#include <cstddef>
+#include <string>
+
#include "ceres/fpclassify.h"
+#include "ceres/stringprintf.h"
namespace ceres {
namespace internal {
@@ -55,6 +58,20 @@
return true;
}
+int FindInvalidValue(const int size, const double* x) {
+ if (x == NULL) {
+ return size;
+ }
+
+ for (int i = 0; i < size; ++i) {
+ if (!IsFinite(x[i]) || (x[i] == kImpossibleValue)) {
+ return i;
+ }
+ }
+
+ return size;
+};
+
void InvalidateArray(const int size, double* x) {
if (x != NULL) {
for (int i = 0; i < size; ++i) {
@@ -63,5 +80,19 @@
}
}
+void AppendArrayToString(const int size, const double* x, string* result) {
+ for (int i = 0; i < size; ++i) {
+ if (x == NULL) {
+ StringAppendF(result, "Not Computed ");
+ } else {
+ if (x[i] == kImpossibleValue) {
+ StringAppendF(result, "Uninitialized ");
+ } else {
+ StringAppendF(result, "%12g ", x[i]);
+ }
+ }
+ }
+}
+
} // namespace internal
} // namespace ceres
diff --git a/internal/ceres/array_utils.h b/internal/ceres/array_utils.h
index 742f439..34fda6f 100644
--- a/internal/ceres/array_utils.h
+++ b/internal/ceres/array_utils.h
@@ -57,6 +57,14 @@
// equal to the "impossible" value used by InvalidateArray.
bool IsArrayValid(int size, const double* x);
+// If the array contains an invalid value, return the index for it,
+// otherwise return size.
+int FindInvalidValue(const int size, const double* x);
+
+// Utility routine to print an array of doubles to a string. If the
+// array pointer is NULL, it is treated as an array of zeros.
+void AppendArrayToString(const int size, const double* x, string* result);
+
extern const double kImpossibleValue;
} // namespace internal
diff --git a/internal/ceres/array_utils_test.cc b/internal/ceres/array_utils_test.cc
index c19a44a..96e625d 100644
--- a/internal/ceres/array_utils_test.cc
+++ b/internal/ceres/array_utils_test.cc
@@ -54,5 +54,22 @@
EXPECT_FALSE(IsArrayValid(3, x));
}
+TEST(ArrayUtils, FindInvalidIndex) {
+ double x[3];
+ x[0] = 0.0;
+ x[1] = 1.0;
+ x[2] = 2.0;
+ EXPECT_EQ(FindInvalidValue(3, x), 3);
+ x[1] = std::numeric_limits<double>::infinity();
+ EXPECT_EQ(FindInvalidValue(3, x), 1);
+ x[1] = std::numeric_limits<double>::quiet_NaN();
+ EXPECT_EQ(FindInvalidValue(3, x), 1);
+ x[1] = std::numeric_limits<double>::signaling_NaN();
+ EXPECT_EQ(FindInvalidValue(3, x), 1);
+ EXPECT_EQ(FindInvalidValue(1, NULL), 1);
+ InvalidateArray(3, x);
+ EXPECT_EQ(FindInvalidValue(3, x), 0);
+}
+
} // namespace internal
} // namespace ceres
diff --git a/internal/ceres/residual_block_utils.cc b/internal/ceres/residual_block_utils.cc
index 4d88a9f..d2564a7 100644
--- a/internal/ceres/residual_block_utils.cc
+++ b/internal/ceres/residual_block_utils.cc
@@ -61,24 +61,6 @@
}
}
-// Utility routine to print an array of doubles to a string. If the
-// array pointer is NULL, it is treated as an array of zeros.
-namespace {
-void AppendArrayToString(const int size, const double* x, string* result) {
- for (int i = 0; i < size; ++i) {
- if (x == NULL) {
- StringAppendF(result, "Not Computed ");
- } else {
- if (x[i] == kImpossibleValue) {
- StringAppendF(result, "Uninitialized ");
- } else {
- StringAppendF(result, "%12g ", x[i]);
- }
- }
- }
-}
-} // namespace
-
string EvaluationToString(const ResidualBlock& block,
double const* const* parameters,
double* cost,
diff --git a/internal/ceres/solver_impl.cc b/internal/ceres/solver_impl.cc
index 1001a55..0d55c05 100644
--- a/internal/ceres/solver_impl.cc
+++ b/internal/ceres/solver_impl.cc
@@ -224,6 +224,28 @@
summary->num_residuals_reduced = program.NumResiduals();
}
+bool ParameterBlocksAreFinite(const ProblemImpl* problem,
+ string* message) {
+ CHECK_NOTNULL(message);
+ const Program& program = problem->program();
+ const vector<ParameterBlock*>& parameter_blocks = program.parameter_blocks();
+ for (int i = 0; i < parameter_blocks.size(); ++i) {
+ const double* array = parameter_blocks[i]->user_state();
+ const int size = parameter_blocks[i]->Size();
+ const int invalid_index = FindInvalidValue(size, array);
+ if (invalid_index != size) {
+ *message = StringPrintf(
+ "ParameterBlock: %p with size %d has at least one invalid value.\n"
+ "First invalid value is at index: %d.\n"
+ "Parameter block values: ",
+ array, size, invalid_index);
+ AppendArrayToString(size, array, message);
+ return false;
+ }
+ }
+ return true;
+}
+
bool LineSearchOptionsAreValid(const Solver::Options& options,
string* message) {
// Validate values for configuration parameters supplied by user.
@@ -419,7 +441,7 @@
<< " residual blocks, "
<< problem_impl->NumResiduals()
<< " residuals.";
-
+ *CHECK_NOTNULL(summary) = Solver::Summary();
if (options.minimizer_type == TRUST_REGION) {
TrustRegionSolve(options, problem_impl, summary);
} else {
@@ -440,9 +462,6 @@
Program* original_program = original_problem_impl->mutable_program();
ProblemImpl* problem_impl = original_problem_impl;
- // Reset the summary object to its default values.
- *CHECK_NOTNULL(summary) = Solver::Summary();
-
summary->minimizer_type = TRUST_REGION;
SummarizeGivenProgram(*original_program, summary);
@@ -484,6 +503,11 @@
return;
}
+ if (!ParameterBlocksAreFinite(problem_impl, &summary->message)) {
+ LOG(ERROR) << "Terminating: " << summary->message;
+ return;
+ }
+
event_logger.AddEvent("Init");
original_program->SetParameterBlockStatePtrsToUserStatePtrs();
@@ -704,9 +728,6 @@
Program* original_program = original_problem_impl->mutable_program();
ProblemImpl* problem_impl = original_problem_impl;
- // Reset the summary object to its default values.
- *CHECK_NOTNULL(summary) = Solver::Summary();
-
SummarizeGivenProgram(*original_program, summary);
summary->minimizer_type = LINE_SEARCH;
summary->line_search_direction_type =
@@ -746,6 +767,11 @@
summary->num_threads_given = original_options.num_threads;
summary->num_threads_used = options.num_threads;
+ if (!ParameterBlocksAreFinite(problem_impl, &summary->message)) {
+ LOG(ERROR) << "Terminating: " << summary->message;
+ return;
+ }
+
if (original_options.linear_solver_ordering != NULL) {
if (!IsOrderingValid(original_options, problem_impl, &summary->message)) {
LOG(ERROR) << summary->message;
diff --git a/internal/ceres/solver_impl_test.cc b/internal/ceres/solver_impl_test.cc
index d6faaff..1a810ed 100644
--- a/internal/ceres/solver_impl_test.cc
+++ b/internal/ceres/solver_impl_test.cc
@@ -1085,5 +1085,18 @@
EXPECT_EQ(array, expected);
}
+TEST(SolverImpl, ProblemHasNanParameterBlocks) {
+ Problem problem;
+ double x[2];
+ x[0] = 1.0;
+ x[1] = std::numeric_limits<double>::quiet_NaN();
+ problem.AddResidualBlock(new MockCostFunctionBase<1, 2, 0, 0>(), NULL, x);
+ Solver::Options options;
+ Solver::Summary summary;
+ Solve(options, &problem, &summary);
+ EXPECT_EQ(summary.termination_type, FAILURE);
+ EXPECT_NE(summary.message.find("has at least one invalid value"), string::npos);
+}
+
} // namespace internal
} // namespace ceres