Autodiff Codegen Part 2: Conditionals

- Expression types for if/else blocks
- CERES_IF/ELSE macros
- Multiple assignment to the same variable

Change-Id: If529516f243f31823d1ef7b8827bb6f2390e418d
diff --git a/include/ceres/internal/expression.h b/include/ceres/internal/expression.h
index c990d93..5ab2d62 100644
--- a/include/ceres/internal/expression.h
+++ b/include/ceres/internal/expression.h
@@ -28,17 +28,144 @@
 //
 // Author: darius.rueckert@fau.de (Darius Rueckert)
 //
+// During code generation, your cost functor is converted into a list of
+// expressions stored in an expression graph. For each operator (+,-,=,...),
+// function call (sin,cos,...), and special keyword (if,else,...) the
+// appropriate ExpressionType is selected. On a high level all ExpressionTypes
+// are grouped into two different classes: Arithmetic expressions and control
+// expressions.
 //
-// This file contains the basic expression type, which is used during code
-// generation. Only assignment expressions of the following form are supported:
+// Part 1: Arithmetic Expressions
 //
-// result = [constant|binary_expr|functioncall]
+// Arithmetic expression are the most basic and common types. They are all of
+// the following form:
 //
-// Examples:
-// v_78 = v_28 / v_62;
-// v_97 = exp(v_20);
-// v_89 = 3.000000;
+// <lhs> = <rhs>
 //
+// <lhs> is the variable name on the left hand side of the assignment. <rhs> can
+// be different depending on the ExpressionType. It must evaluate to a single
+// scalar value though. Here are a few examples of arithmetic expressions (the
+// ExpressionType is given on the right):
+//
+// v_0 = 3.1415;        // COMPILE_TIME_CONSTANT
+// v_1 = v_0;           // ASSIGNMENT
+// v_2 = v_0 + v_1;     // PLUS
+// v_3 = v_2 / v_0;     // DIVISION
+// v_4 = sin(v_3);      // FUNCTION_CALL
+// v_5 = v_4 < v_3;     // BINARY_COMPARISON
+//
+// As you can see, the right hand side of each expression contains exactly one
+// operator/value/function call. If you write long expressions like
+//
+// T c = a + b - T(3) * a;
+//
+// it will broken up into the individual expressions like so:
+//
+// v_0 = a + b;
+// v_1 = 3;
+// v_2 = v_1 * a;
+// c   = v_0 - v_2;
+//
+// All arithmetic expressions are generated by operator and function
+// overloading. These overloads are defined in expression_ref.h.
+//
+//
+//
+// Part 2: Control Expressions
+//
+// Control expressions include special instructions that handle the control flow
+// of a program. So far, only if/else is supported, but while/for might come in
+// the future.
+//
+// Generating code for conditional jumps (if/else) is more complicated than
+// for arithmetic expressions. Let's look at a small example to see the
+// problems. After that we explain how these problems are solved in Ceres.
+//
+// 1    T a = parameters[0][0];
+// 2    T b = 1.0;
+// 3    if (a < b) {
+// 4      b = 3.0;
+// 5    } else {
+// 6      b = 4.0;
+// 7    }
+// 8    b += 1.0;
+// 9    residuals[0] = b;
+//
+// Problem 1.
+// We need to generate code for both branches. In C++ there is no way to execute
+// both branches of an if, but we need to execute them to generate the code.
+//
+// Problem 2.
+// The comparison a < b in line 3 is not convertible to bool. Since the value of
+// a is not known during code generation, the expression a < b can not be
+// evaluated. In fact, a < b will return an expression of type
+// BINARY_COMPARISON.
+//
+// Problem 3.
+// There is no way to record that an if was executed. "if" is a special operator
+// which cannot be overloaded. Therefore we can't generate code that contains
+// "if.
+//
+// Problem 4.
+// We have no information about "blocks" or "scopes" during code generation.
+// Even if we could overload the if-operator, there is now way to capture which
+// expression was executed in which branches of the if. For example, we generate
+// code for the else branch. How can we know that the else branch is finished?
+// Is line 8 inside the else-block or already outside?
+//
+// Solution.
+// Instead of using the keywords if/else we insert the macros
+// CERES_IF, CERES_ELSE and CERES_ENDIF. These macros just map to a function,
+// which inserts an expression into the graph. Here is how the example from
+// above looks like with the expanded macros:
+//
+// 1    T a = parameters[0][0];
+// 2    T b = 1.0;
+// 3    CreateIf(a < b); {
+// 4      b = 3.0;
+// 5    } CreateElse(); {
+// 6      b = 4.0;
+// 7    } CreateEndif();
+// 8    b += 1.0;
+// 9    residuals[0] = b;
+//
+// Problem 1 solved.
+// There are no branches during code generation, therefore both blocks are
+// evaluated.
+//
+// Problem 2 solved.
+// The function CreateIf(_) does not take a bool as argument, but an
+// ComparisonExpression. Later during code generation an actual "if" is created
+// with the condition as argument.
+//
+// Problem 3 solved.
+// We replaced "if" by a function call so we can record it now.
+//
+// Problem 4 solved.
+// Expressions are added into the graph in the correct order. That means, after
+// seeing a CreateIf() we know that all following expressions until CreateElse()
+// belong to the true-branch. Similar, all expression from CreateElse() to
+// CreateEndif() belong to the false-branch. This also works recursively with
+// nested ifs.
+//
+// If you want to use the AutoDiff code generation for your cost functors, you
+// have to replace all if/else by the CERES_IF, CERES_ELSE and CERES_ENDIF
+// macros. The example from above looks like this:
+//
+// 1    T a = parameters[0][0];
+// 2    T b = 1.0;
+// 3    CERES_IF (a < b) {
+// 4      b = 3.0;
+// 5    } CERES_ELSE {
+// 6      b = 4.0;
+// 7    } CERES_ENDIF;
+// 8    b += 1.0;
+// 9    residuals[0] = b;
+//
+// These macros don't have a negative impact on performance, because they only
+// expand to the CreateIf/.. functions in code generation mode. Otherwise they
+// expand to the if/else keywords. See expression_ref.h for the exact
+// definition.
 //
 #ifndef CERES_PUBLIC_EXPRESSION_H_
 #define CERES_PUBLIC_EXPRESSION_H_
