Refactor nist.cc to be compatible with TinySolver

Change-Id: Iec0455ff9fe327fe75dc63f5b80c2ecca2c48e55
diff --git a/examples/nist.cc b/examples/nist.cc
index 55a3c7c..754467c 100644
--- a/examples/nist.cc
+++ b/examples/nist.cc
@@ -1,5 +1,5 @@
 // Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2015 Google Inc. All rights reserved.
+// Copyright 2017 Google Inc. All rights reserved.
 // http://ceres-solver.org/
 //
 // Redistribution and use in source and binary forms, with or without
@@ -71,14 +71,18 @@
 // Average LRE     2.3      4.3       4.0  6.8      4.4    9.4
 //      Winner       0        0         5   11        2     41
 
+#include <Eigen/Core>
+#include <fstream>
 #include <iostream>
 #include <iterator>
-#include <fstream>
+
 #include "ceres/ceres.h"
+#include "ceres/tiny_solver.h"
+#include "ceres/tiny_solver_cost_function_adapter.h"
 #include "gflags/gflags.h"
 #include "glog/logging.h"
-#include "Eigen/Core"
 
+DEFINE_bool(use_tiny_solver, false, "Use TinySolver instead of Ceres::Solver");
 DEFINE_string(nist_data_dir, "", "Directory containing the NIST non-linear"
               "regression examples");
 DEFINE_string(minimizer, "trust_region",
@@ -265,20 +269,22 @@
   double certified_cost_;
 };
 
