Add Solver::Options::IsValid.
This provides a user visible way to validate the Solver::Options
before calling Solve.
Change-Id: Ife84fd33532ab2ccb7ac95abe22735843db51fde
diff --git a/docs/source/solving.rst b/docs/source/solving.rst
index 2ad84fd..32b7fc5 100644
--- a/docs/source/solving.rst
+++ b/docs/source/solving.rst
@@ -791,9 +791,14 @@
.. class:: Solver::Options
- :class:`Solver::Options` controls the overall behavior of the
- solver. We list the various settings and their default values below.
+ :class:`Solver::Options` controls the overall behavior of the
+ solver. We list the various settings and their default values below.
+.. function:: bool Solver::Options::IsValid(string* error) const
+
+ Validate the values in the options struct and returns true on
+ success. If there is a problem, the method returns false with
+ ``error`` containing a textual description of the cause.
.. member:: MinimizerType Solver::Options::minimizer_type
diff --git a/docs/source/version_history.rst b/docs/source/version_history.rst
index 45298a5..1dcc002 100644
--- a/docs/source/version_history.rst
+++ b/docs/source/version_history.rst
@@ -7,6 +7,9 @@
HEAD
====
+#. Added ``Solver::Options::IsValid`` which allows users to validate
+ their solver configuration before calling ``Solve``.
+
Backward Incompatible API Changes
---------------------------------
diff --git a/include/ceres/solver.h b/include/ceres/solver.h
index 33ffb54..4723c75 100644
--- a/include/ceres/solver.h
+++ b/include/ceres/solver.h
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
+// Copyright 2014 Google Inc. All rights reserved.
// http://code.google.com/p/ceres-solver/
//
// Redistribution and use in source and binary forms, with or without
@@ -126,6 +126,11 @@
update_state_every_iteration = false;
}
+ // Returns true if the options struct has a valid
+ // configuration. Returns false otherwise, and fills in *error
+ // with a message describing the problem.
+ bool IsValid(string* error) const;
+
// Minimizer options ----------------------------------------
// Ceres supports the two major families of optimization strategies -
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt
index a0dea4e..dd75a01 100644
--- a/internal/ceres/CMakeLists.txt
+++ b/internal/ceres/CMakeLists.txt
@@ -270,6 +270,7 @@
CERES_TEST(schur_eliminator)
CERES_TEST(single_linkage_clustering)
CERES_TEST(small_blas)
+ CERES_TEST(solver)
CERES_TEST(solver_impl)
# TODO(sameeragarwal): This test should ultimately be made
diff --git a/internal/ceres/solver.cc b/internal/ceres/solver.cc
index 7dcae7a..bec2e0c 100644
--- a/internal/ceres/solver.cc
+++ b/internal/ceres/solver.cc
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
+// Copyright 2014 Google Inc. All rights reserved.
// http://code.google.com/p/ceres-solver/
//
// Redistribution and use in source and binary forms, with or without
@@ -43,6 +43,242 @@
namespace ceres {
namespace {
+#define OPTION_GT(x, y) \
+ if (options.x <= y) { \
+ *error = string("Invalid configuration. Violated constraint " \
+ "Solver::Options::" #x " > " #y); \
+ return false; \
+ }
+
+#define OPTION_GE(x, y) \
+ if (options.x < y) { \
+ *error = string("Invalid configuration. Violated constraint " \
+ "Solver::Options::" #x " >= " #y); \
+ return false; \
+ }
+
+#define OPTION_LE(x, y) \
+ if (options.x > y) { \
+ *error = string("Invalid configuration. Violated constraint " \
+ "Solver::Options::" #x " <= " #y); \
+ return false; \
+ }
+
+#define OPTION_LT(x, y) \
+ if (options.x >= y) { \
+ *error = string("Invalid configuration. Violated constraint " \
+ "Solver::Options::" #x " < " #y); \
+ return false; \
+ }
+
+#define OPTION_LE_OPTION(x, y) \
+ if (options.x > options.y) { \
+ *error = string("Invalid configuration. Violated constraint " \
+ "Solver::Options::" #x " <= " \
+ "Solver::Options::" #y); \
+ return false; \
+ }
+
+#define OPTION_LT_OPTION(x, y) \
+ if (options.x >= options.y) { \
+ *error = string("Invalid configuration. Violated constraint " \
+ "Solver::Options::" #x " < " \
+ "Solver::Options::" #y); \
+ return false; \
+ }
+
+bool CommonOptionsAreValid(const Solver::Options& options, string* error) {
+ OPTION_GE(max_num_iterations, 0);
+ OPTION_GE(max_solver_time_in_seconds, 0.0);
+ OPTION_GE(function_tolerance, 0.0);
+ OPTION_GE(gradient_tolerance, 0.0);
+ OPTION_GE(parameter_tolerance, 0.0);
+ OPTION_GT(num_threads, 0);
+ OPTION_GT(num_linear_solver_threads, 0);
+ if (options.check_gradients) {
+ OPTION_GT(gradient_check_relative_precision, 0.0);
+ OPTION_GT(numeric_derivative_relative_step_size, 0.0);
+ }
+ return true;
+}
+
+bool TrustRegionOptionsAreValid(const Solver::Options& options, string* error) {
+ OPTION_GT(initial_trust_region_radius, 0.0);
+ OPTION_GT(min_trust_region_radius, 0.0);
+ OPTION_GT(max_trust_region_radius, 0.0);
+ OPTION_LE_OPTION(min_trust_region_radius, max_trust_region_radius);
+ OPTION_LE_OPTION(min_trust_region_radius, initial_trust_region_radius);
+ OPTION_LE_OPTION(initial_trust_region_radius, max_trust_region_radius);
+ OPTION_GT(min_relative_decrease, 0.0);
+ OPTION_GE(min_lm_diagonal, 0.0);
+ OPTION_GE(max_lm_diagonal, 0.0);
+ OPTION_LE_OPTION(min_lm_diagonal, max_lm_diagonal);
+ OPTION_GE(max_num_consecutive_invalid_steps, 0);
+ OPTION_GT(eta, 0.0);
+ OPTION_GE(min_linear_solver_iterations, 1);
+ OPTION_GE(max_linear_solver_iterations, 1);
+ OPTION_LE_OPTION(min_linear_solver_iterations, max_linear_solver_iterations);
+
+ if (options.use_inner_iterations) {
+ OPTION_GE(inner_iteration_tolerance, 0.0);
+ }
+
+ if (options.use_nonmonotonic_steps) {
+ OPTION_GT(max_consecutive_nonmonotonic_steps, 0);
+ }
+
+ if (options.preconditioner_type == CLUSTER_JACOBI &&
+ options.sparse_linear_algebra_library_type != SUITE_SPARSE) {
+ *error = "CLUSTER_JACOBI requires "
+ "Solver::Options::sparse_linear_algebra_library_type to be "
+ "SUITE_SPARSE";
+ return false;
+ }
+
+ if (options.preconditioner_type == CLUSTER_TRIDIAGONAL &&
+ options.sparse_linear_algebra_library_type != SUITE_SPARSE) {
+ *error = "CLUSTER_TRIDIAGONAL requires "
+ "Solver::Options::sparse_linear_algebra_library_type to be "
+ "SUITE_SPARSE";
+ return false;
+ }
+
+#ifdef CERES_NO_LAPACK
+ if (options.dense_linear_algebra_library_type == LAPACK) {
+ if (options.type == DENSE_NORMAL_CHOLESKY) {
+ *error = "Can't use DENSE_NORMAL_CHOLESKY with LAPACK because "
+ "LAPACK was not enabled when Ceres was built.";
+ return false;
+ }
+
+ if (options.type == DENSE_QR) {
+ *error = "Can't use DENSE_QR with LAPACK because "
+ "LAPACK was not enabled when Ceres was built.";
+ return false;
+ }
+
+ if (options.type == DENSE_SCHUR) {
+ *error = "Can't use DENSE_SCHUR with LAPACK because "
+ "LAPACK was not enabled when Ceres was built.";
+ return false;
+ }
+ }
+#endif
+
+#ifdef CERES_NO_SUITESPARSE
+ if (options.sparse_linear_algebra_library_type == SUITE_SPARSE) {
+ if (options.type == SPARSE_NORMAL_CHOLESKY) {
+ *error = "Can't use SPARSE_NORMAL_CHOLESKY with SUITESPARSE because "
+ "SuiteSparse was not enabled when Ceres was built.";
+ return false;
+ }
+
+ if (options.type == SPARSE_SCHUR) {
+ *error = "Can't use SPARSE_SCHUR with SUITESPARSE because "
+ "SuiteSparse was not enabled when Ceres was built.";
+ return false;
+ }
+
+ if (options.preconditioner_type == CLUSTER_JACOBI) {
+ *error = "CLUSTER_JACOBI preconditioner not supported. "
+ "SuiteSparse was not enabled when Ceres was built."
+ return false;
+ }
+
+ if (options.preconditioner_type == CLUSTER_TRIDIAGONAL) {
+ *error = "CLUSTER_TRIDIAGONAL preconditioner not supported. "
+ "SuiteSparse was not enabled when Ceres was built."
+ return false;
+ }
+ }
+#endif
+
+#ifdef CERES_NO_CXSPARSE
+ if (options.sparse_linear_algebra_library_type == CX_SPARSE) {
+ if (options.type == SPARSE_NORMAL_CHOLESKY) {
+ *error = "Can't use SPARSE_NORMAL_CHOLESKY with CX_SPARSE because "
+ "CXSparse was not enabled when Ceres was built.";
+ return false;
+ }
+
+ if (options.type == SPARSE_SCHUR) {
+ *error = "Can't use SPARSE_SCHUR with CX_SPARSE because "
+ "CXSparse was not enabled when Ceres was built.";
+ return false;
+ }
+ }
+#endif
+
+ if (options.trust_region_strategy_type == DOGLEG) {
+ if (options.linear_solver_type == ITERATIVE_SCHUR ||
+ options.linear_solver_type == CGNR) {
+ *error = "DOGLEG only supports exact factorization based linear "
+ "solvers. If you want to use an iterative solver please "
+ "use LEVENBERG_MARQUARDT as the trust_region_strategy_type";
+ return false;
+ }
+ }
+
+ if (options.trust_region_minimizer_iterations_to_dump.size() > 0 &&
+ options.trust_region_problem_dump_format_type != CONSOLE &&
+ options.trust_region_problem_dump_directory.empty()) {
+ *error = "Solver::Options::trust_region_problem_dump_directory is empty.";
+ return false;
+ }
+
+ return true;
+}
+
+bool LineSearchOptionsAreValid(const Solver::Options& options, string* error) {
+ OPTION_GT(max_lbfgs_rank, 0);
+ OPTION_GT(min_line_search_step_size, 0.0);
+ OPTION_GT(max_line_search_step_contraction, 0.0);
+ OPTION_LT(max_line_search_step_contraction, 1.0);
+ OPTION_LT_OPTION(max_line_search_step_contraction,
+ min_line_search_step_contraction);
+ OPTION_LE(min_line_search_step_contraction, 1.0);
+ OPTION_GT(max_num_line_search_step_size_iterations, 0);
+ OPTION_GT(line_search_sufficient_function_decrease, 0.0);
+ OPTION_LT_OPTION(line_search_sufficient_function_decrease,
+ line_search_sufficient_curvature_decrease);
+ OPTION_LT(line_search_sufficient_curvature_decrease, 1.0);
+ OPTION_GT(max_line_search_step_expansion, 1.0);
+
+ if ((options.line_search_direction_type == ceres::BFGS ||
+ options.line_search_direction_type == ceres::LBFGS) &&
+ options.line_search_type != ceres::WOLFE) {
+ *error =
+ string("Invalid configuration: require line_search_type == "
+ "ceres::WOLFE when using (L)BFGS to ensure that underlying "
+ "assumptions are guaranteed to be satisfied.");
+ return false;
+ }
+
+ // Warn user if they have requested BISECTION interpolation, but constraints
+ // on max/min step size change during line search prevent bisection scaling
+ // from occurring. Warn only, as this is likely a user mistake, but one which
+ // does not prevent us from continuing.
+ LOG_IF(WARNING,
+ (options.line_search_interpolation_type == ceres::BISECTION &&
+ (options.max_line_search_step_contraction > 0.5 ||
+ options.min_line_search_step_contraction < 0.5)))
+ << "Line search interpolation type is BISECTION, but specified "
+ << "max_line_search_step_contraction: "
+ << options.max_line_search_step_contraction << ", and "
+ << "min_line_search_step_contraction: "
+ << options.min_line_search_step_contraction
+ << ", prevent bisection (0.5) scaling, continuing with solve regardless.";
+
+ return true;
+}
+
+#undef OPTION_GT
+#undef OPTION_GE
+#undef OPTION_LE
+#undef OPTION_LT
+#undef OPTION_LE_OPTION
+#undef OPTION_LT_OPTION
+
void StringifyOrdering(const vector<int>& ordering, string* report) {
if (ordering.size() == 0) {
internal::StringAppendF(report, "AUTOMATIC");
@@ -55,7 +291,20 @@
internal::StringAppendF(report, "%d", ordering.back());
}
-} // namespace
+} // namespace
+
+bool Solver::Options::IsValid(string* error) const {
+ if (!CommonOptionsAreValid(*this, error)) {
+ return false;
+ }
+
+ if (minimizer_type == TRUST_REGION) {
+ return TrustRegionOptionsAreValid(*this, error);
+ }
+
+ CHECK_EQ(minimizer_type, LINE_SEARCH);
+ return LineSearchOptionsAreValid(*this, error);
+}
Solver::~Solver() {}
@@ -63,8 +312,16 @@
Problem* problem,
Solver::Summary* summary) {
double start_time_seconds = internal::WallTimeInSeconds();
- internal::ProblemImpl* problem_impl =
- CHECK_NOTNULL(problem)->problem_impl_.get();
+ CHECK_NOTNULL(problem);
+ CHECK_NOTNULL(summary);
+
+ *summary = Summary();
+ if (!options.IsValid(&summary->message)) {
+ LOG(ERROR) << "Terminating: " << summary->message;
+ return;
+ }
+
+ internal::ProblemImpl* problem_impl = problem->problem_impl_.get();
internal::SolverImpl::Solve(options, problem_impl, summary);
summary->total_time_in_seconds =
internal::WallTimeInSeconds() - start_time_seconds;
diff --git a/internal/ceres/solver_impl.cc b/internal/ceres/solver_impl.cc
index 9b39156..421a4d9 100644
--- a/internal/ceres/solver_impl.cc
+++ b/internal/ceres/solver_impl.cc
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
+// Copyright 2014 Google Inc. All rights reserved.
// http://code.google.com/p/ceres-solver/
//
// Redistribution and use in source and binary forms, with or without
@@ -61,86 +61,6 @@
namespace ceres {
namespace internal {
-namespace {
-
-bool LineSearchOptionsAreValid(const Solver::Options& options,
- string* message) {
- // Validate values for configuration parameters supplied by user.
- if ((options.line_search_direction_type == ceres::BFGS ||
- options.line_search_direction_type == ceres::LBFGS) &&
- options.line_search_type != ceres::WOLFE) {
- *message =
- string("Invalid configuration: require line_search_type == "
- "ceres::WOLFE when using (L)BFGS to ensure that underlying "
- "assumptions are guaranteed to be satisfied.");
- return false;
- }
- if (options.max_lbfgs_rank <= 0) {
- *message =
- string("Invalid configuration: require max_lbfgs_rank > 0");
- return false;
- }
- if (options.min_line_search_step_size <= 0.0) {
- *message =
- "Invalid configuration: require min_line_search_step_size > 0.0.";
- return false;
- }
- if (options.line_search_sufficient_function_decrease <= 0.0) {
- *message =
- string("Invalid configuration: require ") +
- string("line_search_sufficient_function_decrease > 0.0.");
- return false;
- }
- if (options.max_line_search_step_contraction <= 0.0 ||
- options.max_line_search_step_contraction >= 1.0) {
- *message = string("Invalid configuration: require ") +
- string("0.0 < max_line_search_step_contraction < 1.0.");
- return false;
- }
- if (options.min_line_search_step_contraction <=
- options.max_line_search_step_contraction ||
- options.min_line_search_step_contraction > 1.0) {
- *message = string("Invalid configuration: require ") +
- string("max_line_search_step_contraction < ") +
- string("min_line_search_step_contraction <= 1.0.");
- return false;
- }
- // Warn user if they have requested BISECTION interpolation, but constraints
- // on max/min step size change during line search prevent bisection scaling
- // from occurring. Warn only, as this is likely a user mistake, but one which
- // does not prevent us from continuing.
- LOG_IF(WARNING,
- (options.line_search_interpolation_type == ceres::BISECTION &&
- (options.max_line_search_step_contraction > 0.5 ||
- options.min_line_search_step_contraction < 0.5)))
- << "Line search interpolation type is BISECTION, but specified "
- << "max_line_search_step_contraction: "
- << options.max_line_search_step_contraction << ", and "
- << "min_line_search_step_contraction: "
- << options.min_line_search_step_contraction
- << ", prevent bisection (0.5) scaling, continuing with solve regardless.";
- if (options.max_num_line_search_step_size_iterations <= 0) {
- *message = string("Invalid configuration: require ") +
- string("max_num_line_search_step_size_iterations > 0.");
- return false;
- }
- if (options.line_search_sufficient_curvature_decrease <=
- options.line_search_sufficient_function_decrease ||
- options.line_search_sufficient_curvature_decrease > 1.0) {
- *message = string("Invalid configuration: require ") +
- string("line_search_sufficient_function_decrease < ") +
- string("line_search_sufficient_curvature_decrease < 1.0.");
- return false;
- }
- if (options.max_line_search_step_expansion <= 1.0) {
- *message = string("Invalid configuration: require ") +
- string("max_line_search_step_expansion > 1.0.");
- return false;
- }
- return true;
-}
-
-} // namespace
void SolverImpl::TrustRegionMinimize(
const Solver::Options& options,
@@ -271,7 +191,6 @@
<< " residual blocks, "
<< problem_impl->NumResiduals()
<< " residuals.";
- *CHECK_NOTNULL(summary) = Solver::Summary();
if (options.minimizer_type == TRUST_REGION) {
TrustRegionSolve(options, problem_impl, summary);
} else {
@@ -401,7 +320,7 @@
double post_process_start_time = WallTimeInSeconds();
summary->message =
- "Terminating: Function tolerance reached. "
+ "Function tolerance reached. "
"No non-constant parameter blocks found.";
summary->termination_type = CONVERGENCE;
VLOG_IF(1, options.logging_type != SILENT) << summary->message;
@@ -535,11 +454,6 @@
summary->nonlinear_conjugate_gradient_type =
original_options.nonlinear_conjugate_gradient_type;
- if (!LineSearchOptionsAreValid(original_options, &summary->message)) {
- LOG(ERROR) << summary->message;
- return;
- }
-
if (original_program->IsBoundsConstrained()) {
summary->message = "LINE_SEARCH Minimizer does not support bounds.";
LOG(ERROR) << "Terminating: " << summary->message;
@@ -568,7 +482,7 @@
summary->num_threads_given = original_options.num_threads;
summary->num_threads_used = options.num_threads;
- if (original_program->ParameterBlocksAreFinite(&summary->message)) {
+ if (!original_program->ParameterBlocksAreFinite(&summary->message)) {
LOG(ERROR) << "Terminating: " << summary->message;
return;
}
@@ -626,10 +540,12 @@
WallTimeInSeconds() - solver_start_time;
summary->message =
- "Terminating: Function tolerance reached. "
+ "Function tolerance reached. "
"No non-constant parameter blocks found.";
summary->termination_type = CONVERGENCE;
VLOG_IF(1, options.logging_type != SILENT) << summary->message;
+ summary->initial_cost = summary->fixed_cost;
+ summary->final_cost = summary->fixed_cost;
const double post_process_start_time = WallTimeInSeconds();
SetSummaryFinalCost(summary);
diff --git a/internal/ceres/solver_impl_test.cc b/internal/ceres/solver_impl_test.cc
index c22ac49..b0005a7 100644
--- a/internal/ceres/solver_impl_test.cc
+++ b/internal/ceres/solver_impl_test.cc
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
+// Copyright 2014 Google Inc. All rights reserved.
// http://code.google.com/p/ceres-solver/
//
// Redistribution and use in source and binary forms, with or without
@@ -317,207 +317,6 @@
EXPECT_EQ(parameter_blocks[2]->user_state(), &y);
}
-#if defined(CERES_NO_SUITESPARSE) && defined(CERES_NO_CXSPARSE)
-TEST(SolverImpl, CreateLinearSolverNoSuiteSparse) {
- Solver::Options options;
- options.linear_solver_type = SPARSE_NORMAL_CHOLESKY;
- // CreateLinearSolver assumes a non-empty ordering.
- options.linear_solver_ordering.reset(new ParameterBlockOrdering);
- string message;
- EXPECT_FALSE(SolverImpl::CreateLinearSolver(&options, &message));
-}
-#endif
-
-TEST(SolverImpl, CreateLinearSolverNegativeMaxNumIterations) {
- Solver::Options options;
- options.linear_solver_type = DENSE_QR;
- options.max_linear_solver_iterations = -1;
- // CreateLinearSolver assumes a non-empty ordering.
- options.linear_solver_ordering.reset(new ParameterBlockOrdering);
- string message;
- EXPECT_EQ(SolverImpl::CreateLinearSolver(&options, &message),
- static_cast<LinearSolver*>(NULL));
-}
-
-TEST(SolverImpl, CreateLinearSolverNegativeMinNumIterations) {
- Solver::Options options;
- options.linear_solver_type = DENSE_QR;
- options.min_linear_solver_iterations = -1;
- // CreateLinearSolver assumes a non-empty ordering.
- options.linear_solver_ordering.reset(new ParameterBlockOrdering);
- string message;
- EXPECT_EQ(SolverImpl::CreateLinearSolver(&options, &message),
- static_cast<LinearSolver*>(NULL));
-}
-
-TEST(SolverImpl, CreateLinearSolverMaxLessThanMinIterations) {
- Solver::Options options;
- options.linear_solver_type = DENSE_QR;
- options.min_linear_solver_iterations = 10;
- options.max_linear_solver_iterations = 5;
- options.linear_solver_ordering.reset(new ParameterBlockOrdering);
- string message;
- EXPECT_EQ(SolverImpl::CreateLinearSolver(&options, &message),
- static_cast<LinearSolver*>(NULL));
-}
-
-TEST(SolverImpl, CreateLinearSolverDenseSchurMultipleThreads) {
- Solver::Options options;
- options.linear_solver_type = DENSE_SCHUR;
- options.num_linear_solver_threads = 2;
- // The Schur type solvers can only be created with the Ordering
- // contains at least one elimination group.
- options.linear_solver_ordering.reset(new ParameterBlockOrdering);
- double x;
- double y;
- options.linear_solver_ordering->AddElementToGroup(&x, 0);
- options.linear_solver_ordering->AddElementToGroup(&y, 0);
-
- string message;
- scoped_ptr<LinearSolver> solver(
- SolverImpl::CreateLinearSolver(&options, &message));
- EXPECT_TRUE(solver != NULL);
- EXPECT_EQ(options.linear_solver_type, DENSE_SCHUR);
- EXPECT_EQ(options.num_linear_solver_threads, 2);
-}
-
-TEST(SolverImpl, CreateIterativeLinearSolverForDogleg) {
- Solver::Options options;
- options.trust_region_strategy_type = DOGLEG;
- // CreateLinearSolver assumes a non-empty ordering.
- options.linear_solver_ordering.reset(new ParameterBlockOrdering);
- string message;
- options.linear_solver_type = ITERATIVE_SCHUR;
- EXPECT_EQ(SolverImpl::CreateLinearSolver(&options, &message),
- static_cast<LinearSolver*>(NULL));
-
- options.linear_solver_type = CGNR;
- EXPECT_EQ(SolverImpl::CreateLinearSolver(&options, &message),
- static_cast<LinearSolver*>(NULL));
-}
-
-TEST(SolverImpl, CreateLinearSolverNormalOperation) {
- Solver::Options options;
- scoped_ptr<LinearSolver> solver;
- options.linear_solver_type = DENSE_QR;
- // CreateLinearSolver assumes a non-empty ordering.
- options.linear_solver_ordering.reset(new ParameterBlockOrdering);
- string message;
- solver.reset(SolverImpl::CreateLinearSolver(&options, &message));
- EXPECT_EQ(options.linear_solver_type, DENSE_QR);
- EXPECT_TRUE(solver.get() != NULL);
-
- options.linear_solver_type = DENSE_NORMAL_CHOLESKY;
- solver.reset(SolverImpl::CreateLinearSolver(&options, &message));
- EXPECT_EQ(options.linear_solver_type, DENSE_NORMAL_CHOLESKY);
- EXPECT_TRUE(solver.get() != NULL);
-
-#ifndef CERES_NO_SUITESPARSE
- options.linear_solver_type = SPARSE_NORMAL_CHOLESKY;
- options.sparse_linear_algebra_library_type = SUITE_SPARSE;
- solver.reset(SolverImpl::CreateLinearSolver(&options, &message));
- EXPECT_EQ(options.linear_solver_type, SPARSE_NORMAL_CHOLESKY);
- EXPECT_TRUE(solver.get() != NULL);
-#endif
-
-#ifndef CERES_NO_CXSPARSE
- options.linear_solver_type = SPARSE_NORMAL_CHOLESKY;
- options.sparse_linear_algebra_library_type = CX_SPARSE;
- solver.reset(SolverImpl::CreateLinearSolver(&options, &message));
- EXPECT_EQ(options.linear_solver_type, SPARSE_NORMAL_CHOLESKY);
- EXPECT_TRUE(solver.get() != NULL);
-#endif
-
- double x;
- double y;
- options.linear_solver_ordering->AddElementToGroup(&x, 0);
- options.linear_solver_ordering->AddElementToGroup(&y, 0);
-
- options.linear_solver_type = DENSE_SCHUR;
- solver.reset(SolverImpl::CreateLinearSolver(&options, &message));
- EXPECT_EQ(options.linear_solver_type, DENSE_SCHUR);
- EXPECT_TRUE(solver.get() != NULL);
-
- options.linear_solver_type = SPARSE_SCHUR;
- solver.reset(SolverImpl::CreateLinearSolver(&options, &message));
-
-#if defined(CERES_NO_SUITESPARSE) && defined(CERES_NO_CXSPARSE)
- EXPECT_TRUE(SolverImpl::CreateLinearSolver(&options, &message) == NULL);
-#else
- EXPECT_TRUE(solver.get() != NULL);
- EXPECT_EQ(options.linear_solver_type, SPARSE_SCHUR);
-#endif
-
- options.linear_solver_type = ITERATIVE_SCHUR;
- solver.reset(SolverImpl::CreateLinearSolver(&options, &message));
- EXPECT_EQ(options.linear_solver_type, ITERATIVE_SCHUR);
- EXPECT_TRUE(solver.get() != NULL);
-}
-
-struct QuadraticCostFunction {
- template <typename T> bool operator()(const T* const x,
- T* residual) const {
- residual[0] = T(5.0) - *x;
- return true;
- }
-};
-
-struct RememberingCallback : public IterationCallback {
- explicit RememberingCallback(double *x) : calls(0), x(x) {}
- virtual ~RememberingCallback() {}
- virtual CallbackReturnType operator()(const IterationSummary& summary) {
- x_values.push_back(*x);
- return SOLVER_CONTINUE;
- }
- int calls;
- double *x;
- vector<double> x_values;
-};
-
-TEST(SolverImpl, UpdateStateEveryIterationOption) {
- double x = 50.0;
- const double original_x = x;
-
- scoped_ptr<CostFunction> cost_function(
- new AutoDiffCostFunction<QuadraticCostFunction, 1, 1>(
- new QuadraticCostFunction));
-
- Problem::Options problem_options;
- problem_options.cost_function_ownership = DO_NOT_TAKE_OWNERSHIP;
- ProblemImpl problem(problem_options);
- problem.AddResidualBlock(cost_function.get(), NULL, &x);
-
- Solver::Options options;
- options.linear_solver_type = DENSE_QR;
-
- RememberingCallback callback(&x);
- options.callbacks.push_back(&callback);
-
- Solver::Summary summary;
-
- int num_iterations;
-
- // First try: no updating.
- SolverImpl::Solve(options, &problem, &summary);
- num_iterations = summary.num_successful_steps +
- summary.num_unsuccessful_steps;
- EXPECT_GT(num_iterations, 1);
- for (int i = 0; i < callback.x_values.size(); ++i) {
- EXPECT_EQ(50.0, callback.x_values[i]);
- }
-
- // Second try: with updating
- x = 50.0;
- options.update_state_every_iteration = true;
- callback.x_values.clear();
- SolverImpl::Solve(options, &problem, &summary);
- num_iterations = summary.num_successful_steps +
- summary.num_unsuccessful_steps;
- EXPECT_GT(num_iterations, 1);
- EXPECT_EQ(original_x, callback.x_values[0]);
- EXPECT_NE(original_x, callback.x_values[1]);
-}
-
// The parameters must be in separate blocks so that they can be individually
// set constant or not.
struct Quadratic4DCostFunction {
@@ -579,44 +378,6 @@
EXPECT_TRUE(problem.program().IsValid());
}
-TEST(SolverImpl, NoParameterBlocks) {
- ProblemImpl problem_impl;
- Solver::Options options;
- Solver::Summary summary;
- SolverImpl::Solve(options, &problem_impl, &summary);
- EXPECT_EQ(summary.termination_type, CONVERGENCE);
- EXPECT_EQ(summary.message,
- "Terminating: Function tolerance reached. "
- "No non-constant parameter blocks found.");
-}
-
-TEST(SolverImpl, NoResiduals) {
- ProblemImpl problem_impl;
- Solver::Options options;
- Solver::Summary summary;
- double x = 1;
- problem_impl.AddParameterBlock(&x, 1);
- SolverImpl::Solve(options, &problem_impl, &summary);
- EXPECT_EQ(summary.termination_type, CONVERGENCE);
- EXPECT_EQ(summary.message,
- "Terminating: Function tolerance reached. "
- "No non-constant parameter blocks found.");
-}
-
-
-TEST(SolverImpl, ProblemIsConstant) {
- ProblemImpl problem_impl;
- Solver::Options options;
- Solver::Summary summary;
- double x = 1;
- problem_impl.AddResidualBlock(new UnaryIdentityCostFunction, NULL, &x);
- problem_impl.SetParameterBlockConstant(&x);
- SolverImpl::Solve(options, &problem_impl, &summary);
- EXPECT_EQ(summary.termination_type, CONVERGENCE);
- EXPECT_EQ(summary.initial_cost, 1.0 / 2.0);
- EXPECT_EQ(summary.final_cost, 1.0 / 2.0);
-}
-
TEST(SolverImpl, AlternateLinearSolverForSchurTypeLinearSolver) {
Solver::Options options;
diff --git a/internal/ceres/solver_test.cc b/internal/ceres/solver_test.cc
new file mode 100644
index 0000000..2a136f7
--- /dev/null
+++ b/internal/ceres/solver_test.cc
@@ -0,0 +1,298 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2014 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: sameeragarwal@google.com (Sameer Agarwal)
+
+#include "ceres/solver.h"
+
+#include <limits>
+#include <cmath>
+#include <vector>
+#include "gtest/gtest.h"
+#include "ceres/internal/scoped_ptr.h"
+#include "ceres/autodiff_cost_function.h"
+#include "ceres/sized_cost_function.h"
+#include "ceres/problem.h"
+#include "ceres/problem_impl.h"
+
+namespace ceres {
+namespace internal {
+
+TEST(SolverOptions, DefaultTrustRegionOptionsAreValid) {
+ Solver::Options options;
+ options.minimizer_type = TRUST_REGION;
+ string error;
+ EXPECT_TRUE(options.IsValid(&error)) << error;
+}
+
+TEST(SolverOptions, DefaultLineSearchOptionsAreValid) {
+ Solver::Options options;
+ options.minimizer_type = LINE_SEARCH;
+ string error;
+ EXPECT_TRUE(options.IsValid(&error)) << error;
+}
+
+struct QuadraticCostFunctor {
+ template <typename T> bool operator()(const T* const x,
+ T* residual) const {
+ residual[0] = T(5.0) - *x;
+ return true;
+ }
+
+ static CostFunction* Create() {
+ return new AutoDiffCostFunction<QuadraticCostFunctor, 1, 1>(
+ new QuadraticCostFunctor);
+ }
+};
+
+struct RememberingCallback : public IterationCallback {
+ explicit RememberingCallback(double *x) : calls(0), x(x) {}
+ virtual ~RememberingCallback() {}
+ virtual CallbackReturnType operator()(const IterationSummary& summary) {
+ x_values.push_back(*x);
+ return SOLVER_CONTINUE;
+ }
+ int calls;
+ double *x;
+ vector<double> x_values;
+};
+
+TEST(Solver, UpdateStateEveryIterationOption) {
+ double x = 50.0;
+ const double original_x = x;
+
+ scoped_ptr<CostFunction> cost_function(QuadraticCostFunctor::Create());
+ Problem::Options problem_options;
+ problem_options.cost_function_ownership = DO_NOT_TAKE_OWNERSHIP;
+ Problem problem(problem_options);
+ problem.AddResidualBlock(cost_function.get(), NULL, &x);
+
+ Solver::Options options;
+ options.linear_solver_type = DENSE_QR;
+
+ RememberingCallback callback(&x);
+ options.callbacks.push_back(&callback);
+
+ Solver::Summary summary;
+
+ int num_iterations;
+
+ // First try: no updating.
+ Solve(options, &problem, &summary);
+ num_iterations = summary.num_successful_steps +
+ summary.num_unsuccessful_steps;
+ EXPECT_GT(num_iterations, 1);
+ for (int i = 0; i < callback.x_values.size(); ++i) {
+ EXPECT_EQ(50.0, callback.x_values[i]);
+ }
+
+ // Second try: with updating
+ x = 50.0;
+ options.update_state_every_iteration = true;
+ callback.x_values.clear();
+ Solve(options, &problem, &summary);
+ num_iterations = summary.num_successful_steps +
+ summary.num_unsuccessful_steps;
+ EXPECT_GT(num_iterations, 1);
+ EXPECT_EQ(original_x, callback.x_values[0]);
+ EXPECT_NE(original_x, callback.x_values[1]);
+}
+
+// The parameters must be in separate blocks so that they can be individually
+// set constant or not.
+struct Quadratic4DCostFunction {
+ template <typename T> bool operator()(const T* const x,
+ const T* const y,
+ const T* const z,
+ const T* const w,
+ T* residual) const {
+ // A 4-dimension axis-aligned quadratic.
+ residual[0] = T(10.0) - *x +
+ T(20.0) - *y +
+ T(30.0) - *z +
+ T(40.0) - *w;
+ return true;
+ }
+
+ static CostFunction* Create() {
+ return new AutoDiffCostFunction<Quadratic4DCostFunction, 1, 1, 1, 1, 1>(
+ new Quadratic4DCostFunction);
+ }
+};
+
+// A cost function that simply returns its argument.
+class UnaryIdentityCostFunction : public SizedCostFunction<1, 1> {
+ public:
+ virtual bool Evaluate(double const* const* parameters,
+ double* residuals,
+ double** jacobians) const {
+ residuals[0] = parameters[0][0];
+ if (jacobians != NULL && jacobians[0] != NULL) {
+ jacobians[0][0] = 1.0;
+ }
+ return true;
+ }
+};
+
+TEST(Solver, TrustRegionProblemHasNoParameterBlocks) {
+ Problem problem;
+ Solver::Options options;
+ options.minimizer_type = TRUST_REGION;
+ Solver::Summary summary;
+ Solve(options, &problem, &summary);
+ EXPECT_EQ(summary.termination_type, CONVERGENCE);
+ EXPECT_EQ(summary.message,
+ "Function tolerance reached. "
+ "No non-constant parameter blocks found.");
+}
+
+TEST(Solver, LineSearchProblemHasNoParameterBlocks) {
+ Problem problem;
+ Solver::Options options;
+ options.minimizer_type = LINE_SEARCH;
+ Solver::Summary summary;
+ Solve(options, &problem, &summary);
+ EXPECT_EQ(summary.termination_type, CONVERGENCE);
+ EXPECT_EQ(summary.message,
+ "Function tolerance reached. "
+ "No non-constant parameter blocks found.");
+}
+
+TEST(Solver, TrustRegionProblemHasZeroResiduals) {
+ Problem problem;
+ double x = 1;
+ problem.AddParameterBlock(&x, 1);
+ Solver::Options options;
+ options.minimizer_type = TRUST_REGION;
+ Solver::Summary summary;
+ Solve(options, &problem, &summary);
+ EXPECT_EQ(summary.termination_type, CONVERGENCE);
+ EXPECT_EQ(summary.message,
+ "Function tolerance reached. "
+ "No non-constant parameter blocks found.");
+}
+
+TEST(Solver, LineSearchProblemHasZeroResiduals) {
+ Problem problem;
+ double x = 1;
+ problem.AddParameterBlock(&x, 1);
+ Solver::Options options;
+ options.minimizer_type = LINE_SEARCH;
+ Solver::Summary summary;
+ Solve(options, &problem, &summary);
+ EXPECT_EQ(summary.termination_type, CONVERGENCE);
+ EXPECT_EQ(summary.message,
+ "Function tolerance reached. "
+ "No non-constant parameter blocks found.");
+}
+
+TEST(Solver, TrustRegionProblemIsConstant) {
+ Problem problem;
+ double x = 1;
+ problem.AddResidualBlock(new UnaryIdentityCostFunction, NULL, &x);
+ problem.SetParameterBlockConstant(&x);
+ Solver::Options options;
+ options.minimizer_type = TRUST_REGION;
+ Solver::Summary summary;
+ Solve(options, &problem, &summary);
+ EXPECT_EQ(summary.termination_type, CONVERGENCE);
+ EXPECT_EQ(summary.initial_cost, 1.0 / 2.0);
+ EXPECT_EQ(summary.final_cost, 1.0 / 2.0);
+}
+
+TEST(Solver, LineSearchProblemIsConstant) {
+ Problem problem;
+ double x = 1;
+ problem.AddResidualBlock(new UnaryIdentityCostFunction, NULL, &x);
+ problem.SetParameterBlockConstant(&x);
+ Solver::Options options;
+ options.minimizer_type = LINE_SEARCH;
+ Solver::Summary summary;
+ Solve(options, &problem, &summary);
+ EXPECT_EQ(summary.termination_type, CONVERGENCE);
+ EXPECT_EQ(summary.initial_cost, 1.0 / 2.0);
+ EXPECT_EQ(summary.final_cost, 1.0 / 2.0);
+}
+
+#if defined(CERES_NO_SUITESPARSE)
+TEST(Solver, SparseNormalCholeskyNoSuiteSparse) {
+ Solver::Options options;
+ options.sparse_linear_algebra_library_type = SUITE_SPARSE;
+ options.linear_solver_type = SPARSE_NORMAL_CHOLESKY;
+ string message;
+ EXPECT_FALSE(options.IsValid(&message));
+}
+#endif
+
+#if defined(CERES_NO_CXSPARSE)
+TEST(Solver, SparseNormalCholeskyNoCXSparse) {
+ Solver::Options options;
+ options.sparse_linear_algebra_library_type = CX_SPARSE;
+ options.linear_solver_type = SPARSE_NORMAL_CHOLESKY;
+ string message;
+ EXPECT_FALSE(options.IsValid(&message));
+}
+#endif
+
+TEST(Solver, IterativeLinearSolverForDogleg) {
+ Solver::Options options;
+ options.trust_region_strategy_type = DOGLEG;
+ string message;
+ options.linear_solver_type = ITERATIVE_SCHUR;
+ EXPECT_FALSE(options.IsValid(&message));
+
+ options.linear_solver_type = CGNR;
+ EXPECT_FALSE(options.IsValid(&message));
+}
+
+TEST(Solver, LinearSolverTypeNormalOperation) {
+ Solver::Options options;
+ options.linear_solver_type = DENSE_QR;
+
+ string message;
+ EXPECT_TRUE(options.IsValid(&message));
+
+ options.linear_solver_type = DENSE_NORMAL_CHOLESKY;
+ EXPECT_TRUE(options.IsValid(&message));
+
+ options.linear_solver_type = DENSE_SCHUR;
+ EXPECT_TRUE(options.IsValid(&message));
+
+ options.linear_solver_type = SPARSE_SCHUR;
+#if defined(CERES_NO_SUITESPARSE) && defined(CERES_NO_CXSPARSE)
+ EXPECT_FALSE(options.IsValid(&message));
+#else
+ EXPECT_TRUE(options.IsValid(&message));
+#endif
+
+ options.linear_solver_type = ITERATIVE_SCHUR;
+ EXPECT_TRUE(options.IsValid(&message));
+}
+
+} // namespace internal
+} // namespace ceres