Better reporting on the NIST problems. Change-Id: I7cf774ec3242c0612dbe52fc233c3fc6cff3f031
diff --git a/examples/nist.cc b/examples/nist.cc index 611b7e6..696bd67 100644 --- a/examples/nist.cc +++ b/examples/nist.cc
@@ -310,7 +310,7 @@ }; template <typename Model, int num_residuals, int num_parameters> -void RegressionDriver(const std::string& filename, +int RegressionDriver(const std::string& filename, const ceres::Solver::Options& options) { NISTProblem nist_problem(FLAGS_nist_data_dir + filename); CHECK_EQ(num_residuals, nist_problem.response_size()); @@ -353,16 +353,71 @@ } Solve(options, &problem, &summaries.back()); double certified_cost = summaries.back().initial_cost; - std::cout << filename << std::endl; + + int num_success = 0; for (int i = 0; i < nist_problem.num_starts(); ++i) { - std::cout << "start " << i + 1 << ": " - << " relative difference : " - << (summaries[i].final_cost - certified_cost) / certified_cost - << " termination: " + const int num_matching_digits = + -std::log10(1e-18 + + fabs(summaries[i].final_cost - certified_cost) + / certified_cost); + std::cout << "start " << i + 1 << " " ; + if (num_matching_digits > 4) { + ++num_success; + std::cout << "SUCCESS"; + } else { + std::cout << "FAILURE"; + } + std::cout << " digits: " << num_matching_digits; + std::cout << " termination: " << ceres::SolverTerminationTypeToString(summaries[i].termination_type) << std::endl; } + return num_success; +} + +void SolveNISTProblems(const ceres::Solver::Options& options) { + std::cout << "Lower Difficulty\n"; + int easy_success = 0; + easy_success += RegressionDriver<Misra1a, 1, 2>("Misra1a.dat", options); + easy_success += RegressionDriver<Chwirut, 1, 3>("Chwirut1.dat", options); + easy_success += RegressionDriver<Chwirut, 1, 3>("Chwirut2.dat", options); + easy_success += RegressionDriver<Lanczos, 1, 6>("Lanczos3.dat", options); + easy_success += RegressionDriver<Gauss, 1, 8>("Gauss1.dat", options); + easy_success += RegressionDriver<Gauss, 1, 8>("Gauss2.dat", options); + easy_success += RegressionDriver<DanWood, 1, 2>("DanWood.dat", options); + easy_success += RegressionDriver<Misra1b, 1, 2>("Misra1b.dat", options); + + std::cout << "\nMedium Difficulty\n"; + int medium_success = 0; + medium_success += RegressionDriver<Kirby2, 1, 5>("Kirby2.dat", options); + medium_success += RegressionDriver<Hahn1, 1, 7>("Hahn1.dat", options); + medium_success += RegressionDriver<Nelson, 1, 3>("Nelson.dat", options); + medium_success += RegressionDriver<MGH17, 1, 5>("MGH17.dat", options); + medium_success += RegressionDriver<Lanczos, 1, 6>("Lanczos1.dat", options); + medium_success += RegressionDriver<Lanczos, 1, 6>("Lanczos2.dat", options); + medium_success += RegressionDriver<Gauss, 1, 8>("Gauss3.dat", options); + medium_success += RegressionDriver<Misra1c, 1, 2>("Misra1c.dat", options); + medium_success += RegressionDriver<Misra1d, 1, 2>("Misra1d.dat", options); + medium_success += RegressionDriver<Roszman1, 1, 4>("Roszman1.dat", options); + medium_success += RegressionDriver<ENSO, 1, 9>("ENSO.dat", options); + + std::cout << "\nHigher Difficulty\n"; + int hard_success = 0; + hard_success += RegressionDriver<MGH09, 1, 4>("MGH09.dat", options); + hard_success += RegressionDriver<Thurber, 1, 7>("Thurber.dat", options); + hard_success += RegressionDriver<BoxBOD, 1, 2>("BoxBOD.dat", options); + hard_success += RegressionDriver<Rat42, 1, 3>("Rat42.dat", options); + hard_success += RegressionDriver<MGH10, 1, 3>("MGH10.dat", options); + hard_success += RegressionDriver<Eckerle4, 1, 3>("Eckerle4.dat", options); + hard_success += RegressionDriver<Rat43, 1, 4>("Rat43.dat", options); + hard_success += RegressionDriver<Bennet5, 1, 3>("Bennett5.dat", options); + + std::cout << "\n"; + std::cout << "Easy : " << easy_success << "/16\n"; + std::cout << "Medium : " << medium_success << "/22\n"; + std::cout << "Hard : " << hard_success << "/16\n"; + std::cout << "Total : " << easy_success + medium_success + hard_success << "/54\n"; } int main(int argc, char** argv) { @@ -373,43 +428,12 @@ // linear solvers. ceres::Solver::Options options; options.linear_solver_type = ceres::DENSE_QR; - options.max_num_iterations = 2000; + options.max_num_iterations = 1000; options.function_tolerance *= 1e-10; options.gradient_tolerance *= 1e-10; options.parameter_tolerance *= 1e-10; - std::cout << "Lower Difficulty\n"; - RegressionDriver<Misra1a, 1, 2>("Misra1a.dat", options); - RegressionDriver<Chwirut, 1, 3>("Chwirut1.dat", options); - RegressionDriver<Chwirut, 1, 3>("Chwirut2.dat", options); - RegressionDriver<Lanczos, 1, 6>("Lanczos3.dat", options); - RegressionDriver<Gauss, 1, 8>("Gauss1.dat", options); - RegressionDriver<Gauss, 1, 8>("Gauss2.dat", options); - RegressionDriver<DanWood, 1, 2>("DanWood.dat", options); - RegressionDriver<Misra1b, 1, 2>("Misra1b.dat", options); - - std::cout << "\nAverage Difficulty\n"; - RegressionDriver<Kirby2, 1, 5>("Kirby2.dat", options); - RegressionDriver<Hahn1, 1, 7>("Hahn1.dat", options); - RegressionDriver<Nelson, 1, 3>("Nelson.dat", options); - RegressionDriver<MGH17, 1, 5>("MGH17.dat", options); - RegressionDriver<Lanczos, 1, 6>("Lanczos1.dat", options); - RegressionDriver<Lanczos, 1, 6>("Lanczos2.dat", options); - RegressionDriver<Gauss, 1, 8>("Gauss3.dat", options); - RegressionDriver<Misra1c, 1, 2>("Misra1c.dat", options); - RegressionDriver<Misra1d, 1, 2>("Misra1d.dat", options); - RegressionDriver<Roszman1, 1, 4>("Roszman1.dat", options); - RegressionDriver<ENSO, 1, 9>("ENSO.dat", options); - - std::cout << "\nHigher Difficulty\n"; - RegressionDriver<MGH09, 1, 4>("MGH09.dat", options); - RegressionDriver<Thurber, 1, 7>("Thurber.dat", options); - RegressionDriver<BoxBOD, 1, 2>("BoxBOD.dat", options); - RegressionDriver<Rat42, 1, 3>("Rat42.dat", options); - RegressionDriver<MGH10, 1, 3>("MGH10.dat", options); - RegressionDriver<Eckerle4, 1, 3>("Eckerle4.dat", options); - RegressionDriver<Rat43, 1, 4>("Rat43.dat", options); - RegressionDriver<Bennet5, 1, 3>("Bennett5.dat", options); + SolveNISTProblems(options); return 0; };