blob: 8fd568ffcc3db6977afa2f660992a00fc78e5a0b [file] [log] [blame]
Keir Mierle8ebb0732012-04-30 23:09:08 -07001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
3// http://code.google.com/p/ceres-solver/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30
31#include "ceres/compressed_row_sparse_matrix.h"
32
33#include <algorithm>
34#include <vector>
35#include "ceres/matrix_proto.h"
36#include "ceres/internal/port.h"
37
38namespace ceres {
39namespace internal {
40namespace {
41
42// Helper functor used by the constructor for reordering the contents
43// of a TripletSparseMatrix.
44struct RowColLessThan {
45 RowColLessThan(const int* rows, const int* cols)
46 : rows(rows), cols(cols) {
47 }
48
49 bool operator()(const int x, const int y) const {
50 if (rows[x] == rows[y]) {
51 return (cols[x] < cols[y]);
52 }
53 return (rows[x] < rows[y]);
54 }
55
56 const int* rows;
57 const int* cols;
58};
59
60} // namespace
61
62// This constructor gives you a semi-initialized CompressedRowSparseMatrix.
63CompressedRowSparseMatrix::CompressedRowSparseMatrix(int num_rows,
64 int num_cols,
65 int max_num_nonzeros) {
66 num_rows_ = num_rows;
67 num_cols_ = num_cols;
68 max_num_nonzeros_ = max_num_nonzeros;
69
70 VLOG(1) << "# of rows: " << num_rows_ << " # of columns: " << num_cols_
71 << " max_num_nonzeros: " << max_num_nonzeros_
72 << ". Allocating " << (num_rows_ + 1) * sizeof(int) + // NOLINT
73 max_num_nonzeros_ * sizeof(int) + // NOLINT
74 max_num_nonzeros_ * sizeof(double); // NOLINT
75
76 rows_.reset(new int[num_rows_ + 1]);
77 cols_.reset(new int[max_num_nonzeros_]);
78 values_.reset(new double[max_num_nonzeros_]);
79
80 fill(rows_.get(), rows_.get() + num_rows_ + 1, 0);
81 fill(cols_.get(), cols_.get() + max_num_nonzeros_, 0);
82 fill(values_.get(), values_.get() + max_num_nonzeros_, 0);
83}
84
85CompressedRowSparseMatrix::CompressedRowSparseMatrix(
86 const TripletSparseMatrix& m) {
87 num_rows_ = m.num_rows();
88 num_cols_ = m.num_cols();
89 max_num_nonzeros_ = m.max_num_nonzeros();
90
91 // index is the list of indices into the TripletSparseMatrix m.
92 vector<int> index(m.num_nonzeros(), 0);
93 for (int i = 0; i < m.num_nonzeros(); ++i) {
94 index[i] = i;
95 }
96
97 // Sort index such that the entries of m are ordered by row and ties
98 // are broken by column.
99 sort(index.begin(), index.end(), RowColLessThan(m.rows(), m.cols()));
100
101 VLOG(1) << "# of rows: " << num_rows_ << " # of columns: " << num_cols_
102 << " max_num_nonzeros: " << max_num_nonzeros_
103 << ". Allocating " << (num_rows_ + 1) * sizeof(int) + // NOLINT
104 max_num_nonzeros_ * sizeof(int) + // NOLINT
105 max_num_nonzeros_ * sizeof(double); // NOLINT
106
107 rows_.reset(new int[num_rows_ + 1]);
108 cols_.reset(new int[max_num_nonzeros_]);
109 values_.reset(new double[max_num_nonzeros_]);
110
111 // rows_ = 0
112 fill(rows_.get(), rows_.get() + num_rows_ + 1, 0);
113
114 // Copy the contents of the cols and values array in the order given
115 // by index and count the number of entries in each row.
116 for (int i = 0; i < m.num_nonzeros(); ++i) {
117 const int idx = index[i];
118 ++rows_[m.rows()[idx] + 1];
119 cols_[i] = m.cols()[idx];
120 values_[i] = m.values()[idx];
121 }
122
123 // Find the cumulative sum of the row counts.
124 for (int i = 1; i < num_rows_ + 1; ++i) {
125 rows_[i] += rows_[i-1];
126 }
127
128 CHECK_EQ(num_nonzeros(), m.num_nonzeros());
129}
130
131#ifndef CERES_DONT_HAVE_PROTOCOL_BUFFERS
132CompressedRowSparseMatrix::CompressedRowSparseMatrix(
133 const SparseMatrixProto& outer_proto) {
134 CHECK(outer_proto.has_compressed_row_matrix());
135
136 const CompressedRowSparseMatrixProto& proto =
137 outer_proto.compressed_row_matrix();
138
139 num_rows_ = proto.num_rows();
140 num_cols_ = proto.num_cols();
141
142 rows_.reset(new int[proto.rows_size()]);
143 cols_.reset(new int[proto.cols_size()]);
144 values_.reset(new double[proto.values_size()]);
145
146 for (int i = 0; i < proto.rows_size(); ++i) {
147 rows_[i] = proto.rows(i);
148 }
149
150 CHECK_EQ(proto.rows_size(), num_rows_ + 1);
151 CHECK_EQ(proto.cols_size(), proto.values_size());
152 CHECK_EQ(proto.cols_size(), rows_[num_rows_]);
153
154 for (int i = 0; i < proto.cols_size(); ++i) {
155 cols_[i] = proto.cols(i);
156 values_[i] = proto.values(i);
157 }
158
159 max_num_nonzeros_ = proto.cols_size();
160}
161#endif
162
163CompressedRowSparseMatrix::CompressedRowSparseMatrix(const double* diagonal,
164 int num_rows) {
165 CHECK_NOTNULL(diagonal);
166
167 num_rows_ = num_rows;
168 num_cols_ = num_rows;
169 max_num_nonzeros_ = num_rows;
170
171 rows_.reset(new int[num_rows_ + 1]);
172 cols_.reset(new int[num_rows_]);
173 values_.reset(new double[num_rows_]);
174
175 rows_[0] = 0;
176 for (int i = 0; i < num_rows_; ++i) {
177 cols_[i] = i;
178 values_[i] = diagonal[i];
179 rows_[i + 1] = i + 1;
180 }
181
182 CHECK_EQ(num_nonzeros(), num_rows);
183}
184
185CompressedRowSparseMatrix::~CompressedRowSparseMatrix() {
186}
187
188void CompressedRowSparseMatrix::SetZero() {
189 fill(values_.get(), values_.get() + num_nonzeros(), 0.0);
190}
191
192void CompressedRowSparseMatrix::RightMultiply(const double* x,
193 double* y) const {
194 CHECK_NOTNULL(x);
195 CHECK_NOTNULL(y);
196
197 for (int r = 0; r < num_rows_; ++r) {
198 for (int idx = rows_[r]; idx < rows_[r + 1]; ++idx) {
199 y[r] += values_[idx] * x[cols_[idx]];
200 }
201 }
202}
203
204void CompressedRowSparseMatrix::LeftMultiply(const double* x, double* y) const {
205 CHECK_NOTNULL(x);
206 CHECK_NOTNULL(y);
207
208 for (int r = 0; r < num_rows_; ++r) {
209 for (int idx = rows_[r]; idx < rows_[r + 1]; ++idx) {
210 y[cols_[idx]] += values_[idx] * x[r];
211 }
212 }
213}
214
215void CompressedRowSparseMatrix::SquaredColumnNorm(double* x) const {
216 CHECK_NOTNULL(x);
217
218 fill(x, x + num_cols_, 0.0);
219 for (int idx = 0; idx < rows_[num_rows_]; ++idx) {
220 x[cols_[idx]] += values_[idx] * values_[idx];
221 }
222}
223
224void CompressedRowSparseMatrix::ScaleColumns(const double* scale) {
225 CHECK_NOTNULL(scale);
226
227 for (int idx = 0; idx < rows_[num_rows_]; ++idx) {
228 values_[idx] *= scale[cols_[idx]];
229 }
230}
231
232void CompressedRowSparseMatrix::ToDenseMatrix(Matrix* dense_matrix) const {
233 CHECK_NOTNULL(dense_matrix);
234 dense_matrix->resize(num_rows_, num_cols_);
235 dense_matrix->setZero();
236
237 for (int r = 0; r < num_rows_; ++r) {
238 for (int idx = rows_[r]; idx < rows_[r + 1]; ++idx) {
239 (*dense_matrix)(r, cols_[idx]) = values_[idx];
240 }
241 }
242}
243
244#ifndef CERES_DONT_HAVE_PROTOCOL_BUFFERS
245void CompressedRowSparseMatrix::ToProto(SparseMatrixProto* outer_proto) const {
246 CHECK_NOTNULL(outer_proto);
247
248 outer_proto->Clear();
249 CompressedRowSparseMatrixProto* proto
250 = outer_proto->mutable_compressed_row_matrix();
251
252 proto->set_num_rows(num_rows_);
253 proto->set_num_cols(num_cols_);
254
255 for (int r = 0; r < num_rows_ + 1; ++r) {
256 proto->add_rows(rows_[r]);
257 }
258
259 for (int idx = 0; idx < rows_[num_rows_]; ++idx) {
260 proto->add_cols(cols_[idx]);
261 proto->add_values(values_[idx]);
262 }
263}
264#endif
265
266void CompressedRowSparseMatrix::DeleteRows(int delta_rows) {
267 CHECK_GE(delta_rows, 0);
268 CHECK_LE(delta_rows, num_rows_);
269
270 int new_num_rows = num_rows_ - delta_rows;
271
272 num_rows_ = new_num_rows;
273 int* new_rows = new int[num_rows_ + 1];
274 copy(rows_.get(), rows_.get() + num_rows_ + 1, new_rows);
275 rows_.reset(new_rows);
276}
277
278void CompressedRowSparseMatrix::AppendRows(const CompressedRowSparseMatrix& m) {
279 CHECK_EQ(m.num_cols(), num_cols_);
280
281 // Check if there is enough space. If not, then allocate new arrays
282 // to hold the combined matrix and copy the contents of this matrix
283 // into it.
284 if (max_num_nonzeros_ < num_nonzeros() + m.num_nonzeros()) {
285 int new_max_num_nonzeros = num_nonzeros() + m.num_nonzeros();
286
287 VLOG(1) << "Reallocating " << sizeof(int) * new_max_num_nonzeros; // NOLINT
288
289 int* new_cols = new int[new_max_num_nonzeros];
290 copy(cols_.get(), cols_.get() + max_num_nonzeros_, new_cols);
291 cols_.reset(new_cols);
292
293 double* new_values = new double[new_max_num_nonzeros];
294 copy(values_.get(), values_.get() + max_num_nonzeros_, new_values);
295 values_.reset(new_values);
296
297 max_num_nonzeros_ = new_max_num_nonzeros;
298 }
299
300 // Copy the contents of m into this matrix.
301 copy(m.cols(), m.cols() + m.num_nonzeros(), cols_.get() + num_nonzeros());
302 copy(m.values(),
303 m.values() + m.num_nonzeros(),
304 values_.get() + num_nonzeros());
305
306 // Create the new rows array to hold the enlarged matrix.
307 int* new_rows = new int[num_rows_ + m.num_rows() + 1];
308 // The first num_rows_ entries are the same
309 copy(rows_.get(), rows_.get() + num_rows_, new_rows);
310
311 // new_rows = [rows_, m.row() + rows_[num_rows_]]
312 fill(new_rows + num_rows_,
313 new_rows + num_rows_ + m.num_rows() + 1,
314 rows_[num_rows_]);
315
316 for (int r = 0; r < m.num_rows() + 1; ++r) {
317 new_rows[num_rows_ + r] += m.rows()[r];
318 }
319
320 rows_.reset(new_rows);
321 num_rows_ += m.num_rows();
322}
323
324} // namespace internal
325} // namespace ceres