Refactor nist.cc to be compatible with TinySolver
Change-Id: Iec0455ff9fe327fe75dc63f5b80c2ecca2c48e55
diff --git a/examples/nist.cc b/examples/nist.cc
index 55a3c7c..754467c 100644
--- a/examples/nist.cc
+++ b/examples/nist.cc
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2015 Google Inc. All rights reserved.
+// Copyright 2017 Google Inc. All rights reserved.
// http://ceres-solver.org/
//
// Redistribution and use in source and binary forms, with or without
@@ -71,14 +71,18 @@
// Average LRE 2.3 4.3 4.0 6.8 4.4 9.4
// Winner 0 0 5 11 2 41
+#include <Eigen/Core>
+#include <fstream>
#include <iostream>
#include <iterator>
-#include <fstream>
+
#include "ceres/ceres.h"
+#include "ceres/tiny_solver.h"
+#include "ceres/tiny_solver_cost_function_adapter.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
-#include "Eigen/Core"
+DEFINE_bool(use_tiny_solver, false, "Use TinySolver instead of Ceres::Solver");
DEFINE_string(nist_data_dir, "", "Directory containing the NIST non-linear"
"regression examples");
DEFINE_string(minimizer, "trust_region",
@@ -265,20 +269,22 @@
double certified_cost_;
};
-#define NIST_BEGIN(CostFunctionName) \
- struct CostFunctionName { \
- CostFunctionName(const double* const x, \
- const double* const y) \
- : x_(*x), y_(*y) {} \
- double x_; \
- double y_; \
- template <typename T> \
- bool operator()(const T* const b, T* residual) const { \
- const T y(y_); \
- const T x(x_); \
- residual[0] = y - (
+#define NIST_BEGIN(CostFunctionName) \
+ struct CostFunctionName { \
+ CostFunctionName(const double* const x, \
+ const double* const y, \
+ const int n) \
+ : x_(x), y_(y), n_(n) {} \
+ const double* x_; \
+ const double* y_; \
+ const int n_; \
+ template <typename T> \
+ bool operator()(const T* const b, T* residual) const { \
+ for (int i = 0; i < n_; ++i) { \
+ const T x(x_[i]); \
+ residual[i] = y_[i] - (
-#define NIST_END ); return true; }};
+#define NIST_END ); } return true; }};
// y = b1 * (b2+x)**(-1/b3) + e
NIST_BEGIN(Bennet5)
@@ -405,20 +411,22 @@
struct Nelson {
public:
- Nelson(const double* const x, const double* const y)
- : x1_(x[0]), x2_(x[1]), y_(y[0]) {}
+ Nelson(const double* const x, const double* const y, const int n)
+ : x_(x), y_(y), n_(n) {}
template <typename T>
bool operator()(const T* const b, T* residual) const {
// log[y] = b1 - b2*x1 * exp[-b3*x2] + e
- residual[0] = log(y_) - (b[0] - b[1] * x1_ * exp(-b[2] * x2_));
+ for (int i = 0; i < n_; ++i) {
+ residual[i] = log(y_[i]) - (b[0] - b[1] * x_[2 * i] * exp(-b[2] * x_[2 * i + 1]));
+ }
return true;
}
private:
- double x1_;
- double x2_;
- double y_;
+ const double* x_;
+ const double* y_;
+ const int n_;
};
static void SetNumericDiffOptions(ceres::NumericDiffOptions* options) {
@@ -426,138 +434,18 @@
options->ridders_relative_initial_step_size = FLAGS_ridders_step_size;
}
-string JoinPath(const string& dirname, const string& basename) {
-#ifdef _WIN32
- static const char separator = '\\';
-#else
- static const char separator = '/';
-#endif // _WIN32
-
- if ((!basename.empty() && basename[0] == separator) || dirname.empty()) {
- return basename;
- } else if (dirname[dirname.size() - 1] == separator) {
- return dirname + basename;
- } else {
- return dirname + string(&separator, 1) + basename;
- }
-}
-
-template <typename Model, int num_residuals, int num_parameters>
-int RegressionDriver(const string& filename,
- const ceres::Solver::Options& options) {
- NISTProblem nist_problem(JoinPath(FLAGS_nist_data_dir, filename));
- CHECK_EQ(num_residuals, nist_problem.response_size());
- CHECK_EQ(num_parameters, nist_problem.num_parameters());
-
- Matrix predictor = nist_problem.predictor();
- Matrix response = nist_problem.response();
- Matrix final_parameters = nist_problem.final_parameters();
-
- printf("%s\n", filename.c_str());
-
- // Each NIST problem comes with multiple starting points, so we
- // construct the problem from scratch for each case and solve it.
- int num_success = 0;
- for (int start = 0; start < nist_problem.num_starts(); ++start) {
- Matrix initial_parameters = nist_problem.initial_parameters(start);
-
- ceres::Problem problem;
- for (int i = 0; i < nist_problem.num_observations(); ++i) {
- Model* model = new Model(
- predictor.data() + nist_problem.predictor_size() * i,
- response.data() + nist_problem.response_size() * i);
- ceres::CostFunction* cost_function = NULL;
- if (FLAGS_use_numeric_diff) {
- ceres::NumericDiffOptions options;
- SetNumericDiffOptions(&options);
- if (FLAGS_numeric_diff_method == "central") {
- cost_function = new NumericDiffCostFunction<Model,
- ceres::CENTRAL,
- num_residuals,
- num_parameters>(
- model, ceres::TAKE_OWNERSHIP, num_residuals, options);
- } else if (FLAGS_numeric_diff_method == "forward") {
- cost_function = new NumericDiffCostFunction<Model,
- ceres::FORWARD,
- num_residuals,
- num_parameters>(
- model, ceres::TAKE_OWNERSHIP, num_residuals, options);
- } else if (FLAGS_numeric_diff_method == "ridders") {
- cost_function = new NumericDiffCostFunction<Model,
- ceres::RIDDERS,
- num_residuals,
- num_parameters>(
- model, ceres::TAKE_OWNERSHIP, num_residuals, options);
- } else {
- LOG(ERROR) << "Invalid numeric diff method specified";
- return 0;
- }
- } else {
- cost_function =
- new ceres::AutoDiffCostFunction<Model,
- num_residuals,
- num_parameters>(model);
- }
-
- problem.AddResidualBlock(cost_function,
- NULL,
- initial_parameters.data());
- }
-
- ceres::Solver::Summary summary;
- Solve(options, &problem, &summary);
-
- // Compute the LRE by comparing each component of the solution
- // with the ground truth, and taking the minimum.
- Matrix final_parameters = nist_problem.final_parameters();
- const double kMaxNumSignificantDigits = 11;
- double log_relative_error = kMaxNumSignificantDigits + 1;
- for (int i = 0; i < num_parameters; ++i) {
- const double tmp_lre =
- -std::log10(std::fabs(final_parameters(i) - initial_parameters(i)) /
- std::fabs(final_parameters(i)));
- // The maximum LRE is capped at 11 - the precision at which the
- // ground truth is known.
- //
- // The minimum LRE is capped at 0 - no digits match between the
- // computed solution and the ground truth.
- log_relative_error =
- std::min(log_relative_error,
- std::max(0.0, std::min(kMaxNumSignificantDigits, tmp_lre)));
- }
-
- const int kMinNumMatchingDigits = 4;
- if (log_relative_error > kMinNumMatchingDigits) {
- ++num_success;
- }
-
- printf("start: %d status: %s lre: %4.1f initial cost: %e final cost:%e "
- "certified cost: %e total iterations: %d\n",
- start + 1,
- log_relative_error < kMinNumMatchingDigits ? "FAILURE" : "SUCCESS",
- log_relative_error,
- summary.initial_cost,
- summary.final_cost,
- nist_problem.certified_cost(),
- (summary.num_successful_steps + summary.num_unsuccessful_steps));
- }
- return num_success;
-}
-
void SetMinimizerOptions(ceres::Solver::Options* options) {
- CHECK(ceres::StringToMinimizerType(FLAGS_minimizer,
- &options->minimizer_type));
+ CHECK(
+ ceres::StringToMinimizerType(FLAGS_minimizer, &options->minimizer_type));
CHECK(ceres::StringToLinearSolverType(FLAGS_linear_solver,
&options->linear_solver_type));
CHECK(ceres::StringToPreconditionerType(FLAGS_preconditioner,
&options->preconditioner_type));
CHECK(ceres::StringToTrustRegionStrategyType(
- FLAGS_trust_region_strategy,
- &options->trust_region_strategy_type));
+ FLAGS_trust_region_strategy, &options->trust_region_strategy_type));
CHECK(ceres::StringToDoglegType(FLAGS_dogleg, &options->dogleg_type));
CHECK(ceres::StringToLineSearchDirectionType(
- FLAGS_line_search_direction,
- &options->line_search_direction_type));
+ FLAGS_line_search_direction, &options->line_search_direction_type));
CHECK(ceres::StringToLineSearchType(FLAGS_line_search,
&options->line_search_type));
CHECK(ceres::StringToLineSearchInterpolationType(
@@ -582,57 +470,213 @@
options->parameter_tolerance = std::numeric_limits<double>::epsilon();
}
+string JoinPath(const string& dirname, const string& basename) {
+#ifdef _WIN32
+ static const char separator = '\\';
+#else
+ static const char separator = '/';
+#endif // _WIN32
+
+ if ((!basename.empty() && basename[0] == separator) || dirname.empty()) {
+ return basename;
+ } else if (dirname[dirname.size() - 1] == separator) {
+ return dirname + basename;
+ } else {
+ return dirname + string(&separator, 1) + basename;
+ }
+}
+
+template <typename Model, int num_parameters>
+CostFunction* CreateCostFunction(const Matrix& predictor,
+ const Matrix& response,
+ const int num_observations) {
+ Model* model =
+ new Model(predictor.data(), response.data(), num_observations);
+ ceres::CostFunction* cost_function = NULL;
+ if (FLAGS_use_numeric_diff) {
+ ceres::NumericDiffOptions options;
+ SetNumericDiffOptions(&options);
+ if (FLAGS_numeric_diff_method == "central") {
+ cost_function = new NumericDiffCostFunction<Model,
+ ceres::CENTRAL,
+ ceres::DYNAMIC,
+ num_parameters>(
+ model,
+ ceres::TAKE_OWNERSHIP,
+ num_observations,
+ options);
+ } else if (FLAGS_numeric_diff_method == "forward") {
+ cost_function = new NumericDiffCostFunction<Model,
+ ceres::FORWARD,
+ ceres::DYNAMIC,
+ num_parameters>(
+ model,
+ ceres::TAKE_OWNERSHIP,
+ num_observations,
+ options);
+ } else if (FLAGS_numeric_diff_method == "ridders") {
+ cost_function = new NumericDiffCostFunction<Model,
+ ceres::RIDDERS,
+ ceres::DYNAMIC,
+ num_parameters>(
+ model,
+ ceres::TAKE_OWNERSHIP,
+ num_observations,
+ options);
+ } else {
+ LOG(ERROR) << "Invalid numeric diff method specified";
+ return 0;
+ }
+ } else {
+ cost_function =
+ new ceres::AutoDiffCostFunction<Model, ceres::DYNAMIC, num_parameters>(
+ model, num_observations);
+ }
+ return cost_function;
+}
+
+double ComputeLRE(const Matrix& expected, const Matrix& actual) {
+ // Compute the LRE by comparing each component of the solution
+ // with the ground truth, and taking the minimum.
+ const double kMaxNumSignificantDigits = 11;
+ double log_relative_error = kMaxNumSignificantDigits + 1;
+ for (int i = 0; i < expected.cols(); ++i) {
+ const double tmp_lre = -std::log10(std::fabs(expected(i) - actual(i)) /
+ std::fabs(expected(i)));
+ // The maximum LRE is capped at 11 - the precision at which the
+ // ground truth is known.
+ //
+ // The minimum LRE is capped at 0 - no digits match between the
+ // computed solution and the ground truth.
+ log_relative_error =
+ std::min(log_relative_error,
+ std::max(0.0, std::min(kMaxNumSignificantDigits, tmp_lre)));
+ }
+ return log_relative_error;
+}
+
+template <typename Model, int num_parameters>
+int RegressionDriver(const string& filename) {
+ NISTProblem nist_problem(JoinPath(FLAGS_nist_data_dir, filename));
+ CHECK_EQ(num_parameters, nist_problem.num_parameters());
+
+ Matrix predictor = nist_problem.predictor();
+ Matrix response = nist_problem.response();
+ Matrix final_parameters = nist_problem.final_parameters();
+
+ printf("%s\n", filename.c_str());
+
+ // Each NIST problem comes with multiple starting points, so we
+ // construct the problem from scratch for each case and solve it.
+ int num_success = 0;
+ for (int start = 0; start < nist_problem.num_starts(); ++start) {
+ Matrix initial_parameters = nist_problem.initial_parameters(start);
+ ceres::CostFunction* cost_function = CreateCostFunction<Model, num_parameters>(
+ predictor, response, nist_problem.num_observations());
+
+ double initial_cost;
+ double final_cost;
+
+ if (!FLAGS_use_tiny_solver) {
+ ceres::Problem problem;
+ problem.AddResidualBlock(cost_function, NULL, initial_parameters.data());
+ ceres::Solver::Summary summary;
+ ceres::Solver::Options options;
+ SetMinimizerOptions(&options);
+ Solve(options, &problem, &summary);
+ initial_cost = summary.initial_cost;
+ final_cost = summary.final_cost;
+ } else {
+ ceres::TinySolverCostFunctionAdapter<Eigen::Dynamic, num_parameters> cfa(
+ *cost_function);
+ typedef
+ ceres::TinySolver<
+ ceres::TinySolverCostFunctionAdapter<Eigen::Dynamic, num_parameters>,
+ Eigen::LDLT<Eigen::Matrix<double, num_parameters, num_parameters>>> Solver;
+ Solver solver;
+ solver.options.max_iterations = FLAGS_num_iterations;
+ solver.options.error_threshold = std::numeric_limits<double>::epsilon();
+ solver.options.gradient_threshold = std::numeric_limits<double>::epsilon();
+ solver.options.relative_step_threshold = std::numeric_limits<double>::epsilon();
+
+ Eigen::Matrix<double, num_parameters,1> x;
+ x = initial_parameters.transpose();
+ typename Solver::Summary summary = solver.Solve(cfa, &x);
+ initial_parameters = x;
+ initial_cost = summary.initial_cost;
+ final_cost = summary.final_cost;
+ delete cost_function;
+ }
+
+ const double log_relative_error = ComputeLRE(nist_problem.final_parameters(),
+ initial_parameters);
+ const int kMinNumMatchingDigits = 4;
+ if (log_relative_error > kMinNumMatchingDigits) {
+ ++num_success;
+ }
+
+ printf(
+ "start: %d status: %s lre: %4.1f initial cost: %e final cost:%e "
+ "certified cost: %e\n",
+ start + 1,
+ log_relative_error < kMinNumMatchingDigits ? "FAILURE" : "SUCCESS",
+ log_relative_error,
+ initial_cost,
+ final_cost,
+ nist_problem.certified_cost());
+ }
+ return num_success;
+}
+
+
void SolveNISTProblems() {
if (FLAGS_nist_data_dir.empty()) {
LOG(FATAL) << "Must specify the directory containing the NIST problems";
}
- ceres::Solver::Options options;
- SetMinimizerOptions(&options);
-
cout << "Lower Difficulty\n";
int easy_success = 0;
- easy_success += RegressionDriver<Misra1a, 1, 2>("Misra1a.dat", options);
- easy_success += RegressionDriver<Chwirut, 1, 3>("Chwirut1.dat", options);
- easy_success += RegressionDriver<Chwirut, 1, 3>("Chwirut2.dat", options);
- easy_success += RegressionDriver<Lanczos, 1, 6>("Lanczos3.dat", options);
- easy_success += RegressionDriver<Gauss, 1, 8>("Gauss1.dat", options);
- easy_success += RegressionDriver<Gauss, 1, 8>("Gauss2.dat", options);
- easy_success += RegressionDriver<DanWood, 1, 2>("DanWood.dat", options);
- easy_success += RegressionDriver<Misra1b, 1, 2>("Misra1b.dat", options);
+ easy_success += RegressionDriver<Misra1a, 2>("Misra1a.dat");
+ easy_success += RegressionDriver<Chwirut, 3>("Chwirut1.dat");
+ easy_success += RegressionDriver<Chwirut, 3>("Chwirut2.dat");
+ easy_success += RegressionDriver<Lanczos, 6>("Lanczos3.dat");
+ easy_success += RegressionDriver<Gauss, 8>("Gauss1.dat");
+ easy_success += RegressionDriver<Gauss, 8>("Gauss2.dat");
+ easy_success += RegressionDriver<DanWood, 2>("DanWood.dat");
+ easy_success += RegressionDriver<Misra1b, 2>("Misra1b.dat");
cout << "\nMedium Difficulty\n";
int medium_success = 0;
- medium_success += RegressionDriver<Kirby2, 1, 5>("Kirby2.dat", options);
- medium_success += RegressionDriver<Hahn1, 1, 7>("Hahn1.dat", options);
- medium_success += RegressionDriver<Nelson, 1, 3>("Nelson.dat", options);
- medium_success += RegressionDriver<MGH17, 1, 5>("MGH17.dat", options);
- medium_success += RegressionDriver<Lanczos, 1, 6>("Lanczos1.dat", options);
- medium_success += RegressionDriver<Lanczos, 1, 6>("Lanczos2.dat", options);
- medium_success += RegressionDriver<Gauss, 1, 8>("Gauss3.dat", options);
- medium_success += RegressionDriver<Misra1c, 1, 2>("Misra1c.dat", options);
- medium_success += RegressionDriver<Misra1d, 1, 2>("Misra1d.dat", options);
- medium_success += RegressionDriver<Roszman1, 1, 4>("Roszman1.dat", options);
- medium_success += RegressionDriver<ENSO, 1, 9>("ENSO.dat", options);
+ medium_success += RegressionDriver<Kirby2, 5>("Kirby2.dat");
+ medium_success += RegressionDriver<Hahn1, 7>("Hahn1.dat");
+ medium_success += RegressionDriver<Nelson, 3>("Nelson.dat");
+ medium_success += RegressionDriver<MGH17, 5>("MGH17.dat");
+ medium_success += RegressionDriver<Lanczos, 6>("Lanczos1.dat");
+ medium_success += RegressionDriver<Lanczos, 6>("Lanczos2.dat");
+ medium_success += RegressionDriver<Gauss, 8>("Gauss3.dat");
+ medium_success += RegressionDriver<Misra1c, 2>("Misra1c.dat");
+ medium_success += RegressionDriver<Misra1d, 2>("Misra1d.dat");
+ medium_success += RegressionDriver<Roszman1, 4>("Roszman1.dat");
+ medium_success += RegressionDriver<ENSO, 9>("ENSO.dat");
cout << "\nHigher Difficulty\n";
int hard_success = 0;
- hard_success += RegressionDriver<MGH09, 1, 4>("MGH09.dat", options);
- hard_success += RegressionDriver<Thurber, 1, 7>("Thurber.dat", options);
- hard_success += RegressionDriver<BoxBOD, 1, 2>("BoxBOD.dat", options);
- hard_success += RegressionDriver<Rat42, 1, 3>("Rat42.dat", options);
- hard_success += RegressionDriver<MGH10, 1, 3>("MGH10.dat", options);
+ hard_success += RegressionDriver<MGH09, 4>("MGH09.dat");
+ hard_success += RegressionDriver<Thurber, 7>("Thurber.dat");
+ hard_success += RegressionDriver<BoxBOD, 2>("BoxBOD.dat");
+ hard_success += RegressionDriver<Rat42, 3>("Rat42.dat");
+ hard_success += RegressionDriver<MGH10, 3>("MGH10.dat");
- hard_success += RegressionDriver<Eckerle4, 1, 3>("Eckerle4.dat", options);
- hard_success += RegressionDriver<Rat43, 1, 4>("Rat43.dat", options);
- hard_success += RegressionDriver<Bennet5, 1, 3>("Bennett5.dat", options);
+ hard_success += RegressionDriver<Eckerle4, 3>("Eckerle4.dat");
+ hard_success += RegressionDriver<Rat43, 4>("Rat43.dat");
+ hard_success += RegressionDriver<Bennet5, 3>("Bennett5.dat");
cout << "\n";
cout << "Easy : " << easy_success << "/16\n";
cout << "Medium : " << medium_success << "/22\n";
cout << "Hard : " << hard_success << "/16\n";
- cout << "Total : "
- << easy_success + medium_success + hard_success << "/54\n";
+ cout << "Total : " << easy_success + medium_success + hard_success
+ << "/54\n";
}
} // namespace examples