Fix ExpressionRef

- Add missing copy constructor
- Add make functions for input/output expressions
- Add missing ExpressionRef overload for MakeRuntimeConstant
- Add ExpressionRef overloads for common functions (cos,exp,...)
- Make ExpressionRef constructor from double implicit
- Add simple test for using ExpressionRef in Jets.
- Rename expression test macro

Change-Id: I071dc6717e4034a281662021998a91e80636da27
diff --git a/include/ceres/internal/expression_ref.h b/include/ceres/internal/expression_ref.h
index f4b920f..e0386f0 100644
--- a/include/ceres/internal/expression_ref.h
+++ b/include/ceres/internal/expression_ref.h
@@ -50,7 +50,11 @@
   // Create a compile time constant expression directly from a double value.
   // This is important so that we can write T(3.14) in our code and
   // it's automatically converted to the correct expression.
-  explicit ExpressionRef(double compile_time_constant);
+  //
+  // This constructor is implicit, because the line
+  //    T a(0);
+  // must work for T = Jet<ExpressionRef>.
+  ExpressionRef(double compile_time_constant);
 
   // Create an ASSIGNMENT expression from other to this.
   //
@@ -71,6 +75,7 @@
   //
   // If 'other' is not pointing to a variable (id==invalid), we found an
   // uninitialized assignment, which is handled as an error.
+  ExpressionRef(const ExpressionRef& other);
   ExpressionRef& operator=(const ExpressionRef& other);
 
   // Compound operators
@@ -96,8 +101,45 @@
 ExpressionRef operator/(ExpressionRef x, ExpressionRef y);
 
 // Functions