@@ -68,26 +195,25 @@
   // residual[0] = v_51;
   OUTPUT_ASSIGNMENT,
 
-  // Trivial Assignment
-  // v_1 = v_0;
+  // Trivial assignment
+  // v_3 = v_1
   ASSIGNMENT,
 
   // Binary Arithmetic Operations
   // v_2 = v_0 + v_1
-  PLUS,
-  MINUS,
-  MULTIPLICATION,
-  DIVISION,
+  // The operator is stored in Expression::name_.
+  BINARY_ARITHMETIC,
 
   // Unary Arithmetic Operation
   // v_1 = -(v_0);
   // v_2 = +(v_1);
-  UNARY_MINUS,
-  UNARY_PLUS,
+  // The operator is stored in Expression::name_.
+  UNARY_ARITHMETIC,
 
   // Binary Comparison. (<,>,&&,...)
   // This is the only expressions which returns a 'bool'.
-  // const bool v_2 = v_0 < v_1
+  // v_2 = v_0 < v_1
+  // The operator is stored in Expression::name_.
   BINARY_COMPARISON,
 
   // The !-operator on logical expression.
@@ -102,6 +228,12 @@
   // v_3 = ternary(v_0,v_1,v_2);
   TERNARY,
 
+  // Conditional control expressions if/else/endif.
+  // These are special expressions, because they don't define a new variable.
+  IF,
+  ELSE,
+  ENDIF,
+
   // No Operation. A placeholder for an 'empty' expressions which will be
   // optimized out during code generation.
   NOP
@@ -129,11 +261,11 @@
   static ExpressionId CreateParameter(const std::string& name);
   static ExpressionId CreateOutputAssignment(ExpressionId v,
                                              const std::string& name);
