Remove AutodiffCodegen Tests Change-Id: Icd194db7b22add518844f1b507d0fdd3e0fe17fe
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt index ea67cf2..95579da 100644 --- a/internal/ceres/CMakeLists.txt +++ b/internal/ceres/CMakeLists.txt
@@ -370,8 +370,7 @@ add_library(test_util evaluator_test_utils.cc numeric_diff_test_utils.cc - test_util.cc - codegen/test_utils.cc) + test_util.cc) target_include_directories(test_util PUBLIC ${Ceres_SOURCE_DIR}/internal) if (MINIGLOG) @@ -527,6 +526,3 @@ add_subdirectory(autodiff_benchmarks) endif (BUILD_BENCHMARKS) -if(CODE_GENERATION) - add_subdirectory(codegen) -endif()
diff --git a/internal/ceres/autodiff_benchmarks/autodiff_benchmarks.cc b/internal/ceres/autodiff_benchmarks/autodiff_benchmarks.cc index cf2d558..3d1699b 100644 --- a/internal/ceres/autodiff_benchmarks/autodiff_benchmarks.cc +++ b/internal/ceres/autodiff_benchmarks/autodiff_benchmarks.cc
@@ -39,10 +39,33 @@ #include "ceres/autodiff_benchmarks/relative_pose_error.h" #include "ceres/autodiff_benchmarks/snavely_reprojection_error.h" #include "ceres/ceres.h" -#include "ceres/codegen/test_utils.h" namespace ceres { +namespace internal { + +// If we want to use functors with both operator() and an Evaluate() method +// with AutoDiff then this wrapper class here has to be used. Autodiff doesn't +// support functors that have an Evaluate() function. +// +// CostFunctionToFunctor hides the Evaluate() function, because it doesn't +// derive from CostFunction. Autodiff sees it as a simple functor and will use +// the operator() as expected. +template <typename CostFunction> +struct CostFunctionToFunctor { + template <typename... _Args> + CostFunctionToFunctor(_Args&&... __args) + : cost_function(std::forward<_Args>(__args)...) {} + + template <typename... _Args> + bool operator()(_Args&&... __args) const { + return cost_function(std::forward<_Args>(__args)...); + } + + CostFunction cost_function; +}; +} + template <int kParameterBlockSize> static void BM_ConstantAnalytic(benchmark::State& state) { constexpr int num_residuals = 1;
diff --git a/internal/ceres/codegen/CMakeLists.txt b/internal/ceres/codegen/CMakeLists.txt deleted file mode 100644 index 4bea21f..0000000 --- a/internal/ceres/codegen/CMakeLists.txt +++ /dev/null
@@ -1,90 +0,0 @@ -# Ceres Solver - A fast non-linear least squares minimizer -# Copyright 2019 Google Inc. All rights reserved. -# http://ceres-solver.org/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# * Neither the name of Google Inc. nor the names of its contributors may be -# used to endorse or promote products derived from this software without -# specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -# POSSIBILITY OF SUCH DAMAGE. -# -# Author: darius.rueckert@fau.de (Darius Rueckert) -# -# Testing the AutoDiffCodegen system is more complicated, because function- and -# constructor calls have side-effects. In C++ the evaluation order and -# the elision of copies is implementation defined. Between different compilers, -# some expression might be evaluated in a different order or some copies might be -# removed. -# -# Therefore, we run tests that expect a particular compiler behaviour only on gcc. -# -# The semantic tests, which check the correctness by executing the generated code -# are still run on all platforms. -if (BUILD_TESTING AND GFLAGS) - if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - set(CXX_FLAGS_OLD ${CMAKE_CXX_FLAGS}) - # From the man page: - # The C++ standard allows an implementation to omit creating a - # temporary which is only used to initialize another object of the - # same type. Specifying -fno-elide-constructors disables that - # optimization, and forces G++ to call the copy constructor in all cases. - # We use this flag to get the same results between different versions of - # gcc and different optimization levels. Without this flag, testing would - # be very painfull and might break when a new compiler version is released. - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-elide-constructors") - ceres_test(expression_ref) - ceres_test(code_generator) - ceres_test(eliminate_nops) - set(CMAKE_CXX_FLAGS ${CXX_FLAGS_OLD}) - endif() - - ceres_test(expression) - ceres_test(expression_graph) - - macro (generate_test_functor FUNCTOR_NAME FUNCTOR_FILE) - ceres_generate_cost_function_implementation_for_functor( - NAME ${FUNCTOR_NAME} - INPUT_FILE ${FUNCTOR_FILE} - OUTPUT_DIRECTORY tests - NAMESPACE test - ) - endmacro() - - # Semantic tests should work on every platform - include(CeresCodeGeneration) - - - generate_test_functor(InputOutputAssignment autodiff_codegen_test.h) - generate_test_functor(CompileTimeConstants autodiff_codegen_test.h) - generate_test_functor(Assignments autodiff_codegen_test.h) - generate_test_functor(BinaryArithmetic autodiff_codegen_test.h) - generate_test_functor(UnaryArithmetic autodiff_codegen_test.h) - generate_test_functor(BinaryComparison autodiff_codegen_test.h) - generate_test_functor(LogicalOperators autodiff_codegen_test.h) - generate_test_functor(ScalarFunctions autodiff_codegen_test.h) - generate_test_functor(LogicalFunctions autodiff_codegen_test.h) - generate_test_functor(Branches autodiff_codegen_test.h) - ceres_test(autodiff_codegen) - target_link_libraries(autodiff_codegen_test PUBLIC - InputOutputAssignment CompileTimeConstants Assignments BinaryArithmetic - UnaryArithmetic BinaryComparison LogicalOperators ScalarFunctions - LogicalFunctions Branches) -endif()
diff --git a/internal/ceres/codegen/autodiff_codegen_test.cc b/internal/ceres/codegen/autodiff_codegen_test.cc deleted file mode 100644 index c6c5cbc..0000000 --- a/internal/ceres/codegen/autodiff_codegen_test.cc +++ /dev/null
@@ -1,117 +0,0 @@ -// Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2019 Google Inc. All rights reserved. -// http://code.google.com/p/ceres-solver/ -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// * Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// * Neither the name of Google Inc. nor the names of its contributors may be -// used to endorse or promote products derived from this software without -// specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// -// Author: darius.rueckert@fau.de (Darius Rueckert) -// -// This file tests the Expression class. For each member function one test is -// included here. -// -#include "autodiff_codegen_test.h" - -#include "ceres/autodiff_cost_function.h" -#include "codegen/test_utils.h" -#include "gtest/gtest.h" - -namespace ceres { -namespace internal { - -class AutoDiffCodegenTest : public testing::TestWithParam<double> { - public: - template <typename CostFunctionType, int kNumResiduals, int... Ns> - void TestCostFunction() { - using CostFunctorType = CostFunctionToFunctor<CostFunctionType>; - CostFunctionType generated_cost_function; - CostFunctorType cost_functor; - auto* cost_function_ad = - new AutoDiffCostFunction<CostFunctorType, kNumResiduals, Ns...>( - &cost_functor); - auto value = GetParam(); - CompareCostFunctions(&generated_cost_function, - cost_function_ad, - value, - kRelativeErrorThreshold); - } - static constexpr double kRelativeErrorThreshold = 0; -}; - -TEST_P(AutoDiffCodegenTest, InputOutputAssignment) { - TestCostFunction<test::InputOutputAssignment, 7, 4, 2, 1>(); -} - -TEST_P(AutoDiffCodegenTest, CompileTimeConstants) { - TestCostFunction<test::CompileTimeConstants, 7, 1>(); -} - -TEST_P(AutoDiffCodegenTest, Assignments) { - TestCostFunction<test::Assignments, 8, 2>(); -} - -TEST_P(AutoDiffCodegenTest, BinaryArithmetic) { - TestCostFunction<test::BinaryArithmetic, 9, 2>(); -} - -TEST_P(AutoDiffCodegenTest, UnaryArithmetic) { - TestCostFunction<test::UnaryArithmetic, 3, 1>(); -} - -TEST_P(AutoDiffCodegenTest, BinaryComparison) { - TestCostFunction<test::BinaryComparison, 12, 2>(); -} - -TEST_P(AutoDiffCodegenTest, LogicalOperators) { - TestCostFunction<test::LogicalOperators, 8, 3>(); -} - -TEST_P(AutoDiffCodegenTest, ScalarFunctions) { - TestCostFunction<test::ScalarFunctions, 20, 22>(); -} - -TEST_P(AutoDiffCodegenTest, LogicalFunctions) { - TestCostFunction<test::LogicalFunctions, 4, 4>(); -} - -TEST_P(AutoDiffCodegenTest, Branches) { - TestCostFunction<test::Branches, 4, 3>(); -} - -INSTANTIATE_TEST_SUITE_P( - AutoDiffCodegenTest, - AutoDiffCodegenTest, - testing::Values(0, - -1, - 1, - 0.5, - -0.5, - 10, - -10, - 1e20, - 1e-20, - std::numeric_limits<double>::infinity(), - std::numeric_limits<double>::quiet_NaN())); -} // namespace internal -} // namespace ceres
diff --git a/internal/ceres/codegen/autodiff_codegen_test.h b/internal/ceres/codegen/autodiff_codegen_test.h deleted file mode 100644 index f5a9f44..0000000 --- a/internal/ceres/codegen/autodiff_codegen_test.h +++ /dev/null
@@ -1,358 +0,0 @@ -// Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2020 Google Inc. All rights reserved. -// http://code.google.com/p/ceres-solver/ -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// * Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// * Neither the name of Google Inc. nor the names of its contributors may be -// used to endorse or promote products derived from this software without -// specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// -// Author: darius.rueckert@fau.de (Darius Rueckert) -// -// This file includes unit test functors for every supported expression type. -// This is similar to expression_ref_test and codegeneration_test, but for the -// complete pipeline including automatic differentation. For each of the structs -// below, the Evaluate function is generated using GenerateCodeForFunctor. After -// that this function is executed with random parameters. The result of the -// residuals and jacobians is then compared to AutoDiff (without code -// generation). Of course, the correctness of this module depends on the -// correctness of autodiff. -// -#include <cmath> -#include <limits> - -#include "ceres/codegen/codegen_cost_function.h" -namespace test { - -struct InputOutputAssignment : public ceres::CodegenCostFunction<7, 4, 2, 1> { - template <typename T> - bool operator()(const T* x0, const T* x1, const T* x2, T* y) const { - y[0] = x0[0]; - y[1] = x0[1]; - y[2] = x0[2]; - y[3] = x0[3]; - - y[4] = x1[0]; - y[5] = x1[1]; - - y[6] = x2[0]; - return true; - } -#include "tests/inputoutputassignment.h" -}; - -struct CompileTimeConstants : public ceres::CodegenCostFunction<7, 1> { - template <typename T> - bool operator()(const T* x0, T* y) const { - y[0] = T(0); - y[1] = T(1); - y[2] = T(-1); - y[3] = T(1e-10); - y[4] = T(1e10); - y[5] = T(std::numeric_limits<double>::infinity()); - y[6] = T(std::numeric_limits<double>::quiet_NaN()); - - return true; - } -#include "tests/compiletimeconstants.h" -}; - -struct Assignments : public ceres::CodegenCostFunction<8, 2> { - template <typename T> - bool operator()(const T* x0, T* y) const { - T a = x0[0]; - T b = x0[1]; - y[0] = a; - y[1] = b; - y[2] = y[3] = a; - - T c = a; - y[4] = c; - - T d(b); - y[5] = d; - - y[6] = std::move(c); - - y[7] = std::move(T(T(std::move(T(a))))); - return true; - } -#include "tests/assignments.h" -}; - -struct BinaryArithmetic : public ceres::CodegenCostFunction<9, 2> { - template <typename T> - bool operator()(const T* x0, T* y) const { - T a = x0[0]; - T b = x0[1]; - y[0] = a + b; - y[1] = a - b; - y[2] = a * b; - y[3] = a / b; - - y[4] = a; - y[4] += b; - y[5] = a; - y[5] -= b; - y[6] = a; - y[6] *= b; - y[7] = a; - y[7] /= b; - - y[8] = a + b * a / a - b + b / a; - return true; - } -#include "tests/binaryarithmetic.h" -}; - -struct UnaryArithmetic : public ceres::CodegenCostFunction<3, 1> { - template <typename T> - bool operator()(const T* x0, T* y) const { - T a = x0[0]; - y[0] = -a; - y[1] = +a; - y[2] = a; - return true; - } -#include "tests/unaryarithmetic.h" -}; - -struct BinaryComparison : public ceres::CodegenCostFunction<12, 2> { - template <typename T> - bool operator()(const T* x0, T* y) const { - T a = x0[0]; - T b = x0[1]; - - // For each operator we swap the inputs so both branches are evaluated once. - CERES_IF(a < b) { y[0] = T(0); } - CERES_ELSE { y[0] = T(1); } - CERES_ENDIF - CERES_IF(b < a) { y[1] = T(0); } - CERES_ELSE { y[1] = T(1); } - CERES_ENDIF - - CERES_IF(a > b) { y[2] = T(0); } - CERES_ELSE { y[2] = T(1); } - CERES_ENDIF - CERES_IF(b > a) { y[3] = T(0); } - CERES_ELSE { y[3] = T(1); } - CERES_ENDIF - - CERES_IF(a <= b) { y[4] = T(0); } - CERES_ELSE { y[4] = T(1); } - CERES_ENDIF - CERES_IF(b <= a) { y[5] = T(0); } - CERES_ELSE { y[5] = T(1); } - CERES_ENDIF - - CERES_IF(a >= b) { y[6] = T(0); } - CERES_ELSE { y[6] = T(1); } - CERES_ENDIF - CERES_IF(b >= a) { y[7] = T(0); } - CERES_ELSE { y[7] = T(1); } - CERES_ENDIF - - CERES_IF(a == b) { y[8] = T(0); } - CERES_ELSE { y[8] = T(1); } - CERES_ENDIF - CERES_IF(b == a) { y[9] = T(0); } - CERES_ELSE { y[9] = T(1); } - CERES_ENDIF - - CERES_IF(a != b) { y[10] = T(0); } - CERES_ELSE { y[10] = T(1); } - CERES_ENDIF - CERES_IF(b != a) { y[11] = T(0); } - CERES_ELSE { y[11] = T(1); } - CERES_ENDIF - - return true; - } -#include "tests/binarycomparison.h" -}; - -struct LogicalOperators : public ceres::CodegenCostFunction<8, 3> { - template <typename T> - bool operator()(const T* x0, T* y) const { - T a = x0[0]; - T b = x0[1]; - T c = x0[2]; - auto r1 = a < b; - auto r2 = a < c; - - CERES_IF(r1) { y[0] = T(0); } - CERES_ELSE { y[0] = T(1); } - CERES_ENDIF - CERES_IF(r2) { y[1] = T(0); } - CERES_ELSE { y[1] = T(1); } - CERES_ENDIF - CERES_IF(!r1) { y[2] = T(0); } - CERES_ELSE { y[2] = T(1); } - CERES_ENDIF - CERES_IF(!r2) { y[3] = T(0); } - CERES_ELSE { y[3] = T(1); } - CERES_ENDIF - - CERES_IF(r1 && r2) { y[4] = T(0); } - CERES_ELSE { y[4] = T(1); } - CERES_ENDIF - CERES_IF(!r1 && !r2) { y[5] = T(0); } - CERES_ELSE { y[5] = T(1); } - CERES_ENDIF - - CERES_IF(r1 || r2) { y[6] = T(0); } - CERES_ELSE { y[6] = T(1); } - CERES_ENDIF - CERES_IF(!r1 || !r2) { y[7] = T(0); } - CERES_ELSE { y[7] = T(1); } - CERES_ENDIF - - return true; - } -#include "tests/logicaloperators.h" -}; - -struct ScalarFunctions : public ceres::CodegenCostFunction<20, 22> { - template <typename T> - bool operator()(const T* x0, T* y) const { - y[0] = abs(x0[0]); - y[1] = acos(x0[1]); - y[2] = asin(x0[2]); - y[3] = atan(x0[3]); - y[4] = cbrt(x0[4]); - y[5] = ceil(x0[5]); - y[6] = cos(x0[6]); - y[7] = cosh(x0[7]); - y[8] = exp(x0[8]); - y[9] = exp2(x0[9]); - y[10] = floor(x0[10]); - y[11] = log(x0[11]); - y[12] = log2(x0[12]); - y[13] = sin(x0[13]); - y[14] = sinh(x0[14]); - y[15] = sqrt(x0[15]); - y[16] = tan(x0[16]); - y[17] = tanh(x0[17]); - y[18] = atan2(x0[18], x0[19]); - y[19] = pow(x0[20], x0[21]); - return true; - } -#include "tests/scalarfunctions.h" -}; - -struct LogicalFunctions : public ceres::CodegenCostFunction<4, 4> { - template <typename T> - bool operator()(const T* x0, T* y) const { - using std::isfinite; - using std::isinf; - using std::isnan; - using std::isnormal; - T a = x0[0]; - auto r1 = isfinite(a); - auto r2 = isinf(a); - auto r3 = isnan(a); - auto r4 = isnormal(a); - - CERES_IF(r1) { y[0] = T(0); } - CERES_ELSE { y[0] = T(1); } - CERES_ENDIF - CERES_IF(r2) { y[1] = T(0); } - CERES_ELSE { y[1] = T(1); } - CERES_ENDIF - CERES_IF(r3) { y[2] = T(0); } - CERES_ELSE { y[2] = T(1); } - CERES_ENDIF - CERES_IF(r4) { y[3] = T(0); } - CERES_ELSE { y[3] = T(1); } - CERES_ENDIF - - return true; - } -#include "tests/logicalfunctions.h" -}; - -struct Branches : public ceres::CodegenCostFunction<4, 3> { - template <typename T> - bool operator()(const T* x0, T* y) const { - T a = x0[0]; - T b = x0[1]; - T c = x0[2]; - auto r1 = a < b; - auto r2 = a < c; - auto r3 = b < c; - - // If without else - y[0] = T(0); - CERES_IF(r1) { y[0] += T(1); } - CERES_ENDIF - - // If else - y[1] = T(0); - CERES_IF(r1) { y[1] += T(-1); } - CERES_ELSE { y[1] += T(1); } - CERES_ENDIF - - // Nested if - y[2] = T(0); - CERES_IF(r1) { - y[2] += T(1); - CERES_IF(r2) { - y[2] += T(4); - CERES_IF(r2) { y[2] += T(8); } - CERES_ENDIF - } - CERES_ENDIF - } - CERES_ENDIF - - // Nested if-else - y[3] = T(0); - CERES_IF(r1) { - y[3] += T(1); - CERES_IF(r2) { - y[3] += T(2); - CERES_IF(r3) { y[3] += T(4); } - CERES_ELSE { y[3] += T(8); } - CERES_ENDIF - } - CERES_ELSE { - y[3] += T(16); - CERES_IF(r3) { y[3] += T(32); } - CERES_ELSE { y[3] += T(64); } - CERES_ENDIF - } - CERES_ENDIF - } - CERES_ELSE { - y[3] += T(128); - CERES_IF(r2) { y[3] += T(256); } - CERES_ELSE { y[3] += T(512); } - CERES_ENDIF - } - CERES_ENDIF - - return true; - } -#include "tests/branches.h" -}; - -} // namespace test
diff --git a/internal/ceres/codegen/code_generator_test.cc b/internal/ceres/codegen/code_generator_test.cc deleted file mode 100644 index f799a75..0000000 --- a/internal/ceres/codegen/code_generator_test.cc +++ /dev/null
@@ -1,535 +0,0 @@ -// Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2019 Google Inc. All rights reserved. -// http://code.google.com/p/ceres-solver/ -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// * Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// * Neither the name of Google Inc. nor the names of its contributors may be -// used to endorse or promote products derived from this software without -// specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// -// Author: darius.rueckert@fau.de (Darius Rueckert) -// -#define CERES_CODEGEN - -#include "ceres/codegen/internal/code_generator.h" - -#include "ceres/codegen/internal/expression_graph.h" -#include "ceres/codegen/internal/expression_ref.h" -#include "gtest/gtest.h" - -namespace ceres { -namespace internal { - -static void GenerateAndCheck(const ExpressionGraph& graph, - const std::vector<std::string>& reference) { - CodeGenerator::Options generator_options; - CodeGenerator gen(graph, generator_options); - auto code = gen.Generate(); - EXPECT_EQ(code.size(), reference.size()); - - for (int i = 0; i < std::min(code.size(), reference.size()); ++i) { - EXPECT_EQ(code[i], reference[i]) << "Invalid Line: " << (i + 1); - } -} - -using T = ExpressionRef; - -TEST(CodeGenerator, Empty) { - StartRecordingExpressions(); - auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = {"{", "}"}; - GenerateAndCheck(graph, expected_code); -} - -// Now we add one TEST for each ExpressionType. -TEST(CodeGenerator, COMPILE_TIME_CONSTANT) { - StartRecordingExpressions(); - T a = T(0); - T b = T(123.5); - T c = T(1 + 1); - T d = T(std::numeric_limits<double>::infinity()); - T e = T(-std::numeric_limits<double>::infinity()); - T f = T(std::numeric_limits<double>::quiet_NaN()); - T g; // Uninitialized variables are 0 initialized. - auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = { - "{", - " double v_0;", - " double v_1;", - " double v_2;", - " double v_3;", - " double v_4;", - " double v_5;", - " double v_6;", - " v_0 = 0;", - " v_1 = 123.5;", - " v_2 = 2;", - " v_3 = std::numeric_limits<double>::infinity();", - " v_4 = -std::numeric_limits<double>::infinity();", - " v_5 = std::numeric_limits<double>::quiet_NaN();", - " v_6 = 0;", - "}"}; - GenerateAndCheck(graph, expected_code); -} - -TEST(CodeGenerator, INPUT_ASSIGNMENT) { - double local_variable = 5.0; - StartRecordingExpressions(); - T a = CERES_LOCAL_VARIABLE(T, local_variable); - T b = MakeParameter("parameters[0][0]"); - T c = a + b; - auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = {"{", - " double v_0;", - " double v_1;", - " double v_2;", - " v_0 = local_variable;", - " v_1 = parameters[0][0];", - " v_2 = v_0 + v_1;", - "}"}; - GenerateAndCheck(graph, expected_code); -} - -TEST(CodeGenerator, OUTPUT_ASSIGNMENT) { - StartRecordingExpressions(); - T a = 1; - T b = 0; - MakeOutput(a, "residual[0]"); - auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = {"{", - " double v_0;", - " double v_1;", - " double v_2;", - " v_0 = 1;", - " v_1 = 0;", - " residual[0] = v_0;", - "}"}; - GenerateAndCheck(graph, expected_code); -} - -TEST(CodeGenerator, ASSIGNMENT) { - StartRecordingExpressions(); - T a = T(0); // 0 - T b = T(1); // 1 - T c = a; // 2 - a = b; // 3 - a = a + b; // 4 + 5 - auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = {"{", - " double v_0;", - " double v_1;", - " double v_2;", - " double v_4;", - " v_0 = 0;", - " v_1 = 1;", - " v_2 = v_0;", - " v_0 = v_1;", - " v_4 = v_0 + v_1;", - " v_0 = v_4;", - "}"}; - GenerateAndCheck(graph, expected_code); -} - -TEST(CodeGenerator, BINARY_ARITHMETIC_SIMPLE) { - StartRecordingExpressions(); - T a = T(1); - T b = T(2); - T r1 = a + b; - T r2 = a - b; - T r3 = a * b; - T r4 = a / b; - auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = {"{", - " double v_0;", - " double v_1;", - " double v_2;", - " double v_3;", - " double v_4;", - " double v_5;", - " v_0 = 1;", - " v_1 = 2;", - " v_2 = v_0 + v_1;", - " v_3 = v_0 - v_1;", - " v_4 = v_0 * v_1;", - " v_5 = v_0 / v_1;", - "}"}; - GenerateAndCheck(graph, expected_code); -} - -TEST(CodeGenerator, BINARY_ARITHMETIC_NESTED) { - StartRecordingExpressions(); - T a = T(1); - T b = T(2); - T r1 = b - a * (a + b) / a; - auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = {"{", - " double v_0;", - " double v_1;", - " double v_2;", - " double v_3;", - " double v_4;", - " double v_5;", - " v_0 = 1;", - " v_1 = 2;", - " v_2 = v_0 + v_1;", - " v_3 = v_0 * v_2;", - " v_4 = v_3 / v_0;", - " v_5 = v_1 - v_4;", - "}"}; - GenerateAndCheck(graph, expected_code); -} - -TEST(CodeGenerator, BINARY_ARITHMETIC_COMPOUND) { - // For each binary compound arithmetic operation, two lines are generated: - // - The actual operation assigning to a new temporary variable - // - An assignment from the temporary to the lhs - StartRecordingExpressions(); - T a = T(1); - T b = T(2); - a += b; - a -= b; - a *= b; - a /= b; - auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = {"{", - " double v_0;", - " double v_1;", - " double v_2;", - " double v_4;", - " double v_6;", - " double v_8;", - " v_0 = 1;", - " v_1 = 2;", - " v_2 = v_0 + v_1;", - " v_0 = v_2;", - " v_4 = v_0 - v_1;", - " v_0 = v_4;", - " v_6 = v_0 * v_1;", - " v_0 = v_6;", - " v_8 = v_0 / v_1;", - " v_0 = v_8;", - "}"}; - GenerateAndCheck(graph, expected_code); -} - -TEST(CodeGenerator, UNARY_ARITHMETIC) { - StartRecordingExpressions(); - T a = T(0); - T r1 = -a; - T r2 = +a; - auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = {"{", - " double v_0;", - " double v_1;", - " double v_2;", - " v_0 = 0;", - " v_1 = -v_0;", - " v_2 = +v_0;", - "}"}; - GenerateAndCheck(graph, expected_code); -} - -TEST(CodeGenerator, BINARY_COMPARISON) { - StartRecordingExpressions(); - T a = T(0); - T b = T(1); - auto r1 = a < b; - auto r2 = a <= b; - auto r3 = a > b; - auto r4 = a >= b; - auto r5 = a == b; - auto r6 = a != b; - auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = {"{", - " double v_0;", - " double v_1;", - " bool v_2;", - " bool v_3;", - " bool v_4;", - " bool v_5;", - " bool v_6;", - " bool v_7;", - " v_0 = 0;", - " v_1 = 1;", - " v_2 = v_0 < v_1;", - " v_3 = v_0 <= v_1;", - " v_4 = v_0 > v_1;", - " v_5 = v_0 >= v_1;", - " v_6 = v_0 == v_1;", - " v_7 = v_0 != v_1;", - "}"}; - GenerateAndCheck(graph, expected_code); -} - -TEST(CodeGenerator, LOGICAL_OPERATORS) { - // Tests binary logical operators &&, || and the unary logical operator ! - StartRecordingExpressions(); - T a = T(0); - T b = T(1); - auto r1 = a < b; - auto r2 = a <= b; - - auto r3 = r1 && r2; - auto r4 = r1 || r2; - auto r5 = !r1; - - auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = {"{", - " double v_0;", - " double v_1;", - " bool v_2;", - " bool v_3;", - " bool v_4;", - " bool v_5;", - " bool v_6;", - " v_0 = 0;", - " v_1 = 1;", - " v_2 = v_0 < v_1;", - " v_3 = v_0 <= v_1;", - " v_4 = v_2 && v_3;", - " v_5 = v_2 || v_3;", - " v_6 = !v_2;", - "}"}; - GenerateAndCheck(graph, expected_code); -} - -TEST(CodeGenerator, FUNCTION_CALL) { - StartRecordingExpressions(); - T a = T(0); - T b = T(1); - - abs(a); - acos(a); - asin(a); - atan(a); - cbrt(a); - ceil(a); - cos(a); - cosh(a); - exp(a); - exp2(a); - floor(a); - log(a); - log2(a); - sin(a); - sinh(a); - sqrt(a); - tan(a); - tanh(a); - atan2(a, b); - pow(a, b); - - auto graph = StopRecordingExpressions(); - - std::vector<std::string> expected_code = {"{", - " double v_0;", - " double v_1;", - " double v_2;", - " double v_3;", - " double v_4;", - " double v_5;", - " double v_6;", - " double v_7;", - " double v_8;", - " double v_9;", - " double v_10;", - " double v_11;", - " double v_12;", - " double v_13;", - " double v_14;", - " double v_15;", - " double v_16;", - " double v_17;", - " double v_18;", - " double v_19;", - " double v_20;", - " double v_21;", - " v_0 = 0;", - " v_1 = 1;", - " v_2 = std::abs(v_0);", - " v_3 = std::acos(v_0);", - " v_4 = std::asin(v_0);", - " v_5 = std::atan(v_0);", - " v_6 = std::cbrt(v_0);", - " v_7 = std::ceil(v_0);", - " v_8 = std::cos(v_0);", - " v_9 = std::cosh(v_0);", - " v_10 = std::exp(v_0);", - " v_11 = std::exp2(v_0);", - " v_12 = std::floor(v_0);", - " v_13 = std::log(v_0);", - " v_14 = std::log2(v_0);", - " v_15 = std::sin(v_0);", - " v_16 = std::sinh(v_0);", - " v_17 = std::sqrt(v_0);", - " v_18 = std::tan(v_0);", - " v_19 = std::tanh(v_0);", - " v_20 = std::atan2(v_0, v_1);", - " v_21 = std::pow(v_0, v_1);", - "}"}; - GenerateAndCheck(graph, expected_code); -} - -TEST(CodeGenerator, LOGICAL_FUNCTION_CALL) { - StartRecordingExpressions(); - T a = T(1); - - isfinite(a); - isinf(a); - isnan(a); - isnormal(a); - - auto graph = StopRecordingExpressions(); - - std::vector<std::string> expected_code = {"{", - " double v_0;", - " bool v_1;", - " bool v_2;", - " bool v_3;", - " bool v_4;", - " v_0 = 1;", - " v_1 = std::isfinite(v_0);", - " v_2 = std::isinf(v_0);", - " v_3 = std::isnan(v_0);", - " v_4 = std::isnormal(v_0);", - "}"}; - GenerateAndCheck(graph, expected_code); -} - -TEST(CodeGenerator, IF_SIMPLE) { - StartRecordingExpressions(); - T a = T(0); - T b = T(1); - auto r1 = a < b; - CERES_IF(r1) {} - CERES_ELSE {} - CERES_ENDIF; - - auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = {"{", - " double v_0;", - " double v_1;", - " bool v_2;", - " v_0 = 0;", - " v_1 = 1;", - " v_2 = v_0 < v_1;", - " if (v_2) {", - " } else {", - " }", - "}"}; - GenerateAndCheck(graph, expected_code); -} - -TEST(CodeGenerator, IF_ASSIGNMENT) { - StartRecordingExpressions(); - T a = T(0); - T b = T(1); - auto r1 = a < b; - - T result = 0; - CERES_IF(r1) { result = 5.0; } - CERES_ELSE { result = 6.0; } - CERES_ENDIF; - MakeOutput(result, "result"); - - auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = {"{", - " double v_0;", - " double v_1;", - " bool v_2;", - " double v_3;", - " double v_5;", - " double v_8;", - " double v_11;", - " v_0 = 0;", - " v_1 = 1;", - " v_2 = v_0 < v_1;", - " v_3 = 0;", - " if (v_2) {", - " v_5 = 5;", - " v_3 = v_5;", - " } else {", - " v_8 = 6;", - " v_3 = v_8;", - " }", - " result = v_3;", - "}"}; - GenerateAndCheck(graph, expected_code); -} - -TEST(CodeGenerator, IF_NESTED_ASSIGNMENT) { - StartRecordingExpressions(); - T a = T(0); - T b = T(1); - - T result = 0; - CERES_IF(a <= b) { - result = 5.0; - CERES_IF(a == b) { result = 7.0; } - CERES_ENDIF; - } - CERES_ELSE { result = 6.0; } - CERES_ENDIF; - MakeOutput(result, "result"); - - auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = {"{", - " double v_0;", - " double v_1;", - " double v_2;", - " bool v_3;", - " double v_5;", - " bool v_7;", - " double v_9;", - " double v_13;", - " double v_16;", - " v_0 = 0;", - " v_1 = 1;", - " v_2 = 0;", - " v_3 = v_0 <= v_1;", - " if (v_3) {", - " v_5 = 5;", - " v_2 = v_5;", - " v_7 = v_0 == v_1;", - " if (v_7) {", - " v_9 = 7;", - " v_2 = v_9;", - " }", - " } else {", - " v_13 = 6;", - " v_2 = v_13;", - " }", - " result = v_2;", - "}"}; - GenerateAndCheck(graph, expected_code); -} - -TEST(CodeGenerator, COMMENT) { - StartRecordingExpressions(); - CERES_COMMENT("Hello"); - auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = {"{", " // Hello", "}"}; - GenerateAndCheck(graph, expected_code); -} - -} // namespace internal -} // namespace ceres
diff --git a/internal/ceres/codegen/eliminate_nops_test.cc b/internal/ceres/codegen/eliminate_nops_test.cc deleted file mode 100644 index 4b196b1..0000000 --- a/internal/ceres/codegen/eliminate_nops_test.cc +++ /dev/null
@@ -1,110 +0,0 @@ -// Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2020 Google Inc. All rights reserved. -// http://code.google.com/p/ceres-solver/ -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// * Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// * Neither the name of Google Inc. nor the names of its contributors may be -// used to endorse or promote products derived from this software without -// specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// -// Author: darius.rueckert@fau.de (Darius Rueckert) -// -#define CERES_CODEGEN - -#include "ceres/codegen/internal/eliminate_nops.h" - -#include "ceres/codegen/internal/code_generator.h" -#include "ceres/codegen/internal/expression_graph.h" -#include "ceres/codegen/internal/expression_ref.h" -#include "gtest/gtest.h" - -namespace ceres { -namespace internal { - -using T = ExpressionRef; - -TEST(EliminateNops, SimpleLinear) { - StartRecordingExpressions(); - { - T a = T(0); - // The Expression default constructor creates a NOP. - AddExpressionToGraph(Expression()); - AddExpressionToGraph(Expression()); - T b = T(2); - AddExpressionToGraph(Expression()); - MakeOutput(b, "residual[0]"); - AddExpressionToGraph(Expression()); - } - auto graph = StopRecordingExpressions(); - - StartRecordingExpressions(); - { - T a = T(0); - T b = T(2); - MakeOutput(b, "residual[0]"); - } - auto reference = StopRecordingExpressions(); - - auto summary = EliminateNops(&graph); - EXPECT_TRUE(summary.expression_graph_changed); - EXPECT_EQ(graph, reference); -} - -TEST(EliminateNops, Branches) { - StartRecordingExpressions(); - { - T a = T(0); - // The Expression default constructor creates a NOP. - AddExpressionToGraph(Expression()); - AddExpressionToGraph(Expression()); - T b = T(2); - CERES_IF(a < b) { - AddExpressionToGraph(Expression()); - T c = T(3); - } - CERES_ELSE { - AddExpressionToGraph(Expression()); - MakeOutput(b, "residual[0]"); - AddExpressionToGraph(Expression()); - } - CERES_ENDIF - AddExpressionToGraph(Expression()); - } - auto graph = StopRecordingExpressions(); - - StartRecordingExpressions(); - { - T a = T(0); - T b = T(2); - CERES_IF(a < b) { T c = T(3); } - CERES_ELSE { MakeOutput(b, "residual[0]"); } - CERES_ENDIF - } - auto reference = StopRecordingExpressions(); - - auto summary = EliminateNops(&graph); - EXPECT_TRUE(summary.expression_graph_changed); - EXPECT_EQ(graph, reference); -} - -} // namespace internal -} // namespace ceres
diff --git a/internal/ceres/codegen/expression_graph_test.cc b/internal/ceres/codegen/expression_graph_test.cc deleted file mode 100644 index 4f4f0ec..0000000 --- a/internal/ceres/codegen/expression_graph_test.cc +++ /dev/null
@@ -1,286 +0,0 @@ -// Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2019 Google Inc. All rights reserved. -// http://code.google.com/p/ceres-solver/ -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// * Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// * Neither the name of Google Inc. nor the names of its contributors may be -// used to endorse or promote products derived from this software without -// specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// -// Author: darius.rueckert@fau.de (Darius Rueckert) -// -// This file tests the ExpressionGraph class. This test depends on the -// correctness of Expression. -// -#include "ceres/codegen/internal/expression_graph.h" - -#include "ceres/codegen/internal/expression.h" -#include "gtest/gtest.h" - -namespace ceres { -namespace internal { - -TEST(ExpressionGraph, Size) { - ExpressionGraph graph; - EXPECT_EQ(graph.Size(), 0); - // Insert 3 NOPs - graph.InsertBack(Expression()); - graph.InsertBack(Expression()); - graph.InsertBack(Expression()); - EXPECT_EQ(graph.Size(), 3); -} - -TEST(ExpressionGraph, Recording) { - EXPECT_EQ(GetCurrentExpressionGraph(), nullptr); - StartRecordingExpressions(); - EXPECT_NE(GetCurrentExpressionGraph(), nullptr); - auto graph = StopRecordingExpressions(); - EXPECT_EQ(graph, ExpressionGraph()); - EXPECT_EQ(GetCurrentExpressionGraph(), nullptr); -} - -TEST(ExpressionGraph, InsertBackControl) { - // Control expression are inserted to the back without any modifications. - auto expr1 = Expression::CreateIf(ExpressionId(0)); - auto expr2 = Expression::CreateElse(); - auto expr3 = Expression::CreateEndIf(); - - ExpressionGraph graph; - graph.InsertBack(expr1); - graph.InsertBack(expr2); - graph.InsertBack(expr3); - - EXPECT_EQ(graph.Size(), 3); - EXPECT_EQ(graph.ExpressionForId(0), expr1); - EXPECT_EQ(graph.ExpressionForId(1), expr2); - EXPECT_EQ(graph.ExpressionForId(2), expr3); -} - -TEST(ExpressionGraph, InsertBackNewVariable) { - // If an arithmetic expression with lhs=kinvalidValue is inserted in the back, - // then a new variable name is created and set to the lhs_id. - auto expr1 = Expression::CreateCompileTimeConstant(42); - auto expr2 = Expression::CreateCompileTimeConstant(10); - auto expr3 = - Expression::CreateBinaryArithmetic("+", ExpressionId(0), ExpressionId(1)); - - ExpressionGraph graph; - graph.InsertBack(expr1); - graph.InsertBack(expr2); - graph.InsertBack(expr3); - EXPECT_EQ(graph.Size(), 3); - - // The ExpressionGraph has a copy of the inserted expression with the correct - // lhs_ids. We set them here manually for comparision. - expr1.set_lhs_id(0); - expr2.set_lhs_id(1); - expr3.set_lhs_id(2); - EXPECT_EQ(graph.ExpressionForId(0), expr1); - EXPECT_EQ(graph.ExpressionForId(1), expr2); - EXPECT_EQ(graph.ExpressionForId(2), expr3); -} - -TEST(ExpressionGraph, InsertBackExistingVariable) { - auto expr1 = Expression::CreateCompileTimeConstant(42); - auto expr2 = Expression::CreateCompileTimeConstant(10); - auto expr3 = Expression::CreateAssignment(1, 0); - - ExpressionGraph graph; - graph.InsertBack(expr1); - graph.InsertBack(expr2); - graph.InsertBack(expr3); - EXPECT_EQ(graph.Size(), 3); - - expr1.set_lhs_id(0); - expr2.set_lhs_id(1); - expr3.set_lhs_id(1); - EXPECT_EQ(graph.ExpressionForId(0), expr1); - EXPECT_EQ(graph.ExpressionForId(1), expr2); - EXPECT_EQ(graph.ExpressionForId(2), expr3); -} - -TEST(ExpressionGraph, DependsOn) { - ExpressionGraph graph; - graph.InsertBack(Expression::CreateCompileTimeConstant(42)); - graph.InsertBack(Expression::CreateCompileTimeConstant(10)); - graph.InsertBack(Expression::CreateBinaryArithmetic( - "+", ExpressionId(0), ExpressionId(1))); - graph.InsertBack(Expression::CreateBinaryArithmetic( - "+", ExpressionId(2), ExpressionId(0))); - - // Code: - // v_0 = 42 - // v_1 = 10 - // v_2 = v_0 + v_1 - // v_3 = v_2 + v_0 - - // Direct dependencies dependency check - ASSERT_TRUE(graph.DependsOn(2, 0)); - ASSERT_TRUE(graph.DependsOn(2, 1)); - ASSERT_TRUE(graph.DependsOn(3, 2)); - ASSERT_TRUE(graph.DependsOn(3, 0)); - ASSERT_FALSE(graph.DependsOn(1, 0)); - ASSERT_FALSE(graph.DependsOn(1, 1)); - ASSERT_FALSE(graph.DependsOn(2, 3)); - - // Recursive test - ASSERT_TRUE(graph.DependsOn(3, 1)); -} - -TEST(ExpressionGraph, FindMatchingEndif) { - ExpressionGraph graph; - graph.InsertBack(Expression::CreateCompileTimeConstant(1)); - graph.InsertBack(Expression::CreateCompileTimeConstant(2)); - graph.InsertBack(Expression::CreateBinaryCompare("<", 0, 1)); - graph.InsertBack(Expression::CreateIf(2)); - graph.InsertBack(Expression::CreateIf(2)); - graph.InsertBack(Expression::CreateElse()); - graph.InsertBack(Expression::CreateEndIf()); - graph.InsertBack(Expression::CreateElse()); - graph.InsertBack(Expression::CreateIf(2)); - graph.InsertBack(Expression::CreateEndIf()); - graph.InsertBack(Expression::CreateEndIf()); - graph.InsertBack(Expression::CreateIf(2)); // < if without matching endif - EXPECT_EQ(graph.Size(), 12); - - // Code <id> - // v_0 = 1 0 - // v_1 = 2 1 - // v_2 = v_0 < v_1 2 - // IF (v_2) 3 - // IF (v_2) 4 - // ELSE 5 - // ENDIF 6 - // ELSE 7 - // IF (v_2) 8 - // ENDIF 9 - // ENDIF 10 - // IF(v_2) 11 - - EXPECT_EQ(graph.FindMatchingEndif(3), 10); - EXPECT_EQ(graph.FindMatchingEndif(4), 6); - EXPECT_EQ(graph.FindMatchingEndif(8), 9); - EXPECT_EQ(graph.FindMatchingEndif(11), kInvalidExpressionId); -} - -TEST(ExpressionGraph, FindMatchingElse) { - ExpressionGraph graph; - graph.InsertBack(Expression::CreateCompileTimeConstant(1)); - graph.InsertBack(Expression::CreateCompileTimeConstant(2)); - graph.InsertBack(Expression::CreateBinaryCompare("<", 0, 1)); - graph.InsertBack(Expression::CreateIf(2)); - graph.InsertBack(Expression::CreateIf(2)); - graph.InsertBack(Expression::CreateElse()); - graph.InsertBack(Expression::CreateEndIf()); - graph.InsertBack(Expression::CreateElse()); - graph.InsertBack(Expression::CreateIf(2)); - graph.InsertBack(Expression::CreateEndIf()); - graph.InsertBack(Expression::CreateEndIf()); - graph.InsertBack(Expression::CreateIf(2)); // < if without matching endif - EXPECT_EQ(graph.Size(), 12); - - // Code <id> - // v_0 = 1 0 - // v_1 = 2 1 - // v_2 = v_0 < v_1 2 - // IF (v_2) 3 - // IF (v_2) 4 - // ELSE 5 - // ENDIF 6 - // ELSE 7 - // IF (v_2) 8 - // ENDIF 9 - // ENDIF 10 - // IF(v_2) 11 - - EXPECT_EQ(graph.FindMatchingElse(3), 7); - EXPECT_EQ(graph.FindMatchingElse(4), 5); - EXPECT_EQ(graph.FindMatchingElse(8), kInvalidExpressionId); - EXPECT_EQ(graph.FindMatchingEndif(11), kInvalidExpressionId); -} - -TEST(ExpressionGraph, InsertExpression_UpdateReferences) { - // This test checks if references to shifted expressions are updated - // accordingly. - ExpressionGraph graph; - graph.InsertBack(Expression::CreateCompileTimeConstant(42)); - graph.InsertBack(Expression::CreateCompileTimeConstant(10)); - graph.InsertBack(Expression::CreateBinaryArithmetic( - "+", ExpressionId(0), ExpressionId(1))); - // Code: - // v_0 = 42 - // v_1 = 10 - // v_2 = v_0 + v_1 - - // Insert another compile time constant at the beginning - graph.Insert(0, Expression::CreateCompileTimeConstant(5)); - // This should shift all indices like this: - // v_0 = 5 - // v_1 = 42 - // v_2 = 10 - // v_3 = v_1 + v_2 - - // Test by inserting it in the correct order - ExpressionGraph ref; - ref.InsertBack(Expression::CreateCompileTimeConstant(5)); - ref.InsertBack(Expression::CreateCompileTimeConstant(42)); - ref.InsertBack(Expression::CreateCompileTimeConstant(10)); - ref.InsertBack(Expression::CreateBinaryArithmetic( - "+", ExpressionId(1), ExpressionId(2))); - EXPECT_EQ(graph.Size(), ref.Size()); - EXPECT_EQ(graph, ref); -} - -TEST(ExpressionGraph, Erase) { - // This test checks if references to shifted expressions are updated - // accordingly. - ExpressionGraph graph; - graph.InsertBack(Expression::CreateCompileTimeConstant(42)); - graph.InsertBack(Expression::CreateCompileTimeConstant(10)); - graph.InsertBack(Expression::CreateCompileTimeConstant(3)); - graph.InsertBack(Expression::CreateBinaryArithmetic( - "+", ExpressionId(0), ExpressionId(2))); - // Code: - // v_0 = 42 - // v_1 = 10 - // v_2 = 3 - // v_3 = v_0 + v_2 - - // Erase the unused expression v_1 = 10 - graph.Erase(1); - // This should shift all indices like this: - // v_0 = 42 - // v_1 = 3 - // v_2 = v_0 + v_1 - - // Test by inserting it in the correct order - ExpressionGraph ref; - ref.InsertBack(Expression::CreateCompileTimeConstant(42)); - ref.InsertBack(Expression::CreateCompileTimeConstant(3)); - ref.InsertBack(Expression::CreateBinaryArithmetic( - "+", ExpressionId(0), ExpressionId(1))); - EXPECT_EQ(graph.Size(), ref.Size()); - EXPECT_EQ(graph, ref); -} - -} // namespace internal -} // namespace ceres
diff --git a/internal/ceres/codegen/expression_ref_test.cc b/internal/ceres/codegen/expression_ref_test.cc deleted file mode 100644 index c9a118b..0000000 --- a/internal/ceres/codegen/expression_ref_test.cc +++ /dev/null
@@ -1,411 +0,0 @@ -// Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2019 Google Inc. All rights reserved. -// http://code.google.com/p/ceres-solver/ -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// * Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// * Neither the name of Google Inc. nor the names of its contributors may be -// used to endorse or promote products derived from this software without -// specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// -// Author: darius.rueckert@fau.de (Darius Rueckert) -// -// This file tests the ExpressionRef class. This test depends on the -// correctness of Expression and ExpressionGraph. -// -#define CERES_CODEGEN - -#include "ceres/codegen/internal/expression_ref.h" - -#include "ceres/codegen/internal/expression_graph.h" -#include "gtest/gtest.h" - -namespace ceres { -namespace internal { - -using T = ExpressionRef; - -TEST(ExpressionRef, COMPILE_TIME_CONSTANT) { - StartRecordingExpressions(); - T a = T(0); - T b = T(123.5); - T c = T(1 + 1); - T d; // Uninitialized variables are also compile time constants - auto graph = StopRecordingExpressions(); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(0)); - reference.InsertBack(Expression::CreateCompileTimeConstant(123.5)); - reference.InsertBack(Expression::CreateCompileTimeConstant(2)); - reference.InsertBack(Expression::CreateCompileTimeConstant(0)); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, INPUT_ASSIGNMENT) { - double local_variable = 5.0; - StartRecordingExpressions(); - T a = CERES_LOCAL_VARIABLE(T, local_variable); - T b = MakeParameter("parameters[0][0]"); - auto graph = StopRecordingExpressions(); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateInputAssignment("local_variable")); - reference.InsertBack(Expression::CreateInputAssignment("parameters[0][0]")); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, OUTPUT_ASSIGNMENT) { - StartRecordingExpressions(); - T a = 1; - T b = 0; - MakeOutput(a, "residual[0]"); - auto graph = StopRecordingExpressions(); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(1)); - reference.InsertBack(Expression::CreateCompileTimeConstant(0)); - reference.InsertBack(Expression::CreateOutputAssignment(0, "residual[0]")); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, Assignment) { - StartRecordingExpressions(); - T a = 1; - T b = 2; - b = a; - auto graph = StopRecordingExpressions(); - EXPECT_EQ(graph.Size(), 3); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(1)); - reference.InsertBack(Expression::CreateCompileTimeConstant(2)); - reference.InsertBack(Expression::CreateAssignment(1, 0)); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, AssignmentCreate) { - StartRecordingExpressions(); - T a = 2; - T b = a; - auto graph = StopRecordingExpressions(); - EXPECT_EQ(graph.Size(), 2); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(2)); - reference.InsertBack(Expression::CreateAssignment(kInvalidExpressionId, 0)); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, MoveAssignment) { - StartRecordingExpressions(); - T a = 1; - T b = 2; - b = std::move(a); - auto graph = StopRecordingExpressions(); - EXPECT_EQ(graph.Size(), 3); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(1)); - reference.InsertBack(Expression::CreateCompileTimeConstant(2)); - reference.InsertBack(Expression::CreateAssignment(1, 0)); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, BINARY_ARITHMETIC_SIMPLE) { - StartRecordingExpressions(); - T a = T(1); - T b = T(2); - T r1 = a + b; - T r2 = a - b; - T r3 = a * b; - T r4 = a / b; - auto graph = StopRecordingExpressions(); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(1)); - reference.InsertBack(Expression::CreateCompileTimeConstant(2)); - reference.InsertBack(Expression::CreateBinaryArithmetic("+", 0, 1)); - reference.InsertBack(Expression::CreateBinaryArithmetic("-", 0, 1)); - reference.InsertBack(Expression::CreateBinaryArithmetic("*", 0, 1)); - reference.InsertBack(Expression::CreateBinaryArithmetic("/", 0, 1)); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, BINARY_ARITHMETIC_NESTED) { - StartRecordingExpressions(); - T a = T(1); - T b = T(2); - T r1 = b - a * (a + b) / a; - auto graph = StopRecordingExpressions(); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(1)); - reference.InsertBack(Expression::CreateCompileTimeConstant(2)); - reference.InsertBack(Expression::CreateBinaryArithmetic("+", 0, 1)); - reference.InsertBack(Expression::CreateBinaryArithmetic("*", 0, 2)); - reference.InsertBack(Expression::CreateBinaryArithmetic("/", 3, 0)); - reference.InsertBack(Expression::CreateBinaryArithmetic("-", 1, 4)); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, BINARY_ARITHMETIC_COMPOUND) { - // For each binary compound arithmetic operation, two lines are generated: - // - The actual operation assigning to a new temporary variable - // - An assignment from the temporary to the lhs - StartRecordingExpressions(); - T a = T(1); - T b = T(2); - a += b; - a -= b; - a *= b; - a /= b; - auto graph = StopRecordingExpressions(); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(1)); - reference.InsertBack(Expression::CreateCompileTimeConstant(2)); - reference.InsertBack(Expression::CreateBinaryArithmetic("+", 0, 1)); - reference.InsertBack(Expression::CreateAssignment(0, 2)); - reference.InsertBack(Expression::CreateBinaryArithmetic("-", 0, 1)); - reference.InsertBack(Expression::CreateAssignment(0, 4)); - reference.InsertBack(Expression::CreateBinaryArithmetic("*", 0, 1)); - reference.InsertBack(Expression::CreateAssignment(0, 6)); - reference.InsertBack(Expression::CreateBinaryArithmetic("/", 0, 1)); - reference.InsertBack(Expression::CreateAssignment(0, 8)); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, UNARY_ARITHMETIC) { - StartRecordingExpressions(); - T a = T(1); - T r1 = -a; - T r2 = +a; - auto graph = StopRecordingExpressions(); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(1)); - reference.InsertBack(Expression::CreateUnaryArithmetic("-", 0)); - reference.InsertBack(Expression::CreateUnaryArithmetic("+", 0)); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, BINARY_COMPARISON) { - using BOOL = ComparisonExpressionRef; - StartRecordingExpressions(); - T a = T(1); - T b = T(2); - BOOL r1 = a < b; - BOOL r2 = a <= b; - BOOL r3 = a > b; - BOOL r4 = a >= b; - BOOL r5 = a == b; - BOOL r6 = a != b; - auto graph = StopRecordingExpressions(); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(1)); - reference.InsertBack(Expression::CreateCompileTimeConstant(2)); - reference.InsertBack(Expression::CreateBinaryCompare("<", 0, 1)); - reference.InsertBack(Expression::CreateBinaryCompare("<=", 0, 1)); - reference.InsertBack(Expression::CreateBinaryCompare(">", 0, 1)); - reference.InsertBack(Expression::CreateBinaryCompare(">=", 0, 1)); - reference.InsertBack(Expression::CreateBinaryCompare("==", 0, 1)); - reference.InsertBack(Expression::CreateBinaryCompare("!=", 0, 1)); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, LOGICAL_OPERATORS) { - using BOOL = ComparisonExpressionRef; - // Tests binary logical operators &&, || and the unary logical operator ! - StartRecordingExpressions(); - T a = T(1); - T b = T(2); - BOOL r1 = a < b; - BOOL r2 = a <= b; - BOOL r3 = r1 && r2; - BOOL r4 = r1 || r2; - BOOL r5 = !r1; - auto graph = StopRecordingExpressions(); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(1)); - reference.InsertBack(Expression::CreateCompileTimeConstant(2)); - reference.InsertBack(Expression::CreateBinaryCompare("<", 0, 1)); - reference.InsertBack(Expression::CreateBinaryCompare("<=", 0, 1)); - reference.InsertBack(Expression::CreateBinaryCompare("&&", 2, 3)); - reference.InsertBack(Expression::CreateBinaryCompare("||", 2, 3)); - reference.InsertBack(Expression::CreateLogicalNegation(2)); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, SCALAR_FUNCTION_CALL) { - StartRecordingExpressions(); - T a = T(1); - T b = T(2); - abs(a); - acos(a); - asin(a); - atan(a); - cbrt(a); - ceil(a); - cos(a); - cosh(a); - exp(a); - exp2(a); - floor(a); - log(a); - log2(a); - sin(a); - sinh(a); - sqrt(a); - tan(a); - tanh(a); - atan2(a, b); - pow(a, b); - auto graph = StopRecordingExpressions(); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(1)); - reference.InsertBack(Expression::CreateCompileTimeConstant(2)); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::abs", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::acos", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::asin", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::atan", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::cbrt", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::ceil", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::cos", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::cosh", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::exp", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::exp2", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::floor", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::log", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::log2", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::sin", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::sinh", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::sqrt", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::tan", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("std::tanh", {0})); - reference.InsertBack( - Expression::CreateScalarFunctionCall("std::atan2", {0, 1})); - reference.InsertBack( - Expression::CreateScalarFunctionCall("std::pow", {0, 1})); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, LOGICAL_FUNCTION_CALL) { - StartRecordingExpressions(); - T a = T(1); - isfinite(a); - isinf(a); - isnan(a); - isnormal(a); - auto graph = StopRecordingExpressions(); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(1)); - reference.InsertBack( - Expression::CreateLogicalFunctionCall("std::isfinite", {0})); - reference.InsertBack( - Expression::CreateLogicalFunctionCall("std::isinf", {0})); - reference.InsertBack( - Expression::CreateLogicalFunctionCall("std::isnan", {0})); - reference.InsertBack( - Expression::CreateLogicalFunctionCall("std::isnormal", {0})); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, IF) { - StartRecordingExpressions(); - T a = T(1); - T b = T(2); - auto r1 = a < b; - CERES_IF(r1) {} - CERES_ENDIF; - auto graph = StopRecordingExpressions(); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(1)); - reference.InsertBack(Expression::CreateCompileTimeConstant(2)); - reference.InsertBack(Expression::CreateBinaryCompare("<", 0, 1)); - reference.InsertBack(Expression::CreateIf(2)); - reference.InsertBack(Expression::CreateEndIf()); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, IF_ELSE) { - StartRecordingExpressions(); - T a = T(1); - T b = T(2); - auto r1 = a < b; - CERES_IF(r1) {} - CERES_ELSE {} - CERES_ENDIF; - auto graph = StopRecordingExpressions(); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(1)); - reference.InsertBack(Expression::CreateCompileTimeConstant(2)); - reference.InsertBack(Expression::CreateBinaryCompare("<", 0, 1)); - reference.InsertBack(Expression::CreateIf(2)); - reference.InsertBack(Expression::CreateElse()); - reference.InsertBack(Expression::CreateEndIf()); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, IF_NESTED) { - StartRecordingExpressions(); - T a = T(1); - T b = T(2); - auto r1 = a < b; - auto r2 = a == b; - CERES_IF(r1) { - CERES_IF(r2) {} - CERES_ENDIF; - } - CERES_ELSE {} - CERES_ENDIF; - auto graph = StopRecordingExpressions(); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateCompileTimeConstant(1)); - reference.InsertBack(Expression::CreateCompileTimeConstant(2)); - reference.InsertBack(Expression::CreateBinaryCompare("<", 0, 1)); - reference.InsertBack(Expression::CreateBinaryCompare("==", 0, 1)); - reference.InsertBack(Expression::CreateIf(2)); - reference.InsertBack(Expression::CreateIf(3)); - reference.InsertBack(Expression::CreateEndIf()); - reference.InsertBack(Expression::CreateElse()); - reference.InsertBack(Expression::CreateEndIf()); - EXPECT_EQ(reference, graph); -} - -TEST(ExpressionRef, COMMENT) { - StartRecordingExpressions(); - CERES_COMMENT("This is a comment"); - auto graph = StopRecordingExpressions(); - - ExpressionGraph reference; - reference.InsertBack(Expression::CreateComment("This is a comment")); - EXPECT_EQ(reference, graph); -} - -} // namespace internal -} // namespace ceres
diff --git a/internal/ceres/codegen/expression_test.cc b/internal/ceres/codegen/expression_test.cc deleted file mode 100644 index 65daf69..0000000 --- a/internal/ceres/codegen/expression_test.cc +++ /dev/null
@@ -1,325 +0,0 @@ -// Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2019 Google Inc. All rights reserved. -// http://code.google.com/p/ceres-solver/ -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// * Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// * Neither the name of Google Inc. nor the names of its contributors may be -// used to endorse or promote products derived from this software without -// specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// -// Author: darius.rueckert@fau.de (Darius Rueckert) -// -// This file tests the Expression class. For each member function one test is -// included here. -// -#include "ceres/codegen/internal/expression.h" -#include "gtest/gtest.h" - -namespace ceres { -namespace internal { - -TEST(Expression, ConstructorAndAccessors) { - Expression expr(ExpressionType::LOGICAL_NEGATION, - ExpressionReturnType::BOOLEAN, - 12345, - {1, 5, 8, 10}, - "TestConstructor", - 57.25); - EXPECT_EQ(expr.type(), ExpressionType::LOGICAL_NEGATION); - EXPECT_EQ(expr.return_type(), ExpressionReturnType::BOOLEAN); - EXPECT_EQ(expr.lhs_id(), 12345); - EXPECT_EQ(expr.arguments(), std::vector<ExpressionId>({1, 5, 8, 10})); - EXPECT_EQ(expr.name(), "TestConstructor"); - EXPECT_EQ(expr.value(), 57.25); -} - -TEST(Expression, CreateFunctions) { - // The default constructor creates a NOP! - EXPECT_EQ(Expression(), - Expression(ExpressionType::NOP, - ExpressionReturnType::VOID, - kInvalidExpressionId, - {}, - "", - 0)); - - EXPECT_EQ(Expression::CreateCompileTimeConstant(72), - Expression(ExpressionType::COMPILE_TIME_CONSTANT, - ExpressionReturnType::SCALAR, - kInvalidExpressionId, - {}, - "", - 72)); - - EXPECT_EQ(Expression::CreateInputAssignment("arguments[0][0]"), - Expression(ExpressionType::INPUT_ASSIGNMENT, - ExpressionReturnType::SCALAR, - kInvalidExpressionId, - {}, - "arguments[0][0]", - 0)); - - EXPECT_EQ(Expression::CreateOutputAssignment(ExpressionId(5), "residuals[3]"), - Expression(ExpressionType::OUTPUT_ASSIGNMENT, - ExpressionReturnType::SCALAR, - kInvalidExpressionId, - {5}, - "residuals[3]", - 0)); - - EXPECT_EQ(Expression::CreateAssignment(ExpressionId(3), ExpressionId(5)), - Expression(ExpressionType::ASSIGNMENT, - ExpressionReturnType::SCALAR, - 3, - {5}, - "", - 0)); - - EXPECT_EQ( - Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5)), - Expression(ExpressionType::BINARY_ARITHMETIC, - ExpressionReturnType::SCALAR, - kInvalidExpressionId, - {3, 5}, - "+", - 0)); - - EXPECT_EQ(Expression::CreateUnaryArithmetic("-", ExpressionId(5)), - Expression(ExpressionType::UNARY_ARITHMETIC, - ExpressionReturnType::SCALAR, - kInvalidExpressionId, - {5}, - "-", - 0)); - - EXPECT_EQ( - Expression::CreateBinaryCompare("<", ExpressionId(3), ExpressionId(5)), - Expression(ExpressionType::BINARY_COMPARISON, - ExpressionReturnType::BOOLEAN, - kInvalidExpressionId, - {3, 5}, - "<", - 0)); - - EXPECT_EQ(Expression::CreateLogicalNegation(ExpressionId(5)), - Expression(ExpressionType::LOGICAL_NEGATION, - ExpressionReturnType::BOOLEAN, - kInvalidExpressionId, - {5}, - "", - 0)); - - EXPECT_EQ(Expression::CreateScalarFunctionCall( - "pow", {ExpressionId(3), ExpressionId(5)}), - Expression(ExpressionType::FUNCTION_CALL, - ExpressionReturnType::SCALAR, - kInvalidExpressionId, - {3, 5}, - "pow", - 0)); - - EXPECT_EQ( - Expression::CreateLogicalFunctionCall("isfinite", {ExpressionId(3)}), - Expression(ExpressionType::FUNCTION_CALL, - ExpressionReturnType::BOOLEAN, - kInvalidExpressionId, - {3}, - "isfinite", - 0)); - - EXPECT_EQ(Expression::CreateIf(ExpressionId(5)), - Expression(ExpressionType::IF, - ExpressionReturnType::VOID, - kInvalidExpressionId, - {5}, - "", - 0)); - - EXPECT_EQ(Expression::CreateElse(), - Expression(ExpressionType::ELSE, - ExpressionReturnType::VOID, - kInvalidExpressionId, - {}, - "", - 0)); - - EXPECT_EQ(Expression::CreateEndIf(), - Expression(ExpressionType::ENDIF, - ExpressionReturnType::VOID, - kInvalidExpressionId, - {}, - "", - 0)); - - EXPECT_EQ(Expression::CreateComment("Test"), - Expression(ExpressionType::COMMENT, - ExpressionReturnType::VOID, - kInvalidExpressionId, - {}, - "Test", - 0)); -} - -TEST(Expression, IsArithmeticExpression) { - ASSERT_TRUE( - Expression::CreateCompileTimeConstant(5).IsArithmeticExpression()); - ASSERT_TRUE(Expression::CreateScalarFunctionCall("pow", {3, 5}) - .IsArithmeticExpression()); - // Logical expression are also arithmetic! - ASSERT_TRUE( - Expression::CreateBinaryCompare("<", 3, 5).IsArithmeticExpression()); - ASSERT_FALSE(Expression::CreateIf(5).IsArithmeticExpression()); - ASSERT_FALSE(Expression::CreateEndIf().IsArithmeticExpression()); - ASSERT_FALSE(Expression().IsArithmeticExpression()); -} - -TEST(Expression, IsControlExpression) { - // In the current implementation this is the exact opposite of - // IsArithmeticExpression. - ASSERT_FALSE(Expression::CreateCompileTimeConstant(5).IsControlExpression()); - ASSERT_FALSE(Expression::CreateScalarFunctionCall("pow", {3, 5}) - .IsControlExpression()); - ASSERT_FALSE( - Expression::CreateBinaryCompare("<", 3, 5).IsControlExpression()); - ASSERT_TRUE(Expression::CreateIf(5).IsControlExpression()); - ASSERT_TRUE(Expression::CreateEndIf().IsControlExpression()); - ASSERT_TRUE(Expression::CreateComment("Test").IsControlExpression()); - ASSERT_TRUE(Expression().IsControlExpression()); -} - -TEST(Expression, IsCompileTimeConstantAndEqualTo) { - ASSERT_TRUE( - Expression::CreateCompileTimeConstant(5).IsCompileTimeConstantAndEqualTo( - 5)); - ASSERT_FALSE( - Expression::CreateCompileTimeConstant(3).IsCompileTimeConstantAndEqualTo( - 5)); - ASSERT_FALSE(Expression::CreateBinaryCompare("<", 3, 5) - .IsCompileTimeConstantAndEqualTo(5)); -} - -TEST(Expression, IsReplaceableBy) { - // Create 2 identical expression - auto expr1 = - Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5)); - - auto expr2 = - Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5)); - - // They are idendical and of course replaceable - ASSERT_EQ(expr1, expr2); - ASSERT_EQ(expr2, expr1); - ASSERT_TRUE(expr1.IsReplaceableBy(expr2)); - ASSERT_TRUE(expr2.IsReplaceableBy(expr1)); - - // Give them different left hand sides - expr1.set_lhs_id(72); - expr2.set_lhs_id(42); - - // v_72 = v_3 + v_5 - // v_42 = v_3 + v_5 - // -> They should be replaceable by each other - - ASSERT_NE(expr1, expr2); - ASSERT_NE(expr2, expr1); - - ASSERT_TRUE(expr1.IsReplaceableBy(expr2)); - ASSERT_TRUE(expr2.IsReplaceableBy(expr1)); - - // A slightly differnt expression with the argument flipped - auto expr3 = - Expression::CreateBinaryArithmetic("+", ExpressionId(5), ExpressionId(3)); - - ASSERT_NE(expr1, expr3); - ASSERT_FALSE(expr1.IsReplaceableBy(expr3)); -} - -TEST(Expression, Replace) { - auto expr1 = - Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5)); - expr1.set_lhs_id(13); - - auto expr2 = - Expression::CreateAssignment(kInvalidExpressionId, ExpressionId(7)); - - // We replace the arithmetic expr1 by an assignment from the variable 7. This - // is the typical usecase in subexpression elimination. - expr1.Replace(expr2); - - // expr1 should now be an assignment from 7 to 13 - EXPECT_EQ(expr1, - Expression(ExpressionType::ASSIGNMENT, - ExpressionReturnType::SCALAR, - 13, - {7}, - "", - 0)); -} - -TEST(Expression, DirectlyDependsOn) { - auto expr1 = - Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5)); - - ASSERT_TRUE(expr1.DirectlyDependsOn(ExpressionId(3))); - ASSERT_TRUE(expr1.DirectlyDependsOn(ExpressionId(5))); - ASSERT_FALSE(expr1.DirectlyDependsOn(ExpressionId(kInvalidExpressionId))); - ASSERT_FALSE(expr1.DirectlyDependsOn(ExpressionId(42))); -} -TEST(Expression, MakeNop) { - auto expr1 = - Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5)); - - expr1.MakeNop(); - - EXPECT_EQ(expr1, - Expression(ExpressionType::NOP, - ExpressionReturnType::VOID, - kInvalidExpressionId, - {}, - "", - 0)); -} - -TEST(Expression, IsSemanticallyEquivalentTo) { - // Create 2 identical expression - auto expr1 = - Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5)); - - auto expr2 = - Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(5)); - - ASSERT_TRUE(expr1.IsSemanticallyEquivalentTo(expr1)); - ASSERT_TRUE(expr1.IsSemanticallyEquivalentTo(expr2)); - - auto expr3 = - Expression::CreateBinaryArithmetic("+", ExpressionId(3), ExpressionId(8)); - - ASSERT_TRUE(expr1.IsSemanticallyEquivalentTo(expr3)); - - auto expr4 = - Expression::CreateBinaryArithmetic("-", ExpressionId(3), ExpressionId(5)); - - ASSERT_FALSE(expr1.IsSemanticallyEquivalentTo(expr4)); -} - -} // namespace internal -} // namespace ceres
diff --git a/internal/ceres/codegen/test_utils.cc b/internal/ceres/codegen/test_utils.cc deleted file mode 100644 index 303c0db..0000000 --- a/internal/ceres/codegen/test_utils.cc +++ /dev/null
@@ -1,88 +0,0 @@ -// Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2020 Google Inc. All rights reserved. -// http://code.google.com/p/ceres-solver/ -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// * Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// * Neither the name of Google Inc. nor the names of its contributors may be -// used to endorse or promote products derived from this software without -// specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// -// Author: darius.rueckert@fau.de (Darius Rueckert) - -#include "ceres/codegen/test_utils.h" -#include "ceres/test_util.h" - -namespace ceres { -namespace internal { - -std::pair<std::vector<double>, std::vector<double> > EvaluateCostFunction( - CostFunction* cost_function, double value) { - auto num_residuals = cost_function->num_residuals(); - auto parameter_block_sizes = cost_function->parameter_block_sizes(); - auto num_parameter_blocks = parameter_block_sizes.size(); - - int total_num_parameters = 0; - for (auto block_size : parameter_block_sizes) { - total_num_parameters += block_size; - } - - std::vector<double> params_array(total_num_parameters, value); - std::vector<double*> params(num_parameter_blocks); - std::vector<double> residuals(num_residuals, 0); - std::vector<double> jacobians_array(num_residuals * total_num_parameters, 0); - std::vector<double*> jacobians(num_parameter_blocks); - - for (int i = 0, k = 0; i < num_parameter_blocks; - k += parameter_block_sizes[i], ++i) { - params[i] = ¶ms_array[k]; - } - - for (int i = 0, k = 0; i < num_parameter_blocks; - k += parameter_block_sizes[i], ++i) { - jacobians[i] = &jacobians_array[k * num_residuals]; - } - - cost_function->Evaluate(params.data(), residuals.data(), jacobians.data()); - - return std::make_pair(residuals, jacobians_array); -} - -void CompareCostFunctions(CostFunction* cost_function1, - CostFunction* cost_function2, - - double value, - double tol) { - auto residuals_and_jacobians_1 = EvaluateCostFunction(cost_function1, value); - auto residuals_and_jacobians_2 = EvaluateCostFunction(cost_function2, value); - - ExpectArraysClose(residuals_and_jacobians_1.first.size(), - residuals_and_jacobians_1.first.data(), - residuals_and_jacobians_2.first.data(), - tol); - ExpectArraysClose(residuals_and_jacobians_1.second.size(), - residuals_and_jacobians_1.second.data(), - residuals_and_jacobians_2.second.data(), - tol); -} - -} // namespace internal -} // namespace ceres
diff --git a/internal/ceres/codegen/test_utils.h b/internal/ceres/codegen/test_utils.h deleted file mode 100644 index eaaf438..0000000 --- a/internal/ceres/codegen/test_utils.h +++ /dev/null
@@ -1,82 +0,0 @@ -// Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2020 Google Inc. All rights reserved. -// http://code.google.com/p/ceres-solver/ -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// * Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// * Neither the name of Google Inc. nor the names of its contributors may be -// used to endorse or promote products derived from this software without -// specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// -// Author: darius.rueckert@fau.de (Darius Rueckert) - -#ifndef CERES_INTERNAL_CODEGEN_TEST_UTILS_H_ -#define CERES_INTERNAL_CODEGEN_TEST_UTILS_H_ - -#include "ceres/internal/autodiff.h" -#include "ceres/random.h" -#include "ceres/sized_cost_function.h" - -namespace ceres { -namespace internal { - -// CodegenCostFunctions have both, an templated operator() and the Evaluate() -// function. The operator() is used during code generation and Evaluate() is -// used during execution. -// -// If we want to use the operator() during execution (with autodiff) this -// wrapper class here has to be used. Autodiff doesn't support functors that -// have an Evaluate() function. -// -// CostFunctionToFunctor hides the Evaluate() function, because it doesn't -// derive from CostFunction. Autodiff sees it as a simple functor and will use -// the operator() as expected. -template <typename CostFunction> -struct CostFunctionToFunctor { - template <typename... _Args> - CostFunctionToFunctor(_Args&&... __args) - : cost_function(std::forward<_Args>(__args)...) {} - - template <typename... _Args> - bool operator()(_Args&&... __args) const { - return cost_function(std::forward<_Args>(__args)...); - } - - CostFunction cost_function; -}; - -// Evaluate a cost function and return the residuals and jacobians. -// All parameters are set to 'value'. -std::pair<std::vector<double>, std::vector<double>> EvaluateCostFunction( - CostFunction* cost_function, double value); - -// Evaluates the two cost functions using the method above and then compares the -// result. The comparison uses GTEST expect macros so this function should be -// called from a test enviroment. -void CompareCostFunctions(CostFunction* cost_function1, - CostFunction* cost_function2, - double value, - double tol); - -} // namespace internal -} // namespace ceres - -#endif