Autodiff Codegen Part 3: CodeGenerator

Add the class CodeGenerator, which is able to convert
objects of ExpressionGraph into strings.

Change-Id: Iff94fdbd3f2e055a78871b1a9947f1ca5ae3cd17
diff --git a/include/ceres/internal/code_generator.h b/include/ceres/internal/code_generator.h
new file mode 100644
index 0000000..d629907
--- /dev/null
+++ b/include/ceres/internal/code_generator.h
@@ -0,0 +1,123 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2019 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+//   this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+//   this list of conditions and the following disclaimer in the documentation
+//   and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+//   used to endorse or promote products derived from this software without
+//   specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: darius.rueckert@fau.de (Darius Rueckert)
+//
+#ifndef CERES_PUBLIC_CODE_GENERATOR_H_
+#define CERES_PUBLIC_CODE_GENERATOR_H_
+
+#include "ceres/internal/expression.h"
+#include "ceres/internal/expression_graph.h"
+
+#include <string>
+#include <vector>
+
+namespace ceres {
+namespace internal {
+
+// This class is used to convert an expression graph into a string. The typical
+// pipeline is:
+//
+// 1. Record ExpressionGraph
+// 2. Optimize ExpressionGraph
+// 3. Generate C++ code (this class here)
+//
+// The CodeGenerator operates in the following way:
+//
+// 1. Print Header
+//    - The header string is defined in the options.
+//    - This is usually the function name including the parameter list.
+//
+// 2. Print Declarations
+//    - Declare all used variables
+//    - Example:
+//        double v_0;
+//        double v_1;
+//        bool v_3;
+//        ...
+//
+// 3. Print Code
+//    - Convert each expression line by line to a string
+//    - Example:
+//        v_2 = v_0 + v_1
+//        if (v_5) {
+//          v_2 = v_0;
+//          ....
+//
+class CodeGenerator {
+ public:
+  struct Options {
+    // Name of the function.
+    // Example:
+    //   bool Evaluate(const double* x, double* res)
+    std::string function_name = "";
+
+    // Number of spaces added for each level of indentation.
+    int indentation_spaces_per_level = 2;
+
+    // The prefix added to each variable name.
+    std::string variable_prefix = "v_";
+  };
+
+  CodeGenerator(const ExpressionGraph& graph, const Options& options);
+
+  // Generate the C++ code in the steps (1)-(3) described above.
+  // The result is a vector of strings, where each element is exactly one line
+  // of code. The order is important and must not be changed.
+  std::vector<std::string> Generate();
+
+ private:
+  // Converts a single expression given by id to a string.
+  // The format depends on the ExpressionType.
+  // See ExpressionType in expression.h for more detailed how the different
+  // lines will look like.
+  std::string ExpressionToString(ExpressionId id);
+
+  // Helper function to get the name of an expression.
+  // If the expression does not have a valid name an error is generated.
+  std::string VariableForExpressionId(ExpressionId id);
+
+  // Returns the type as a string of the left hand side.
+  static std::string DataTypeForExpression(ExpressionType type);
+
+  // Adds one level of indentation. Called when an IF expression is encountered.
+  void PushIndentation();
+
+  // Removes one level of indentation. Currently only used by ENDIF.
+  void PopIndentation();
+
+  const ExpressionGraph& graph_;
+  const Options options_;
+  std::string indentation_ = "";
+  static constexpr int kFloatingPointPrecision = 25;
+};
+
+}  // namespace internal
+}  // namespace ceres
+
+#endif  // CERES_PUBLIC_CODE_GENERATOR_H_
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt
index 148b058..a3bd8b5 100644
--- a/internal/ceres/CMakeLists.txt
+++ b/internal/ceres/CMakeLists.txt
@@ -60,6 +60,7 @@
     canonical_views_clustering.cc
     cgnr_solver.cc
     callbacks.cc
+    code_generator.cc
     compressed_col_sparse_matrix_utils.cc
     compressed_row_jacobian_writer.cc
     compressed_row_sparse_matrix.cc
