Adding Wolfe line search algorithm and full BFGS search direction options.
Change-Id: I9d3fb117805bdfa5bc33613368f45ae8f10e0d79
diff --git a/internal/ceres/line_search_minimizer.cc b/internal/ceres/line_search_minimizer.cc
index 24aada3..2cc89fa 100644
--- a/internal/ceres/line_search_minimizer.cc
+++ b/internal/ceres/line_search_minimizer.cc
@@ -160,6 +160,8 @@
line_search_direction_options.nonlinear_conjugate_gradient_type =
options.nonlinear_conjugate_gradient_type;
line_search_direction_options.max_lbfgs_rank = options.max_lbfgs_rank;
+ line_search_direction_options.use_approximate_eigenvalue_bfgs_scaling =
+ options.use_approximate_eigenvalue_bfgs_scaling;
scoped_ptr<LineSearchDirection> line_search_direction(
LineSearchDirection::Create(line_search_direction_options));
@@ -170,15 +172,32 @@
options.line_search_interpolation_type;
line_search_options.min_step_size = options.min_line_search_step_size;
line_search_options.sufficient_decrease =
- options.armijo_sufficient_decrease;
- line_search_options.min_relative_step_size_change =
- options.min_armijo_relative_step_size_change;
- line_search_options.max_relative_step_size_change =
- options.max_armijo_relative_step_size_change;
+ options.line_search_sufficient_function_decrease;
+ line_search_options.max_step_contraction =
+ options.max_line_search_step_contraction;
+ line_search_options.min_step_contraction =
+ options.min_line_search_step_contraction;
+ line_search_options.max_num_iterations =
+ options.max_num_line_search_step_size_iterations;
+ line_search_options.sufficient_curvature_decrease =
+ options.line_search_sufficient_curvature_decrease;
+ line_search_options.max_step_expansion =
+ options.max_line_search_step_expansion;
line_search_options.function = &line_search_function;
- ArmijoLineSearch line_search;
+ scoped_ptr<LineSearch>
+ line_search(LineSearch::Create(options.line_search_type,
+ line_search_options,
+ &summary->error));
+ if (line_search.get() == NULL) {
+ LOG(ERROR) << "Ceres bug: Unable to create a LineSearch object, please "
+ << "contact the developers!, error: " << summary->error;
+ summary->termination_type = DID_NOT_RUN;
+ return;
+ }
+
LineSearch::Summary line_search_summary;
+ int num_line_search_direction_restarts = 0;
while (true) {
if (!RunCallbacks(options.callbacks, iteration_summary, summary)) {
@@ -215,9 +234,36 @@
¤t_state.search_direction);
}
- if (!line_search_status) {
- LOG(WARNING) << "Line search direction computation failed. "
- "Resorting to steepest descent.";
+ if (!line_search_status &&
+ num_line_search_direction_restarts >=
+ options.max_num_line_search_direction_restarts) {
+ // Line search direction failed to generate a new direction, and we
+ // have already reached our specified maximum number of restarts,
+ // terminate optimization.
+ summary->error =
+ StringPrintf("Line search direction failure: specified "
+ "max_num_line_search_direction_restarts: %d reached.",
+ options.max_num_line_search_direction_restarts);
+ LOG(WARNING) << summary->error << " terminating optimization.";
+ summary->termination_type = NUMERICAL_FAILURE;
+ break;
+
+ } else if (!line_search_status) {
+ // Restart line search direction with gradient descent on first iteration
+ // as we have not yet reached our maximum number of restarts.
+ CHECK_LT(num_line_search_direction_restarts,
+ options.max_num_line_search_direction_restarts);
+
+ ++num_line_search_direction_restarts;
+ LOG(WARNING)
+ << "Line search direction algorithm: "
+ << LineSearchDirectionTypeToString(options.line_search_direction_type)
+ << ", failed to produce a valid new direction at iteration: "
+ << iteration_summary.iteration << ". Restarting, number of "
+ << "restarts: " << num_line_search_direction_restarts << " / "
+ << options.max_num_line_search_direction_restarts << " [max].";
+ line_search_direction.reset(
+ LineSearchDirection::Create(line_search_direction_options));
current_state.search_direction = -current_state.gradient;
}
@@ -227,16 +273,34 @@
// TODO(sameeragarwal): Refactor this into its own object and add
// explanations for the various choices.
- const double initial_step_size = (iteration_summary.iteration == 1)
+ //
+ // Note that we use !line_search_status to ensure that we treat cases when
+ // we restarted the line search direction equivalently to the first
+ // iteration.
+ const double initial_step_size =
+ (iteration_summary.iteration == 1 || !line_search_status)
? min(1.0, 1.0 / current_state.gradient_max_norm)
: min(1.0, 2.0 * (current_state.cost - previous_state.cost) /
current_state.directional_derivative);
+ // By definition, we should only ever go forwards along the specified search
+ // direction in a line search, most likely cause for this being violated
+ // would be a numerical failure in the line search direction calculation.
+ if (initial_step_size < 0.0) {
+ summary->error =
+ StringPrintf("Numerical failure in line search, initial_step_size is "
+ "negative: %.5e, directional_derivative: %.5e, "
+ "(current_cost - previous_cost): %.5e",
+ initial_step_size, current_state.directional_derivative,
+ (current_state.cost - previous_state.cost));
+ LOG(WARNING) << summary->error;
+ summary->termination_type = NUMERICAL_FAILURE;
+ break;
+ }
- line_search.Search(line_search_options,
- initial_step_size,
- current_state.cost,
- current_state.directional_derivative,
- &line_search_summary);
+ line_search->Search(initial_step_size,
+ current_state.cost,
+ current_state.directional_derivative,
+ &line_search_summary);
current_state.step_size = line_search_summary.optimal_step_size;
delta = current_state.step_size * current_state.search_direction;
@@ -282,7 +346,11 @@
iteration_summary.step_norm = delta.norm();
iteration_summary.step_size = current_state.step_size;
iteration_summary.line_search_function_evaluations =
- line_search_summary.num_evaluations;
+ line_search_summary.num_function_evaluations;
+ iteration_summary.line_search_gradient_evaluations =
+ line_search_summary.num_gradient_evaluations;
+ iteration_summary.line_search_iterations =
+ line_search_summary.num_iterations;
iteration_summary.iteration_time_in_seconds =
WallTimeInSeconds() - iteration_start_time;
iteration_summary.cumulative_time_in_seconds =