blob: 020abfa6bae872ec604a6ff84a88d70580b37943 [file] [log] [blame]
Keir Mierle8ebb0732012-04-30 23:09:08 -07001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
3// http://code.google.com/p/ceres-solver/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: keir@google.com (Keir Mierle)
30//
31// This tests the Levenberg-Marquardt loop using a direct Evaluator
32// implementation, rather than having a test that goes through all the Program
33// and Problem machinery.
34
35#include <cmath>
36#include "ceres/dense_qr_solver.h"
37#include "ceres/dense_sparse_matrix.h"
38#include "ceres/evaluator.h"
39#include "ceres/levenberg_marquardt.h"
40#include "ceres/linear_solver.h"
41#include "ceres/minimizer.h"
42#include "ceres/internal/port.h"
43#include "gtest/gtest.h"
44
45namespace ceres {
46namespace internal {
47
48// Templated Evaluator for Powell's function. The template parameters
49// indicate which of the four variables/columns of the jacobian are
50// active. This is equivalent to constructing a problem and using the
51// SubsetLocalParameterization. This allows us to test the support for
52// the Evaluator::Plus operation besides checking for the basic
53// performance of the LevenbergMarquardt algorithm.
54template <bool col1, bool col2, bool col3, bool col4>
55class PowellEvaluator2 : public Evaluator {
56 public:
57 PowellEvaluator2()
58 : num_active_cols_(
59 (col1 ? 1 : 0) +
60 (col2 ? 1 : 0) +
61 (col3 ? 1 : 0) +
62 (col4 ? 1 : 0)) {
63 VLOG(1) << "Columns: "
64 << col1 << " "
65 << col2 << " "
66 << col3 << " "
67 << col4;
68 }
69
70 virtual ~PowellEvaluator2() {}
71
72 // Implementation of Evaluator interface.
73 virtual SparseMatrix* CreateJacobian() const {
74 CHECK(col1 || col2 || col3 || col4);
75 DenseSparseMatrix* dense_jacobian =
76 new DenseSparseMatrix(NumResiduals(), NumEffectiveParameters());
77 dense_jacobian->SetZero();
78 return dense_jacobian;
79 }
80
81 virtual bool Evaluate(const double* state,
82 double* cost,
83 double* residuals,
84 SparseMatrix* jacobian) {
85 double x1 = state[0];
86 double x2 = state[1];
87 double x3 = state[2];
88 double x4 = state[3];
89
90 VLOG(1) << "State: "
91 << "x1=" << x1 << ", "
92 << "x2=" << x2 << ", "
93 << "x3=" << x3 << ", "
94 << "x4=" << x4 << ".";
95
96 double f1 = x1 + 10.0 * x2;
97 double f2 = sqrt(5.0) * (x3 - x4);
98 double f3 = pow(x2 - 2.0 * x3, 2.0);
99 double f4 = sqrt(10.0) * pow(x1 - x4, 2.0);
100
101 VLOG(1) << "Function: "
102 << "f1=" << f1 << ", "
103 << "f2=" << f2 << ", "
104 << "f3=" << f3 << ", "
105 << "f4=" << f4 << ".";
106
107 *cost = (f1*f1 + f2*f2 + f3*f3 + f4*f4) / 2.0;
108
109 VLOG(1) << "Cost: " << *cost;
110
111 if (residuals != NULL) {
112 residuals[0] = f1;
113 residuals[1] = f2;
114 residuals[2] = f3;
115 residuals[3] = f4;
116 }
117
118 if (jacobian != NULL) {
119 DenseSparseMatrix* dense_jacobian;
120 dense_jacobian = down_cast<DenseSparseMatrix*>(jacobian);
121 dense_jacobian->SetZero();
122
123 AlignedMatrixRef jacobian_matrix = dense_jacobian->mutable_matrix();
124 CHECK_EQ(jacobian_matrix.cols(), num_active_cols_);
125
126 int column_index = 0;
127 if (col1) {
128 jacobian_matrix.col(column_index++) <<
129 1.0,
130 0.0,
131 0.0,
132 sqrt(10) * 2.0 * (x1 - x4) * (1.0 - x4);
133 }
134 if (col2) {
135 jacobian_matrix.col(column_index++) <<
136 10.0,
137 0.0,
138 2.0*(x2 - 2.0*x3)*(1.0 - 2.0*x3),
139 0.0;
140 }
141
142 if (col3) {
143 jacobian_matrix.col(column_index++) <<
144 0.0,
145 sqrt(5.0),
146 2.0*(x2 - 2.0*x3)*(x2 - 2.0),
147 0.0;
148 }
149
150 if (col4) {
151 jacobian_matrix.col(column_index++) <<
152 0.0,
153 -sqrt(5.0),
154 0.0,
155 sqrt(10) * 2.0 * (x1 - x4) * (x1 - 1.0);
156 }
157 VLOG(1) << "\n" << jacobian_matrix;
158 }
159 return true;
160 }
161
162 virtual bool Plus(const double* state,
163 const double* delta,
164 double* state_plus_delta) const {
165 int delta_index = 0;
166 state_plus_delta[0] = (col1 ? state[0] + delta[delta_index++] : state[0]);
167 state_plus_delta[1] = (col2 ? state[1] + delta[delta_index++] : state[1]);
168 state_plus_delta[2] = (col3 ? state[2] + delta[delta_index++] : state[2]);
169 state_plus_delta[3] = (col4 ? state[3] + delta[delta_index++] : state[3]);
170 return true;
171 }
172
173 virtual int NumEffectiveParameters() const { return num_active_cols_; }
174 virtual int NumParameters() const { return 4; }
175 virtual int NumResiduals() const { return 4; }
176
177 private:
178 const int num_active_cols_;
179};
180
181// Templated function to hold a subset of the columns fixed and check
182// if the solver converges to the optimal values or not.
183template<bool col1, bool col2, bool col3, bool col4>
184void IsSolveSuccessful() {
185 LevenbergMarquardt lm;
186 Solver::Options solver_options;
187 Minimizer::Options minimizer_options(solver_options);
188 minimizer_options.gradient_tolerance = 1e-26;
189 minimizer_options.function_tolerance = 1e-26;
190 minimizer_options.parameter_tolerance = 1e-26;
191 LinearSolver::Options linear_solver_options;
192 DenseQRSolver linear_solver(linear_solver_options);
193
194 double initial_parameters[4] = { 3, -1, 0, 1.0 };
195 double final_parameters[4] = { -1.0, -1.0, -1.0, -1.0 };
196
197 // If the column is inactive, then set its value to the optimal
198 // value.
199 initial_parameters[0] = (col1 ? initial_parameters[0] : 0.0);
200 initial_parameters[1] = (col2 ? initial_parameters[1] : 0.0);
201 initial_parameters[2] = (col3 ? initial_parameters[2] : 0.0);
202 initial_parameters[3] = (col4 ? initial_parameters[3] : 0.0);
203
204 PowellEvaluator2<col1, col2, col3, col4> powell_evaluator;
205
206 Solver::Summary summary;
207 lm.Minimize(minimizer_options,
208 &powell_evaluator,
209 &linear_solver,
210 initial_parameters,
211 final_parameters,
212 &summary);
213
214 // The minimum is at x1 = x2 = x3 = x4 = 0.
215 EXPECT_NEAR(0.0, final_parameters[0], 0.001);
216 EXPECT_NEAR(0.0, final_parameters[1], 0.001);
217 EXPECT_NEAR(0.0, final_parameters[2], 0.001);
218 EXPECT_NEAR(0.0, final_parameters[3], 0.001);
219};
220
221TEST(LevenbergMarquardt, PowellsSingularFunction) {
222 // This case is excluded because this has a local minimum and does
223 // not find the optimum. This should not affect the correctness of
224 // this test since we are testing all the other 14 combinations of
225 // column activations.
226
227 // IsSolveSuccessful<true, true, false, true>();
228
229 IsSolveSuccessful<true, true, true, true>();
230 IsSolveSuccessful<true, true, true, false>();
231 IsSolveSuccessful<true, false, true, true>();
232 IsSolveSuccessful<false, true, true, true>();
233 IsSolveSuccessful<true, true, false, false>();
234 IsSolveSuccessful<true, false, true, false>();
235 IsSolveSuccessful<false, true, true, false>();
236 IsSolveSuccessful<true, false, false, true>();
237 IsSolveSuccessful<false, true, false, true>();
238 IsSolveSuccessful<false, false, true, true>();
239 IsSolveSuccessful<true, false, false, false>();
240 IsSolveSuccessful<false, true, false, false>();
241 IsSolveSuccessful<false, false, true, false>();
242 IsSolveSuccessful<false, false, false, true>();
243}
244
245
246} // namespace internal
247} // namespace ceres