Fix ExpressionRef - Add missing copy constructor - Add make functions for input/output expressions - Add missing ExpressionRef overload for MakeRuntimeConstant - Add ExpressionRef overloads for common functions (cos,exp,...) - Make ExpressionRef constructor from double implicit - Add simple test for using ExpressionRef in Jets. - Rename expression test macro Change-Id: I071dc6717e4034a281662021998a91e80636da27
diff --git a/include/ceres/internal/expression_ref.h b/include/ceres/internal/expression_ref.h index f4b920f..e0386f0 100644 --- a/include/ceres/internal/expression_ref.h +++ b/include/ceres/internal/expression_ref.h
@@ -50,7 +50,11 @@ // Create a compile time constant expression directly from a double value. // This is important so that we can write T(3.14) in our code and // it's automatically converted to the correct expression. - explicit ExpressionRef(double compile_time_constant); + // + // This constructor is implicit, because the line + // T a(0); + // must work for T = Jet<ExpressionRef>. + ExpressionRef(double compile_time_constant); // Create an ASSIGNMENT expression from other to this. // @@ -71,6 +75,7 @@ // // If 'other' is not pointing to a variable (id==invalid), we found an // uninitialized assignment, which is handled as an error. + ExpressionRef(const ExpressionRef& other); ExpressionRef& operator=(const ExpressionRef& other); // Compound operators @@ -96,8 +101,45 @@ ExpressionRef operator/(ExpressionRef x, ExpressionRef y); // Functions -// TODO: Add all function supported by Jet. -ExpressionRef sin(ExpressionRef x); + +// Helper function to create a function call expression. +// Users can generate code for their own custom functions by adding an overload +// for ExpressionRef that maps to MakeFunctionCall. See below for examples. +ExpressionRef MakeFunctionCall(const std::string& name, + const std::vector<ExpressionRef>& params); + +#define CERES_DEFINE_UNARY_FUNCTION_CALL(name) \ + inline ExpressionRef name(ExpressionRef x) { \ + return MakeFunctionCall(#name, {x}); \ + } +#define CERES_DEFINE_BINARY_FUNCTION_CALL(name) \ + inline ExpressionRef name(ExpressionRef x, ExpressionRef y) { \ + return MakeFunctionCall(#name, {x, y}); \ + } +CERES_DEFINE_UNARY_FUNCTION_CALL(abs); +CERES_DEFINE_UNARY_FUNCTION_CALL(acos); +CERES_DEFINE_UNARY_FUNCTION_CALL(asin); +CERES_DEFINE_UNARY_FUNCTION_CALL(atan); +CERES_DEFINE_UNARY_FUNCTION_CALL(cbrt); +CERES_DEFINE_UNARY_FUNCTION_CALL(ceil); +CERES_DEFINE_UNARY_FUNCTION_CALL(cos); +CERES_DEFINE_UNARY_FUNCTION_CALL(cosh); +CERES_DEFINE_UNARY_FUNCTION_CALL(exp); +CERES_DEFINE_UNARY_FUNCTION_CALL(exp2); +CERES_DEFINE_UNARY_FUNCTION_CALL(floor); +CERES_DEFINE_UNARY_FUNCTION_CALL(log); +CERES_DEFINE_UNARY_FUNCTION_CALL(log2); +CERES_DEFINE_UNARY_FUNCTION_CALL(sin); +CERES_DEFINE_UNARY_FUNCTION_CALL(sinh); +CERES_DEFINE_UNARY_FUNCTION_CALL(sqrt); +CERES_DEFINE_UNARY_FUNCTION_CALL(tan); +CERES_DEFINE_UNARY_FUNCTION_CALL(tanh); + +CERES_DEFINE_BINARY_FUNCTION_CALL(atan2); +CERES_DEFINE_BINARY_FUNCTION_CALL(pow); + +#undef CERES_DEFINE_UNARY_FUNCTION_CALL +#undef CERES_DEFINE_BINARY_FUNCTION_CALL // This additonal type is required, so that we can detect invalid conditions // during compile time. For example, the following should create a compile time @@ -142,13 +184,21 @@ template <typename T> struct RuntimeConstant { using ReturnType = T; - static inline ReturnType Get(double v, const char* name) { return v; } + static inline ReturnType Get(double v, const char* /* unused */) { return v; } +}; + +template <> +struct RuntimeConstant<ExpressionRef> { + using ReturnType = ExpressionRef; + static inline ReturnType Get(double /* unused */, const char* name) { + return ExpressionRef::Create(Expression::CreateRuntimeConstant(name)); + } }; template <typename G, int N> struct RuntimeConstant<Jet<G, N>> { using ReturnType = Jet<G, N>; - static inline Jet<G, N> Get(double v, const char* name) { + static inline Jet<G, N> Get(double v, const char* /* unused */) { return Jet<G, N>(v); } }; @@ -156,11 +206,11 @@ template <int N> struct RuntimeConstant<Jet<ExpressionRef, N>> { using ReturnType = Jet<ExpressionRef, N>; - static inline ReturnType Get(double v, const char* name) { + static inline ReturnType Get(double /* unused */, const char* name) { // Note: The scalar value of v will be thrown away, because we don't need it // during code generation. - (void)v; - return Jet<ExpressionRef, N>(Expression::CreateRuntimeConstant(name)); + return Jet<ExpressionRef, N>( + ExpressionRef::Create(Expression::CreateRuntimeConstant(name))); } }; @@ -173,6 +223,13 @@ #define CERES_EXPRESSION_RUNTIME_CONSTANT(_v) \ ceres::internal::MakeRuntimeConstant<T>(_v, #_v) +inline ExpressionRef MakeParameter(const std::string& name) { + return ExpressionRef::Create(Expression::CreateParameter(name)); +} +inline ExpressionRef MakeOutput(ExpressionRef v, const std::string& name) { + return ExpressionRef::Create(Expression::CreateOutputAssignment(v.id, name)); +} + // The CERES_CODEGEN macro is defined by the build system only during code // generation. In all other cases the CERES_IF/ELSE macros just expand to the // if/else keywords.
diff --git a/internal/ceres/conditional_expressions_test.cc b/internal/ceres/conditional_expressions_test.cc index 72fe3ee..d7b3979 100644 --- a/internal/ceres/conditional_expressions_test.cc +++ b/internal/ceres/conditional_expressions_test.cc
@@ -55,7 +55,7 @@ // clang-format off // Id Type Lhs Value Name Arguments - TE(0, COMPILE_TIME_CONSTANT, 0, 2, ""); + CHECK_EXPRESSION(0, COMPILE_TIME_CONSTANT, 0, 2, "",); // clang-format on // Variables after execution: @@ -86,9 +86,9 @@ // clang-format off // Id, Type, Lhs, Value, Name, Arguments - TE( 0, COMPILE_TIME_CONSTANT, 0, 2, "", ); - TE( 1, COMPILE_TIME_CONSTANT, 1, 4, "", ); - TE( 2, ASSIGNMENT, 1, 0, "", 0); + CHECK_EXPRESSION( 0, COMPILE_TIME_CONSTANT, 0, 2, "", ); + CHECK_EXPRESSION( 1, COMPILE_TIME_CONSTANT, 1, 4, "", ); + CHECK_EXPRESSION( 2, ASSIGNMENT, 1, 0, "", 0); // clang-format on // Variables after execution: @@ -123,12 +123,12 @@ // clang-format off // Id, Type, Lhs, Value, Name, Arguments... - TE( 0, COMPILE_TIME_CONSTANT, 0, 2, "", ); - TE( 1, COMPILE_TIME_CONSTANT, 1, 3, "", ); - TE( 2, BINARY_COMPARISON, 2, 0, "<", 0, 1); - TE( 3, IF, -1, 0, "", 2); - TE( 4, ELSE, -1, 0, "", ); - TE( 5, ENDIF, -1, 0, "", ); + CHECK_EXPRESSION( 0, COMPILE_TIME_CONSTANT, 0, 2, "", ); + CHECK_EXPRESSION( 1, COMPILE_TIME_CONSTANT, 1, 3, "", ); + CHECK_EXPRESSION( 2, BINARY_COMPARISON, 2, 0, "<", 0, 1); + CHECK_EXPRESSION( 3, IF, -1, 0, "", 2); + CHECK_EXPRESSION( 4, ELSE, -1, 0, "", ); + CHECK_EXPRESSION( 5, ENDIF, -1, 0, "", ); // clang-format on } @@ -162,17 +162,17 @@ // clang-format off // Id, Type, Lhs, Value, Name, Arguments... - TE( 0, COMPILE_TIME_CONSTANT, 0, 2, "", ); - TE( 1, COMPILE_TIME_CONSTANT, 1, 3, "", ); - TE( 2, BINARY_COMPARISON, 2, 0, "<", 0, 1); - TE( 3, IF, -1, 0, "", 2); - TE( 4, BINARY_ARITHMETIC, 4, 0, "+", 0, 1); - TE( 5, ELSE, -1, 0, "", ); - TE( 6, BINARY_ARITHMETIC, 6, 0, "-", 0, 1); - TE( 7, ASSIGNMENT, 4, 0, "", 6 ); - TE( 8, ENDIF, -1, 0, "", ); - TE( 9, BINARY_ARITHMETIC, 9, 0, "+", 4, 0); - TE( 10, ASSIGNMENT, 4, 0, "", 9 ); + CHECK_EXPRESSION( 0, COMPILE_TIME_CONSTANT, 0, 2, "", ); + CHECK_EXPRESSION( 1, COMPILE_TIME_CONSTANT, 1, 3, "", ); + CHECK_EXPRESSION( 2, BINARY_COMPARISON, 2, 0, "<", 0, 1); + CHECK_EXPRESSION( 3, IF, -1, 0, "", 2); + CHECK_EXPRESSION( 4, BINARY_ARITHMETIC, 4, 0, "+", 0, 1); + CHECK_EXPRESSION( 5, ELSE, -1, 0, "", ); + CHECK_EXPRESSION( 6, BINARY_ARITHMETIC, 6, 0, "-", 0, 1); + CHECK_EXPRESSION( 7, ASSIGNMENT, 4, 0, "", 6 ); + CHECK_EXPRESSION( 8, ENDIF, -1, 0, "", ); + CHECK_EXPRESSION( 9, BINARY_ARITHMETIC, 9, 0, "+", 4, 0); + CHECK_EXPRESSION( 10, ASSIGNMENT, 4, 0, "", 9 ); // clang-format on // Variables after execution:
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc index 224a3bf..7c2bc5d 100644 --- a/internal/ceres/expression_ref.cc +++ b/internal/ceres/expression_ref.cc
@@ -44,6 +44,8 @@ id = Expression::CreateCompileTimeConstant(compile_time_constant); } +ExpressionRef::ExpressionRef(const ExpressionRef& other) { *this = other; } + ExpressionRef& ExpressionRef::operator=(const ExpressionRef& other) { // Assigning an uninitialized variable to another variable is an error. CHECK(other.IsInitialized()) << "Uninitialized Assignment."; @@ -111,8 +113,13 @@ } // Functions -ExpressionRef sin(ExpressionRef x) { - return ExpressionRef::Create(Expression::CreateFunctionCall("sin", {x.id})); +ExpressionRef MakeFunctionCall(const std::string& name, + const std::vector<ExpressionRef>& params) { + std::vector<ExpressionId> ids; + for (auto p : params) { + ids.push_back(p.id); + } + return ExpressionRef::Create(Expression::CreateFunctionCall(name, ids)); } ExpressionRef Ternary(ComparisonExpressionRef c,
diff --git a/internal/ceres/expression_test.cc b/internal/ceres/expression_test.cc index d020192..a8a8213 100644 --- a/internal/ceres/expression_test.cc +++ b/internal/ceres/expression_test.cc
@@ -31,10 +31,8 @@ #define CERES_CODEGEN -#include "ceres/internal/expression_graph.h" -#include "ceres/internal/expression_ref.h" - -#include "gtest/gtest.h" +#include "ceres/expression_test.h" +#include "ceres/jet.h" namespace ceres { namespace internal { @@ -115,6 +113,51 @@ ASSERT_TRUE(graph.ExpressionForId(d.id).DirectlyDependsOn(c.id)); } +TEST(Expression, Jet) { + using T = Jet<ExpressionRef, 1>; + + StartRecordingExpressions(); + + T a(2, 0); + T b = a * a; + + auto graph = StopRecordingExpressions(); + + // b is valid during the assignment so we expect an + // additional assignment expression. + EXPECT_EQ(graph.Size(), 8); + + // Expected code + // v_0 = 2; + // v_1 = 0; + // v_2 = 1; + // v_1 = v_2; + // v_3 = v_0 * v_0; + // v_4 = v_0 * v_1; + // v_5 = v_1 * v_0; + // v_6 = v_3 * v_4; + // v_7 = v_5 * v_6; + + // clang-format off + // Id, Type, Lhs, Value, Name, Arguments + CHECK_EXPRESSION( 0, COMPILE_TIME_CONSTANT, 0, 2, "", ); + CHECK_EXPRESSION( 1, COMPILE_TIME_CONSTANT, 1, 0, "", ); + CHECK_EXPRESSION( 2, COMPILE_TIME_CONSTANT, 2, 1, "", ); + CHECK_EXPRESSION( 3, ASSIGNMENT, 1, 0, "", 2 ); + CHECK_EXPRESSION( 4, BINARY_ARITHMETIC, 4, 0, "*", 0, 0); + CHECK_EXPRESSION( 5, BINARY_ARITHMETIC, 5, 0, "*", 0, 1); + CHECK_EXPRESSION( 6, BINARY_ARITHMETIC, 6, 0, "*", 1, 0); + CHECK_EXPRESSION( 7, BINARY_ARITHMETIC, 7, 0, "+", 5, 6); + // clang-format on + + // Variables after execution: + // + // b.a <=> v_4 + // b.v[0] <=> v_7 + EXPECT_EQ(b.a.id, 4); + EXPECT_EQ(b.v[0].id, 7); +} + // Todo: remaining functions of Expression } // namespace internal
diff --git a/internal/ceres/expression_test.h b/internal/ceres/expression_test.h index de9e2c2..2be7f8e 100644 --- a/internal/ceres/expression_test.h +++ b/internal/ceres/expression_test.h
@@ -55,12 +55,12 @@ EXPECT_EQ(expr.arguments(), arguments); } -#define TE(_id, _type, _lhs_id, _value, _name, ...) \ - TestExpression(graph.ExpressionForId(_id), \ - ExpressionType::_type, \ - _lhs_id, \ - _value, \ - _name, \ +#define CHECK_EXPRESSION(_id, _type, _lhs_id, _value, _name, ...) \ + TestExpression(graph.ExpressionForId(_id), \ + ExpressionType::_type, \ + _lhs_id, \ + _value, \ + _name, \ {__VA_ARGS__}) } // namespace internal