Autodiff Codegen Part 3: CodeGenerator Add the class CodeGenerator, which is able to convert objects of ExpressionGraph into strings. Change-Id: Iff94fdbd3f2e055a78871b1a9947f1ca5ae3cd17
diff --git a/include/ceres/internal/code_generator.h b/include/ceres/internal/code_generator.h new file mode 100644 index 0000000..d629907 --- /dev/null +++ b/include/ceres/internal/code_generator.h
@@ -0,0 +1,123 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) +// +#ifndef CERES_PUBLIC_CODE_GENERATOR_H_ +#define CERES_PUBLIC_CODE_GENERATOR_H_ + +#include "ceres/internal/expression.h" +#include "ceres/internal/expression_graph.h" + +#include <string> +#include <vector> + +namespace ceres { +namespace internal { + +// This class is used to convert an expression graph into a string. The typical +// pipeline is: +// +// 1. Record ExpressionGraph +// 2. Optimize ExpressionGraph +// 3. Generate C++ code (this class here) +// +// The CodeGenerator operates in the following way: +// +// 1. Print Header +// - The header string is defined in the options. +// - This is usually the function name including the parameter list. +// +// 2. Print Declarations +// - Declare all used variables +// - Example: +// double v_0; +// double v_1; +// bool v_3; +// ... +// +// 3. Print Code +// - Convert each expression line by line to a string +// - Example: +// v_2 = v_0 + v_1 +// if (v_5) { +// v_2 = v_0; +// .... +// +class CodeGenerator { + public: + struct Options { + // Name of the function. + // Example: + // bool Evaluate(const double* x, double* res) + std::string function_name = ""; + + // Number of spaces added for each level of indentation. + int indentation_spaces_per_level = 2; + + // The prefix added to each variable name. + std::string variable_prefix = "v_"; + }; + + CodeGenerator(const ExpressionGraph& graph, const Options& options); + + // Generate the C++ code in the steps (1)-(3) described above. + // The result is a vector of strings, where each element is exactly one line + // of code. The order is important and must not be changed. + std::vector<std::string> Generate(); + + private: + // Converts a single expression given by id to a string. + // The format depends on the ExpressionType. + // See ExpressionType in expression.h for more detailed how the different + // lines will look like. + std::string ExpressionToString(ExpressionId id); + + // Helper function to get the name of an expression. + // If the expression does not have a valid name an error is generated. + std::string VariableForExpressionId(ExpressionId id); + + // Returns the type as a string of the left hand side. + static std::string DataTypeForExpression(ExpressionType type); + + // Adds one level of indentation. Called when an IF expression is encountered. + void PushIndentation(); + + // Removes one level of indentation. Currently only used by ENDIF. + void PopIndentation(); + + const ExpressionGraph& graph_; + const Options options_; + std::string indentation_ = ""; + static constexpr int kFloatingPointPrecision = 25; +}; + +} // namespace internal +} // namespace ceres + +#endif // CERES_PUBLIC_CODE_GENERATOR_H_
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt index 148b058..a3bd8b5 100644 --- a/internal/ceres/CMakeLists.txt +++ b/internal/ceres/CMakeLists.txt
@@ -60,6 +60,7 @@ canonical_views_clustering.cc cgnr_solver.cc callbacks.cc + code_generator.cc compressed_col_sparse_matrix_utils.cc compressed_row_jacobian_writer.cc compressed_row_sparse_matrix.cc @@ -416,6 +417,7 @@ ceres_test(block_sparse_matrix) ceres_test(c_api) ceres_test(canonical_views_clustering) + ceres_test(code_generator) ceres_test(compressed_col_sparse_matrix_utils) ceres_test(compressed_row_sparse_matrix) ceres_test(concurrent_queue)
diff --git a/internal/ceres/code_generator.cc b/internal/ceres/code_generator.cc new file mode 100644 index 0000000..3af4bfb --- /dev/null +++ b/internal/ceres/code_generator.cc
@@ -0,0 +1,284 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) + +#include "ceres/internal/code_generator.h" +#include <sstream> +#include "assert.h" +#include "glog/logging.h" + +namespace ceres { +namespace internal { + +CodeGenerator::CodeGenerator(const ExpressionGraph& graph, + const Options& options) + : graph_(graph), options_(options) {} + +std::vector<std::string> CodeGenerator::Generate() { + std::vector<std::string> code; + + // 1. Print the header + if (!options_.function_name.empty()) { + code.emplace_back(options_.function_name); + } + + code.emplace_back("{"); + PushIndentation(); + + // 2. Print declarations + for (ExpressionId id = 0; id < graph_.Size(); ++id) { + // By definition of the lhs_id, an expression defines a new variable only if + // the current_id is identical to the lhs_id. + const auto& expr = graph_.ExpressionForId(id); + if (id != expr.lhs_id()) { + continue; + } + // + // Format: <type> <id>; + // Example: double v_0; + // + const std::string declaration_string = + indentation_ + DataTypeForExpression(expr.type()) + " " + + VariableForExpressionId(id) + ";"; + code.emplace_back(declaration_string); + } + + // 3. Print code + for (ExpressionId id = 0; id < graph_.Size(); ++id) { + code.emplace_back(ExpressionToString(id)); + } + + PopIndentation(); + CHECK(indentation_.empty()) << "IF - ENDIF missmatch detected."; + code.emplace_back("}"); + + return code; +} + +std::string CodeGenerator::ExpressionToString(ExpressionId id) { + // An expression is converted into a string, by first adding the required + // indentation spaces and then adding a ExpressionType-specific string. The + // following list shows the exact output format for each ExpressionType. The + // placeholders <value>, <name>,... stand for the respective members value_, + // name_, ... of the current expression. ExpressionIds such as lhs_id and + // arguments are converted to the corresponding variable name (7 -> "v_7"). + + auto& expr = graph_.ExpressionForId(id); + + std::stringstream result; + result.precision(kFloatingPointPrecision); + + // Convert the variable names of lhs and arguments to string. This makes the + // big switch/case below more readable. + std::string lhs; + if (expr.HasValidLhs()) { + lhs = VariableForExpressionId(expr.lhs_id()); + } + std::vector<std::string> args; + for (ExpressionId id : expr.arguments()) { + args.push_back(VariableForExpressionId(id)); + } + auto value = expr.value(); + const auto& name = expr.name(); + + switch (expr.type()) { + case ExpressionType::COMPILE_TIME_CONSTANT: { + // + // Format: <lhs_id> = <value>; + // Example: v_0 = 3.1415; + // + result << indentation_ << lhs << " = " << value << ";"; + break; + } + case ExpressionType::INPUT_ASSIGNMENT: { + // + // Format: <lhs_id> = <name>; + // Example: v_0 = _observed_point_x; + // + result << indentation_ << lhs << " = " << name << ";"; + break; + } + case ExpressionType::OUTPUT_ASSIGNMENT: { + // + // Format: <name> = <arguments[0]>; + // Example: residual[0] = v_51; + // + result << indentation_ << name << " = " << args[0] << ";"; + break; + } + case ExpressionType::ASSIGNMENT: { + // + // Format: <lhs_id> = <arguments[0]>; + // Example: v_1 = v_0; + // + result << indentation_ << lhs << " = " << args[0] << ";"; + break; + } + case ExpressionType::BINARY_ARITHMETIC: { + // + // Format: <lhs_id> = <arguments[0]> <name> <arguments[1]>; + // Example: v_2 = v_0 + v_1; + // + result << indentation_ << lhs << " = " << args[0] << " " << name << " " + << args[1] << ";"; + break; + } + case ExpressionType::UNARY_ARITHMETIC: { + // + // Format: <lhs_id> = <name><arguments[0]>; + // Example: v_1 = -v_0; + // + result << indentation_ << lhs << " = " << name << args[0] << ";"; + break; + } + case ExpressionType::BINARY_COMPARISON: { + // + // Format: <lhs_id> = <arguments[0]> <name> <arguments[1]>; + // Example: v_2 = v_0 < v_1; + // + result << indentation_ << lhs << " = " << args[0] << " " << name << " " + << args[1] << ";"; + break; + } + case ExpressionType::LOGICAL_NEGATION: { + // + // Format: <lhs_id> = !<arguments[0]>; + // Example: v_1 = !v_0; + // + result << indentation_ << lhs << " = !" << args[0] << ";"; + break; + } + case ExpressionType::FUNCTION_CALL: { + // + // Format: <lhs_id> = <name>(<arguments[0]>, <arguments[1]>, ...); + // Example: v_1 = sin(v_0); + // + result << indentation_ << lhs << " = " << name << "("; + result << (args.size() ? args[0] : ""); + for (int i = 1; i < args.size(); ++i) { + result << ", " << args[i]; + } + result << ");"; + break; + } + case ExpressionType::IF: { + // + // Format: if (<arguments[0]>) { + // Example: if (v_0) { + // Special: Adds 1 level of indentation for all following + // expressions. + // + result << indentation_ << "if (" << args[0] << ") {"; + PushIndentation(); + break; + } + case ExpressionType::ELSE: { + // + // Format: } else { + // Example: } else { + // Special: This expression is printed with one less level of + // indentation. + // + PopIndentation(); + result << indentation_ << "} else {"; + PushIndentation(); + break; + } + case ExpressionType::ENDIF: { + // + // Format: } + // Example: } + // Special: Removes 1 level of indentation for this and all + // following expressions. + // + PopIndentation(); + result << indentation_ << "}"; + break; + } + case ExpressionType::NOP: { + // + // Format: // <NOP> + // Example: // <NOP> + // + result << indentation_ << "// <NOP>"; + break; + } + default: + CHECK(false) << "CodeGenerator::ToString for ExpressionType " + << static_cast<int>(expr.type()) << " not implemented!"; + } + return result.str(); +} + +std::string CodeGenerator::VariableForExpressionId(ExpressionId id) { + // + // Format: <variable_prefix><id> + // Example: v_42 + // + auto& expr = graph_.ExpressionForId(id); + CHECK(expr.lhs_id() == id) + << "ExpressionId " << id + << " does not have a name (it has not been declared)."; + return options_.variable_prefix + std::to_string(expr.lhs_id()); +} + +std::string CodeGenerator::DataTypeForExpression(ExpressionType type) { + std::string type_string; + switch (type) { + case ExpressionType::BINARY_COMPARISON: + case ExpressionType::LOGICAL_NEGATION: + type_string = "bool"; + break; + case ExpressionType::IF: + case ExpressionType::ELSE: + case ExpressionType::ENDIF: + case ExpressionType::NOP: + type_string = "void"; + break; + default: + type_string = "double"; + } + return type_string; +} + +void CodeGenerator::PushIndentation() { + for (int i = 0; i < options_.indentation_spaces_per_level; ++i) { + indentation_.push_back(' '); + } +} + +void CodeGenerator::PopIndentation() { + for (int i = 0; i < options_.indentation_spaces_per_level; ++i) { + CHECK(!indentation_.empty()) << "IF - ENDIF missmatch detected."; + indentation_.pop_back(); + } +} + +} // namespace internal +} // namespace ceres
diff --git a/internal/ceres/code_generator_test.cc b/internal/ceres/code_generator_test.cc new file mode 100644 index 0000000..c0feb41 --- /dev/null +++ b/internal/ceres/code_generator_test.cc
@@ -0,0 +1,465 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) +// +#define CERES_CODEGEN + +#include "ceres/internal/code_generator.h" +#include "ceres/internal/expression_graph.h" +#include "ceres/internal/expression_ref.h" + +#include "gtest/gtest.h" + +namespace ceres { +namespace internal { + +static void GenerateAndCheck(const ExpressionGraph& graph, + const std::vector<std::string>& reference) { + CodeGenerator::Options generator_options; + CodeGenerator gen(graph, generator_options); + auto code = gen.Generate(); + EXPECT_EQ(code.size(), reference.size()); + + for (int i = 0; i < code.size(); ++i) { + EXPECT_EQ(code[i], reference[i]) << "Invalid Line: " << (i + 1); + } +} + +using T = ExpressionRef; + +TEST(CodeGenerator, Empty) { + StartRecordingExpressions(); + auto graph = StopRecordingExpressions(); + std::vector<std::string> expected_code = {"{", "}"}; + GenerateAndCheck(graph, expected_code); +} + +// Now we add one TEST for each ExpressionType. +TEST(CodeGenerator, COMPILE_TIME_CONSTANT) { + StartRecordingExpressions(); + T a = T(0); + T b = T(123.5); + T c = T(1 + 1); + T d; // Uninitialized variables should not generate code! + auto graph = StopRecordingExpressions(); + std::vector<std::string> expected_code = {"{", + " double v_0;", + " double v_1;", + " double v_2;", + " v_0 = 0;", + " v_1 = 123.5;", + " v_2 = 2;", + "}"}; + GenerateAndCheck(graph, expected_code); +} + +TEST(CodeGenerator, INPUT_ASSIGNMENT) { + double local_variable = 5.0; + StartRecordingExpressions(); + T a = CERES_LOCAL_VARIABLE(local_variable); + T b = MakeParameter("parameters[0][0]"); + T c = a + b; + auto graph = StopRecordingExpressions(); + std::vector<std::string> expected_code = {"{", + " double v_0;", + " double v_1;", + " double v_2;", + " v_0 = local_variable;", + " v_1 = parameters[0][0];", + " v_2 = v_0 + v_1;", + "}"}; + GenerateAndCheck(graph, expected_code); +} + +TEST(CodeGenerator, OUTPUT_ASSIGNMENT) { + double local_variable = 5.0; + StartRecordingExpressions(); + T a = 1; + T b = 0; + MakeOutput(a, "residual[0]"); + auto graph = StopRecordingExpressions(); + std::vector<std::string> expected_code = {"{", + " double v_0;", + " double v_1;", + " double v_2;", + " v_0 = 1;", + " v_1 = 0;", + " residual[0] = v_0;", + "}"}; + GenerateAndCheck(graph, expected_code); +} + +TEST(CodeGenerator, ASSIGNMENT) { + StartRecordingExpressions(); + T a = T(0); + T b = T(1); + T c = a; // < This should not generate a line! + a = b; + a = a + b; // < Create temporary + assignment + auto graph = StopRecordingExpressions(); + std::vector<std::string> expected_code = {"{", + " double v_0;", + " double v_1;", + " double v_3;", + " v_0 = 0;", + " v_1 = 1;", + " v_0 = v_1;", + " v_3 = v_0 + v_1;", + " v_0 = v_3;", + "}"}; + GenerateAndCheck(graph, expected_code); +} + +TEST(CodeGenerator, BINARY_ARITHMETIC_SIMPLE) { + StartRecordingExpressions(); + T a = T(0); + T b = T(1); + T r1 = a + b; + T r2 = a - b; + T r3 = a * b; + T r4 = a / b; + auto graph = StopRecordingExpressions(); + std::vector<std::string> expected_code = {"{", + " double v_0;", + " double v_1;", + " double v_2;", + " double v_3;", + " double v_4;", + " double v_5;", + " v_0 = 0;", + " v_1 = 1;", + " v_2 = v_0 + v_1;", + " v_3 = v_0 - v_1;", + " v_4 = v_0 * v_1;", + " v_5 = v_0 / v_1;", + "}"}; + GenerateAndCheck(graph, expected_code); +} + +TEST(CodeGenerator, BINARY_ARITHMETIC_COMPOUND) { + // For each binary compound arithmetic operation, two lines are generated: + // - The actual operation assigning to a new temporary variable + // - An assignment from the temporary to the lhs + StartRecordingExpressions(); + T a = T(0); + T b = T(1); + b += a; + b -= a; + b *= a; + b /= a; + auto graph = StopRecordingExpressions(); + std::vector<std::string> expected_code = {"{", + " double v_0;", + " double v_1;", + " double v_2;", + " double v_4;", + " double v_6;", + " double v_8;", + " v_0 = 0;", + " v_1 = 1;", + " v_2 = v_1 + v_0;", + " v_1 = v_2;", + " v_4 = v_1 - v_0;", + " v_1 = v_4;", + " v_6 = v_1 * v_0;", + " v_1 = v_6;", + " v_8 = v_1 / v_0;", + " v_1 = v_8;", + "}"}; + GenerateAndCheck(graph, expected_code); +} + +TEST(CodeGenerator, UNARY_ARITHMETIC) { + StartRecordingExpressions(); + T a = T(0); + T r1 = -a; + T r2 = +a; + auto graph = StopRecordingExpressions(); + std::vector<std::string> expected_code = {"{", + " double v_0;", + " double v_1;", + " double v_2;", + " v_0 = 0;", + " v_1 = -v_0;", + " v_2 = +v_0;", + "}"}; + GenerateAndCheck(graph, expected_code); +} + +TEST(CodeGenerator, BINARY_COMPARISON) { + StartRecordingExpressions(); + T a = T(0); + T b = T(1); + auto r1 = a < b; + auto r2 = a <= b; + auto r3 = a > b; + auto r4 = a >= b; + auto r5 = a == b; + auto r6 = a != b; + auto graph = StopRecordingExpressions(); + std::vector<std::string> expected_code = {"{", + " double v_0;", + " double v_1;", + " bool v_2;", + " bool v_3;", + " bool v_4;", + " bool v_5;", + " bool v_6;", + " bool v_7;", + " v_0 = 0;", + " v_1 = 1;", + " v_2 = v_0 < v_1;", + " v_3 = v_0 <= v_1;", + " v_4 = v_0 > v_1;", + " v_5 = v_0 >= v_1;", + " v_6 = v_0 == v_1;", + " v_7 = v_0 != v_1;", + "}"}; + GenerateAndCheck(graph, expected_code); +} + +TEST(CodeGenerator, LOGICAL_OPERATORS) { + // Tests binary logical operators &&, || and the unary logical operator ! + StartRecordingExpressions(); + T a = T(0); + T b = T(1); + auto r1 = a < b; + auto r2 = a <= b; + + auto r3 = r1 && r2; + auto r4 = r1 || r2; + auto r5 = !r1; + + auto graph = StopRecordingExpressions(); + std::vector<std::string> expected_code = {"{", + " double v_0;", + " double v_1;", + " bool v_2;", + " bool v_3;", + " bool v_4;", + " bool v_5;", + " bool v_6;", + " v_0 = 0;", + " v_1 = 1;", + " v_2 = v_0 < v_1;", + " v_3 = v_0 <= v_1;", + " v_4 = v_2 && v_3;", + " v_5 = v_2 || v_3;", + " v_6 = !v_2;", + "}"}; + GenerateAndCheck(graph, expected_code); +} + +TEST(CodeGenerator, FUNCTION_CALL) { + StartRecordingExpressions(); + T a = T(0); + T b = T(1); + + abs(a); + acos(a); + asin(a); + atan(a); + cbrt(a); + ceil(a); + cos(a); + cosh(a); + exp(a); + exp2(a); + floor(a); + log(a); + log2(a); + sin(a); + sinh(a); + sqrt(a); + tan(a); + tanh(a); + atan2(a, b); + pow(a, b); + + auto graph = StopRecordingExpressions(); + + std::vector<std::string> expected_code = {"{", + " double v_0;", + " double v_1;", + " double v_2;", + " double v_3;", + " double v_4;", + " double v_5;", + " double v_6;", + " double v_7;", + " double v_8;", + " double v_9;", + " double v_10;", + " double v_11;", + " double v_12;", + " double v_13;", + " double v_14;", + " double v_15;", + " double v_16;", + " double v_17;", + " double v_18;", + " double v_19;", + " double v_20;", + " double v_21;", + " v_0 = 0;", + " v_1 = 1;", + " v_2 = abs(v_0);", + " v_3 = acos(v_0);", + " v_4 = asin(v_0);", + " v_5 = atan(v_0);", + " v_6 = cbrt(v_0);", + " v_7 = ceil(v_0);", + " v_8 = cos(v_0);", + " v_9 = cosh(v_0);", + " v_10 = exp(v_0);", + " v_11 = exp2(v_0);", + " v_12 = floor(v_0);", + " v_13 = log(v_0);", + " v_14 = log2(v_0);", + " v_15 = sin(v_0);", + " v_16 = sinh(v_0);", + " v_17 = sqrt(v_0);", + " v_18 = tan(v_0);", + " v_19 = tanh(v_0);", + " v_20 = atan2(v_0, v_1);", + " v_21 = pow(v_0, v_1);", + "}"}; + GenerateAndCheck(graph, expected_code); +} + +TEST(CodeGenerator, IF_SIMPLE) { + StartRecordingExpressions(); + T a = T(0); + T b = T(1); + auto r1 = a < b; + CERES_IF(r1) {} + CERES_ELSE {} + CERES_ENDIF; + + auto graph = StopRecordingExpressions(); + std::vector<std::string> expected_code = {"{", + " double v_0;", + " double v_1;", + " bool v_2;", + " v_0 = 0;", + " v_1 = 1;", + " v_2 = v_0 < v_1;", + " if (v_2) {", + " } else {", + " }", + "}"}; + GenerateAndCheck(graph, expected_code); +} + +TEST(CodeGenerator, IF_ASSIGNMENT) { + StartRecordingExpressions(); + T a = T(0); + T b = T(1); + auto r1 = a < b; + + T result = 0; + CERES_IF(r1) { result = 5.0; } + CERES_ELSE { result = 6.0; } + CERES_ENDIF; + MakeOutput(result, "result"); + + auto graph = StopRecordingExpressions(); + std::vector<std::string> expected_code = {"{", + " double v_0;", + " double v_1;", + " bool v_2;", + " double v_3;", + " double v_5;", + " double v_8;", + " double v_11;", + " v_0 = 0;", + " v_1 = 1;", + " v_2 = v_0 < v_1;", + " v_3 = 0;", + " if (v_2) {", + " v_5 = 5;", + " v_3 = v_5;", + " } else {", + " v_8 = 6;", + " v_3 = v_8;", + " }", + " result = v_3;", + "}"}; + GenerateAndCheck(graph, expected_code); +} + +TEST(CodeGenerator, IF_NESTED_ASSIGNMENT) { + StartRecordingExpressions(); + T a = T(0); + T b = T(1); + + T result = 0; + CERES_IF(a <= b) { + result = 5.0; + CERES_IF(a == b) { result = 7.0; } + CERES_ENDIF; + } + CERES_ELSE { result = 6.0; } + CERES_ENDIF; + MakeOutput(result, "result"); + + auto graph = StopRecordingExpressions(); + std::vector<std::string> expected_code = {"{", + " double v_0;", + " double v_1;", + " double v_2;", + " bool v_3;", + " double v_5;", + " bool v_7;", + " double v_9;", + " double v_13;", + " double v_16;", + " v_0 = 0;", + " v_1 = 1;", + " v_2 = 0;", + " v_3 = v_0 <= v_1;", + " if (v_3) {", + " v_5 = 5;", + " v_2 = v_5;", + " v_7 = v_0 == v_1;", + " if (v_7) {", + " v_9 = 7;", + " v_2 = v_9;", + " }", + " } else {", + " v_13 = 6;", + " v_2 = v_13;", + " }", + " result = v_2;", + "}"}; + GenerateAndCheck(graph, expected_code); +} + +} // namespace internal +} // namespace ceres