Better reporting on the NIST problems.
Change-Id: I7cf774ec3242c0612dbe52fc233c3fc6cff3f031
diff --git a/examples/nist.cc b/examples/nist.cc
index 611b7e6..696bd67 100644
--- a/examples/nist.cc
+++ b/examples/nist.cc
@@ -310,7 +310,7 @@
};
template <typename Model, int num_residuals, int num_parameters>
-void RegressionDriver(const std::string& filename,
+int RegressionDriver(const std::string& filename,
const ceres::Solver::Options& options) {
NISTProblem nist_problem(FLAGS_nist_data_dir + filename);
CHECK_EQ(num_residuals, nist_problem.response_size());
@@ -353,16 +353,71 @@
}
Solve(options, &problem, &summaries.back());
double certified_cost = summaries.back().initial_cost;
-
std::cout << filename << std::endl;
+
+ int num_success = 0;
for (int i = 0; i < nist_problem.num_starts(); ++i) {
- std::cout << "start " << i + 1 << ": "
- << " relative difference : "
- << (summaries[i].final_cost - certified_cost) / certified_cost
- << " termination: "
+ const int num_matching_digits =
+ -std::log10(1e-18 +
+ fabs(summaries[i].final_cost - certified_cost)
+ / certified_cost);
+ std::cout << "start " << i + 1 << " " ;
+ if (num_matching_digits > 4) {
+ ++num_success;
+ std::cout << "SUCCESS";
+ } else {
+ std::cout << "FAILURE";
+ }
+ std::cout << " digits: " << num_matching_digits;
+ std::cout << " termination: "
<< ceres::SolverTerminationTypeToString(summaries[i].termination_type)
<< std::endl;
}
+ return num_success;
+}
+
+void SolveNISTProblems(const ceres::Solver::Options& options) {
+ std::cout << "Lower Difficulty\n";
+ int easy_success = 0;
+ easy_success += RegressionDriver<Misra1a, 1, 2>("Misra1a.dat", options);
+ easy_success += RegressionDriver<Chwirut, 1, 3>("Chwirut1.dat", options);
+ easy_success += RegressionDriver<Chwirut, 1, 3>("Chwirut2.dat", options);
+ easy_success += RegressionDriver<Lanczos, 1, 6>("Lanczos3.dat", options);
+ easy_success += RegressionDriver<Gauss, 1, 8>("Gauss1.dat", options);
+ easy_success += RegressionDriver<Gauss, 1, 8>("Gauss2.dat", options);
+ easy_success += RegressionDriver<DanWood, 1, 2>("DanWood.dat", options);
+ easy_success += RegressionDriver<Misra1b, 1, 2>("Misra1b.dat", options);
+
+ std::cout << "\nMedium Difficulty\n";
+ int medium_success = 0;
+ medium_success += RegressionDriver<Kirby2, 1, 5>("Kirby2.dat", options);
+ medium_success += RegressionDriver<Hahn1, 1, 7>("Hahn1.dat", options);
+ medium_success += RegressionDriver<Nelson, 1, 3>("Nelson.dat", options);
+ medium_success += RegressionDriver<MGH17, 1, 5>("MGH17.dat", options);
+ medium_success += RegressionDriver<Lanczos, 1, 6>("Lanczos1.dat", options);
+ medium_success += RegressionDriver<Lanczos, 1, 6>("Lanczos2.dat", options);
+ medium_success += RegressionDriver<Gauss, 1, 8>("Gauss3.dat", options);
+ medium_success += RegressionDriver<Misra1c, 1, 2>("Misra1c.dat", options);
+ medium_success += RegressionDriver<Misra1d, 1, 2>("Misra1d.dat", options);
+ medium_success += RegressionDriver<Roszman1, 1, 4>("Roszman1.dat", options);
+ medium_success += RegressionDriver<ENSO, 1, 9>("ENSO.dat", options);
+
+ std::cout << "\nHigher Difficulty\n";
+ int hard_success = 0;
+ hard_success += RegressionDriver<MGH09, 1, 4>("MGH09.dat", options);
+ hard_success += RegressionDriver<Thurber, 1, 7>("Thurber.dat", options);
+ hard_success += RegressionDriver<BoxBOD, 1, 2>("BoxBOD.dat", options);
+ hard_success += RegressionDriver<Rat42, 1, 3>("Rat42.dat", options);
+ hard_success += RegressionDriver<MGH10, 1, 3>("MGH10.dat", options);
+ hard_success += RegressionDriver<Eckerle4, 1, 3>("Eckerle4.dat", options);
+ hard_success += RegressionDriver<Rat43, 1, 4>("Rat43.dat", options);
+ hard_success += RegressionDriver<Bennet5, 1, 3>("Bennett5.dat", options);
+
+ std::cout << "\n";
+ std::cout << "Easy : " << easy_success << "/16\n";
+ std::cout << "Medium : " << medium_success << "/22\n";
+ std::cout << "Hard : " << hard_success << "/16\n";
+ std::cout << "Total : " << easy_success + medium_success + hard_success << "/54\n";
}
int main(int argc, char** argv) {
@@ -373,43 +428,12 @@
// linear solvers.
ceres::Solver::Options options;
options.linear_solver_type = ceres::DENSE_QR;
- options.max_num_iterations = 2000;
+ options.max_num_iterations = 1000;
options.function_tolerance *= 1e-10;
options.gradient_tolerance *= 1e-10;
options.parameter_tolerance *= 1e-10;
- std::cout << "Lower Difficulty\n";
- RegressionDriver<Misra1a, 1, 2>("Misra1a.dat", options);
- RegressionDriver<Chwirut, 1, 3>("Chwirut1.dat", options);
- RegressionDriver<Chwirut, 1, 3>("Chwirut2.dat", options);
- RegressionDriver<Lanczos, 1, 6>("Lanczos3.dat", options);
- RegressionDriver<Gauss, 1, 8>("Gauss1.dat", options);
- RegressionDriver<Gauss, 1, 8>("Gauss2.dat", options);
- RegressionDriver<DanWood, 1, 2>("DanWood.dat", options);
- RegressionDriver<Misra1b, 1, 2>("Misra1b.dat", options);
-
- std::cout << "\nAverage Difficulty\n";
- RegressionDriver<Kirby2, 1, 5>("Kirby2.dat", options);
- RegressionDriver<Hahn1, 1, 7>("Hahn1.dat", options);
- RegressionDriver<Nelson, 1, 3>("Nelson.dat", options);
- RegressionDriver<MGH17, 1, 5>("MGH17.dat", options);
- RegressionDriver<Lanczos, 1, 6>("Lanczos1.dat", options);
- RegressionDriver<Lanczos, 1, 6>("Lanczos2.dat", options);
- RegressionDriver<Gauss, 1, 8>("Gauss3.dat", options);
- RegressionDriver<Misra1c, 1, 2>("Misra1c.dat", options);
- RegressionDriver<Misra1d, 1, 2>("Misra1d.dat", options);
- RegressionDriver<Roszman1, 1, 4>("Roszman1.dat", options);
- RegressionDriver<ENSO, 1, 9>("ENSO.dat", options);
-
- std::cout << "\nHigher Difficulty\n";
- RegressionDriver<MGH09, 1, 4>("MGH09.dat", options);
- RegressionDriver<Thurber, 1, 7>("Thurber.dat", options);
- RegressionDriver<BoxBOD, 1, 2>("BoxBOD.dat", options);
- RegressionDriver<Rat42, 1, 3>("Rat42.dat", options);
- RegressionDriver<MGH10, 1, 3>("MGH10.dat", options);
- RegressionDriver<Eckerle4, 1, 3>("Eckerle4.dat", options);
- RegressionDriver<Rat43, 1, 4>("Rat43.dat", options);
- RegressionDriver<Bennet5, 1, 3>("Bennett5.dat", options);
+ SolveNISTProblems(options);
return 0;
};