Validate ParameterBlocks before solve.

Ensure that all parameter blocks have finite values
before the minimizer is called.

Change-Id: I15fd9c487247989626f799496bb8f5ea8728d6f0
diff --git a/internal/ceres/array_utils.cc b/internal/ceres/array_utils.cc
index 673baa4..3eea042 100644
--- a/internal/ceres/array_utils.cc
+++ b/internal/ceres/array_utils.cc
@@ -32,7 +32,10 @@
 
 #include <cmath>
 #include <cstddef>
+#include <string>
+
 #include "ceres/fpclassify.h"
+#include "ceres/stringprintf.h"
 
 namespace ceres {
 namespace internal {
@@ -55,6 +58,20 @@
   return true;
 }
 
+int FindInvalidValue(const int size, const double* x) {
+  if (x == NULL) {
+    return size;
+  }
+
+  for (int i = 0; i < size; ++i) {
+    if (!IsFinite(x[i]) || (x[i] == kImpossibleValue))  {
+      return i;
+    }
+  }
+
+  return size;
+};
+
 void InvalidateArray(const int size, double* x) {
   if (x != NULL) {
     for (int i = 0; i < size; ++i) {
@@ -63,5 +80,19 @@
   }
 }
 
+void AppendArrayToString(const int size, const double* x, string* result) {
+  for (int i = 0; i < size; ++i) {
+    if (x == NULL) {
+      StringAppendF(result, "Not Computed  ");
+    } else {
+      if (x[i] == kImpossibleValue) {
+        StringAppendF(result, "Uninitialized ");
+      } else {
+        StringAppendF(result, "%12g ", x[i]);
+      }
+    }
+  }
+}
+
 }  // namespace internal
 }  // namespace ceres
diff --git a/internal/ceres/array_utils.h b/internal/ceres/array_utils.h
index 742f439..34fda6f 100644
--- a/internal/ceres/array_utils.h
+++ b/internal/ceres/array_utils.h
@@ -57,6 +57,14 @@
 // equal to the "impossible" value used by InvalidateArray.
 bool IsArrayValid(int size, const double* x);
 
+// If the array contains an invalid value, return the index for it,
+// otherwise return size.
+int FindInvalidValue(const int size, const double* x);
+
+// Utility routine to print an array of doubles to a string. If the
+// array pointer is NULL, it is treated as an array of zeros.
+void AppendArrayToString(const int size, const double* x, string* result);
+
 extern const double kImpossibleValue;
 
 }  // namespace internal
diff --git a/internal/ceres/array_utils_test.cc b/internal/ceres/array_utils_test.cc
index c19a44a..96e625d 100644
--- a/internal/ceres/array_utils_test.cc
+++ b/internal/ceres/array_utils_test.cc
@@ -54,5 +54,22 @@
   EXPECT_FALSE(IsArrayValid(3, x));
 }
 
+TEST(ArrayUtils, FindInvalidIndex) {
+  double x[3];
+  x[0] = 0.0;
+  x[1] = 1.0;
+  x[2] = 2.0;
+  EXPECT_EQ(FindInvalidValue(3, x), 3);
+  x[1] = std::numeric_limits<double>::infinity();
+  EXPECT_EQ(FindInvalidValue(3, x), 1);
+  x[1] = std::numeric_limits<double>::quiet_NaN();
+  EXPECT_EQ(FindInvalidValue(3, x), 1);
+  x[1] = std::numeric_limits<double>::signaling_NaN();
+  EXPECT_EQ(FindInvalidValue(3, x), 1);
+  EXPECT_EQ(FindInvalidValue(1, NULL), 1);
+  InvalidateArray(3, x);
+  EXPECT_EQ(FindInvalidValue(3, x), 0);
+}
+
 }  // namespace internal
 }  // namespace ceres
diff --git a/internal/ceres/residual_block_utils.cc b/internal/ceres/residual_block_utils.cc
index 4d88a9f..d2564a7 100644
--- a/internal/ceres/residual_block_utils.cc
+++ b/internal/ceres/residual_block_utils.cc
@@ -61,24 +61,6 @@
   }
 }
 