@@ -416,6 +417,7 @@
   ceres_test(block_sparse_matrix)
   ceres_test(c_api)
   ceres_test(canonical_views_clustering)
+  ceres_test(code_generator)
   ceres_test(compressed_col_sparse_matrix_utils)
   ceres_test(compressed_row_sparse_matrix)
   ceres_test(concurrent_queue)
diff --git a/internal/ceres/code_generator.cc b/internal/ceres/code_generator.cc
new file mode 100644
index 0000000..3af4bfb
--- /dev/null
+++ b/internal/ceres/code_generator.cc
@@ -0,0 +1,284 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2019 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+//   this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+//   this list of conditions and the following disclaimer in the documentation
+//   and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+//   used to endorse or promote products derived from this software without
+//   specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: darius.rueckert@fau.de (Darius Rueckert)
+
+#include "ceres/internal/code_generator.h"
+#include <sstream>
+#include "assert.h"
+#include "glog/logging.h"
+
+namespace ceres {
+namespace internal {
+
+CodeGenerator::CodeGenerator(const ExpressionGraph& graph,
+                             const Options& options)
+    : graph_(graph), options_(options) {}
+
+std::vector<std::string> CodeGenerator::Generate() {
+  std::vector<std::string> code;
+
+  // 1. Print the header
+  if (!options_.function_name.empty()) {
+    code.emplace_back(options_.function_name);
+  }
+
+  code.emplace_back("{");
+  PushIndentation();
+
+  // 2. Print declarations
+  for (ExpressionId id = 0; id < graph_.Size(); ++id) {
+    // By definition of the lhs_id, an expression defines a new variable only if
+    // the current_id is identical to the lhs_id.
+    const auto& expr = graph_.ExpressionForId(id);
+    if (id != expr.lhs_id()) {
+      continue;
+    }
+    //
+    // Format:     <type> <id>;
+    // Example:    double v_0;
+    //
+    const std::string declaration_string =
+        indentation_ + DataTypeForExpression(expr.type()) + " " +
+        VariableForExpressionId(id) + ";";
+    code.emplace_back(declaration_string);
+  }
+
+  // 3. Print code
+  for (ExpressionId id = 0; id < graph_.Size(); ++id) {
+    code.emplace_back(ExpressionToString(id));
+  }
+
+  PopIndentation();
+  CHECK(indentation_.empty()) << "IF - ENDIF missmatch detected.";
+  code.emplace_back("}");
+
+  return code;
+}
+
+std::string CodeGenerator::ExpressionToString(ExpressionId id) {
+  // An expression is converted into a string, by first adding the required
+  // indentation spaces and then adding a ExpressionType-specific string. The
+  // following list shows the exact output format for each ExpressionType. The
+  // placeholders <value>, <name>,... stand for the respective members value_,
+  // name_, ... of the current expression. ExpressionIds such as lhs_id and
+  // arguments are converted to the corresponding variable name (7 -> "v_7").
+
+  auto& expr = graph_.ExpressionForId(id);
+
+  std::stringstream result;
+  result.precision(kFloatingPointPrecision);
+
+  // Convert the variable names of lhs and arguments to string. This makes the
+  // big switch/case below more readable.
+  std::string lhs;
+  if (expr.HasValidLhs()) {
+    lhs = VariableForExpressionId(expr.lhs_id());
+  }
+  std::vector<std::string> args;
+  for (ExpressionId id : expr.arguments()) {
+    args.push_back(VariableForExpressionId(id));
+  }
+  auto value = expr.value();
+  const auto& name = expr.name();
+
+  switch (expr.type()) {
+    case ExpressionType::COMPILE_TIME_CONSTANT: {
+      //
+      // Format:     <lhs_id> = <value>;
+      // Example:    v_0      = 3.1415;
+      //
+      result << indentation_ << lhs << " = " << value << ";";
+      break;
+    }
+    case ExpressionType::INPUT_ASSIGNMENT: {
+      //
+      // Format:     <lhs_id> = <name>;
+      // Example:    v_0      = _observed_point_x;
+      //
+      result << indentation_ << lhs << " = " << name << ";";
+      break;
+    }
+    case ExpressionType::OUTPUT_ASSIGNMENT: {
+      //
+      // Format:     <name>      = <arguments[0]>;
+      // Example:    residual[0] = v_51;
+      //
+      result << indentation_ << name << " = " << args[0] << ";";
+      break;
+    }
+    case ExpressionType::ASSIGNMENT: {
+      //
+      // Format:     <lhs_id> = <arguments[0]>;
+      // Example:    v_1      = v_0;
+      //
+      result << indentation_ << lhs << " = " << args[0] << ";";
+      break;
+    }
+    case ExpressionType::BINARY_ARITHMETIC: {
+      //
+      // Format:     <lhs_id> = <arguments[0]> <name> <arguments[1]>;
+      // Example:    v_2      = v_0 + v_1;
+      //
+      result << indentation_ << lhs << " = " << args[0] << " " << name << " "
+             << args[1] << ";";
+      break;
+    }
+    case ExpressionType::UNARY_ARITHMETIC: {
+      //
+      // Format:     <lhs_id> = <name><arguments[0]>;
+      // Example:    v_1      = -v_0;
+      //
+      result << indentation_ << lhs << " = " << name << args[0] << ";";
+      break;
+    }
+    case ExpressionType::BINARY_COMPARISON: {
+      //
+      // Format:     <lhs_id> =  <arguments[0]> <name> <arguments[1]>;
+      // Example:    v_2   = v_0 < v_1;
+      //
+      result << indentation_ << lhs << " = " << args[0] << " " << name << " "
+             << args[1] << ";";
+      break;
+    }
+    case ExpressionType::LOGICAL_NEGATION: {
+      //
+      // Format:     <lhs_id> = !<arguments[0]>;
+      // Example:    v_1   = !v_0;
+      //
+      result << indentation_ << lhs << " = !" << args[0] << ";";
+      break;
+    }
+    case ExpressionType::FUNCTION_CALL: {
+      //
+      // Format:     <lhs_id> = <name>(<arguments[0]>, <arguments[1]>, ...);
+      // Example:    v_1   = sin(v_0);
+      //
+      result << indentation_ << lhs << " = " << name << "(";
+      result << (args.size() ? args[0] : "");
+      for (int i = 1; i < args.size(); ++i) {
+        result << ", " << args[i];
+      }
+      result << ");";
+      break;
+    }
+    case ExpressionType::IF: {
+      //
+      // Format:     if (<arguments[0]>) {
+      // Example:    if (v_0) {
+      // Special:    Adds 1 level of indentation for all following
+      //             expressions.
+      //
+      result << indentation_ << "if (" << args[0] << ") {";
+      PushIndentation();
+      break;
+    }
+    case ExpressionType::ELSE: {
+      //
+      // Format:     } else {
+      // Example:    } else {
+      // Special:    This expression is printed with one less level of
+      //             indentation.
+      //
+      PopIndentation();
+      result << indentation_ << "} else {";
+      PushIndentation();
+      break;
+    }
+    case ExpressionType::ENDIF: {
+      //
+      // Format:     }
+      // Example:    }
+      // Special:    Removes 1 level of indentation for this and all
+      //             following expressions.
+      //
+      PopIndentation();
+      result << indentation_ << "}";
+      break;
+    }
+    case ExpressionType::NOP: {
+      //
+      // Format:     // <NOP>
+      // Example:    // <NOP>
+      //
+      result << indentation_ << "// <NOP>";
+      break;
+    }
+    default:
+      CHECK(false) << "CodeGenerator::ToString for ExpressionType "
+                   << static_cast<int>(expr.type()) << " not implemented!";
+  }
+  return result.str();
+}
+
+std::string CodeGenerator::VariableForExpressionId(ExpressionId id) {
+  //
+  // Format:     <variable_prefix><id>
+  // Example:    v_42
+  //
+  auto& expr = graph_.ExpressionForId(id);
+  CHECK(expr.lhs_id() == id)
+      << "ExpressionId " << id
+      << " does not have a name (it has not been declared).";
+  return options_.variable_prefix + std::to_string(expr.lhs_id());
+}
+
+std::string CodeGenerator::DataTypeForExpression(ExpressionType type) {
+  std::string type_string;
+  switch (type) {
+    case ExpressionType::BINARY_COMPARISON:
+    case ExpressionType::LOGICAL_NEGATION:
+      type_string = "bool";
+      break;
+    case ExpressionType::IF:
+    case ExpressionType::ELSE:
+    case ExpressionType::ENDIF:
+    case ExpressionType::NOP:
+      type_string = "void";
+      break;
+    default:
+      type_string = "double";
+  }
+  return type_string;
+}
+
+void CodeGenerator::PushIndentation() {
+  for (int i = 0; i < options_.indentation_spaces_per_level; ++i) {
+    indentation_.push_back(' ');
+  }
+}
+
+void CodeGenerator::PopIndentation() {
+  for (int i = 0; i < options_.indentation_spaces_per_level; ++i) {
+    CHECK(!indentation_.empty()) << "IF - ENDIF missmatch detected.";
+    indentation_.pop_back();
+  }
+}
+
+}  // namespace internal
+}  // namespace ceres
diff --git a/internal/ceres/code_generator_test.cc b/internal/ceres/code_generator_test.cc
new file mode 100644
index 0000000..c0feb41
--- /dev/null
+++ b/internal/ceres/code_generator_test.cc
@@ -0,0 +1,465 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2019 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+//   this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+//   this list of conditions and the following disclaimer in the documentation
+//   and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+//   used to endorse or promote products derived from this software without
+//   specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: darius.rueckert@fau.de (Darius Rueckert)
+//
+#define CERES_CODEGEN
+
+#include "ceres/internal/code_generator.h"
+#include "ceres/internal/expression_graph.h"
+#include "ceres/internal/expression_ref.h"
+
+#include "gtest/gtest.h"
+
+namespace ceres {
+namespace internal {
+
+static void GenerateAndCheck(const ExpressionGraph& graph,
+                             const std::vector<std::string>& reference) {
+  CodeGenerator::Options generator_options;
+  CodeGenerator gen(graph, generator_options);
+  auto code = gen.Generate();
+  EXPECT_EQ(code.size(), reference.size());
+
+  for (int i = 0; i < code.size(); ++i) {
+    EXPECT_EQ(code[i], reference[i]) << "Invalid Line: " << (i + 1);
+  }
+}
+
+using T = ExpressionRef;
+
+TEST(CodeGenerator, Empty) {
+  StartRecordingExpressions();
+  auto graph = StopRecordingExpressions();
+  std::vector<std::string> expected_code = {"{", "}"};
+  GenerateAndCheck(graph, expected_code);
+}
+
+// Now we add one TEST for each ExpressionType.
+TEST(CodeGenerator, COMPILE_TIME_CONSTANT) {
+  StartRecordingExpressions();
+  T a = T(0);
+  T b = T(123.5);
+  T c = T(1 + 1);
+  T d;  // Uninitialized variables should not generate code!
+  auto graph = StopRecordingExpressions();
+  std::vector<std::string> expected_code = {"{",
+                                            "  double v_0;",
+                                            "  double v_1;",
+                                            "  double v_2;",
+                                            "  v_0 = 0;",
+                                            "  v_1 = 123.5;",
+                                            "  v_2 = 2;",
+                                            "}"};
+  GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, INPUT_ASSIGNMENT) {
+  double local_variable = 5.0;
+  StartRecordingExpressions();
+  T a = CERES_LOCAL_VARIABLE(local_variable);
+  T b = MakeParameter("parameters[0][0]");
+  T c = a + b;
+  auto graph = StopRecordingExpressions();
+  std::vector<std::string> expected_code = {"{",
+                                            "  double v_0;",
+                                            "  double v_1;",
+                                            "  double v_2;",
+                                            "  v_0 = local_variable;",
+                                            "  v_1 = parameters[0][0];",
+                                            "  v_2 = v_0 + v_1;",
+                                            "}"};
+  GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, OUTPUT_ASSIGNMENT) {
+  double local_variable = 5.0;
+  StartRecordingExpressions();
+  T a = 1;
+  T b = 0;
+  MakeOutput(a, "residual[0]");
+  auto graph = StopRecordingExpressions();
+  std::vector<std::string> expected_code = {"{",
+                                            "  double v_0;",
+                                            "  double v_1;",
+                                            "  double v_2;",
+                                            "  v_0 = 1;",
+                                            "  v_1 = 0;",
+                                            "  residual[0] = v_0;",
+                                            "}"};
+  GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, ASSIGNMENT) {
+  StartRecordingExpressions();
+  T a = T(0);
+  T b = T(1);
+  T c = a;  // < This should not generate a line!
+  a = b;
+  a = a + b;  // < Create temporary + assignment
+  auto graph = StopRecordingExpressions();
+  std::vector<std::string> expected_code = {"{",
+                                            "  double v_0;",
+                                            "  double v_1;",
+                                            "  double v_3;",
+                                            "  v_0 = 0;",
+                                            "  v_1 = 1;",
+                                            "  v_0 = v_1;",
+                                            "  v_3 = v_0 + v_1;",
+                                            "  v_0 = v_3;",
+                                            "}"};
+  GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, BINARY_ARITHMETIC_SIMPLE) {
+  StartRecordingExpressions();
+  T a = T(0);
+  T b = T(1);
+  T r1 = a + b;
+  T r2 = a - b;
+  T r3 = a * b;
+  T r4 = a / b;
+  auto graph = StopRecordingExpressions();
+  std::vector<std::string> expected_code = {"{",
+                                            "  double v_0;",
+                                            "  double v_1;",
+                                            "  double v_2;",
+                                            "  double v_3;",
+                                            "  double v_4;",
+                                            "  double v_5;",
+                                            "  v_0 = 0;",
+                                            "  v_1 = 1;",
+                                            "  v_2 = v_0 + v_1;",
+                                            "  v_3 = v_0 - v_1;",
+                                            "  v_4 = v_0 * v_1;",
+                                            "  v_5 = v_0 / v_1;",
+                                            "}"};
+  GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, BINARY_ARITHMETIC_COMPOUND) {
+  // For each binary compound arithmetic operation, two lines are generated:
+  //    - The actual operation assigning to a new temporary variable
+  //    - An assignment from the temporary to the lhs
+  StartRecordingExpressions();
+  T a = T(0);
+  T b = T(1);
+  b += a;
+  b -= a;
+  b *= a;
+  b /= a;
+  auto graph = StopRecordingExpressions();
+  std::vector<std::string> expected_code = {"{",
+                                            "  double v_0;",
+                                            "  double v_1;",
+                                            "  double v_2;",
+                                            "  double v_4;",
+                                            "  double v_6;",
+                                            "  double v_8;",
+                                            "  v_0 = 0;",
+                                            "  v_1 = 1;",
+                                            "  v_2 = v_1 + v_0;",
+                                            "  v_1 = v_2;",
+                                            "  v_4 = v_1 - v_0;",
+                                            "  v_1 = v_4;",
+                                            "  v_6 = v_1 * v_0;",
+                                            "  v_1 = v_6;",
+                                            "  v_8 = v_1 / v_0;",
+                                            "  v_1 = v_8;",
+                                            "}"};
+  GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, UNARY_ARITHMETIC) {
+  StartRecordingExpressions();
+  T a = T(0);
+  T r1 = -a;
+  T r2 = +a;
+  auto graph = StopRecordingExpressions();
+  std::vector<std::string> expected_code = {"{",
+                                            "  double v_0;",
+                                            "  double v_1;",
+                                            "  double v_2;",
+                                            "  v_0 = 0;",
+                                            "  v_1 = -v_0;",
+                                            "  v_2 = +v_0;",
+                                            "}"};
+  GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, BINARY_COMPARISON) {
+  StartRecordingExpressions();
+  T a = T(0);
+  T b = T(1);
+  auto r1 = a < b;
+  auto r2 = a <= b;
+  auto r3 = a > b;
+  auto r4 = a >= b;
+  auto r5 = a == b;
+  auto r6 = a != b;
+  auto graph = StopRecordingExpressions();
+  std::vector<std::string> expected_code = {"{",
+                                            "  double v_0;",
+                                            "  double v_1;",
+                                            "  bool v_2;",
+                                            "  bool v_3;",
+                                            "  bool v_4;",
+                                            "  bool v_5;",
+                                            "  bool v_6;",
+                                            "  bool v_7;",
+                                            "  v_0 = 0;",
+                                            "  v_1 = 1;",
+                                            "  v_2 = v_0 < v_1;",
+                                            "  v_3 = v_0 <= v_1;",
+                                            "  v_4 = v_0 > v_1;",
+                                            "  v_5 = v_0 >= v_1;",
+                                            "  v_6 = v_0 == v_1;",
+                                            "  v_7 = v_0 != v_1;",
+                                            "}"};
+  GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, LOGICAL_OPERATORS) {
+  // Tests binary logical operators &&, || and the unary logical operator !
+  StartRecordingExpressions();
+  T a = T(0);
+  T b = T(1);
+  auto r1 = a < b;
+  auto r2 = a <= b;
+
+  auto r3 = r1 && r2;
+  auto r4 = r1 || r2;
+  auto r5 = !r1;
+
+  auto graph = StopRecordingExpressions();
+  std::vector<std::string> expected_code = {"{",
+                                            "  double v_0;",
+                                            "  double v_1;",
+                                            "  bool v_2;",
+                                            "  bool v_3;",
+                                            "  bool v_4;",
+                                            "  bool v_5;",
+                                            "  bool v_6;",
+                                            "  v_0 = 0;",
+                                            "  v_1 = 1;",
+                                            "  v_2 = v_0 < v_1;",
+                                            "  v_3 = v_0 <= v_1;",
+                                            "  v_4 = v_2 && v_3;",
+                                            "  v_5 = v_2 || v_3;",
+                                            "  v_6 = !v_2;",
+                                            "}"};
+  GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, FUNCTION_CALL) {
+  StartRecordingExpressions();
+  T a = T(0);
+  T b = T(1);
+
+  abs(a);
+  acos(a);
+  asin(a);
+  atan(a);
+  cbrt(a);
+  ceil(a);
+  cos(a);
+  cosh(a);
+  exp(a);
+  exp2(a);
+  floor(a);
+  log(a);
+  log2(a);
+  sin(a);
+  sinh(a);
+  sqrt(a);
+  tan(a);
+  tanh(a);
+  atan2(a, b);
+  pow(a, b);
+
+  auto graph = StopRecordingExpressions();
+
+  std::vector<std::string> expected_code = {"{",
+                                            "  double v_0;",
+                                            "  double v_1;",
+                                            "  double v_2;",
+                                            "  double v_3;",
+                                            "  double v_4;",
+                                            "  double v_5;",
+                                            "  double v_6;",
+                                            "  double v_7;",
+                                            "  double v_8;",
+                                            "  double v_9;",
+                                            "  double v_10;",
+                                            "  double v_11;",
+                                            "  double v_12;",
+                                            "  double v_13;",
+                                            "  double v_14;",
+                                            "  double v_15;",
+                                            "  double v_16;",
+                                            "  double v_17;",
+                                            "  double v_18;",
+                                            "  double v_19;",
+                                            "  double v_20;",
+                                            "  double v_21;",
+                                            "  v_0 = 0;",
+                                            "  v_1 = 1;",
+                                            "  v_2 = abs(v_0);",
+                                            "  v_3 = acos(v_0);",
+                                            "  v_4 = asin(v_0);",
+                                            "  v_5 = atan(v_0);",
+                                            "  v_6 = cbrt(v_0);",
+                                            "  v_7 = ceil(v_0);",
+                                            "  v_8 = cos(v_0);",
+                                            "  v_9 = cosh(v_0);",
+                                            "  v_10 = exp(v_0);",
+                                            "  v_11 = exp2(v_0);",
+                                            "  v_12 = floor(v_0);",
+                                            "  v_13 = log(v_0);",
+                                            "  v_14 = log2(v_0);",
+                                            "  v_15 = sin(v_0);",
+                                            "  v_16 = sinh(v_0);",
+                                            "  v_17 = sqrt(v_0);",
+                                            "  v_18 = tan(v_0);",
+                                            "  v_19 = tanh(v_0);",
+                                            "  v_20 = atan2(v_0, v_1);",
+                                            "  v_21 = pow(v_0, v_1);",
+                                            "}"};
+  GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, IF_SIMPLE) {
+  StartRecordingExpressions();
+  T a = T(0);
+  T b = T(1);
+  auto r1 = a < b;
+  CERES_IF(r1) {}
+  CERES_ELSE {}
+  CERES_ENDIF;
+
+  auto graph = StopRecordingExpressions();
+  std::vector<std::string> expected_code = {"{",
+                                            "  double v_0;",
+                                            "  double v_1;",
+                                            "  bool v_2;",
+                                            "  v_0 = 0;",
+                                            "  v_1 = 1;",
+                                            "  v_2 = v_0 < v_1;",
+                                            "  if (v_2) {",
+                                            "  } else {",
+                                            "  }",
+                                            "}"};
+  GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, IF_ASSIGNMENT) {
+  StartRecordingExpressions();
+  T a = T(0);
+  T b = T(1);
+  auto r1 = a < b;
+
+  T result = 0;
+  CERES_IF(r1) { result = 5.0; }
+  CERES_ELSE { result = 6.0; }
+  CERES_ENDIF;
+  MakeOutput(result, "result");
+
+  auto graph = StopRecordingExpressions();
+  std::vector<std::string> expected_code = {"{",
+                                            "  double v_0;",
+                                            "  double v_1;",
+                                            "  bool v_2;",
+                                            "  double v_3;",
+                                            "  double v_5;",
+                                            "  double v_8;",
+                                            "  double v_11;",
+                                            "  v_0 = 0;",
+                                            "  v_1 = 1;",
+                                            "  v_2 = v_0 < v_1;",
+                                            "  v_3 = 0;",
+                                            "  if (v_2) {",
+                                            "    v_5 = 5;",
+                                            "    v_3 = v_5;",
+                                            "  } else {",
+                                            "    v_8 = 6;",
+                                            "    v_3 = v_8;",
+                                            "  }",
+                                            "  result = v_3;",
+                                            "}"};
+  GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, IF_NESTED_ASSIGNMENT) {
+  StartRecordingExpressions();
+  T a = T(0);
+  T b = T(1);
+
+  T result = 0;
+  CERES_IF(a <= b) {
+    result = 5.0;
+    CERES_IF(a == b) { result = 7.0; }
+    CERES_ENDIF;
+  }
+  CERES_ELSE { result = 6.0; }
+  CERES_ENDIF;
+  MakeOutput(result, "result");
+
+  auto graph = StopRecordingExpressions();
+  std::vector<std::string> expected_code = {"{",
+                                            "  double v_0;",
+                                            "  double v_1;",
+                                            "  double v_2;",
+                                            "  bool v_3;",
+                                            "  double v_5;",
+                                            "  bool v_7;",
+                                            "  double v_9;",
+                                            "  double v_13;",
+                                            "  double v_16;",
+                                            "  v_0 = 0;",
+                                            "  v_1 = 1;",
+                                            "  v_2 = 0;",
+                                            "  v_3 = v_0 <= v_1;",
+                                            "  if (v_3) {",
+                                            "    v_5 = 5;",
+                                            "    v_2 = v_5;",
+                                            "    v_7 = v_0 == v_1;",
+                                            "    if (v_7) {",
+                                            "      v_9 = 7;",
+                                            "      v_2 = v_9;",
+                                            "    }",
+                                            "  } else {",
+                                            "    v_13 = 6;",
+                                            "    v_2 = v_13;",
+                                            "  }",
+                                            "  result = v_2;",
+                                            "}"};
+  GenerateAndCheck(graph, expected_code);
+}
+
+}  // namespace internal
+}  // namespace ceres