Autodiff Codegen Part 3: CodeGenerator
Add the class CodeGenerator, which is able to convert
objects of ExpressionGraph into strings.
Change-Id: Iff94fdbd3f2e055a78871b1a9947f1ca5ae3cd17
diff --git a/include/ceres/internal/code_generator.h b/include/ceres/internal/code_generator.h
new file mode 100644
index 0000000..d629907
--- /dev/null
+++ b/include/ceres/internal/code_generator.h
@@ -0,0 +1,123 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2019 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: darius.rueckert@fau.de (Darius Rueckert)
+//
+#ifndef CERES_PUBLIC_CODE_GENERATOR_H_
+#define CERES_PUBLIC_CODE_GENERATOR_H_
+
+#include "ceres/internal/expression.h"
+#include "ceres/internal/expression_graph.h"
+
+#include <string>
+#include <vector>
+
+namespace ceres {
+namespace internal {
+
+// This class is used to convert an expression graph into a string. The typical
+// pipeline is:
+//
+// 1. Record ExpressionGraph
+// 2. Optimize ExpressionGraph
+// 3. Generate C++ code (this class here)
+//
+// The CodeGenerator operates in the following way:
+//
+// 1. Print Header
+// - The header string is defined in the options.
+// - This is usually the function name including the parameter list.
+//
+// 2. Print Declarations
+// - Declare all used variables
+// - Example:
+// double v_0;
+// double v_1;
+// bool v_3;
+// ...
+//
+// 3. Print Code
+// - Convert each expression line by line to a string
+// - Example:
+// v_2 = v_0 + v_1
+// if (v_5) {
+// v_2 = v_0;
+// ....
+//
+class CodeGenerator {
+ public:
+ struct Options {
+ // Name of the function.
+ // Example:
+ // bool Evaluate(const double* x, double* res)
+ std::string function_name = "";
+
+ // Number of spaces added for each level of indentation.
+ int indentation_spaces_per_level = 2;
+
+ // The prefix added to each variable name.
+ std::string variable_prefix = "v_";
+ };
+
+ CodeGenerator(const ExpressionGraph& graph, const Options& options);
+
+ // Generate the C++ code in the steps (1)-(3) described above.
+ // The result is a vector of strings, where each element is exactly one line
+ // of code. The order is important and must not be changed.
+ std::vector<std::string> Generate();
+
+ private:
+ // Converts a single expression given by id to a string.
+ // The format depends on the ExpressionType.
+ // See ExpressionType in expression.h for more detailed how the different
+ // lines will look like.
+ std::string ExpressionToString(ExpressionId id);
+
+ // Helper function to get the name of an expression.
+ // If the expression does not have a valid name an error is generated.
+ std::string VariableForExpressionId(ExpressionId id);
+
+ // Returns the type as a string of the left hand side.
+ static std::string DataTypeForExpression(ExpressionType type);
+
+ // Adds one level of indentation. Called when an IF expression is encountered.
+ void PushIndentation();
+
+ // Removes one level of indentation. Currently only used by ENDIF.
+ void PopIndentation();
+
+ const ExpressionGraph& graph_;
+ const Options options_;
+ std::string indentation_ = "";
+ static constexpr int kFloatingPointPrecision = 25;
+};
+
+} // namespace internal
+} // namespace ceres
+
+#endif // CERES_PUBLIC_CODE_GENERATOR_H_
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt
index 148b058..a3bd8b5 100644
--- a/internal/ceres/CMakeLists.txt
+++ b/internal/ceres/CMakeLists.txt
@@ -60,6 +60,7 @@
canonical_views_clustering.cc
cgnr_solver.cc
callbacks.cc
+ code_generator.cc
compressed_col_sparse_matrix_utils.cc
compressed_row_jacobian_writer.cc
compressed_row_sparse_matrix.cc
@@ -416,6 +417,7 @@
ceres_test(block_sparse_matrix)
ceres_test(c_api)
ceres_test(canonical_views_clustering)
+ ceres_test(code_generator)
ceres_test(compressed_col_sparse_matrix_utils)
ceres_test(compressed_row_sparse_matrix)
ceres_test(concurrent_queue)
diff --git a/internal/ceres/code_generator.cc b/internal/ceres/code_generator.cc
new file mode 100644
index 0000000..3af4bfb
--- /dev/null
+++ b/internal/ceres/code_generator.cc
@@ -0,0 +1,284 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2019 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: darius.rueckert@fau.de (Darius Rueckert)
+
+#include "ceres/internal/code_generator.h"
+#include <sstream>
+#include "assert.h"
+#include "glog/logging.h"
+
+namespace ceres {
+namespace internal {
+
+CodeGenerator::CodeGenerator(const ExpressionGraph& graph,
+ const Options& options)
+ : graph_(graph), options_(options) {}
+
+std::vector<std::string> CodeGenerator::Generate() {
+ std::vector<std::string> code;
+
+ // 1. Print the header
+ if (!options_.function_name.empty()) {
+ code.emplace_back(options_.function_name);
+ }
+
+ code.emplace_back("{");
+ PushIndentation();
+
+ // 2. Print declarations
+ for (ExpressionId id = 0; id < graph_.Size(); ++id) {
+ // By definition of the lhs_id, an expression defines a new variable only if
+ // the current_id is identical to the lhs_id.
+ const auto& expr = graph_.ExpressionForId(id);
+ if (id != expr.lhs_id()) {
+ continue;
+ }
+ //
+ // Format: <type> <id>;
+ // Example: double v_0;
+ //
+ const std::string declaration_string =
+ indentation_ + DataTypeForExpression(expr.type()) + " " +
+ VariableForExpressionId(id) + ";";
+ code.emplace_back(declaration_string);
+ }
+
+ // 3. Print code
+ for (ExpressionId id = 0; id < graph_.Size(); ++id) {
+ code.emplace_back(ExpressionToString(id));
+ }
+
+ PopIndentation();
+ CHECK(indentation_.empty()) << "IF - ENDIF missmatch detected.";
+ code.emplace_back("}");
+
+ return code;
+}
+
+std::string CodeGenerator::ExpressionToString(ExpressionId id) {
+ // An expression is converted into a string, by first adding the required
+ // indentation spaces and then adding a ExpressionType-specific string. The
+ // following list shows the exact output format for each ExpressionType. The
+ // placeholders <value>, <name>,... stand for the respective members value_,
+ // name_, ... of the current expression. ExpressionIds such as lhs_id and
+ // arguments are converted to the corresponding variable name (7 -> "v_7").
+
+ auto& expr = graph_.ExpressionForId(id);
+
+ std::stringstream result;
+ result.precision(kFloatingPointPrecision);
+
+ // Convert the variable names of lhs and arguments to string. This makes the
+ // big switch/case below more readable.
+ std::string lhs;
+ if (expr.HasValidLhs()) {
+ lhs = VariableForExpressionId(expr.lhs_id());
+ }
+ std::vector<std::string> args;
+ for (ExpressionId id : expr.arguments()) {
+ args.push_back(VariableForExpressionId(id));
+ }
+ auto value = expr.value();
+ const auto& name = expr.name();
+
+ switch (expr.type()) {
+ case ExpressionType::COMPILE_TIME_CONSTANT: {
+ //
+ // Format: <lhs_id> = <value>;
+ // Example: v_0 = 3.1415;
+ //
+ result << indentation_ << lhs << " = " << value << ";";
+ break;
+ }
+ case ExpressionType::INPUT_ASSIGNMENT: {
+ //
+ // Format: <lhs_id> = <name>;
+ // Example: v_0 = _observed_point_x;
+ //
+ result << indentation_ << lhs << " = " << name << ";";
+ break;
+ }
+ case ExpressionType::OUTPUT_ASSIGNMENT: {
+ //
+ // Format: <name> = <arguments[0]>;
+ // Example: residual[0] = v_51;
+ //
+ result << indentation_ << name << " = " << args[0] << ";";
+ break;
+ }
+ case ExpressionType::ASSIGNMENT: {
+ //
+ // Format: <lhs_id> = <arguments[0]>;
+ // Example: v_1 = v_0;
+ //
+ result << indentation_ << lhs << " = " << args[0] << ";";
+ break;
+ }
+ case ExpressionType::BINARY_ARITHMETIC: {
+ //
+ // Format: <lhs_id> = <arguments[0]> <name> <arguments[1]>;
+ // Example: v_2 = v_0 + v_1;
+ //
+ result << indentation_ << lhs << " = " << args[0] << " " << name << " "
+ << args[1] << ";";
+ break;
+ }
+ case ExpressionType::UNARY_ARITHMETIC: {
+ //
+ // Format: <lhs_id> = <name><arguments[0]>;
+ // Example: v_1 = -v_0;
+ //
+ result << indentation_ << lhs << " = " << name << args[0] << ";";
+ break;
+ }
+ case ExpressionType::BINARY_COMPARISON: {
+ //
+ // Format: <lhs_id> = <arguments[0]> <name> <arguments[1]>;
+ // Example: v_2 = v_0 < v_1;
+ //
+ result << indentation_ << lhs << " = " << args[0] << " " << name << " "
+ << args[1] << ";";
+ break;
+ }
+ case ExpressionType::LOGICAL_NEGATION: {
+ //
+ // Format: <lhs_id> = !<arguments[0]>;
+ // Example: v_1 = !v_0;
+ //
+ result << indentation_ << lhs << " = !" << args[0] << ";";
+ break;
+ }
+ case ExpressionType::FUNCTION_CALL: {
+ //
+ // Format: <lhs_id> = <name>(<arguments[0]>, <arguments[1]>, ...);
+ // Example: v_1 = sin(v_0);
+ //
+ result << indentation_ << lhs << " = " << name << "(";
+ result << (args.size() ? args[0] : "");
+ for (int i = 1; i < args.size(); ++i) {
+ result << ", " << args[i];
+ }
+ result << ");";
+ break;
+ }
+ case ExpressionType::IF: {
+ //
+ // Format: if (<arguments[0]>) {
+ // Example: if (v_0) {
+ // Special: Adds 1 level of indentation for all following
+ // expressions.
+ //
+ result << indentation_ << "if (" << args[0] << ") {";
+ PushIndentation();
+ break;
+ }
+ case ExpressionType::ELSE: {
+ //
+ // Format: } else {
+ // Example: } else {
+ // Special: This expression is printed with one less level of
+ // indentation.
+ //
+ PopIndentation();
+ result << indentation_ << "} else {";
+ PushIndentation();
+ break;
+ }
+ case ExpressionType::ENDIF: {
+ //
+ // Format: }
+ // Example: }
+ // Special: Removes 1 level of indentation for this and all
+ // following expressions.
+ //
+ PopIndentation();
+ result << indentation_ << "}";
+ break;
+ }
+ case ExpressionType::NOP: {
+ //
+ // Format: // <NOP>
+ // Example: // <NOP>
+ //
+ result << indentation_ << "// <NOP>";
+ break;
+ }
+ default:
+ CHECK(false) << "CodeGenerator::ToString for ExpressionType "
+ << static_cast<int>(expr.type()) << " not implemented!";
+ }
+ return result.str();
+}
+
+std::string CodeGenerator::VariableForExpressionId(ExpressionId id) {
+ //
+ // Format: <variable_prefix><id>
+ // Example: v_42
+ //
+ auto& expr = graph_.ExpressionForId(id);
+ CHECK(expr.lhs_id() == id)
+ << "ExpressionId " << id
+ << " does not have a name (it has not been declared).";
+ return options_.variable_prefix + std::to_string(expr.lhs_id());
+}
+
+std::string CodeGenerator::DataTypeForExpression(ExpressionType type) {
+ std::string type_string;
+ switch (type) {
+ case ExpressionType::BINARY_COMPARISON:
+ case ExpressionType::LOGICAL_NEGATION:
+ type_string = "bool";
+ break;
+ case ExpressionType::IF:
+ case ExpressionType::ELSE:
+ case ExpressionType::ENDIF:
+ case ExpressionType::NOP:
+ type_string = "void";
+ break;
+ default:
+ type_string = "double";
+ }
+ return type_string;
+}
+
+void CodeGenerator::PushIndentation() {
+ for (int i = 0; i < options_.indentation_spaces_per_level; ++i) {
+ indentation_.push_back(' ');
+ }
+}
+
+void CodeGenerator::PopIndentation() {
+ for (int i = 0; i < options_.indentation_spaces_per_level; ++i) {
+ CHECK(!indentation_.empty()) << "IF - ENDIF missmatch detected.";
+ indentation_.pop_back();
+ }
+}
+
+} // namespace internal
+} // namespace ceres
diff --git a/internal/ceres/code_generator_test.cc b/internal/ceres/code_generator_test.cc
new file mode 100644
index 0000000..c0feb41
--- /dev/null
+++ b/internal/ceres/code_generator_test.cc
@@ -0,0 +1,465 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2019 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: darius.rueckert@fau.de (Darius Rueckert)
+//
+#define CERES_CODEGEN
+
+#include "ceres/internal/code_generator.h"
+#include "ceres/internal/expression_graph.h"
+#include "ceres/internal/expression_ref.h"
+
+#include "gtest/gtest.h"
+
+namespace ceres {
+namespace internal {
+
+static void GenerateAndCheck(const ExpressionGraph& graph,
+ const std::vector<std::string>& reference) {
+ CodeGenerator::Options generator_options;
+ CodeGenerator gen(graph, generator_options);
+ auto code = gen.Generate();
+ EXPECT_EQ(code.size(), reference.size());
+
+ for (int i = 0; i < code.size(); ++i) {
+ EXPECT_EQ(code[i], reference[i]) << "Invalid Line: " << (i + 1);
+ }
+}
+
+using T = ExpressionRef;
+
+TEST(CodeGenerator, Empty) {
+ StartRecordingExpressions();
+ auto graph = StopRecordingExpressions();
+ std::vector<std::string> expected_code = {"{", "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
+// Now we add one TEST for each ExpressionType.
+TEST(CodeGenerator, COMPILE_TIME_CONSTANT) {
+ StartRecordingExpressions();
+ T a = T(0);
+ T b = T(123.5);
+ T c = T(1 + 1);
+ T d; // Uninitialized variables should not generate code!
+ auto graph = StopRecordingExpressions();
+ std::vector<std::string> expected_code = {"{",
+ " double v_0;",
+ " double v_1;",
+ " double v_2;",
+ " v_0 = 0;",
+ " v_1 = 123.5;",
+ " v_2 = 2;",
+ "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, INPUT_ASSIGNMENT) {
+ double local_variable = 5.0;
+ StartRecordingExpressions();
+ T a = CERES_LOCAL_VARIABLE(local_variable);
+ T b = MakeParameter("parameters[0][0]");
+ T c = a + b;
+ auto graph = StopRecordingExpressions();
+ std::vector<std::string> expected_code = {"{",
+ " double v_0;",
+ " double v_1;",
+ " double v_2;",
+ " v_0 = local_variable;",
+ " v_1 = parameters[0][0];",
+ " v_2 = v_0 + v_1;",
+ "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, OUTPUT_ASSIGNMENT) {
+ double local_variable = 5.0;
+ StartRecordingExpressions();
+ T a = 1;
+ T b = 0;
+ MakeOutput(a, "residual[0]");
+ auto graph = StopRecordingExpressions();
+ std::vector<std::string> expected_code = {"{",
+ " double v_0;",
+ " double v_1;",
+ " double v_2;",
+ " v_0 = 1;",
+ " v_1 = 0;",
+ " residual[0] = v_0;",
+ "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, ASSIGNMENT) {
+ StartRecordingExpressions();
+ T a = T(0);
+ T b = T(1);
+ T c = a; // < This should not generate a line!
+ a = b;
+ a = a + b; // < Create temporary + assignment
+ auto graph = StopRecordingExpressions();
+ std::vector<std::string> expected_code = {"{",
+ " double v_0;",
+ " double v_1;",
+ " double v_3;",
+ " v_0 = 0;",
+ " v_1 = 1;",
+ " v_0 = v_1;",
+ " v_3 = v_0 + v_1;",
+ " v_0 = v_3;",
+ "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, BINARY_ARITHMETIC_SIMPLE) {
+ StartRecordingExpressions();
+ T a = T(0);
+ T b = T(1);
+ T r1 = a + b;
+ T r2 = a - b;
+ T r3 = a * b;
+ T r4 = a / b;
+ auto graph = StopRecordingExpressions();
+ std::vector<std::string> expected_code = {"{",
+ " double v_0;",
+ " double v_1;",
+ " double v_2;",
+ " double v_3;",
+ " double v_4;",
+ " double v_5;",
+ " v_0 = 0;",
+ " v_1 = 1;",
+ " v_2 = v_0 + v_1;",
+ " v_3 = v_0 - v_1;",
+ " v_4 = v_0 * v_1;",
+ " v_5 = v_0 / v_1;",
+ "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, BINARY_ARITHMETIC_COMPOUND) {
+ // For each binary compound arithmetic operation, two lines are generated:
+ // - The actual operation assigning to a new temporary variable
+ // - An assignment from the temporary to the lhs
+ StartRecordingExpressions();
+ T a = T(0);
+ T b = T(1);
+ b += a;
+ b -= a;
+ b *= a;
+ b /= a;
+ auto graph = StopRecordingExpressions();
+ std::vector<std::string> expected_code = {"{",
+ " double v_0;",
+ " double v_1;",
+ " double v_2;",
+ " double v_4;",
+ " double v_6;",
+ " double v_8;",
+ " v_0 = 0;",
+ " v_1 = 1;",
+ " v_2 = v_1 + v_0;",
+ " v_1 = v_2;",
+ " v_4 = v_1 - v_0;",
+ " v_1 = v_4;",
+ " v_6 = v_1 * v_0;",
+ " v_1 = v_6;",
+ " v_8 = v_1 / v_0;",
+ " v_1 = v_8;",
+ "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, UNARY_ARITHMETIC) {
+ StartRecordingExpressions();
+ T a = T(0);
+ T r1 = -a;
+ T r2 = +a;
+ auto graph = StopRecordingExpressions();
+ std::vector<std::string> expected_code = {"{",
+ " double v_0;",
+ " double v_1;",
+ " double v_2;",
+ " v_0 = 0;",
+ " v_1 = -v_0;",
+ " v_2 = +v_0;",
+ "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, BINARY_COMPARISON) {
+ StartRecordingExpressions();
+ T a = T(0);
+ T b = T(1);
+ auto r1 = a < b;
+ auto r2 = a <= b;
+ auto r3 = a > b;
+ auto r4 = a >= b;
+ auto r5 = a == b;
+ auto r6 = a != b;
+ auto graph = StopRecordingExpressions();
+ std::vector<std::string> expected_code = {"{",
+ " double v_0;",
+ " double v_1;",
+ " bool v_2;",
+ " bool v_3;",
+ " bool v_4;",
+ " bool v_5;",
+ " bool v_6;",
+ " bool v_7;",
+ " v_0 = 0;",
+ " v_1 = 1;",
+ " v_2 = v_0 < v_1;",
+ " v_3 = v_0 <= v_1;",
+ " v_4 = v_0 > v_1;",
+ " v_5 = v_0 >= v_1;",
+ " v_6 = v_0 == v_1;",
+ " v_7 = v_0 != v_1;",
+ "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, LOGICAL_OPERATORS) {
+ // Tests binary logical operators &&, || and the unary logical operator !
+ StartRecordingExpressions();
+ T a = T(0);
+ T b = T(1);
+ auto r1 = a < b;
+ auto r2 = a <= b;
+
+ auto r3 = r1 && r2;
+ auto r4 = r1 || r2;
+ auto r5 = !r1;
+
+ auto graph = StopRecordingExpressions();
+ std::vector<std::string> expected_code = {"{",
+ " double v_0;",
+ " double v_1;",
+ " bool v_2;",
+ " bool v_3;",
+ " bool v_4;",
+ " bool v_5;",
+ " bool v_6;",
+ " v_0 = 0;",
+ " v_1 = 1;",
+ " v_2 = v_0 < v_1;",
+ " v_3 = v_0 <= v_1;",
+ " v_4 = v_2 && v_3;",
+ " v_5 = v_2 || v_3;",
+ " v_6 = !v_2;",
+ "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, FUNCTION_CALL) {
+ StartRecordingExpressions();
+ T a = T(0);
+ T b = T(1);
+
+ abs(a);
+ acos(a);
+ asin(a);
+ atan(a);
+ cbrt(a);
+ ceil(a);
+ cos(a);
+ cosh(a);
+ exp(a);
+ exp2(a);
+ floor(a);
+ log(a);
+ log2(a);
+ sin(a);
+ sinh(a);
+ sqrt(a);
+ tan(a);
+ tanh(a);
+ atan2(a, b);
+ pow(a, b);
+
+ auto graph = StopRecordingExpressions();
+
+ std::vector<std::string> expected_code = {"{",
+ " double v_0;",
+ " double v_1;",
+ " double v_2;",
+ " double v_3;",
+ " double v_4;",
+ " double v_5;",
+ " double v_6;",
+ " double v_7;",
+ " double v_8;",
+ " double v_9;",
+ " double v_10;",
+ " double v_11;",
+ " double v_12;",
+ " double v_13;",
+ " double v_14;",
+ " double v_15;",
+ " double v_16;",
+ " double v_17;",
+ " double v_18;",
+ " double v_19;",
+ " double v_20;",
+ " double v_21;",
+ " v_0 = 0;",
+ " v_1 = 1;",
+ " v_2 = abs(v_0);",
+ " v_3 = acos(v_0);",
+ " v_4 = asin(v_0);",
+ " v_5 = atan(v_0);",
+ " v_6 = cbrt(v_0);",
+ " v_7 = ceil(v_0);",
+ " v_8 = cos(v_0);",
+ " v_9 = cosh(v_0);",
+ " v_10 = exp(v_0);",
+ " v_11 = exp2(v_0);",
+ " v_12 = floor(v_0);",
+ " v_13 = log(v_0);",
+ " v_14 = log2(v_0);",
+ " v_15 = sin(v_0);",
+ " v_16 = sinh(v_0);",
+ " v_17 = sqrt(v_0);",
+ " v_18 = tan(v_0);",
+ " v_19 = tanh(v_0);",
+ " v_20 = atan2(v_0, v_1);",
+ " v_21 = pow(v_0, v_1);",
+ "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, IF_SIMPLE) {
+ StartRecordingExpressions();
+ T a = T(0);
+ T b = T(1);
+ auto r1 = a < b;
+ CERES_IF(r1) {}
+ CERES_ELSE {}
+ CERES_ENDIF;
+
+ auto graph = StopRecordingExpressions();
+ std::vector<std::string> expected_code = {"{",
+ " double v_0;",
+ " double v_1;",
+ " bool v_2;",
+ " v_0 = 0;",
+ " v_1 = 1;",
+ " v_2 = v_0 < v_1;",
+ " if (v_2) {",
+ " } else {",
+ " }",
+ "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, IF_ASSIGNMENT) {
+ StartRecordingExpressions();
+ T a = T(0);
+ T b = T(1);
+ auto r1 = a < b;
+
+ T result = 0;
+ CERES_IF(r1) { result = 5.0; }
+ CERES_ELSE { result = 6.0; }
+ CERES_ENDIF;
+ MakeOutput(result, "result");
+
+ auto graph = StopRecordingExpressions();
+ std::vector<std::string> expected_code = {"{",
+ " double v_0;",
+ " double v_1;",
+ " bool v_2;",
+ " double v_3;",
+ " double v_5;",
+ " double v_8;",
+ " double v_11;",
+ " v_0 = 0;",
+ " v_1 = 1;",
+ " v_2 = v_0 < v_1;",
+ " v_3 = 0;",
+ " if (v_2) {",
+ " v_5 = 5;",
+ " v_3 = v_5;",
+ " } else {",
+ " v_8 = 6;",
+ " v_3 = v_8;",
+ " }",
+ " result = v_3;",
+ "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
+TEST(CodeGenerator, IF_NESTED_ASSIGNMENT) {
+ StartRecordingExpressions();
+ T a = T(0);
+ T b = T(1);
+
+ T result = 0;
+ CERES_IF(a <= b) {
+ result = 5.0;
+ CERES_IF(a == b) { result = 7.0; }
+ CERES_ENDIF;
+ }
+ CERES_ELSE { result = 6.0; }
+ CERES_ENDIF;
+ MakeOutput(result, "result");
+
+ auto graph = StopRecordingExpressions();
+ std::vector<std::string> expected_code = {"{",
+ " double v_0;",
+ " double v_1;",
+ " double v_2;",
+ " bool v_3;",
+ " double v_5;",
+ " bool v_7;",
+ " double v_9;",
+ " double v_13;",
+ " double v_16;",
+ " v_0 = 0;",
+ " v_1 = 1;",
+ " v_2 = 0;",
+ " v_3 = v_0 <= v_1;",
+ " if (v_3) {",
+ " v_5 = 5;",
+ " v_2 = v_5;",
+ " v_7 = v_0 == v_1;",
+ " if (v_7) {",
+ " v_9 = 7;",
+ " v_2 = v_9;",
+ " }",
+ " } else {",
+ " v_13 = 6;",
+ " v_2 = v_13;",
+ " }",
+ " result = v_2;",
+ "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
+} // namespace internal
+} // namespace ceres