Improve testing of the codegen system
- Move codegen tests to a sub directory
- Add tests for all functions of Expression, ExpressionRef, and ExpressionGraph
- Respect dependencies during tests: The ExpressionGraph test doesn't
use ExpressionRef anymore.
The new tests revealed a few bugs so the following changes were made:
- Expression::MakeNop now resets the current expression with the default
constructed NOP expression
- ExpressionGraph::Insert now updates the lhs_id the same way as
InsertBack()
Change-Id: I6a18925c1e4d972c29ec1219f2073b4eaf2df737
diff --git a/include/ceres/codegen/internal/expression.h b/include/ceres/codegen/internal/expression.h
index 7d93a4d..808d741 100644
--- a/include/ceres/codegen/internal/expression.h
+++ b/include/ceres/codegen/internal/expression.h
@@ -249,6 +249,7 @@
// ExpressionGraph (see expression_graph.h).
class Expression {
public:
+ // Creates a NOP expression.
Expression() = default;
Expression(ExpressionType type,
diff --git a/include/ceres/codegen/internal/expression_graph.h b/include/ceres/codegen/internal/expression_graph.h
index 30b5a78..331a2d9 100644
--- a/include/ceres/codegen/internal/expression_graph.h
+++ b/include/ceres/codegen/internal/expression_graph.h
@@ -53,6 +53,9 @@
bool DependsOn(ExpressionId A, ExpressionId B) const;
bool operator==(const ExpressionGraph& other) const;
+ bool operator!=(const ExpressionGraph& other) const {
+ return !(*this == other);
+ }
Expression& ExpressionForId(ExpressionId id) { return expressions_[id]; }
const Expression& ExpressionForId(ExpressionId id) const {
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt
index 0733f47..af13f35 100644
--- a/internal/ceres/CMakeLists.txt
+++ b/internal/ceres/CMakeLists.txt
@@ -504,33 +504,7 @@
add_subdirectory(generated_bundle_adjustment_tests)
if(CODE_GENERATION)
- # Testing the AutoDiffCodegen system is more complicated, because function- and
- # constructor calls have side-effects. In C++ the evaluation order and
- # the elision of copies is implementation defined. Between different compilers,
- # some expression might be evaluated in a different order or some copies might be
- # removed.
- #
- # Therefore, we run tests that expect a particular compiler behaviour only on gcc.
- #
- # The semantic tests, which check the correctness by executing the generated code
- # are still run on all platforms.
- if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
- set(CXX_FLAGS_OLD ${CMAKE_CXX_FLAGS})
- # From the man page:
- # The C++ standard allows an implementation to omit creating a
- # temporary which is only used to initialize another object of the
- # same type. Specifying -fno-elide-constructors disables that
- # optimization, and forces G++ to call the copy constructor in all cases.
- # We use this flag to get the same results between different versions of
- # gcc and different optimization levels. Without this flag, testing would
- # be very painfull and might break when a new compiler version is released.
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-elide-constructors")
- ceres_test(expression)
- ceres_test(code_generator)
- ceres_test(conditional_expressions)
- ceres_test(expression_graph)
- set(CMAKE_CXX_FLAGS ${CXX_FLAGS_OLD})
- endif()
+ add_subdirectory(codegen)
endif()
endif (BUILD_TESTING AND GFLAGS)
diff --git a/internal/ceres/codegen/CMakeLists.txt b/internal/ceres/codegen/CMakeLists.txt
new file mode 100644
index 0000000..7362aef
--- /dev/null
+++ b/internal/ceres/codegen/CMakeLists.txt
@@ -0,0 +1,57 @@
+# Ceres Solver - A fast non-linear least squares minimizer
+# Copyright 2019 Google Inc. All rights reserved.
+# http://ceres-solver.org/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+# * Neither the name of Google Inc. nor the names of its contributors may be
+# used to endorse or promote products derived from this software without
+# specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+#
+# Author: darius.rueckert@fau.de (Darius Rueckert)
+#
+# Testing the AutoDiffCodegen system is more complicated, because function- and
+# constructor calls have side-effects. In C++ the evaluation order and
+# the elision of copies is implementation defined. Between different compilers,
+# some expression might be evaluated in a different order or some copies might be
+# removed.
+#
+# Therefore, we run tests that expect a particular compiler behaviour only on gcc.
+#
+# The semantic tests, which check the correctness by executing the generated code
+# are still run on all platforms.
+if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+ set(CXX_FLAGS_OLD ${CMAKE_CXX_FLAGS})
+ # From the man page:
+ # The C++ standard allows an implementation to omit creating a
+ # temporary which is only used to initialize another object of the
+ # same type. Specifying -fno-elide-constructors disables that
+ # optimization, and forces G++ to call the copy constructor in all cases.
+ # We use this flag to get the same results between different versions of
+ # gcc and different optimization levels. Without this flag, testing would
+ # be very painfull and might break when a new compiler version is released.
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-elide-constructors")
+ ceres_test(expression)
+ ceres_test(expression_graph)
+ ceres_test(expression_ref)
+ ceres_test(code_generator)
+ set(CMAKE_CXX_FLAGS ${CXX_FLAGS_OLD})
+endif()
diff --git a/internal/ceres/code_generator_test.cc b/internal/ceres/codegen/code_generator_test.cc
similarity index 90%
rename from internal/ceres/code_generator_test.cc
rename to internal/ceres/codegen/code_generator_test.cc
index 2ac7ee3..f974b1a 100644
--- a/internal/ceres/code_generator_test.cc
+++ b/internal/ceres/codegen/code_generator_test.cc
@@ -97,7 +97,6 @@
}
TEST(CodeGenerator, OUTPUT_ASSIGNMENT) {
- double local_variable = 5.0;
StartRecordingExpressions();
T a = 1;
T b = 0;
@@ -139,8 +138,8 @@
TEST(CodeGenerator, BINARY_ARITHMETIC_SIMPLE) {
StartRecordingExpressions();
- T a = T(0);
- T b = T(1);
+ T a = T(1);
+ T b = T(2);
T r1 = a + b;
T r2 = a - b;
T r3 = a * b;
@@ -153,8 +152,8 @@
" double v_3;",
" double v_4;",
" double v_5;",
- " v_0 = 0;",
- " v_1 = 1;",
+ " v_0 = 1;",
+ " v_1 = 2;",
" v_2 = v_0 + v_1;",
" v_3 = v_0 - v_1;",
" v_4 = v_0 * v_1;",
@@ -163,17 +162,40 @@
GenerateAndCheck(graph, expected_code);
}
+TEST(CodeGenerator, BINARY_ARITHMETIC_NESTED) {
+ StartRecordingExpressions();
+ T a = T(1);
+ T b = T(2);
+ T r1 = b - a * (a + b) / a;
+ auto graph = StopRecordingExpressions();
+ std::vector<std::string> expected_code = {"{",
+ " double v_0;",
+ " double v_1;",
+ " double v_2;",
+ " double v_3;",
+ " double v_4;",
+ " double v_5;",
+ " v_0 = 1;",
+ " v_1 = 2;",
+ " v_2 = v_0 + v_1;",
+ " v_3 = v_0 * v_2;",
+ " v_4 = v_3 / v_0;",
+ " v_5 = v_1 - v_4;",
+ "}"};
+ GenerateAndCheck(graph, expected_code);
+}
+
TEST(CodeGenerator, BINARY_ARITHMETIC_COMPOUND) {
// For each binary compound arithmetic operation, two lines are generated:
// - The actual operation assigning to a new temporary variable
// - An assignment from the temporary to the lhs
StartRecordingExpressions();
- T a = T(0);
- T b = T(1);
- b += a;
- b -= a;
- b *= a;
- b /= a;
+ T a = T(1);
+ T b = T(2);
+ a += b;
+ a -= b;
+ a *= b;
+ a /= b;
auto graph = StopRecordingExpressions();
std::vector<std::string> expected_code = {"{",
" double v_0;",
@@ -182,16 +204,16 @@
" double v_4;",
" double v_6;",
" double v_8;",
- " v_0 = 0;",
- " v_1 = 1;",
- " v_2 = v_1 + v_0;",
- " v_1 = v_2;",
- " v_4 = v_1 - v_0;",
- " v_1 = v_4;",
- " v_6 = v_1 * v_0;",
- " v_1 = v_6;",
- " v_8 = v_1 / v_0;",
- " v_1 = v_8;",
+ " v_0 = 1;",
+ " v_1 = 2;",
+ " v_2 = v_0 + v_1;",
+ " v_0 = v_2;",
+ " v_4 = v_0 - v_1;",
+ " v_0 = v_4;",
+ " v_6 = v_0 * v_1;",
+ " v_0 = v_6;",
+ " v_8 = v_0 / v_1;",
+ " v_0 = v_8;",
"}"};
GenerateAndCheck(graph, expected_code);
}
diff --git a/internal/ceres/codegen/expression_graph_test.cc b/internal/ceres/codegen/expression_graph_test.cc
new file mode 100644
index 0000000..b08a9b3
--- /dev/null
+++ b/internal/ceres/codegen/expression_graph_test.cc
@@ -0,0 +1,182 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2019 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: darius.rueckert@fau.de (Darius Rueckert)
+//
+// This file tests the ExpressionGraph class. This test depends on the
+// correctness of Expression.
+//
+#include "ceres/codegen/internal/expression.h"
+#include "ceres/codegen/internal/expression_graph.h"
+
+#include "gtest/gtest.h"
+
+namespace ceres {
+namespace internal {
+
+TEST(ExpressionGraph, Size) {
+ ExpressionGraph graph;
+ EXPECT_EQ(graph.Size(), 0);
+ // Insert 3 NOPs
+ graph.InsertBack(Expression());
+ graph.InsertBack(Expression());
+ graph.InsertBack(Expression());
+ EXPECT_EQ(graph.Size(), 3);
+}
+
+TEST(ExpressionGraph, Recording) {
+ EXPECT_EQ(GetCurrentExpressionGraph(), nullptr);
+ StartRecordingExpressions();
+ EXPECT_NE(GetCurrentExpressionGraph(), nullptr);
+ auto graph = StopRecordingExpressions();
+ EXPECT_EQ(graph, ExpressionGraph());
+ EXPECT_EQ(GetCurrentExpressionGraph(), nullptr);
+}
+
+TEST(ExpressionGraph, InsertBackControl) {
+ // Control expression are inserted to the back without any modifications.
+ auto expr1 = Expression::CreateIf(ExpressionId(0));
+ auto expr2 = Expression::CreateElse();
+ auto expr3 = Expression::CreateEndIf();
+
+ ExpressionGraph graph;
+ graph.InsertBack(expr1);
+ graph.InsertBack(expr2);
+ graph.InsertBack(expr3);
+
+ EXPECT_EQ(graph.Size(), 3);
+ EXPECT_EQ(graph.ExpressionForId(0), expr1);
+ EXPECT_EQ(graph.ExpressionForId(1), expr2);
+ EXPECT_EQ(graph.ExpressionForId(2), expr3);
+}
+
+TEST(ExpressionGraph, InsertBackNewVariable) {
+ // If an arithmetic expression with lhs=kinvalidValue is inserted in the back,
+ // then a new variable name is created and set to the lhs_id.
+ auto expr1 = Expression::CreateCompileTimeConstant(42);
+ auto expr2 = Expression::CreateCompileTimeConstant(10);
+ auto expr3 =
+ Expression::CreateBinaryArithmetic("+", ExpressionId(0), ExpressionId(1));
+
+ ExpressionGraph graph;
+ graph.InsertBack(expr1);
+ graph.InsertBack(expr2);
+ graph.InsertBack(expr3);
+ EXPECT_EQ(graph.Size(), 3);
+
+ // The ExpressionGraph has a copy of the inserted expression with the correct
+ // lhs_ids. We set them here manually for comparision.
+ expr1.set_lhs_id(0);
+ expr2.set_lhs_id(1);
+ expr3.set_lhs_id(2);
+ EXPECT_EQ(graph.ExpressionForId(0), expr1);
+ EXPECT_EQ(graph.ExpressionForId(1), expr2);
+ EXPECT_EQ(graph.ExpressionForId(2), expr3);
+}
+
+TEST(ExpressionGraph, InsertBackExistingVariable) {
+ auto expr1 = Expression::CreateCompileTimeConstant(42);
+ auto expr2 = Expression::CreateCompileTimeConstant(10);
+ auto expr3 = Expression::CreateAssignment(1, 0);
+
+ ExpressionGraph graph;
+ graph.InsertBack(expr1);
+ graph.InsertBack(expr2);
+ graph.InsertBack(expr3);
+ EXPECT_EQ(graph.Size(), 3);
+
+ expr1.set_lhs_id(0);
+ expr2.set_lhs_id(1);
+ expr3.set_lhs_id(1);
+ EXPECT_EQ(graph.ExpressionForId(0), expr1);
+ EXPECT_EQ(graph.ExpressionForId(1), expr2);
+ EXPECT_EQ(graph.ExpressionForId(2), expr3);
+}
+
+TEST(ExpressionGraph, DependsOn) {
+ ExpressionGraph graph;
+ graph.InsertBack(Expression::CreateCompileTimeConstant(42));
+ graph.InsertBack(Expression::CreateCompileTimeConstant(10));
+ graph.InsertBack(Expression::CreateBinaryArithmetic(
+ "+", ExpressionId(0), ExpressionId(1)));
+ graph.InsertBack(Expression::CreateBinaryArithmetic(
+ "+", ExpressionId(2), ExpressionId(0)));
+
+ // Code:
+ // v_0 = 42
+ // v_1 = 10
+ // v_2 = v_0 + v_1
+ // v_3 = v_2 + v_0
+
+ // Direct dependencies dependency check
+ ASSERT_TRUE(graph.DependsOn(2, 0));
+ ASSERT_TRUE(graph.DependsOn(2, 1));
+ ASSERT_TRUE(graph.DependsOn(3, 2));
+ ASSERT_TRUE(graph.DependsOn(3, 0));
+ ASSERT_FALSE(graph.DependsOn(1, 0));
+ ASSERT_FALSE(graph.DependsOn(1, 1));
+ ASSERT_FALSE(graph.DependsOn(2, 3));
+
+ // Recursive test
+ ASSERT_TRUE(graph.DependsOn(3, 1));
+}
+
+TEST(ExpressionGraph, InsertExpression_UpdateReferences) {
+ // This test checks if references to shifted expressions are updated
+ // accordingly.
+ ExpressionGraph graph;
+ graph.InsertBack(Expression::CreateCompileTimeConstant(42));
+ graph.InsertBack(Expression::CreateCompileTimeConstant(10));
+ graph.InsertBack(Expression::CreateBinaryArithmetic(
+ "+", ExpressionId(0), ExpressionId(1)));
+ // Code:
+ // v_0 = 42
+ // v_1 = 10
+ // v_2 = v_0 + v_1
+
+ // Insert another compile time constant at the beginning
+ graph.Insert(0, Expression::CreateCompileTimeConstant(5));
+ // This should shift all indices like this:
+ // v_0 = 5
+ // v_1 = 42
+ // v_2 = 10
+ // v_3 = v_1 + v_2
+
+ // Test by inserting it in the correct order
+ ExpressionGraph ref;
+ ref.InsertBack(Expression::CreateCompileTimeConstant(5));
+ ref.InsertBack(Expression::CreateCompileTimeConstant(42));
+ ref.InsertBack(Expression::CreateCompileTimeConstant(10));
+ ref.InsertBack(Expression::CreateBinaryArithmetic(
+ "+", ExpressionId(1), ExpressionId(2)));
+ EXPECT_EQ(graph.Size(), ref.Size());
+ EXPECT_EQ(graph, ref);
+}
+
+} // namespace internal
+} // namespace ceres
diff --git a/internal/ceres/codegen/expression_ref_test.cc b/internal/ceres/codegen/expression_ref_test.cc
new file mode 100644
index 0000000..aeb9d2b
--- /dev/null
+++ b/internal/ceres/codegen/expression_ref_test.cc
@@ -0,0 +1,387 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2019 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: darius.rueckert@fau.de (Darius Rueckert)
+//
+// This file tests the ExpressionRef class. This test depends on the
+// correctness of Expression and ExpressionGraph.
+//
+#define CERES_CODEGEN
+
+#include "ceres/codegen/internal/expression_graph.h"
+#include "ceres/codegen/internal/expression_ref.h"
+#include "gtest/gtest.h"
+
+namespace ceres {
+namespace internal {
+
+using T = ExpressionRef;
+
+TEST(ExpressionRef, COMPILE_TIME_CONSTANT) {
+ StartRecordingExpressions();
+ T a = T(0);
+ T b = T(123.5);
+ T c = T(1 + 1);
+ T d; // Uninitialized variables should not generate code!
+ auto graph = StopRecordingExpressions();
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(0));
+ reference.InsertBack(Expression::CreateCompileTimeConstant(123.5));
+ reference.InsertBack(Expression::CreateCompileTimeConstant(2));
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(ExpressionRef, INPUT_ASSIGNMENT) {
+ double local_variable = 5.0;
+ StartRecordingExpressions();
+ T a = CERES_LOCAL_VARIABLE(T, local_variable);
+ T b = MakeParameter("parameters[0][0]");
+ auto graph = StopRecordingExpressions();
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateInputAssignment("local_variable"));
+ reference.InsertBack(Expression::CreateInputAssignment("parameters[0][0]"));
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(ExpressionRef, OUTPUT_ASSIGNMENT) {
+ StartRecordingExpressions();
+ T a = 1;
+ T b = 0;
+ MakeOutput(a, "residual[0]");
+ auto graph = StopRecordingExpressions();
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(1));
+ reference.InsertBack(Expression::CreateCompileTimeConstant(0));
+ reference.InsertBack(Expression::CreateOutputAssignment(0, "residual[0]"));
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(ExpressionRef, Assignment) {
+ StartRecordingExpressions();
+ T a = 1;
+ T b = 2;
+ b = a;
+ auto graph = StopRecordingExpressions();
+ EXPECT_EQ(graph.Size(), 3);
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(1));
+ reference.InsertBack(Expression::CreateCompileTimeConstant(2));
+ reference.InsertBack(Expression::CreateAssignment(1, 0));
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(ExpressionRef, AssignmentCreate) {
+ StartRecordingExpressions();
+ T a = 2;
+ T b = a;
+ auto graph = StopRecordingExpressions();
+ EXPECT_EQ(graph.Size(), 2);
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(2));
+ reference.InsertBack(Expression::CreateAssignment(kInvalidExpressionId, 0));
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(ExpressionRef, MoveAssignmentCreate) {
+ StartRecordingExpressions();
+ T a = 2;
+ T b = std::move(a);
+ auto graph = StopRecordingExpressions();
+ EXPECT_EQ(graph.Size(), 1);
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(2));
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(ExpressionRef, MoveAssignment) {
+ StartRecordingExpressions();
+ T a = 1;
+ T b = 2;
+ b = std::move(a);
+ auto graph = StopRecordingExpressions();
+ EXPECT_EQ(graph.Size(), 3);
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(1));
+ reference.InsertBack(Expression::CreateCompileTimeConstant(2));
+ reference.InsertBack(Expression::CreateAssignment(1, 0));
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(ExpressionRef, BINARY_ARITHMETIC_SIMPLE) {
+ StartRecordingExpressions();
+ T a = T(1);
+ T b = T(2);
+ T r1 = a + b;
+ T r2 = a - b;
+ T r3 = a * b;
+ T r4 = a / b;
+ auto graph = StopRecordingExpressions();
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(1));
+ reference.InsertBack(Expression::CreateCompileTimeConstant(2));
+ reference.InsertBack(Expression::CreateBinaryArithmetic("+", 0, 1));
+ reference.InsertBack(Expression::CreateBinaryArithmetic("-", 0, 1));
+ reference.InsertBack(Expression::CreateBinaryArithmetic("*", 0, 1));
+ reference.InsertBack(Expression::CreateBinaryArithmetic("/", 0, 1));
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(ExpressionRef, BINARY_ARITHMETIC_NESTED) {
+ StartRecordingExpressions();
+ T a = T(1);
+ T b = T(2);
+ T r1 = b - a * (a + b) / a;
+ auto graph = StopRecordingExpressions();
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(1));
+ reference.InsertBack(Expression::CreateCompileTimeConstant(2));
+ reference.InsertBack(Expression::CreateBinaryArithmetic("+", 0, 1));
+ reference.InsertBack(Expression::CreateBinaryArithmetic("*", 0, 2));
+ reference.InsertBack(Expression::CreateBinaryArithmetic("/", 3, 0));
+ reference.InsertBack(Expression::CreateBinaryArithmetic("-", 1, 4));
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(CodeGenerator, BINARY_ARITHMETIC_COMPOUND) {
+ // For each binary compound arithmetic operation, two lines are generated:
+ // - The actual operation assigning to a new temporary variable
+ // - An assignment from the temporary to the lhs
+ StartRecordingExpressions();
+ T a = T(1);
+ T b = T(2);
+ a += b;
+ a -= b;
+ a *= b;
+ a /= b;
+ auto graph = StopRecordingExpressions();
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(1));
+ reference.InsertBack(Expression::CreateCompileTimeConstant(2));
+ reference.InsertBack(Expression::CreateBinaryArithmetic("+", 0, 1));
+ reference.InsertBack(Expression::CreateAssignment(0, 2));
+ reference.InsertBack(Expression::CreateBinaryArithmetic("-", 0, 1));
+ reference.InsertBack(Expression::CreateAssignment(0, 4));
+ reference.InsertBack(Expression::CreateBinaryArithmetic("*", 0, 1));
+ reference.InsertBack(Expression::CreateAssignment(0, 6));
+ reference.InsertBack(Expression::CreateBinaryArithmetic("/", 0, 1));
+ reference.InsertBack(Expression::CreateAssignment(0, 8));
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(CodeGenerator, UNARY_ARITHMETIC) {
+ StartRecordingExpressions();
+ T a = T(1);
+ T r1 = -a;
+ T r2 = +a;
+ auto graph = StopRecordingExpressions();
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(1));
+ reference.InsertBack(Expression::CreateUnaryArithmetic("-", 0));
+ reference.InsertBack(Expression::CreateUnaryArithmetic("+", 0));
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(CodeGenerator, BINARY_COMPARISON) {
+ using BOOL = ComparisonExpressionRef;
+ StartRecordingExpressions();
+ T a = T(1);
+ T b = T(2);
+ BOOL r1 = a < b;
+ BOOL r2 = a <= b;
+ BOOL r3 = a > b;
+ BOOL r4 = a >= b;
+ BOOL r5 = a == b;
+ BOOL r6 = a != b;
+ auto graph = StopRecordingExpressions();
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(1));
+ reference.InsertBack(Expression::CreateCompileTimeConstant(2));
+ reference.InsertBack(Expression::CreateBinaryCompare("<", 0, 1));
+ reference.InsertBack(Expression::CreateBinaryCompare("<=", 0, 1));
+ reference.InsertBack(Expression::CreateBinaryCompare(">", 0, 1));
+ reference.InsertBack(Expression::CreateBinaryCompare(">=", 0, 1));
+ reference.InsertBack(Expression::CreateBinaryCompare("==", 0, 1));
+ reference.InsertBack(Expression::CreateBinaryCompare("!=", 0, 1));
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(CodeGenerator, LOGICAL_OPERATORS) {
+ using BOOL = ComparisonExpressionRef;
+ // Tests binary logical operators &&, || and the unary logical operator !
+ StartRecordingExpressions();
+ T a = T(1);
+ T b = T(2);
+ BOOL r1 = a < b;
+ BOOL r2 = a <= b;
+ BOOL r3 = r1 && r2;
+ BOOL r4 = r1 || r2;
+ BOOL r5 = !r1;
+ auto graph = StopRecordingExpressions();
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(1));
+ reference.InsertBack(Expression::CreateCompileTimeConstant(2));
+ reference.InsertBack(Expression::CreateBinaryCompare("<", 0, 1));
+ reference.InsertBack(Expression::CreateBinaryCompare("<=", 0, 1));
+ reference.InsertBack(Expression::CreateBinaryCompare("&&", 2, 3));
+ reference.InsertBack(Expression::CreateBinaryCompare("||", 2, 3));
+ reference.InsertBack(Expression::CreateLogicalNegation(2));
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(CodeGenerator, FUNCTION_CALL) {
+ StartRecordingExpressions();
+ T a = T(1);
+ T b = T(2);
+ abs(a);
+ acos(a);
+ asin(a);
+ atan(a);
+ cbrt(a);
+ ceil(a);
+ cos(a);
+ cosh(a);
+ exp(a);
+ exp2(a);
+ floor(a);
+ log(a);
+ log2(a);
+ sin(a);
+ sinh(a);
+ sqrt(a);
+ tan(a);
+ tanh(a);
+ atan2(a, b);
+ pow(a, b);
+ auto graph = StopRecordingExpressions();
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(1));
+ reference.InsertBack(Expression::CreateCompileTimeConstant(2));
+ reference.InsertBack(Expression::CreateFunctionCall("abs", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("acos", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("asin", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("atan", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("cbrt", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("ceil", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("cos", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("cosh", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("exp", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("exp2", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("floor", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("log", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("log2", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("sin", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("sinh", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("sqrt", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("tan", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("tanh", {0}));
+ reference.InsertBack(Expression::CreateFunctionCall("atan2", {0, 1}));
+ reference.InsertBack(Expression::CreateFunctionCall("pow", {0, 1}));
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(CodeGenerator, IF) {
+ StartRecordingExpressions();
+ T a = T(1);
+ T b = T(2);
+ auto r1 = a < b;
+ CERES_IF(r1) {}
+ CERES_ENDIF;
+ auto graph = StopRecordingExpressions();
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(1));
+ reference.InsertBack(Expression::CreateCompileTimeConstant(2));
+ reference.InsertBack(Expression::CreateBinaryCompare("<", 0, 1));
+ reference.InsertBack(Expression::CreateIf(2));
+ reference.InsertBack(Expression::CreateEndIf());
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(CodeGenerator, IF_ELSE) {
+ StartRecordingExpressions();
+ T a = T(1);
+ T b = T(2);
+ auto r1 = a < b;
+ CERES_IF(r1) {}
+ CERES_ELSE {}
+ CERES_ENDIF;
+ auto graph = StopRecordingExpressions();
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(1));
+ reference.InsertBack(Expression::CreateCompileTimeConstant(2));
+ reference.InsertBack(Expression::CreateBinaryCompare("<", 0, 1));
+ reference.InsertBack(Expression::CreateIf(2));
+ reference.InsertBack(Expression::CreateElse());
+ reference.InsertBack(Expression::CreateEndIf());
+ EXPECT_EQ(reference, graph);
+}
+
+TEST(CodeGenerator, IF_NESTED) {
+ StartRecordingExpressions();
+ T a = T(1);
+ T b = T(2);
+ auto r1 = a < b;
+ auto r2 = a == b;
+ CERES_IF(r1) {
+ CERES_IF(r2) {}
+ CERES_ENDIF;
+ }
+ CERES_ELSE {}
+ CERES_ENDIF;
+ auto graph = StopRecordingExpressions();
+
+ ExpressionGraph reference;
+ reference.InsertBack(Expression::CreateCompileTimeConstant(1));
+ reference.InsertBack(Expression::CreateCompileTimeConstant(2));
+ reference.InsertBack(Expression::CreateBinaryCompare("<", 0, 1));
+ reference.InsertBack(Expression::CreateBinaryCompare("==", 0, 1));
+ reference.InsertBack(Expression::CreateIf(2));
+ reference.InsertBack(Expression::CreateIf(3));
+ reference.InsertBack(Expression::CreateEndIf());
+ reference.InsertBack(Expression::CreateElse());
+ reference.InsertBack(Expression::CreateEndIf());
+ EXPECT_EQ(reference, graph);
+}
+
+} // namespace internal
+} // namespace ceres
diff --git a/internal/ceres/expression_test.cc b/internal/ceres/codegen/expression_test.cc
similarity index 63%
rename from internal/ceres/expression_test.cc
rename to internal/ceres/codegen/expression_test.cc
index cab196b..395bbaf 100644
--- a/internal/ceres/expression_test.cc
+++ b/internal/ceres/codegen/expression_test.cc
@@ -28,11 +28,10 @@
//
// Author: darius.rueckert@fau.de (Darius Rueckert)
//
-#define CERES_CODEGEN
-
-#include "ceres/codegen/internal/expression_graph.h"
-#include "ceres/codegen/internal/expression_ref.h"
-#include "ceres/jet.h"
+// This file tests the Expression class. For each member function one test is
+// included here.
+//
+#include "ceres/codegen/internal/expression.h"
#include "gtest/gtest.h"
namespace ceres {
@@ -168,135 +167,61 @@
ASSERT_FALSE(expr1.IsReplaceableBy(expr3));
}
+TEST(Expression, Replace) {
+ auto expr1 =
+ Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5));
+ expr1.set_lhs_id(13);
+
+ auto expr2 =
+ Expression::CreateAssignment(kInvalidExpressionId, ExpressionId(7));
+
+ // We replace the arithmetic expr1 by an assignment from the variable 7. This
+ // is the typical usecase in subexpression elimination.
+ expr1.Replace(expr2);
+
+ // expr1 should now be an assignment from 7 to 13
+ EXPECT_EQ(expr1, Expression(ExpressionType::ASSIGNMENT, 13, {7}, "", 0));
+}
+
TEST(Expression, DirectlyDependsOn) {
- using T = ExpressionRef;
+ auto expr1 =
+ Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5));
- StartRecordingExpressions();
+ ASSERT_TRUE(expr1.DirectlyDependsOn(ExpressionId(3)));
+ ASSERT_TRUE(expr1.DirectlyDependsOn(ExpressionId(5)));
+ ASSERT_FALSE(expr1.DirectlyDependsOn(ExpressionId(kInvalidExpressionId)));
+ ASSERT_FALSE(expr1.DirectlyDependsOn(ExpressionId(42)));
+}
+TEST(Expression, MakeNop) {
+ auto expr1 =
+ Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5));
- T unused(6);
- T a(2), b(3);
- T c = a + b;
- T d = c + a;
+ expr1.MakeNop();
- auto graph = StopRecordingExpressions();
-
- ASSERT_FALSE(graph.ExpressionForId(a.id).DirectlyDependsOn(unused.id));
- ASSERT_TRUE(graph.ExpressionForId(c.id).DirectlyDependsOn(a.id));
- ASSERT_TRUE(graph.ExpressionForId(c.id).DirectlyDependsOn(b.id));
- ASSERT_TRUE(graph.ExpressionForId(d.id).DirectlyDependsOn(a.id));
- ASSERT_FALSE(graph.ExpressionForId(d.id).DirectlyDependsOn(b.id));
- ASSERT_TRUE(graph.ExpressionForId(d.id).DirectlyDependsOn(c.id));
+ EXPECT_EQ(expr1,
+ Expression(ExpressionType::NOP, kInvalidExpressionId, {}, "", 0));
}
-TEST(Expression, Ternary) {
- using T = ExpressionRef;
+TEST(Expression, IsSemanticallyEquivalentTo) {
+ // Create 2 identical expression
+ auto expr1 =
+ Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5));
- StartRecordingExpressions();
- T a(2); // 0
- T b(3); // 1
- auto c = a < b; // 2
- T d = Ternary(c, a, b); // 3
- MakeOutput(d, "result"); // 4
- auto graph = StopRecordingExpressions();
+ auto expr2 =
+ Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5));
- EXPECT_EQ(graph.Size(), 5);
+ ASSERT_TRUE(expr1.IsSemanticallyEquivalentTo(expr1));
+ ASSERT_TRUE(expr1.IsSemanticallyEquivalentTo(expr2));
- // Expected code
- // v_0 = 2;
- // v_1 = 3;
- // v_2 = v_0 < v_1;
- // v_3 = Ternary(v_2, v_0, v_1);
- // result = v_3;
+ auto expr3 =
+ Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(8));
- ExpressionGraph reference;
- // clang-format off
- // Id, Type, Lhs, Value, Name, Arguments
- reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 2});
- reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 1, {}, "", 3});
- reference.InsertBack({ ExpressionType::BINARY_COMPARISON, 2, {0,1}, "<", 0});
- reference.InsertBack({ ExpressionType::FUNCTION_CALL, 3, {2,0,1}, "Ternary", 0});
- reference.InsertBack({ ExpressionType::OUTPUT_ASSIGNMENT, 4, {3}, "result", 0});
- // clang-format on
- EXPECT_EQ(reference, graph);
-}
+ ASSERT_TRUE(expr1.IsSemanticallyEquivalentTo(expr3));
-TEST(Expression, Assignment) {
- using T = ExpressionRef;
+ auto expr4 =
+ Expression::CreateBinaryArithmetic("-", ExpressionId(3), ExpressionId(5));
- StartRecordingExpressions();
- T a = 1;
- T b = 2;
- b = a;
- auto graph = StopRecordingExpressions();
-
- EXPECT_EQ(graph.Size(), 3);
-
- ExpressionGraph reference;
- // clang-format off
- // Id, Type, Lhs, Value, Name, Arguments
- reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 1});
- reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 1, {}, "", 2});
- reference.InsertBack({ ExpressionType::ASSIGNMENT, 1, {0}, "", 0});
- // clang-format on
- EXPECT_EQ(reference, graph);
-}
-
-TEST(Expression, AssignmentCreate) {
- using T = ExpressionRef;
-
- StartRecordingExpressions();
- T a = 2;
- T b = a;
- auto graph = StopRecordingExpressions();
-
- EXPECT_EQ(graph.Size(), 2);
-
- ExpressionGraph reference;
- // clang-format off
- // Id, Type, Lhs, Value, Name, Arguments
- reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 2});
- reference.InsertBack({ ExpressionType::ASSIGNMENT, 1, {0}, "", 0});
- // clang-format on
- EXPECT_EQ(reference, graph);
-}
-
-TEST(Expression, MoveAssignmentCreate) {
- using T = ExpressionRef;
-
- StartRecordingExpressions();
- T a = 1;
- T b = std::move(a);
- auto graph = StopRecordingExpressions();
-
- EXPECT_EQ(graph.Size(), 1);
-
- ExpressionGraph reference;
- // clang-format off
- // Id, Type, Lhs, Value, Name, Arguments
- reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 1});
- // clang-format on
- EXPECT_EQ(reference, graph);
-}
-
-TEST(Expression, MoveAssignment) {
- using T = ExpressionRef;
-
- StartRecordingExpressions();
- T a = 1;
- T b = 2;
- b = std::move(a);
- auto graph = StopRecordingExpressions();
-
- EXPECT_EQ(graph.Size(), 3);
-
- ExpressionGraph reference;
- // clang-format off
- // Id, Type, Lhs, Value, Name, Arguments
- reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 1});
- reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 1, {}, "", 2});
- reference.InsertBack({ ExpressionType::ASSIGNMENT, 1, {0}, "", 0});
- // clang-format on
- EXPECT_EQ(reference, graph);
+ ASSERT_FALSE(expr1.IsSemanticallyEquivalentTo(expr4));
}
} // namespace internal
diff --git a/internal/ceres/conditional_expressions_test.cc b/internal/ceres/conditional_expressions_test.cc
deleted file mode 100644
index 9c499ea..0000000
--- a/internal/ceres/conditional_expressions_test.cc
+++ /dev/null
@@ -1,132 +0,0 @@
-// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2019 Google Inc. All rights reserved.
-// http://code.google.com/p/ceres-solver/
-//
-// Redistribution and use in source and binary forms, with or without
-// modification, are permitted provided that the following conditions are met:
-//
-// * Redistributions of source code must retain the above copyright notice,
-// this list of conditions and the following disclaimer.
-// * Redistributions in binary form must reproduce the above copyright notice,
-// this list of conditions and the following disclaimer in the documentation
-// and/or other materials provided with the distribution.
-// * Neither the name of Google Inc. nor the names of its contributors may be
-// used to endorse or promote products derived from this software without
-// specific prior written permission.
-//
-// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
-// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
-// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
-// POSSIBILITY OF SUCH DAMAGE.
-//
-// Author: darius.rueckert@fau.de (Darius Rueckert)
-//
-
-#define CERES_CODEGEN
-
-#include "ceres/codegen/internal/expression_graph.h"
-#include "ceres/codegen/internal/expression_ref.h"
-#include "gtest/gtest.h"
-
-namespace ceres {
-namespace internal {
-
-TEST(Expression, ConditionalMinimal) {
- using T = ExpressionRef;
-
- StartRecordingExpressions();
- T a(2);
- T b(3);
- auto c = a < b;
- CERES_IF(c) {}
- CERES_ELSE {}
- CERES_ENDIF
- auto graph = StopRecordingExpressions();
-
- // Expected code
- // v_0 = 2;
- // v_1 = 3;
- // v_2 = v_0 < v_1;
- // if(v_2);
- // else
- // endif
-
- EXPECT_EQ(graph.Size(), 6);
-
- ExpressionGraph reference;
- // clang-format off
- // Id, Type, Lhs, Value, Name, Arguments...
- reference.InsertBack({ExpressionType::COMPILE_TIME_CONSTANT, 0, {} , "", 2});
- reference.InsertBack({ExpressionType::COMPILE_TIME_CONSTANT, 1, {} , "", 3});
- reference.InsertBack({ ExpressionType::BINARY_COMPARISON, 2, {0, 1} , "<", 0});
- reference.InsertBack({ ExpressionType::IF, -1, {2} , "", 0});
- reference.InsertBack({ ExpressionType::ELSE, -1, {} , "", 0});
- reference.InsertBack({ ExpressionType::ENDIF, -1, {} , "", 0});
- // clang-format on
- EXPECT_EQ(reference, graph);
-}
-
-TEST(Expression, ConditionalAssignment) {
- using T = ExpressionRef;
-
- StartRecordingExpressions();
-
- T result;
- T a(2);
- T b(3);
- auto c = a < b;
- CERES_IF(c) { result = a + b; }
- CERES_ELSE { result = a - b; }
- CERES_ENDIF
- result += a;
- auto graph = StopRecordingExpressions();
-
- // Expected code
- // v_0 = 2;
- // v_1 = 3;
- // v_2 = v_0 < v_1;
- // if(v_2);
- // v_4 = v_0 + v_1;
- // else
- // v_6 = v_0 - v_1;
- // v_4 = v_6
- // endif
- // v_9 = v_4 + v_0;
- // v_4 = v_9;
-
- ExpressionGraph reference;
- // clang-format off
- // Id, Type, Lhs, Value, Name, Arguments...
- reference.InsertBack({ExpressionType::COMPILE_TIME_CONSTANT, 0, {} , "", 2});
- reference.InsertBack({ExpressionType::COMPILE_TIME_CONSTANT, 1, {} , "", 3});
- reference.InsertBack({ ExpressionType::BINARY_COMPARISON, 2, {0, 1}, "<", 0});
- reference.InsertBack({ ExpressionType::IF, -1, {2} , "", 0});
- reference.InsertBack({ ExpressionType::BINARY_ARITHMETIC, 4, {0, 1}, "+", 0});
- reference.InsertBack({ ExpressionType::ELSE, -1, {} , "", 0});
- reference.InsertBack({ ExpressionType::BINARY_ARITHMETIC, 6, {0, 1}, "-", 0});
- reference.InsertBack({ ExpressionType::ASSIGNMENT, 4, {6} , "", 0});
- reference.InsertBack({ ExpressionType::ENDIF, -1, {} , "", 0});
- reference.InsertBack({ ExpressionType::BINARY_ARITHMETIC, 9, {4, 0}, "+", 0});
- reference.InsertBack({ ExpressionType::ASSIGNMENT, 4, {9} , "", 0});
- // clang-format on
- EXPECT_EQ(reference, graph);
-
- // Variables after execution:
- //
- // a <=> v_0
- // b <=> v_1
- // result <=> v_4
- EXPECT_EQ(a.id, 0);
- EXPECT_EQ(b.id, 1);
- EXPECT_EQ(result.id, 4);
-}
-
-} // namespace internal
-} // namespace ceres
diff --git a/internal/ceres/expression.cc b/internal/ceres/expression.cc
index ca92eaf..27e15cb 100644
--- a/internal/ceres/expression.cc
+++ b/internal/ceres/expression.cc
@@ -142,8 +142,8 @@
}
void Expression::MakeNop() {
- type_ = ExpressionType::NOP;
- arguments_.clear();
+ // The default constructor creates a NOP expression!
+ *this = Expression();
}
bool Expression::operator==(const Expression& other) const {
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc
index c0761af..d9b33e9 100644
--- a/internal/ceres/expression_graph.cc
+++ b/internal/ceres/expression_graph.cc
@@ -108,8 +108,17 @@
expressions_[id + 1] = expression;
}
- // Insert new expression at the correct place
- expressions_[location] = expression;
+ if (expression.IsControlExpression() ||
+ expression.lhs_id() != kInvalidExpressionId) {
+ // Insert new expression at the correct place
+ expressions_[location] = expression;
+ } else {
+ // Arithmetic expression with invalid lhs
+ // -> Set lhs to location
+ Expression copy = expression;
+ copy.set_lhs_id(location);
+ expressions_[location] = copy;
+ }
}
ExpressionId ExpressionGraph::InsertBack(const Expression& expression) {
diff --git a/internal/ceres/expression_graph_test.cc b/internal/ceres/expression_graph_test.cc
deleted file mode 100644
index 65b019c..0000000
--- a/internal/ceres/expression_graph_test.cc
+++ /dev/null
@@ -1,160 +0,0 @@
-// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2019 Google Inc. All rights reserved.
-// http://code.google.com/p/ceres-solver/
-//
-// Redistribution and use in source and binary forms, with or without
-// modification, are permitted provided that the following conditions are met:
-//
-// * Redistributions of source code must retain the above copyright notice,
-// this list of conditions and the following disclaimer.
-// * Redistributions in binary form must reproduce the above copyright notice,
-// this list of conditions and the following disclaimer in the documentation
-// and/or other materials provided with the distribution.
-// * Neither the name of Google Inc. nor the names of its contributors may be
-// used to endorse or promote products derived from this software without
-// specific prior written permission.
-//
-// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
-// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
-// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
-// POSSIBILITY OF SUCH DAMAGE.
-//
-// Author: darius.rueckert@fau.de (Darius Rueckert)
-//
-// Test expression creation and logic.
-
-#include "ceres/codegen/internal/expression_graph.h"
-#include "ceres/codegen/internal/expression_ref.h"
-
-#include "gtest/gtest.h"
-
-namespace ceres {
-namespace internal {
-
-TEST(ExpressionGraph, Creation) {
- using T = ExpressionRef;
- ExpressionGraph graph;
-
- StartRecordingExpressions();
- graph = StopRecordingExpressions();
- EXPECT_EQ(graph.Size(), 0);
-
- StartRecordingExpressions();
- T a(1);
- T b(2);
- T c = a + b;
- graph = StopRecordingExpressions();
- EXPECT_EQ(graph.Size(), 3);
-}
-
-TEST(ExpressionGraph, Dependencies) {
- using T = ExpressionRef;
-
- StartRecordingExpressions();
-
- T unused(6);
- T a(2), b(3);
- T c = a + b;
- T d = c + a;
-
- auto tree = StopRecordingExpressions();
-
- // Recursive dependency check
- ASSERT_TRUE(tree.DependsOn(d.id, c.id));
- ASSERT_TRUE(tree.DependsOn(d.id, a.id));
- ASSERT_TRUE(tree.DependsOn(d.id, b.id));
- ASSERT_FALSE(tree.DependsOn(d.id, unused.id));
-}
-
-TEST(ExpressionGraph, InsertExpression_UpdateReferences) {
- // This test checks if references to shifted expressions are updated
- // accordingly.
- using T = ExpressionRef;
- StartRecordingExpressions();
- T a(2); // 0
- T b(3); // 1
- T c = a + b; // 2
- auto graph = StopRecordingExpressions();
-
- // Test if 'a' and 'c' are actually at location 0 and 2
- auto& a_expr = graph.ExpressionForId(0);
- EXPECT_EQ(a_expr.type(), ExpressionType::COMPILE_TIME_CONSTANT);
- EXPECT_EQ(a_expr.value(), 2);
-
- // At this point 'c' should have 0 and 1 as arguments.
- auto& c_expr = graph.ExpressionForId(2);
- EXPECT_EQ(c_expr.type(), ExpressionType::BINARY_ARITHMETIC);
- EXPECT_EQ(c_expr.arguments()[0], 0);
- EXPECT_EQ(c_expr.arguments()[1], 1);
-
- // We insert at the beginning, which shifts everything by one spot.
- graph.Insert(0, {ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 10.2});
-
- // Test if 'a' and 'c' are actually at location 1 and 3
- auto& a_expr2 = graph.ExpressionForId(1);
- EXPECT_EQ(a_expr2.type(), ExpressionType::COMPILE_TIME_CONSTANT);
- EXPECT_EQ(a_expr2.value(), 2);
-
- // At this point 'c' should have 1 and 2 as arguments.
- auto& c_expr2 = graph.ExpressionForId(3);
- EXPECT_EQ(c_expr2.type(), ExpressionType::BINARY_ARITHMETIC);
- EXPECT_EQ(c_expr2.arguments()[0], 1);
- EXPECT_EQ(c_expr2.arguments()[1], 2);
-}
-
-TEST(ExpressionGraph, InsertExpression) {
- using T = ExpressionRef;
-
- StartRecordingExpressions();
-
- {
- T a(2); // 0
- T b(3); // 1
- T five = 5; // 2
- T tmp = a + five; // 3
- a = tmp; // 4
- T c = a + b; // 5
- T d = a * b; // 6
- T e = c + d; // 7
- MakeOutput(e, "result"); // 8
- }
- auto reference = StopRecordingExpressions();
- EXPECT_EQ(reference.Size(), 9);
-
- StartRecordingExpressions();
-
- {
- // The expressions 2,3,4 from above are missing.
- T a(2); // 0
- T b(3); // 1
- T c = a + b; // 2
- T d = a * b; // 3
- T e = c + d; // 4
- MakeOutput(e, "result"); // 5
- }
-
- auto graph1 = StopRecordingExpressions();
- EXPECT_EQ(graph1.Size(), 6);
- ASSERT_FALSE(reference == graph1);
-
- // We manually insert the 3 missing expressions
- // clang-format off
- graph1.Insert(2,{ ExpressionType::COMPILE_TIME_CONSTANT, 2, {}, "", 5});
- graph1.Insert(3,{ ExpressionType::BINARY_ARITHMETIC, 3, {0, 2}, "+", 0});
- graph1.Insert(4,{ ExpressionType::ASSIGNMENT, 0, {3}, "", 0});
- // clang-format on
-
- // Now the graphs are identical!
- EXPECT_EQ(graph1.Size(), 9);
- ASSERT_TRUE(reference == graph1);
-}
-
-} // namespace internal
-} // namespace ceres