Example code for cubic interpolation. Example code demonstrates how a sampled function can be minimized. Also, in the process uncovered some deficiencies in the CubicInterpolator and BicubicInterpolator interfaces and fixed them. Change-Id: I18c8f670fbee076bf1e94d1f45c7477fd71640e8
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 62df0c0..cd53a1c 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt
@@ -65,10 +65,12 @@ ADD_EXECUTABLE(robust_curve_fitting robust_curve_fitting.cc) TARGET_LINK_LIBRARIES(robust_curve_fitting ceres) -ADD_EXECUTABLE(simple_bundle_adjuster - simple_bundle_adjuster.cc) +ADD_EXECUTABLE(simple_bundle_adjuster simple_bundle_adjuster.cc) TARGET_LINK_LIBRARIES(simple_bundle_adjuster ceres) +ADD_EXECUTABLE(sampled_function sampled_function.cc) +TARGET_LINK_LIBRARIES(sampled_function ceres) + IF (GFLAGS) # The CERES_GFLAGS_NAMESPACE compile definition is NOT stored in # CERES_COMPILE_OPTIONS (and thus config.h) as Ceres itself does not
diff --git a/examples/sampled_function.cc b/examples/sampled_function.cc new file mode 100644 index 0000000..700ef6d --- /dev/null +++ b/examples/sampled_function.cc
@@ -0,0 +1,88 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2014 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: sameeragarwal@google.com (Sameer Agarwal) +// +// A simple example of optimizing a sampled function by using cubic +// interpolation. + +#include "ceres/ceres.h" +#include "ceres/cubic_interpolation.h" +#include "glog/logging.h" + +using ceres::CubicInterpolator; +using ceres::AutoDiffCostFunction; +using ceres::CostFunction; +using ceres::Problem; +using ceres::Solver; +using ceres::Solve; + +// A simple cost functor that interfaces an interpolated table of +// values with automatic differentiation. +struct InterpolatedCostFunctor { + explicit InterpolatedCostFunctor(const CubicInterpolator& interpolator) + : interpolator_(interpolator) { + } + + template<typename T> bool operator()(const T* x, T* residuals) const { + return interpolator_.Evaluate(*x, residuals); + } + + static CostFunction* Create(const CubicInterpolator& interpolator) { + return new AutoDiffCostFunction<InterpolatedCostFunctor, 1, 1>( + new InterpolatedCostFunctor(interpolator)); + } + + private: + const CubicInterpolator& interpolator_; +}; + +int main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + + // Evaluate the function f(x) = (x - 4.5)^2; + const int kNumSamples = 10; + double values[kNumSamples]; + for (int i = 0; i < 10; ++i) { + values[i] = (i - 4.5) * (i - 4.5); + } + + CubicInterpolator interpolator(values, kNumSamples); + + double x = 1.0; + Problem problem; + CostFunction* cost_function = InterpolatedCostFunctor::Create(interpolator); + problem.AddResidualBlock(cost_function, NULL, &x); + + Solver::Options options; + options.minimizer_progress_to_stdout = true; + Solver::Summary summary; + Solve(options, &problem, &summary); + std::cout << summary.BriefReport() << "\n"; + return 0; +}
diff --git a/include/ceres/cubic_interpolation.h b/include/ceres/cubic_interpolation.h index 2ade679..7e477c8 100644 --- a/include/ceres/cubic_interpolation.h +++ b/include/ceres/cubic_interpolation.h
@@ -70,7 +70,13 @@ // derivative. Returns false if x is out of bounds. bool Evaluate(double x, double* f, double* dfdx) const; - // Overload for Jets, which automatically accounts for the chain rule. + // The following two Evaluate overloads are needed for interfacing + // with automatic differentiation. The first is for when a scalar + // evaluation is done, and the second one is for when Jets are used. + bool Evaluate(const double& x, double* f) const { + return Evaluate(x, f, NULL); + } + template<typename JetT> bool Evaluate(const JetT& x, JetT* f) const { double dfdx; if (!Evaluate(x.a, &f->a, &dfdx)) { @@ -117,7 +123,13 @@ bool Evaluate(double r, double c, double* f, double* dfdr, double* dfdc) const; - // Overload for Jets, which automatically accounts for the chain rule. + // The following two Evaluate overloads are needed for interfacing + // with automatic differentiation. The first is for when a scalar + // evaluation is done, and the second one is for when Jets are used. + bool Evaluate(const double& r, const double& c, double* f) const { + return Evaluate(r, c, f, NULL, NULL); + } + template<typename JetT> bool Evaluate(const JetT& r, const JetT& c, JetT* f) const {