-// Utility routine to print an array of doubles to a string. If the
-// array pointer is NULL, it is treated as an array of zeros.
-namespace {
-void AppendArrayToString(const int size, const double* x, string* result) {
-  for (int i = 0; i < size; ++i) {
-    if (x == NULL) {
-      StringAppendF(result, "Not Computed  ");
-    } else {
-      if (x[i] == kImpossibleValue) {
-        StringAppendF(result, "Uninitialized ");
-      } else {
-        StringAppendF(result, "%12g ", x[i]);
-      }
-    }
-  }
-}
-}  // namespace
-
 string EvaluationToString(const ResidualBlock& block,
                           double const* const* parameters,
                           double* cost,
diff --git a/internal/ceres/solver_impl.cc b/internal/ceres/solver_impl.cc
index 1001a55..0d55c05 100644
--- a/internal/ceres/solver_impl.cc
+++ b/internal/ceres/solver_impl.cc
@@ -224,6 +224,28 @@
   summary->num_residuals_reduced = program.NumResiduals();
 }
 
+bool ParameterBlocksAreFinite(const ProblemImpl* problem,
+                              string* message) {
+  CHECK_NOTNULL(message);
+  const Program& program = problem->program();
+  const vector<ParameterBlock*>& parameter_blocks = program.parameter_blocks();
+  for (int i = 0; i < parameter_blocks.size(); ++i) {
+    const double* array = parameter_blocks[i]->user_state();
+    const int size = parameter_blocks[i]->Size();
+    const int invalid_index = FindInvalidValue(size, array);
+    if (invalid_index != size) {
+      *message = StringPrintf(
+          "ParameterBlock: %p with size %d has at least one invalid value.\n"
+          "First invalid value is at index: %d.\n"
+          "Parameter block values: ",
+          array, size, invalid_index);
+      AppendArrayToString(size, array, message);
+      return false;
+    }
+  }
+  return true;
+}
+
 bool LineSearchOptionsAreValid(const Solver::Options& options,
                                string* message) {
   // Validate values for configuration parameters supplied by user.
@@ -419,7 +441,7 @@
           << " residual blocks, "
           << problem_impl->NumResiduals()
           << " residuals.";
-
+  *CHECK_NOTNULL(summary) = Solver::Summary();
   if (options.minimizer_type == TRUST_REGION) {
     TrustRegionSolve(options, problem_impl, summary);
   } else {
@@ -440,9 +462,6 @@
   Program* original_program = original_problem_impl->mutable_program();
   ProblemImpl* problem_impl = original_problem_impl;
 
-  // Reset the summary object to its default values.
-  *CHECK_NOTNULL(summary) = Solver::Summary();
-
   summary->minimizer_type = TRUST_REGION;
 
   SummarizeGivenProgram(*original_program, summary);
@@ -484,6 +503,11 @@
     return;
   }
 
+  if (!ParameterBlocksAreFinite(problem_impl, &summary->message)) {
+    LOG(ERROR) << "Terminating: " << summary->message;
+    return;
+  }
+
   event_logger.AddEvent("Init");
 
   original_program->SetParameterBlockStatePtrsToUserStatePtrs();
@@ -704,9 +728,6 @@
   Program* original_program = original_problem_impl->mutable_program();
   ProblemImpl* problem_impl = original_problem_impl;
 
-  // Reset the summary object to its default values.
-  *CHECK_NOTNULL(summary) = Solver::Summary();
-
   SummarizeGivenProgram(*original_program, summary);
   summary->minimizer_type = LINE_SEARCH;
   summary->line_search_direction_type =
@@ -746,6 +767,11 @@
   summary->num_threads_given = original_options.num_threads;
   summary->num_threads_used = options.num_threads;
 
+  if (!ParameterBlocksAreFinite(problem_impl, &summary->message)) {
+    LOG(ERROR) << "Terminating: " << summary->message;
+    return;
+  }
+
   if (original_options.linear_solver_ordering != NULL) {
     if (!IsOrderingValid(original_options, problem_impl, &summary->message)) {
       LOG(ERROR) << summary->message;
diff --git a/internal/ceres/solver_impl_test.cc b/internal/ceres/solver_impl_test.cc
index d6faaff..1a810ed 100644
--- a/internal/ceres/solver_impl_test.cc
+++ b/internal/ceres/solver_impl_test.cc
@@ -1085,5 +1085,18 @@
   EXPECT_EQ(array, expected);
 }
 
+TEST(SolverImpl, ProblemHasNanParameterBlocks) {
+  Problem problem;
+  double x[2];
+  x[0] = 1.0;
+  x[1] = std::numeric_limits<double>::quiet_NaN();
+  problem.AddResidualBlock(new MockCostFunctionBase<1, 2, 0, 0>(), NULL, x);
+  Solver::Options options;
+  Solver::Summary summary;
+  Solve(options, &problem, &summary);
+  EXPECT_EQ(summary.termination_type, FAILURE);
+  EXPECT_NE(summary.message.find("has at least one invalid value"), string::npos);
+}
+
 }  // namespace internal
 }  // namespace ceres