Adding Wolfe line search algorithm and full BFGS search direction options.
Change-Id: I9d3fb117805bdfa5bc33613368f45ae8f10e0d79
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index e9e5cef..9dfc80b 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -37,9 +37,6 @@
ADD_EXECUTABLE(helloworld_analytic_diff helloworld_analytic_diff.cc)
TARGET_LINK_LIBRARIES(helloworld_analytic_diff ceres)
-ADD_EXECUTABLE(powell powell.cc)
-TARGET_LINK_LIBRARIES(powell ceres)
-
ADD_EXECUTABLE(curve_fitting curve_fitting.cc)
TARGET_LINK_LIBRARIES(curve_fitting ceres)
@@ -55,6 +52,9 @@
TARGET_LINK_LIBRARIES(simple_bundle_adjuster ceres)
IF (GFLAGS)
+ ADD_EXECUTABLE(powell powell.cc)
+ TARGET_LINK_LIBRARIES(powell ceres)
+
ADD_EXECUTABLE(nist nist.cc)
TARGET_LINK_LIBRARIES(nist ceres)
diff --git a/examples/nist.cc b/examples/nist.cc
index 8e0f37e..1773a0f 100644
--- a/examples/nist.cc
+++ b/examples/nist.cc
@@ -81,6 +81,8 @@
DEFINE_string(nist_data_dir, "", "Directory containing the NIST non-linear"
"regression examples");
+DEFINE_string(minimizer, "trust_region",
+ "Minimizer type to use, choices are: line_search & trust_region");
DEFINE_string(trust_region_strategy, "levenberg_marquardt",
"Options are: levenberg_marquardt, dogleg");
DEFINE_string(dogleg, "traditional_dogleg",
@@ -90,6 +92,25 @@
"cgnr");
DEFINE_string(preconditioner, "jacobi", "Options are: "
"identity, jacobi");
+DEFINE_string(line_search, "armijo",
+ "Line search algorithm to use, choices are: armijo and wolfe.");
+DEFINE_string(line_search_direction, "lbfgs",
+ "Line search direction algorithm to use, choices: lbfgs, bfgs");
+DEFINE_int32(max_line_search_iterations, 20,
+ "Maximum number of iterations for each line search.");
+DEFINE_int32(max_line_search_restarts, 10,
+ "Maximum number of restarts of line search direction algorithm.");
+DEFINE_string(line_search_interpolation, "cubic",
+ "Degree of polynomial aproximation in line search, "
+ "choices are: bisection, quadratic & cubic.");
+DEFINE_int32(lbfgs_rank, 20,
+ "Rank of L-BFGS inverse Hessian approximation in line search.");
+DEFINE_bool(approximate_eigenvalue_bfgs_scaling, false,
+ "Use approximate eigenvalue scaling in (L)BFGS line search.");
+DEFINE_double(sufficient_decrease, 1.0e-4,
+ "Line search Armijo sufficient (function) decrease factor.");
+DEFINE_double(sufficient_curvature_decrease, 0.9,
+ "Line search Wolfe sufficient curvature decrease factor.");
DEFINE_int32(num_iterations, 10000, "Number of iterations");
DEFINE_bool(nonmonotonic_steps, false, "Trust region algorithm can use"
" nonmonotic steps");
@@ -392,7 +413,7 @@
template <typename Model, int num_residuals, int num_parameters>
int RegressionDriver(const std::string& filename,
- const ceres::Solver::Options& options) {
+ const ceres::Solver::Options& options) {
NISTProblem nist_problem(FLAGS_nist_data_dir + filename);
CHECK_EQ(num_residuals, nist_problem.response_size());
CHECK_EQ(num_parameters, nist_problem.num_parameters());
@@ -446,18 +467,22 @@
++num_success;
}
- printf("start: %d status: %s lre: %4.1f initial cost: %e final cost:%e certified cost: %e\n",
+ printf("start: %d status: %s lre: %4.1f initial cost: %e final cost:%e "
+ "certified cost: %e total iterations: %d\n",
start + 1,
log_relative_error < kMinNumMatchingDigits ? "FAILURE" : "SUCCESS",
log_relative_error,
summary.initial_cost,
summary.final_cost,
- nist_problem.certified_cost());
+ nist_problem.certified_cost(),
+ (summary.num_successful_steps + summary.num_unsuccessful_steps));
}
return num_success;
}
void SetMinimizerOptions(ceres::Solver::Options* options) {
+ CHECK(ceres::StringToMinimizerType(FLAGS_minimizer,
+ &options->minimizer_type));
CHECK(ceres::StringToLinearSolverType(FLAGS_linear_solver,
&options->linear_solver_type));
CHECK(ceres::StringToPreconditionerType(FLAGS_preconditioner,
@@ -466,10 +491,28 @@
FLAGS_trust_region_strategy,
&options->trust_region_strategy_type));
CHECK(ceres::StringToDoglegType(FLAGS_dogleg, &options->dogleg_type));
+ CHECK(ceres::StringToLineSearchDirectionType(
+ FLAGS_line_search_direction,
+ &options->line_search_direction_type));
+ CHECK(ceres::StringToLineSearchType(FLAGS_line_search,
+ &options->line_search_type));
+ CHECK(ceres::StringToLineSearchInterpolationType(
+ FLAGS_line_search_interpolation,
+ &options->line_search_interpolation_type));
options->max_num_iterations = FLAGS_num_iterations;
options->use_nonmonotonic_steps = FLAGS_nonmonotonic_steps;
options->initial_trust_region_radius = FLAGS_initial_trust_region_radius;
+ options->max_lbfgs_rank = FLAGS_lbfgs_rank;
+ options->line_search_sufficient_function_decrease = FLAGS_sufficient_decrease;
+ options->line_search_sufficient_curvature_decrease =
+ FLAGS_sufficient_curvature_decrease;
+ options->max_num_line_search_step_size_iterations =
+ FLAGS_max_line_search_iterations;
+ options->max_num_line_search_direction_restarts =
+ FLAGS_max_line_search_restarts;
+ options->use_approximate_eigenvalue_bfgs_scaling =
+ FLAGS_approximate_eigenvalue_bfgs_scaling;
options->function_tolerance = 1e-18;
options->gradient_tolerance = 1e-18;
options->parameter_tolerance = 1e-18;
diff --git a/examples/powell.cc b/examples/powell.cc
index 4a41728..c0cba02 100644
--- a/examples/powell.cc
+++ b/examples/powell.cc
@@ -46,6 +46,7 @@
#include <vector>
#include "ceres/ceres.h"
+#include "gflags/gflags.h"
#include "glog/logging.h"
using ceres::AutoDiffCostFunction;
@@ -94,7 +95,11 @@
}
};
+DEFINE_string(minimizer, "trust_region",
+ "Minimizer type to use, choices are: line_search & trust_region");
+
int main(int argc, char** argv) {
+ google::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
double x1 = 3.0;
@@ -119,23 +124,27 @@
NULL,
&x1, &x4);
- // Run the solver!
Solver::Options options;
- options.max_num_iterations = 30;
+ LOG_IF(FATAL, !ceres::StringToMinimizerType(FLAGS_minimizer,
+ &options.minimizer_type))
+ << "Invalid minimizer: " << FLAGS_minimizer
+ << ", valid options are: trust_region and line_search.";
+
+ options.max_num_iterations = 100;
options.linear_solver_type = ceres::DENSE_QR;
options.minimizer_progress_to_stdout = true;
- Solver::Summary summary;
-
std::cout << "Initial x1 = " << x1
<< ", x2 = " << x2
<< ", x3 = " << x3
<< ", x4 = " << x4
<< "\n";
+ // Run the solver!
+ Solver::Summary summary;
Solve(options, &problem, &summary);
- std::cout << summary.BriefReport() << "\n";
+ std::cout << summary.FullReport() << "\n";
std::cout << "Final x1 = " << x1
<< ", x2 = " << x2
<< ", x3 = " << x3