// Ceres Solver - A fast non-linear least squares minimizer
// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
// http://code.google.com/p/ceres-solver/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
//   this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright notice,
//   this list of conditions and the following disclaimer in the documentation
//   and/or other materials provided with the distribution.
// * Neither the name of Google Inc. nor the names of its contributors may be
//   used to endorse or promote products derived from this software without
//   specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Author: sameeragarwal@google.com (Sameer Agarwal)

#include "ceres/linear_least_squares_problems.h"

#include <string>
#include <vector>
#include <glog/logging.h>
#include "ceres/block_sparse_matrix.h"
#include "ceres/block_structure.h"
#include "ceres/compressed_row_sparse_matrix.h"
#include "ceres/file.h"
#include "ceres/matrix_proto.h"
#include "ceres/triplet_sparse_matrix.h"
#include "ceres/internal/scoped_ptr.h"
#include "ceres/types.h"

namespace ceres {
namespace internal {

LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromId(int id) {
  switch (id) {
    case 0:
      return LinearLeastSquaresProblem0();
    case 1:
      return LinearLeastSquaresProblem1();
    case 2:
      return LinearLeastSquaresProblem2();
    case 3:
      return LinearLeastSquaresProblem3();
    default:
      LOG(FATAL) << "Unknown problem id requested " << id;
  }
}

#ifndef CERES_DONT_HAVE_PROTOCOL_BUFFERS
LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromFile(
    const string& filename) {
  LinearLeastSquaresProblemProto problem_proto;
  {
    string serialized_proto;
    ReadFileToStringOrDie(filename, &serialized_proto);
    CHECK(problem_proto.ParseFromString(serialized_proto));
  }

  LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
  const SparseMatrixProto& A = problem_proto.a();

  if (A.has_block_matrix()) {
    problem->A.reset(new BlockSparseMatrix(A));
  } else if (A.has_triplet_matrix()) {
    problem->A.reset(new TripletSparseMatrix(A));
  } else {
    problem->A.reset(new CompressedRowSparseMatrix(A));
  }

  if (problem_proto.b_size() > 0) {
    problem->b.reset(new double[problem_proto.b_size()]);
    for (int i = 0; i < problem_proto.b_size(); ++i) {
      problem->b[i] = problem_proto.b(i);
    }
  }

  if (problem_proto.d_size() > 0) {
    problem->D.reset(new double[problem_proto.d_size()]);
    for (int i = 0; i < problem_proto.d_size(); ++i) {
      problem->D[i] = problem_proto.d(i);
    }
  }

  if (problem_proto.d_size() > 0) {
    if (problem_proto.x_size() > 0) {
      problem->x_D.reset(new double[problem_proto.x_size()]);
      for (int i = 0; i < problem_proto.x_size(); ++i) {
        problem->x_D[i] = problem_proto.x(i);
      }
    }
  } else {
    if (problem_proto.x_size() > 0) {
      problem->x.reset(new double[problem_proto.x_size()]);
      for (int i = 0; i < problem_proto.x_size(); ++i) {
        problem->x[i] = problem_proto.x(i);
      }
    }
  }

  problem->num_eliminate_blocks = 0;
  if (problem_proto.has_num_eliminate_blocks()) {
    problem->num_eliminate_blocks = problem_proto.num_eliminate_blocks();
  }

  return problem;
}
#else
LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromFile(
    const string& filename) {
  LOG(FATAL)
      << "Loading a least squares problem from disk requires "
      << "Ceres to be built with Protocol Buffers support.";
  return NULL;
}
#endif  // CERES_DONT_HAVE_PROTOCOL_BUFFERS

/*
A = [1   2]
    [3   4]
    [6 -10]

b = [  8
      18
     -18]

x = [2
     3]

D = [1
     2]

x_D = [1.78448275;
       2.82327586;]
 */
LinearLeastSquaresProblem* LinearLeastSquaresProblem0() {
  LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;

  TripletSparseMatrix* A = new TripletSparseMatrix(3, 2, 6);
  problem->b.reset(new double[3]);
  problem->D.reset(new double[2]);

  problem->x.reset(new double[2]);
  problem->x_D.reset(new double[2]);

  int* Ai = A->mutable_rows();
  int* Aj = A->mutable_cols();
  double* Ax = A->mutable_values();

  int counter = 0;
  for (int i = 0; i < 3; ++i) {
    for (int j = 0; j< 2; ++j) {
      Ai[counter]=i;
      Aj[counter]=j;
      ++counter;
    }
  };

  Ax[0] = 1.;
  Ax[1] = 2.;
  Ax[2] = 3.;
  Ax[3] = 4.;
  Ax[4] = 6;
  Ax[5] = -10;
  A->set_num_nonzeros(6);
  problem->A.reset(A);

  problem->b[0] = 8;
  problem->b[1] = 18;
  problem->b[2] = -18;

  problem->x[0] = 2.0;
  problem->x[1] = 3.0;

  problem->D[0] = 1;
  problem->D[1] = 2;

  problem->x_D[0] = 1.78448275;
  problem->x_D[1] = 2.82327586;
  return problem;
}


/*
      A = [1 0  | 2 0 0
           3 0  | 0 4 0
           0 5  | 0 0 6
           0 7  | 8 0 0
           0 9  | 1 0 0
           0 0  | 1 1 1]

      b = [0
           1
           2
           3
           4
           5]

      c = A'* b = [ 3
                   67
                   33
                    9
                   17]

      A'A = [10    0    2   12   0
              0  155   65    0  30
              2   65   70    1   1
             12    0    1   17   1
              0   30    1    1  37]

      S = [ 42.3419  -1.4000  -11.5806
            -1.4000   2.6000    1.0000
            11.5806   1.0000   31.1935]

      r = [ 4.3032
            5.4000
            5.0323]

      S\r = [ 0.2102
              2.1367
              0.1388]

      A\b = [-2.3061
              0.3172
              0.2102
              2.1367
              0.1388]
*/
// The following two functions create a TripletSparseMatrix and a
// BlockSparseMatrix version of this problem.

// TripletSparseMatrix version.
LinearLeastSquaresProblem* LinearLeastSquaresProblem1() {
  int num_rows = 6;
  int num_cols = 5;

  LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
  TripletSparseMatrix* A = new TripletSparseMatrix(num_rows,
                                                   num_cols,
                                                   num_rows * num_cols);
  problem->b.reset(new double[num_rows]);
  problem->D.reset(new double[num_cols]);
  problem->num_eliminate_blocks = 2;

  int* rows = A->mutable_rows();
  int* cols = A->mutable_cols();
  double* values = A->mutable_values();

  int nnz = 0;

  // Row 1
  {
    rows[nnz] = 0;
    cols[nnz] = 0;
    values[nnz++] = 1;

    rows[nnz] = 0;
    cols[nnz] = 2;
    values[nnz++] = 2;
  }

  // Row 2
  {
    rows[nnz] = 1;
    cols[nnz] = 0;
    values[nnz++] = 3;

    rows[nnz] = 1;
    cols[nnz] = 3;
    values[nnz++] = 4;
  }

  // Row 3
  {
    rows[nnz] = 2;
    cols[nnz] = 1;
    values[nnz++] = 5;

    rows[nnz] = 2;
    cols[nnz] = 4;
    values[nnz++] = 6;
  }

  // Row 4
  {
    rows[nnz] = 3;
    cols[nnz] = 1;
    values[nnz++] = 7;

    rows[nnz] = 3;
    cols[nnz] = 2;
    values[nnz++] = 8;
  }

  // Row 5
  {
    rows[nnz] = 4;
    cols[nnz] = 1;
    values[nnz++] = 9;

    rows[nnz] = 4;
    cols[nnz] = 2;
    values[nnz++] = 1;
  }

  // Row 6
  {
    rows[nnz] = 5;
    cols[nnz] = 2;
    values[nnz++] = 1;

    rows[nnz] = 5;
    cols[nnz] = 3;
    values[nnz++] = 1;

    rows[nnz] = 5;
    cols[nnz] = 4;
    values[nnz++] = 1;
  }

  A->set_num_nonzeros(nnz);
  CHECK(A->IsValid());

  problem->A.reset(A);

  for (int i = 0; i < num_cols; ++i) {
    problem->D.get()[i] = 1;
  }

  for (int i = 0; i < num_rows; ++i) {
    problem->b.get()[i] = i;
  }

  return problem;
}

// BlockSparseMatrix version
LinearLeastSquaresProblem* LinearLeastSquaresProblem2() {
  int num_rows = 6;
  int num_cols = 5;

  LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;

  problem->b.reset(new double[num_rows]);
  problem->D.reset(new double[num_cols]);
  problem->num_eliminate_blocks = 2;

  CompressedRowBlockStructure* bs = new CompressedRowBlockStructure;
  scoped_array<double> values(new double[num_rows * num_cols]);

  for (int c = 0; c < num_cols; ++c) {
    bs->cols.push_back(Block());
    bs->cols.back().size = 1;
    bs->cols.back().position = c;
  }

  int nnz = 0;

  // Row 1
  {
    values[nnz++] = 1;
    values[nnz++] = 2;

    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 0;
    row.cells.push_back(Cell(0, 0));
    row.cells.push_back(Cell(2, 1));
  }

  // Row 2
  {
    values[nnz++] = 3;
    values[nnz++] = 4;

    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 1;
    row.cells.push_back(Cell(0, 2));
    row.cells.push_back(Cell(3, 3));
  }

  // Row 3
  {
    values[nnz++] = 5;
    values[nnz++] = 6;

    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 2;
    row.cells.push_back(Cell(1, 4));
    row.cells.push_back(Cell(4, 5));
  }

  // Row 4
  {
    values[nnz++] = 7;
    values[nnz++] = 8;

    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 3;
    row.cells.push_back(Cell(1, 6));
    row.cells.push_back(Cell(2, 7));
  }

  // Row 5
  {
    values[nnz++] = 9;
    values[nnz++] = 1;

    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 4;
    row.cells.push_back(Cell(1, 8));
    row.cells.push_back(Cell(2, 9));
  }

  // Row 6
  {
    values[nnz++] = 1;
    values[nnz++] = 1;
    values[nnz++] = 1;

    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 5;
    row.cells.push_back(Cell(2, 10));
    row.cells.push_back(Cell(3, 11));
    row.cells.push_back(Cell(4, 12));
  }

  BlockSparseMatrix* A = new BlockSparseMatrix(bs);
  memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values()));

