Update return type in code generator and add tests for logical functions

Change-Id: I128f8ccc21c2c3a2c79674ac28dd9a5b26b58772
diff --git a/include/ceres/codegen/internal/code_generator.h b/include/ceres/codegen/internal/code_generator.h
index 01a1162..2b98146 100644
--- a/include/ceres/codegen/internal/code_generator.h
+++ b/include/ceres/codegen/internal/code_generator.h
@@ -102,9 +102,6 @@
   // If the expression does not have a valid name an error is generated.
   std::string VariableForExpressionId(ExpressionId id);
 
-  // Returns the type as a string of the left hand side.
-  static std::string DataTypeForExpression(ExpressionType type);
-
   // Adds one level of indentation. Called when an IF expression is encountered.
   void PushIndentation();
 
diff --git a/include/ceres/codegen/internal/expression.h b/include/ceres/codegen/internal/expression.h
index f75a0aa..7de7a77 100644
--- a/include/ceres/codegen/internal/expression.h
+++ b/include/ceres/codegen/internal/expression.h
@@ -249,6 +249,8 @@
   VOID,
 };
 
+std::string ExpressionReturnTypeToString(ExpressionReturnType type);
+
 // This class contains all data that is required to generate one line of code.
 // Each line has the following form:
 //
diff --git a/internal/ceres/code_generator.cc b/internal/ceres/code_generator.cc
index 3743ef6..dd0ab82 100644
--- a/internal/ceres/code_generator.cc
+++ b/internal/ceres/code_generator.cc
@@ -64,7 +64,7 @@
     // Example:    double v_0;
     //
     const std::string declaration_string =
-        indentation_ + DataTypeForExpression(expr.type()) + " " +
+        indentation_ + ExpressionReturnTypeToString(expr.return_type()) + " " +
         VariableForExpressionId(id) + ";";
     code.emplace_back(declaration_string);
   }
@@ -248,25 +248,6 @@
   return options_.variable_prefix + std::to_string(expr.lhs_id());
 }
 
