blob: c0feb414e2432ba9f5d4febaa3e07fe1c4e55183 [file] [log] [blame]
// Ceres Solver - A fast non-linear least squares minimizer
// Copyright 2019 Google Inc. All rights reserved.
// http://code.google.com/p/ceres-solver/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
// * Neither the name of Google Inc. nor the names of its contributors may be
// used to endorse or promote products derived from this software without
// specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Author: darius.rueckert@fau.de (Darius Rueckert)
//
#define CERES_CODEGEN
#include "ceres/internal/code_generator.h"
#include "ceres/internal/expression_graph.h"
#include "ceres/internal/expression_ref.h"
#include "gtest/gtest.h"
namespace ceres {
namespace internal {
static void GenerateAndCheck(const ExpressionGraph& graph,
const std::vector<std::string>& reference) {
CodeGenerator::Options generator_options;
CodeGenerator gen(graph, generator_options);
auto code = gen.Generate();
EXPECT_EQ(code.size(), reference.size());
for (int i = 0; i < code.size(); ++i) {
EXPECT_EQ(code[i], reference[i]) << "Invalid Line: " << (i + 1);
}
}
using T = ExpressionRef;
TEST(CodeGenerator, Empty) {
StartRecordingExpressions();
auto graph = StopRecordingExpressions();
std::vector<std::string> expected_code = {"{", "}"};
GenerateAndCheck(graph, expected_code);
}
// Now we add one TEST for each ExpressionType.
TEST(CodeGenerator, COMPILE_TIME_CONSTANT) {
StartRecordingExpressions();
T a = T(0);
T b = T(123.5);
T c = T(1 + 1);
T d; // Uninitialized variables should not generate code!
auto graph = StopRecordingExpressions();
std::vector<std::string> expected_code = {"{",
" double v_0;",
" double v_1;",
" double v_2;",
" v_0 = 0;",
" v_1 = 123.5;",
" v_2 = 2;",
"}"};
GenerateAndCheck(graph, expected_code);
}
TEST(CodeGenerator, INPUT_ASSIGNMENT) {
double local_variable = 5.0;
StartRecordingExpressions();
T a = CERES_LOCAL_VARIABLE(local_variable);
T b = MakeParameter("parameters[0][0]");
T c = a + b;
auto graph = StopRecordingExpressions();
std::vector<std::string> expected_code = {"{",
" double v_0;",
" double v_1;",
" double v_2;",
" v_0 = local_variable;",
" v_1 = parameters[0][0];",
" v_2 = v_0 + v_1;",
"}"};
GenerateAndCheck(graph, expected_code);
}
TEST(CodeGenerator, OUTPUT_ASSIGNMENT) {
double local_variable = 5.0;
StartRecordingExpressions();
T a = 1;
T b = 0;
MakeOutput(a, "residual[0]");
auto graph = StopRecordingExpressions();
std::vector<std::string> expected_code = {"{",
" double v_0;",
" double v_1;",
" double v_2;",
" v_0 = 1;",
" v_1 = 0;",
" residual[0] = v_0;",
"}"};
GenerateAndCheck(graph, expected_code);
}
TEST(CodeGenerator, ASSIGNMENT) {
StartRecordingExpressions();
T a = T(0);
T b = T(1);
T c = a; // < This should not generate a line!
a = b;
a = a + b; // < Create temporary + assignment
auto graph = StopRecordingExpressions();
std::vector<std::string> expected_code = {"{",
" double v_0;",
" double v_1;",
" double v_3;",
" v_0 = 0;",
" v_1 = 1;",
" v_0 = v_1;",
" v_3 = v_0 + v_1;",
" v_0 = v_3;",
"}"};
GenerateAndCheck(graph, expected_code);
}
TEST(CodeGenerator, BINARY_ARITHMETIC_SIMPLE) {
StartRecordingExpressions();
T a = T(0);
T b = T(1);
T r1 = a + b;
T r2 = a - b;
T r3 = a * b;
T r4 = a / b;
auto graph = StopRecordingExpressions();
std::vector<std::string> expected_code = {"{",
" double v_0;",
" double v_1;",
" double v_2;",
" double v_3;",
" double v_4;",
" double v_5;",
" v_0 = 0;",
" v_1 = 1;",
" v_2 = v_0 + v_1;",
" v_3 = v_0 - v_1;",
" v_4 = v_0 * v_1;",
" v_5 = v_0 / v_1;",
"}"};
GenerateAndCheck(graph, expected_code);
}
TEST(CodeGenerator, BINARY_ARITHMETIC_COMPOUND) {
// For each binary compound arithmetic operation, two lines are generated:
// - The actual operation assigning to a new temporary variable
// - An assignment from the temporary to the lhs
StartRecordingExpressions();
T a = T(0);
T b = T(1);
b += a;
b -= a;
b *= a;
b /= a;
auto graph = StopRecordingExpressions();
std::vector<std::string> expected_code = {"{",
" double v_0;",
" double v_1;",
" double v_2;",
" double v_4;",
" double v_6;",
" double v_8;",
" v_0 = 0;",
" v_1 = 1;",
" v_2 = v_1 + v_0;",
" v_1 = v_2;",
" v_4 = v_1 - v_0;",
" v_1 = v_4;",
" v_6 = v_1 * v_0;",
" v_1 = v_6;",
" v_8 = v_1 / v_0;",
" v_1 = v_8;",
"}"};
GenerateAndCheck(graph, expected_code);
}
TEST(CodeGenerator, UNARY_ARITHMETIC) {
StartRecordingExpressions();
T a = T(0);
T r1 = -a;
T r2 = +a;
auto graph = StopRecordingExpressions();
std::vector<std::string> expected_code = {"{",
" double v_0;",
" double v_1;",
" double v_2;",
" v_0 = 0;",
" v_1 = -v_0;",
" v_2 = +v_0;",
"}"};
GenerateAndCheck(graph, expected_code);
}
TEST(CodeGenerator, BINARY_COMPARISON) {
StartRecordingExpressions();
T a = T(0);
T b = T(1);
auto r1 = a < b;
auto r2 = a <= b;
auto r3 = a > b;
auto r4 = a >= b;
auto r5 = a == b;
auto r6 = a != b;
auto graph = StopRecordingExpressions();
std::vector<std::string> expected_code = {"{",
" double v_0;",
" double v_1;",
" bool v_2;",
" bool v_3;",
" bool v_4;",
" bool v_5;",
" bool v_6;",
" bool v_7;",
" v_0 = 0;",
" v_1 = 1;",
" v_2 = v_0 < v_1;",
" v_3 = v_0 <= v_1;",
" v_4 = v_0 > v_1;",
" v_5 = v_0 >= v_1;",
" v_6 = v_0 == v_1;",
" v_7 = v_0 != v_1;",
"}"};
GenerateAndCheck(graph, expected_code);
}
TEST(CodeGenerator, LOGICAL_OPERATORS) {
// Tests binary logical operators &&, || and the unary logical operator !
StartRecordingExpressions();
T a = T(0);
T b = T(1);
auto r1 = a < b;
auto r2 = a <= b;
auto r3 = r1 && r2;
auto r4 = r1 || r2;
auto r5 = !r1;
auto graph = StopRecordingExpressions();
std::vector<std::string> expected_code = {"{",
" double v_0;",
" double v_1;",
" bool v_2;",
" bool v_3;",
" bool v_4;",
" bool v_5;",
" bool v_6;",
" v_0 = 0;",
" v_1 = 1;",
" v_2 = v_0 < v_1;",
" v_3 = v_0 <= v_1;",
" v_4 = v_2 && v_3;",
" v_5 = v_2 || v_3;",
" v_6 = !v_2;",
"}"};
GenerateAndCheck(graph, expected_code);
}
TEST(CodeGenerator, FUNCTION_CALL) {
StartRecordingExpressions();
T a = T(0);
T b = T(1);
abs(a);
acos(a);
asin(a);
atan(a);
cbrt(a);
ceil(a);
cos(a);
cosh(a);
exp(a);
exp2(a);
floor(a);
log(a);
log2(a);
sin(a);
sinh(a);
sqrt(a);
tan(a);
tanh(a);
atan2(a, b);
pow(a, b);
auto graph = StopRecordingExpressions();
std::vector<std::string> expected_code = {"{",
" double v_0;",
" double v_1;",
" double v_2;",
" double v_3;",
" double v_4;",
" double v_5;",
" double v_6;",
" double v_7;",
" double v_8;",
" double v_9;",
" double v_10;",
" double v_11;",
" double v_12;",
" double v_13;",
" double v_14;",
" double v_15;",
" double v_16;",
" double v_17;",
" double v_18;",
" double v_19;",
" double v_20;",
" double v_21;",
" v_0 = 0;",
" v_1 = 1;",
" v_2 = abs(v_0);",
" v_3 = acos(v_0);",
" v_4 = asin(v_0);",
" v_5 = atan(v_0);",
" v_6 = cbrt(v_0);",
" v_7 = ceil(v_0);",
" v_8 = cos(v_0);",
" v_9 = cosh(v_0);",
" v_10 = exp(v_0);",
" v_11 = exp2(v_0);",
" v_12 = floor(v_0);",
" v_13 = log(v_0);",
" v_14 = log2(v_0);",
" v_15 = sin(v_0);",
" v_16 = sinh(v_0);",
" v_17 = sqrt(v_0);",
" v_18 = tan(v_0);",
" v_19 = tanh(v_0);",
" v_20 = atan2(v_0, v_1);",
" v_21 = pow(v_0, v_1);",
"}"};
GenerateAndCheck(graph, expected_code);
}
TEST(CodeGenerator, IF_SIMPLE) {
StartRecordingExpressions();
T a = T(0);
T b = T(1);
auto r1 = a < b;
CERES_IF(r1) {}
CERES_ELSE {}
CERES_ENDIF;
auto graph = StopRecordingExpressions();
std::vector<std::string> expected_code = {"{",
" double v_0;",
" double v_1;",
" bool v_2;",
" v_0 = 0;",
" v_1 = 1;",
" v_2 = v_0 < v_1;",
" if (v_2) {",
" } else {",
" }",
"}"};
GenerateAndCheck(graph, expected_code);
}
TEST(CodeGenerator, IF_ASSIGNMENT) {
StartRecordingExpressions();
T a = T(0);
T b = T(1);
auto r1 = a < b;
T result = 0;
CERES_IF(r1) { result = 5.0; }
CERES_ELSE { result = 6.0; }
CERES_ENDIF;
MakeOutput(result, "result");
auto graph = StopRecordingExpressions();
std::vector<std::string> expected_code = {"{",
" double v_0;",
" double v_1;",
" bool v_2;",
" double v_3;",
" double v_5;",
" double v_8;",
" double v_11;",
" v_0 = 0;",
" v_1 = 1;",
" v_2 = v_0 < v_1;",
" v_3 = 0;",
" if (v_2) {",
" v_5 = 5;",
" v_3 = v_5;",
" } else {",
" v_8 = 6;",
" v_3 = v_8;",
" }",
" result = v_3;",
"}"};
GenerateAndCheck(graph, expected_code);
}
TEST(CodeGenerator, IF_NESTED_ASSIGNMENT) {
StartRecordingExpressions();
T a = T(0);
T b = T(1);
T result = 0;
CERES_IF(a <= b) {
result = 5.0;
CERES_IF(a == b) { result = 7.0; }
CERES_ENDIF;
}
CERES_ELSE { result = 6.0; }
CERES_ENDIF;
MakeOutput(result, "result");
auto graph = StopRecordingExpressions();
std::vector<std::string> expected_code = {"{",
" double v_0;",
" double v_1;",
" double v_2;",
" bool v_3;",
" double v_5;",
" bool v_7;",
" double v_9;",
" double v_13;",
" double v_16;",
" v_0 = 0;",
" v_1 = 1;",
" v_2 = 0;",
" v_3 = v_0 <= v_1;",
" if (v_3) {",
" v_5 = 5;",
" v_2 = v_5;",
" v_7 = v_0 == v_1;",
" if (v_7) {",
" v_9 = 7;",
" v_2 = v_9;",
" }",
" } else {",
" v_13 = 6;",
" v_2 = v_13;",
" }",
" result = v_2;",
"}"};
GenerateAndCheck(graph, expected_code);
}
} // namespace internal
} // namespace ceres