Update return type in code generator and add tests for logical functions Change-Id: I128f8ccc21c2c3a2c79674ac28dd9a5b26b58772
diff --git a/include/ceres/codegen/internal/code_generator.h b/include/ceres/codegen/internal/code_generator.h index 01a1162..2b98146 100644 --- a/include/ceres/codegen/internal/code_generator.h +++ b/include/ceres/codegen/internal/code_generator.h
@@ -102,9 +102,6 @@ // If the expression does not have a valid name an error is generated. std::string VariableForExpressionId(ExpressionId id); - // Returns the type as a string of the left hand side. - static std::string DataTypeForExpression(ExpressionType type); - // Adds one level of indentation. Called when an IF expression is encountered. void PushIndentation();
diff --git a/include/ceres/codegen/internal/expression.h b/include/ceres/codegen/internal/expression.h index f75a0aa..7de7a77 100644 --- a/include/ceres/codegen/internal/expression.h +++ b/include/ceres/codegen/internal/expression.h
@@ -249,6 +249,8 @@ VOID, }; +std::string ExpressionReturnTypeToString(ExpressionReturnType type); + // This class contains all data that is required to generate one line of code. // Each line has the following form: //
diff --git a/internal/ceres/code_generator.cc b/internal/ceres/code_generator.cc index 3743ef6..dd0ab82 100644 --- a/internal/ceres/code_generator.cc +++ b/internal/ceres/code_generator.cc
@@ -64,7 +64,7 @@ // Example: double v_0; // const std::string declaration_string = - indentation_ + DataTypeForExpression(expr.type()) + " " + + indentation_ + ExpressionReturnTypeToString(expr.return_type()) + " " + VariableForExpressionId(id) + ";"; code.emplace_back(declaration_string); } @@ -248,25 +248,6 @@ return options_.variable_prefix + std::to_string(expr.lhs_id()); } -std::string CodeGenerator::DataTypeForExpression(ExpressionType type) { - std::string type_string; - switch (type) { - case ExpressionType::BINARY_COMPARISON: - case ExpressionType::LOGICAL_NEGATION: - type_string = "bool"; - break; - case ExpressionType::IF: - case ExpressionType::ELSE: - case ExpressionType::ENDIF: - case ExpressionType::NOP: - type_string = "void"; - break; - default: - type_string = "double"; - } - return type_string; -} - void CodeGenerator::PushIndentation() { for (int i = 0; i < options_.indentation_spaces_per_level; ++i) { indentation_.push_back(' ');
diff --git a/internal/ceres/codegen/code_generator_test.cc b/internal/ceres/codegen/code_generator_test.cc index f974b1a..a73ccac 100644 --- a/internal/ceres/codegen/code_generator_test.cc +++ b/internal/ceres/codegen/code_generator_test.cc
@@ -376,6 +376,32 @@ GenerateAndCheck(graph, expected_code); } +TEST(CodeGenerator, LOGICAL_FUNCTION_CALL) { + StartRecordingExpressions(); + T a = T(1); + + isfinite(a); + isinf(a); + isnan(a); + isnormal(a); + + auto graph = StopRecordingExpressions(); + + std::vector<std::string> expected_code = {"{", + " double v_0;", + " bool v_1;", + " bool v_2;", + " bool v_3;", + " bool v_4;", + " v_0 = 1;", + " v_1 = isfinite(v_0);", + " v_2 = isinf(v_0);", + " v_3 = isnan(v_0);", + " v_4 = isnormal(v_0);", + "}"}; + GenerateAndCheck(graph, expected_code); +} + TEST(CodeGenerator, IF_SIMPLE) { StartRecordingExpressions(); T a = T(0);
diff --git a/internal/ceres/codegen/expression_ref_test.cc b/internal/ceres/codegen/expression_ref_test.cc index 6d6971e..88d0562 100644 --- a/internal/ceres/codegen/expression_ref_test.cc +++ b/internal/ceres/codegen/expression_ref_test.cc
@@ -33,8 +33,8 @@ // #define CERES_CODEGEN -#include "ceres/codegen/internal/expression_graph.h" #include "ceres/codegen/internal/expression_ref.h" +#include "ceres/codegen/internal/expression_graph.h" #include "gtest/gtest.h" namespace ceres { @@ -176,7 +176,7 @@ EXPECT_EQ(reference, graph); } -TEST(CodeGenerator, BINARY_ARITHMETIC_COMPOUND) { +TEST(ExpressionRef, BINARY_ARITHMETIC_COMPOUND) { // For each binary compound arithmetic operation, two lines are generated: // - The actual operation assigning to a new temporary variable // - An assignment from the temporary to the lhs @@ -203,7 +203,7 @@ EXPECT_EQ(reference, graph); } -TEST(CodeGenerator, UNARY_ARITHMETIC) { +TEST(ExpressionRef, UNARY_ARITHMETIC) { StartRecordingExpressions(); T a = T(1); T r1 = -a; @@ -217,7 +217,7 @@ EXPECT_EQ(reference, graph); } -TEST(CodeGenerator, BINARY_COMPARISON) { +TEST(ExpressionRef, BINARY_COMPARISON) { using BOOL = ComparisonExpressionRef; StartRecordingExpressions(); T a = T(1); @@ -242,7 +242,7 @@ EXPECT_EQ(reference, graph); } -TEST(CodeGenerator, LOGICAL_OPERATORS) { +TEST(ExpressionRef, LOGICAL_OPERATORS) { using BOOL = ComparisonExpressionRef; // Tests binary logical operators &&, || and the unary logical operator ! StartRecordingExpressions(); @@ -266,7 +266,7 @@ EXPECT_EQ(reference, graph); } -TEST(CodeGenerator, FUNCTION_CALL) { +TEST(ExpressionRef, SCALAR_FUNCTION_CALL) { StartRecordingExpressions(); T a = T(1); T b = T(2); @@ -318,7 +318,25 @@ EXPECT_EQ(reference, graph); } -TEST(CodeGenerator, IF) { +TEST(ExpressionRef, LOGICAL_FUNCTION_CALL) { + StartRecordingExpressions(); + T a = T(1); + isfinite(a); + isinf(a); + isnan(a); + isnormal(a); + auto graph = StopRecordingExpressions(); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(1)); + reference.InsertBack(Expression::CreateLogicalFunctionCall("isfinite", {0})); + reference.InsertBack(Expression::CreateLogicalFunctionCall("isinf", {0})); + reference.InsertBack(Expression::CreateLogicalFunctionCall("isnan", {0})); + reference.InsertBack(Expression::CreateLogicalFunctionCall("isnormal", {0})); + EXPECT_EQ(reference, graph); +} + +TEST(ExpressionRef, IF) { StartRecordingExpressions(); T a = T(1); T b = T(2); @@ -336,7 +354,7 @@ EXPECT_EQ(reference, graph); } -TEST(CodeGenerator, IF_ELSE) { +TEST(ExpressionRef, IF_ELSE) { StartRecordingExpressions(); T a = T(1); T b = T(2); @@ -356,7 +374,7 @@ EXPECT_EQ(reference, graph); } -TEST(CodeGenerator, IF_NESTED) { +TEST(ExpressionRef, IF_NESTED) { StartRecordingExpressions(); T a = T(1); T b = T(2);
diff --git a/internal/ceres/expression.cc b/internal/ceres/expression.cc index 406899d..0fec5a0 100644 --- a/internal/ceres/expression.cc +++ b/internal/ceres/expression.cc
@@ -30,10 +30,25 @@ #include "ceres/codegen/internal/expression.h" #include <algorithm> +#include "glog/logging.h" namespace ceres { namespace internal { +std::string ExpressionReturnTypeToString(ExpressionReturnType type) { + switch (type) { + case ExpressionReturnType::SCALAR: + return "double"; + case ExpressionReturnType::BOOLEAN: + return "bool"; + case ExpressionReturnType::VOID: + return "void"; + default: + CHECK(false) << "Unknown ExpressionReturnType."; + return ""; + } +} + Expression::Expression(ExpressionType type, ExpressionReturnType return_type, ExpressionId lhs_id,