Moved AutoDiffCodeGen macros to a separate (public) header User defined cost functors have to include this new header instead of internal/expression_ref.h. This hides some complexity of ExpressionRef and reduces compile time outside of code generation mode. This patch also removes the dependency ExpressionRef->Jet. Change-Id: Ie3f93648775e14881dc5cfab213bbc983c6cfeee
diff --git a/include/ceres/internal/code_generator.h b/include/ceres/codegen/internal/code_generator.h similarity index 93% rename from include/ceres/internal/code_generator.h rename to include/ceres/codegen/internal/code_generator.h index d629907..01a1162 100644 --- a/include/ceres/internal/code_generator.h +++ b/include/ceres/codegen/internal/code_generator.h
@@ -28,11 +28,11 @@ // // Author: darius.rueckert@fau.de (Darius Rueckert) // -#ifndef CERES_PUBLIC_CODE_GENERATOR_H_ -#define CERES_PUBLIC_CODE_GENERATOR_H_ +#ifndef CERES_PUBLIC_CODEGEN_INTERNAL_CODE_GENERATOR_H_ +#define CERES_PUBLIC_CODEGEN_INTERNAL_CODE_GENERATOR_H_ -#include "ceres/internal/expression.h" -#include "ceres/internal/expression_graph.h" +#include "ceres/codegen/internal/expression.h" +#include "ceres/codegen/internal/expression_graph.h" #include <string> #include <vector> @@ -120,4 +120,4 @@ } // namespace internal } // namespace ceres -#endif // CERES_PUBLIC_CODE_GENERATOR_H_ +#endif // CERES_PUBLIC_CODEGEN_INTERNAL_CODE_GENERATOR_H_
diff --git a/include/ceres/internal/expression.h b/include/ceres/codegen/internal/expression.h similarity index 98% rename from include/ceres/internal/expression.h rename to include/ceres/codegen/internal/expression.h index 09d4ceb..bf07a21 100644 --- a/include/ceres/internal/expression.h +++ b/include/ceres/codegen/internal/expression.h
@@ -167,8 +167,8 @@ // expand to the if/else keywords. See expression_ref.h for the exact // definition. // -#ifndef CERES_PUBLIC_EXPRESSION_H_ -#define CERES_PUBLIC_EXPRESSION_H_ +#ifndef CERES_PUBLIC_CODEGEN_INTERNAL_EXPRESSION_H_ +#define CERES_PUBLIC_CODEGEN_INTERNAL_EXPRESSION_H_ #include <string> #include <vector> @@ -370,4 +370,5 @@ } // namespace internal } // namespace ceres -#endif + +#endif // CERES_PUBLIC_CODEGEN_INTERNAL_EXPRESSION_H_
diff --git a/include/ceres/internal/expression_graph.h b/include/ceres/codegen/internal/expression_graph.h similarity index 96% rename from include/ceres/internal/expression_graph.h rename to include/ceres/codegen/internal/expression_graph.h index 644af0b..22fbba6 100644 --- a/include/ceres/internal/expression_graph.h +++ b/include/ceres/codegen/internal/expression_graph.h
@@ -28,8 +28,8 @@ // // Author: darius.rueckert@fau.de (Darius Rueckert) -#ifndef CERES_PUBLIC_EXPRESSION_TREE_H_ -#define CERES_PUBLIC_EXPRESSION_TREE_H_ +#ifndef CERES_PUBLIC_CODEGEN_INTERNAL_EXPRESSION_GRAPH_H_ +#define CERES_PUBLIC_CODEGEN_INTERNAL_EXPRESSION_GRAPH_H_ #include <vector> @@ -117,4 +117,5 @@ } // namespace internal } // namespace ceres -#endif + +#endif // CERES_PUBLIC_CODEGEN_INTERNAL_EXPRESSION_GRAPH_H_
diff --git a/include/ceres/internal/expression_ref.h b/include/ceres/codegen/internal/expression_ref.h similarity index 85% rename from include/ceres/internal/expression_ref.h rename to include/ceres/codegen/internal/expression_ref.h index b79ccb4..c26a4ab 100644 --- a/include/ceres/internal/expression_ref.h +++ b/include/ceres/codegen/internal/expression_ref.h
@@ -33,8 +33,8 @@ #define CERES_PUBLIC_EXPRESSION_REF_H_ #include <string> -#include "ceres/jet.h" -#include "expression.h" +#include "ceres/codegen/internal/types.h" +#include "ceres/codegen/internal/expression.h" namespace ceres { namespace internal { @@ -205,38 +205,13 @@ const ComparisonExpressionRef& y); ComparisonExpressionRef operator!(const ComparisonExpressionRef& x); -// This struct is used to mark numbers which are constant over -// multiple invocations but can differ between instances. -template <typename T> -struct InputAssignment { - using ReturnType = T; - static inline ReturnType Get(double v, const char* /* unused */) { return v; } -}; - template <> struct InputAssignment<ExpressionRef> { using ReturnType = ExpressionRef; static inline ReturnType Get(double /* unused */, const char* name) { - return ExpressionRef::Create(Expression::CreateInputAssignment(name)); - } -}; - -template <typename G, int N> -struct InputAssignment<Jet<G, N>> { - using ReturnType = Jet<G, N>; - static inline Jet<G, N> Get(double v, const char* /* unused */) { - return Jet<G, N>(v); - } -}; - -template <int N> -struct InputAssignment<Jet<ExpressionRef, N>> { - using ReturnType = Jet<ExpressionRef, N>; - static inline ReturnType Get(double /* unused */, const char* name) { // Note: The scalar value of v will be thrown away, because we don't need it // during code generation. - return Jet<ExpressionRef, N>( - ExpressionRef::Create(Expression::CreateInputAssignment(name))); + return ExpressionRef::Create(Expression::CreateInputAssignment(name)); } }; @@ -246,13 +221,6 @@ return InputAssignment<T>::Get(v, name); } -// This macro should be used for local variables in cost functors. Using local -// variables directly, will compile their current value into the code. -// Example: -// T x = CERES_LOCAL_VARIABLE(observed_x_); -#define CERES_LOCAL_VARIABLE(_v) \ - ceres::internal::MakeInputAssignment<T>(_v, #_v) - inline ExpressionRef MakeParameter(const std::string& name) { return ExpressionRef::Create(Expression::CreateInputAssignment(name)); } @@ -261,24 +229,8 @@ return ExpressionRef::Create(Expression::CreateOutputAssignment(v.id, name)); } -// The CERES_CODEGEN macro is defined by the build system only during code -// generation. In all other cases the CERES_IF/ELSE macros just expand to the -// if/else keywords. -#ifdef CERES_CODEGEN -#define CERES_IF(condition_) Expression::CreateIf((condition_).id); -#define CERES_ELSE Expression::CreateElse(); -#define CERES_ENDIF Expression::CreateEndIf(); -#else -// clang-format off -#define CERES_IF(condition_) if (condition_) { -#define CERES_ELSE } else { -#define CERES_ENDIF } -// clang-format on -#endif - } // namespace internal -// See jet.h for more info on this type. template <> struct ComparisonReturnType<internal::ExpressionRef> { using type = internal::ComparisonExpressionRef;
diff --git a/include/ceres/codegen/internal/types.h b/include/ceres/codegen/internal/types.h new file mode 100644 index 0000000..5d32d5b --- /dev/null +++ b/include/ceres/codegen/internal/types.h
@@ -0,0 +1,67 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) +// +#ifndef CERES_PUBLIC_CODEGEN_INTERNAL_TYPES_H_ +#define CERES_PUBLIC_CODEGEN_INTERNAL_TYPES_H_ + +#include "ceres/codegen/macros.h" + +namespace ceres { +// The return type of a comparison, for example from <, &&, ==. +// +// In the context of traditional Ceres Jet operations, this would +// always be a bool. However, in the autodiff code generation context, +// the return is always an expression, and so a different type must be +// used as a return from comparisons. +// +// In the autodiff codegen context, this function is overloaded so that 'type' +// is one of the autodiff code generation expression types. +template <typename T> +struct ComparisonReturnType { + using type = bool; +}; + +namespace internal { +// The InputAssignment struct is used to implement the CERES_LOCAL_VARIABLE +// macro defined in macros.h. The input is a double variable and the +// corresponding name as a string. During execution mode (T==double or +// T==Jet<double>) the variable name is unused and this function only returns +// the value. During code generation (T==ExpressionRef) the value is unused and +// an assignment expression is created. For example: +// v_0 = observed_x; +template <typename T> +struct InputAssignment { + static inline T Get(double v, const char* /* unused */) { return v; } +}; + +} // namespace internal +} // namespace ceres + +#endif // CERES_PUBLIC_CODEGEN_INTERNAL_TYPES_H_
diff --git a/include/ceres/codegen/macros.h b/include/ceres/codegen/macros.h new file mode 100644 index 0000000..9750075 --- /dev/null +++ b/include/ceres/codegen/macros.h
@@ -0,0 +1,117 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) +// +// This file defines the required macros to use local variables and if-else +// branches in the AutoDiffCodeGen system. This is also the only file that +// should be included by a user-defined cost functor. +// +// To generate code for your cost functor the following steps have to be +// implemented (see below for a full example): +// +// 1. Include this file +// 2. Wrap accesses to local variables in the CERES_LOCAL_VARIABLE macro. +// 3. Replace if, else by CERES_IF, CERES_ELSE and add CERES_ENDIF +// 4. Add a default constructor +// +// Example - my_cost_functor.h +// ======================================================== +// #include "ceres/rotation.h" +// #include "ceres/autodiff_codegen_macros.h" +// +// struct MyReprojectionError { +// MyReprojectionError(double observed_x, double observed_y) +// : observed_x(observed_x), observed_y(observed_y) {} +// +// // The cost functor must be default constructible! +// MyReprojectionError() = default; +// +// template <typename T> +// bool operator()(const T* const camera, +// const T* const point, +// T* residuals) const { +// T p[3]; +// AngleAxisRotatePoint(camera, point, p); +// p[0] += camera[3]; +// p[1] += camera[4]; +// p[2] += camera[5]; +// +// // The if block is written using the macros! +// CERES_IF(p[2] < T(0)) { +// p[0] = -p[0]; +// p[1] = -p[1]; +// p[2] = -p[2]; +// } CERES_ELSE { +// p[0] += T(1.0); +// }CERES_ENDIF; +// +// const T& focal = camera[6]; +// const T predicted_x = focal * p[0]; +// const T predicted_y = focal * p[1]; +// +// // The read-access to the local variables observed_x and observed_y are +// // wrapped in the CERES_LOCAL_VARIABLE macro! +// residuals[0] = predicted_x - CERES_LOCAL_VARIABLE(T, observed_x); +// residuals[1] = predicted_y - CERES_LOCAL_VARIABLE(T, observed_y); +// return true; +// } +// double observed_x; +// double observed_y; +// }; +// +// ======================================================== +// +// This file defines the following macros: +// +// CERES_LOCAL_VARIABLE +// CERES_IF +// CERES_ELSE +// CERES_ENDIF +// +#ifndef CERES_PUBLIC_CODEGEN_MACROS_H_ +#define CERES_PUBLIC_CODEGEN_MACROS_H_ + +// The CERES_CODEGEN macro is defined by the build system only during code +// generation. +#ifndef CERES_CODEGEN +#define CERES_LOCAL_VARIABLE(_template_type, _local_variable) (_local_variable) +#define CERES_IF(condition_) if (condition_) +#define CERES_ELSE else +#define CERES_ENDIF +#else +#define CERES_LOCAL_VARIABLE(_template_type, _local_variable) \ + ceres::internal::InputAssignment<_template_type>::Get(_local_variable, \ + #_local_variable) +#define CERES_IF(condition_) \ + ceres::internal::Expression::CreateIf((condition_).id); +#define CERES_ELSE ceres::internal::Expression::CreateElse(); +#define CERES_ENDIF ceres::internal::Expression::CreateEndIf(); +#endif + +#endif // CERES_PUBLIC_CODEGEN_MACROS_H_
diff --git a/include/ceres/jet.h b/include/ceres/jet.h index 25d11e6..fb7afce 100644 --- a/include/ceres/jet.h +++ b/include/ceres/jet.h
@@ -164,24 +164,11 @@ #include <string> #include "Eigen/Core" +#include "ceres/codegen/internal/types.h" #include "ceres/internal/port.h" namespace ceres { -// The return type of a Jet comparison, for example from <, &&, ==. -// -// In the context of traditional Ceres Jet operations, this would -// always be a bool. However, in the autodiff code generation context, -// the return is always an expression, and so a different type must be -// used as a return from comparisons. -// -// In the autodiff codegen context, this function is overloaded so that 'type' -// is one of the autodiff code generation expression types. -template <typename T> -struct ComparisonReturnType { - using type = bool; -}; - template <typename T, int N> struct Jet { enum { DIMENSION = N }; @@ -894,6 +881,18 @@ return s; } +namespace internal { +// In the context of AutoDiffCodeGen, local variables can be added using the +// CERES_LOCAL_VARIABLE macro defined in ceres/codegen/macros.h. This partial +// specialization defined how local variables of type double are converted to +// Jet<T>. +template <typename T, int N> +struct InputAssignment<Jet<T, N>> { + static inline Jet<T, N> Get(double v, const char* name) { + return Jet<T, N>(InputAssignment<T>::Get(v, name)); + } +}; +} // namespace internal } // namespace ceres namespace Eigen {
diff --git a/internal/ceres/code_generator.cc b/internal/ceres/code_generator.cc index 3af4bfb..3743ef6 100644 --- a/internal/ceres/code_generator.cc +++ b/internal/ceres/code_generator.cc
@@ -28,7 +28,7 @@ // // Author: darius.rueckert@fau.de (Darius Rueckert) -#include "ceres/internal/code_generator.h" +#include "ceres/codegen/internal/code_generator.h" #include <sstream> #include "assert.h" #include "glog/logging.h"
diff --git a/internal/ceres/code_generator_test.cc b/internal/ceres/code_generator_test.cc index c8908aa..2ac7ee3 100644 --- a/internal/ceres/code_generator_test.cc +++ b/internal/ceres/code_generator_test.cc
@@ -30,10 +30,9 @@ // #define CERES_CODEGEN -#include "ceres/internal/code_generator.h" -#include "ceres/internal/expression_graph.h" -#include "ceres/internal/expression_ref.h" - +#include "ceres/codegen/internal/code_generator.h" +#include "ceres/codegen/internal/expression_graph.h" +#include "ceres/codegen/internal/expression_ref.h" #include "gtest/gtest.h" namespace ceres { @@ -82,7 +81,7 @@ TEST(CodeGenerator, INPUT_ASSIGNMENT) { double local_variable = 5.0; StartRecordingExpressions(); - T a = CERES_LOCAL_VARIABLE(local_variable); + T a = CERES_LOCAL_VARIABLE(T, local_variable); T b = MakeParameter("parameters[0][0]"); T c = a + b; auto graph = StopRecordingExpressions();
diff --git a/internal/ceres/conditional_expressions_test.cc b/internal/ceres/conditional_expressions_test.cc index c2220b5..3fcb7a3 100644 --- a/internal/ceres/conditional_expressions_test.cc +++ b/internal/ceres/conditional_expressions_test.cc
@@ -31,9 +31,8 @@ #define CERES_CODEGEN -#include "ceres/internal/expression_graph.h" -#include "ceres/internal/expression_ref.h" - +#include "ceres/codegen/internal/expression_graph.h" +#include "ceres/codegen/internal/expression_ref.h" #include "gtest/gtest.h" namespace ceres {
diff --git a/internal/ceres/expression.cc b/internal/ceres/expression.cc index 54450b8..b2471e7 100644 --- a/internal/ceres/expression.cc +++ b/internal/ceres/expression.cc
@@ -28,10 +28,9 @@ // // Author: darius.rueckert@fau.de (Darius Rueckert) -#include "ceres/internal/expression.h" +#include "ceres/codegen/internal/expression.h" #include <algorithm> - -#include "ceres/internal/expression_graph.h" +#include "ceres/codegen/internal/expression_graph.h" #include "glog/logging.h" namespace ceres {
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc index 511a98b..6d02c1c 100644 --- a/internal/ceres/expression_graph.cc +++ b/internal/ceres/expression_graph.cc
@@ -28,7 +28,7 @@ // // Author: darius.rueckert@fau.de (Darius Rueckert) -#include "ceres/internal/expression_graph.h" +#include "ceres/codegen/internal/expression_graph.h" #include "glog/logging.h" namespace ceres {
diff --git a/internal/ceres/expression_graph_test.cc b/internal/ceres/expression_graph_test.cc index faaa10c..23e68ee 100644 --- a/internal/ceres/expression_graph_test.cc +++ b/internal/ceres/expression_graph_test.cc
@@ -30,8 +30,8 @@ // // Test expression creation and logic. -#include "ceres/internal/expression_graph.h" -#include "ceres/internal/expression_ref.h" +#include "ceres/codegen/internal/expression_graph.h" +#include "ceres/codegen/internal/expression_ref.h" #include "gtest/gtest.h"
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc index 5a330b9..bb8f920 100644 --- a/internal/ceres/expression_ref.cc +++ b/internal/ceres/expression_ref.cc
@@ -28,7 +28,7 @@ // // Author: darius.rueckert@fau.de (Darius Rueckert) -#include "ceres/internal/expression_ref.h" +#include "ceres/codegen/internal/expression_ref.h" #include "glog/logging.h" namespace ceres {
diff --git a/internal/ceres/expression_test.cc b/internal/ceres/expression_test.cc index 84d1fa4..25fd7f6 100644 --- a/internal/ceres/expression_test.cc +++ b/internal/ceres/expression_test.cc
@@ -31,8 +31,8 @@ #define CERES_CODEGEN -#include "ceres/internal/expression_graph.h" -#include "ceres/internal/expression_ref.h" +#include "ceres/codegen/internal/expression_graph.h" +#include "ceres/codegen/internal/expression_ref.h" #include "ceres/jet.h" #include "gtest/gtest.h"