Add Solver::Options::IsValid. This provides a user visible way to validate the Solver::Options before calling Solve. Change-Id: Ife84fd33532ab2ccb7ac95abe22735843db51fde
diff --git a/docs/source/solving.rst b/docs/source/solving.rst index 2ad84fd..32b7fc5 100644 --- a/docs/source/solving.rst +++ b/docs/source/solving.rst
@@ -791,9 +791,14 @@ .. class:: Solver::Options - :class:`Solver::Options` controls the overall behavior of the - solver. We list the various settings and their default values below. + :class:`Solver::Options` controls the overall behavior of the + solver. We list the various settings and their default values below. +.. function:: bool Solver::Options::IsValid(string* error) const + + Validate the values in the options struct and returns true on + success. If there is a problem, the method returns false with + ``error`` containing a textual description of the cause. .. member:: MinimizerType Solver::Options::minimizer_type
diff --git a/docs/source/version_history.rst b/docs/source/version_history.rst index 45298a5..1dcc002 100644 --- a/docs/source/version_history.rst +++ b/docs/source/version_history.rst
@@ -7,6 +7,9 @@ HEAD ==== +#. Added ``Solver::Options::IsValid`` which allows users to validate + their solver configuration before calling ``Solve``. + Backward Incompatible API Changes ---------------------------------
diff --git a/include/ceres/solver.h b/include/ceres/solver.h index 33ffb54..4723c75 100644 --- a/include/ceres/solver.h +++ b/include/ceres/solver.h
@@ -1,5 +1,5 @@ // Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2010, 2011, 2012 Google Inc. All rights reserved. +// Copyright 2014 Google Inc. All rights reserved. // http://code.google.com/p/ceres-solver/ // // Redistribution and use in source and binary forms, with or without @@ -126,6 +126,11 @@ update_state_every_iteration = false; } + // Returns true if the options struct has a valid + // configuration. Returns false otherwise, and fills in *error + // with a message describing the problem. + bool IsValid(string* error) const; + // Minimizer options ---------------------------------------- // Ceres supports the two major families of optimization strategies -
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt index a0dea4e..dd75a01 100644 --- a/internal/ceres/CMakeLists.txt +++ b/internal/ceres/CMakeLists.txt
@@ -270,6 +270,7 @@ CERES_TEST(schur_eliminator) CERES_TEST(single_linkage_clustering) CERES_TEST(small_blas) + CERES_TEST(solver) CERES_TEST(solver_impl) # TODO(sameeragarwal): This test should ultimately be made
diff --git a/internal/ceres/solver.cc b/internal/ceres/solver.cc index 7dcae7a..bec2e0c 100644 --- a/internal/ceres/solver.cc +++ b/internal/ceres/solver.cc
@@ -1,5 +1,5 @@ // Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2010, 2011, 2012 Google Inc. All rights reserved. +// Copyright 2014 Google Inc. All rights reserved. // http://code.google.com/p/ceres-solver/ // // Redistribution and use in source and binary forms, with or without @@ -43,6 +43,242 @@ namespace ceres { namespace { +#define OPTION_GT(x, y) \ + if (options.x <= y) { \ + *error = string("Invalid configuration. Violated constraint " \ + "Solver::Options::" #x " > " #y); \ + return false; \ + } + +#define OPTION_GE(x, y) \ + if (options.x < y) { \ + *error = string("Invalid configuration. Violated constraint " \ + "Solver::Options::" #x " >= " #y); \ + return false; \ + } + +#define OPTION_LE(x, y) \ + if (options.x > y) { \ + *error = string("Invalid configuration. Violated constraint " \ + "Solver::Options::" #x " <= " #y); \ + return false; \ + } + +#define OPTION_LT(x, y) \ + if (options.x >= y) { \ + *error = string("Invalid configuration. Violated constraint " \ + "Solver::Options::" #x " < " #y); \ + return false; \ + } + +#define OPTION_LE_OPTION(x, y) \ + if (options.x > options.y) { \ + *error = string("Invalid configuration. Violated constraint " \ + "Solver::Options::" #x " <= " \ + "Solver::Options::" #y); \ + return false; \ + } + +#define OPTION_LT_OPTION(x, y) \ + if (options.x >= options.y) { \ + *error = string("Invalid configuration. Violated constraint " \ + "Solver::Options::" #x " < " \ + "Solver::Options::" #y); \ + return false; \ + } + +bool CommonOptionsAreValid(const Solver::Options& options, string* error) { + OPTION_GE(max_num_iterations, 0); + OPTION_GE(max_solver_time_in_seconds, 0.0); + OPTION_GE(function_tolerance, 0.0); + OPTION_GE(gradient_tolerance, 0.0); + OPTION_GE(parameter_tolerance, 0.0); + OPTION_GT(num_threads, 0); + OPTION_GT(num_linear_solver_threads, 0); + if (options.check_gradients) { + OPTION_GT(gradient_check_relative_precision, 0.0); + OPTION_GT(numeric_derivative_relative_step_size, 0.0); + } + return true; +} + +bool TrustRegionOptionsAreValid(const Solver::Options& options, string* error) { + OPTION_GT(initial_trust_region_radius, 0.0); + OPTION_GT(min_trust_region_radius, 0.0); + OPTION_GT(max_trust_region_radius, 0.0); + OPTION_LE_OPTION(min_trust_region_radius, max_trust_region_radius); + OPTION_LE_OPTION(min_trust_region_radius, initial_trust_region_radius); + OPTION_LE_OPTION(initial_trust_region_radius, max_trust_region_radius); + OPTION_GT(min_relative_decrease, 0.0); + OPTION_GE(min_lm_diagonal, 0.0); + OPTION_GE(max_lm_diagonal, 0.0); + OPTION_LE_OPTION(min_lm_diagonal, max_lm_diagonal); + OPTION_GE(max_num_consecutive_invalid_steps, 0); + OPTION_GT(eta, 0.0); + OPTION_GE(min_linear_solver_iterations, 1); + OPTION_GE(max_linear_solver_iterations, 1); + OPTION_LE_OPTION(min_linear_solver_iterations, max_linear_solver_iterations); + + if (options.use_inner_iterations) { + OPTION_GE(inner_iteration_tolerance, 0.0); + } + + if (options.use_nonmonotonic_steps) { + OPTION_GT(max_consecutive_nonmonotonic_steps, 0); + } + + if (options.preconditioner_type == CLUSTER_JACOBI && + options.sparse_linear_algebra_library_type != SUITE_SPARSE) { + *error = "CLUSTER_JACOBI requires " + "Solver::Options::sparse_linear_algebra_library_type to be " + "SUITE_SPARSE"; + return false; + } + + if (options.preconditioner_type == CLUSTER_TRIDIAGONAL && + options.sparse_linear_algebra_library_type != SUITE_SPARSE) { + *error = "CLUSTER_TRIDIAGONAL requires " + "Solver::Options::sparse_linear_algebra_library_type to be " + "SUITE_SPARSE"; + return false; + } + +#ifdef CERES_NO_LAPACK + if (options.dense_linear_algebra_library_type == LAPACK) { + if (options.type == DENSE_NORMAL_CHOLESKY) { + *error = "Can't use DENSE_NORMAL_CHOLESKY with LAPACK because " + "LAPACK was not enabled when Ceres was built."; + return false; + } + + if (options.type == DENSE_QR) { + *error = "Can't use DENSE_QR with LAPACK because " + "LAPACK was not enabled when Ceres was built."; + return false; + } + + if (options.type == DENSE_SCHUR) { + *error = "Can't use DENSE_SCHUR with LAPACK because " + "LAPACK was not enabled when Ceres was built."; + return false; + } + } +#endif + +#ifdef CERES_NO_SUITESPARSE + if (options.sparse_linear_algebra_library_type == SUITE_SPARSE) { + if (options.type == SPARSE_NORMAL_CHOLESKY) { + *error = "Can't use SPARSE_NORMAL_CHOLESKY with SUITESPARSE because " + "SuiteSparse was not enabled when Ceres was built."; + return false; + } + + if (options.type == SPARSE_SCHUR) { + *error = "Can't use SPARSE_SCHUR with SUITESPARSE because " + "SuiteSparse was not enabled when Ceres was built."; + return false; + } + + if (options.preconditioner_type == CLUSTER_JACOBI) { + *error = "CLUSTER_JACOBI preconditioner not supported. " + "SuiteSparse was not enabled when Ceres was built." + return false; + } + + if (options.preconditioner_type == CLUSTER_TRIDIAGONAL) { + *error = "CLUSTER_TRIDIAGONAL preconditioner not supported. " + "SuiteSparse was not enabled when Ceres was built." + return false; + } + } +#endif + +#ifdef CERES_NO_CXSPARSE + if (options.sparse_linear_algebra_library_type == CX_SPARSE) { + if (options.type == SPARSE_NORMAL_CHOLESKY) { + *error = "Can't use SPARSE_NORMAL_CHOLESKY with CX_SPARSE because " + "CXSparse was not enabled when Ceres was built."; + return false; + } + + if (options.type == SPARSE_SCHUR) { + *error = "Can't use SPARSE_SCHUR with CX_SPARSE because " + "CXSparse was not enabled when Ceres was built."; + return false; + } + } +#endif + + if (options.trust_region_strategy_type == DOGLEG) { + if (options.linear_solver_type == ITERATIVE_SCHUR || + options.linear_solver_type == CGNR) { + *error = "DOGLEG only supports exact factorization based linear " + "solvers. If you want to use an iterative solver please " + "use LEVENBERG_MARQUARDT as the trust_region_strategy_type"; + return false; + } + } + + if (options.trust_region_minimizer_iterations_to_dump.size() > 0 && + options.trust_region_problem_dump_format_type != CONSOLE && + options.trust_region_problem_dump_directory.empty()) { + *error = "Solver::Options::trust_region_problem_dump_directory is empty."; + return false; + } + + return true; +} + +bool LineSearchOptionsAreValid(const Solver::Options& options, string* error) { + OPTION_GT(max_lbfgs_rank, 0); + OPTION_GT(min_line_search_step_size, 0.0); + OPTION_GT(max_line_search_step_contraction, 0.0); + OPTION_LT(max_line_search_step_contraction, 1.0); + OPTION_LT_OPTION(max_line_search_step_contraction, + min_line_search_step_contraction); + OPTION_LE(min_line_search_step_contraction, 1.0); + OPTION_GT(max_num_line_search_step_size_iterations, 0); + OPTION_GT(line_search_sufficient_function_decrease, 0.0); + OPTION_LT_OPTION(line_search_sufficient_function_decrease, + line_search_sufficient_curvature_decrease); + OPTION_LT(line_search_sufficient_curvature_decrease, 1.0); + OPTION_GT(max_line_search_step_expansion, 1.0); + + if ((options.line_search_direction_type == ceres::BFGS || + options.line_search_direction_type == ceres::LBFGS) && + options.line_search_type != ceres::WOLFE) { + *error = + string("Invalid configuration: require line_search_type == " + "ceres::WOLFE when using (L)BFGS to ensure that underlying " + "assumptions are guaranteed to be satisfied."); + return false; + } + + // Warn user if they have requested BISECTION interpolation, but constraints + // on max/min step size change during line search prevent bisection scaling + // from occurring. Warn only, as this is likely a user mistake, but one which + // does not prevent us from continuing. + LOG_IF(WARNING, + (options.line_search_interpolation_type == ceres::BISECTION && + (options.max_line_search_step_contraction > 0.5 || + options.min_line_search_step_contraction < 0.5))) + << "Line search interpolation type is BISECTION, but specified " + << "max_line_search_step_contraction: " + << options.max_line_search_step_contraction << ", and " + << "min_line_search_step_contraction: " + << options.min_line_search_step_contraction + << ", prevent bisection (0.5) scaling, continuing with solve regardless."; + + return true; +} + +#undef OPTION_GT +#undef OPTION_GE +#undef OPTION_LE +#undef OPTION_LT +#undef OPTION_LE_OPTION +#undef OPTION_LT_OPTION + void StringifyOrdering(const vector<int>& ordering, string* report) { if (ordering.size() == 0) { internal::StringAppendF(report, "AUTOMATIC"); @@ -55,7 +291,20 @@ internal::StringAppendF(report, "%d", ordering.back()); } -} // namespace +} // namespace + +bool Solver::Options::IsValid(string* error) const { + if (!CommonOptionsAreValid(*this, error)) { + return false; + } + + if (minimizer_type == TRUST_REGION) { + return TrustRegionOptionsAreValid(*this, error); + } + + CHECK_EQ(minimizer_type, LINE_SEARCH); + return LineSearchOptionsAreValid(*this, error); +} Solver::~Solver() {} @@ -63,8 +312,16 @@ Problem* problem, Solver::Summary* summary) { double start_time_seconds = internal::WallTimeInSeconds(); - internal::ProblemImpl* problem_impl = - CHECK_NOTNULL(problem)->problem_impl_.get(); + CHECK_NOTNULL(problem); + CHECK_NOTNULL(summary); + + *summary = Summary(); + if (!options.IsValid(&summary->message)) { + LOG(ERROR) << "Terminating: " << summary->message; + return; + } + + internal::ProblemImpl* problem_impl = problem->problem_impl_.get(); internal::SolverImpl::Solve(options, problem_impl, summary); summary->total_time_in_seconds = internal::WallTimeInSeconds() - start_time_seconds;
diff --git a/internal/ceres/solver_impl.cc b/internal/ceres/solver_impl.cc index 9b39156..421a4d9 100644 --- a/internal/ceres/solver_impl.cc +++ b/internal/ceres/solver_impl.cc
@@ -1,5 +1,5 @@ // Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2010, 2011, 2012 Google Inc. All rights reserved. +// Copyright 2014 Google Inc. All rights reserved. // http://code.google.com/p/ceres-solver/ // // Redistribution and use in source and binary forms, with or without @@ -61,86 +61,6 @@ namespace ceres { namespace internal { -namespace { - -bool LineSearchOptionsAreValid(const Solver::Options& options, - string* message) { - // Validate values for configuration parameters supplied by user. - if ((options.line_search_direction_type == ceres::BFGS || - options.line_search_direction_type == ceres::LBFGS) && - options.line_search_type != ceres::WOLFE) { - *message = - string("Invalid configuration: require line_search_type == " - "ceres::WOLFE when using (L)BFGS to ensure that underlying " - "assumptions are guaranteed to be satisfied."); - return false; - } - if (options.max_lbfgs_rank <= 0) { - *message = - string("Invalid configuration: require max_lbfgs_rank > 0"); - return false; - } - if (options.min_line_search_step_size <= 0.0) { - *message = - "Invalid configuration: require min_line_search_step_size > 0.0."; - return false; - } - if (options.line_search_sufficient_function_decrease <= 0.0) { - *message = - string("Invalid configuration: require ") + - string("line_search_sufficient_function_decrease > 0.0."); - return false; - } - if (options.max_line_search_step_contraction <= 0.0 || - options.max_line_search_step_contraction >= 1.0) { - *message = string("Invalid configuration: require ") + - string("0.0 < max_line_search_step_contraction < 1.0."); - return false; - } - if (options.min_line_search_step_contraction <= - options.max_line_search_step_contraction || - options.min_line_search_step_contraction > 1.0) { - *message = string("Invalid configuration: require ") + - string("max_line_search_step_contraction < ") + - string("min_line_search_step_contraction <= 1.0."); - return false; - } - // Warn user if they have requested BISECTION interpolation, but constraints - // on max/min step size change during line search prevent bisection scaling - // from occurring. Warn only, as this is likely a user mistake, but one which - // does not prevent us from continuing. - LOG_IF(WARNING, - (options.line_search_interpolation_type == ceres::BISECTION && - (options.max_line_search_step_contraction > 0.5 || - options.min_line_search_step_contraction < 0.5))) - << "Line search interpolation type is BISECTION, but specified " - << "max_line_search_step_contraction: " - << options.max_line_search_step_contraction << ", and " - << "min_line_search_step_contraction: " - << options.min_line_search_step_contraction - << ", prevent bisection (0.5) scaling, continuing with solve regardless."; - if (options.max_num_line_search_step_size_iterations <= 0) { - *message = string("Invalid configuration: require ") + - string("max_num_line_search_step_size_iterations > 0."); - return false; - } - if (options.line_search_sufficient_curvature_decrease <= - options.line_search_sufficient_function_decrease || - options.line_search_sufficient_curvature_decrease > 1.0) { - *message = string("Invalid configuration: require ") + - string("line_search_sufficient_function_decrease < ") + - string("line_search_sufficient_curvature_decrease < 1.0."); - return false; - } - if (options.max_line_search_step_expansion <= 1.0) { - *message = string("Invalid configuration: require ") + - string("max_line_search_step_expansion > 1.0."); - return false; - } - return true; -} - -} // namespace void SolverImpl::TrustRegionMinimize( const Solver::Options& options, @@ -271,7 +191,6 @@ << " residual blocks, " << problem_impl->NumResiduals() << " residuals."; - *CHECK_NOTNULL(summary) = Solver::Summary(); if (options.minimizer_type == TRUST_REGION) { TrustRegionSolve(options, problem_impl, summary); } else { @@ -401,7 +320,7 @@ double post_process_start_time = WallTimeInSeconds(); summary->message = - "Terminating: Function tolerance reached. " + "Function tolerance reached. " "No non-constant parameter blocks found."; summary->termination_type = CONVERGENCE; VLOG_IF(1, options.logging_type != SILENT) << summary->message; @@ -535,11 +454,6 @@ summary->nonlinear_conjugate_gradient_type = original_options.nonlinear_conjugate_gradient_type; - if (!LineSearchOptionsAreValid(original_options, &summary->message)) { - LOG(ERROR) << summary->message; - return; - } - if (original_program->IsBoundsConstrained()) { summary->message = "LINE_SEARCH Minimizer does not support bounds."; LOG(ERROR) << "Terminating: " << summary->message; @@ -568,7 +482,7 @@ summary->num_threads_given = original_options.num_threads; summary->num_threads_used = options.num_threads; - if (original_program->ParameterBlocksAreFinite(&summary->message)) { + if (!original_program->ParameterBlocksAreFinite(&summary->message)) { LOG(ERROR) << "Terminating: " << summary->message; return; } @@ -626,10 +540,12 @@ WallTimeInSeconds() - solver_start_time; summary->message = - "Terminating: Function tolerance reached. " + "Function tolerance reached. " "No non-constant parameter blocks found."; summary->termination_type = CONVERGENCE; VLOG_IF(1, options.logging_type != SILENT) << summary->message; + summary->initial_cost = summary->fixed_cost; + summary->final_cost = summary->fixed_cost; const double post_process_start_time = WallTimeInSeconds(); SetSummaryFinalCost(summary);
diff --git a/internal/ceres/solver_impl_test.cc b/internal/ceres/solver_impl_test.cc index c22ac49..b0005a7 100644 --- a/internal/ceres/solver_impl_test.cc +++ b/internal/ceres/solver_impl_test.cc
@@ -1,5 +1,5 @@ // Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2010, 2011, 2012 Google Inc. All rights reserved. +// Copyright 2014 Google Inc. All rights reserved. // http://code.google.com/p/ceres-solver/ // // Redistribution and use in source and binary forms, with or without @@ -317,207 +317,6 @@ EXPECT_EQ(parameter_blocks[2]->user_state(), &y); } -#if defined(CERES_NO_SUITESPARSE) && defined(CERES_NO_CXSPARSE) -TEST(SolverImpl, CreateLinearSolverNoSuiteSparse) { - Solver::Options options; - options.linear_solver_type = SPARSE_NORMAL_CHOLESKY; - // CreateLinearSolver assumes a non-empty ordering. - options.linear_solver_ordering.reset(new ParameterBlockOrdering); - string message; - EXPECT_FALSE(SolverImpl::CreateLinearSolver(&options, &message)); -} -#endif - -TEST(SolverImpl, CreateLinearSolverNegativeMaxNumIterations) { - Solver::Options options; - options.linear_solver_type = DENSE_QR; - options.max_linear_solver_iterations = -1; - // CreateLinearSolver assumes a non-empty ordering. - options.linear_solver_ordering.reset(new ParameterBlockOrdering); - string message; - EXPECT_EQ(SolverImpl::CreateLinearSolver(&options, &message), - static_cast<LinearSolver*>(NULL)); -} - -TEST(SolverImpl, CreateLinearSolverNegativeMinNumIterations) { - Solver::Options options; - options.linear_solver_type = DENSE_QR; - options.min_linear_solver_iterations = -1; - // CreateLinearSolver assumes a non-empty ordering. - options.linear_solver_ordering.reset(new ParameterBlockOrdering); - string message; - EXPECT_EQ(SolverImpl::CreateLinearSolver(&options, &message), - static_cast<LinearSolver*>(NULL)); -} - -TEST(SolverImpl, CreateLinearSolverMaxLessThanMinIterations) { - Solver::Options options; - options.linear_solver_type = DENSE_QR; - options.min_linear_solver_iterations = 10; - options.max_linear_solver_iterations = 5; - options.linear_solver_ordering.reset(new ParameterBlockOrdering); - string message; - EXPECT_EQ(SolverImpl::CreateLinearSolver(&options, &message), - static_cast<LinearSolver*>(NULL)); -} - -TEST(SolverImpl, CreateLinearSolverDenseSchurMultipleThreads) { - Solver::Options options; - options.linear_solver_type = DENSE_SCHUR; - options.num_linear_solver_threads = 2; - // The Schur type solvers can only be created with the Ordering - // contains at least one elimination group. - options.linear_solver_ordering.reset(new ParameterBlockOrdering); - double x; - double y; - options.linear_solver_ordering->AddElementToGroup(&x, 0); - options.linear_solver_ordering->AddElementToGroup(&y, 0); - - string message; - scoped_ptr<LinearSolver> solver( - SolverImpl::CreateLinearSolver(&options, &message)); - EXPECT_TRUE(solver != NULL); - EXPECT_EQ(options.linear_solver_type, DENSE_SCHUR); - EXPECT_EQ(options.num_linear_solver_threads, 2); -} - -TEST(SolverImpl, CreateIterativeLinearSolverForDogleg) { - Solver::Options options; - options.trust_region_strategy_type = DOGLEG; - // CreateLinearSolver assumes a non-empty ordering. - options.linear_solver_ordering.reset(new ParameterBlockOrdering); - string message; - options.linear_solver_type = ITERATIVE_SCHUR; - EXPECT_EQ(SolverImpl::CreateLinearSolver(&options, &message), - static_cast<LinearSolver*>(NULL)); - - options.linear_solver_type = CGNR; - EXPECT_EQ(SolverImpl::CreateLinearSolver(&options, &message), - static_cast<LinearSolver*>(NULL)); -} - -TEST(SolverImpl, CreateLinearSolverNormalOperation) { - Solver::Options options; - scoped_ptr<LinearSolver> solver; - options.linear_solver_type = DENSE_QR; - // CreateLinearSolver assumes a non-empty ordering. - options.linear_solver_ordering.reset(new ParameterBlockOrdering); - string message; - solver.reset(SolverImpl::CreateLinearSolver(&options, &message)); - EXPECT_EQ(options.linear_solver_type, DENSE_QR); - EXPECT_TRUE(solver.get() != NULL); - - options.linear_solver_type = DENSE_NORMAL_CHOLESKY; - solver.reset(SolverImpl::CreateLinearSolver(&options, &message)); - EXPECT_EQ(options.linear_solver_type, DENSE_NORMAL_CHOLESKY); - EXPECT_TRUE(solver.get() != NULL); - -#ifndef CERES_NO_SUITESPARSE - options.linear_solver_type = SPARSE_NORMAL_CHOLESKY; - options.sparse_linear_algebra_library_type = SUITE_SPARSE; - solver.reset(SolverImpl::CreateLinearSolver(&options, &message)); - EXPECT_EQ(options.linear_solver_type, SPARSE_NORMAL_CHOLESKY); - EXPECT_TRUE(solver.get() != NULL); -#endif - -#ifndef CERES_NO_CXSPARSE - options.linear_solver_type = SPARSE_NORMAL_CHOLESKY; - options.sparse_linear_algebra_library_type = CX_SPARSE; - solver.reset(SolverImpl::CreateLinearSolver(&options, &message)); - EXPECT_EQ(options.linear_solver_type, SPARSE_NORMAL_CHOLESKY); - EXPECT_TRUE(solver.get() != NULL); -#endif - - double x; - double y; - options.linear_solver_ordering->AddElementToGroup(&x, 0); - options.linear_solver_ordering->AddElementToGroup(&y, 0); - - options.linear_solver_type = DENSE_SCHUR; - solver.reset(SolverImpl::CreateLinearSolver(&options, &message)); - EXPECT_EQ(options.linear_solver_type, DENSE_SCHUR); - EXPECT_TRUE(solver.get() != NULL); - - options.linear_solver_type = SPARSE_SCHUR; - solver.reset(SolverImpl::CreateLinearSolver(&options, &message)); - -#if defined(CERES_NO_SUITESPARSE) && defined(CERES_NO_CXSPARSE) - EXPECT_TRUE(SolverImpl::CreateLinearSolver(&options, &message) == NULL); -#else - EXPECT_TRUE(solver.get() != NULL); - EXPECT_EQ(options.linear_solver_type, SPARSE_SCHUR); -#endif - - options.linear_solver_type = ITERATIVE_SCHUR; - solver.reset(SolverImpl::CreateLinearSolver(&options, &message)); - EXPECT_EQ(options.linear_solver_type, ITERATIVE_SCHUR); - EXPECT_TRUE(solver.get() != NULL); -} - -struct QuadraticCostFunction { - template <typename T> bool operator()(const T* const x, - T* residual) const { - residual[0] = T(5.0) - *x; - return true; - } -}; - -struct RememberingCallback : public IterationCallback { - explicit RememberingCallback(double *x) : calls(0), x(x) {} - virtual ~RememberingCallback() {} - virtual CallbackReturnType operator()(const IterationSummary& summary) { - x_values.push_back(*x); - return SOLVER_CONTINUE; - } - int calls; - double *x; - vector<double> x_values; -}; - -TEST(SolverImpl, UpdateStateEveryIterationOption) { - double x = 50.0; - const double original_x = x; - - scoped_ptr<CostFunction> cost_function( - new AutoDiffCostFunction<QuadraticCostFunction, 1, 1>( - new QuadraticCostFunction)); - - Problem::Options problem_options; - problem_options.cost_function_ownership = DO_NOT_TAKE_OWNERSHIP; - ProblemImpl problem(problem_options); - problem.AddResidualBlock(cost_function.get(), NULL, &x); - - Solver::Options options; - options.linear_solver_type = DENSE_QR; - - RememberingCallback callback(&x); - options.callbacks.push_back(&callback); - - Solver::Summary summary; - - int num_iterations; - - // First try: no updating. - SolverImpl::Solve(options, &problem, &summary); - num_iterations = summary.num_successful_steps + - summary.num_unsuccessful_steps; - EXPECT_GT(num_iterations, 1); - for (int i = 0; i < callback.x_values.size(); ++i) { - EXPECT_EQ(50.0, callback.x_values[i]); - } - - // Second try: with updating - x = 50.0; - options.update_state_every_iteration = true; - callback.x_values.clear(); - SolverImpl::Solve(options, &problem, &summary); - num_iterations = summary.num_successful_steps + - summary.num_unsuccessful_steps; - EXPECT_GT(num_iterations, 1); - EXPECT_EQ(original_x, callback.x_values[0]); - EXPECT_NE(original_x, callback.x_values[1]); -} - // The parameters must be in separate blocks so that they can be individually // set constant or not. struct Quadratic4DCostFunction { @@ -579,44 +378,6 @@ EXPECT_TRUE(problem.program().IsValid()); } -TEST(SolverImpl, NoParameterBlocks) { - ProblemImpl problem_impl; - Solver::Options options; - Solver::Summary summary; - SolverImpl::Solve(options, &problem_impl, &summary); - EXPECT_EQ(summary.termination_type, CONVERGENCE); - EXPECT_EQ(summary.message, - "Terminating: Function tolerance reached. " - "No non-constant parameter blocks found."); -} - -TEST(SolverImpl, NoResiduals) { - ProblemImpl problem_impl; - Solver::Options options; - Solver::Summary summary; - double x = 1; - problem_impl.AddParameterBlock(&x, 1); - SolverImpl::Solve(options, &problem_impl, &summary); - EXPECT_EQ(summary.termination_type, CONVERGENCE); - EXPECT_EQ(summary.message, - "Terminating: Function tolerance reached. " - "No non-constant parameter blocks found."); -} - - -TEST(SolverImpl, ProblemIsConstant) { - ProblemImpl problem_impl; - Solver::Options options; - Solver::Summary summary; - double x = 1; - problem_impl.AddResidualBlock(new UnaryIdentityCostFunction, NULL, &x); - problem_impl.SetParameterBlockConstant(&x); - SolverImpl::Solve(options, &problem_impl, &summary); - EXPECT_EQ(summary.termination_type, CONVERGENCE); - EXPECT_EQ(summary.initial_cost, 1.0 / 2.0); - EXPECT_EQ(summary.final_cost, 1.0 / 2.0); -} - TEST(SolverImpl, AlternateLinearSolverForSchurTypeLinearSolver) { Solver::Options options;
diff --git a/internal/ceres/solver_test.cc b/internal/ceres/solver_test.cc new file mode 100644 index 0000000..2a136f7 --- /dev/null +++ b/internal/ceres/solver_test.cc
@@ -0,0 +1,298 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2014 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: sameeragarwal@google.com (Sameer Agarwal) + +#include "ceres/solver.h" + +#include <limits> +#include <cmath> +#include <vector> +#include "gtest/gtest.h" +#include "ceres/internal/scoped_ptr.h" +#include "ceres/autodiff_cost_function.h" +#include "ceres/sized_cost_function.h" +#include "ceres/problem.h" +#include "ceres/problem_impl.h" + +namespace ceres { +namespace internal { + +TEST(SolverOptions, DefaultTrustRegionOptionsAreValid) { + Solver::Options options; + options.minimizer_type = TRUST_REGION; + string error; + EXPECT_TRUE(options.IsValid(&error)) << error; +} + +TEST(SolverOptions, DefaultLineSearchOptionsAreValid) { + Solver::Options options; + options.minimizer_type = LINE_SEARCH; + string error; + EXPECT_TRUE(options.IsValid(&error)) << error; +} + +struct QuadraticCostFunctor { + template <typename T> bool operator()(const T* const x, + T* residual) const { + residual[0] = T(5.0) - *x; + return true; + } + + static CostFunction* Create() { + return new AutoDiffCostFunction<QuadraticCostFunctor, 1, 1>( + new QuadraticCostFunctor); + } +}; + +struct RememberingCallback : public IterationCallback { + explicit RememberingCallback(double *x) : calls(0), x(x) {} + virtual ~RememberingCallback() {} + virtual CallbackReturnType operator()(const IterationSummary& summary) { + x_values.push_back(*x); + return SOLVER_CONTINUE; + } + int calls; + double *x; + vector<double> x_values; +}; + +TEST(Solver, UpdateStateEveryIterationOption) { + double x = 50.0; + const double original_x = x; + + scoped_ptr<CostFunction> cost_function(QuadraticCostFunctor::Create()); + Problem::Options problem_options; + problem_options.cost_function_ownership = DO_NOT_TAKE_OWNERSHIP; + Problem problem(problem_options); + problem.AddResidualBlock(cost_function.get(), NULL, &x); + + Solver::Options options; + options.linear_solver_type = DENSE_QR; + + RememberingCallback callback(&x); + options.callbacks.push_back(&callback); + + Solver::Summary summary; + + int num_iterations; + + // First try: no updating. + Solve(options, &problem, &summary); + num_iterations = summary.num_successful_steps + + summary.num_unsuccessful_steps; + EXPECT_GT(num_iterations, 1); + for (int i = 0; i < callback.x_values.size(); ++i) { + EXPECT_EQ(50.0, callback.x_values[i]); + } + + // Second try: with updating + x = 50.0; + options.update_state_every_iteration = true; + callback.x_values.clear(); + Solve(options, &problem, &summary); + num_iterations = summary.num_successful_steps + + summary.num_unsuccessful_steps; + EXPECT_GT(num_iterations, 1); + EXPECT_EQ(original_x, callback.x_values[0]); + EXPECT_NE(original_x, callback.x_values[1]); +} + +// The parameters must be in separate blocks so that they can be individually +// set constant or not. +struct Quadratic4DCostFunction { + template <typename T> bool operator()(const T* const x, + const T* const y, + const T* const z, + const T* const w, + T* residual) const { + // A 4-dimension axis-aligned quadratic. + residual[0] = T(10.0) - *x + + T(20.0) - *y + + T(30.0) - *z + + T(40.0) - *w; + return true; + } + + static CostFunction* Create() { + return new AutoDiffCostFunction<Quadratic4DCostFunction, 1, 1, 1, 1, 1>( + new Quadratic4DCostFunction); + } +}; + +// A cost function that simply returns its argument. +class UnaryIdentityCostFunction : public SizedCostFunction<1, 1> { + public: + virtual bool Evaluate(double const* const* parameters, + double* residuals, + double** jacobians) const { + residuals[0] = parameters[0][0]; + if (jacobians != NULL && jacobians[0] != NULL) { + jacobians[0][0] = 1.0; + } + return true; + } +}; + +TEST(Solver, TrustRegionProblemHasNoParameterBlocks) { + Problem problem; + Solver::Options options; + options.minimizer_type = TRUST_REGION; + Solver::Summary summary; + Solve(options, &problem, &summary); + EXPECT_EQ(summary.termination_type, CONVERGENCE); + EXPECT_EQ(summary.message, + "Function tolerance reached. " + "No non-constant parameter blocks found."); +} + +TEST(Solver, LineSearchProblemHasNoParameterBlocks) { + Problem problem; + Solver::Options options; + options.minimizer_type = LINE_SEARCH; + Solver::Summary summary; + Solve(options, &problem, &summary); + EXPECT_EQ(summary.termination_type, CONVERGENCE); + EXPECT_EQ(summary.message, + "Function tolerance reached. " + "No non-constant parameter blocks found."); +} + +TEST(Solver, TrustRegionProblemHasZeroResiduals) { + Problem problem; + double x = 1; + problem.AddParameterBlock(&x, 1); + Solver::Options options; + options.minimizer_type = TRUST_REGION; + Solver::Summary summary; + Solve(options, &problem, &summary); + EXPECT_EQ(summary.termination_type, CONVERGENCE); + EXPECT_EQ(summary.message, + "Function tolerance reached. " + "No non-constant parameter blocks found."); +} + +TEST(Solver, LineSearchProblemHasZeroResiduals) { + Problem problem; + double x = 1; + problem.AddParameterBlock(&x, 1); + Solver::Options options; + options.minimizer_type = LINE_SEARCH; + Solver::Summary summary; + Solve(options, &problem, &summary); + EXPECT_EQ(summary.termination_type, CONVERGENCE); + EXPECT_EQ(summary.message, + "Function tolerance reached. " + "No non-constant parameter blocks found."); +} + +TEST(Solver, TrustRegionProblemIsConstant) { + Problem problem; + double x = 1; + problem.AddResidualBlock(new UnaryIdentityCostFunction, NULL, &x); + problem.SetParameterBlockConstant(&x); + Solver::Options options; + options.minimizer_type = TRUST_REGION; + Solver::Summary summary; + Solve(options, &problem, &summary); + EXPECT_EQ(summary.termination_type, CONVERGENCE); + EXPECT_EQ(summary.initial_cost, 1.0 / 2.0); + EXPECT_EQ(summary.final_cost, 1.0 / 2.0); +} + +TEST(Solver, LineSearchProblemIsConstant) { + Problem problem; + double x = 1; + problem.AddResidualBlock(new UnaryIdentityCostFunction, NULL, &x); + problem.SetParameterBlockConstant(&x); + Solver::Options options; + options.minimizer_type = LINE_SEARCH; + Solver::Summary summary; + Solve(options, &problem, &summary); + EXPECT_EQ(summary.termination_type, CONVERGENCE); + EXPECT_EQ(summary.initial_cost, 1.0 / 2.0); + EXPECT_EQ(summary.final_cost, 1.0 / 2.0); +} + +#if defined(CERES_NO_SUITESPARSE) +TEST(Solver, SparseNormalCholeskyNoSuiteSparse) { + Solver::Options options; + options.sparse_linear_algebra_library_type = SUITE_SPARSE; + options.linear_solver_type = SPARSE_NORMAL_CHOLESKY; + string message; + EXPECT_FALSE(options.IsValid(&message)); +} +#endif + +#if defined(CERES_NO_CXSPARSE) +TEST(Solver, SparseNormalCholeskyNoCXSparse) { + Solver::Options options; + options.sparse_linear_algebra_library_type = CX_SPARSE; + options.linear_solver_type = SPARSE_NORMAL_CHOLESKY; + string message; + EXPECT_FALSE(options.IsValid(&message)); +} +#endif + +TEST(Solver, IterativeLinearSolverForDogleg) { + Solver::Options options; + options.trust_region_strategy_type = DOGLEG; + string message; + options.linear_solver_type = ITERATIVE_SCHUR; + EXPECT_FALSE(options.IsValid(&message)); + + options.linear_solver_type = CGNR; + EXPECT_FALSE(options.IsValid(&message)); +} + +TEST(Solver, LinearSolverTypeNormalOperation) { + Solver::Options options; + options.linear_solver_type = DENSE_QR; + + string message; + EXPECT_TRUE(options.IsValid(&message)); + + options.linear_solver_type = DENSE_NORMAL_CHOLESKY; + EXPECT_TRUE(options.IsValid(&message)); + + options.linear_solver_type = DENSE_SCHUR; + EXPECT_TRUE(options.IsValid(&message)); + + options.linear_solver_type = SPARSE_SCHUR; +#if defined(CERES_NO_SUITESPARSE) && defined(CERES_NO_CXSPARSE) + EXPECT_FALSE(options.IsValid(&message)); +#else + EXPECT_TRUE(options.IsValid(&message)); +#endif + + options.linear_solver_type = ITERATIVE_SCHUR; + EXPECT_TRUE(options.IsValid(&message)); +} + +} // namespace internal +} // namespace ceres