Fix ExpressionRef
- Add missing copy constructor
- Add make functions for input/output expressions
- Add missing ExpressionRef overload for MakeRuntimeConstant
- Add ExpressionRef overloads for common functions (cos,exp,...)
- Make ExpressionRef constructor from double implicit
- Add simple test for using ExpressionRef in Jets.
- Rename expression test macro
Change-Id: I071dc6717e4034a281662021998a91e80636da27
diff --git a/include/ceres/internal/expression_ref.h b/include/ceres/internal/expression_ref.h
index f4b920f..e0386f0 100644
--- a/include/ceres/internal/expression_ref.h
+++ b/include/ceres/internal/expression_ref.h
@@ -50,7 +50,11 @@
// Create a compile time constant expression directly from a double value.
// This is important so that we can write T(3.14) in our code and
// it's automatically converted to the correct expression.
- explicit ExpressionRef(double compile_time_constant);
+ //
+ // This constructor is implicit, because the line
+ // T a(0);
+ // must work for T = Jet<ExpressionRef>.
+ ExpressionRef(double compile_time_constant);
// Create an ASSIGNMENT expression from other to this.
//
@@ -71,6 +75,7 @@
//
// If 'other' is not pointing to a variable (id==invalid), we found an
// uninitialized assignment, which is handled as an error.
+ ExpressionRef(const ExpressionRef& other);
ExpressionRef& operator=(const ExpressionRef& other);
// Compound operators
@@ -96,8 +101,45 @@
ExpressionRef operator/(ExpressionRef x, ExpressionRef y);
// Functions
-// TODO: Add all function supported by Jet.
-ExpressionRef sin(ExpressionRef x);
+
+// Helper function to create a function call expression.
+// Users can generate code for their own custom functions by adding an overload
+// for ExpressionRef that maps to MakeFunctionCall. See below for examples.
+ExpressionRef MakeFunctionCall(const std::string& name,
+ const std::vector<ExpressionRef>& params);
+
+#define CERES_DEFINE_UNARY_FUNCTION_CALL(name) \
+ inline ExpressionRef name(ExpressionRef x) { \
+ return MakeFunctionCall(#name, {x}); \
+ }
+#define CERES_DEFINE_BINARY_FUNCTION_CALL(name) \
+ inline ExpressionRef name(ExpressionRef x, ExpressionRef y) { \
+ return MakeFunctionCall(#name, {x, y}); \
+ }
+CERES_DEFINE_UNARY_FUNCTION_CALL(abs);
+CERES_DEFINE_UNARY_FUNCTION_CALL(acos);
+CERES_DEFINE_UNARY_FUNCTION_CALL(asin);
+CERES_DEFINE_UNARY_FUNCTION_CALL(atan);
+CERES_DEFINE_UNARY_FUNCTION_CALL(cbrt);
+CERES_DEFINE_UNARY_FUNCTION_CALL(ceil);
+CERES_DEFINE_UNARY_FUNCTION_CALL(cos);
+CERES_DEFINE_UNARY_FUNCTION_CALL(cosh);
+CERES_DEFINE_UNARY_FUNCTION_CALL(exp);
+CERES_DEFINE_UNARY_FUNCTION_CALL(exp2);
+CERES_DEFINE_UNARY_FUNCTION_CALL(floor);
+CERES_DEFINE_UNARY_FUNCTION_CALL(log);
+CERES_DEFINE_UNARY_FUNCTION_CALL(log2);
+CERES_DEFINE_UNARY_FUNCTION_CALL(sin);
+CERES_DEFINE_UNARY_FUNCTION_CALL(sinh);
+CERES_DEFINE_UNARY_FUNCTION_CALL(sqrt);
+CERES_DEFINE_UNARY_FUNCTION_CALL(tan);
+CERES_DEFINE_UNARY_FUNCTION_CALL(tanh);
+
+CERES_DEFINE_BINARY_FUNCTION_CALL(atan2);
+CERES_DEFINE_BINARY_FUNCTION_CALL(pow);
+
+#undef CERES_DEFINE_UNARY_FUNCTION_CALL
+#undef CERES_DEFINE_BINARY_FUNCTION_CALL
// This additonal type is required, so that we can detect invalid conditions
// during compile time. For example, the following should create a compile time
@@ -142,13 +184,21 @@
template <typename T>
struct RuntimeConstant {
using ReturnType = T;
- static inline ReturnType Get(double v, const char* name) { return v; }
+ static inline ReturnType Get(double v, const char* /* unused */) { return v; }
+};
+
+template <>
+struct RuntimeConstant<ExpressionRef> {
+ using ReturnType = ExpressionRef;
+ static inline ReturnType Get(double /* unused */, const char* name) {
+ return ExpressionRef::Create(Expression::CreateRuntimeConstant(name));
+ }
};
template <typename G, int N>
struct RuntimeConstant<Jet<G, N>> {
using ReturnType = Jet<G, N>;
- static inline Jet<G, N> Get(double v, const char* name) {
+ static inline Jet<G, N> Get(double v, const char* /* unused */) {
return Jet<G, N>(v);
}
};
@@ -156,11 +206,11 @@
template <int N>
struct RuntimeConstant<Jet<ExpressionRef, N>> {
using ReturnType = Jet<ExpressionRef, N>;
- static inline ReturnType Get(double v, const char* name) {
+ static inline ReturnType Get(double /* unused */, const char* name) {
// Note: The scalar value of v will be thrown away, because we don't need it
// during code generation.
- (void)v;
- return Jet<ExpressionRef, N>(Expression::CreateRuntimeConstant(name));
+ return Jet<ExpressionRef, N>(
+ ExpressionRef::Create(Expression::CreateRuntimeConstant(name)));
}
};
@@ -173,6 +223,13 @@
#define CERES_EXPRESSION_RUNTIME_CONSTANT(_v) \
ceres::internal::MakeRuntimeConstant<T>(_v, #_v)
+inline ExpressionRef MakeParameter(const std::string& name) {
+ return ExpressionRef::Create(Expression::CreateParameter(name));
+}
+inline ExpressionRef MakeOutput(ExpressionRef v, const std::string& name) {
+ return ExpressionRef::Create(Expression::CreateOutputAssignment(v.id, name));
+}
+
// The CERES_CODEGEN macro is defined by the build system only during code
// generation. In all other cases the CERES_IF/ELSE macros just expand to the
// if/else keywords.
diff --git a/internal/ceres/conditional_expressions_test.cc b/internal/ceres/conditional_expressions_test.cc
index 72fe3ee..d7b3979 100644
--- a/internal/ceres/conditional_expressions_test.cc
+++ b/internal/ceres/conditional_expressions_test.cc
@@ -55,7 +55,7 @@
// clang-format off
// Id Type Lhs Value Name Arguments
- TE(0, COMPILE_TIME_CONSTANT, 0, 2, "");
+ CHECK_EXPRESSION(0, COMPILE_TIME_CONSTANT, 0, 2, "",);
// clang-format on
// Variables after execution:
@@ -86,9 +86,9 @@
// clang-format off
// Id, Type, Lhs, Value, Name, Arguments
- TE( 0, COMPILE_TIME_CONSTANT, 0, 2, "", );
- TE( 1, COMPILE_TIME_CONSTANT, 1, 4, "", );
- TE( 2, ASSIGNMENT, 1, 0, "", 0);
+ CHECK_EXPRESSION( 0, COMPILE_TIME_CONSTANT, 0, 2, "", );
+ CHECK_EXPRESSION( 1, COMPILE_TIME_CONSTANT, 1, 4, "", );
+ CHECK_EXPRESSION( 2, ASSIGNMENT, 1, 0, "", 0);
// clang-format on
// Variables after execution:
@@ -123,12 +123,12 @@
// clang-format off
// Id, Type, Lhs, Value, Name, Arguments...
- TE( 0, COMPILE_TIME_CONSTANT, 0, 2, "", );
- TE( 1, COMPILE_TIME_CONSTANT, 1, 3, "", );
- TE( 2, BINARY_COMPARISON, 2, 0, "<", 0, 1);
- TE( 3, IF, -1, 0, "", 2);
- TE( 4, ELSE, -1, 0, "", );
- TE( 5, ENDIF, -1, 0, "", );
+ CHECK_EXPRESSION( 0, COMPILE_TIME_CONSTANT, 0, 2, "", );
+ CHECK_EXPRESSION( 1, COMPILE_TIME_CONSTANT, 1, 3, "", );
+ CHECK_EXPRESSION( 2, BINARY_COMPARISON, 2, 0, "<", 0, 1);
+ CHECK_EXPRESSION( 3, IF, -1, 0, "", 2);
+ CHECK_EXPRESSION( 4, ELSE, -1, 0, "", );
+ CHECK_EXPRESSION( 5, ENDIF, -1, 0, "", );
// clang-format on
}
@@ -162,17 +162,17 @@
// clang-format off
// Id, Type, Lhs, Value, Name, Arguments...
- TE( 0, COMPILE_TIME_CONSTANT, 0, 2, "", );
- TE( 1, COMPILE_TIME_CONSTANT, 1, 3, "", );
- TE( 2, BINARY_COMPARISON, 2, 0, "<", 0, 1);
- TE( 3, IF, -1, 0, "", 2);
- TE( 4, BINARY_ARITHMETIC, 4, 0, "+", 0, 1);
- TE( 5, ELSE, -1, 0, "", );
- TE( 6, BINARY_ARITHMETIC, 6, 0, "-", 0, 1);
- TE( 7, ASSIGNMENT, 4, 0, "", 6 );
- TE( 8, ENDIF, -1, 0, "", );
- TE( 9, BINARY_ARITHMETIC, 9, 0, "+", 4, 0);
- TE( 10, ASSIGNMENT, 4, 0, "", 9 );
+ CHECK_EXPRESSION( 0, COMPILE_TIME_CONSTANT, 0, 2, "", );
+ CHECK_EXPRESSION( 1, COMPILE_TIME_CONSTANT, 1, 3, "", );
+ CHECK_EXPRESSION( 2, BINARY_COMPARISON, 2, 0, "<", 0, 1);
+ CHECK_EXPRESSION( 3, IF, -1, 0, "", 2);
+ CHECK_EXPRESSION( 4, BINARY_ARITHMETIC, 4, 0, "+", 0, 1);
+ CHECK_EXPRESSION( 5, ELSE, -1, 0, "", );
+ CHECK_EXPRESSION( 6, BINARY_ARITHMETIC, 6, 0, "-", 0, 1);
+ CHECK_EXPRESSION( 7, ASSIGNMENT, 4, 0, "", 6 );
+ CHECK_EXPRESSION( 8, ENDIF, -1, 0, "", );
+ CHECK_EXPRESSION( 9, BINARY_ARITHMETIC, 9, 0, "+", 4, 0);
+ CHECK_EXPRESSION( 10, ASSIGNMENT, 4, 0, "", 9 );
// clang-format on
// Variables after execution:
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc
index 224a3bf..7c2bc5d 100644
--- a/internal/ceres/expression_ref.cc
+++ b/internal/ceres/expression_ref.cc
@@ -44,6 +44,8 @@
id = Expression::CreateCompileTimeConstant(compile_time_constant);
}
+ExpressionRef::ExpressionRef(const ExpressionRef& other) { *this = other; }
+
ExpressionRef& ExpressionRef::operator=(const ExpressionRef& other) {
// Assigning an uninitialized variable to another variable is an error.
CHECK(other.IsInitialized()) << "Uninitialized Assignment.";
@@ -111,8 +113,13 @@
}
// Functions
-ExpressionRef sin(ExpressionRef x) {
- return ExpressionRef::Create(Expression::CreateFunctionCall("sin", {x.id}));
+ExpressionRef MakeFunctionCall(const std::string& name,
+ const std::vector<ExpressionRef>& params) {
+ std::vector<ExpressionId> ids;
+ for (auto p : params) {
+ ids.push_back(p.id);
+ }
+ return ExpressionRef::Create(Expression::CreateFunctionCall(name, ids));
}
ExpressionRef Ternary(ComparisonExpressionRef c,
diff --git a/internal/ceres/expression_test.cc b/internal/ceres/expression_test.cc
index d020192..a8a8213 100644
--- a/internal/ceres/expression_test.cc
+++ b/internal/ceres/expression_test.cc
@@ -31,10 +31,8 @@
#define CERES_CODEGEN
-#include "ceres/internal/expression_graph.h"
-#include "ceres/internal/expression_ref.h"
-
-#include "gtest/gtest.h"
+#include "ceres/expression_test.h"
+#include "ceres/jet.h"
namespace ceres {
namespace internal {
@@ -115,6 +113,51 @@
ASSERT_TRUE(graph.ExpressionForId(d.id).DirectlyDependsOn(c.id));
}
+TEST(Expression, Jet) {
+ using T = Jet<ExpressionRef, 1>;
+
+ StartRecordingExpressions();
+
+ T a(2, 0);
+ T b = a * a;
+
+ auto graph = StopRecordingExpressions();
+
+ // b is valid during the assignment so we expect an
+ // additional assignment expression.
+ EXPECT_EQ(graph.Size(), 8);
+
+ // Expected code
+ // v_0 = 2;
+ // v_1 = 0;
+ // v_2 = 1;
+ // v_1 = v_2;
+ // v_3 = v_0 * v_0;
+ // v_4 = v_0 * v_1;
+ // v_5 = v_1 * v_0;
+ // v_6 = v_3 * v_4;
+ // v_7 = v_5 * v_6;
+
+ // clang-format off
+ // Id, Type, Lhs, Value, Name, Arguments
+ CHECK_EXPRESSION( 0, COMPILE_TIME_CONSTANT, 0, 2, "", );
+ CHECK_EXPRESSION( 1, COMPILE_TIME_CONSTANT, 1, 0, "", );
+ CHECK_EXPRESSION( 2, COMPILE_TIME_CONSTANT, 2, 1, "", );
+ CHECK_EXPRESSION( 3, ASSIGNMENT, 1, 0, "", 2 );
+ CHECK_EXPRESSION( 4, BINARY_ARITHMETIC, 4, 0, "*", 0, 0);
+ CHECK_EXPRESSION( 5, BINARY_ARITHMETIC, 5, 0, "*", 0, 1);
+ CHECK_EXPRESSION( 6, BINARY_ARITHMETIC, 6, 0, "*", 1, 0);
+ CHECK_EXPRESSION( 7, BINARY_ARITHMETIC, 7, 0, "+", 5, 6);
+ // clang-format on
+
+ // Variables after execution:
+ //
+ // b.a <=> v_4
+ // b.v[0] <=> v_7
+ EXPECT_EQ(b.a.id, 4);
+ EXPECT_EQ(b.v[0].id, 7);
+}
+
// Todo: remaining functions of Expression
} // namespace internal
diff --git a/internal/ceres/expression_test.h b/internal/ceres/expression_test.h
index de9e2c2..2be7f8e 100644
--- a/internal/ceres/expression_test.h
+++ b/internal/ceres/expression_test.h
@@ -55,12 +55,12 @@
EXPECT_EQ(expr.arguments(), arguments);
}
-#define TE(_id, _type, _lhs_id, _value, _name, ...) \
- TestExpression(graph.ExpressionForId(_id), \
- ExpressionType::_type, \
- _lhs_id, \
- _value, \
- _name, \
+#define CHECK_EXPRESSION(_id, _type, _lhs_id, _value, _name, ...) \
+ TestExpression(graph.ExpressionForId(_id), \
+ ExpressionType::_type, \
+ _lhs_id, \
+ _value, \
+ _name, \
{__VA_ARGS__})
} // namespace internal