blob: 9e3d8bdbb9827ee93435d90f0e0ab0d3b719ed14 [file] [log] [blame]
Keir Mierle8ebb0732012-04-30 23:09:08 -07001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
3// http://code.google.com/p/ceres-solver/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30
31#include "ceres/linear_least_squares_problems.h"
32
33#include <string>
34#include <vector>
35#include <glog/logging.h>
36#include "ceres/block_sparse_matrix.h"
37#include "ceres/block_structure.h"
38#include "ceres/compressed_row_sparse_matrix.h"
39#include "ceres/file.h"
40#include "ceres/matrix_proto.h"
41#include "ceres/triplet_sparse_matrix.h"
42#include "ceres/internal/scoped_ptr.h"
43#include "ceres/types.h"
44
45namespace ceres {
46namespace internal {
47
48LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromId(int id) {
49 switch (id) {
50 case 0:
51 return LinearLeastSquaresProblem0();
52 case 1:
53 return LinearLeastSquaresProblem1();
54 case 2:
55 return LinearLeastSquaresProblem2();
56 case 3:
57 return LinearLeastSquaresProblem3();
58 default:
59 LOG(FATAL) << "Unknown problem id requested " << id;
60 }
61}
62
63#ifndef CERES_DONT_HAVE_PROTOCOL_BUFFERS
64LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromFile(
65 const string& filename) {
66 LinearLeastSquaresProblemProto problem_proto;
67 {
68 string serialized_proto;
69 ReadFileToStringOrDie(filename, &serialized_proto);
70 CHECK(problem_proto.ParseFromString(serialized_proto));
71 }
72
73 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
74 const SparseMatrixProto& A = problem_proto.a();
75
76 if (A.has_block_matrix()) {
77 problem->A.reset(new BlockSparseMatrix(A));
78 } else if (A.has_triplet_matrix()) {
79 problem->A.reset(new TripletSparseMatrix(A));
80 } else {
81 problem->A.reset(new CompressedRowSparseMatrix(A));
82 }
83
84 if (problem_proto.b_size() > 0) {
85 problem->b.reset(new double[problem_proto.b_size()]);
86 for (int i = 0; i < problem_proto.b_size(); ++i) {
87 problem->b[i] = problem_proto.b(i);
88 }
89 }
90
91 if (problem_proto.d_size() > 0) {
92 problem->D.reset(new double[problem_proto.d_size()]);
93 for (int i = 0; i < problem_proto.d_size(); ++i) {
94 problem->D[i] = problem_proto.d(i);
95 }
96 }
97
98 if (problem_proto.d_size() > 0) {
99 if (problem_proto.x_size() > 0) {
100 problem->x_D.reset(new double[problem_proto.x_size()]);
101 for (int i = 0; i < problem_proto.x_size(); ++i) {
102 problem->x_D[i] = problem_proto.x(i);
103 }
104 }
105 } else {
106 if (problem_proto.x_size() > 0) {
107 problem->x.reset(new double[problem_proto.x_size()]);
108 for (int i = 0; i < problem_proto.x_size(); ++i) {
109 problem->x[i] = problem_proto.x(i);
110 }
111 }
112 }
113
114 problem->num_eliminate_blocks = 0;
115 if (problem_proto.has_num_eliminate_blocks()) {
116 problem->num_eliminate_blocks = problem_proto.num_eliminate_blocks();
117 }
118
119 return problem;
120}
121#else
122LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromFile(
123 const string& filename) {
124 LOG(FATAL)
125 << "Loading a least squares problem from disk requires "
126 << "Ceres to be built with Protocol Buffers support.";
127 return NULL;
128}
129#endif // CERES_DONT_HAVE_PROTOCOL_BUFFERS
130
131/*
132A = [1 2]
133 [3 4]
134 [6 -10]
135
136b = [ 8
137 18
138 -18]
139
140x = [2
141 3]
142
143D = [1
144 2]
145
146x_D = [1.78448275;
147 2.82327586;]
148 */
149LinearLeastSquaresProblem* LinearLeastSquaresProblem0() {
150 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
151
152 TripletSparseMatrix* A = new TripletSparseMatrix(3, 2, 6);
153 problem->b.reset(new double[3]);
154 problem->D.reset(new double[2]);
155
156 problem->x.reset(new double[2]);
157 problem->x_D.reset(new double[2]);
158
159 int* Ai = A->mutable_rows();
160 int* Aj = A->mutable_cols();
161 double* Ax = A->mutable_values();
162
163 int counter = 0;
164 for (int i = 0; i < 3; ++i) {
165 for (int j = 0; j< 2; ++j) {
166 Ai[counter]=i;
167 Aj[counter]=j;
168 ++counter;
169 }
170 };
171
172 Ax[0] = 1.;
173 Ax[1] = 2.;
174 Ax[2] = 3.;
175 Ax[3] = 4.;
176 Ax[4] = 6;
177 Ax[5] = -10;
178 A->set_num_nonzeros(6);
179 problem->A.reset(A);
180
181 problem->b[0] = 8;
182 problem->b[1] = 18;
183 problem->b[2] = -18;
184
185 problem->x[0] = 2.0;
186 problem->x[1] = 3.0;
187
188 problem->D[0] = 1;
189 problem->D[1] = 2;
190
191 problem->x_D[0] = 1.78448275;
192 problem->x_D[1] = 2.82327586;
193 return problem;
194}
195
196
197/*
198 A = [1 0 | 2 0 0
199 3 0 | 0 4 0
200 0 5 | 0 0 6
201 0 7 | 8 0 0
202 0 9 | 1 0 0
203 0 0 | 1 1 1]
204
205 b = [0
206 1
207 2
208 3
209 4
210 5]
211
212 c = A'* b = [ 3
213 67
214 33
215 9
216 17]
217
218 A'A = [10 0 2 12 0
219 0 155 65 0 30
220 2 65 70 1 1
221 12 0 1 17 1
222 0 30 1 1 37]
223
224 S = [ 42.3419 -1.4000 -11.5806
225 -1.4000 2.6000 1.0000
226 11.5806 1.0000 31.1935]
227
228 r = [ 4.3032
229 5.4000
230 5.0323]
231
232 S\r = [ 0.2102
233 2.1367
234 0.1388]
235
236 A\b = [-2.3061
237 0.3172
238 0.2102
239 2.1367
240 0.1388]
241*/
242// The following two functions create a TripletSparseMatrix and a
243// BlockSparseMatrix version of this problem.
244
245// TripletSparseMatrix version.
246LinearLeastSquaresProblem* LinearLeastSquaresProblem1() {
247 int num_rows = 6;
248 int num_cols = 5;
249
250 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
251 TripletSparseMatrix* A = new TripletSparseMatrix(num_rows,
252 num_cols,
253 num_rows * num_cols);
254 problem->b.reset(new double[num_rows]);
255 problem->D.reset(new double[num_cols]);
256 problem->num_eliminate_blocks = 2;
257
258 int* rows = A->mutable_rows();
259 int* cols = A->mutable_cols();
260 double* values = A->mutable_values();
261
262 int nnz = 0;
263
264 // Row 1
265 {
266 rows[nnz] = 0;
267 cols[nnz] = 0;
268 values[nnz++] = 1;
269
270 rows[nnz] = 0;
271 cols[nnz] = 2;
272 values[nnz++] = 2;
273 }
274
275 // Row 2
276 {
277 rows[nnz] = 1;
278 cols[nnz] = 0;
279 values[nnz++] = 3;
280
281 rows[nnz] = 1;
282 cols[nnz] = 3;
283 values[nnz++] = 4;
284 }
285
286 // Row 3
287 {
288 rows[nnz] = 2;
289 cols[nnz] = 1;
290 values[nnz++] = 5;
291
292 rows[nnz] = 2;
293 cols[nnz] = 4;
294 values[nnz++] = 6;
295 }
296
297 // Row 4
298 {
299 rows[nnz] = 3;
300 cols[nnz] = 1;
301 values[nnz++] = 7;
302
303 rows[nnz] = 3;
304 cols[nnz] = 2;
305 values[nnz++] = 8;
306 }
307
308 // Row 5
309 {
310 rows[nnz] = 4;
311 cols[nnz] = 1;
312 values[nnz++] = 9;
313
314 rows[nnz] = 4;
315 cols[nnz] = 2;
316 values[nnz++] = 1;
317 }
318
319 // Row 6
320 {
321 rows[nnz] = 5;
322 cols[nnz] = 2;
323 values[nnz++] = 1;
324
325 rows[nnz] = 5;
326 cols[nnz] = 3;
327 values[nnz++] = 1;
328
329 rows[nnz] = 5;
330 cols[nnz] = 4;
331 values[nnz++] = 1;
332 }
333
334 A->set_num_nonzeros(nnz);
335 CHECK(A->IsValid());
336
337 problem->A.reset(A);
338
339 for (int i = 0; i < num_cols; ++i) {
340 problem->D.get()[i] = 1;
341 }
342
343 for (int i = 0; i < num_rows; ++i) {
344 problem->b.get()[i] = i;
345 }
346
347 return problem;
348}
349
350// BlockSparseMatrix version
351LinearLeastSquaresProblem* LinearLeastSquaresProblem2() {
352 int num_rows = 6;
353 int num_cols = 5;
354
355 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
356
357 problem->b.reset(new double[num_rows]);
358 problem->D.reset(new double[num_cols]);
359 problem->num_eliminate_blocks = 2;
360
361 CompressedRowBlockStructure* bs = new CompressedRowBlockStructure;
362 scoped_array<double> values(new double[num_rows * num_cols]);
363
364 for (int c = 0; c < num_cols; ++c) {
365 bs->cols.push_back(Block());
366 bs->cols.back().size = 1;
367 bs->cols.back().position = c;
368 }
369
370 int nnz = 0;
371
372 // Row 1
373 {
374 values[nnz++] = 1;
375 values[nnz++] = 2;
376
377 bs->rows.push_back(CompressedRow());
378 CompressedRow& row = bs->rows.back();
379 row.block.size = 1;
380 row.block.position = 0;
381 row.cells.push_back(Cell(0, 0));
382 row.cells.push_back(Cell(2, 1));
383 }
384
385 // Row 2
386 {
387 values[nnz++] = 3;
388 values[nnz++] = 4;
389
390 bs->rows.push_back(CompressedRow());
391 CompressedRow& row = bs->rows.back();
392 row.block.size = 1;
393 row.block.position = 1;
394 row.cells.push_back(Cell(0, 2));
395 row.cells.push_back(Cell(3, 3));
396 }
397
398 // Row 3
399 {
400 values[nnz++] = 5;
401 values[nnz++] = 6;
402
403 bs->rows.push_back(CompressedRow());
404 CompressedRow& row = bs->rows.back();
405 row.block.size = 1;
406 row.block.position = 2;
407 row.cells.push_back(Cell(1, 4));
408 row.cells.push_back(Cell(4, 5));
409 }
410
411 // Row 4
412 {
413 values[nnz++] = 7;
414 values[nnz++] = 8;
415
416 bs->rows.push_back(CompressedRow());
417 CompressedRow& row = bs->rows.back();
418 row.block.size = 1;
419 row.block.position = 3;
420 row.cells.push_back(Cell(1, 6));
421 row.cells.push_back(Cell(2, 7));
422 }
423
424 // Row 5
425 {
426 values[nnz++] = 9;
427 values[nnz++] = 1;
428
429 bs->rows.push_back(CompressedRow());
430 CompressedRow& row = bs->rows.back();
431 row.block.size = 1;
432 row.block.position = 4;
433 row.cells.push_back(Cell(1, 8));
434 row.cells.push_back(Cell(2, 9));
435 }
436
437 // Row 6
438 {
439 values[nnz++] = 1;
440 values[nnz++] = 1;
441 values[nnz++] = 1;
442
443 bs->rows.push_back(CompressedRow());
444 CompressedRow& row = bs->rows.back();
445 row.block.size = 1;
446 row.block.position = 5;
447 row.cells.push_back(Cell(2, 10));
448 row.cells.push_back(Cell(3, 11));
449 row.cells.push_back(Cell(4, 12));
450 }
451
452 BlockSparseMatrix* A = new BlockSparseMatrix(bs);
453 memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values()));
454
455 for (int i = 0; i < num_cols; ++i) {
456 problem->D.get()[i] = 1;
457 }
458
459 for (int i = 0; i < num_rows; ++i) {
460 problem->b.get()[i] = i;
461 }
462
463 problem->A.reset(A);
464
465 return problem;
466}
467
468
469/*
470 A = [1 0
471 3 0
472 0 5
473 0 7
474 0 9
475 0 0]
476
477 b = [0
478 1
479 2
480 3
481 4
482 5]
483*/
484// BlockSparseMatrix version
485LinearLeastSquaresProblem* LinearLeastSquaresProblem3() {
486 int num_rows = 5;
487 int num_cols = 2;
488
489 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
490
491 problem->b.reset(new double[num_rows]);
492 problem->D.reset(new double[num_cols]);
493 problem->num_eliminate_blocks = 2;
494
495 CompressedRowBlockStructure* bs = new CompressedRowBlockStructure;
496 scoped_array<double> values(new double[num_rows * num_cols]);
497
498 for (int c = 0; c < num_cols; ++c) {
499 bs->cols.push_back(Block());
500 bs->cols.back().size = 1;
501 bs->cols.back().position = c;
502 }
503
504 int nnz = 0;
505
506 // Row 1
507 {
508 values[nnz++] = 1;
509 bs->rows.push_back(CompressedRow());
510 CompressedRow& row = bs->rows.back();
511 row.block.size = 1;
512 row.block.position = 0;
513 row.cells.push_back(Cell(0, 0));
514 }
515
516 // Row 2
517 {
518 values[nnz++] = 3;
519 bs->rows.push_back(CompressedRow());
520 CompressedRow& row = bs->rows.back();
521 row.block.size = 1;
522 row.block.position = 1;
523 row.cells.push_back(Cell(0, 1));
524 }
525
526 // Row 3
527 {
528 values[nnz++] = 5;
529 bs->rows.push_back(CompressedRow());
530 CompressedRow& row = bs->rows.back();
531 row.block.size = 1;
532 row.block.position = 2;
533 row.cells.push_back(Cell(1, 2));
534 }
535
536 // Row 4
537 {
538 values[nnz++] = 7;
539 bs->rows.push_back(CompressedRow());
540 CompressedRow& row = bs->rows.back();
541 row.block.size = 1;
542 row.block.position = 3;
543 row.cells.push_back(Cell(1, 3));
544 }
545
546 // Row 5
547 {
548 values[nnz++] = 9;
549 bs->rows.push_back(CompressedRow());
550 CompressedRow& row = bs->rows.back();
551 row.block.size = 1;
552 row.block.position = 4;
553 row.cells.push_back(Cell(1, 4));
554 }
555
556 BlockSparseMatrix* A = new BlockSparseMatrix(bs);
557 memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values()));
558
559 for (int i = 0; i < num_cols; ++i) {
560 problem->D.get()[i] = 1;
561 }
562
563 for (int i = 0; i < num_rows; ++i) {
564 problem->b.get()[i] = i;
565 }
566
567 problem->A.reset(A);
568
569 return problem;
570}
571
572} // namespace internal
573} // namespace ceres