Add the expression return type as a member to Expression

Before this patch the return type was implicitly defined by the
ExpressionType. This patch separates this connection and allows
each Expression to have one of the predefined types (scalar,
boolean, void).

This patch is required to add support for the functions isfinite,
isinf, isnan, and isnormal. These are function taking a double and
returning a bool.

This also moves some complexity of the code generator to the
Expression, because the generator can direclty get the c++ type.

Change-Id: I8b32bab1bfab2f668875e506d6f3b789a5d1f3fd
diff --git a/include/ceres/codegen/internal/expression.h b/include/ceres/codegen/internal/expression.h
index 808d741..f75a0aa 100644
--- a/include/ceres/codegen/internal/expression.h
+++ b/include/ceres/codegen/internal/expression.h
@@ -234,6 +234,21 @@
   NOP
 };
 
+enum class ExpressionReturnType {
+  // The expression returns a scalar value (float or double). Used for most
+  // arithmetic operations and function calls.
+  SCALAR,
+  // The expression returns a boolean value. Used for logical expressions
+  //   v_3 = v_1 < v_2
+  // and functions returning a bool
+  //   v_3 = isfinite(v_1);
+  BOOLEAN,
+  // The expressions doesn't return a value. Used for the control
+  // expressions
+  // and NOP.
+  VOID,
+};
+
 // This class contains all data that is required to generate one line of code.
 // Each line has the following form:
 //
@@ -253,6 +268,7 @@
   Expression() = default;
 
   Expression(ExpressionType type,
+             ExpressionReturnType return_type = ExpressionReturnType::VOID,
              ExpressionId lhs_id = kInvalidExpressionId,
              const std::vector<ExpressionId>& arguments = {},
              const std::string& name = "",
@@ -276,8 +292,10 @@
                                         ExpressionId l,
                                         ExpressionId r);
   static Expression CreateLogicalNegation(ExpressionId v);
-  static Expression CreateFunctionCall(const std::string& name,
-                                       const std::vector<ExpressionId>& params);
+  static Expression CreateScalarFunctionCall(
+      const std::string& name, const std::vector<ExpressionId>& params);
+  static Expression CreateLogicalFunctionCall(
+      const std::string& name, const std::vector<ExpressionId>& params);
   static Expression CreateIf(ExpressionId condition);
   static Expression CreateElse();
   static Expression CreateEndIf();
@@ -332,6 +350,7 @@
   bool IsSemanticallyEquivalentTo(const Expression& other) const;
 
   ExpressionType type() const { return type_; }
+  ExpressionReturnType return_type() const { return return_type_; }
   ExpressionId lhs_id() const { return lhs_id_; }
   double value() const { return value_; }
   const std::string& name() const { return name_; }
@@ -342,6 +361,7 @@
 
  private:
   ExpressionType type_ = ExpressionType::NOP;
+  ExpressionReturnType return_type_ = ExpressionReturnType::VOID;
 
   // If lhs_id_ >= 0, then this expression is assigned to v_<lhs_id>.
   // For example:
diff --git a/include/ceres/codegen/internal/expression_ref.h b/include/ceres/codegen/internal/expression_ref.h
index c888739..6a04edb 100644
--- a/include/ceres/codegen/internal/expression_ref.h
+++ b/include/ceres/codegen/internal/expression_ref.h
@@ -35,6 +35,7 @@
 #include <string>
 #include "ceres/codegen/internal/expression.h"
 #include "ceres/codegen/internal/types.h"
