Update return type in code generator and add tests for logical functions
Change-Id: I128f8ccc21c2c3a2c79674ac28dd9a5b26b58772
diff --git a/include/ceres/codegen/internal/code_generator.h b/include/ceres/codegen/internal/code_generator.h
index 01a1162..2b98146 100644
--- a/include/ceres/codegen/internal/code_generator.h
+++ b/include/ceres/codegen/internal/code_generator.h
@@ -102,9 +102,6 @@
// If the expression does not have a valid name an error is generated.
std::string VariableForExpressionId(ExpressionId id);
- // Returns the type as a string of the left hand side.
- static std::string DataTypeForExpression(ExpressionType type);
-
// Adds one level of indentation. Called when an IF expression is encountered.
void PushIndentation();
diff --git a/include/ceres/codegen/internal/expression.h b/include/ceres/codegen/internal/expression.h
index f75a0aa..7de7a77 100644
--- a/include/ceres/codegen/internal/expression.h
+++ b/include/ceres/codegen/internal/expression.h
@@ -249,6 +249,8 @@
VOID,
};
+std::string ExpressionReturnTypeToString(ExpressionReturnType type);
+
// This class contains all data that is required to generate one line of code.
// Each line has the following form:
//
diff --git a/internal/ceres/code_generator.cc b/internal/ceres/code_generator.cc
index 3743ef6..dd0ab82 100644
--- a/internal/ceres/code_generator.cc
+++ b/internal/ceres/code_generator.cc
@@ -64,7 +64,7 @@
// Example: double v_0;
//
const std::string declaration_string =
- indentation_ + DataTypeForExpression(expr.type()) + " " +
+ indentation_ + ExpressionReturnTypeToString(expr.return_type()) + " " +
VariableForExpressionId(id) + ";";
code.emplace_back(declaration_string);
}
@@ -248,25 +248,6 @@
return options_.variable_prefix + std::to_string(expr.lhs_id());
}
-std::string CodeGenerator::DataTypeForExpression(ExpressionType type) {
- std::string type_string;
- switch (type) {
- case ExpressionType::BINARY_COMPARISON:
- case ExpressionType::LOGICAL_NEGATION:
- type_string = "bool";
- break;
- case ExpressionType::IF:
- case ExpressionType::ELSE:
- case ExpressionType::ENDIF:
- case ExpressionType::NOP:
- type_string = "void";
- break;
- default:
- type_string = "double";
- }
- return type_string;
-}
-
void CodeGenerator::PushIndentation() {
for (int i = 0; i < options_.indentation_spaces_per_level; ++i) {
indentation_.push_back(' ');
diff --git a/internal/ceres/codegen/code_generator_test.cc b/internal/ceres/codegen/code_generator_test.cc
index f974b1a..a73ccac 100644
--- a/internal/ceres/codegen/code_generator_test.cc
+++ b/internal/ceres/codegen/code_generator_test.cc
@@ -376,6 +376,32 @@
GenerateAndCheck(graph, expected_code);
}
+TEST(CodeGenerator, LOGICAL_FUNCTION_CALL) {
+ StartRecordingExpressions();
+ T a = T(1);
+
+ isfinite(a);
+ isinf(a);
+ isnan(a);
+ isnormal(a);
+
+ auto graph = StopRecordingExpressions();
+
+ std::vector<std::string> expected_code = {"{",
+ " double v_0;",
+ " bool v_1;",
+ " bool v_2;",
+ " bool v_3;",
+ " bool v_4;",
+ " v_0 = 1;",
+ " v_1 = isfinite(v_0);",
+ " v_2 = isinf(v_0);",
+ " v_3 = isnan(v_0);",
+ " v_4 = isnormal(v_0);",
+ "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
TEST(CodeGenerator, IF_SIMPLE) {
StartRecordingExpressions();
T a = T(0);
diff --git a/internal/ceres/codegen/expression_ref_test.cc b/internal/ceres/codegen/expression_ref_test.cc
index 6d6971e..88d0562 100644
--- a/internal/ceres/codegen/expression_ref_test.cc
+++ b/internal/ceres/codegen/expression_ref_test.cc
@@ -33,8 +33,8 @@
//
#define CERES_CODEGEN
-#include "ceres/codegen/internal/expression_graph.h"
#include "ceres/codegen/internal/expression_ref.h"
+#include "ceres/codegen/internal/expression_graph.h"
#include "gtest/gtest.h"
namespace ceres {
@@ -176,7 +176,7 @@
EXPECT_EQ(reference, graph);
}
-TEST(CodeGenerator, BINARY_ARITHMETIC_COMPOUND) {
+TEST(ExpressionRef, BINARY_ARITHMETIC_COMPOUND) {
// For each binary compound arithmetic operation, two lines are generated:
// - The actual operation assigning to a new temporary variable
// - An assignment from the temporary to the lhs
@@ -203,7 +203,7 @@
EXPECT_EQ(reference, graph);
}
-TEST(CodeGenerator, UNARY_ARITHMETIC) {
+TEST(ExpressionRef, UNARY_ARITHMETIC) {
StartRecordingExpressions();
T a = T(1);
T r1 = -a;
@@ -217,7 +217,7 @@
EXPECT_EQ(reference, graph);
}
-TEST(CodeGenerator, BINARY_COMPARISON) {
+TEST(ExpressionRef, BINARY_COMPARISON) {
using BOOL = ComparisonExpressionRef;
StartRecordingExpressions();
T a = T(1);
@@ -242,7 +242,7 @@
EXPECT_EQ(reference, graph);
}
-TEST(CodeGenerator, LOGICAL_OPERATORS) {
+TEST(ExpressionRef, LOGICAL_OPERATORS) {
using BOOL = ComparisonExpressionRef;
// Tests binary logical operators &&, || and the unary logical operator !
StartRecordingExpressions();
@@ -266,7 +266,7 @@
EXPECT_EQ(reference, graph);
}
-TEST(CodeGenerator, FUNCTION_CALL) {
+TEST(ExpressionRef, SCALAR_FUNCTION_CALL) {
StartRecordingExpressions();
T a = T(1);
T b = T(2);
@@ -318,7 +318,25 @@
EXPECT_EQ(reference, graph);
}
-TEST(CodeGenerator, IF) {
+TEST(ExpressionRef, LOGICAL_FUNCTION_CALL) {
+ StartRecordingExpressions();
+ T a = T(1);
+ isfinite(a);
+ isinf(a);
+ isnan(a);
+ isnormal(a);
+ auto graph = StopRecordingExpressions();
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(1));
+ reference.InsertBack(Expression::CreateLogicalFunctionCall("isfinite", {0}));
+ reference.InsertBack(Expression::CreateLogicalFunctionCall("isinf", {0}));
+ reference.InsertBack(Expression::CreateLogicalFunctionCall("isnan", {0}));
+ reference.InsertBack(Expression::CreateLogicalFunctionCall("isnormal", {0}));
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(ExpressionRef, IF) {
StartRecordingExpressions();
T a = T(1);
T b = T(2);
@@ -336,7 +354,7 @@
EXPECT_EQ(reference, graph);
}
-TEST(CodeGenerator, IF_ELSE) {
+TEST(ExpressionRef, IF_ELSE) {
StartRecordingExpressions();
T a = T(1);
T b = T(2);
@@ -356,7 +374,7 @@
EXPECT_EQ(reference, graph);
}
-TEST(CodeGenerator, IF_NESTED) {
+TEST(ExpressionRef, IF_NESTED) {
StartRecordingExpressions();
T a = T(1);
T b = T(2);
diff --git a/internal/ceres/expression.cc b/internal/ceres/expression.cc
index 406899d..0fec5a0 100644
--- a/internal/ceres/expression.cc
+++ b/internal/ceres/expression.cc
@@ -30,10 +30,25 @@
#include "ceres/codegen/internal/expression.h"
#include <algorithm>
+#include "glog/logging.h"
namespace ceres {
namespace internal {
+std::string ExpressionReturnTypeToString(ExpressionReturnType type) {
+ switch (type) {
+ case ExpressionReturnType::SCALAR:
+ return "double";
+ case ExpressionReturnType::BOOLEAN:
+ return "bool";
+ case ExpressionReturnType::VOID:
+ return "void";
+ default:
+ CHECK(false) << "Unknown ExpressionReturnType.";
+ return "";
+ }
+}
+
Expression::Expression(ExpressionType type,
ExpressionReturnType return_type,
ExpressionId lhs_id,