Refactor FunctionSample & LineSearchFunction
1. Move FunctionSample to its own .h/.cc files.
2. Migrate LineSearchFunction::Evaluate to use FunctionSample
for input and output.
Change-Id: I8bfb97e1900d95a4686c9621dda5b584458b45c0
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt
index 357fae7..21669d0 100644
--- a/internal/ceres/CMakeLists.txt
+++ b/internal/ceres/CMakeLists.txt
@@ -65,6 +65,7 @@
evaluator.cc
eigensparse.cc
file.cc
+ function_sample.cc
gradient_checker.cc
gradient_checking_cost_function.cc
gradient_problem.cc
diff --git a/internal/ceres/function_sample.cc b/internal/ceres/function_sample.cc
new file mode 100644
index 0000000..01f3136
--- /dev/null
+++ b/internal/ceres/function_sample.cc
@@ -0,0 +1,44 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2015 Google Inc. All rights reserved.
+// http://ceres-solver.org/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: sameeragarwal@google.com (Sameer Agarwal)
+
+#include "ceres/function_sample.h"
+#include "ceres/stringprintf.h"
+
+namespace ceres {
+namespace internal {
+
+std::string FunctionSample::ToDebugString() const {
+ return StringPrintf("[x: %.8e, value: %.8e, gradient: %.8e, "
+ "value_is_valid: %d, gradient_is_valid: %d]",
+ x, value, gradient, value_is_valid, gradient_is_valid);
+}
+
+} // namespace internal
+} // namespace ceres
diff --git a/internal/ceres/function_sample.h b/internal/ceres/function_sample.h
new file mode 100644
index 0000000..e4356c6
--- /dev/null
+++ b/internal/ceres/function_sample.h
@@ -0,0 +1,61 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2017 Google Inc. All rights reserved.
+// http://ceres-solver.org/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: sameeragarwal@google.com (Sameer Agarwal)
+
+#ifndef CERES_INTERNAL_FUNCTION_SAMPLE_H_
+#define CERES_INTERNAL_FUNCTION_SAMPLE_H_
+
+#include <string>
+
+namespace ceres {
+namespace internal {
+
+// Clients can use this struct to communicate the value of the
+// function and or its gradient at a given point x.
+struct FunctionSample {
+ FunctionSample()
+ : x(0.0),
+ value(0.0),
+ value_is_valid(false),
+ gradient(0.0),
+ gradient_is_valid(false) {
+ }
+ std::string ToDebugString() const;
+
+ double x;
+ double value; // value = f(x)
+ bool value_is_valid;
+ double gradient; // gradient = f'(x)
+ bool gradient_is_valid;
+};
+
+} // namespace internal
+} // namespace ceres
+
+#endif // CERES_INTERNAL_FUNCTION_SAMPLE_H_
diff --git a/internal/ceres/gradient_problem_solver.cc b/internal/ceres/gradient_problem_solver.cc
index ba36dd7..488bff5 100644
--- a/internal/ceres/gradient_problem_solver.cc
+++ b/internal/ceres/gradient_problem_solver.cc
@@ -262,10 +262,10 @@
static_cast<int>(iterations.size()));
StringAppendF(&report, "\nTime (in seconds):\n");
- StringAppendF(&report, "\n Cost evaluation %23.4f(%d)\n",
+ StringAppendF(&report, "\n Cost evaluation %23.4f (%d)\n",
cost_evaluation_time_in_seconds,
num_cost_evaluations);
- StringAppendF(&report, " Gradient evaluation %23.4f(%d)\n",
+ StringAppendF(&report, " Gradient evaluation %23.4f (%d)\n",
gradient_evaluation_time_in_seconds,
num_gradient_evaluations);
StringAppendF(&report, " Polynomial minimization %17.4f\n",
diff --git a/internal/ceres/line_search.cc b/internal/ceres/line_search.cc
index 9cdcb7b..3d946dc 100644
--- a/internal/ceres/line_search.cc
+++ b/internal/ceres/line_search.cc
@@ -33,14 +33,15 @@
#include <iomanip>
#include <iostream> // NOLINT
-#include "glog/logging.h"
#include "ceres/evaluator.h"
-#include "ceres/internal/eigen.h"
#include "ceres/fpclassify.h"
+#include "ceres/function_sample.h"
+#include "ceres/internal/eigen.h"
#include "ceres/map_util.h"
#include "ceres/polynomial.h"
#include "ceres/stringprintf.h"
#include "ceres/wall_time.h"
+#include "glog/logging.h"
namespace ceres {
namespace internal {
@@ -124,27 +125,37 @@
direction_ = direction;
}
-bool LineSearchFunction::Evaluate(double x, double* f, double* g) {
- scaled_direction_ = x * direction_;
+void LineSearchFunction::Evaluate(const double x,
+ const bool evaluate_gradient,
+ FunctionSample* output) {
+ output->x = x;
+ output->value_is_valid = false;
+ output->gradient_is_valid = false;
+
+ scaled_direction_ = output->x * direction_;
if (!evaluator_->Plus(position_.data(),
scaled_direction_.data(),
evaluation_point_.data())) {
- return false;
+ return;
}
- if (g == NULL) {
- return (evaluator_->Evaluate(evaluation_point_.data(),
- f, NULL, NULL, NULL) &&
- IsFinite(*f));
+ const bool eval_status =
+ evaluator_->Evaluate(evaluation_point_.data(),
+ &(output->value),
+ NULL,
+ evaluate_gradient ? gradient_.data() : NULL,
+ NULL);
+
+ if (!eval_status || !IsFinite(output->value)) {
+ return;
}
- if (!evaluator_->Evaluate(evaluation_point_.data(),
- f, NULL, gradient_.data(), NULL)) {
- return false;
+ output->value_is_valid = true;
+ if (evaluate_gradient) {
+ output->gradient = direction_.dot(gradient_);
}
-
- *g = direction_.dot(gradient_);
- return IsFinite(*f) && IsFinite(*g);
+ output->gradient_is_valid = IsFinite(output->gradient);
+ return;
}
double LineSearchFunction::DirectionInfinityNorm() const {
@@ -289,31 +300,22 @@
const FunctionSample initial_position =
ValueAndGradientSample(0.0, initial_cost, initial_gradient);
- FunctionSample previous = ValueAndGradientSample(0.0, 0.0, 0.0);
- previous.value_is_valid = false;
-
- FunctionSample current = ValueAndGradientSample(step_size_estimate, 0.0, 0.0);
- current.value_is_valid = false;
+ const double descent_direction_max_norm = function->DirectionInfinityNorm();
+ FunctionSample previous;
+ FunctionSample current;
// As the Armijo line search algorithm always uses the initial point, for
// which both the function value and derivative are known, when fitting a
// minimizing polynomial, we can fit up to a quadratic without requiring the
// gradient at the current query point.
- const bool interpolation_uses_gradient_at_current_sample =
- options().interpolation_type == CUBIC;
- const double descent_direction_max_norm = function->DirectionInfinityNorm();
+ const bool kEvaluateGradient = options().interpolation_type == CUBIC;
++summary->num_function_evaluations;
- if (interpolation_uses_gradient_at_current_sample) {
+ if (kEvaluateGradient) {
++summary->num_gradient_evaluations;
}
- current.value_is_valid =
- function->Evaluate(current.x,
- ¤t.value,
- interpolation_uses_gradient_at_current_sample
- ? ¤t.gradient : NULL);
- current.gradient_is_valid =
- interpolation_uses_gradient_at_current_sample && current.value_is_valid;
+
+ function->Evaluate(step_size_estimate, kEvaluateGradient, ¤t);
while (!current.value_is_valid ||
current.value > (initial_cost
+ options().sufficient_decrease
@@ -354,19 +356,13 @@
}
previous = current;
- current.x = step_size;
++summary->num_function_evaluations;
- if (interpolation_uses_gradient_at_current_sample) {
+ if (kEvaluateGradient) {
++summary->num_gradient_evaluations;
}
- current.value_is_valid =
- function->Evaluate(current.x,
- ¤t.value,
- interpolation_uses_gradient_at_current_sample
- ? ¤t.gradient : NULL);
- current.gradient_is_valid =
- interpolation_uses_gradient_at_current_sample && current.value_is_valid;
+
+ function->Evaluate(step_size, kEvaluateGradient, ¤t);
}
summary->optimal_step_size = current.x;
@@ -515,8 +511,7 @@
LineSearchFunction* function = options().function;
FunctionSample previous = initial_position;
- FunctionSample current = ValueAndGradientSample(step_size_estimate, 0.0, 0.0);
- current.value_is_valid = false;
+ FunctionSample current;
const double descent_direction_max_norm =
function->DirectionInfinityNorm();
@@ -535,12 +530,8 @@
// issues).
++summary->num_function_evaluations;
++summary->num_gradient_evaluations;
- current.value_is_valid =
- function->Evaluate(current.x,
- ¤t.value,
- ¤t.gradient);
- current.gradient_is_valid = current.value_is_valid;
-
+ const bool kEvaluateGradient = true;
+ function->Evaluate(step_size_estimate, kEvaluateGradient, ¤t);
while (true) {
++summary->num_iterations;
@@ -670,15 +661,9 @@
}
previous = current.value_is_valid ? current : previous;
- current.x = step_size;
-
++summary->num_function_evaluations;
++summary->num_gradient_evaluations;
- current.value_is_valid =
- function->Evaluate(current.x,
- ¤t.value,
- ¤t.gradient);
- current.gradient_is_valid = current.value_is_valid;
+ function->Evaluate(step_size, kEvaluateGradient, ¤t);
}
// Ensure that even if a valid bracket was found, we will only mark a zoom
@@ -799,7 +784,7 @@
const FunctionSample unused_previous;
DCHECK(!unused_previous.value_is_valid);
const double polynomial_minimization_start_time = WallTimeInSeconds();
- solution->x =
+ const double step_size =
this->InterpolatingPolynomialMinimizingStepSize(
options().interpolation_type,
lower_bound_step,
@@ -823,12 +808,9 @@
// to numerical issues).
++summary->num_function_evaluations;
++summary->num_gradient_evaluations;
- solution->value_is_valid =
- function->Evaluate(solution->x,
- &solution->value,
- &solution->gradient);
- solution->gradient_is_valid = solution->value_is_valid;
- if (!solution->value_is_valid) {
+ const bool kEvaluateGradient = true;
+ function->Evaluate(step_size, kEvaluateGradient, solution);
+ if (!solution->value_is_valid || !solution->gradient_is_valid) {
summary->error =
StringPrintf("Line search failed: Wolfe Zoom phase found "
"step_size: %.5e, for which function is invalid, "
diff --git a/internal/ceres/line_search.h b/internal/ceres/line_search.h
index 6a21cbe..b3f03b4 100644
--- a/internal/ceres/line_search.h
+++ b/internal/ceres/line_search.h
@@ -234,6 +234,7 @@
public:
explicit LineSearchFunction(Evaluator* evaluator);
void Init(const Vector& position, const Vector& direction);
+
// Evaluate the line search objective
//
// f(x) = p(position + x * direction)
@@ -241,11 +242,16 @@
// Where, p is the objective function of the general optimization
// problem.
//
- // g is the gradient f'(x) at x.
+ // evaluate_gradient controls whether the gradient will be evaluated
+ // or not.
//
- // f must not be null. The gradient is computed only if g is not null.
- bool Evaluate(double x, double* f, double* g);
+ // On return output->value_is_valid and output->gradient_is_valid
+ // indicate whether output->value and output->gradient can be used
+ // or not.
+ void Evaluate(double x, bool evaluate_gradient, FunctionSample* output);
+
double DirectionInfinityNorm() const;
+
// Resets to now, the start point for the results from TimeStatistics().
void ResetTimeStatistics();
void TimeStatistics(double* cost_evaluation_time_in_seconds,
diff --git a/internal/ceres/line_search_minimizer.cc b/internal/ceres/line_search_minimizer.cc
index ca1bc6c..9516ba2 100644
--- a/internal/ceres/line_search_minimizer.cc
+++ b/internal/ceres/line_search_minimizer.cc
@@ -130,7 +130,7 @@
iteration_summary.linear_solver_iterations = 0;
iteration_summary.step_solver_time_in_seconds = 0;
- // Do initial cost and Jacobian evaluation.
+ // Do initial cost and gradient evaluation.
if (!Evaluate(evaluator, x, ¤t_state, &summary->message)) {
summary->termination_type = FAILURE;
summary->message = "Initial cost and jacobian evaluation failed. "
@@ -142,9 +142,8 @@
summary->initial_cost = current_state.cost + summary->fixed_cost;
iteration_summary.cost = current_state.cost + summary->fixed_cost;
- iteration_summary.gradient_max_norm = current_state.gradient_max_norm;
iteration_summary.gradient_norm = sqrt(current_state.gradient_squared_norm);
-
+ iteration_summary.gradient_max_norm = current_state.gradient_max_norm;
if (iteration_summary.gradient_max_norm <= options.gradient_tolerance) {
summary->message = StringPrintf("Gradient tolerance reached. "
"Gradient max norm: %e <= %e",
diff --git a/internal/ceres/polynomial.cc b/internal/ceres/polynomial.cc
index aef17bb..6462bdd 100644
--- a/internal/ceres/polynomial.cc
+++ b/internal/ceres/polynomial.cc
@@ -36,14 +36,13 @@
#include <vector>
#include "Eigen/Dense"
+#include "ceres/function_sample.h"
#include "ceres/internal/port.h"
-#include "ceres/stringprintf.h"
#include "glog/logging.h"
namespace ceres {
namespace internal {
-using std::string;
using std::vector;
namespace {
@@ -327,12 +326,6 @@
}
}
-string FunctionSample::ToDebugString() const {
- return StringPrintf("[x: %.8e, value: %.8e, gradient: %.8e, "
- "value_is_valid: %d, gradient_is_valid: %d]",
- x, value, gradient, value_is_valid, gradient_is_valid);
-}
-
Vector FindInterpolatingPolynomial(const vector<FunctionSample>& samples) {
const int num_samples = samples.size();
int num_constraints = 0;
diff --git a/internal/ceres/polynomial.h b/internal/ceres/polynomial.h
index 09a64c5..3e09bae 100644
--- a/internal/ceres/polynomial.h
+++ b/internal/ceres/polynomial.h
@@ -32,7 +32,6 @@
#ifndef CERES_INTERNAL_POLYNOMIAL_SOLVER_H_
#define CERES_INTERNAL_POLYNOMIAL_SOLVER_H_
-#include <string>
#include <vector>
#include "ceres/internal/eigen.h"
#include "ceres/internal/port.h"
@@ -40,6 +39,8 @@
namespace ceres {
namespace internal {
+struct FunctionSample;
+
// All polynomials are assumed to be the form
//
// sum_{i=0}^N polynomial(i) x^{N-i}.
@@ -84,27 +85,6 @@
double* optimal_x,
double* optimal_value);
-// Structure for storing sample values of a function.
-//
-// Clients can use this struct to communicate the value of the
-// function and or its gradient at a given point x.
-struct FunctionSample {
- FunctionSample()
- : x(0.0),
- value(0.0),
- value_is_valid(false),
- gradient(0.0),
- gradient_is_valid(false) {
- }
- std::string ToDebugString() const;
-
- double x;
- double value; // value = f(x)
- bool value_is_valid;
- double gradient; // gradient = f'(x)
- bool gradient_is_valid;
-};
-
// Given a set of function value and/or gradient samples, find a
// polynomial whose value and gradients are exactly equal to the ones
// in samples.
diff --git a/internal/ceres/polynomial_test.cc b/internal/ceres/polynomial_test.cc
index d7026ed..00c8534 100644
--- a/internal/ceres/polynomial_test.cc
+++ b/internal/ceres/polynomial_test.cc
@@ -36,6 +36,7 @@
#include <cstddef>
#include <algorithm>
#include "gtest/gtest.h"
+#include "ceres/function_sample.h"
#include "ceres/test_util.h"
namespace ceres {
diff --git a/jni/Android.mk b/jni/Android.mk
index 551cd45..8dbb593 100644
--- a/jni/Android.mk
+++ b/jni/Android.mk
@@ -148,6 +148,7 @@
$(CERES_SRC_PATH)/eigensparse.cc \
$(CERES_SRC_PATH)/evaluator.cc \
$(CERES_SRC_PATH)/file.cc \
+ $(CERES_SRC_PATH)/function_sample.cc \
$(CERES_SRC_PATH)/gradient_checker.cc \
$(CERES_SRC_PATH)/gradient_checking_cost_function.cc \
$(CERES_SRC_PATH)/gradient_problem.cc \