|  | // Ceres Solver - A fast non-linear least squares minimizer | 
|  | // Copyright 2019 Google Inc. All rights reserved. | 
|  | // http://code.google.com/p/ceres-solver/ | 
|  | // | 
|  | // Redistribution and use in source and binary forms, with or without | 
|  | // modification, are permitted provided that the following conditions are met: | 
|  | // | 
|  | // * Redistributions of source code must retain the above copyright notice, | 
|  | //   this list of conditions and the following disclaimer. | 
|  | // * Redistributions in binary form must reproduce the above copyright notice, | 
|  | //   this list of conditions and the following disclaimer in the documentation | 
|  | //   and/or other materials provided with the distribution. | 
|  | // * Neither the name of Google Inc. nor the names of its contributors may be | 
|  | //   used to endorse or promote products derived from this software without | 
|  | //   specific prior written permission. | 
|  | // | 
|  | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
|  | // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
|  | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | 
|  | // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | 
|  | // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | 
|  | // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | 
|  | // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | 
|  | // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | 
|  | // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | 
|  | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | 
|  | // POSSIBILITY OF SUCH DAMAGE. | 
|  | // | 
|  | // Author: darius.rueckert@fau.de (Darius Rueckert) | 
|  | // | 
|  | #define CERES_CODEGEN | 
|  |  | 
|  | #include "ceres/internal/code_generator.h" | 
|  | #include "ceres/internal/expression_graph.h" | 
|  | #include "ceres/internal/expression_ref.h" | 
|  |  | 
|  | #include "gtest/gtest.h" | 
|  |  | 
|  | namespace ceres { | 
|  | namespace internal { | 
|  |  | 
|  | static void GenerateAndCheck(const ExpressionGraph& graph, | 
|  | const std::vector<std::string>& reference) { | 
|  | CodeGenerator::Options generator_options; | 
|  | CodeGenerator gen(graph, generator_options); | 
|  | auto code = gen.Generate(); | 
|  | EXPECT_EQ(code.size(), reference.size()); | 
|  |  | 
|  | for (int i = 0; i < code.size(); ++i) { | 
|  | EXPECT_EQ(code[i], reference[i]) << "Invalid Line: " << (i + 1); | 
|  | } | 
|  | } | 
|  |  | 
|  | using T = ExpressionRef; | 
|  |  | 
|  | TEST(CodeGenerator, Empty) { | 
|  | StartRecordingExpressions(); | 
|  | auto graph = StopRecordingExpressions(); | 
|  | std::vector<std::string> expected_code = {"{", "}"}; | 
|  | GenerateAndCheck(graph, expected_code); | 
|  | } | 
|  |  | 
|  | // Now we add one TEST for each ExpressionType. | 
|  | TEST(CodeGenerator, COMPILE_TIME_CONSTANT) { | 
|  | StartRecordingExpressions(); | 
|  | T a = T(0); | 
|  | T b = T(123.5); | 
|  | T c = T(1 + 1); | 
|  | T d;  // Uninitialized variables should not generate code! | 
|  | auto graph = StopRecordingExpressions(); | 
|  | std::vector<std::string> expected_code = {"{", | 
|  | "  double v_0;", | 
|  | "  double v_1;", | 
|  | "  double v_2;", | 
|  | "  v_0 = 0;", | 
|  | "  v_1 = 123.5;", | 
|  | "  v_2 = 2;", | 
|  | "}"}; | 
|  | GenerateAndCheck(graph, expected_code); | 
|  | } | 
|  |  | 
|  | TEST(CodeGenerator, INPUT_ASSIGNMENT) { | 
|  | double local_variable = 5.0; | 
|  | StartRecordingExpressions(); | 
|  | T a = CERES_LOCAL_VARIABLE(local_variable); | 
|  | T b = MakeParameter("parameters[0][0]"); | 
|  | T c = a + b; | 
|  | auto graph = StopRecordingExpressions(); | 
|  | std::vector<std::string> expected_code = {"{", | 
|  | "  double v_0;", | 
|  | "  double v_1;", | 
|  | "  double v_2;", | 
|  | "  v_0 = local_variable;", | 
|  | "  v_1 = parameters[0][0];", | 
|  | "  v_2 = v_0 + v_1;", | 
|  | "}"}; | 
|  | GenerateAndCheck(graph, expected_code); | 
|  | } | 
|  |  | 
|  | TEST(CodeGenerator, OUTPUT_ASSIGNMENT) { | 
|  | double local_variable = 5.0; | 
|  | StartRecordingExpressions(); | 
|  | T a = 1; | 
|  | T b = 0; | 
|  | MakeOutput(a, "residual[0]"); | 
|  | auto graph = StopRecordingExpressions(); | 
|  | std::vector<std::string> expected_code = {"{", | 
|  | "  double v_0;", | 
|  | "  double v_1;", | 
|  | "  double v_2;", | 
|  | "  v_0 = 1;", | 
|  | "  v_1 = 0;", | 
|  | "  residual[0] = v_0;", | 
|  | "}"}; | 
|  | GenerateAndCheck(graph, expected_code); | 
|  | } | 
|  |  | 
|  | TEST(CodeGenerator, ASSIGNMENT) { | 
|  | StartRecordingExpressions(); | 
|  | T a = T(0); | 
|  | T b = T(1); | 
|  | T c = a;  // < This should not generate a line! | 
|  | a = b; | 
|  | a = a + b;  // < Create temporary + assignment | 
|  | auto graph = StopRecordingExpressions(); | 
|  | std::vector<std::string> expected_code = {"{", | 
|  | "  double v_0;", | 
|  | "  double v_1;", | 
|  | "  double v_3;", | 
|  | "  v_0 = 0;", | 
|  | "  v_1 = 1;", | 
|  | "  v_0 = v_1;", | 
|  | "  v_3 = v_0 + v_1;", | 
|  | "  v_0 = v_3;", | 
|  | "}"}; | 
|  | GenerateAndCheck(graph, expected_code); | 
|  | } | 
|  |  | 
|  | TEST(CodeGenerator, BINARY_ARITHMETIC_SIMPLE) { | 
|  | StartRecordingExpressions(); | 
|  | T a = T(0); | 
|  | T b = T(1); | 
|  | T r1 = a + b; | 
|  | T r2 = a - b; | 
|  | T r3 = a * b; | 
|  | T r4 = a / b; | 
|  | auto graph = StopRecordingExpressions(); | 
|  | std::vector<std::string> expected_code = {"{", | 
|  | "  double v_0;", | 
|  | "  double v_1;", | 
|  | "  double v_2;", | 
|  | "  double v_3;", | 
|  | "  double v_4;", | 
|  | "  double v_5;", | 
|  | "  v_0 = 0;", | 
|  | "  v_1 = 1;", | 
|  | "  v_2 = v_0 + v_1;", | 
|  | "  v_3 = v_0 - v_1;", | 
|  | "  v_4 = v_0 * v_1;", | 
|  | "  v_5 = v_0 / v_1;", | 
|  | "}"}; | 
|  | GenerateAndCheck(graph, expected_code); | 
|  | } | 
|  |  | 
|  | TEST(CodeGenerator, BINARY_ARITHMETIC_COMPOUND) { | 
|  | // For each binary compound arithmetic operation, two lines are generated: | 
|  | //    - The actual operation assigning to a new temporary variable | 
|  | //    - An assignment from the temporary to the lhs | 
|  | StartRecordingExpressions(); | 
|  | T a = T(0); | 
|  | T b = T(1); | 
|  | b += a; | 
|  | b -= a; | 
|  | b *= a; | 
|  | b /= a; | 
|  | auto graph = StopRecordingExpressions(); | 
|  | std::vector<std::string> expected_code = {"{", | 
|  | "  double v_0;", | 
|  | "  double v_1;", | 
|  | "  double v_2;", | 
|  | "  double v_4;", | 
|  | "  double v_6;", | 
|  | "  double v_8;", | 
|  | "  v_0 = 0;", | 
|  | "  v_1 = 1;", | 
|  | "  v_2 = v_1 + v_0;", | 
|  | "  v_1 = v_2;", | 
|  | "  v_4 = v_1 - v_0;", | 
|  | "  v_1 = v_4;", | 
|  | "  v_6 = v_1 * v_0;", | 
|  | "  v_1 = v_6;", | 
|  | "  v_8 = v_1 / v_0;", | 
|  | "  v_1 = v_8;", | 
|  | "}"}; | 
|  | GenerateAndCheck(graph, expected_code); | 
|  | } | 
|  |  | 
|  | TEST(CodeGenerator, UNARY_ARITHMETIC) { | 
|  | StartRecordingExpressions(); | 
|  | T a = T(0); | 
|  | T r1 = -a; | 
|  | T r2 = +a; | 
|  | auto graph = StopRecordingExpressions(); | 
|  | std::vector<std::string> expected_code = {"{", | 
|  | "  double v_0;", | 
|  | "  double v_1;", | 
|  | "  double v_2;", | 
|  | "  v_0 = 0;", | 
|  | "  v_1 = -v_0;", | 
|  | "  v_2 = +v_0;", | 
|  | "}"}; | 
|  | GenerateAndCheck(graph, expected_code); | 
|  | } | 
|  |  | 
|  | TEST(CodeGenerator, BINARY_COMPARISON) { | 
|  | StartRecordingExpressions(); | 
|  | T a = T(0); | 
|  | T b = T(1); | 
|  | auto r1 = a < b; | 
|  | auto r2 = a <= b; | 
|  | auto r3 = a > b; | 
|  | auto r4 = a >= b; | 
|  | auto r5 = a == b; | 
|  | auto r6 = a != b; | 
|  | auto graph = StopRecordingExpressions(); | 
|  | std::vector<std::string> expected_code = {"{", | 
|  | "  double v_0;", | 
|  | "  double v_1;", | 
|  | "  bool v_2;", | 
|  | "  bool v_3;", | 
|  | "  bool v_4;", | 
|  | "  bool v_5;", | 
|  | "  bool v_6;", | 
|  | "  bool v_7;", | 
|  | "  v_0 = 0;", | 
|  | "  v_1 = 1;", | 
|  | "  v_2 = v_0 < v_1;", | 
|  | "  v_3 = v_0 <= v_1;", | 
|  | "  v_4 = v_0 > v_1;", | 
|  | "  v_5 = v_0 >= v_1;", | 
|  | "  v_6 = v_0 == v_1;", | 
|  | "  v_7 = v_0 != v_1;", | 
|  | "}"}; | 
|  | GenerateAndCheck(graph, expected_code); | 
|  | } | 
|  |  | 
|  | TEST(CodeGenerator, LOGICAL_OPERATORS) { | 
|  | // Tests binary logical operators &&, || and the unary logical operator ! | 
|  | StartRecordingExpressions(); | 
|  | T a = T(0); | 
|  | T b = T(1); | 
|  | auto r1 = a < b; | 
|  | auto r2 = a <= b; | 
|  |  | 
|  | auto r3 = r1 && r2; | 
|  | auto r4 = r1 || r2; | 
|  | auto r5 = !r1; | 
|  |  | 
|  | auto graph = StopRecordingExpressions(); | 
|  | std::vector<std::string> expected_code = {"{", | 
|  | "  double v_0;", | 
|  | "  double v_1;", | 
|  | "  bool v_2;", | 
|  | "  bool v_3;", | 
|  | "  bool v_4;", | 
|  | "  bool v_5;", | 
|  | "  bool v_6;", | 
|  | "  v_0 = 0;", | 
|  | "  v_1 = 1;", | 
|  | "  v_2 = v_0 < v_1;", | 
|  | "  v_3 = v_0 <= v_1;", | 
|  | "  v_4 = v_2 && v_3;", | 
|  | "  v_5 = v_2 || v_3;", | 
|  | "  v_6 = !v_2;", | 
|  | "}"}; | 
|  | GenerateAndCheck(graph, expected_code); | 
|  | } | 
|  |  | 
|  | TEST(CodeGenerator, FUNCTION_CALL) { | 
|  | StartRecordingExpressions(); | 
|  | T a = T(0); | 
|  | T b = T(1); | 
|  |  | 
|  | abs(a); | 
|  | acos(a); | 
|  | asin(a); | 
|  | atan(a); | 
|  | cbrt(a); | 
|  | ceil(a); | 
|  | cos(a); | 
|  | cosh(a); | 
|  | exp(a); | 
|  | exp2(a); | 
|  | floor(a); | 
|  | log(a); | 
|  | log2(a); | 
|  | sin(a); | 
|  | sinh(a); | 
|  | sqrt(a); | 
|  | tan(a); | 
|  | tanh(a); | 
|  | atan2(a, b); | 
|  | pow(a, b); | 
|  |  | 
|  | auto graph = StopRecordingExpressions(); | 
|  |  | 
|  | std::vector<std::string> expected_code = {"{", | 
|  | "  double v_0;", | 
|  | "  double v_1;", | 
|  | "  double v_2;", | 
|  | "  double v_3;", | 
|  | "  double v_4;", | 
|  | "  double v_5;", | 
|  | "  double v_6;", | 
|  | "  double v_7;", | 
|  | "  double v_8;", | 
|  | "  double v_9;", | 
|  | "  double v_10;", | 
|  | "  double v_11;", | 
|  | "  double v_12;", | 
|  | "  double v_13;", | 
|  | "  double v_14;", | 
|  | "  double v_15;", | 
|  | "  double v_16;", | 
|  | "  double v_17;", | 
|  | "  double v_18;", | 
|  | "  double v_19;", | 
|  | "  double v_20;", | 
|  | "  double v_21;", | 
|  | "  v_0 = 0;", | 
|  | "  v_1 = 1;", | 
|  | "  v_2 = abs(v_0);", | 
|  | "  v_3 = acos(v_0);", | 
|  | "  v_4 = asin(v_0);", | 
|  | "  v_5 = atan(v_0);", | 
|  | "  v_6 = cbrt(v_0);", | 
|  | "  v_7 = ceil(v_0);", | 
|  | "  v_8 = cos(v_0);", | 
|  | "  v_9 = cosh(v_0);", | 
|  | "  v_10 = exp(v_0);", | 
|  | "  v_11 = exp2(v_0);", | 
|  | "  v_12 = floor(v_0);", | 
|  | "  v_13 = log(v_0);", | 
|  | "  v_14 = log2(v_0);", | 
|  | "  v_15 = sin(v_0);", | 
|  | "  v_16 = sinh(v_0);", | 
|  | "  v_17 = sqrt(v_0);", | 
|  | "  v_18 = tan(v_0);", | 
|  | "  v_19 = tanh(v_0);", | 
|  | "  v_20 = atan2(v_0, v_1);", | 
|  | "  v_21 = pow(v_0, v_1);", | 
|  | "}"}; | 
|  | GenerateAndCheck(graph, expected_code); | 
|  | } | 
|  |  | 
|  | TEST(CodeGenerator, IF_SIMPLE) { | 
|  | StartRecordingExpressions(); | 
|  | T a = T(0); | 
|  | T b = T(1); | 
|  | auto r1 = a < b; | 
|  | CERES_IF(r1) {} | 
|  | CERES_ELSE {} | 
|  | CERES_ENDIF; | 
|  |  | 
|  | auto graph = StopRecordingExpressions(); | 
|  | std::vector<std::string> expected_code = {"{", | 
|  | "  double v_0;", | 
|  | "  double v_1;", | 
|  | "  bool v_2;", | 
|  | "  v_0 = 0;", | 
|  | "  v_1 = 1;", | 
|  | "  v_2 = v_0 < v_1;", | 
|  | "  if (v_2) {", | 
|  | "  } else {", | 
|  | "  }", | 
|  | "}"}; | 
|  | GenerateAndCheck(graph, expected_code); | 
|  | } | 
|  |  | 
|  | TEST(CodeGenerator, IF_ASSIGNMENT) { | 
|  | StartRecordingExpressions(); | 
|  | T a = T(0); | 
|  | T b = T(1); | 
|  | auto r1 = a < b; | 
|  |  | 
|  | T result = 0; | 
|  | CERES_IF(r1) { result = 5.0; } | 
|  | CERES_ELSE { result = 6.0; } | 
|  | CERES_ENDIF; | 
|  | MakeOutput(result, "result"); | 
|  |  | 
|  | auto graph = StopRecordingExpressions(); | 
|  | std::vector<std::string> expected_code = {"{", | 
|  | "  double v_0;", | 
|  | "  double v_1;", | 
|  | "  bool v_2;", | 
|  | "  double v_3;", | 
|  | "  double v_5;", | 
|  | "  double v_8;", | 
|  | "  double v_11;", | 
|  | "  v_0 = 0;", | 
|  | "  v_1 = 1;", | 
|  | "  v_2 = v_0 < v_1;", | 
|  | "  v_3 = 0;", | 
|  | "  if (v_2) {", | 
|  | "    v_5 = 5;", | 
|  | "    v_3 = v_5;", | 
|  | "  } else {", | 
|  | "    v_8 = 6;", | 
|  | "    v_3 = v_8;", | 
|  | "  }", | 
|  | "  result = v_3;", | 
|  | "}"}; | 
|  | GenerateAndCheck(graph, expected_code); | 
|  | } | 
|  |  | 
|  | TEST(CodeGenerator, IF_NESTED_ASSIGNMENT) { | 
|  | StartRecordingExpressions(); | 
|  | T a = T(0); | 
|  | T b = T(1); | 
|  |  | 
|  | T result = 0; | 
|  | CERES_IF(a <= b) { | 
|  | result = 5.0; | 
|  | CERES_IF(a == b) { result = 7.0; } | 
|  | CERES_ENDIF; | 
|  | } | 
|  | CERES_ELSE { result = 6.0; } | 
|  | CERES_ENDIF; | 
|  | MakeOutput(result, "result"); | 
|  |  | 
|  | auto graph = StopRecordingExpressions(); | 
|  | std::vector<std::string> expected_code = {"{", | 
|  | "  double v_0;", | 
|  | "  double v_1;", | 
|  | "  double v_2;", | 
|  | "  bool v_3;", | 
|  | "  double v_5;", | 
|  | "  bool v_7;", | 
|  | "  double v_9;", | 
|  | "  double v_13;", | 
|  | "  double v_16;", | 
|  | "  v_0 = 0;", | 
|  | "  v_1 = 1;", | 
|  | "  v_2 = 0;", | 
|  | "  v_3 = v_0 <= v_1;", | 
|  | "  if (v_3) {", | 
|  | "    v_5 = 5;", | 
|  | "    v_2 = v_5;", | 
|  | "    v_7 = v_0 == v_1;", | 
|  | "    if (v_7) {", | 
|  | "      v_9 = 7;", | 
|  | "      v_2 = v_9;", | 
|  | "    }", | 
|  | "  } else {", | 
|  | "    v_13 = 6;", | 
|  | "    v_2 = v_13;", | 
|  | "  }", | 
|  | "  result = v_2;", | 
|  | "}"}; | 
|  | GenerateAndCheck(graph, expected_code); | 
|  | } | 
|  |  | 
|  | }  // namespace internal | 
|  | }  // namespace ceres |