-// TODO: Add all function supported by Jet.
-ExpressionRef sin(ExpressionRef x);
+
+// Helper function to create a function call expression.
+// Users can generate code for their own custom functions by adding an overload
+// for ExpressionRef that maps to MakeFunctionCall. See below for examples.
+ExpressionRef MakeFunctionCall(const std::string& name,
+                               const std::vector<ExpressionRef>& params);
+
+#define CERES_DEFINE_UNARY_FUNCTION_CALL(name) \
+  inline ExpressionRef name(ExpressionRef x) { \
+    return MakeFunctionCall(#name, {x});       \
+  }
+#define CERES_DEFINE_BINARY_FUNCTION_CALL(name)                 \
+  inline ExpressionRef name(ExpressionRef x, ExpressionRef y) { \
+    return MakeFunctionCall(#name, {x, y});                     \
+  }
+CERES_DEFINE_UNARY_FUNCTION_CALL(abs);
+CERES_DEFINE_UNARY_FUNCTION_CALL(acos);
+CERES_DEFINE_UNARY_FUNCTION_CALL(asin);
+CERES_DEFINE_UNARY_FUNCTION_CALL(atan);
+CERES_DEFINE_UNARY_FUNCTION_CALL(cbrt);
+CERES_DEFINE_UNARY_FUNCTION_CALL(ceil);
+CERES_DEFINE_UNARY_FUNCTION_CALL(cos);
+CERES_DEFINE_UNARY_FUNCTION_CALL(cosh);
+CERES_DEFINE_UNARY_FUNCTION_CALL(exp);
+CERES_DEFINE_UNARY_FUNCTION_CALL(exp2);
+CERES_DEFINE_UNARY_FUNCTION_CALL(floor);
+CERES_DEFINE_UNARY_FUNCTION_CALL(log);
+CERES_DEFINE_UNARY_FUNCTION_CALL(log2);
+CERES_DEFINE_UNARY_FUNCTION_CALL(sin);
+CERES_DEFINE_UNARY_FUNCTION_CALL(sinh);
+CERES_DEFINE_UNARY_FUNCTION_CALL(sqrt);
+CERES_DEFINE_UNARY_FUNCTION_CALL(tan);
+CERES_DEFINE_UNARY_FUNCTION_CALL(tanh);
+
+CERES_DEFINE_BINARY_FUNCTION_CALL(atan2);
+CERES_DEFINE_BINARY_FUNCTION_CALL(pow);
+
+#undef CERES_DEFINE_UNARY_FUNCTION_CALL
+#undef CERES_DEFINE_BINARY_FUNCTION_CALL
 
 // This additonal type is required, so that we can detect invalid conditions
 // during compile time. For example, the following should create a compile time
@@ -142,13 +184,21 @@
 template <typename T>
 struct RuntimeConstant {
   using ReturnType = T;
-  static inline ReturnType Get(double v, const char* name) { return v; }
+  static inline ReturnType Get(double v, const char* /* unused */) { return v; }
+};
+
+template <>
+struct RuntimeConstant<ExpressionRef> {
+  using ReturnType = ExpressionRef;
+  static inline ReturnType Get(double /* unused */, const char* name) {
+    return ExpressionRef::Create(Expression::CreateRuntimeConstant(name));
+  }
 };
 
 template <typename G, int N>
 struct RuntimeConstant<Jet<G, N>> {
   using ReturnType = Jet<G, N>;
-  static inline Jet<G, N> Get(double v, const char* name) {
+  static inline Jet<G, N> Get(double v, const char* /* unused */) {
     return Jet<G, N>(v);
   }
 };
@@ -156,11 +206,11 @@
 template <int N>
 struct RuntimeConstant<Jet<ExpressionRef, N>> {
   using ReturnType = Jet<ExpressionRef, N>;
-  static inline ReturnType Get(double v, const char* name) {
+  static inline ReturnType Get(double /* unused */, const char* name) {
     // Note: The scalar value of v will be thrown away, because we don't need it
     // during code generation.
-    (void)v;
-    return Jet<ExpressionRef, N>(Expression::CreateRuntimeConstant(name));
+    return Jet<ExpressionRef, N>(
+        ExpressionRef::Create(Expression::CreateRuntimeConstant(name)));
   }
 };
 
@@ -173,6 +223,13 @@
 #define CERES_EXPRESSION_RUNTIME_CONSTANT(_v) \
   ceres::internal::MakeRuntimeConstant<T>(_v, #_v)
 
+inline ExpressionRef MakeParameter(const std::string& name) {
+  return ExpressionRef::Create(Expression::CreateParameter(name));
+}
+inline ExpressionRef MakeOutput(ExpressionRef v, const std::string& name) {
+  return ExpressionRef::Create(Expression::CreateOutputAssignment(v.id, name));
+}
+
 // The CERES_CODEGEN macro is defined by the build system only during code
 // generation. In all other cases the CERES_IF/ELSE macros just expand to the
 // if/else keywords.
diff --git a/internal/ceres/conditional_expressions_test.cc b/internal/ceres/conditional_expressions_test.cc
index 72fe3ee..d7b3979 100644
--- a/internal/ceres/conditional_expressions_test.cc
+++ b/internal/ceres/conditional_expressions_test.cc
@@ -55,7 +55,7 @@
 
   // clang-format off
   // Id  Type                   Lhs  Value Name  Arguments
-  TE(0,  COMPILE_TIME_CONSTANT, 0,   2,     "");
+  CHECK_EXPRESSION(0,  COMPILE_TIME_CONSTANT, 0,   2,     "",);
   // clang-format on
 
   // Variables after execution:
@@ -86,9 +86,9 @@
 
   // clang-format off
   // Id, Type, Lhs, Value, Name, Arguments
-  TE(  0,  COMPILE_TIME_CONSTANT,   0,   2,   "",   );
-  TE(  1,  COMPILE_TIME_CONSTANT,   1,   4,   "",   );
-  TE(  2,             ASSIGNMENT,   1,   0,   "",  0);
+  CHECK_EXPRESSION(  0,  COMPILE_TIME_CONSTANT,   0,   2,   "",   );
+  CHECK_EXPRESSION(  1,  COMPILE_TIME_CONSTANT,   1,   4,   "",   );
+  CHECK_EXPRESSION(  2,             ASSIGNMENT,   1,   0,   "",  0);
   // clang-format on
 
   // Variables after execution:
@@ -123,12 +123,12 @@
 
   // clang-format off
   // Id, Type, Lhs, Value, Name, Arguments...
-  TE(  0, COMPILE_TIME_CONSTANT,   0,   2,   "",      );
-  TE(  1, COMPILE_TIME_CONSTANT,   1,   3,   "",      );
-  TE(  2,     BINARY_COMPARISON,   2,   0,  "<",  0, 1);
-  TE(  3,                    IF,  -1,   0,   "",     2);
-  TE(  4,                  ELSE,  -1,   0,   "",      );
-  TE(  5,                 ENDIF,  -1,   0,   "",      );
+  CHECK_EXPRESSION(  0, COMPILE_TIME_CONSTANT,   0,   2,   "",      );
+  CHECK_EXPRESSION(  1, COMPILE_TIME_CONSTANT,   1,   3,   "",      );
+  CHECK_EXPRESSION(  2,     BINARY_COMPARISON,   2,   0,  "<",  0, 1);
+  CHECK_EXPRESSION(  3,                    IF,  -1,   0,   "",     2);
+  CHECK_EXPRESSION(  4,                  ELSE,  -1,   0,   "",      );
+  CHECK_EXPRESSION(  5,                 ENDIF,  -1,   0,   "",      );
   // clang-format on
 }
 
@@ -162,17 +162,17 @@
 
   // clang-format off
   // Id,   Type,                  Lhs, Value, Name, Arguments...
-  TE(  0,  COMPILE_TIME_CONSTANT,    0,    2,   "",      );
-  TE(  1,  COMPILE_TIME_CONSTANT,    1,    3,   "",      );
-  TE(  2,      BINARY_COMPARISON,    2,    0,  "<",  0, 1);
-  TE(  3,                     IF,   -1,    0,   "",     2);
-  TE(  4,      BINARY_ARITHMETIC,    4,    0,  "+",  0, 1);
-  TE(  5,                   ELSE,   -1,    0,   "",      );
-  TE(  6,      BINARY_ARITHMETIC,    6,    0,  "-",  0, 1);
-  TE(  7,             ASSIGNMENT,    4,    0,   "",  6   );
-  TE(  8,                  ENDIF,   -1,    0,   "",      );
-  TE(  9,      BINARY_ARITHMETIC,    9,    0,  "+",  4, 0);
-  TE( 10,             ASSIGNMENT,    4,    0,   "",  9   );
+  CHECK_EXPRESSION(  0,  COMPILE_TIME_CONSTANT,    0,    2,   "",      );
+  CHECK_EXPRESSION(  1,  COMPILE_TIME_CONSTANT,    1,    3,   "",      );
+  CHECK_EXPRESSION(  2,      BINARY_COMPARISON,    2,    0,  "<",  0, 1);
+  CHECK_EXPRESSION(  3,                     IF,   -1,    0,   "",     2);
+  CHECK_EXPRESSION(  4,      BINARY_ARITHMETIC,    4,    0,  "+",  0, 1);
+  CHECK_EXPRESSION(  5,                   ELSE,   -1,    0,   "",      );
+  CHECK_EXPRESSION(  6,      BINARY_ARITHMETIC,    6,    0,  "-",  0, 1);
+  CHECK_EXPRESSION(  7,             ASSIGNMENT,    4,    0,   "",  6   );
+  CHECK_EXPRESSION(  8,                  ENDIF,   -1,    0,   "",      );
+  CHECK_EXPRESSION(  9,      BINARY_ARITHMETIC,    9,    0,  "+",  4, 0);
+  CHECK_EXPRESSION( 10,             ASSIGNMENT,    4,    0,   "",  9   );
   // clang-format on
 
   // Variables after execution:
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc
index 224a3bf..7c2bc5d 100644
--- a/internal/ceres/expression_ref.cc
+++ b/internal/ceres/expression_ref.cc
@@ -44,6 +44,8 @@
   id = Expression::CreateCompileTimeConstant(compile_time_constant);
 }
 
+ExpressionRef::ExpressionRef(const ExpressionRef& other) { *this = other; }
+
 ExpressionRef& ExpressionRef::operator=(const ExpressionRef& other) {
   // Assigning an uninitialized variable to another variable is an error.
   CHECK(other.IsInitialized()) << "Uninitialized Assignment.";
@@ -111,8 +113,13 @@
 }
 
 // Functions
-ExpressionRef sin(ExpressionRef x) {
-  return ExpressionRef::Create(Expression::CreateFunctionCall("sin", {x.id}));
+ExpressionRef MakeFunctionCall(const std::string& name,
+                               const std::vector<ExpressionRef>& params) {
+  std::vector<ExpressionId> ids;
+  for (auto p : params) {
+    ids.push_back(p.id);
+  }
+  return ExpressionRef::Create(Expression::CreateFunctionCall(name, ids));
 }
 
 ExpressionRef Ternary(ComparisonExpressionRef c,
diff --git a/internal/ceres/expression_test.cc b/internal/ceres/expression_test.cc
index d020192..a8a8213 100644
--- a/internal/ceres/expression_test.cc
+++ b/internal/ceres/expression_test.cc
@@ -31,10 +31,8 @@
 
 #define CERES_CODEGEN
 
-#include "ceres/internal/expression_graph.h"
-#include "ceres/internal/expression_ref.h"
-
-#include "gtest/gtest.h"
+#include "ceres/expression_test.h"
+#include "ceres/jet.h"
 
 namespace ceres {
 namespace internal {
@@ -115,6 +113,51 @@
   ASSERT_TRUE(graph.ExpressionForId(d.id).DirectlyDependsOn(c.id));
 }
 
+TEST(Expression, Jet) {
+  using T = Jet<ExpressionRef, 1>;
+
+  StartRecordingExpressions();
+
+  T a(2, 0);
+  T b = a * a;
+
+  auto graph = StopRecordingExpressions();
+
+  // b is valid during the assignment so we expect an
+  // additional assignment expression.
+  EXPECT_EQ(graph.Size(), 8);
+
+  // Expected code
+  //   v_0 = 2;
+  //   v_1 = 0;
+  //   v_2 = 1;
+  //   v_1 = v_2;
+  //   v_3 = v_0 * v_0;
+  //   v_4 = v_0 * v_1;
+  //   v_5 = v_1 * v_0;
+  //   v_6 = v_3 * v_4;
+  //   v_7 = v_5 * v_6;
+
+  // clang-format off
+  // Id, Type, Lhs, Value, Name, Arguments
+  CHECK_EXPRESSION(  0,  COMPILE_TIME_CONSTANT,   0,   2,   "",     );
+  CHECK_EXPRESSION(  1,  COMPILE_TIME_CONSTANT,   1,   0,   "",     );
+  CHECK_EXPRESSION(  2,  COMPILE_TIME_CONSTANT,   2,   1,   "",     );
+  CHECK_EXPRESSION(  3,             ASSIGNMENT,   1,   0,   "", 2   );
+  CHECK_EXPRESSION(  4,      BINARY_ARITHMETIC,   4,   0,  "*", 0, 0);
+  CHECK_EXPRESSION(  5,      BINARY_ARITHMETIC,   5,   0,  "*", 0, 1);
+  CHECK_EXPRESSION(  6,      BINARY_ARITHMETIC,   6,   0,  "*", 1, 0);
+  CHECK_EXPRESSION(  7,      BINARY_ARITHMETIC,   7,   0,  "+", 5, 6);
+  // clang-format on
+
+  // Variables after execution:
+  //
+  // b.a         <=> v_4
+  // b.v[0]      <=> v_7
+  EXPECT_EQ(b.a.id, 4);
+  EXPECT_EQ(b.v[0].id, 7);
+}
+
 // Todo: remaining functions of Expression
 
 }  // namespace internal
diff --git a/internal/ceres/expression_test.h b/internal/ceres/expression_test.h
index de9e2c2..2be7f8e 100644
--- a/internal/ceres/expression_test.h
+++ b/internal/ceres/expression_test.h
@@ -55,12 +55,12 @@
   EXPECT_EQ(expr.arguments(), arguments);
 }
 
-#define TE(_id, _type, _lhs_id, _value, _name, ...) \
-  TestExpression(graph.ExpressionForId(_id),        \
-                 ExpressionType::_type,             \
-                 _lhs_id,                           \
-                 _value,                            \
-                 _name,                             \
+#define CHECK_EXPRESSION(_id, _type, _lhs_id, _value, _name, ...) \
+  TestExpression(graph.ExpressionForId(_id),                      \
+                 ExpressionType::_type,                           \
+                 _lhs_id,                                         \
+                 _value,                                          \
+                 _name,                                           \
                  {__VA_ARGS__})
 
 }  // namespace internal