Remove MakeFunctionCall() and add test for Ternary
Change-Id: Icf798a939a9868bc66c295ef0867ec075d4860da
diff --git a/include/ceres/internal/expression_ref.h b/include/ceres/internal/expression_ref.h
index a1afd14..570a8d8 100644
--- a/include/ceres/internal/expression_ref.h
+++ b/include/ceres/internal/expression_ref.h
@@ -110,20 +110,15 @@
ExpressionRef operator/(ExpressionRef x, ExpressionRef y);
// Functions
-
-// Helper function to create a function call expression.
-// Users can generate code for their own custom functions by adding an overload
-// for ExpressionRef that maps to MakeFunctionCall. See below for examples.
-ExpressionRef MakeFunctionCall(const std::string& name,
- const std::vector<ExpressionRef>& params);
-
-#define CERES_DEFINE_UNARY_FUNCTION_CALL(name) \
- inline ExpressionRef name(ExpressionRef x) { \
- return MakeFunctionCall(#name, {x}); \
+#define CERES_DEFINE_UNARY_FUNCTION_CALL(name) \
+ inline ExpressionRef name(ExpressionRef x) { \
+ return ExpressionRef::Create( \
+ Expression::CreateFunctionCall(#name, {x.id})); \
}
#define CERES_DEFINE_BINARY_FUNCTION_CALL(name) \
inline ExpressionRef name(ExpressionRef x, ExpressionRef y) { \
- return MakeFunctionCall(#name, {x, y}); \
+ return ExpressionRef::Create( \
+ Expression::CreateFunctionCall(#name, {x.id, y.id})); \
}
CERES_DEFINE_UNARY_FUNCTION_CALL(abs);
CERES_DEFINE_UNARY_FUNCTION_CALL(acos);
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc
index e9cda42..8f7300c 100644
--- a/internal/ceres/expression_ref.cc
+++ b/internal/ceres/expression_ref.cc
@@ -112,20 +112,11 @@
Expression::CreateBinaryArithmetic("*", x.id, y.id));
}
-// Functions
-ExpressionRef MakeFunctionCall(const std::string& name,
- const std::vector<ExpressionRef>& params) {
- std::vector<ExpressionId> ids;
- for (auto p : params) {
- ids.push_back(p.id);
- }
- return ExpressionRef::Create(Expression::CreateFunctionCall(name, ids));
-}
-
ExpressionRef Ternary(ComparisonExpressionRef c,
ExpressionRef a,
ExpressionRef b) {
- return MakeFunctionCall("ternary", {c.id, a.id, b.id});
+ return ExpressionRef::Create(
+ Expression::CreateFunctionCall("Ternary", {c.id, a.id, b.id}));
}
#define CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR(op) \
diff --git a/internal/ceres/expression_test.cc b/internal/ceres/expression_test.cc
index 7dd5a61..86159d9 100644
--- a/internal/ceres/expression_test.cc
+++ b/internal/ceres/expression_test.cc
@@ -115,5 +115,37 @@
ASSERT_TRUE(graph.ExpressionForId(d.id).DirectlyDependsOn(c.id));
}
+TEST(Expression, Ternary) {
+ using T = ExpressionRef;
+
+ StartRecordingExpressions();
+ T a(2); // 0
+ T b(3); // 1
+ auto c = a < b; // 2
+ T d = Ternary(c, a, b); // 3
+ MakeOutput(d, "result"); // 4
+ auto graph = StopRecordingExpressions();
+
+ EXPECT_EQ(graph.Size(), 5);
+
+ // Expected code
+ // v_0 = 2;
+ // v_1 = 3;
+ // v_2 = v_0 < v_1;
+ // v_3 = Ternary(v_2, v_0, v_1);
+ // result = v_3;
+
+ ExpressionGraph reference;
+ // clang-format off
+ // Id, Type, Lhs, Value, Name, Arguments
+ reference.InsertExpression( 0, ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 2);
+ reference.InsertExpression( 1, ExpressionType::COMPILE_TIME_CONSTANT, 1, {}, "", 3);
+ reference.InsertExpression( 2, ExpressionType::BINARY_COMPARISON, 2, {0,1}, "<", 0);
+ reference.InsertExpression( 3, ExpressionType::FUNCTION_CALL, 3, {2,0,1}, "Ternary", 0);
+ reference.InsertExpression( 4, ExpressionType::OUTPUT_ASSIGNMENT, 4, {3}, "result", 0);
+ // clang-format on
+ EXPECT_EQ(reference, graph);
+}
+
} // namespace internal
} // namespace ceres