blob: f44160748d205f8e603f0cbc30d53ea5657597e0 [file] [log] [blame]
Sameer Agarwal9883fc32012-11-30 12:32:43 -08001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2012 Google Inc. All rights reserved.
3// http://code.google.com/p/ceres-solver/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30
31#include "ceres/line_search_direction.h"
32#include "ceres/line_search_minimizer.h"
33#include "ceres/low_rank_inverse_hessian.h"
34#include "ceres/internal/eigen.h"
35#include "glog/logging.h"
36
37namespace ceres {
38namespace internal {
39
40class SteepestDescent : public LineSearchDirection {
41 public:
42 virtual ~SteepestDescent() {}
43 bool NextDirection(const LineSearchMinimizer::State& previous,
44 const LineSearchMinimizer::State& current,
45 Vector* search_direction) {
46 *search_direction = -current.gradient;
47 return true;
48 }
49};
50
51class NonlinearConjugateGradient : public LineSearchDirection {
52 public:
53 NonlinearConjugateGradient(const NonlinearConjugateGradientType type,
54 const double function_tolerance)
55 : type_(type),
56 function_tolerance_(function_tolerance) {
57 }
58
59 bool NextDirection(const LineSearchMinimizer::State& previous,
60 const LineSearchMinimizer::State& current,
61 Vector* search_direction) {
62 double beta = 0.0;
63 Vector gradient_change;
64 switch (type_) {
65 case FLETCHER_REEVES:
66 beta = current.gradient_squared_norm / previous.gradient_squared_norm;
67 break;
68 case POLAK_RIBIRERE:
69 gradient_change = current.gradient - previous.gradient;
70 beta = (current.gradient.dot(gradient_change) /
71 previous.gradient_squared_norm);
72 break;
73 case HESTENES_STIEFEL:
74 gradient_change = current.gradient - previous.gradient;
75 beta = (current.gradient.dot(gradient_change) /
76 previous.search_direction.dot(gradient_change));
77 break;
78 default:
79 LOG(FATAL) << "Unknown nonlinear conjugate gradient type: " << type_;
80 }
81
82 *search_direction = -current.gradient + beta * previous.search_direction;
83 const double directional_derivative = current. gradient.dot(*search_direction);
84 if (directional_derivative > -function_tolerance_) {
85 LOG(WARNING) << "Restarting non-linear conjugate gradients: "
86 << directional_derivative;
87 *search_direction = -current.gradient;
88 };
89
90 return true;
91 }
92
93 private:
94 const NonlinearConjugateGradientType type_;
95 const double function_tolerance_;
96};
97
98class LBFGS : public LineSearchDirection {
99 public:
100 LBFGS(const int num_parameters, const int max_lbfgs_rank)
101 : low_rank_inverse_hessian_(num_parameters, max_lbfgs_rank) {}
102
103 virtual ~LBFGS() {}
104
105 bool NextDirection(const LineSearchMinimizer::State& previous,
106 const LineSearchMinimizer::State& current,
107 Vector* search_direction) {
108 low_rank_inverse_hessian_.Update(
109 previous.search_direction * previous.step_size,
110 current.gradient - previous.gradient);
111 search_direction->setZero();
112 low_rank_inverse_hessian_.RightMultiply(current.gradient.data(),
113 search_direction->data());
114 *search_direction *= -1.0;
115 return true;
116 }
117
118 private:
119 LowRankInverseHessian low_rank_inverse_hessian_;
120};
121
122LineSearchDirection*
123LineSearchDirection::Create(LineSearchDirection::Options& options) {
124 if (options.type == STEEPEST_DESCENT) {
125 return new SteepestDescent;
126 }
127
128 if (options.type == NONLINEAR_CONJUGATE_GRADIENT) {
129 return new NonlinearConjugateGradient(
130 options.nonlinear_conjugate_gradient_type,
131 options.function_tolerance);
132 }
133
134 if (options.type == ceres::LBFGS) {
135 return new ceres::internal::LBFGS(options.num_parameters,
136 options.max_lbfgs_rank);
137 }
138
139 LOG(ERROR) << "Unknown line search direction type: " << options.type;
140 return NULL;
141}
142
143} // namespace internal
144} // namespace ceres