-  static ExpressionId CreateAssignment(ExpressionId v);
-  static ExpressionId CreateBinaryArithmetic(ExpressionType type,
+  static ExpressionId CreateAssignment(ExpressionId dst, ExpressionId src);
+  static ExpressionId CreateBinaryArithmetic(const std::string& op,
                                              ExpressionId l,
                                              ExpressionId r);
-  static ExpressionId CreateUnaryArithmetic(ExpressionType type,
+  static ExpressionId CreateUnaryArithmetic(const std::string& op,
                                             ExpressionId v);
   static ExpressionId CreateBinaryCompare(const std::string& name,
                                           ExpressionId l,
@@ -145,9 +277,19 @@
                                     ExpressionId if_true,
                                     ExpressionId if_false);
 
-  // Returns true if the expression type is one of the basic math-operators:
-  // +,-,*,/
-  bool IsArithmetic() const;
+  // Conditional control expressions are inserted into the graph but can't be
+  // referenced by other expressions. Therefore they don't return an
+  // ExpressionId.
+  static void CreateIf(ExpressionId condition);
+  static void CreateElse();
+  static void CreateEndIf();
+
+  // Returns true if this is an arithmetic expression.
+  // Arithmetic expressions must have a valid left hand side.
+  bool IsArithmeticExpression() const;
+
+  // Returns true if this is a control expression.
+  bool IsControlExpression() const;
 
   // If this expression is the compile time constant with the given value.
   // Used during optimization to collapse zero/one arithmetic operations.
@@ -170,16 +312,36 @@
   // Converts this expression into a NOP
   void MakeNop();
 
+  // Returns true if this expression has a valid lhs.
+  bool HasValidLhs() const { return lhs_id_ != kInvalidExpressionId; }
+
+  ExpressionType type() const { return type_; }
+  ExpressionId lhs_id() const { return lhs_id_; }
+  double value() const { return value_; }
+  const std::string& name() const { return name_; }
+  const std::vector<ExpressionId>& arguments() const { return arguments_; }
+
  private:
   // Only ExpressionGraph is allowed to call the constructor, because it manages
   // the memory and ids.
   friend class ExpressionGraph;
 
   // Private constructor. Use the "CreateXX" functions instead.
-  Expression(ExpressionType type, ExpressionId id);
+  Expression(ExpressionType type, ExpressionId lhs_id);
 
   ExpressionType type_ = ExpressionType::NOP;
-  const ExpressionId id_ = kInvalidExpressionId;
+
+  // If lhs_id_ >= 0, then this expression is assigned to v_<lhs_id>.
+  // For example:
+  //    v_1 = v_0 + v_0     (Type = PLUS)
+  //    v_3 = sin(v_1)      (Type = FUNCTION_CALL)
+  //      ^
+  //   lhs_id_
+  //
+  // If lhs_id_ == kInvalidExpressionId, then the expression type is not
+  // arithmetic. Currently, only the following types have lhs_id = invalid:
+  // IF,ELSE,ENDIF,NOP
+  const ExpressionId lhs_id_ = kInvalidExpressionId;
 
   // Expressions have different number of arguments. For example a binary "+"
   // has 2 parameters and a function call to "sin" has 1 parameter. Here, a
diff --git a/include/ceres/internal/expression_graph.h b/include/ceres/internal/expression_graph.h
index 446fddb..308528f 100644
--- a/include/ceres/internal/expression_graph.h
+++ b/include/ceres/internal/expression_graph.h
@@ -48,9 +48,25 @@
 // A is parent of B    <=>  A has B as a parameter    <=> A.DirectlyDependsOn(B)
 class ExpressionGraph {
  public:
-  // Creates an expression and adds it to expressions_.
-  // The returned reference will be invalid after this function is called again.
-  Expression& CreateExpression(ExpressionType type);
+  // Creates an arithmetic expression of the following form:
+  // <lhs> = <rhs>;
+  //
+  // For example:
+  //   CreateArithmeticExpression(PLUS, 5)
+  // will generate:
+  //   v_5 = __ + __;
+  // The place holders are then set by the CreateXX functions of Expression.
+  //
+  // If lhs_id == kInvalidExpressionId, then a new lhs_id will be generated and
+  // assigned to the created expression.
+  // Calling this function with a lhs_id that doesn't exist results in an
+  // error.
+  Expression& CreateArithmeticExpression(ExpressionType type,
+                                         ExpressionId lhs_id);
+
+  // Control expression don't have a left hand side.
+  // Supported types: IF/ELSE/ENDIF/NOP
+  Expression& CreateControlExpression(ExpressionType type);
 
   // Checks if A depends on B.
   // -> B is a descendant of A
diff --git a/include/ceres/internal/expression_ref.h b/include/ceres/internal/expression_ref.h
index 67ff227..f4b920f 100644
--- a/include/ceres/internal/expression_ref.h
+++ b/include/ceres/internal/expression_ref.h
@@ -52,8 +52,26 @@
   // it's automatically converted to the correct expression.
   explicit ExpressionRef(double compile_time_constant);
 
-  // Returns v_id
-  std::string ToString() const;
+  // Create an ASSIGNMENT expression from other to this.
+  //
+  // For example:
+  //   a = b;        // With a.id = 5 and b.id = 3
+  // will generate the following assignment:
+  //   v_5 = v_3;
+  //
+  // If this (lhs) ExpressionRef is currently not pointing to a variable
+  // (id==invalid), then we can eliminate the assignment by just letting "this"
+  // point to the same variable as "other".
+  //
+  // Example:
+  //   a = b;       // With a.id = invalid and b.id = 3
+  // will generate NO expression, but after this line the following will be
+  // true:
+  //    a.id == b.id == 3
+  //
+  // If 'other' is not pointing to a variable (id==invalid), we found an
+  // uninitialized assignment, which is handled as an error.
+  ExpressionRef& operator=(const ExpressionRef& other);
 
   // Compound operators
   ExpressionRef& operator+=(ExpressionRef x);
@@ -61,6 +79,8 @@
   ExpressionRef& operator*=(ExpressionRef x);
   ExpressionRef& operator/=(ExpressionRef x);
 
+  bool IsInitialized() const { return id != kInvalidExpressionId; }
+
   // The index into the ExpressionGraph data array.
   ExpressionId id = kInvalidExpressionId;
 
@@ -152,6 +172,22 @@
 
 #define CERES_EXPRESSION_RUNTIME_CONSTANT(_v) \
   ceres::internal::MakeRuntimeConstant<T>(_v, #_v)
+
+// The CERES_CODEGEN macro is defined by the build system only during code
+// generation. In all other cases the CERES_IF/ELSE macros just expand to the
+// if/else keywords.
+#ifdef CERES_CODEGEN
+#define CERES_IF(condition_) Expression::CreateIf((condition_).id);
+#define CERES_ELSE Expression::CreateElse();
+#define CERES_ENDIF Expression::CreateEndIf();
+#else
+// clang-format off
+#define CERES_IF(condition_) if (condition_) {
+#define CERES_ELSE } else {
+#define CERES_ENDIF }
+// clang-format on
+#endif
+
 }  // namespace internal
 
 // See jet.h for more info on this type.
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt
index 63e8540..15770ad 100644
--- a/internal/ceres/CMakeLists.txt
+++ b/internal/ceres/CMakeLists.txt
@@ -446,6 +446,7 @@
   ceres_test(evaluation_callback)
   ceres_test(evaluator)
   ceres_test(expression)
+  ceres_test(conditional_expressions)
   ceres_test(expression_graph)
   ceres_test(fixed_array)
   ceres_test(gradient_checker)
diff --git a/internal/ceres/conditional_expressions_test.cc b/internal/ceres/conditional_expressions_test.cc
new file mode 100644
index 0000000..72fe3ee
--- /dev/null
+++ b/internal/ceres/conditional_expressions_test.cc
@@ -0,0 +1,189 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2019 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+//   this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+//   this list of conditions and the following disclaimer in the documentation
+//   and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+//   used to endorse or promote products derived from this software without
+//   specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: darius.rueckert@fau.de (Darius Rueckert)
+//
+
+#define CERES_CODEGEN
+
+#include "expression_test.h"
+
+namespace ceres {
+namespace internal {
+
+TEST(Expression, AssignmentElimination) {
+  using T = ExpressionRef;
+
+  StartRecordingExpressions();
+  T a(2);
+  T b;
+  b = a;
+  auto graph = StopRecordingExpressions();
+
+  // b is invalid during the assignment so we expect no expression to be
+  // generated. The only expression in the graph should be the constant
+  // assignment to a.
+  EXPECT_EQ(graph.Size(), 1);
+
+  // Expected code
+  //   v_0 = 2;
+
+  // clang-format off
+  // Id  Type                   Lhs  Value Name  Arguments
+  TE(0,  COMPILE_TIME_CONSTANT, 0,   2,     "");
+  // clang-format on
+
+  // Variables after execution:
+  //
+  // a      <=> v_0
+  // b      <=> v_0
+  EXPECT_EQ(a.id, 0);
+  EXPECT_EQ(b.id, 0);
+}
+
+TEST(Expression, Assignment) {
+  using T = ExpressionRef;
+
+  StartRecordingExpressions();
+  T a(2);
+  T b(4);
+  b = a;
+  auto graph = StopRecordingExpressions();
+
+  // b is valid during the assignment so we expect an
+  // additional assignment expression.
+  EXPECT_EQ(graph.Size(), 3);
+
+  // Expected code
+  //   v_0 = 2;
+  //   v_1 = 4;
+  //   v_1 = v_0;
+
+  // clang-format off
+  // Id, Type, Lhs, Value, Name, Arguments
+  TE(  0,  COMPILE_TIME_CONSTANT,   0,   2,   "",   );
+  TE(  1,  COMPILE_TIME_CONSTANT,   1,   4,   "",   );
+  TE(  2,             ASSIGNMENT,   1,   0,   "",  0);
+  // clang-format on
+
+  // Variables after execution:
+  //
+  // a      <=> v_0
+  // b      <=> v_1
+  EXPECT_EQ(a.id, 0);
+  EXPECT_EQ(b.id, 1);
+}
+
+TEST(Expression, ConditionalMinimal) {
+  using T = ExpressionRef;
+
+  StartRecordingExpressions();
+  T a(2);
+  T b(3);
+  auto c = a < b;
+  CERES_IF(c) {}
+  CERES_ELSE {}
+  CERES_ENDIF
+  auto graph = StopRecordingExpressions();
+
+  // Expected code
+  //   v_0 = 2;
+  //   v_1 = 3;
+  //   v_2 = v_0 < v_1;
+  //   if(v_2);
+  //   else
+  //   endif
+
+  EXPECT_EQ(graph.Size(), 6);
+
+  // clang-format off
+  // Id, Type, Lhs, Value, Name, Arguments...
+  TE(  0, COMPILE_TIME_CONSTANT,   0,   2,   "",      );
+  TE(  1, COMPILE_TIME_CONSTANT,   1,   3,   "",      );
+  TE(  2,     BINARY_COMPARISON,   2,   0,  "<",  0, 1);
+  TE(  3,                    IF,  -1,   0,   "",     2);
+  TE(  4,                  ELSE,  -1,   0,   "",      );
+  TE(  5,                 ENDIF,  -1,   0,   "",      );
+  // clang-format on
+}
+
+TEST(Expression, ConditionalAssignment) {
+  using T = ExpressionRef;
+
+  StartRecordingExpressions();
+
+  T result;
+  T a(2);
+  T b(3);
+  auto c = a < b;
+  CERES_IF(c) { result = a + b; }
+  CERES_ELSE { result = a - b; }
+  CERES_ENDIF
+  result += a;
+  auto graph = StopRecordingExpressions();
+
+  // Expected code
+  //   v_0 = 2;
+  //   v_1 = 3;
+  //   v_2 = v_0 < v_1;
+  //   if(v_2);
+  //     v_4 = v_0 + v_1;
+  //   else
+  //     v_6 = v_0 - v_1;
+  //     v_4 = v_6
+  //   endif
+  //   v_9 = v_4 + v_0;
+  //   v_4 = v_9;
+
+  // clang-format off
+  // Id,   Type,                  Lhs, Value, Name, Arguments...
+  TE(  0,  COMPILE_TIME_CONSTANT,    0,    2,   "",      );
+  TE(  1,  COMPILE_TIME_CONSTANT,    1,    3,   "",      );
+  TE(  2,      BINARY_COMPARISON,    2,    0,  "<",  0, 1);
+  TE(  3,                     IF,   -1,    0,   "",     2);
+  TE(  4,      BINARY_ARITHMETIC,    4,    0,  "+",  0, 1);
+  TE(  5,                   ELSE,   -1,    0,   "",      );
+  TE(  6,      BINARY_ARITHMETIC,    6,    0,  "-",  0, 1);
+  TE(  7,             ASSIGNMENT,    4,    0,   "",  6   );
+  TE(  8,                  ENDIF,   -1,    0,   "",      );
+  TE(  9,      BINARY_ARITHMETIC,    9,    0,  "+",  4, 0);
+  TE( 10,             ASSIGNMENT,    4,    0,   "",  9   );
+  // clang-format on
+
+  // Variables after execution:
+  //
+  // a      <=> v_0
+  // b      <=> v_1
+  // result <=> v_4
+  EXPECT_EQ(a.id, 0);
+  EXPECT_EQ(b.id, 1);
+  EXPECT_EQ(result.id, 4);
+}
+
+}  // namespace internal
+}  // namespace ceres
diff --git a/internal/ceres/expression.cc b/internal/ceres/expression.cc
index 3edcc7d..4c8dd5c 100644
--- a/internal/ceres/expression.cc
+++ b/internal/ceres/expression.cc
@@ -37,112 +37,127 @@
 namespace ceres {
 namespace internal {
 
-static Expression& MakeExpression(ExpressionType type) {
+// Wrapper for ExpressionGraph::CreateArithmeticExpression, which checks if a
+// graph is currently active. See that function for an explanation.
+static Expression& MakeArithmeticExpression(
+    ExpressionType type, ExpressionId lhs_id = kInvalidExpressionId) {
   auto pool = GetCurrentExpressionGraph();
   CHECK(pool)
       << "The ExpressionGraph has to be created before using Expressions. This "
          "is achieved by calling ceres::StartRecordingExpressions.";
-  return pool->CreateExpression(type);
+  return pool->CreateArithmeticExpression(type, lhs_id);
+}
+
+// Wrapper for ExpressionGraph::CreateControlExpression.
+static Expression& MakeControlExpression(ExpressionType type) {
+  auto pool = GetCurrentExpressionGraph();
+  CHECK(pool)
+      << "The ExpressionGraph has to be created before using Expressions. This "
+         "is achieved by calling ceres::StartRecordingExpressions.";
+  return pool->CreateControlExpression(type);
 }
 
 ExpressionId Expression::CreateCompileTimeConstant(double v) {
-  auto& expr = MakeExpression(ExpressionType::COMPILE_TIME_CONSTANT);
+  auto& expr = MakeArithmeticExpression(ExpressionType::COMPILE_TIME_CONSTANT);
   expr.value_ = v;
-  return expr.id_;
+  return expr.lhs_id_;
 }
 
 ExpressionId Expression::CreateRuntimeConstant(const std::string& name) {
-  auto& expr = MakeExpression(ExpressionType::RUNTIME_CONSTANT);
+  auto& expr = MakeArithmeticExpression(ExpressionType::RUNTIME_CONSTANT);
   expr.name_ = name;
-  return expr.id_;
+  return expr.lhs_id_;
 }
 
 ExpressionId Expression::CreateParameter(const std::string& name) {
-  auto& expr = MakeExpression(ExpressionType::PARAMETER);
+  auto& expr = MakeArithmeticExpression(ExpressionType::PARAMETER);
   expr.name_ = name;
-  return expr.id_;
+  return expr.lhs_id_;
 }
 
-ExpressionId Expression::CreateAssignment(ExpressionId v) {
-  auto& expr = MakeExpression(ExpressionType::ASSIGNMENT);
-  expr.arguments_.push_back(v);
-  return expr.id_;
+ExpressionId Expression::CreateAssignment(ExpressionId dst, ExpressionId src) {
+  auto& expr = MakeArithmeticExpression(ExpressionType::ASSIGNMENT, dst);
+
+  expr.arguments_.push_back(src);
+  return expr.lhs_id_;
 }
 
-ExpressionId Expression::CreateUnaryArithmetic(ExpressionType type,
+ExpressionId Expression::CreateUnaryArithmetic(const std::string& op,
                                                ExpressionId v) {
-  auto& expr = MakeExpression(type);
+  auto& expr = MakeArithmeticExpression(ExpressionType::UNARY_ARITHMETIC);
+  expr.name_ = op;
   expr.arguments_.push_back(v);
-  return expr.id_;
+  return expr.lhs_id_;
 }
 
 ExpressionId Expression::CreateOutputAssignment(ExpressionId v,
                                                 const std::string& name) {
-  auto& expr = MakeExpression(ExpressionType::OUTPUT_ASSIGNMENT);
+  auto& expr = MakeArithmeticExpression(ExpressionType::OUTPUT_ASSIGNMENT);
   expr.arguments_.push_back(v);
   expr.name_ = name;
-  return expr.id_;
+  return expr.lhs_id_;
 }
 
 ExpressionId Expression::CreateFunctionCall(
     const std::string& name, const std::vector<ExpressionId>& params) {
-  auto& expr = MakeExpression(ExpressionType::FUNCTION_CALL);
+  auto& expr = MakeArithmeticExpression(ExpressionType::FUNCTION_CALL);
   expr.arguments_ = params;
   expr.name_ = name;
-  return expr.id_;
+  return expr.lhs_id_;
 }
 
 ExpressionId Expression::CreateTernary(ExpressionId condition,
                                        ExpressionId if_true,
                                        ExpressionId if_false) {
-  auto& expr = MakeExpression(ExpressionType::TERNARY);
+  auto& expr = MakeArithmeticExpression(ExpressionType::TERNARY);
   expr.arguments_.push_back(condition);
   expr.arguments_.push_back(if_true);
   expr.arguments_.push_back(if_false);
-  return expr.id_;
+  return expr.lhs_id_;
 }
 
 ExpressionId Expression::CreateBinaryCompare(const std::string& name,
                                              ExpressionId l,
                                              ExpressionId r) {
-  auto& expr = MakeExpression(ExpressionType::BINARY_COMPARISON);
+  auto& expr = MakeArithmeticExpression(ExpressionType::BINARY_COMPARISON);
   expr.arguments_.push_back(l);
   expr.arguments_.push_back(r);
   expr.name_ = name;
-  return expr.id_;
+  return expr.lhs_id_;
 }
 
 ExpressionId Expression::CreateLogicalNegation(ExpressionId v) {
-  auto& expr = MakeExpression(ExpressionType::LOGICAL_NEGATION);
+  auto& expr = MakeArithmeticExpression(ExpressionType::LOGICAL_NEGATION);
   expr.arguments_.push_back(v);
-  return expr.id_;
+  return expr.lhs_id_;
 }
 
-ExpressionId Expression::CreateBinaryArithmetic(ExpressionType type,
+ExpressionId Expression::CreateBinaryArithmetic(const std::string& op,
                                                 ExpressionId l,
                                                 ExpressionId r) {
-  auto& expr = MakeExpression(type);
+  auto& expr = MakeArithmeticExpression(ExpressionType::BINARY_ARITHMETIC);
+  expr.name_ = op;
   expr.arguments_.push_back(l);
   expr.arguments_.push_back(r);
-  return expr.id_;
+  return expr.lhs_id_;
 }
-Expression::Expression(ExpressionType type, ExpressionId id)
-    : type_(type), id_(id) {}
 
-bool Expression::IsArithmetic() const {
-  switch (type_) {
-    case ExpressionType::PLUS:
-    case ExpressionType::MULTIPLICATION:
-    case ExpressionType::DIVISION:
-    case ExpressionType::MINUS:
-    case ExpressionType::UNARY_MINUS:
-    case ExpressionType::UNARY_PLUS:
-      return true;
-    default:
-      return false;
-  }
+void Expression::CreateIf(ExpressionId condition) {
+  auto& expr = MakeControlExpression(ExpressionType::IF);
+  expr.arguments_.push_back(condition);
 }
 
+void Expression::CreateElse() { MakeControlExpression(ExpressionType::ELSE); }
+
+void Expression::CreateEndIf() { MakeControlExpression(ExpressionType::ENDIF); }
+
+Expression::Expression(ExpressionType type, ExpressionId id)
+    : type_(type), lhs_id_(id) {}
+
+bool Expression::IsArithmeticExpression() const { return HasValidLhs(); }
+
+bool Expression::IsControlExpression() const { return !HasValidLhs(); }
+
 bool Expression::IsReplaceableBy(const Expression& other) const {
   // Check everything except the id.
   return (type_ == other.type_ && name_ == other.name_ &&
@@ -150,7 +165,7 @@
 }
 
 void Expression::Replace(const Expression& other) {
-  if (other.id_ == id_) {
+  if (other.lhs_id_ == lhs_id_) {
     return;
   }
 
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc
index 0757b97..5c2b84e 100644
--- a/internal/ceres/expression_graph.cc
+++ b/internal/ceres/expression_graph.cc
@@ -55,9 +55,23 @@
 
 ExpressionGraph* GetCurrentExpressionGraph() { return expression_pool; }
 
-Expression& ExpressionGraph::CreateExpression(ExpressionType type) {
-  auto id = expressions_.size();
-  Expression expr(type, id);
+Expression& ExpressionGraph::CreateArithmeticExpression(ExpressionType type,
+                                                        ExpressionId lhs_id) {
+  if (lhs_id == kInvalidExpressionId) {
+    // We are creating a new temporary variable.
+    // -> The new lhs_id is the index into the graph
+    lhs_id = static_cast<ExpressionId>(expressions_.size());
+  } else {
+    // The left hand side already exists.
+  }
+
+  Expression expr(type, lhs_id);
+  expressions_.push_back(expr);
+  return expressions_.back();
+}
+
+Expression& ExpressionGraph::CreateControlExpression(ExpressionType type) {
+  Expression expr(type, kInvalidExpressionId);
   expressions_.push_back(expr);
   return expressions_.back();
 }
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc
index 75a7723..224a3bf 100644
--- a/internal/ceres/expression_ref.cc
+++ b/internal/ceres/expression_ref.cc
@@ -29,8 +29,7 @@
 // Author: darius.rueckert@fau.de (Darius Rueckert)
 
 #include "ceres/internal/expression_ref.h"
-#include "assert.h"
-#include "ceres/internal/expression.h"
+#include "glog/logging.h"
 
 namespace ceres {
 namespace internal {
@@ -41,11 +40,24 @@
   return ref;
 }
 
-std::string ExpressionRef::ToString() const { return std::to_string(id); }
-
 ExpressionRef::ExpressionRef(double compile_time_constant) {
-  (*this) = ExpressionRef::Create(
-      Expression::CreateCompileTimeConstant(compile_time_constant));
+  id = Expression::CreateCompileTimeConstant(compile_time_constant);
+}
+
+ExpressionRef& ExpressionRef::operator=(const ExpressionRef& other) {
+  // Assigning an uninitialized variable to another variable is an error.
+  CHECK(other.IsInitialized()) << "Uninitialized Assignment.";
+
+  if (IsInitialized()) {
+    // Create assignment from other -> this
+    Expression::CreateAssignment(id, other.id);
+  } else {
+    // Special case: "this" expressionref is invalid
+    //    -> Skip assignment
+    //    -> Let "this" point to the same variable as other
+    id = other.id;
+  }
+  return *this;
 }
 
 // Compound operators
@@ -71,33 +83,31 @@
 
 // Arith. Operators
 ExpressionRef operator-(ExpressionRef x) {
-  return ExpressionRef::Create(
-      Expression::CreateUnaryArithmetic(ExpressionType::UNARY_MINUS, x.id));
+  return ExpressionRef::Create(Expression::CreateUnaryArithmetic("-", x.id));
 }
 
 ExpressionRef operator+(ExpressionRef x) {
-  return ExpressionRef::Create(
-      Expression::CreateUnaryArithmetic(ExpressionType::UNARY_PLUS, x.id));
+  return ExpressionRef::Create(Expression::CreateUnaryArithmetic("+", x.id));
 }
 
 ExpressionRef operator+(ExpressionRef x, ExpressionRef y) {
   return ExpressionRef::Create(
-      Expression::CreateBinaryArithmetic(ExpressionType::PLUS, x.id, y.id));
+      Expression::CreateBinaryArithmetic("+", x.id, y.id));
 }
 
 ExpressionRef operator-(ExpressionRef x, ExpressionRef y) {
   return ExpressionRef::Create(
-      Expression::CreateBinaryArithmetic(ExpressionType::MINUS, x.id, y.id));
+      Expression::CreateBinaryArithmetic("-", x.id, y.id));
 }
 
 ExpressionRef operator/(ExpressionRef x, ExpressionRef y) {
   return ExpressionRef::Create(
-      Expression::CreateBinaryArithmetic(ExpressionType::DIVISION, x.id, y.id));
+      Expression::CreateBinaryArithmetic("/", x.id, y.id));
 }
 
 ExpressionRef operator*(ExpressionRef x, ExpressionRef y) {
-  return ExpressionRef::Create(Expression::CreateBinaryArithmetic(
-      ExpressionType::MULTIPLICATION, x.id, y.id));
+  return ExpressionRef::Create(
+      Expression::CreateBinaryArithmetic("*", x.id, y.id));
 }
 
 // Functions
diff --git a/internal/ceres/expression_test.cc b/internal/ceres/expression_test.cc
index 3e683c1..704a464 100644
--- a/internal/ceres/expression_test.cc
+++ b/internal/ceres/expression_test.cc
@@ -29,6 +29,8 @@
 // Author: darius.rueckert@fau.de (Darius Rueckert)
 //
 
+#define CERES_CODEGEN
+
 #include "ceres/internal/expression_graph.h"
 #include "ceres/internal/expression_ref.h"
 
@@ -48,10 +50,10 @@
 
   auto graph = StopRecordingExpressions();
 
-  ASSERT_FALSE(graph.ExpressionForId(a.id).IsArithmetic());
-  ASSERT_FALSE(graph.ExpressionForId(b.id).IsArithmetic());
-  ASSERT_TRUE(graph.ExpressionForId(c.id).IsArithmetic());
-  ASSERT_TRUE(graph.ExpressionForId(d.id).IsArithmetic());
+  ASSERT_FALSE(graph.ExpressionForId(a.id).IsArithmeticExpression());
+  ASSERT_FALSE(graph.ExpressionForId(b.id).IsArithmeticExpression());
+  ASSERT_TRUE(graph.ExpressionForId(c.id).IsArithmeticExpression());
+  ASSERT_TRUE(graph.ExpressionForId(d.id).IsArithmeticExpression());
 }
 
 TEST(Expression, IsCompileTimeConstantAndEqualTo) {
diff --git a/internal/ceres/expression_test.h b/internal/ceres/expression_test.h
new file mode 100644
index 0000000..de9e2c2
--- /dev/null
+++ b/internal/ceres/expression_test.h
@@ -0,0 +1,69 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2019 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+//   this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+//   this list of conditions and the following disclaimer in the documentation
+//   and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+//   used to endorse or promote products derived from this software without
+//   specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: darius.rueckert@fau.de (Darius Rueckert)
+//
+
+#ifndef CERES_PUBLIC_EXPRESSION_TEST_H_
+#define CERES_PUBLIC_EXPRESSION_TEST_H_
+
+#include "ceres/internal/expression_graph.h"
+#include "ceres/internal/expression_ref.h"
+
+#include "gtest/gtest.h"
+
+// This file adds a few helper functions to test Expressions and
+// ExpressionGraphs for correctness.
+namespace ceres {
+namespace internal {
+
+inline void TestExpression(const Expression& expr,
+                           ExpressionType type,
+                           ExpressionId lhs_id,
+                           double value,
+                           const std::string& name,
+                           const std::vector<ExpressionId>& arguments) {
+  EXPECT_EQ(static_cast<int>(expr.type()), static_cast<int>(type));
+  EXPECT_EQ(expr.lhs_id(), lhs_id);
+  EXPECT_EQ(expr.value(), value);
+  EXPECT_EQ(expr.name(), name);
+  EXPECT_EQ(expr.arguments(), arguments);
+}
+
+#define TE(_id, _type, _lhs_id, _value, _name, ...) \
+  TestExpression(graph.ExpressionForId(_id),        \
+                 ExpressionType::_type,             \
+                 _lhs_id,                           \
+                 _value,                            \
+                 _name,                             \
+                 {__VA_ARGS__})
+
+}  // namespace internal
+}  // namespace ceres
+
+#endif