Autodiff Codegen Part 2: Conditionals - Expression types for if/else blocks - CERES_IF/ELSE macros - Multiple assignment to the same variable Change-Id: If529516f243f31823d1ef7b8827bb6f2390e418d
diff --git a/include/ceres/internal/expression.h b/include/ceres/internal/expression.h index c990d93..5ab2d62 100644 --- a/include/ceres/internal/expression.h +++ b/include/ceres/internal/expression.h
@@ -28,17 +28,144 @@ // // Author: darius.rueckert@fau.de (Darius Rueckert) // +// During code generation, your cost functor is converted into a list of +// expressions stored in an expression graph. For each operator (+,-,=,...), +// function call (sin,cos,...), and special keyword (if,else,...) the +// appropriate ExpressionType is selected. On a high level all ExpressionTypes +// are grouped into two different classes: Arithmetic expressions and control +// expressions. // -// This file contains the basic expression type, which is used during code -// generation. Only assignment expressions of the following form are supported: +// Part 1: Arithmetic Expressions // -// result = [constant|binary_expr|functioncall] +// Arithmetic expression are the most basic and common types. They are all of +// the following form: // -// Examples: -// v_78 = v_28 / v_62; -// v_97 = exp(v_20); -// v_89 = 3.000000; +// <lhs> = <rhs> // +// <lhs> is the variable name on the left hand side of the assignment. <rhs> can +// be different depending on the ExpressionType. It must evaluate to a single +// scalar value though. Here are a few examples of arithmetic expressions (the +// ExpressionType is given on the right): +// +// v_0 = 3.1415; // COMPILE_TIME_CONSTANT +// v_1 = v_0; // ASSIGNMENT +// v_2 = v_0 + v_1; // PLUS +// v_3 = v_2 / v_0; // DIVISION +// v_4 = sin(v_3); // FUNCTION_CALL +// v_5 = v_4 < v_3; // BINARY_COMPARISON +// +// As you can see, the right hand side of each expression contains exactly one +// operator/value/function call. If you write long expressions like +// +// T c = a + b - T(3) * a; +// +// it will broken up into the individual expressions like so: +// +// v_0 = a + b; +// v_1 = 3; +// v_2 = v_1 * a; +// c = v_0 - v_2; +// +// All arithmetic expressions are generated by operator and function +// overloading. These overloads are defined in expression_ref.h. +// +// +// +// Part 2: Control Expressions +// +// Control expressions include special instructions that handle the control flow +// of a program. So far, only if/else is supported, but while/for might come in +// the future. +// +// Generating code for conditional jumps (if/else) is more complicated than +// for arithmetic expressions. Let's look at a small example to see the +// problems. After that we explain how these problems are solved in Ceres. +// +// 1 T a = parameters[0][0]; +// 2 T b = 1.0; +// 3 if (a < b) { +// 4 b = 3.0; +// 5 } else { +// 6 b = 4.0; +// 7 } +// 8 b += 1.0; +// 9 residuals[0] = b; +// +// Problem 1. +// We need to generate code for both branches. In C++ there is no way to execute +// both branches of an if, but we need to execute them to generate the code. +// +// Problem 2. +// The comparison a < b in line 3 is not convertible to bool. Since the value of +// a is not known during code generation, the expression a < b can not be +// evaluated. In fact, a < b will return an expression of type +// BINARY_COMPARISON. +// +// Problem 3. +// There is no way to record that an if was executed. "if" is a special operator +// which cannot be overloaded. Therefore we can't generate code that contains +// "if. +// +// Problem 4. +// We have no information about "blocks" or "scopes" during code generation. +// Even if we could overload the if-operator, there is now way to capture which +// expression was executed in which branches of the if. For example, we generate +// code for the else branch. How can we know that the else branch is finished? +// Is line 8 inside the else-block or already outside? +// +// Solution. +// Instead of using the keywords if/else we insert the macros +// CERES_IF, CERES_ELSE and CERES_ENDIF. These macros just map to a function, +// which inserts an expression into the graph. Here is how the example from +// above looks like with the expanded macros: +// +// 1 T a = parameters[0][0]; +// 2 T b = 1.0; +// 3 CreateIf(a < b); { +// 4 b = 3.0; +// 5 } CreateElse(); { +// 6 b = 4.0; +// 7 } CreateEndif(); +// 8 b += 1.0; +// 9 residuals[0] = b; +// +// Problem 1 solved. +// There are no branches during code generation, therefore both blocks are +// evaluated. +// +// Problem 2 solved. +// The function CreateIf(_) does not take a bool as argument, but an +// ComparisonExpression. Later during code generation an actual "if" is created +// with the condition as argument. +// +// Problem 3 solved. +// We replaced "if" by a function call so we can record it now. +// +// Problem 4 solved. +// Expressions are added into the graph in the correct order. That means, after +// seeing a CreateIf() we know that all following expressions until CreateElse() +// belong to the true-branch. Similar, all expression from CreateElse() to +// CreateEndif() belong to the false-branch. This also works recursively with +// nested ifs. +// +// If you want to use the AutoDiff code generation for your cost functors, you +// have to replace all if/else by the CERES_IF, CERES_ELSE and CERES_ENDIF +// macros. The example from above looks like this: +// +// 1 T a = parameters[0][0]; +// 2 T b = 1.0; +// 3 CERES_IF (a < b) { +// 4 b = 3.0; +// 5 } CERES_ELSE { +// 6 b = 4.0; +// 7 } CERES_ENDIF; +// 8 b += 1.0; +// 9 residuals[0] = b; +// +// These macros don't have a negative impact on performance, because they only +// expand to the CreateIf/.. functions in code generation mode. Otherwise they +// expand to the if/else keywords. See expression_ref.h for the exact +// definition. // #ifndef CERES_PUBLIC_EXPRESSION_H_ #define CERES_PUBLIC_EXPRESSION_H_ @@ -68,26 +195,25 @@ // residual[0] = v_51; OUTPUT_ASSIGNMENT, - // Trivial Assignment - // v_1 = v_0; + // Trivial assignment + // v_3 = v_1 ASSIGNMENT, // Binary Arithmetic Operations // v_2 = v_0 + v_1 - PLUS, - MINUS, - MULTIPLICATION, - DIVISION, + // The operator is stored in Expression::name_. + BINARY_ARITHMETIC, // Unary Arithmetic Operation // v_1 = -(v_0); // v_2 = +(v_1); - UNARY_MINUS, - UNARY_PLUS, + // The operator is stored in Expression::name_. + UNARY_ARITHMETIC, // Binary Comparison. (<,>,&&,...) // This is the only expressions which returns a 'bool'. - // const bool v_2 = v_0 < v_1 + // v_2 = v_0 < v_1 + // The operator is stored in Expression::name_. BINARY_COMPARISON, // The !-operator on logical expression. @@ -102,6 +228,12 @@ // v_3 = ternary(v_0,v_1,v_2); TERNARY, + // Conditional control expressions if/else/endif. + // These are special expressions, because they don't define a new variable. + IF, + ELSE, + ENDIF, + // No Operation. A placeholder for an 'empty' expressions which will be // optimized out during code generation. NOP @@ -129,11 +261,11 @@ static ExpressionId CreateParameter(const std::string& name); static ExpressionId CreateOutputAssignment(ExpressionId v, const std::string& name); - static ExpressionId CreateAssignment(ExpressionId v); - static ExpressionId CreateBinaryArithmetic(ExpressionType type, + static ExpressionId CreateAssignment(ExpressionId dst, ExpressionId src); + static ExpressionId CreateBinaryArithmetic(const std::string& op, ExpressionId l, ExpressionId r); - static ExpressionId CreateUnaryArithmetic(ExpressionType type, + static ExpressionId CreateUnaryArithmetic(const std::string& op, ExpressionId v); static ExpressionId CreateBinaryCompare(const std::string& name, ExpressionId l, @@ -145,9 +277,19 @@ ExpressionId if_true, ExpressionId if_false); - // Returns true if the expression type is one of the basic math-operators: - // +,-,*,/ - bool IsArithmetic() const; + // Conditional control expressions are inserted into the graph but can't be + // referenced by other expressions. Therefore they don't return an + // ExpressionId. + static void CreateIf(ExpressionId condition); + static void CreateElse(); + static void CreateEndIf(); + + // Returns true if this is an arithmetic expression. + // Arithmetic expressions must have a valid left hand side. + bool IsArithmeticExpression() const; + + // Returns true if this is a control expression. + bool IsControlExpression() const; // If this expression is the compile time constant with the given value. // Used during optimization to collapse zero/one arithmetic operations. @@ -170,16 +312,36 @@ // Converts this expression into a NOP void MakeNop(); + // Returns true if this expression has a valid lhs. + bool HasValidLhs() const { return lhs_id_ != kInvalidExpressionId; } + + ExpressionType type() const { return type_; } + ExpressionId lhs_id() const { return lhs_id_; } + double value() const { return value_; } + const std::string& name() const { return name_; } + const std::vector<ExpressionId>& arguments() const { return arguments_; } + private: // Only ExpressionGraph is allowed to call the constructor, because it manages // the memory and ids. friend class ExpressionGraph; // Private constructor. Use the "CreateXX" functions instead. - Expression(ExpressionType type, ExpressionId id); + Expression(ExpressionType type, ExpressionId lhs_id); ExpressionType type_ = ExpressionType::NOP; - const ExpressionId id_ = kInvalidExpressionId; + + // If lhs_id_ >= 0, then this expression is assigned to v_<lhs_id>. + // For example: + // v_1 = v_0 + v_0 (Type = PLUS) + // v_3 = sin(v_1) (Type = FUNCTION_CALL) + // ^ + // lhs_id_ + // + // If lhs_id_ == kInvalidExpressionId, then the expression type is not + // arithmetic. Currently, only the following types have lhs_id = invalid: + // IF,ELSE,ENDIF,NOP + const ExpressionId lhs_id_ = kInvalidExpressionId; // Expressions have different number of arguments. For example a binary "+" // has 2 parameters and a function call to "sin" has 1 parameter. Here, a
diff --git a/include/ceres/internal/expression_graph.h b/include/ceres/internal/expression_graph.h index 446fddb..308528f 100644 --- a/include/ceres/internal/expression_graph.h +++ b/include/ceres/internal/expression_graph.h
@@ -48,9 +48,25 @@ // A is parent of B <=> A has B as a parameter <=> A.DirectlyDependsOn(B) class ExpressionGraph { public: - // Creates an expression and adds it to expressions_. - // The returned reference will be invalid after this function is called again. - Expression& CreateExpression(ExpressionType type); + // Creates an arithmetic expression of the following form: + // <lhs> = <rhs>; + // + // For example: + // CreateArithmeticExpression(PLUS, 5) + // will generate: + // v_5 = __ + __; + // The place holders are then set by the CreateXX functions of Expression. + // + // If lhs_id == kInvalidExpressionId, then a new lhs_id will be generated and + // assigned to the created expression. + // Calling this function with a lhs_id that doesn't exist results in an + // error. + Expression& CreateArithmeticExpression(ExpressionType type, + ExpressionId lhs_id); + + // Control expression don't have a left hand side. + // Supported types: IF/ELSE/ENDIF/NOP + Expression& CreateControlExpression(ExpressionType type); // Checks if A depends on B. // -> B is a descendant of A
diff --git a/include/ceres/internal/expression_ref.h b/include/ceres/internal/expression_ref.h index 67ff227..f4b920f 100644 --- a/include/ceres/internal/expression_ref.h +++ b/include/ceres/internal/expression_ref.h
@@ -52,8 +52,26 @@ // it's automatically converted to the correct expression. explicit ExpressionRef(double compile_time_constant); - // Returns v_id - std::string ToString() const; + // Create an ASSIGNMENT expression from other to this. + // + // For example: + // a = b; // With a.id = 5 and b.id = 3 + // will generate the following assignment: + // v_5 = v_3; + // + // If this (lhs) ExpressionRef is currently not pointing to a variable + // (id==invalid), then we can eliminate the assignment by just letting "this" + // point to the same variable as "other". + // + // Example: + // a = b; // With a.id = invalid and b.id = 3 + // will generate NO expression, but after this line the following will be + // true: + // a.id == b.id == 3 + // + // If 'other' is not pointing to a variable (id==invalid), we found an + // uninitialized assignment, which is handled as an error. + ExpressionRef& operator=(const ExpressionRef& other); // Compound operators ExpressionRef& operator+=(ExpressionRef x); @@ -61,6 +79,8 @@ ExpressionRef& operator*=(ExpressionRef x); ExpressionRef& operator/=(ExpressionRef x); + bool IsInitialized() const { return id != kInvalidExpressionId; } + // The index into the ExpressionGraph data array. ExpressionId id = kInvalidExpressionId; @@ -152,6 +172,22 @@ #define CERES_EXPRESSION_RUNTIME_CONSTANT(_v) \ ceres::internal::MakeRuntimeConstant<T>(_v, #_v) + +// The CERES_CODEGEN macro is defined by the build system only during code +// generation. In all other cases the CERES_IF/ELSE macros just expand to the +// if/else keywords. +#ifdef CERES_CODEGEN +#define CERES_IF(condition_) Expression::CreateIf((condition_).id); +#define CERES_ELSE Expression::CreateElse(); +#define CERES_ENDIF Expression::CreateEndIf(); +#else +// clang-format off +#define CERES_IF(condition_) if (condition_) { +#define CERES_ELSE } else { +#define CERES_ENDIF } +// clang-format on +#endif + } // namespace internal // See jet.h for more info on this type.
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt index 63e8540..15770ad 100644 --- a/internal/ceres/CMakeLists.txt +++ b/internal/ceres/CMakeLists.txt
@@ -446,6 +446,7 @@ ceres_test(evaluation_callback) ceres_test(evaluator) ceres_test(expression) + ceres_test(conditional_expressions) ceres_test(expression_graph) ceres_test(fixed_array) ceres_test(gradient_checker)
diff --git a/internal/ceres/conditional_expressions_test.cc b/internal/ceres/conditional_expressions_test.cc new file mode 100644 index 0000000..72fe3ee --- /dev/null +++ b/internal/ceres/conditional_expressions_test.cc
@@ -0,0 +1,189 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) +// + +#define CERES_CODEGEN + +#include "expression_test.h" + +namespace ceres { +namespace internal { + +TEST(Expression, AssignmentElimination) { + using T = ExpressionRef; + + StartRecordingExpressions(); + T a(2); + T b; + b = a; + auto graph = StopRecordingExpressions(); + + // b is invalid during the assignment so we expect no expression to be + // generated. The only expression in the graph should be the constant + // assignment to a. + EXPECT_EQ(graph.Size(), 1); + + // Expected code + // v_0 = 2; + + // clang-format off + // Id Type Lhs Value Name Arguments + TE(0, COMPILE_TIME_CONSTANT, 0, 2, ""); + // clang-format on + + // Variables after execution: + // + // a <=> v_0 + // b <=> v_0 + EXPECT_EQ(a.id, 0); + EXPECT_EQ(b.id, 0); +} + +TEST(Expression, Assignment) { + using T = ExpressionRef; + + StartRecordingExpressions(); + T a(2); + T b(4); + b = a; + auto graph = StopRecordingExpressions(); + + // b is valid during the assignment so we expect an + // additional assignment expression. + EXPECT_EQ(graph.Size(), 3); + + // Expected code + // v_0 = 2; + // v_1 = 4; + // v_1 = v_0; + + // clang-format off + // Id, Type, Lhs, Value, Name, Arguments + TE( 0, COMPILE_TIME_CONSTANT, 0, 2, "", ); + TE( 1, COMPILE_TIME_CONSTANT, 1, 4, "", ); + TE( 2, ASSIGNMENT, 1, 0, "", 0); + // clang-format on + + // Variables after execution: + // + // a <=> v_0 + // b <=> v_1 + EXPECT_EQ(a.id, 0); + EXPECT_EQ(b.id, 1); +} + +TEST(Expression, ConditionalMinimal) { + using T = ExpressionRef; + + StartRecordingExpressions(); + T a(2); + T b(3); + auto c = a < b; + CERES_IF(c) {} + CERES_ELSE {} + CERES_ENDIF + auto graph = StopRecordingExpressions(); + + // Expected code + // v_0 = 2; + // v_1 = 3; + // v_2 = v_0 < v_1; + // if(v_2); + // else + // endif + + EXPECT_EQ(graph.Size(), 6); + + // clang-format off + // Id, Type, Lhs, Value, Name, Arguments... + TE( 0, COMPILE_TIME_CONSTANT, 0, 2, "", ); + TE( 1, COMPILE_TIME_CONSTANT, 1, 3, "", ); + TE( 2, BINARY_COMPARISON, 2, 0, "<", 0, 1); + TE( 3, IF, -1, 0, "", 2); + TE( 4, ELSE, -1, 0, "", ); + TE( 5, ENDIF, -1, 0, "", ); + // clang-format on +} + +TEST(Expression, ConditionalAssignment) { + using T = ExpressionRef; + + StartRecordingExpressions(); + + T result; + T a(2); + T b(3); + auto c = a < b; + CERES_IF(c) { result = a + b; } + CERES_ELSE { result = a - b; } + CERES_ENDIF + result += a; + auto graph = StopRecordingExpressions(); + + // Expected code + // v_0 = 2; + // v_1 = 3; + // v_2 = v_0 < v_1; + // if(v_2); + // v_4 = v_0 + v_1; + // else + // v_6 = v_0 - v_1; + // v_4 = v_6 + // endif + // v_9 = v_4 + v_0; + // v_4 = v_9; + + // clang-format off + // Id, Type, Lhs, Value, Name, Arguments... + TE( 0, COMPILE_TIME_CONSTANT, 0, 2, "", ); + TE( 1, COMPILE_TIME_CONSTANT, 1, 3, "", ); + TE( 2, BINARY_COMPARISON, 2, 0, "<", 0, 1); + TE( 3, IF, -1, 0, "", 2); + TE( 4, BINARY_ARITHMETIC, 4, 0, "+", 0, 1); + TE( 5, ELSE, -1, 0, "", ); + TE( 6, BINARY_ARITHMETIC, 6, 0, "-", 0, 1); + TE( 7, ASSIGNMENT, 4, 0, "", 6 ); + TE( 8, ENDIF, -1, 0, "", ); + TE( 9, BINARY_ARITHMETIC, 9, 0, "+", 4, 0); + TE( 10, ASSIGNMENT, 4, 0, "", 9 ); + // clang-format on + + // Variables after execution: + // + // a <=> v_0 + // b <=> v_1 + // result <=> v_4 + EXPECT_EQ(a.id, 0); + EXPECT_EQ(b.id, 1); + EXPECT_EQ(result.id, 4); +} + +} // namespace internal +} // namespace ceres
diff --git a/internal/ceres/expression.cc b/internal/ceres/expression.cc index 3edcc7d..4c8dd5c 100644 --- a/internal/ceres/expression.cc +++ b/internal/ceres/expression.cc
@@ -37,112 +37,127 @@ namespace ceres { namespace internal { -static Expression& MakeExpression(ExpressionType type) { +// Wrapper for ExpressionGraph::CreateArithmeticExpression, which checks if a +// graph is currently active. See that function for an explanation. +static Expression& MakeArithmeticExpression( + ExpressionType type, ExpressionId lhs_id = kInvalidExpressionId) { auto pool = GetCurrentExpressionGraph(); CHECK(pool) << "The ExpressionGraph has to be created before using Expressions. This " "is achieved by calling ceres::StartRecordingExpressions."; - return pool->CreateExpression(type); + return pool->CreateArithmeticExpression(type, lhs_id); +} + +// Wrapper for ExpressionGraph::CreateControlExpression. +static Expression& MakeControlExpression(ExpressionType type) { + auto pool = GetCurrentExpressionGraph(); + CHECK(pool) + << "The ExpressionGraph has to be created before using Expressions. This " + "is achieved by calling ceres::StartRecordingExpressions."; + return pool->CreateControlExpression(type); } ExpressionId Expression::CreateCompileTimeConstant(double v) { - auto& expr = MakeExpression(ExpressionType::COMPILE_TIME_CONSTANT); + auto& expr = MakeArithmeticExpression(ExpressionType::COMPILE_TIME_CONSTANT); expr.value_ = v; - return expr.id_; + return expr.lhs_id_; } ExpressionId Expression::CreateRuntimeConstant(const std::string& name) { - auto& expr = MakeExpression(ExpressionType::RUNTIME_CONSTANT); + auto& expr = MakeArithmeticExpression(ExpressionType::RUNTIME_CONSTANT); expr.name_ = name; - return expr.id_; + return expr.lhs_id_; } ExpressionId Expression::CreateParameter(const std::string& name) { - auto& expr = MakeExpression(ExpressionType::PARAMETER); + auto& expr = MakeArithmeticExpression(ExpressionType::PARAMETER); expr.name_ = name; - return expr.id_; + return expr.lhs_id_; } -ExpressionId Expression::CreateAssignment(ExpressionId v) { - auto& expr = MakeExpression(ExpressionType::ASSIGNMENT); - expr.arguments_.push_back(v); - return expr.id_; +ExpressionId Expression::CreateAssignment(ExpressionId dst, ExpressionId src) { + auto& expr = MakeArithmeticExpression(ExpressionType::ASSIGNMENT, dst); + + expr.arguments_.push_back(src); + return expr.lhs_id_; } -ExpressionId Expression::CreateUnaryArithmetic(ExpressionType type, +ExpressionId Expression::CreateUnaryArithmetic(const std::string& op, ExpressionId v) { - auto& expr = MakeExpression(type); + auto& expr = MakeArithmeticExpression(ExpressionType::UNARY_ARITHMETIC); + expr.name_ = op; expr.arguments_.push_back(v); - return expr.id_; + return expr.lhs_id_; } ExpressionId Expression::CreateOutputAssignment(ExpressionId v, const std::string& name) { - auto& expr = MakeExpression(ExpressionType::OUTPUT_ASSIGNMENT); + auto& expr = MakeArithmeticExpression(ExpressionType::OUTPUT_ASSIGNMENT); expr.arguments_.push_back(v); expr.name_ = name; - return expr.id_; + return expr.lhs_id_; } ExpressionId Expression::CreateFunctionCall( const std::string& name, const std::vector<ExpressionId>& params) { - auto& expr = MakeExpression(ExpressionType::FUNCTION_CALL); + auto& expr = MakeArithmeticExpression(ExpressionType::FUNCTION_CALL); expr.arguments_ = params; expr.name_ = name; - return expr.id_; + return expr.lhs_id_; } ExpressionId Expression::CreateTernary(ExpressionId condition, ExpressionId if_true, ExpressionId if_false) { - auto& expr = MakeExpression(ExpressionType::TERNARY); + auto& expr = MakeArithmeticExpression(ExpressionType::TERNARY); expr.arguments_.push_back(condition); expr.arguments_.push_back(if_true); expr.arguments_.push_back(if_false); - return expr.id_; + return expr.lhs_id_; } ExpressionId Expression::CreateBinaryCompare(const std::string& name, ExpressionId l, ExpressionId r) { - auto& expr = MakeExpression(ExpressionType::BINARY_COMPARISON); + auto& expr = MakeArithmeticExpression(ExpressionType::BINARY_COMPARISON); expr.arguments_.push_back(l); expr.arguments_.push_back(r); expr.name_ = name; - return expr.id_; + return expr.lhs_id_; } ExpressionId Expression::CreateLogicalNegation(ExpressionId v) { - auto& expr = MakeExpression(ExpressionType::LOGICAL_NEGATION); + auto& expr = MakeArithmeticExpression(ExpressionType::LOGICAL_NEGATION); expr.arguments_.push_back(v); - return expr.id_; + return expr.lhs_id_; } -ExpressionId Expression::CreateBinaryArithmetic(ExpressionType type, +ExpressionId Expression::CreateBinaryArithmetic(const std::string& op, ExpressionId l, ExpressionId r) { - auto& expr = MakeExpression(type); + auto& expr = MakeArithmeticExpression(ExpressionType::BINARY_ARITHMETIC); + expr.name_ = op; expr.arguments_.push_back(l); expr.arguments_.push_back(r); - return expr.id_; + return expr.lhs_id_; } -Expression::Expression(ExpressionType type, ExpressionId id) - : type_(type), id_(id) {} -bool Expression::IsArithmetic() const { - switch (type_) { - case ExpressionType::PLUS: - case ExpressionType::MULTIPLICATION: - case ExpressionType::DIVISION: - case ExpressionType::MINUS: - case ExpressionType::UNARY_MINUS: - case ExpressionType::UNARY_PLUS: - return true; - default: - return false; - } +void Expression::CreateIf(ExpressionId condition) { + auto& expr = MakeControlExpression(ExpressionType::IF); + expr.arguments_.push_back(condition); } +void Expression::CreateElse() { MakeControlExpression(ExpressionType::ELSE); } + +void Expression::CreateEndIf() { MakeControlExpression(ExpressionType::ENDIF); } + +Expression::Expression(ExpressionType type, ExpressionId id) + : type_(type), lhs_id_(id) {} + +bool Expression::IsArithmeticExpression() const { return HasValidLhs(); } + +bool Expression::IsControlExpression() const { return !HasValidLhs(); } + bool Expression::IsReplaceableBy(const Expression& other) const { // Check everything except the id. return (type_ == other.type_ && name_ == other.name_ && @@ -150,7 +165,7 @@ } void Expression::Replace(const Expression& other) { - if (other.id_ == id_) { + if (other.lhs_id_ == lhs_id_) { return; }
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc index 0757b97..5c2b84e 100644 --- a/internal/ceres/expression_graph.cc +++ b/internal/ceres/expression_graph.cc
@@ -55,9 +55,23 @@ ExpressionGraph* GetCurrentExpressionGraph() { return expression_pool; } -Expression& ExpressionGraph::CreateExpression(ExpressionType type) { - auto id = expressions_.size(); - Expression expr(type, id); +Expression& ExpressionGraph::CreateArithmeticExpression(ExpressionType type, + ExpressionId lhs_id) { + if (lhs_id == kInvalidExpressionId) { + // We are creating a new temporary variable. + // -> The new lhs_id is the index into the graph + lhs_id = static_cast<ExpressionId>(expressions_.size()); + } else { + // The left hand side already exists. + } + + Expression expr(type, lhs_id); + expressions_.push_back(expr); + return expressions_.back(); +} + +Expression& ExpressionGraph::CreateControlExpression(ExpressionType type) { + Expression expr(type, kInvalidExpressionId); expressions_.push_back(expr); return expressions_.back(); }
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc index 75a7723..224a3bf 100644 --- a/internal/ceres/expression_ref.cc +++ b/internal/ceres/expression_ref.cc
@@ -29,8 +29,7 @@ // Author: darius.rueckert@fau.de (Darius Rueckert) #include "ceres/internal/expression_ref.h" -#include "assert.h" -#include "ceres/internal/expression.h" +#include "glog/logging.h" namespace ceres { namespace internal { @@ -41,11 +40,24 @@ return ref; } -std::string ExpressionRef::ToString() const { return std::to_string(id); } - ExpressionRef::ExpressionRef(double compile_time_constant) { - (*this) = ExpressionRef::Create( - Expression::CreateCompileTimeConstant(compile_time_constant)); + id = Expression::CreateCompileTimeConstant(compile_time_constant); +} + +ExpressionRef& ExpressionRef::operator=(const ExpressionRef& other) { + // Assigning an uninitialized variable to another variable is an error. + CHECK(other.IsInitialized()) << "Uninitialized Assignment."; + + if (IsInitialized()) { + // Create assignment from other -> this + Expression::CreateAssignment(id, other.id); + } else { + // Special case: "this" expressionref is invalid + // -> Skip assignment + // -> Let "this" point to the same variable as other + id = other.id; + } + return *this; } // Compound operators @@ -71,33 +83,31 @@ // Arith. Operators ExpressionRef operator-(ExpressionRef x) { - return ExpressionRef::Create( - Expression::CreateUnaryArithmetic(ExpressionType::UNARY_MINUS, x.id)); + return ExpressionRef::Create(Expression::CreateUnaryArithmetic("-", x.id)); } ExpressionRef operator+(ExpressionRef x) { - return ExpressionRef::Create( - Expression::CreateUnaryArithmetic(ExpressionType::UNARY_PLUS, x.id)); + return ExpressionRef::Create(Expression::CreateUnaryArithmetic("+", x.id)); } ExpressionRef operator+(ExpressionRef x, ExpressionRef y) { return ExpressionRef::Create( - Expression::CreateBinaryArithmetic(ExpressionType::PLUS, x.id, y.id)); + Expression::CreateBinaryArithmetic("+", x.id, y.id)); } ExpressionRef operator-(ExpressionRef x, ExpressionRef y) { return ExpressionRef::Create( - Expression::CreateBinaryArithmetic(ExpressionType::MINUS, x.id, y.id)); + Expression::CreateBinaryArithmetic("-", x.id, y.id)); } ExpressionRef operator/(ExpressionRef x, ExpressionRef y) { return ExpressionRef::Create( - Expression::CreateBinaryArithmetic(ExpressionType::DIVISION, x.id, y.id)); + Expression::CreateBinaryArithmetic("/", x.id, y.id)); } ExpressionRef operator*(ExpressionRef x, ExpressionRef y) { - return ExpressionRef::Create(Expression::CreateBinaryArithmetic( - ExpressionType::MULTIPLICATION, x.id, y.id)); + return ExpressionRef::Create( + Expression::CreateBinaryArithmetic("*", x.id, y.id)); } // Functions
diff --git a/internal/ceres/expression_test.cc b/internal/ceres/expression_test.cc index 3e683c1..704a464 100644 --- a/internal/ceres/expression_test.cc +++ b/internal/ceres/expression_test.cc
@@ -29,6 +29,8 @@ // Author: darius.rueckert@fau.de (Darius Rueckert) // +#define CERES_CODEGEN + #include "ceres/internal/expression_graph.h" #include "ceres/internal/expression_ref.h" @@ -48,10 +50,10 @@ auto graph = StopRecordingExpressions(); - ASSERT_FALSE(graph.ExpressionForId(a.id).IsArithmetic()); - ASSERT_FALSE(graph.ExpressionForId(b.id).IsArithmetic()); - ASSERT_TRUE(graph.ExpressionForId(c.id).IsArithmetic()); - ASSERT_TRUE(graph.ExpressionForId(d.id).IsArithmetic()); + ASSERT_FALSE(graph.ExpressionForId(a.id).IsArithmeticExpression()); + ASSERT_FALSE(graph.ExpressionForId(b.id).IsArithmeticExpression()); + ASSERT_TRUE(graph.ExpressionForId(c.id).IsArithmeticExpression()); + ASSERT_TRUE(graph.ExpressionForId(d.id).IsArithmeticExpression()); } TEST(Expression, IsCompileTimeConstantAndEqualTo) {
diff --git a/internal/ceres/expression_test.h b/internal/ceres/expression_test.h new file mode 100644 index 0000000..de9e2c2 --- /dev/null +++ b/internal/ceres/expression_test.h
@@ -0,0 +1,69 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) +// + +#ifndef CERES_PUBLIC_EXPRESSION_TEST_H_ +#define CERES_PUBLIC_EXPRESSION_TEST_H_ + +#include "ceres/internal/expression_graph.h" +#include "ceres/internal/expression_ref.h" + +#include "gtest/gtest.h" + +// This file adds a few helper functions to test Expressions and +// ExpressionGraphs for correctness. +namespace ceres { +namespace internal { + +inline void TestExpression(const Expression& expr, + ExpressionType type, + ExpressionId lhs_id, + double value, + const std::string& name, + const std::vector<ExpressionId>& arguments) { + EXPECT_EQ(static_cast<int>(expr.type()), static_cast<int>(type)); + EXPECT_EQ(expr.lhs_id(), lhs_id); + EXPECT_EQ(expr.value(), value); + EXPECT_EQ(expr.name(), name); + EXPECT_EQ(expr.arguments(), arguments); +} + +#define TE(_id, _type, _lhs_id, _value, _name, ...) \ + TestExpression(graph.ExpressionForId(_id), \ + ExpressionType::_type, \ + _lhs_id, \ + _value, \ + _name, \ + {__VA_ARGS__}) + +} // namespace internal +} // namespace ceres + +#endif