-std::string CodeGenerator::DataTypeForExpression(ExpressionType type) {
-  std::string type_string;
-  switch (type) {
-    case ExpressionType::BINARY_COMPARISON:
-    case ExpressionType::LOGICAL_NEGATION:
-      type_string = "bool";
-      break;
-    case ExpressionType::IF:
-    case ExpressionType::ELSE:
-    case ExpressionType::ENDIF:
-    case ExpressionType::NOP:
-      type_string = "void";
-      break;
-    default:
-      type_string = "double";
-  }
-  return type_string;
-}
-
 void CodeGenerator::PushIndentation() {
   for (int i = 0; i < options_.indentation_spaces_per_level; ++i) {
     indentation_.push_back(' ');
diff --git a/internal/ceres/codegen/code_generator_test.cc b/internal/ceres/codegen/code_generator_test.cc
index f974b1a..a73ccac 100644
--- a/internal/ceres/codegen/code_generator_test.cc
+++ b/internal/ceres/codegen/code_generator_test.cc
@@ -376,6 +376,32 @@
   GenerateAndCheck(graph, expected_code);
 }
 
+TEST(CodeGenerator, LOGICAL_FUNCTION_CALL) {
+  StartRecordingExpressions();
+  T a = T(1);
+
+  isfinite(a);
+  isinf(a);
+  isnan(a);
+  isnormal(a);
+
+  auto graph = StopRecordingExpressions();
+
+  std::vector<std::string> expected_code = {"{",
+                                            "  double v_0;",
+                                            "  bool v_1;",
+                                            "  bool v_2;",
+                                            "  bool v_3;",
+                                            "  bool v_4;",
+                                            "  v_0 = 1;",
+                                            "  v_1 = isfinite(v_0);",
+                                            "  v_2 = isinf(v_0);",
+                                            "  v_3 = isnan(v_0);",
+                                            "  v_4 = isnormal(v_0);",
+                                            "}"};
+  GenerateAndCheck(graph, expected_code);
+}
+
 TEST(CodeGenerator, IF_SIMPLE) {
   StartRecordingExpressions();
   T a = T(0);
diff --git a/internal/ceres/codegen/expression_ref_test.cc b/internal/ceres/codegen/expression_ref_test.cc
index 6d6971e..88d0562 100644
--- a/internal/ceres/codegen/expression_ref_test.cc
+++ b/internal/ceres/codegen/expression_ref_test.cc
@@ -33,8 +33,8 @@
 //
 #define CERES_CODEGEN
 
-#include "ceres/codegen/internal/expression_graph.h"
 #include "ceres/codegen/internal/expression_ref.h"
+#include "ceres/codegen/internal/expression_graph.h"
 #include "gtest/gtest.h"
 
 namespace ceres {
@@ -176,7 +176,7 @@
   EXPECT_EQ(reference, graph);
 }
 
-TEST(CodeGenerator, BINARY_ARITHMETIC_COMPOUND) {
+TEST(ExpressionRef, BINARY_ARITHMETIC_COMPOUND) {
   // For each binary compound arithmetic operation, two lines are generated:
   //    - The actual operation assigning to a new temporary variable
   //    - An assignment from the temporary to the lhs
@@ -203,7 +203,7 @@
   EXPECT_EQ(reference, graph);
 }
 
-TEST(CodeGenerator, UNARY_ARITHMETIC) {
+TEST(ExpressionRef, UNARY_ARITHMETIC) {
   StartRecordingExpressions();
   T a = T(1);
   T r1 = -a;
@@ -217,7 +217,7 @@
   EXPECT_EQ(reference, graph);
 }
 
-TEST(CodeGenerator, BINARY_COMPARISON) {
+TEST(ExpressionRef, BINARY_COMPARISON) {
   using BOOL = ComparisonExpressionRef;
   StartRecordingExpressions();
   T a = T(1);
@@ -242,7 +242,7 @@
   EXPECT_EQ(reference, graph);
 }
 
-TEST(CodeGenerator, LOGICAL_OPERATORS) {
+TEST(ExpressionRef, LOGICAL_OPERATORS) {
   using BOOL = ComparisonExpressionRef;
   // Tests binary logical operators &&, || and the unary logical operator !
   StartRecordingExpressions();
@@ -266,7 +266,7 @@
   EXPECT_EQ(reference, graph);
 }
 
-TEST(CodeGenerator, FUNCTION_CALL) {
+TEST(ExpressionRef, SCALAR_FUNCTION_CALL) {
   StartRecordingExpressions();
   T a = T(1);
   T b = T(2);
@@ -318,7 +318,25 @@
   EXPECT_EQ(reference, graph);
 }
 
-TEST(CodeGenerator, IF) {
+TEST(ExpressionRef, LOGICAL_FUNCTION_CALL) {
+  StartRecordingExpressions();
+  T a = T(1);
+  isfinite(a);
+  isinf(a);
+  isnan(a);
+  isnormal(a);
+  auto graph = StopRecordingExpressions();
+
+  ExpressionGraph reference;
+  reference.InsertBack(Expression::CreateCompileTimeConstant(1));
+  reference.InsertBack(Expression::CreateLogicalFunctionCall("isfinite", {0}));
+  reference.InsertBack(Expression::CreateLogicalFunctionCall("isinf", {0}));
+  reference.InsertBack(Expression::CreateLogicalFunctionCall("isnan", {0}));
+  reference.InsertBack(Expression::CreateLogicalFunctionCall("isnormal", {0}));
+  EXPECT_EQ(reference, graph);
+}
+
+TEST(ExpressionRef, IF) {
   StartRecordingExpressions();
   T a = T(1);
   T b = T(2);
@@ -336,7 +354,7 @@
   EXPECT_EQ(reference, graph);
 }
 
-TEST(CodeGenerator, IF_ELSE) {
+TEST(ExpressionRef, IF_ELSE) {
   StartRecordingExpressions();
   T a = T(1);
   T b = T(2);
@@ -356,7 +374,7 @@
   EXPECT_EQ(reference, graph);
 }
 
-TEST(CodeGenerator, IF_NESTED) {
+TEST(ExpressionRef, IF_NESTED) {
   StartRecordingExpressions();
   T a = T(1);
   T b = T(2);
diff --git a/internal/ceres/expression.cc b/internal/ceres/expression.cc
index 406899d..0fec5a0 100644
--- a/internal/ceres/expression.cc
+++ b/internal/ceres/expression.cc
@@ -30,10 +30,25 @@
 
 #include "ceres/codegen/internal/expression.h"
 #include <algorithm>
+#include "glog/logging.h"
 
 namespace ceres {
 namespace internal {
 
+std::string ExpressionReturnTypeToString(ExpressionReturnType type) {
+  switch (type) {
+    case ExpressionReturnType::SCALAR:
+      return "double";
+    case ExpressionReturnType::BOOLEAN:
+      return "bool";
+    case ExpressionReturnType::VOID:
+      return "void";
+    default:
+      CHECK(false) << "Unknown ExpressionReturnType.";
+      return "";
+  }
+}
+
 Expression::Expression(ExpressionType type,
                        ExpressionReturnType return_type,
                        ExpressionId lhs_id,