+
 namespace ceres {
 namespace internal {
 
@@ -130,15 +131,15 @@
 ExpressionRef operator/(const ExpressionRef& x, const ExpressionRef& y);
 
 // Functions
-#define CERES_DEFINE_UNARY_FUNCTION_CALL(name)          \
-  inline ExpressionRef name(const ExpressionRef& x) {   \
-    return AddExpressionToGraph(                        \
-        Expression::CreateFunctionCall(#name, {x.id})); \
+#define CERES_DEFINE_UNARY_FUNCTION_CALL(name)                \
+  inline ExpressionRef name(const ExpressionRef& x) {         \
+    return AddExpressionToGraph(                              \
+        Expression::CreateScalarFunctionCall(#name, {x.id})); \
   }
 #define CERES_DEFINE_BINARY_FUNCTION_CALL(name)                               \
   inline ExpressionRef name(const ExpressionRef& x, const ExpressionRef& y) { \
     return AddExpressionToGraph(                                              \
-        Expression::CreateFunctionCall(#name, {x.id, y.id}));                 \
+        Expression::CreateScalarFunctionCall(#name, {x.id, y.id}));           \
   }
 CERES_DEFINE_UNARY_FUNCTION_CALL(abs);
 CERES_DEFINE_UNARY_FUNCTION_CALL(acos);
@@ -209,6 +210,19 @@
                                    const ComparisonExpressionRef& y);
 ComparisonExpressionRef operator!(const ComparisonExpressionRef& x);
 
+#define CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(name)          \
+  inline ComparisonExpressionRef name(const ExpressionRef& x) { \
+    return ComparisonExpressionRef(AddExpressionToGraph(        \
+        Expression::CreateLogicalFunctionCall(#name, {x.id}))); \
+  }
+
+CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(isfinite);
+CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(isinf);
+CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(isnan);
+CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(isnormal);
+
+#undef CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL
+
 template <>
 struct InputAssignment<ExpressionRef> {
   using ReturnType = ExpressionRef;
diff --git a/internal/ceres/codegen/expression_ref_test.cc b/internal/ceres/codegen/expression_ref_test.cc
index aeb9d2b..6d6971e 100644
--- a/internal/ceres/codegen/expression_ref_test.cc
+++ b/internal/ceres/codegen/expression_ref_test.cc
@@ -295,26 +295,26 @@
   ExpressionGraph reference;
   reference.InsertBack(Expression::CreateCompileTimeConstant(1));
   reference.InsertBack(Expression::CreateCompileTimeConstant(2));
-  reference.InsertBack(Expression::CreateFunctionCall("abs", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("acos", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("asin", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("atan", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("cbrt", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("ceil", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("cos", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("cosh", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("exp", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("exp2", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("floor", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("log", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("log2", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("sin", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("sinh", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("sqrt", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("tan", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("tanh", {0}));
-  reference.InsertBack(Expression::CreateFunctionCall("atan2", {0, 1}));
-  reference.InsertBack(Expression::CreateFunctionCall("pow", {0, 1}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("abs", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("acos", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("asin", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("atan", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("cbrt", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("ceil", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("cos", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("cosh", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("exp", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("exp2", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("floor", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("log", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("log2", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("sin", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("sinh", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("sqrt", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("tan", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("tanh", {0}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("atan2", {0, 1}));
+  reference.InsertBack(Expression::CreateScalarFunctionCall("pow", {0, 1}));
   EXPECT_EQ(reference, graph);
 }
 
diff --git a/internal/ceres/codegen/expression_test.cc b/internal/ceres/codegen/expression_test.cc
index 395bbaf..ae96ea0 100644
--- a/internal/ceres/codegen/expression_test.cc
+++ b/internal/ceres/codegen/expression_test.cc
@@ -39,11 +39,13 @@
 
 TEST(Expression, ConstructorAndAccessors) {
   Expression expr(ExpressionType::LOGICAL_NEGATION,
+                  ExpressionReturnType::BOOLEAN,
                   12345,
                   {1, 5, 8, 10},
                   "TestConstructor",
                   57.25);
   EXPECT_EQ(expr.type(), ExpressionType::LOGICAL_NEGATION);
+  EXPECT_EQ(expr.return_type(), ExpressionReturnType::BOOLEAN);
   EXPECT_EQ(expr.lhs_id(), 12345);
   EXPECT_EQ(expr.arguments(), std::vector<ExpressionId>({1, 5, 8, 10}));
   EXPECT_EQ(expr.name(), "TestConstructor");
@@ -51,54 +53,129 @@
 }
 
 TEST(Expression, CreateFunctions) {
-  // clang-format off
   // The default constructor creates a NOP!
-  EXPECT_EQ(Expression(), Expression(
-            ExpressionType::NOP, kInvalidExpressionId, {}, "", 0));
+  EXPECT_EQ(Expression(),
+            Expression(ExpressionType::NOP,
+                       ExpressionReturnType::VOID,
+                       kInvalidExpressionId,
+                       {},
+                       "",
+                       0));
 
-  EXPECT_EQ(Expression::CreateCompileTimeConstant(72), Expression(
-            ExpressionType::COMPILE_TIME_CONSTANT, kInvalidExpressionId, {}, "", 72));
+  EXPECT_EQ(Expression::CreateCompileTimeConstant(72),
+            Expression(ExpressionType::COMPILE_TIME_CONSTANT,
+                       ExpressionReturnType::SCALAR,
+                       kInvalidExpressionId,
+                       {},
+                       "",
+                       72));
 
-  EXPECT_EQ(Expression::CreateInputAssignment("arguments[0][0]"), Expression(
-            ExpressionType::INPUT_ASSIGNMENT, kInvalidExpressionId, {}, "arguments[0][0]", 0));
+  EXPECT_EQ(Expression::CreateInputAssignment("arguments[0][0]"),
+            Expression(ExpressionType::INPUT_ASSIGNMENT,
+                       ExpressionReturnType::SCALAR,
+                       kInvalidExpressionId,
+                       {},
+                       "arguments[0][0]",
+                       0));
 
-  EXPECT_EQ(Expression::CreateOutputAssignment(ExpressionId(5), "residuals[3]"), Expression(
-            ExpressionType::OUTPUT_ASSIGNMENT, kInvalidExpressionId, {5}, "residuals[3]", 0));
+  EXPECT_EQ(Expression::CreateOutputAssignment(ExpressionId(5), "residuals[3]"),
+            Expression(ExpressionType::OUTPUT_ASSIGNMENT,
+                       ExpressionReturnType::SCALAR,
+                       kInvalidExpressionId,
+                       {5},
+                       "residuals[3]",
+                       0));
 
-  EXPECT_EQ(Expression::CreateAssignment(ExpressionId(3), ExpressionId(5)), Expression(
-            ExpressionType::ASSIGNMENT, 3, {5}, "", 0));
+  EXPECT_EQ(Expression::CreateAssignment(ExpressionId(3), ExpressionId(5)),
+            Expression(ExpressionType::ASSIGNMENT,
+                       ExpressionReturnType::SCALAR,
+                       3,
+                       {5},
+                       "",
+                       0));
 
-  EXPECT_EQ(Expression::CreateBinaryArithmetic("+", ExpressionId(3),ExpressionId(5)), Expression(
-            ExpressionType::BINARY_ARITHMETIC, kInvalidExpressionId, {3,5}, "+", 0));
+  EXPECT_EQ(
+      Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5)),
+      Expression(ExpressionType::BINARY_ARITHMETIC,
+                 ExpressionReturnType::SCALAR,
+                 kInvalidExpressionId,
+                 {3, 5},
+                 "+",
+                 0));
 
-  EXPECT_EQ(Expression::CreateUnaryArithmetic("-", ExpressionId(5)), Expression(
-            ExpressionType::UNARY_ARITHMETIC, kInvalidExpressionId, {5}, "-", 0));
+  EXPECT_EQ(Expression::CreateUnaryArithmetic("-", ExpressionId(5)),
+            Expression(ExpressionType::UNARY_ARITHMETIC,
+                       ExpressionReturnType::SCALAR,
+                       kInvalidExpressionId,
+                       {5},
+                       "-",
+                       0));
 
-  EXPECT_EQ(Expression::CreateBinaryCompare("<",ExpressionId(3),ExpressionId(5)), Expression(
-            ExpressionType::BINARY_COMPARISON, kInvalidExpressionId, {3,5}, "<", 0));
+  EXPECT_EQ(
+      Expression::CreateBinaryCompare("<", ExpressionId(3), ExpressionId(5)),
+      Expression(ExpressionType::BINARY_COMPARISON,
+                 ExpressionReturnType::BOOLEAN,
+                 kInvalidExpressionId,
+                 {3, 5},
+                 "<",
+                 0));
 
-  EXPECT_EQ(Expression::CreateLogicalNegation(ExpressionId(5)), Expression(
-            ExpressionType::LOGICAL_NEGATION, kInvalidExpressionId, {5}, "", 0));
+  EXPECT_EQ(Expression::CreateLogicalNegation(ExpressionId(5)),
+            Expression(ExpressionType::LOGICAL_NEGATION,
+                       ExpressionReturnType::BOOLEAN,
+                       kInvalidExpressionId,
+                       {5},
+                       "",
+                       0));
 
-  EXPECT_EQ(Expression::CreateFunctionCall("pow",{ExpressionId(3),ExpressionId(5)}), Expression(
-            ExpressionType::FUNCTION_CALL, kInvalidExpressionId, {3,5}, "pow", 0));
+  EXPECT_EQ(Expression::CreateScalarFunctionCall(
+                "pow", {ExpressionId(3), ExpressionId(5)}),
+            Expression(ExpressionType::FUNCTION_CALL,
+                       ExpressionReturnType::SCALAR,
+                       kInvalidExpressionId,
+                       {3, 5},
+                       "pow",
+                       0));
 
-  EXPECT_EQ(Expression::CreateIf(ExpressionId(5)), Expression(
-            ExpressionType::IF, kInvalidExpressionId, {5}, "", 0));
+  EXPECT_EQ(
+      Expression::CreateLogicalFunctionCall("isfinite", {ExpressionId(3)}),
+      Expression(ExpressionType::FUNCTION_CALL,
+                 ExpressionReturnType::BOOLEAN,
+                 kInvalidExpressionId,
+                 {3},
+                 "isfinite",
+                 0));
 
-  EXPECT_EQ(Expression::CreateElse(), Expression(
-            ExpressionType::ELSE, kInvalidExpressionId, {}, "", 0));
+  EXPECT_EQ(Expression::CreateIf(ExpressionId(5)),
+            Expression(ExpressionType::IF,
+                       ExpressionReturnType::VOID,
+                       kInvalidExpressionId,
+                       {5},
+                       "",
+                       0));
 
-  EXPECT_EQ(Expression::CreateEndIf(), Expression(
-            ExpressionType::ENDIF, kInvalidExpressionId, {}, "", 0));
-  // clang-format on
+  EXPECT_EQ(Expression::CreateElse(),
+            Expression(ExpressionType::ELSE,
+                       ExpressionReturnType::VOID,
+                       kInvalidExpressionId,
+                       {},
+                       "",
+                       0));
+
+  EXPECT_EQ(Expression::CreateEndIf(),
+            Expression(ExpressionType::ENDIF,
+                       ExpressionReturnType::VOID,
+                       kInvalidExpressionId,
+                       {},
+                       "",
+                       0));
 }
 
 TEST(Expression, IsArithmeticExpression) {
   ASSERT_TRUE(
       Expression::CreateCompileTimeConstant(5).IsArithmeticExpression());
-  ASSERT_TRUE(
-      Expression::CreateFunctionCall("pow", {3, 5}).IsArithmeticExpression());
+  ASSERT_TRUE(Expression::CreateScalarFunctionCall("pow", {3, 5})
+                  .IsArithmeticExpression());
   // Logical expression are also arithmetic!
   ASSERT_TRUE(
       Expression::CreateBinaryCompare("<", 3, 5).IsArithmeticExpression());
@@ -111,8 +188,8 @@
   // In the current implementation this is the exact opposite of
   // IsArithmeticExpression.
   ASSERT_FALSE(Expression::CreateCompileTimeConstant(5).IsControlExpression());
-  ASSERT_FALSE(
-      Expression::CreateFunctionCall("pow", {3, 5}).IsControlExpression());
+  ASSERT_FALSE(Expression::CreateScalarFunctionCall("pow", {3, 5})
+                   .IsControlExpression());
   ASSERT_FALSE(
       Expression::CreateBinaryCompare("<", 3, 5).IsControlExpression());
   ASSERT_TRUE(Expression::CreateIf(5).IsControlExpression());
@@ -180,7 +257,13 @@
   expr1.Replace(expr2);
 
   // expr1 should now be an assignment from 7 to 13
-  EXPECT_EQ(expr1, Expression(ExpressionType::ASSIGNMENT, 13, {7}, "", 0));
+  EXPECT_EQ(expr1,
+            Expression(ExpressionType::ASSIGNMENT,
+                       ExpressionReturnType::SCALAR,
+                       13,
+                       {7},
+                       "",
+                       0));
 }
 
 TEST(Expression, DirectlyDependsOn) {
@@ -199,7 +282,12 @@
   expr1.MakeNop();
 
   EXPECT_EQ(expr1,
-            Expression(ExpressionType::NOP, kInvalidExpressionId, {}, "", 0));
+            Expression(ExpressionType::NOP,
+                       ExpressionReturnType::VOID,
+                       kInvalidExpressionId,
+                       {},
+                       "",
+                       0));
 }
 
 TEST(Expression, IsSemanticallyEquivalentTo) {
diff --git a/internal/ceres/expression.cc b/internal/ceres/expression.cc
index 27e15cb..406899d 100644
--- a/internal/ceres/expression.cc
+++ b/internal/ceres/expression.cc
@@ -35,69 +35,108 @@
 namespace internal {
 
 Expression::Expression(ExpressionType type,
+                       ExpressionReturnType return_type,
                        ExpressionId lhs_id,
                        const std::vector<ExpressionId>& arguments,
                        const std::string& name,
                        double value)
     : type_(type),
+      return_type_(return_type),
       lhs_id_(lhs_id),
       arguments_(arguments),
       name_(name),
       value_(value) {}
 
 Expression Expression::CreateCompileTimeConstant(double v) {
-  return Expression(
-      ExpressionType::COMPILE_TIME_CONSTANT, kInvalidExpressionId, {}, "", v);
+  return Expression(ExpressionType::COMPILE_TIME_CONSTANT,
+                    ExpressionReturnType::SCALAR,
+                    kInvalidExpressionId,
+                    {},
+                    "",
+                    v);
 }
 
 Expression Expression::CreateInputAssignment(const std::string& name) {
-  return Expression(
-      ExpressionType::INPUT_ASSIGNMENT, kInvalidExpressionId, {}, name);
+  return Expression(ExpressionType::INPUT_ASSIGNMENT,
+                    ExpressionReturnType::SCALAR,
+                    kInvalidExpressionId,
+                    {},
+                    name);
 }
 
 Expression Expression::CreateOutputAssignment(ExpressionId v,
                                               const std::string& name) {
-  return Expression(
-      ExpressionType::OUTPUT_ASSIGNMENT, kInvalidExpressionId, {v}, name);
+  return Expression(ExpressionType::OUTPUT_ASSIGNMENT,
+                    ExpressionReturnType::SCALAR,
+                    kInvalidExpressionId,
+                    {v},
+                    name);
 }
 
 Expression Expression::CreateAssignment(ExpressionId dst, ExpressionId src) {
-  return Expression(ExpressionType::ASSIGNMENT, dst, {src});
+  return Expression(
+      ExpressionType::ASSIGNMENT, ExpressionReturnType::SCALAR, dst, {src});
 }
 
 Expression Expression::CreateBinaryArithmetic(const std::string& op,
                                               ExpressionId l,
                                               ExpressionId r) {
-  return Expression(
-      ExpressionType::BINARY_ARITHMETIC, kInvalidExpressionId, {l, r}, op);
+  return Expression(ExpressionType::BINARY_ARITHMETIC,
+                    ExpressionReturnType::SCALAR,
+                    kInvalidExpressionId,
+                    {l, r},
+                    op);
 }
 
 Expression Expression::CreateUnaryArithmetic(const std::string& op,
                                              ExpressionId v) {
-  return Expression(
-      ExpressionType::UNARY_ARITHMETIC, kInvalidExpressionId, {v}, op);
+  return Expression(ExpressionType::UNARY_ARITHMETIC,
+                    ExpressionReturnType::SCALAR,
+                    kInvalidExpressionId,
+                    {v},
+                    op);
 }
 
 Expression Expression::CreateBinaryCompare(const std::string& name,
                                            ExpressionId l,
                                            ExpressionId r) {
-  return Expression(
-      ExpressionType::BINARY_COMPARISON, kInvalidExpressionId, {l, r}, name);
+  return Expression(ExpressionType::BINARY_COMPARISON,
+                    ExpressionReturnType::BOOLEAN,
+                    kInvalidExpressionId,
+                    {l, r},
+                    name);
 }
 
 Expression Expression::CreateLogicalNegation(ExpressionId v) {
-  return Expression(
-      ExpressionType::LOGICAL_NEGATION, kInvalidExpressionId, {v});
+  return Expression(ExpressionType::LOGICAL_NEGATION,
+                    ExpressionReturnType::BOOLEAN,
+                    kInvalidExpressionId,
+                    {v});
 }
 
-Expression Expression::CreateFunctionCall(
+Expression Expression::CreateScalarFunctionCall(
     const std::string& name, const std::vector<ExpressionId>& params) {
-  return Expression(
-      ExpressionType::FUNCTION_CALL, kInvalidExpressionId, params, name);
+  return Expression(ExpressionType::FUNCTION_CALL,
+                    ExpressionReturnType::SCALAR,
+                    kInvalidExpressionId,
+                    params,
+                    name);
+}
+
+Expression Expression::CreateLogicalFunctionCall(
+    const std::string& name, const std::vector<ExpressionId>& params) {
+  return Expression(ExpressionType::FUNCTION_CALL,
+                    ExpressionReturnType::BOOLEAN,
+                    kInvalidExpressionId,
+                    params,
+                    name);
 }
 
 Expression Expression::CreateIf(ExpressionId condition) {
-  return Expression(ExpressionType::IF, kInvalidExpressionId, {condition});
+  return Expression(ExpressionType::IF,
+                    ExpressionReturnType::VOID,
+                    kInvalidExpressionId,
+                    {condition});
 }
 
 Expression Expression::CreateElse() { return Expression(ExpressionType::ELSE); }
@@ -147,9 +186,9 @@
 }
 
 bool Expression::operator==(const Expression& other) const {
-  return type() == other.type() && name() == other.name() &&
-         value() == other.value() && lhs_id() == other.lhs_id() &&
-         arguments() == other.arguments();
+  return type() == other.type() && return_type() == other.return_type() &&
+         name() == other.name() && value() == other.value() &&
+         lhs_id() == other.lhs_id() && arguments() == other.arguments();
 }
 
 bool Expression::IsSemanticallyEquivalentTo(const Expression& other) const {
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc
index d9b33e9..7bae8d2 100644
--- a/internal/ceres/expression_graph.cc
+++ b/internal/ceres/expression_graph.cc
@@ -91,7 +91,7 @@
                              const Expression& expression) {
   ExpressionId last_expression_id = Size() - 1;
   // Increase size by adding a dummy expression.
-  expressions_.push_back(Expression(ExpressionType::NOP, kInvalidExpressionId));
+  expressions_.push_back(Expression());
 
   // Move everything after id back and update references
   for (ExpressionId id = last_expression_id; id >= location; --id) {
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc
index 2738b08..7c43595 100644
--- a/internal/ceres/expression_ref.cc
+++ b/internal/ceres/expression_ref.cc
@@ -165,7 +165,7 @@
                       const ExpressionRef& x,
                       const ExpressionRef& y) {
   return AddExpressionToGraph(
-      Expression::CreateFunctionCall("Ternary", {c.id, x.id, y.id}));
+      Expression::CreateScalarFunctionCall("Ternary", {c.id, x.id, y.id}));
 }
 
 #define CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR(op)         \