blob: 71d49d51d9bbd356824d0996b0a29b4b072c3e53 [file] [log] [blame]
Keir Mierle8ebb0732012-04-30 23:09:08 -07001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
3// http://code.google.com/p/ceres-solver/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30//
31// A preconditioned conjugate gradients solver
32// (ConjugateGradientsSolver) for positive semidefinite linear
33// systems.
34//
35// We have also augmented the termination criterion used by this
36// solver to support not just residual based termination but also
37// termination based on decrease in the value of the quadratic model
38// that CG optimizes.
39
40#include "ceres/conjugate_gradients_solver.h"
41
42#include <cmath>
43#include <cstddef>
44#include <glog/logging.h>
45#include "ceres/linear_operator.h"
46#include "ceres/internal/eigen.h"
47#include "ceres/types.h"
48
49namespace ceres {
50namespace internal {
51namespace {
52
53bool IsZeroOrInfinity(double x) {
54 return ((x == 0.0) || (isinf(x)));
55}
56
57// Constant used in the MATLAB implementation ~ 2 * eps.
58const double kEpsilon = 2.2204e-16;
59
60} // namespace
61
62ConjugateGradientsSolver::ConjugateGradientsSolver(
63 const LinearSolver::Options& options)
64 : options_(options) {
65}
66
67LinearSolver::Summary ConjugateGradientsSolver::Solve(
68 LinearOperator* A,
69 const double* b,
70 const LinearSolver::PerSolveOptions& per_solve_options,
71 double* x) {
72 CHECK_NOTNULL(A);
73 CHECK_NOTNULL(x);
74 CHECK_NOTNULL(b);
75 CHECK_EQ(A->num_rows(), A->num_cols());
76
77 LinearSolver::Summary summary;
78 summary.termination_type = MAX_ITERATIONS;
79 summary.num_iterations = 0;
80
81 int num_cols = A->num_cols();
82 VectorRef xref(x, num_cols);
83 ConstVectorRef bref(b, num_cols);
84
85 double norm_b = bref.norm();
86 if (norm_b == 0.0) {
87 xref.setZero();
88 summary.termination_type = TOLERANCE;
89 return summary;
90 }
91
92 Vector r(num_cols);
93 Vector p(num_cols);
94 Vector z(num_cols);
95 Vector tmp(num_cols);
96
97 double tol_r = per_solve_options.r_tolerance * norm_b;
98
99 tmp.setZero();
100 A->RightMultiply(x, tmp.data());
101 r = bref - tmp;
102 double norm_r = r.norm();
103
104 if (norm_r <= tol_r) {
105 summary.termination_type = TOLERANCE;
106 return summary;
107 }
108
109 double rho = 1.0;
110
111 // Initial value of the quadratic model Q = x'Ax - 2 * b'x.
112 double Q0 = -1.0 * xref.dot(bref + r);
113
114 for (summary.num_iterations = 1;
115 summary.num_iterations < options_.max_num_iterations;
116 ++summary.num_iterations) {
117 VLOG(2) << "cg iteration " << summary.num_iterations;
118
119 // Apply preconditioner
120 if (per_solve_options.preconditioner != NULL) {
121 z.setZero();
122 per_solve_options.preconditioner->RightMultiply(r.data(), z.data());
123 } else {
124 z = r;
125 }
126
127 double last_rho = rho;
128 rho = r.dot(z);
129
130 if (IsZeroOrInfinity(rho)) {
131 LOG(ERROR) << "Numerical failure. rho = " << rho;
132 summary.termination_type = FAILURE;
133 break;
134 };
135
136 if (summary.num_iterations == 1) {
137 p = z;
138 } else {
139 double beta = rho / last_rho;
140 if (IsZeroOrInfinity(beta)) {
141 LOG(ERROR) << "Numerical failure. beta = " << beta;
142 summary.termination_type = FAILURE;
143 break;
144 }
145 p = z + beta * p;
146 }
147
148 Vector& q = z;
149 q.setZero();
150 A->RightMultiply(p.data(), q.data());
151 double pq = p.dot(q);
152
153 if ((pq <= 0) || isinf(pq)) {
154 LOG(ERROR) << "Numerical failure. pq = " << pq;
155 summary.termination_type = FAILURE;
156 break;
157 }
158
159 double alpha = rho / pq;
160 if (isinf(alpha)) {
161 LOG(ERROR) << "Numerical failure. alpha " << alpha;
162 summary.termination_type = FAILURE;
163 break;
164 }
165
166 xref = xref + alpha * p;
167
168 // Ideally we would just use the update r = r - alpha*q to keep
169 // track of the residual vector. However this estimate tends to
170 // drift over time due to round off errors. Thus every
171 // residual_reset_period iterations, we calculate the residual as
172 // r = b - Ax. We do not do this every iteration because this
173 // requires an additional matrix vector multiply which would
174 // double the complexity of the CG algorithm.
175 if (summary.num_iterations % options_.residual_reset_period == 0) {
176 tmp.setZero();
177 A->RightMultiply(x, tmp.data());
178 r = bref - tmp;
179 } else {
180 r = r - alpha * q;
181 }
182
183 // Quadratic model based termination.
184 // Q1 = x'Ax - 2 * b' x.
185 double Q1 = -1.0 * xref.dot(bref + r);
186
187 // For PSD matrices A, let
188 //
189 // Q(x) = x'Ax - 2b'x
190 //
191 // be the cost of the quadratic function defined by A and b. Then,
192 // the solver terminates at iteration i if
193 //
194 // i * (Q(x_i) - Q(x_i-1)) / Q(x_i) < q_tolerance.
195 //
196 // This termination criterion is more useful when using CG to
197 // solve the Newton step. This particular convergence test comes
198 // from Stephen Nash's work on truncated Newton
199 // methods. References:
200 //
Keir Mierle0a359d62012-05-05 20:33:46 -0700201 // 1. Stephen G. Nash & Ariela Sofer, Assessing A Search
202 // Direction Within A Truncated Newton Method, Operation
203 // Research Letters 9(1990) 219-221.
204 //
205 // 2. Stephen G. Nash, A Survey of Truncated Newton Methods,
206 // Journal of Computational and Applied Mathematics,
207 // 124(1-2), 45-59, 2000.
Keir Mierle8ebb0732012-04-30 23:09:08 -0700208 //
Keir Mierle8ebb0732012-04-30 23:09:08 -0700209 double zeta = summary.num_iterations * (Q1 - Q0) / Q1;
210 VLOG(2) << "Q termination: zeta " << zeta
211 << " " << per_solve_options.q_tolerance;
212 if (zeta < per_solve_options.q_tolerance) {
213 summary.termination_type = TOLERANCE;
214 break;
215 }
216 Q0 = Q1;
217
218 // Residual based termination.
219 norm_r = r. norm();
220 VLOG(2) << "R termination: norm_r " << norm_r
221 << " " << tol_r;
222 if (norm_r <= tol_r) {
223 summary.termination_type = TOLERANCE;
224 break;
225 }
226 }
227
228 return summary;
229};
230
231} // namespace internal
232} // namespace ceres