blob: 6cd1b1393f7f20be0fae4700437a6ba81033ebef [file] [log] [blame]
Sameer Agarwal056ba9b2019-01-01 06:24:15 -08001// Ceres Solver - A fast non-linear least squares minimizer
Sergiu Deitsch91773742023-06-10 21:01:25 +02002// Copyright 2024 Google Inc. All rights reserved.
Sameer Agarwal056ba9b2019-01-01 06:24:15 -08003// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30
Sameer Agarwal056ba9b2019-01-01 06:24:15 -080031#ifndef CERES_PUBLIC_AUTODIFF_FIRST_ORDER_FUNCTION_H_
32#define CERES_PUBLIC_AUTODIFF_FIRST_ORDER_FUNCTION_H_
33
34#include <memory>
Sergiu Deitsch91773742023-06-10 21:01:25 +020035#include <type_traits>
Sameer Agarwal056ba9b2019-01-01 06:24:15 -080036
37#include "ceres/first_order_function.h"
38#include "ceres/internal/eigen.h"
39#include "ceres/internal/fixed_array.h"
40#include "ceres/jet.h"
41#include "ceres/types.h"
42
43namespace ceres {
44
45// Create FirstOrderFunctions as needed by the GradientProblem
46// framework, with gradients computed via automatic
47// differentiation. For more information on automatic differentiation,
48// see the wikipedia article at
49// http://en.wikipedia.org/wiki/Automatic_differentiation
50//
51// To get an auto differentiated function, you must define a class
52// with a templated operator() (a functor) that computes the cost
53// function in terms of the template parameter T. The autodiff
54// framework substitutes appropriate "jet" objects for T in order to
55// compute the derivative when necessary, but this is hidden, and you
56// should write the function as if T were a scalar type (e.g. a
57// double-precision floating point number).
58//
59// The function must write the computed value in the last argument
60// (the only non-const one) and return true to indicate
61// success.
62//
63// For example, consider a scalar error e = x'y - a, where both x and y are
64// two-dimensional column vector parameters, the prime sign indicates
65// transposition, and a is a constant.
66//
67// To write an auto-differentiable FirstOrderFunction for the above model, first
68// define the object
69//
70// class QuadraticCostFunctor {
71// public:
72// explicit QuadraticCostFunctor(double a) : a_(a) {}
73// template <typename T>
74// bool operator()(const T* const xy, T* cost) const {
75// const T* const x = xy;
76// const T* const y = xy + 2;
77// *cost = x[0] * y[0] + x[1] * y[1] - T(a_);
78// return true;
79// }
80//
81// private:
82// double a_;
83// };
84//
85// Note that in the declaration of operator() the input parameters xy come
86// first, and are passed as const pointers to arrays of T. The
87// output is the last parameter.
88//
Johannes Beck25e1cdb2019-03-17 21:35:49 +010089// Then given this class definition, the auto differentiated FirstOrderFunction
90// for it can be constructed as follows.
Sameer Agarwal056ba9b2019-01-01 06:24:15 -080091//
92// FirstOrderFunction* function =
93// new AutoDiffFirstOrderFunction<QuadraticCostFunctor, 4>(
94// new QuadraticCostFunctor(1.0)));
95//
96// In the instantiation above, the template parameters following
97// "QuadraticCostFunctor", "4", describe the functor as computing a
98// 1-dimensional output from a four dimensional vector.
99//
100// WARNING: Since the functor will get instantiated with different types for
101// T, you must convert from other numeric types to T before mixing
102// computations with other variables of type T. In the example above, this is
103// seen where instead of using a_ directly, a_ is wrapped with T(a_).
104
105template <typename FirstOrderFunctor, int kNumParameters>
Sameer Agarwal8fe8ebc2022-02-18 15:51:17 -0800106class AutoDiffFirstOrderFunction final : public FirstOrderFunction {
Sameer Agarwal056ba9b2019-01-01 06:24:15 -0800107 public:
108 // Takes ownership of functor.
109 explicit AutoDiffFirstOrderFunction(FirstOrderFunctor* functor)
Sergiu Deitsch91773742023-06-10 21:01:25 +0200110 : AutoDiffFirstOrderFunction{
111 std::unique_ptr<FirstOrderFunctor>{functor}} {}
112
113 explicit AutoDiffFirstOrderFunction(
114 std::unique_ptr<FirstOrderFunctor> functor)
115 : functor_(std::move(functor)) {
Sameer Agarwal056ba9b2019-01-01 06:24:15 -0800116 static_assert(kNumParameters > 0, "kNumParameters must be positive");
117 }
118
Sergiu Deitsch91773742023-06-10 21:01:25 +0200119 template <class... Args,
120 std::enable_if_t<std::is_constructible_v<FirstOrderFunctor,
121 Args&&...>>* = nullptr>
122 explicit AutoDiffFirstOrderFunction(Args&&... args)
123 : AutoDiffFirstOrderFunction{
124 std::make_unique<FirstOrderFunctor>(std::forward<Args>(args)...)} {}
125
Sameer Agarwale4577dd2019-07-13 11:19:27 +0200126 bool Evaluate(const double* const parameters,
127 double* cost,
128 double* gradient) const override {
Sameer Agarwal056ba9b2019-01-01 06:24:15 -0800129 if (gradient == nullptr) {
130 return (*functor_)(parameters, cost);
131 }
132
Sergiu Deitschc8658c82022-02-20 02:22:17 +0100133 using JetT = Jet<double, kNumParameters>;
Sameer Agarwal056ba9b2019-01-01 06:24:15 -0800134 internal::FixedArray<JetT, (256 * 7) / sizeof(JetT)> x(kNumParameters);
135 for (int i = 0; i < kNumParameters; ++i) {
136 x[i].a = parameters[i];
137 x[i].v.setZero();
138 x[i].v[i] = 1.0;
139 }
140
141 JetT output;
142 output.a = kImpossibleValue;
143 output.v.setConstant(kImpossibleValue);
144
Johannes Beck25e1cdb2019-03-17 21:35:49 +0100145 if (!(*functor_)(x.data(), &output)) {
Sameer Agarwal056ba9b2019-01-01 06:24:15 -0800146 return false;
147 }
148
149 *cost = output.a;
150 VectorRef(gradient, kNumParameters) = output.v;
151 return true;
152 }
153
Sameer Agarwale4577dd2019-07-13 11:19:27 +0200154 int NumParameters() const override { return kNumParameters; }
Sameer Agarwal056ba9b2019-01-01 06:24:15 -0800155
Sameer Agarwal8fe8ebc2022-02-18 15:51:17 -0800156 const FirstOrderFunctor& functor() const { return *functor_; }
Alex Stewartce966902022-02-06 21:14:16 +0000157
Sameer Agarwal056ba9b2019-01-01 06:24:15 -0800158 private:
159 std::unique_ptr<FirstOrderFunctor> functor_;
160};
161
162} // namespace ceres
163
164#endif // CERES_PUBLIC_AUTODIFF_FIRST_ORDER_FUNCTION_H_