Moved AutoDiffCodeGen macros to a separate (public) header
User defined cost functors have to include this new header
instead of internal/expression_ref.h. This hides some
complexity of ExpressionRef and reduces compile time outside
of code generation mode.
This patch also removes the dependency ExpressionRef->Jet.
Change-Id: Ie3f93648775e14881dc5cfab213bbc983c6cfeee
diff --git a/include/ceres/internal/code_generator.h b/include/ceres/codegen/internal/code_generator.h
similarity index 93%
rename from include/ceres/internal/code_generator.h
rename to include/ceres/codegen/internal/code_generator.h
index d629907..01a1162 100644
--- a/include/ceres/internal/code_generator.h
+++ b/include/ceres/codegen/internal/code_generator.h
@@ -28,11 +28,11 @@
//
// Author: darius.rueckert@fau.de (Darius Rueckert)
//
-#ifndef CERES_PUBLIC_CODE_GENERATOR_H_
-#define CERES_PUBLIC_CODE_GENERATOR_H_
+#ifndef CERES_PUBLIC_CODEGEN_INTERNAL_CODE_GENERATOR_H_
+#define CERES_PUBLIC_CODEGEN_INTERNAL_CODE_GENERATOR_H_
-#include "ceres/internal/expression.h"
-#include "ceres/internal/expression_graph.h"
+#include "ceres/codegen/internal/expression.h"
+#include "ceres/codegen/internal/expression_graph.h"
#include <string>
#include <vector>
@@ -120,4 +120,4 @@
} // namespace internal
} // namespace ceres
-#endif // CERES_PUBLIC_CODE_GENERATOR_H_
+#endif // CERES_PUBLIC_CODEGEN_INTERNAL_CODE_GENERATOR_H_
diff --git a/include/ceres/internal/expression.h b/include/ceres/codegen/internal/expression.h
similarity index 98%
rename from include/ceres/internal/expression.h
rename to include/ceres/codegen/internal/expression.h
index 09d4ceb..bf07a21 100644
--- a/include/ceres/internal/expression.h
+++ b/include/ceres/codegen/internal/expression.h
@@ -167,8 +167,8 @@
// expand to the if/else keywords. See expression_ref.h for the exact
// definition.
//
-#ifndef CERES_PUBLIC_EXPRESSION_H_
-#define CERES_PUBLIC_EXPRESSION_H_
+#ifndef CERES_PUBLIC_CODEGEN_INTERNAL_EXPRESSION_H_
+#define CERES_PUBLIC_CODEGEN_INTERNAL_EXPRESSION_H_
#include <string>
#include <vector>
@@ -370,4 +370,5 @@
} // namespace internal
} // namespace ceres
-#endif
+
+#endif // CERES_PUBLIC_CODEGEN_INTERNAL_EXPRESSION_H_
diff --git a/include/ceres/internal/expression_graph.h b/include/ceres/codegen/internal/expression_graph.h
similarity index 96%
rename from include/ceres/internal/expression_graph.h
rename to include/ceres/codegen/internal/expression_graph.h
index 644af0b..22fbba6 100644
--- a/include/ceres/internal/expression_graph.h
+++ b/include/ceres/codegen/internal/expression_graph.h
@@ -28,8 +28,8 @@
//
// Author: darius.rueckert@fau.de (Darius Rueckert)
-#ifndef CERES_PUBLIC_EXPRESSION_TREE_H_
-#define CERES_PUBLIC_EXPRESSION_TREE_H_
+#ifndef CERES_PUBLIC_CODEGEN_INTERNAL_EXPRESSION_GRAPH_H_
+#define CERES_PUBLIC_CODEGEN_INTERNAL_EXPRESSION_GRAPH_H_
#include <vector>
@@ -117,4 +117,5 @@
} // namespace internal
} // namespace ceres
-#endif
+
+#endif // CERES_PUBLIC_CODEGEN_INTERNAL_EXPRESSION_GRAPH_H_
diff --git a/include/ceres/internal/expression_ref.h b/include/ceres/codegen/internal/expression_ref.h
similarity index 85%
rename from include/ceres/internal/expression_ref.h
rename to include/ceres/codegen/internal/expression_ref.h
index b79ccb4..c26a4ab 100644
--- a/include/ceres/internal/expression_ref.h
+++ b/include/ceres/codegen/internal/expression_ref.h
@@ -33,8 +33,8 @@
#define CERES_PUBLIC_EXPRESSION_REF_H_
#include <string>
-#include "ceres/jet.h"
-#include "expression.h"
+#include "ceres/codegen/internal/types.h"
+#include "ceres/codegen/internal/expression.h"
namespace ceres {
namespace internal {
@@ -205,38 +205,13 @@
const ComparisonExpressionRef& y);
ComparisonExpressionRef operator!(const ComparisonExpressionRef& x);
-// This struct is used to mark numbers which are constant over
-// multiple invocations but can differ between instances.
-template <typename T>
-struct InputAssignment {
- using ReturnType = T;
- static inline ReturnType Get(double v, const char* /* unused */) { return v; }
-};
-
template <>
struct InputAssignment<ExpressionRef> {
using ReturnType = ExpressionRef;
static inline ReturnType Get(double /* unused */, const char* name) {
- return ExpressionRef::Create(Expression::CreateInputAssignment(name));
- }
-};
-
-template <typename G, int N>
-struct InputAssignment<Jet<G, N>> {
- using ReturnType = Jet<G, N>;
- static inline Jet<G, N> Get(double v, const char* /* unused */) {
- return Jet<G, N>(v);
- }
-};
-
-template <int N>
-struct InputAssignment<Jet<ExpressionRef, N>> {
- using ReturnType = Jet<ExpressionRef, N>;
- static inline ReturnType Get(double /* unused */, const char* name) {
// Note: The scalar value of v will be thrown away, because we don't need it
// during code generation.
- return Jet<ExpressionRef, N>(
- ExpressionRef::Create(Expression::CreateInputAssignment(name)));
+ return ExpressionRef::Create(Expression::CreateInputAssignment(name));
}
};
@@ -246,13 +221,6 @@
return InputAssignment<T>::Get(v, name);
}
-// This macro should be used for local variables in cost functors. Using local
-// variables directly, will compile their current value into the code.
-// Example:
-// T x = CERES_LOCAL_VARIABLE(observed_x_);
-#define CERES_LOCAL_VARIABLE(_v) \
- ceres::internal::MakeInputAssignment<T>(_v, #_v)
-
inline ExpressionRef MakeParameter(const std::string& name) {
return ExpressionRef::Create(Expression::CreateInputAssignment(name));
}
@@ -261,24 +229,8 @@
return ExpressionRef::Create(Expression::CreateOutputAssignment(v.id, name));
}
-// The CERES_CODEGEN macro is defined by the build system only during code
-// generation. In all other cases the CERES_IF/ELSE macros just expand to the
-// if/else keywords.
-#ifdef CERES_CODEGEN
-#define CERES_IF(condition_) Expression::CreateIf((condition_).id);
-#define CERES_ELSE Expression::CreateElse();
-#define CERES_ENDIF Expression::CreateEndIf();
-#else
-// clang-format off
-#define CERES_IF(condition_) if (condition_) {
-#define CERES_ELSE } else {
-#define CERES_ENDIF }
-// clang-format on
-#endif
-
} // namespace internal
-// See jet.h for more info on this type.
template <>
struct ComparisonReturnType<internal::ExpressionRef> {
using type = internal::ComparisonExpressionRef;
diff --git a/include/ceres/codegen/internal/types.h b/include/ceres/codegen/internal/types.h
new file mode 100644
index 0000000..5d32d5b
--- /dev/null
+++ b/include/ceres/codegen/internal/types.h
@@ -0,0 +1,67 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2019 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: darius.rueckert@fau.de (Darius Rueckert)
+//
+#ifndef CERES_PUBLIC_CODEGEN_INTERNAL_TYPES_H_
+#define CERES_PUBLIC_CODEGEN_INTERNAL_TYPES_H_
+
+#include "ceres/codegen/macros.h"
+
+namespace ceres {
+// The return type of a comparison, for example from <, &&, ==.
+//
+// In the context of traditional Ceres Jet operations, this would
+// always be a bool. However, in the autodiff code generation context,
+// the return is always an expression, and so a different type must be
+// used as a return from comparisons.
+//
+// In the autodiff codegen context, this function is overloaded so that 'type'
+// is one of the autodiff code generation expression types.
+template <typename T>
+struct ComparisonReturnType {
+ using type = bool;
+};
+
+namespace internal {
+// The InputAssignment struct is used to implement the CERES_LOCAL_VARIABLE
+// macro defined in macros.h. The input is a double variable and the
+// corresponding name as a string. During execution mode (T==double or
+// T==Jet<double>) the variable name is unused and this function only returns
+// the value. During code generation (T==ExpressionRef) the value is unused and
+// an assignment expression is created. For example:
+// v_0 = observed_x;
+template <typename T>
+struct InputAssignment {
+ static inline T Get(double v, const char* /* unused */) { return v; }
+};
+
+} // namespace internal
+} // namespace ceres
+
+#endif // CERES_PUBLIC_CODEGEN_INTERNAL_TYPES_H_
diff --git a/include/ceres/codegen/macros.h b/include/ceres/codegen/macros.h
new file mode 100644
index 0000000..9750075
--- /dev/null
+++ b/include/ceres/codegen/macros.h
@@ -0,0 +1,117 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2019 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: darius.rueckert@fau.de (Darius Rueckert)
+//
+// This file defines the required macros to use local variables and if-else
+// branches in the AutoDiffCodeGen system. This is also the only file that
+// should be included by a user-defined cost functor.
+//
+// To generate code for your cost functor the following steps have to be
+// implemented (see below for a full example):
+//
+// 1. Include this file
+// 2. Wrap accesses to local variables in the CERES_LOCAL_VARIABLE macro.
+// 3. Replace if, else by CERES_IF, CERES_ELSE and add CERES_ENDIF
+// 4. Add a default constructor
+//
+// Example - my_cost_functor.h
+// ========================================================
+// #include "ceres/rotation.h"
+// #include "ceres/autodiff_codegen_macros.h"
+//
+// struct MyReprojectionError {
+// MyReprojectionError(double observed_x, double observed_y)
+// : observed_x(observed_x), observed_y(observed_y) {}
+//
+// // The cost functor must be default constructible!
+// MyReprojectionError() = default;
+//
+// template <typename T>
+// bool operator()(const T* const camera,
+// const T* const point,
+// T* residuals) const {
+// T p[3];
+// AngleAxisRotatePoint(camera, point, p);
+// p[0] += camera[3];
+// p[1] += camera[4];
+// p[2] += camera[5];
+//
+// // The if block is written using the macros!
+// CERES_IF(p[2] < T(0)) {
+// p[0] = -p[0];
+// p[1] = -p[1];
+// p[2] = -p[2];
+// } CERES_ELSE {
+// p[0] += T(1.0);
+// }CERES_ENDIF;
+//
+// const T& focal = camera[6];
+// const T predicted_x = focal * p[0];
+// const T predicted_y = focal * p[1];
+//
+// // The read-access to the local variables observed_x and observed_y are
+// // wrapped in the CERES_LOCAL_VARIABLE macro!
+// residuals[0] = predicted_x - CERES_LOCAL_VARIABLE(T, observed_x);
+// residuals[1] = predicted_y - CERES_LOCAL_VARIABLE(T, observed_y);
+// return true;
+// }
+// double observed_x;
+// double observed_y;
+// };
+//
+// ========================================================
+//
+// This file defines the following macros:
+//
+// CERES_LOCAL_VARIABLE
+// CERES_IF
+// CERES_ELSE
+// CERES_ENDIF
+//
+#ifndef CERES_PUBLIC_CODEGEN_MACROS_H_
+#define CERES_PUBLIC_CODEGEN_MACROS_H_
+
+// The CERES_CODEGEN macro is defined by the build system only during code
+// generation.
+#ifndef CERES_CODEGEN
+#define CERES_LOCAL_VARIABLE(_template_type, _local_variable) (_local_variable)
+#define CERES_IF(condition_) if (condition_)
+#define CERES_ELSE else
+#define CERES_ENDIF
+#else
+#define CERES_LOCAL_VARIABLE(_template_type, _local_variable) \
+ ceres::internal::InputAssignment<_template_type>::Get(_local_variable, \
+ #_local_variable)
+#define CERES_IF(condition_) \
+ ceres::internal::Expression::CreateIf((condition_).id);
+#define CERES_ELSE ceres::internal::Expression::CreateElse();
+#define CERES_ENDIF ceres::internal::Expression::CreateEndIf();
+#endif
+
+#endif // CERES_PUBLIC_CODEGEN_MACROS_H_
diff --git a/include/ceres/jet.h b/include/ceres/jet.h
index 25d11e6..fb7afce 100644
--- a/include/ceres/jet.h
+++ b/include/ceres/jet.h
@@ -164,24 +164,11 @@
#include <string>
#include "Eigen/Core"
+#include "ceres/codegen/internal/types.h"
#include "ceres/internal/port.h"
namespace ceres {
-// The return type of a Jet comparison, for example from <, &&, ==.
-//
-// In the context of traditional Ceres Jet operations, this would
-// always be a bool. However, in the autodiff code generation context,
-// the return is always an expression, and so a different type must be
-// used as a return from comparisons.
-//
-// In the autodiff codegen context, this function is overloaded so that 'type'
-// is one of the autodiff code generation expression types.
-template <typename T>
-struct ComparisonReturnType {
- using type = bool;
-};
-
template <typename T, int N>
struct Jet {
enum { DIMENSION = N };
@@ -894,6 +881,18 @@
return s;
}
+namespace internal {
+// In the context of AutoDiffCodeGen, local variables can be added using the
+// CERES_LOCAL_VARIABLE macro defined in ceres/codegen/macros.h. This partial
+// specialization defined how local variables of type double are converted to
+// Jet<T>.
+template <typename T, int N>
+struct InputAssignment<Jet<T, N>> {
+ static inline Jet<T, N> Get(double v, const char* name) {
+ return Jet<T, N>(InputAssignment<T>::Get(v, name));
+ }
+};
+} // namespace internal
} // namespace ceres
namespace Eigen {
diff --git a/internal/ceres/code_generator.cc b/internal/ceres/code_generator.cc
index 3af4bfb..3743ef6 100644
--- a/internal/ceres/code_generator.cc
+++ b/internal/ceres/code_generator.cc
@@ -28,7 +28,7 @@
//
// Author: darius.rueckert@fau.de (Darius Rueckert)
-#include "ceres/internal/code_generator.h"
+#include "ceres/codegen/internal/code_generator.h"
#include <sstream>
#include "assert.h"
#include "glog/logging.h"
diff --git a/internal/ceres/code_generator_test.cc b/internal/ceres/code_generator_test.cc
index c8908aa..2ac7ee3 100644
--- a/internal/ceres/code_generator_test.cc
+++ b/internal/ceres/code_generator_test.cc
@@ -30,10 +30,9 @@
//
#define CERES_CODEGEN
-#include "ceres/internal/code_generator.h"
-#include "ceres/internal/expression_graph.h"
-#include "ceres/internal/expression_ref.h"
-
+#include "ceres/codegen/internal/code_generator.h"
+#include "ceres/codegen/internal/expression_graph.h"
+#include "ceres/codegen/internal/expression_ref.h"
#include "gtest/gtest.h"
namespace ceres {
@@ -82,7 +81,7 @@
TEST(CodeGenerator, INPUT_ASSIGNMENT) {
double local_variable = 5.0;
StartRecordingExpressions();
- T a = CERES_LOCAL_VARIABLE(local_variable);
+ T a = CERES_LOCAL_VARIABLE(T, local_variable);
T b = MakeParameter("parameters[0][0]");
T c = a + b;
auto graph = StopRecordingExpressions();
diff --git a/internal/ceres/conditional_expressions_test.cc b/internal/ceres/conditional_expressions_test.cc
index c2220b5..3fcb7a3 100644
--- a/internal/ceres/conditional_expressions_test.cc
+++ b/internal/ceres/conditional_expressions_test.cc
@@ -31,9 +31,8 @@
#define CERES_CODEGEN
-#include "ceres/internal/expression_graph.h"
-#include "ceres/internal/expression_ref.h"
-
+#include "ceres/codegen/internal/expression_graph.h"
+#include "ceres/codegen/internal/expression_ref.h"
#include "gtest/gtest.h"
namespace ceres {
diff --git a/internal/ceres/expression.cc b/internal/ceres/expression.cc
index 54450b8..b2471e7 100644
--- a/internal/ceres/expression.cc
+++ b/internal/ceres/expression.cc
@@ -28,10 +28,9 @@
//
// Author: darius.rueckert@fau.de (Darius Rueckert)
-#include "ceres/internal/expression.h"
+#include "ceres/codegen/internal/expression.h"
#include <algorithm>
-
-#include "ceres/internal/expression_graph.h"
+#include "ceres/codegen/internal/expression_graph.h"
#include "glog/logging.h"
namespace ceres {
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc
index 511a98b..6d02c1c 100644
--- a/internal/ceres/expression_graph.cc
+++ b/internal/ceres/expression_graph.cc
@@ -28,7 +28,7 @@
//
// Author: darius.rueckert@fau.de (Darius Rueckert)
-#include "ceres/internal/expression_graph.h"
+#include "ceres/codegen/internal/expression_graph.h"
#include "glog/logging.h"
namespace ceres {
diff --git a/internal/ceres/expression_graph_test.cc b/internal/ceres/expression_graph_test.cc
index faaa10c..23e68ee 100644
--- a/internal/ceres/expression_graph_test.cc
+++ b/internal/ceres/expression_graph_test.cc
@@ -30,8 +30,8 @@
//
// Test expression creation and logic.
-#include "ceres/internal/expression_graph.h"
-#include "ceres/internal/expression_ref.h"
+#include "ceres/codegen/internal/expression_graph.h"
+#include "ceres/codegen/internal/expression_ref.h"
#include "gtest/gtest.h"
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc
index 5a330b9..bb8f920 100644
--- a/internal/ceres/expression_ref.cc
+++ b/internal/ceres/expression_ref.cc
@@ -28,7 +28,7 @@
//
// Author: darius.rueckert@fau.de (Darius Rueckert)
-#include "ceres/internal/expression_ref.h"
+#include "ceres/codegen/internal/expression_ref.h"
#include "glog/logging.h"
namespace ceres {
diff --git a/internal/ceres/expression_test.cc b/internal/ceres/expression_test.cc
index 84d1fa4..25fd7f6 100644
--- a/internal/ceres/expression_test.cc
+++ b/internal/ceres/expression_test.cc
@@ -31,8 +31,8 @@
#define CERES_CODEGEN
-#include "ceres/internal/expression_graph.h"
-#include "ceres/internal/expression_ref.h"
+#include "ceres/codegen/internal/expression_graph.h"
+#include "ceres/codegen/internal/expression_ref.h"
#include "ceres/jet.h"
#include "gtest/gtest.h"