Improve testing of the codegen system - Move codegen tests to a sub directory - Add tests for all functions of Expression, ExpressionRef, and ExpressionGraph - Respect dependencies during tests: The ExpressionGraph test doesn't use ExpressionRef anymore. The new tests revealed a few bugs so the following changes were made: - Expression::MakeNop now resets the current expression with the default constructed NOP expression - ExpressionGraph::Insert now updates the lhs_id the same way as InsertBack() Change-Id: I6a18925c1e4d972c29ec1219f2073b4eaf2df737
diff --git a/include/ceres/codegen/internal/expression.h b/include/ceres/codegen/internal/expression.h index 7d93a4d..808d741 100644 --- a/include/ceres/codegen/internal/expression.h +++ b/include/ceres/codegen/internal/expression.h
@@ -249,6 +249,7 @@ // ExpressionGraph (see expression_graph.h). class Expression { public: + // Creates a NOP expression. Expression() = default; Expression(ExpressionType type,
diff --git a/include/ceres/codegen/internal/expression_graph.h b/include/ceres/codegen/internal/expression_graph.h index 30b5a78..331a2d9 100644 --- a/include/ceres/codegen/internal/expression_graph.h +++ b/include/ceres/codegen/internal/expression_graph.h
@@ -53,6 +53,9 @@ bool DependsOn(ExpressionId A, ExpressionId B) const; bool operator==(const ExpressionGraph& other) const; + bool operator!=(const ExpressionGraph& other) const { + return !(*this == other); + } Expression& ExpressionForId(ExpressionId id) { return expressions_[id]; } const Expression& ExpressionForId(ExpressionId id) const {
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt index 0733f47..af13f35 100644 --- a/internal/ceres/CMakeLists.txt +++ b/internal/ceres/CMakeLists.txt
@@ -504,33 +504,7 @@ add_subdirectory(generated_bundle_adjustment_tests) if(CODE_GENERATION) - # Testing the AutoDiffCodegen system is more complicated, because function- and - # constructor calls have side-effects. In C++ the evaluation order and - # the elision of copies is implementation defined. Between different compilers, - # some expression might be evaluated in a different order or some copies might be - # removed. - # - # Therefore, we run tests that expect a particular compiler behaviour only on gcc. - # - # The semantic tests, which check the correctness by executing the generated code - # are still run on all platforms. - if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - set(CXX_FLAGS_OLD ${CMAKE_CXX_FLAGS}) - # From the man page: - # The C++ standard allows an implementation to omit creating a - # temporary which is only used to initialize another object of the - # same type. Specifying -fno-elide-constructors disables that - # optimization, and forces G++ to call the copy constructor in all cases. - # We use this flag to get the same results between different versions of - # gcc and different optimization levels. Without this flag, testing would - # be very painfull and might break when a new compiler version is released. - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-elide-constructors") - ceres_test(expression) - ceres_test(code_generator) - ceres_test(conditional_expressions) - ceres_test(expression_graph) - set(CMAKE_CXX_FLAGS ${CXX_FLAGS_OLD}) - endif() + add_subdirectory(codegen) endif() endif (BUILD_TESTING AND GFLAGS)
diff --git a/internal/ceres/codegen/CMakeLists.txt b/internal/ceres/codegen/CMakeLists.txt new file mode 100644 index 0000000..7362aef --- /dev/null +++ b/internal/ceres/codegen/CMakeLists.txt
@@ -0,0 +1,57 @@ +# Ceres Solver - A fast non-linear least squares minimizer +# Copyright 2019 Google Inc. All rights reserved. +# http://ceres-solver.org/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of Google Inc. nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: darius.rueckert@fau.de (Darius Rueckert) +# +# Testing the AutoDiffCodegen system is more complicated, because function- and +# constructor calls have side-effects. In C++ the evaluation order and +# the elision of copies is implementation defined. Between different compilers, +# some expression might be evaluated in a different order or some copies might be +# removed. +# +# Therefore, we run tests that expect a particular compiler behaviour only on gcc. +# +# The semantic tests, which check the correctness by executing the generated code +# are still run on all platforms. +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(CXX_FLAGS_OLD ${CMAKE_CXX_FLAGS}) + # From the man page: + # The C++ standard allows an implementation to omit creating a + # temporary which is only used to initialize another object of the + # same type. Specifying -fno-elide-constructors disables that + # optimization, and forces G++ to call the copy constructor in all cases. + # We use this flag to get the same results between different versions of + # gcc and different optimization levels. Without this flag, testing would + # be very painfull and might break when a new compiler version is released. + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-elide-constructors") + ceres_test(expression) + ceres_test(expression_graph) + ceres_test(expression_ref) + ceres_test(code_generator) + set(CMAKE_CXX_FLAGS ${CXX_FLAGS_OLD}) +endif()
diff --git a/internal/ceres/code_generator_test.cc b/internal/ceres/codegen/code_generator_test.cc similarity index 90% rename from internal/ceres/code_generator_test.cc rename to internal/ceres/codegen/code_generator_test.cc index 2ac7ee3..f974b1a 100644 --- a/internal/ceres/code_generator_test.cc +++ b/internal/ceres/codegen/code_generator_test.cc
@@ -97,7 +97,6 @@ } TEST(CodeGenerator, OUTPUT_ASSIGNMENT) { - double local_variable = 5.0; StartRecordingExpressions(); T a = 1; T b = 0; @@ -139,8 +138,8 @@ TEST(CodeGenerator, BINARY_ARITHMETIC_SIMPLE) { StartRecordingExpressions(); - T a = T(0); - T b = T(1); + T a = T(1); + T b = T(2); T r1 = a + b; T r2 = a - b; T r3 = a * b; @@ -153,8 +152,8 @@ " double v_3;", " double v_4;", " double v_5;", - " v_0 = 0;", - " v_1 = 1;", + " v_0 = 1;", + " v_1 = 2;", " v_2 = v_0 + v_1;", " v_3 = v_0 - v_1;", " v_4 = v_0 * v_1;", @@ -163,17 +162,40 @@ GenerateAndCheck(graph, expected_code); } +TEST(CodeGenerator, BINARY_ARITHMETIC_NESTED) { + StartRecordingExpressions(); + T a = T(1); + T b = T(2); + T r1 = b - a * (a + b) / a; + auto graph = StopRecordingExpressions(); + std::vector<std::string> expected_code = {"{", + " double v_0;", + " double v_1;", + " double v_2;", + " double v_3;", + " double v_4;", + " double v_5;", + " v_0 = 1;", + " v_1 = 2;", + " v_2 = v_0 + v_1;", + " v_3 = v_0 * v_2;", + " v_4 = v_3 / v_0;", + " v_5 = v_1 - v_4;", + "}"}; + GenerateAndCheck(graph, expected_code); +} + TEST(CodeGenerator, BINARY_ARITHMETIC_COMPOUND) { // For each binary compound arithmetic operation, two lines are generated: // - The actual operation assigning to a new temporary variable // - An assignment from the temporary to the lhs StartRecordingExpressions(); - T a = T(0); - T b = T(1); - b += a; - b -= a; - b *= a; - b /= a; + T a = T(1); + T b = T(2); + a += b; + a -= b; + a *= b; + a /= b; auto graph = StopRecordingExpressions(); std::vector<std::string> expected_code = {"{", " double v_0;", @@ -182,16 +204,16 @@ " double v_4;", " double v_6;", " double v_8;", - " v_0 = 0;", - " v_1 = 1;", - " v_2 = v_1 + v_0;", - " v_1 = v_2;", - " v_4 = v_1 - v_0;", - " v_1 = v_4;", - " v_6 = v_1 * v_0;", - " v_1 = v_6;", - " v_8 = v_1 / v_0;", - " v_1 = v_8;", + " v_0 = 1;", + " v_1 = 2;", + " v_2 = v_0 + v_1;", + " v_0 = v_2;", + " v_4 = v_0 - v_1;", + " v_0 = v_4;", + " v_6 = v_0 * v_1;", + " v_0 = v_6;", + " v_8 = v_0 / v_1;", + " v_0 = v_8;", "}"}; GenerateAndCheck(graph, expected_code); }
diff --git a/internal/ceres/codegen/expression_graph_test.cc b/internal/ceres/codegen/expression_graph_test.cc new file mode 100644 index 0000000..b08a9b3 --- /dev/null +++ b/internal/ceres/codegen/expression_graph_test.cc
@@ -0,0 +1,182 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) +// +// This file tests the ExpressionGraph class. This test depends on the +// correctness of Expression. +// +#include "ceres/codegen/internal/expression.h" +#include "ceres/codegen/internal/expression_graph.h" + +#include "gtest/gtest.h" + +namespace ceres { +namespace internal { + +TEST(ExpressionGraph, Size) { + ExpressionGraph graph; + EXPECT_EQ(graph.Size(), 0); + // Insert 3 NOPs + graph.InsertBack(Expression()); + graph.InsertBack(Expression()); + graph.InsertBack(Expression()); + EXPECT_EQ(graph.Size(), 3); +} + +TEST(ExpressionGraph, Recording) { + EXPECT_EQ(GetCurrentExpressionGraph(), nullptr); + StartRecordingExpressions(); + EXPECT_NE(GetCurrentExpressionGraph(), nullptr); + auto graph = StopRecordingExpressions(); + EXPECT_EQ(graph, ExpressionGraph()); + EXPECT_EQ(GetCurrentExpressionGraph(), nullptr); +} + +TEST(ExpressionGraph, InsertBackControl) { + // Control expression are inserted to the back without any modifications. + auto expr1 = Expression::CreateIf(ExpressionId(0)); + auto expr2 = Expression::CreateElse(); + auto expr3 = Expression::CreateEndIf(); + + ExpressionGraph graph; + graph.InsertBack(expr1); + graph.InsertBack(expr2); + graph.InsertBack(expr3); + + EXPECT_EQ(graph.Size(), 3); + EXPECT_EQ(graph.ExpressionForId(0), expr1); + EXPECT_EQ(graph.ExpressionForId(1), expr2); + EXPECT_EQ(graph.ExpressionForId(2), expr3); +} + +TEST(ExpressionGraph, InsertBackNewVariable) { + // If an arithmetic expression with lhs=kinvalidValue is inserted in the back, + // then a new variable name is created and set to the lhs_id. + auto expr1 = Expression::CreateCompileTimeConstant(42); + auto expr2 = Expression::CreateCompileTimeConstant(10); + auto expr3 = + Expression::CreateBinaryArithmetic("+", ExpressionId(0), ExpressionId(1)); + + ExpressionGraph graph; + graph.InsertBack(expr1); + graph.InsertBack(expr2); + graph.InsertBack(expr3); + EXPECT_EQ(graph.Size(), 3); + + // The ExpressionGraph has a copy of the inserted expression with the correct + // lhs_ids. We set them here manually for comparision. + expr1.set_lhs_id(0); + expr2.set_lhs_id(1); + expr3.set_lhs_id(2); + EXPECT_EQ(graph.ExpressionForId(0), expr1); + EXPECT_EQ(graph.ExpressionForId(1), expr2); + EXPECT_EQ(graph.ExpressionForId(2), expr3); +} + +TEST(ExpressionGraph, InsertBackExistingVariable) { + auto expr1 = Expression::CreateCompileTimeConstant(42); + auto expr2 = Expression::CreateCompileTimeConstant(10); + auto expr3 = Expression::CreateAssignment(1, 0); + + ExpressionGraph graph; + graph.InsertBack(expr1); + graph.InsertBack(expr2); + graph.InsertBack(expr3); + EXPECT_EQ(graph.Size(), 3); + + expr1.set_lhs_id(0); + expr2.set_lhs_id(1); + expr3.set_lhs_id(1); + EXPECT_EQ(graph.ExpressionForId(0), expr1); + EXPECT_EQ(graph.ExpressionForId(1), expr2); + EXPECT_EQ(graph.ExpressionForId(2), expr3); +} + +TEST(ExpressionGraph, DependsOn) { + ExpressionGraph graph; + graph.InsertBack(Expression::CreateCompileTimeConstant(42)); + graph.InsertBack(Expression::CreateCompileTimeConstant(10)); + graph.InsertBack(Expression::CreateBinaryArithmetic( + "+", ExpressionId(0), ExpressionId(1))); + graph.InsertBack(Expression::CreateBinaryArithmetic( + "+", ExpressionId(2), ExpressionId(0))); + + // Code: + // v_0 = 42 + // v_1 = 10 + // v_2 = v_0 + v_1 + // v_3 = v_2 + v_0 + + // Direct dependencies dependency check + ASSERT_TRUE(graph.DependsOn(2, 0)); + ASSERT_TRUE(graph.DependsOn(2, 1)); + ASSERT_TRUE(graph.DependsOn(3, 2)); + ASSERT_TRUE(graph.DependsOn(3, 0)); + ASSERT_FALSE(graph.DependsOn(1, 0)); + ASSERT_FALSE(graph.DependsOn(1, 1)); + ASSERT_FALSE(graph.DependsOn(2, 3)); + + // Recursive test + ASSERT_TRUE(graph.DependsOn(3, 1)); +} + +TEST(ExpressionGraph, InsertExpression_UpdateReferences) { + // This test checks if references to shifted expressions are updated + // accordingly. + ExpressionGraph graph; + graph.InsertBack(Expression::CreateCompileTimeConstant(42)); + graph.InsertBack(Expression::CreateCompileTimeConstant(10)); + graph.InsertBack(Expression::CreateBinaryArithmetic( + "+", ExpressionId(0), ExpressionId(1))); + // Code: + // v_0 = 42 + // v_1 = 10 + // v_2 = v_0 + v_1 + + // Insert another compile time constant at the beginning + graph.Insert(0, Expression::CreateCompileTimeConstant(5)); + // This should shift all indices like this: + // v_0 = 5 + // v_1 = 42 + // v_2 = 10 + // v_3 = v_1 + v_2 + + // Test by inserting it in the correct order + ExpressionGraph ref; + ref.InsertBack(Expression::CreateCompileTimeConstant(5)); + ref.InsertBack(Expression::CreateCompileTimeConstant(42)); + ref.InsertBack(Expression::CreateCompileTimeConstant(10)); + ref.InsertBack(Expression::CreateBinaryArithmetic( + "+", ExpressionId(1), ExpressionId(2))); + EXPECT_EQ(graph.Size(), ref.Size()); + EXPECT_EQ(graph, ref); +} + +} // namespace internal +} // namespace ceres
diff --git a/internal/ceres/codegen/expression_ref_test.cc b/internal/ceres/codegen/expression_ref_test.cc new file mode 100644 index 0000000..aeb9d2b --- /dev/null +++ b/internal/ceres/codegen/expression_ref_test.cc
@@ -0,0 +1,387 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) +// +// This file tests the ExpressionRef class. This test depends on the +// correctness of Expression and ExpressionGraph. +// +#define CERES_CODEGEN + +#include "ceres/codegen/internal/expression_graph.h" +#include "ceres/codegen/internal/expression_ref.h" +#include "gtest/gtest.h" + +namespace ceres { +namespace internal { + +using T = ExpressionRef; + +TEST(ExpressionRef, COMPILE_TIME_CONSTANT) { + StartRecordingExpressions(); + T a = T(0); + T b = T(123.5); + T c = T(1 + 1); + T d; // Uninitialized variables should not generate code! + auto graph = StopRecordingExpressions(); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(0)); + reference.InsertBack(Expression::CreateCompileTimeConstant(123.5)); + reference.InsertBack(Expression::CreateCompileTimeConstant(2)); + EXPECT_EQ(reference, graph); +} + +TEST(ExpressionRef, INPUT_ASSIGNMENT) { + double local_variable = 5.0; + StartRecordingExpressions(); + T a = CERES_LOCAL_VARIABLE(T, local_variable); + T b = MakeParameter("parameters[0][0]"); + auto graph = StopRecordingExpressions(); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateInputAssignment("local_variable")); + reference.InsertBack(Expression::CreateInputAssignment("parameters[0][0]")); + EXPECT_EQ(reference, graph); +} + +TEST(ExpressionRef, OUTPUT_ASSIGNMENT) { + StartRecordingExpressions(); + T a = 1; + T b = 0; + MakeOutput(a, "residual[0]"); + auto graph = StopRecordingExpressions(); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(1)); + reference.InsertBack(Expression::CreateCompileTimeConstant(0)); + reference.InsertBack(Expression::CreateOutputAssignment(0, "residual[0]")); + EXPECT_EQ(reference, graph); +} + +TEST(ExpressionRef, Assignment) { + StartRecordingExpressions(); + T a = 1; + T b = 2; + b = a; + auto graph = StopRecordingExpressions(); + EXPECT_EQ(graph.Size(), 3); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(1)); + reference.InsertBack(Expression::CreateCompileTimeConstant(2)); + reference.InsertBack(Expression::CreateAssignment(1, 0)); + EXPECT_EQ(reference, graph); +} + +TEST(ExpressionRef, AssignmentCreate) { + StartRecordingExpressions(); + T a = 2; + T b = a; + auto graph = StopRecordingExpressions(); + EXPECT_EQ(graph.Size(), 2); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(2)); + reference.InsertBack(Expression::CreateAssignment(kInvalidExpressionId, 0)); + EXPECT_EQ(reference, graph); +} + +TEST(ExpressionRef, MoveAssignmentCreate) { + StartRecordingExpressions(); + T a = 2; + T b = std::move(a); + auto graph = StopRecordingExpressions(); + EXPECT_EQ(graph.Size(), 1); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(2)); + EXPECT_EQ(reference, graph); +} + +TEST(ExpressionRef, MoveAssignment) { + StartRecordingExpressions(); + T a = 1; + T b = 2; + b = std::move(a); + auto graph = StopRecordingExpressions(); + EXPECT_EQ(graph.Size(), 3); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(1)); + reference.InsertBack(Expression::CreateCompileTimeConstant(2)); + reference.InsertBack(Expression::CreateAssignment(1, 0)); + EXPECT_EQ(reference, graph); +} + +TEST(ExpressionRef, BINARY_ARITHMETIC_SIMPLE) { + StartRecordingExpressions(); + T a = T(1); + T b = T(2); + T r1 = a + b; + T r2 = a - b; + T r3 = a * b; + T r4 = a / b; + auto graph = StopRecordingExpressions(); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(1)); + reference.InsertBack(Expression::CreateCompileTimeConstant(2)); + reference.InsertBack(Expression::CreateBinaryArithmetic("+", 0, 1)); + reference.InsertBack(Expression::CreateBinaryArithmetic("-", 0, 1)); + reference.InsertBack(Expression::CreateBinaryArithmetic("*", 0, 1)); + reference.InsertBack(Expression::CreateBinaryArithmetic("/", 0, 1)); + EXPECT_EQ(reference, graph); +} + +TEST(ExpressionRef, BINARY_ARITHMETIC_NESTED) { + StartRecordingExpressions(); + T a = T(1); + T b = T(2); + T r1 = b - a * (a + b) / a; + auto graph = StopRecordingExpressions(); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(1)); + reference.InsertBack(Expression::CreateCompileTimeConstant(2)); + reference.InsertBack(Expression::CreateBinaryArithmetic("+", 0, 1)); + reference.InsertBack(Expression::CreateBinaryArithmetic("*", 0, 2)); + reference.InsertBack(Expression::CreateBinaryArithmetic("/", 3, 0)); + reference.InsertBack(Expression::CreateBinaryArithmetic("-", 1, 4)); + EXPECT_EQ(reference, graph); +} + +TEST(CodeGenerator, BINARY_ARITHMETIC_COMPOUND) { + // For each binary compound arithmetic operation, two lines are generated: + // - The actual operation assigning to a new temporary variable + // - An assignment from the temporary to the lhs + StartRecordingExpressions(); + T a = T(1); + T b = T(2); + a += b; + a -= b; + a *= b; + a /= b; + auto graph = StopRecordingExpressions(); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(1)); + reference.InsertBack(Expression::CreateCompileTimeConstant(2)); + reference.InsertBack(Expression::CreateBinaryArithmetic("+", 0, 1)); + reference.InsertBack(Expression::CreateAssignment(0, 2)); + reference.InsertBack(Expression::CreateBinaryArithmetic("-", 0, 1)); + reference.InsertBack(Expression::CreateAssignment(0, 4)); + reference.InsertBack(Expression::CreateBinaryArithmetic("*", 0, 1)); + reference.InsertBack(Expression::CreateAssignment(0, 6)); + reference.InsertBack(Expression::CreateBinaryArithmetic("/", 0, 1)); + reference.InsertBack(Expression::CreateAssignment(0, 8)); + EXPECT_EQ(reference, graph); +} + +TEST(CodeGenerator, UNARY_ARITHMETIC) { + StartRecordingExpressions(); + T a = T(1); + T r1 = -a; + T r2 = +a; + auto graph = StopRecordingExpressions(); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(1)); + reference.InsertBack(Expression::CreateUnaryArithmetic("-", 0)); + reference.InsertBack(Expression::CreateUnaryArithmetic("+", 0)); + EXPECT_EQ(reference, graph); +} + +TEST(CodeGenerator, BINARY_COMPARISON) { + using BOOL = ComparisonExpressionRef; + StartRecordingExpressions(); + T a = T(1); + T b = T(2); + BOOL r1 = a < b; + BOOL r2 = a <= b; + BOOL r3 = a > b; + BOOL r4 = a >= b; + BOOL r5 = a == b; + BOOL r6 = a != b; + auto graph = StopRecordingExpressions(); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(1)); + reference.InsertBack(Expression::CreateCompileTimeConstant(2)); + reference.InsertBack(Expression::CreateBinaryCompare("<", 0, 1)); + reference.InsertBack(Expression::CreateBinaryCompare("<=", 0, 1)); + reference.InsertBack(Expression::CreateBinaryCompare(">", 0, 1)); + reference.InsertBack(Expression::CreateBinaryCompare(">=", 0, 1)); + reference.InsertBack(Expression::CreateBinaryCompare("==", 0, 1)); + reference.InsertBack(Expression::CreateBinaryCompare("!=", 0, 1)); + EXPECT_EQ(reference, graph); +} + +TEST(CodeGenerator, LOGICAL_OPERATORS) { + using BOOL = ComparisonExpressionRef; + // Tests binary logical operators &&, || and the unary logical operator ! + StartRecordingExpressions(); + T a = T(1); + T b = T(2); + BOOL r1 = a < b; + BOOL r2 = a <= b; + BOOL r3 = r1 && r2; + BOOL r4 = r1 || r2; + BOOL r5 = !r1; + auto graph = StopRecordingExpressions(); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(1)); + reference.InsertBack(Expression::CreateCompileTimeConstant(2)); + reference.InsertBack(Expression::CreateBinaryCompare("<", 0, 1)); + reference.InsertBack(Expression::CreateBinaryCompare("<=", 0, 1)); + reference.InsertBack(Expression::CreateBinaryCompare("&&", 2, 3)); + reference.InsertBack(Expression::CreateBinaryCompare("||", 2, 3)); + reference.InsertBack(Expression::CreateLogicalNegation(2)); + EXPECT_EQ(reference, graph); +} + +TEST(CodeGenerator, FUNCTION_CALL) { + StartRecordingExpressions(); + T a = T(1); + T b = T(2); + abs(a); + acos(a); + asin(a); + atan(a); + cbrt(a); + ceil(a); + cos(a); + cosh(a); + exp(a); + exp2(a); + floor(a); + log(a); + log2(a); + sin(a); + sinh(a); + sqrt(a); + tan(a); + tanh(a); + atan2(a, b); + pow(a, b); + auto graph = StopRecordingExpressions(); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(1)); + reference.InsertBack(Expression::CreateCompileTimeConstant(2)); + reference.InsertBack(Expression::CreateFunctionCall("abs", {0})); + reference.InsertBack(Expression::CreateFunctionCall("acos", {0})); + reference.InsertBack(Expression::CreateFunctionCall("asin", {0})); + reference.InsertBack(Expression::CreateFunctionCall("atan", {0})); + reference.InsertBack(Expression::CreateFunctionCall("cbrt", {0})); + reference.InsertBack(Expression::CreateFunctionCall("ceil", {0})); + reference.InsertBack(Expression::CreateFunctionCall("cos", {0})); + reference.InsertBack(Expression::CreateFunctionCall("cosh", {0})); + reference.InsertBack(Expression::CreateFunctionCall("exp", {0})); + reference.InsertBack(Expression::CreateFunctionCall("exp2", {0})); + reference.InsertBack(Expression::CreateFunctionCall("floor", {0})); + reference.InsertBack(Expression::CreateFunctionCall("log", {0})); + reference.InsertBack(Expression::CreateFunctionCall("log2", {0})); + reference.InsertBack(Expression::CreateFunctionCall("sin", {0})); + reference.InsertBack(Expression::CreateFunctionCall("sinh", {0})); + reference.InsertBack(Expression::CreateFunctionCall("sqrt", {0})); + reference.InsertBack(Expression::CreateFunctionCall("tan", {0})); + reference.InsertBack(Expression::CreateFunctionCall("tanh", {0})); + reference.InsertBack(Expression::CreateFunctionCall("atan2", {0, 1})); + reference.InsertBack(Expression::CreateFunctionCall("pow", {0, 1})); + EXPECT_EQ(reference, graph); +} + +TEST(CodeGenerator, IF) { + StartRecordingExpressions(); + T a = T(1); + T b = T(2); + auto r1 = a < b; + CERES_IF(r1) {} + CERES_ENDIF; + auto graph = StopRecordingExpressions(); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(1)); + reference.InsertBack(Expression::CreateCompileTimeConstant(2)); + reference.InsertBack(Expression::CreateBinaryCompare("<", 0, 1)); + reference.InsertBack(Expression::CreateIf(2)); + reference.InsertBack(Expression::CreateEndIf()); + EXPECT_EQ(reference, graph); +} + +TEST(CodeGenerator, IF_ELSE) { + StartRecordingExpressions(); + T a = T(1); + T b = T(2); + auto r1 = a < b; + CERES_IF(r1) {} + CERES_ELSE {} + CERES_ENDIF; + auto graph = StopRecordingExpressions(); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(1)); + reference.InsertBack(Expression::CreateCompileTimeConstant(2)); + reference.InsertBack(Expression::CreateBinaryCompare("<", 0, 1)); + reference.InsertBack(Expression::CreateIf(2)); + reference.InsertBack(Expression::CreateElse()); + reference.InsertBack(Expression::CreateEndIf()); + EXPECT_EQ(reference, graph); +} + +TEST(CodeGenerator, IF_NESTED) { + StartRecordingExpressions(); + T a = T(1); + T b = T(2); + auto r1 = a < b; + auto r2 = a == b; + CERES_IF(r1) { + CERES_IF(r2) {} + CERES_ENDIF; + } + CERES_ELSE {} + CERES_ENDIF; + auto graph = StopRecordingExpressions(); + + ExpressionGraph reference; + reference.InsertBack(Expression::CreateCompileTimeConstant(1)); + reference.InsertBack(Expression::CreateCompileTimeConstant(2)); + reference.InsertBack(Expression::CreateBinaryCompare("<", 0, 1)); + reference.InsertBack(Expression::CreateBinaryCompare("==", 0, 1)); + reference.InsertBack(Expression::CreateIf(2)); + reference.InsertBack(Expression::CreateIf(3)); + reference.InsertBack(Expression::CreateEndIf()); + reference.InsertBack(Expression::CreateElse()); + reference.InsertBack(Expression::CreateEndIf()); + EXPECT_EQ(reference, graph); +} + +} // namespace internal +} // namespace ceres
diff --git a/internal/ceres/expression_test.cc b/internal/ceres/codegen/expression_test.cc similarity index 63% rename from internal/ceres/expression_test.cc rename to internal/ceres/codegen/expression_test.cc index cab196b..395bbaf 100644 --- a/internal/ceres/expression_test.cc +++ b/internal/ceres/codegen/expression_test.cc
@@ -28,11 +28,10 @@ // // Author: darius.rueckert@fau.de (Darius Rueckert) // -#define CERES_CODEGEN - -#include "ceres/codegen/internal/expression_graph.h" -#include "ceres/codegen/internal/expression_ref.h" -#include "ceres/jet.h" +// This file tests the Expression class. For each member function one test is +// included here. +// +#include "ceres/codegen/internal/expression.h" #include "gtest/gtest.h" namespace ceres { @@ -168,135 +167,61 @@ ASSERT_FALSE(expr1.IsReplaceableBy(expr3)); } +TEST(Expression, Replace) { + auto expr1 = + Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5)); + expr1.set_lhs_id(13); + + auto expr2 = + Expression::CreateAssignment(kInvalidExpressionId, ExpressionId(7)); + + // We replace the arithmetic expr1 by an assignment from the variable 7. This + // is the typical usecase in subexpression elimination. + expr1.Replace(expr2); + + // expr1 should now be an assignment from 7 to 13 + EXPECT_EQ(expr1, Expression(ExpressionType::ASSIGNMENT, 13, {7}, "", 0)); +} + TEST(Expression, DirectlyDependsOn) { - using T = ExpressionRef; + auto expr1 = + Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5)); - StartRecordingExpressions(); + ASSERT_TRUE(expr1.DirectlyDependsOn(ExpressionId(3))); + ASSERT_TRUE(expr1.DirectlyDependsOn(ExpressionId(5))); + ASSERT_FALSE(expr1.DirectlyDependsOn(ExpressionId(kInvalidExpressionId))); + ASSERT_FALSE(expr1.DirectlyDependsOn(ExpressionId(42))); +} +TEST(Expression, MakeNop) { + auto expr1 = + Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5)); - T unused(6); - T a(2), b(3); - T c = a + b; - T d = c + a; + expr1.MakeNop(); - auto graph = StopRecordingExpressions(); - - ASSERT_FALSE(graph.ExpressionForId(a.id).DirectlyDependsOn(unused.id)); - ASSERT_TRUE(graph.ExpressionForId(c.id).DirectlyDependsOn(a.id)); - ASSERT_TRUE(graph.ExpressionForId(c.id).DirectlyDependsOn(b.id)); - ASSERT_TRUE(graph.ExpressionForId(d.id).DirectlyDependsOn(a.id)); - ASSERT_FALSE(graph.ExpressionForId(d.id).DirectlyDependsOn(b.id)); - ASSERT_TRUE(graph.ExpressionForId(d.id).DirectlyDependsOn(c.id)); + EXPECT_EQ(expr1, + Expression(ExpressionType::NOP, kInvalidExpressionId, {}, "", 0)); } -TEST(Expression, Ternary) { - using T = ExpressionRef; +TEST(Expression, IsSemanticallyEquivalentTo) { + // Create 2 identical expression + auto expr1 = + Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5)); - StartRecordingExpressions(); - T a(2); // 0 - T b(3); // 1 - auto c = a < b; // 2 - T d = Ternary(c, a, b); // 3 - MakeOutput(d, "result"); // 4 - auto graph = StopRecordingExpressions(); + auto expr2 = + Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5)); - EXPECT_EQ(graph.Size(), 5); + ASSERT_TRUE(expr1.IsSemanticallyEquivalentTo(expr1)); + ASSERT_TRUE(expr1.IsSemanticallyEquivalentTo(expr2)); - // Expected code - // v_0 = 2; - // v_1 = 3; - // v_2 = v_0 < v_1; - // v_3 = Ternary(v_2, v_0, v_1); - // result = v_3; + auto expr3 = + Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(8)); - ExpressionGraph reference; - // clang-format off - // Id, Type, Lhs, Value, Name, Arguments - reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 2}); - reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 1, {}, "", 3}); - reference.InsertBack({ ExpressionType::BINARY_COMPARISON, 2, {0,1}, "<", 0}); - reference.InsertBack({ ExpressionType::FUNCTION_CALL, 3, {2,0,1}, "Ternary", 0}); - reference.InsertBack({ ExpressionType::OUTPUT_ASSIGNMENT, 4, {3}, "result", 0}); - // clang-format on - EXPECT_EQ(reference, graph); -} + ASSERT_TRUE(expr1.IsSemanticallyEquivalentTo(expr3)); -TEST(Expression, Assignment) { - using T = ExpressionRef; + auto expr4 = + Expression::CreateBinaryArithmetic("-", ExpressionId(3), ExpressionId(5)); - StartRecordingExpressions(); - T a = 1; - T b = 2; - b = a; - auto graph = StopRecordingExpressions(); - - EXPECT_EQ(graph.Size(), 3); - - ExpressionGraph reference; - // clang-format off - // Id, Type, Lhs, Value, Name, Arguments - reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 1}); - reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 1, {}, "", 2}); - reference.InsertBack({ ExpressionType::ASSIGNMENT, 1, {0}, "", 0}); - // clang-format on - EXPECT_EQ(reference, graph); -} - -TEST(Expression, AssignmentCreate) { - using T = ExpressionRef; - - StartRecordingExpressions(); - T a = 2; - T b = a; - auto graph = StopRecordingExpressions(); - - EXPECT_EQ(graph.Size(), 2); - - ExpressionGraph reference; - // clang-format off - // Id, Type, Lhs, Value, Name, Arguments - reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 2}); - reference.InsertBack({ ExpressionType::ASSIGNMENT, 1, {0}, "", 0}); - // clang-format on - EXPECT_EQ(reference, graph); -} - -TEST(Expression, MoveAssignmentCreate) { - using T = ExpressionRef; - - StartRecordingExpressions(); - T a = 1; - T b = std::move(a); - auto graph = StopRecordingExpressions(); - - EXPECT_EQ(graph.Size(), 1); - - ExpressionGraph reference; - // clang-format off - // Id, Type, Lhs, Value, Name, Arguments - reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 1}); - // clang-format on - EXPECT_EQ(reference, graph); -} - -TEST(Expression, MoveAssignment) { - using T = ExpressionRef; - - StartRecordingExpressions(); - T a = 1; - T b = 2; - b = std::move(a); - auto graph = StopRecordingExpressions(); - - EXPECT_EQ(graph.Size(), 3); - - ExpressionGraph reference; - // clang-format off - // Id, Type, Lhs, Value, Name, Arguments - reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 1}); - reference.InsertBack({ ExpressionType::COMPILE_TIME_CONSTANT, 1, {}, "", 2}); - reference.InsertBack({ ExpressionType::ASSIGNMENT, 1, {0}, "", 0}); - // clang-format on - EXPECT_EQ(reference, graph); + ASSERT_FALSE(expr1.IsSemanticallyEquivalentTo(expr4)); } } // namespace internal
diff --git a/internal/ceres/conditional_expressions_test.cc b/internal/ceres/conditional_expressions_test.cc deleted file mode 100644 index 9c499ea..0000000 --- a/internal/ceres/conditional_expressions_test.cc +++ /dev/null
@@ -1,132 +0,0 @@ -// Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2019 Google Inc. All rights reserved. -// http://code.google.com/p/ceres-solver/ -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// * Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// * Neither the name of Google Inc. nor the names of its contributors may be -// used to endorse or promote products derived from this software without -// specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// -// Author: darius.rueckert@fau.de (Darius Rueckert) -// - -#define CERES_CODEGEN - -#include "ceres/codegen/internal/expression_graph.h" -#include "ceres/codegen/internal/expression_ref.h" -#include "gtest/gtest.h" - -namespace ceres { -namespace internal { - -TEST(Expression, ConditionalMinimal) { - using T = ExpressionRef; - - StartRecordingExpressions(); - T a(2); - T b(3); - auto c = a < b; - CERES_IF(c) {} - CERES_ELSE {} - CERES_ENDIF - auto graph = StopRecordingExpressions(); - - // Expected code - // v_0 = 2; - // v_1 = 3; - // v_2 = v_0 < v_1; - // if(v_2); - // else - // endif - - EXPECT_EQ(graph.Size(), 6); - - ExpressionGraph reference; - // clang-format off - // Id, Type, Lhs, Value, Name, Arguments... - reference.InsertBack({ExpressionType::COMPILE_TIME_CONSTANT, 0, {} , "", 2}); - reference.InsertBack({ExpressionType::COMPILE_TIME_CONSTANT, 1, {} , "", 3}); - reference.InsertBack({ ExpressionType::BINARY_COMPARISON, 2, {0, 1} , "<", 0}); - reference.InsertBack({ ExpressionType::IF, -1, {2} , "", 0}); - reference.InsertBack({ ExpressionType::ELSE, -1, {} , "", 0}); - reference.InsertBack({ ExpressionType::ENDIF, -1, {} , "", 0}); - // clang-format on - EXPECT_EQ(reference, graph); -} - -TEST(Expression, ConditionalAssignment) { - using T = ExpressionRef; - - StartRecordingExpressions(); - - T result; - T a(2); - T b(3); - auto c = a < b; - CERES_IF(c) { result = a + b; } - CERES_ELSE { result = a - b; } - CERES_ENDIF - result += a; - auto graph = StopRecordingExpressions(); - - // Expected code - // v_0 = 2; - // v_1 = 3; - // v_2 = v_0 < v_1; - // if(v_2); - // v_4 = v_0 + v_1; - // else - // v_6 = v_0 - v_1; - // v_4 = v_6 - // endif - // v_9 = v_4 + v_0; - // v_4 = v_9; - - ExpressionGraph reference; - // clang-format off - // Id, Type, Lhs, Value, Name, Arguments... - reference.InsertBack({ExpressionType::COMPILE_TIME_CONSTANT, 0, {} , "", 2}); - reference.InsertBack({ExpressionType::COMPILE_TIME_CONSTANT, 1, {} , "", 3}); - reference.InsertBack({ ExpressionType::BINARY_COMPARISON, 2, {0, 1}, "<", 0}); - reference.InsertBack({ ExpressionType::IF, -1, {2} , "", 0}); - reference.InsertBack({ ExpressionType::BINARY_ARITHMETIC, 4, {0, 1}, "+", 0}); - reference.InsertBack({ ExpressionType::ELSE, -1, {} , "", 0}); - reference.InsertBack({ ExpressionType::BINARY_ARITHMETIC, 6, {0, 1}, "-", 0}); - reference.InsertBack({ ExpressionType::ASSIGNMENT, 4, {6} , "", 0}); - reference.InsertBack({ ExpressionType::ENDIF, -1, {} , "", 0}); - reference.InsertBack({ ExpressionType::BINARY_ARITHMETIC, 9, {4, 0}, "+", 0}); - reference.InsertBack({ ExpressionType::ASSIGNMENT, 4, {9} , "", 0}); - // clang-format on - EXPECT_EQ(reference, graph); - - // Variables after execution: - // - // a <=> v_0 - // b <=> v_1 - // result <=> v_4 - EXPECT_EQ(a.id, 0); - EXPECT_EQ(b.id, 1); - EXPECT_EQ(result.id, 4); -} - -} // namespace internal -} // namespace ceres
diff --git a/internal/ceres/expression.cc b/internal/ceres/expression.cc index ca92eaf..27e15cb 100644 --- a/internal/ceres/expression.cc +++ b/internal/ceres/expression.cc
@@ -142,8 +142,8 @@ } void Expression::MakeNop() { - type_ = ExpressionType::NOP; - arguments_.clear(); + // The default constructor creates a NOP expression! + *this = Expression(); } bool Expression::operator==(const Expression& other) const {
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc index c0761af..d9b33e9 100644 --- a/internal/ceres/expression_graph.cc +++ b/internal/ceres/expression_graph.cc
@@ -108,8 +108,17 @@ expressions_[id + 1] = expression; } - // Insert new expression at the correct place - expressions_[location] = expression; + if (expression.IsControlExpression() || + expression.lhs_id() != kInvalidExpressionId) { + // Insert new expression at the correct place + expressions_[location] = expression; + } else { + // Arithmetic expression with invalid lhs + // -> Set lhs to location + Expression copy = expression; + copy.set_lhs_id(location); + expressions_[location] = copy; + } } ExpressionId ExpressionGraph::InsertBack(const Expression& expression) {
diff --git a/internal/ceres/expression_graph_test.cc b/internal/ceres/expression_graph_test.cc deleted file mode 100644 index 65b019c..0000000 --- a/internal/ceres/expression_graph_test.cc +++ /dev/null
@@ -1,160 +0,0 @@ -// Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2019 Google Inc. All rights reserved. -// http://code.google.com/p/ceres-solver/ -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// * Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// * Neither the name of Google Inc. nor the names of its contributors may be -// used to endorse or promote products derived from this software without -// specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// -// Author: darius.rueckert@fau.de (Darius Rueckert) -// -// Test expression creation and logic. - -#include "ceres/codegen/internal/expression_graph.h" -#include "ceres/codegen/internal/expression_ref.h" - -#include "gtest/gtest.h" - -namespace ceres { -namespace internal { - -TEST(ExpressionGraph, Creation) { - using T = ExpressionRef; - ExpressionGraph graph; - - StartRecordingExpressions(); - graph = StopRecordingExpressions(); - EXPECT_EQ(graph.Size(), 0); - - StartRecordingExpressions(); - T a(1); - T b(2); - T c = a + b; - graph = StopRecordingExpressions(); - EXPECT_EQ(graph.Size(), 3); -} - -TEST(ExpressionGraph, Dependencies) { - using T = ExpressionRef; - - StartRecordingExpressions(); - - T unused(6); - T a(2), b(3); - T c = a + b; - T d = c + a; - - auto tree = StopRecordingExpressions(); - - // Recursive dependency check - ASSERT_TRUE(tree.DependsOn(d.id, c.id)); - ASSERT_TRUE(tree.DependsOn(d.id, a.id)); - ASSERT_TRUE(tree.DependsOn(d.id, b.id)); - ASSERT_FALSE(tree.DependsOn(d.id, unused.id)); -} - -TEST(ExpressionGraph, InsertExpression_UpdateReferences) { - // This test checks if references to shifted expressions are updated - // accordingly. - using T = ExpressionRef; - StartRecordingExpressions(); - T a(2); // 0 - T b(3); // 1 - T c = a + b; // 2 - auto graph = StopRecordingExpressions(); - - // Test if 'a' and 'c' are actually at location 0 and 2 - auto& a_expr = graph.ExpressionForId(0); - EXPECT_EQ(a_expr.type(), ExpressionType::COMPILE_TIME_CONSTANT); - EXPECT_EQ(a_expr.value(), 2); - - // At this point 'c' should have 0 and 1 as arguments. - auto& c_expr = graph.ExpressionForId(2); - EXPECT_EQ(c_expr.type(), ExpressionType::BINARY_ARITHMETIC); - EXPECT_EQ(c_expr.arguments()[0], 0); - EXPECT_EQ(c_expr.arguments()[1], 1); - - // We insert at the beginning, which shifts everything by one spot. - graph.Insert(0, {ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 10.2}); - - // Test if 'a' and 'c' are actually at location 1 and 3 - auto& a_expr2 = graph.ExpressionForId(1); - EXPECT_EQ(a_expr2.type(), ExpressionType::COMPILE_TIME_CONSTANT); - EXPECT_EQ(a_expr2.value(), 2); - - // At this point 'c' should have 1 and 2 as arguments. - auto& c_expr2 = graph.ExpressionForId(3); - EXPECT_EQ(c_expr2.type(), ExpressionType::BINARY_ARITHMETIC); - EXPECT_EQ(c_expr2.arguments()[0], 1); - EXPECT_EQ(c_expr2.arguments()[1], 2); -} - -TEST(ExpressionGraph, InsertExpression) { - using T = ExpressionRef; - - StartRecordingExpressions(); - - { - T a(2); // 0 - T b(3); // 1 - T five = 5; // 2 - T tmp = a + five; // 3 - a = tmp; // 4 - T c = a + b; // 5 - T d = a * b; // 6 - T e = c + d; // 7 - MakeOutput(e, "result"); // 8 - } - auto reference = StopRecordingExpressions(); - EXPECT_EQ(reference.Size(), 9); - - StartRecordingExpressions(); - - { - // The expressions 2,3,4 from above are missing. - T a(2); // 0 - T b(3); // 1 - T c = a + b; // 2 - T d = a * b; // 3 - T e = c + d; // 4 - MakeOutput(e, "result"); // 5 - } - - auto graph1 = StopRecordingExpressions(); - EXPECT_EQ(graph1.Size(), 6); - ASSERT_FALSE(reference == graph1); - - // We manually insert the 3 missing expressions - // clang-format off - graph1.Insert(2,{ ExpressionType::COMPILE_TIME_CONSTANT, 2, {}, "", 5}); - graph1.Insert(3,{ ExpressionType::BINARY_ARITHMETIC, 3, {0, 2}, "+", 0}); - graph1.Insert(4,{ ExpressionType::ASSIGNMENT, 0, {3}, "", 0}); - // clang-format on - - // Now the graphs are identical! - EXPECT_EQ(graph1.Size(), 9); - ASSERT_TRUE(reference == graph1); -} - -} // namespace internal -} // namespace ceres