Remove MakeFunctionCall() and add test for Ternary Change-Id: Icf798a939a9868bc66c295ef0867ec075d4860da
diff --git a/include/ceres/internal/expression_ref.h b/include/ceres/internal/expression_ref.h index a1afd14..570a8d8 100644 --- a/include/ceres/internal/expression_ref.h +++ b/include/ceres/internal/expression_ref.h
@@ -110,20 +110,15 @@ ExpressionRef operator/(ExpressionRef x, ExpressionRef y); // Functions - -// Helper function to create a function call expression. -// Users can generate code for their own custom functions by adding an overload -// for ExpressionRef that maps to MakeFunctionCall. See below for examples. -ExpressionRef MakeFunctionCall(const std::string& name, - const std::vector<ExpressionRef>& params); - -#define CERES_DEFINE_UNARY_FUNCTION_CALL(name) \ - inline ExpressionRef name(ExpressionRef x) { \ - return MakeFunctionCall(#name, {x}); \ +#define CERES_DEFINE_UNARY_FUNCTION_CALL(name) \ + inline ExpressionRef name(ExpressionRef x) { \ + return ExpressionRef::Create( \ + Expression::CreateFunctionCall(#name, {x.id})); \ } #define CERES_DEFINE_BINARY_FUNCTION_CALL(name) \ inline ExpressionRef name(ExpressionRef x, ExpressionRef y) { \ - return MakeFunctionCall(#name, {x, y}); \ + return ExpressionRef::Create( \ + Expression::CreateFunctionCall(#name, {x.id, y.id})); \ } CERES_DEFINE_UNARY_FUNCTION_CALL(abs); CERES_DEFINE_UNARY_FUNCTION_CALL(acos);
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc index e9cda42..8f7300c 100644 --- a/internal/ceres/expression_ref.cc +++ b/internal/ceres/expression_ref.cc
@@ -112,20 +112,11 @@ Expression::CreateBinaryArithmetic("*", x.id, y.id)); } -// Functions -ExpressionRef MakeFunctionCall(const std::string& name, - const std::vector<ExpressionRef>& params) { - std::vector<ExpressionId> ids; - for (auto p : params) { - ids.push_back(p.id); - } - return ExpressionRef::Create(Expression::CreateFunctionCall(name, ids)); -} - ExpressionRef Ternary(ComparisonExpressionRef c, ExpressionRef a, ExpressionRef b) { - return MakeFunctionCall("ternary", {c.id, a.id, b.id}); + return ExpressionRef::Create( + Expression::CreateFunctionCall("Ternary", {c.id, a.id, b.id})); } #define CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR(op) \
diff --git a/internal/ceres/expression_test.cc b/internal/ceres/expression_test.cc index 7dd5a61..86159d9 100644 --- a/internal/ceres/expression_test.cc +++ b/internal/ceres/expression_test.cc
@@ -115,5 +115,37 @@ ASSERT_TRUE(graph.ExpressionForId(d.id).DirectlyDependsOn(c.id)); } +TEST(Expression, Ternary) { + using T = ExpressionRef; + + StartRecordingExpressions(); + T a(2); // 0 + T b(3); // 1 + auto c = a < b; // 2 + T d = Ternary(c, a, b); // 3 + MakeOutput(d, "result"); // 4 + auto graph = StopRecordingExpressions(); + + EXPECT_EQ(graph.Size(), 5); + + // Expected code + // v_0 = 2; + // v_1 = 3; + // v_2 = v_0 < v_1; + // v_3 = Ternary(v_2, v_0, v_1); + // result = v_3; + + ExpressionGraph reference; + // clang-format off + // Id, Type, Lhs, Value, Name, Arguments + reference.InsertExpression( 0, ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 2); + reference.InsertExpression( 1, ExpressionType::COMPILE_TIME_CONSTANT, 1, {}, "", 3); + reference.InsertExpression( 2, ExpressionType::BINARY_COMPARISON, 2, {0,1}, "<", 0); + reference.InsertExpression( 3, ExpressionType::FUNCTION_CALL, 3, {2,0,1}, "Ternary", 0); + reference.InsertExpression( 4, ExpressionType::OUTPUT_ASSIGNMENT, 4, {3}, "result", 0); + // clang-format on + EXPECT_EQ(reference, graph); +} + } // namespace internal } // namespace ceres