-#define NIST_BEGIN(CostFunctionName) \
-  struct CostFunctionName { \
-    CostFunctionName(const double* const x, \
-                     const double* const y) \
-        : x_(*x), y_(*y) {} \
-    double x_; \
-    double y_; \
-    template <typename T> \
-    bool operator()(const T* const b, T* residual) const { \
-    const T y(y_); \
-    const T x(x_); \
-      residual[0] = y - (
+#define NIST_BEGIN(CostFunctionName)                          \
+  struct CostFunctionName {                                   \
+  CostFunctionName(const double* const x,                     \
+                   const double* const y,                     \
+                   const int n)                               \
+      : x_(x), y_(y), n_(n) {}                                \
+    const double* x_;                                         \
+    const double* y_;                                         \
+    const int n_;                                             \
+    template <typename T>                                     \
+    bool operator()(const T* const b, T* residual) const {    \
+      for (int i = 0; i < n_; ++i) {                          \
+        const T x(x_[i]);                                     \
+        residual[i] = y_[i] - (
 
-#define NIST_END ); return true; }};
+#define NIST_END ); } return true; }};
 
 // y = b1 * (b2+x)**(-1/b3)  +  e
 NIST_BEGIN(Bennet5)
@@ -405,20 +411,22 @@
 
 struct Nelson {
  public:
-  Nelson(const double* const x, const double* const y)
-      : x1_(x[0]), x2_(x[1]), y_(y[0]) {}
+  Nelson(const double* const x, const double* const y, const int n)
+      : x_(x), y_(y), n_(n) {}
 
   template <typename T>
   bool operator()(const T* const b, T* residual) const {
     // log[y] = b1 - b2*x1 * exp[-b3*x2]  +  e
-    residual[0] = log(y_) - (b[0] - b[1] * x1_ * exp(-b[2] * x2_));
+    for (int i = 0; i < n_; ++i) {
+      residual[i] = log(y_[i]) - (b[0] - b[1] * x_[2 * i] * exp(-b[2] * x_[2 * i + 1]));
+    }
     return true;
   }
 
  private:
-  double x1_;
-  double x2_;
-  double y_;
+  const double* x_;
+  const double* y_;
+  const int n_;
 };
 
 static void SetNumericDiffOptions(ceres::NumericDiffOptions* options) {
@@ -426,138 +434,18 @@
   options->ridders_relative_initial_step_size = FLAGS_ridders_step_size;
 }
 
-string JoinPath(const string& dirname, const string& basename) {
-#ifdef _WIN32
-    static const char separator = '\\';
-#else
-    static const char separator = '/';
-#endif  // _WIN32
-
-  if ((!basename.empty() && basename[0] == separator) || dirname.empty()) {
-    return basename;
-  } else if (dirname[dirname.size() - 1] == separator) {
-    return dirname + basename;
-  } else {
-    return dirname + string(&separator, 1) + basename;
-  }
-}
-
-template <typename Model, int num_residuals, int num_parameters>
-int RegressionDriver(const string& filename,
-                     const ceres::Solver::Options& options) {
-  NISTProblem nist_problem(JoinPath(FLAGS_nist_data_dir, filename));
-  CHECK_EQ(num_residuals, nist_problem.response_size());
-  CHECK_EQ(num_parameters, nist_problem.num_parameters());
-
-  Matrix predictor = nist_problem.predictor();
-  Matrix response = nist_problem.response();
-  Matrix final_parameters = nist_problem.final_parameters();
-
-  printf("%s\n", filename.c_str());
-
-  // Each NIST problem comes with multiple starting points, so we
-  // construct the problem from scratch for each case and solve it.
-  int num_success = 0;
-  for (int start = 0; start < nist_problem.num_starts(); ++start) {
-    Matrix initial_parameters = nist_problem.initial_parameters(start);
-
-    ceres::Problem problem;
-    for (int i = 0; i < nist_problem.num_observations(); ++i) {
-      Model* model = new Model(
-          predictor.data() + nist_problem.predictor_size() * i,
-          response.data() + nist_problem.response_size() * i);
-      ceres::CostFunction* cost_function = NULL;
-      if (FLAGS_use_numeric_diff) {
-        ceres::NumericDiffOptions options;
-        SetNumericDiffOptions(&options);
-        if (FLAGS_numeric_diff_method == "central") {
-          cost_function = new NumericDiffCostFunction<Model,
-                                                      ceres::CENTRAL,
-                                                      num_residuals,
-                                                      num_parameters>(
-              model, ceres::TAKE_OWNERSHIP, num_residuals, options);
-        } else if (FLAGS_numeric_diff_method == "forward") {
-          cost_function = new NumericDiffCostFunction<Model,
-                                                      ceres::FORWARD,
-                                                      num_residuals,
-                                                      num_parameters>(
-              model, ceres::TAKE_OWNERSHIP, num_residuals, options);
-        } else if (FLAGS_numeric_diff_method == "ridders") {
-          cost_function = new NumericDiffCostFunction<Model,
-                                                      ceres::RIDDERS,
-                                                      num_residuals,
-                                                      num_parameters>(
-              model, ceres::TAKE_OWNERSHIP, num_residuals, options);
-        } else {
-          LOG(ERROR) << "Invalid numeric diff method specified";
-          return 0;
-        }
-      } else {
-         cost_function =
-             new ceres::AutoDiffCostFunction<Model,
-                                             num_residuals,
-                                             num_parameters>(model);
-      }
-
-      problem.AddResidualBlock(cost_function,
-                               NULL,
-                               initial_parameters.data());
-    }
-
-    ceres::Solver::Summary summary;
-    Solve(options, &problem, &summary);
-
-    // Compute the LRE by comparing each component of the solution
-    // with the ground truth, and taking the minimum.
-    Matrix final_parameters = nist_problem.final_parameters();
-    const double kMaxNumSignificantDigits = 11;
-    double log_relative_error = kMaxNumSignificantDigits + 1;
-    for (int i = 0; i < num_parameters; ++i) {
-      const double tmp_lre =
-          -std::log10(std::fabs(final_parameters(i) - initial_parameters(i)) /
-                      std::fabs(final_parameters(i)));
-      // The maximum LRE is capped at 11 - the precision at which the
-      // ground truth is known.
-      //
-      // The minimum LRE is capped at 0 - no digits match between the
-      // computed solution and the ground truth.
-      log_relative_error =
-          std::min(log_relative_error,
-                   std::max(0.0, std::min(kMaxNumSignificantDigits, tmp_lre)));
-    }
-
-    const int kMinNumMatchingDigits = 4;
-    if (log_relative_error > kMinNumMatchingDigits) {
-      ++num_success;
-    }
-
-    printf("start: %d status: %s lre: %4.1f initial cost: %e final cost:%e "
-           "certified cost: %e total iterations: %d\n",
-           start + 1,
-           log_relative_error < kMinNumMatchingDigits ? "FAILURE" : "SUCCESS",
-           log_relative_error,
-           summary.initial_cost,
-           summary.final_cost,
-           nist_problem.certified_cost(),
-           (summary.num_successful_steps + summary.num_unsuccessful_steps));
-  }
-  return num_success;
-}
-
 void SetMinimizerOptions(ceres::Solver::Options* options) {
-  CHECK(ceres::StringToMinimizerType(FLAGS_minimizer,
-                                     &options->minimizer_type));
+  CHECK(
+      ceres::StringToMinimizerType(FLAGS_minimizer, &options->minimizer_type));
   CHECK(ceres::StringToLinearSolverType(FLAGS_linear_solver,
                                         &options->linear_solver_type));
   CHECK(ceres::StringToPreconditionerType(FLAGS_preconditioner,
                                           &options->preconditioner_type));
   CHECK(ceres::StringToTrustRegionStrategyType(
-            FLAGS_trust_region_strategy,
-            &options->trust_region_strategy_type));
+      FLAGS_trust_region_strategy, &options->trust_region_strategy_type));
   CHECK(ceres::StringToDoglegType(FLAGS_dogleg, &options->dogleg_type));
   CHECK(ceres::StringToLineSearchDirectionType(
-      FLAGS_line_search_direction,
-      &options->line_search_direction_type));
+      FLAGS_line_search_direction, &options->line_search_direction_type));
   CHECK(ceres::StringToLineSearchType(FLAGS_line_search,
                                       &options->line_search_type));
   CHECK(ceres::StringToLineSearchInterpolationType(
@@ -582,57 +470,213 @@
   options->parameter_tolerance = std::numeric_limits<double>::epsilon();
 }
 
+string JoinPath(const string& dirname, const string& basename) {
+#ifdef _WIN32
+    static const char separator = '\\';
+#else
+    static const char separator = '/';
+#endif  // _WIN32
+
+  if ((!basename.empty() && basename[0] == separator) || dirname.empty()) {
+    return basename;
+  } else if (dirname[dirname.size() - 1] == separator) {
+    return dirname + basename;
+  } else {
+    return dirname + string(&separator, 1) + basename;
+  }
+}
+
+template <typename Model, int num_parameters>
+CostFunction* CreateCostFunction(const Matrix& predictor,
+                                 const Matrix& response,
+                                 const int num_observations) {
+  Model* model =
+      new Model(predictor.data(), response.data(), num_observations);
+  ceres::CostFunction* cost_function = NULL;
+  if (FLAGS_use_numeric_diff) {
+    ceres::NumericDiffOptions options;
+    SetNumericDiffOptions(&options);
+    if (FLAGS_numeric_diff_method == "central") {
+      cost_function = new NumericDiffCostFunction<Model,
+                                                  ceres::CENTRAL,
+                                                  ceres::DYNAMIC,
+                                                  num_parameters>(
+          model,
+          ceres::TAKE_OWNERSHIP,
+          num_observations,
+          options);
+    } else if (FLAGS_numeric_diff_method == "forward") {
+      cost_function = new NumericDiffCostFunction<Model,
+                                                  ceres::FORWARD,
+                                                  ceres::DYNAMIC,
+                                                  num_parameters>(
+          model,
+          ceres::TAKE_OWNERSHIP,
+          num_observations,
+          options);
+    } else if (FLAGS_numeric_diff_method == "ridders") {
+      cost_function = new NumericDiffCostFunction<Model,
+                                                  ceres::RIDDERS,
+                                                  ceres::DYNAMIC,
+                                                  num_parameters>(
+          model,
+          ceres::TAKE_OWNERSHIP,
+          num_observations,
+          options);
+    } else {
+      LOG(ERROR) << "Invalid numeric diff method specified";
+      return 0;
+    }
+  } else {
+    cost_function =
+        new ceres::AutoDiffCostFunction<Model, ceres::DYNAMIC, num_parameters>(
+            model, num_observations);
+  }
+  return cost_function;
+}
+
+double ComputeLRE(const Matrix& expected, const Matrix& actual) {
+  // Compute the LRE by comparing each component of the solution
+  // with the ground truth, and taking the minimum.
+  const double kMaxNumSignificantDigits = 11;
+  double log_relative_error = kMaxNumSignificantDigits + 1;
+  for (int i = 0; i < expected.cols(); ++i) {
+    const double tmp_lre = -std::log10(std::fabs(expected(i) - actual(i)) /
+                                       std::fabs(expected(i)));
+    // The maximum LRE is capped at 11 - the precision at which the
+    // ground truth is known.
+    //
+    // The minimum LRE is capped at 0 - no digits match between the
+    // computed solution and the ground truth.
+    log_relative_error =
+        std::min(log_relative_error,
+                 std::max(0.0, std::min(kMaxNumSignificantDigits, tmp_lre)));
+  }
+  return log_relative_error;
+}
+
+template <typename Model, int num_parameters>
+int RegressionDriver(const string& filename) {
+  NISTProblem nist_problem(JoinPath(FLAGS_nist_data_dir, filename));
+  CHECK_EQ(num_parameters, nist_problem.num_parameters());
+
+  Matrix predictor = nist_problem.predictor();
+  Matrix response = nist_problem.response();
+  Matrix final_parameters = nist_problem.final_parameters();
+
+  printf("%s\n", filename.c_str());
+
+  // Each NIST problem comes with multiple starting points, so we
+  // construct the problem from scratch for each case and solve it.
+  int num_success = 0;
+  for (int start = 0; start < nist_problem.num_starts(); ++start) {
+    Matrix initial_parameters = nist_problem.initial_parameters(start);
+    ceres::CostFunction* cost_function = CreateCostFunction<Model, num_parameters>(
+        predictor, response,  nist_problem.num_observations());
+
+    double initial_cost;
+    double final_cost;
+
+    if (!FLAGS_use_tiny_solver) {
+      ceres::Problem problem;
+      problem.AddResidualBlock(cost_function, NULL, initial_parameters.data());
+      ceres::Solver::Summary summary;
+      ceres::Solver::Options options;
+      SetMinimizerOptions(&options);
+      Solve(options, &problem, &summary);
+      initial_cost = summary.initial_cost;
+      final_cost = summary.final_cost;
+    } else {
+      ceres::TinySolverCostFunctionAdapter<Eigen::Dynamic, num_parameters> cfa(
+          *cost_function);
+      typedef
+      ceres::TinySolver<
+        ceres::TinySolverCostFunctionAdapter<Eigen::Dynamic, num_parameters>,
+        Eigen::LDLT<Eigen::Matrix<double, num_parameters, num_parameters>>> Solver;
+      Solver solver;
+      solver.options.max_iterations = FLAGS_num_iterations;
+      solver.options.error_threshold = std::numeric_limits<double>::epsilon();
+      solver.options.gradient_threshold = std::numeric_limits<double>::epsilon();
+      solver.options.relative_step_threshold = std::numeric_limits<double>::epsilon();
+
+      Eigen::Matrix<double, num_parameters,1> x;
+      x = initial_parameters.transpose();
+      typename Solver::Summary summary = solver.Solve(cfa, &x);
+      initial_parameters = x;
+      initial_cost = summary.initial_cost;
+      final_cost = summary.final_cost;
+      delete cost_function;
+    }
+
+    const double log_relative_error = ComputeLRE(nist_problem.final_parameters(),
+                                                 initial_parameters);
+    const int kMinNumMatchingDigits = 4;
+    if (log_relative_error > kMinNumMatchingDigits) {
+      ++num_success;
+    }
+
+    printf(
+        "start: %d status: %s lre: %4.1f initial cost: %e final cost:%e "
+        "certified cost: %e\n",
+        start + 1,
+        log_relative_error < kMinNumMatchingDigits ? "FAILURE" : "SUCCESS",
+        log_relative_error,
+        initial_cost,
+        final_cost,
+        nist_problem.certified_cost());
+  }
+  return num_success;
+}
+
+
 void SolveNISTProblems() {
   if (FLAGS_nist_data_dir.empty()) {
     LOG(FATAL) << "Must specify the directory containing the NIST problems";
   }
 
-  ceres::Solver::Options options;
-  SetMinimizerOptions(&options);
-
   cout << "Lower Difficulty\n";
   int easy_success = 0;
-  easy_success += RegressionDriver<Misra1a,  1, 2>("Misra1a.dat",  options);
-  easy_success += RegressionDriver<Chwirut,  1, 3>("Chwirut1.dat", options);
-  easy_success += RegressionDriver<Chwirut,  1, 3>("Chwirut2.dat", options);
-  easy_success += RegressionDriver<Lanczos,  1, 6>("Lanczos3.dat", options);
-  easy_success += RegressionDriver<Gauss,    1, 8>("Gauss1.dat",   options);
-  easy_success += RegressionDriver<Gauss,    1, 8>("Gauss2.dat",   options);
-  easy_success += RegressionDriver<DanWood,  1, 2>("DanWood.dat",  options);
-  easy_success += RegressionDriver<Misra1b,  1, 2>("Misra1b.dat",  options);
+  easy_success += RegressionDriver<Misra1a, 2>("Misra1a.dat");
+  easy_success += RegressionDriver<Chwirut, 3>("Chwirut1.dat");
+  easy_success += RegressionDriver<Chwirut, 3>("Chwirut2.dat");
+  easy_success += RegressionDriver<Lanczos, 6>("Lanczos3.dat");
+  easy_success += RegressionDriver<Gauss, 8>("Gauss1.dat");
+  easy_success += RegressionDriver<Gauss, 8>("Gauss2.dat");
+  easy_success += RegressionDriver<DanWood, 2>("DanWood.dat");
+  easy_success += RegressionDriver<Misra1b, 2>("Misra1b.dat");
 
   cout << "\nMedium Difficulty\n";
   int medium_success = 0;
-  medium_success += RegressionDriver<Kirby2,   1, 5>("Kirby2.dat",   options);
-  medium_success += RegressionDriver<Hahn1,    1, 7>("Hahn1.dat",    options);
-  medium_success += RegressionDriver<Nelson,   1, 3>("Nelson.dat",   options);
-  medium_success += RegressionDriver<MGH17,    1, 5>("MGH17.dat",    options);
-  medium_success += RegressionDriver<Lanczos,  1, 6>("Lanczos1.dat", options);
-  medium_success += RegressionDriver<Lanczos,  1, 6>("Lanczos2.dat", options);
-  medium_success += RegressionDriver<Gauss,    1, 8>("Gauss3.dat",   options);
-  medium_success += RegressionDriver<Misra1c,  1, 2>("Misra1c.dat",  options);
-  medium_success += RegressionDriver<Misra1d,  1, 2>("Misra1d.dat",  options);
-  medium_success += RegressionDriver<Roszman1, 1, 4>("Roszman1.dat", options);
-  medium_success += RegressionDriver<ENSO,     1, 9>("ENSO.dat",     options);
+  medium_success += RegressionDriver<Kirby2, 5>("Kirby2.dat");
+  medium_success += RegressionDriver<Hahn1, 7>("Hahn1.dat");
+  medium_success += RegressionDriver<Nelson, 3>("Nelson.dat");
+  medium_success += RegressionDriver<MGH17, 5>("MGH17.dat");
+  medium_success += RegressionDriver<Lanczos, 6>("Lanczos1.dat");
+  medium_success += RegressionDriver<Lanczos, 6>("Lanczos2.dat");
+  medium_success += RegressionDriver<Gauss, 8>("Gauss3.dat");
+  medium_success += RegressionDriver<Misra1c, 2>("Misra1c.dat");
+  medium_success += RegressionDriver<Misra1d, 2>("Misra1d.dat");
+  medium_success += RegressionDriver<Roszman1, 4>("Roszman1.dat");
+  medium_success += RegressionDriver<ENSO, 9>("ENSO.dat");
 
   cout << "\nHigher Difficulty\n";
   int hard_success = 0;
-  hard_success += RegressionDriver<MGH09,    1, 4>("MGH09.dat",    options);
-  hard_success += RegressionDriver<Thurber,  1, 7>("Thurber.dat",  options);
-  hard_success += RegressionDriver<BoxBOD,   1, 2>("BoxBOD.dat",   options);
-  hard_success += RegressionDriver<Rat42,    1, 3>("Rat42.dat",    options);
-  hard_success += RegressionDriver<MGH10,    1, 3>("MGH10.dat",    options);
+  hard_success += RegressionDriver<MGH09, 4>("MGH09.dat");
+  hard_success += RegressionDriver<Thurber, 7>("Thurber.dat");
+  hard_success += RegressionDriver<BoxBOD, 2>("BoxBOD.dat");
+  hard_success += RegressionDriver<Rat42, 3>("Rat42.dat");
+  hard_success += RegressionDriver<MGH10, 3>("MGH10.dat");
 
-  hard_success += RegressionDriver<Eckerle4, 1, 3>("Eckerle4.dat", options);
-  hard_success += RegressionDriver<Rat43,    1, 4>("Rat43.dat",    options);
-  hard_success += RegressionDriver<Bennet5,  1, 3>("Bennett5.dat", options);
+  hard_success += RegressionDriver<Eckerle4, 3>("Eckerle4.dat");
+  hard_success += RegressionDriver<Rat43, 4>("Rat43.dat");
+  hard_success += RegressionDriver<Bennet5, 3>("Bennett5.dat");
 
   cout << "\n";
   cout << "Easy    : " << easy_success << "/16\n";
   cout << "Medium  : " << medium_success << "/22\n";
   cout << "Hard    : " << hard_success << "/16\n";
-  cout << "Total   : "
-       << easy_success + medium_success + hard_success << "/54\n";
+  cout << "Total   : " << easy_success + medium_success + hard_success
+       << "/54\n";
 }
 
 }  // namespace examples