  for (int i = 0; i < num_cols; ++i) {
    problem->D.get()[i] = 1;
  }

  for (int i = 0; i < num_rows; ++i) {
    problem->b.get()[i] = i;
  }

  problem->A.reset(A);

  return problem;
}


/*
      A = [1 0
           3 0
           0 5
           0 7
           0 9
           0 0]

      b = [0
           1
           2
           3
           4
           5]
*/
// BlockSparseMatrix version
LinearLeastSquaresProblem* LinearLeastSquaresProblem3() {
  int num_rows = 5;
  int num_cols = 2;

  LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;

  problem->b.reset(new double[num_rows]);
  problem->D.reset(new double[num_cols]);
  problem->num_eliminate_blocks = 2;

  CompressedRowBlockStructure* bs = new CompressedRowBlockStructure;
  scoped_array<double> values(new double[num_rows * num_cols]);

  for (int c = 0; c < num_cols; ++c) {
    bs->cols.push_back(Block());
    bs->cols.back().size = 1;
    bs->cols.back().position = c;
  }

  int nnz = 0;

  // Row 1
  {
    values[nnz++] = 1;
    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 0;
    row.cells.push_back(Cell(0, 0));
  }

  // Row 2
  {
    values[nnz++] = 3;
    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 1;
    row.cells.push_back(Cell(0, 1));
  }

  // Row 3
  {
    values[nnz++] = 5;
    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 2;
    row.cells.push_back(Cell(1, 2));
  }

  // Row 4
  {
    values[nnz++] = 7;
    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 3;
    row.cells.push_back(Cell(1, 3));
  }

  // Row 5
  {
    values[nnz++] = 9;
    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 4;
    row.cells.push_back(Cell(1, 4));
  }

  BlockSparseMatrix* A = new BlockSparseMatrix(bs);
  memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values()));

  for (int i = 0; i < num_cols; ++i) {
    problem->D.get()[i] = 1;
  }

  for (int i = 0; i < num_rows; ++i) {
    problem->b.get()[i] = i;
  }

  problem->A.reset(A);

  return problem;
}

}  // namespace internal
}  // namespace ceres
