Autodiff Codegen Part 2: Conditionals
- Expression types for if/else blocks
- CERES_IF/ELSE macros
- Multiple assignment to the same variable
Change-Id: If529516f243f31823d1ef7b8827bb6f2390e418d
diff --git a/include/ceres/internal/expression.h b/include/ceres/internal/expression.h
index c990d93..5ab2d62 100644
--- a/include/ceres/internal/expression.h
+++ b/include/ceres/internal/expression.h
@@ -28,17 +28,144 @@
//
// Author: darius.rueckert@fau.de (Darius Rueckert)
//
+// During code generation, your cost functor is converted into a list of
+// expressions stored in an expression graph. For each operator (+,-,=,...),
+// function call (sin,cos,...), and special keyword (if,else,...) the
+// appropriate ExpressionType is selected. On a high level all ExpressionTypes
+// are grouped into two different classes: Arithmetic expressions and control
+// expressions.
//
-// This file contains the basic expression type, which is used during code
-// generation. Only assignment expressions of the following form are supported:
+// Part 1: Arithmetic Expressions
//
-// result = [constant|binary_expr|functioncall]
+// Arithmetic expression are the most basic and common types. They are all of
+// the following form:
//
-// Examples:
-// v_78 = v_28 / v_62;
-// v_97 = exp(v_20);
-// v_89 = 3.000000;
+// <lhs> = <rhs>
//
+// <lhs> is the variable name on the left hand side of the assignment. <rhs> can
+// be different depending on the ExpressionType. It must evaluate to a single
+// scalar value though. Here are a few examples of arithmetic expressions (the
+// ExpressionType is given on the right):
+//
+// v_0 = 3.1415; // COMPILE_TIME_CONSTANT
+// v_1 = v_0; // ASSIGNMENT
+// v_2 = v_0 + v_1; // PLUS
+// v_3 = v_2 / v_0; // DIVISION
+// v_4 = sin(v_3); // FUNCTION_CALL
+// v_5 = v_4 < v_3; // BINARY_COMPARISON
+//
+// As you can see, the right hand side of each expression contains exactly one
+// operator/value/function call. If you write long expressions like
+//
+// T c = a + b - T(3) * a;
+//
+// it will broken up into the individual expressions like so:
+//
+// v_0 = a + b;
+// v_1 = 3;
+// v_2 = v_1 * a;
+// c = v_0 - v_2;
+//
+// All arithmetic expressions are generated by operator and function
+// overloading. These overloads are defined in expression_ref.h.
+//
+//
+//
+// Part 2: Control Expressions
+//
+// Control expressions include special instructions that handle the control flow
+// of a program. So far, only if/else is supported, but while/for might come in
+// the future.
+//
+// Generating code for conditional jumps (if/else) is more complicated than
+// for arithmetic expressions. Let's look at a small example to see the
+// problems. After that we explain how these problems are solved in Ceres.
+//
+// 1 T a = parameters[0][0];
+// 2 T b = 1.0;
+// 3 if (a < b) {
+// 4 b = 3.0;
+// 5 } else {
+// 6 b = 4.0;
+// 7 }
+// 8 b += 1.0;
+// 9 residuals[0] = b;
+//
+// Problem 1.
+// We need to generate code for both branches. In C++ there is no way to execute
+// both branches of an if, but we need to execute them to generate the code.
+//
+// Problem 2.
+// The comparison a < b in line 3 is not convertible to bool. Since the value of
+// a is not known during code generation, the expression a < b can not be
+// evaluated. In fact, a < b will return an expression of type
+// BINARY_COMPARISON.
+//
+// Problem 3.
+// There is no way to record that an if was executed. "if" is a special operator
+// which cannot be overloaded. Therefore we can't generate code that contains
+// "if.
+//
+// Problem 4.
+// We have no information about "blocks" or "scopes" during code generation.
+// Even if we could overload the if-operator, there is now way to capture which
+// expression was executed in which branches of the if. For example, we generate
+// code for the else branch. How can we know that the else branch is finished?
+// Is line 8 inside the else-block or already outside?
+//
+// Solution.
+// Instead of using the keywords if/else we insert the macros
+// CERES_IF, CERES_ELSE and CERES_ENDIF. These macros just map to a function,
+// which inserts an expression into the graph. Here is how the example from
+// above looks like with the expanded macros:
+//
+// 1 T a = parameters[0][0];
+// 2 T b = 1.0;
+// 3 CreateIf(a < b); {
+// 4 b = 3.0;
+// 5 } CreateElse(); {
+// 6 b = 4.0;
+// 7 } CreateEndif();
+// 8 b += 1.0;
+// 9 residuals[0] = b;
+//
+// Problem 1 solved.
+// There are no branches during code generation, therefore both blocks are
+// evaluated.
+//
+// Problem 2 solved.
+// The function CreateIf(_) does not take a bool as argument, but an
+// ComparisonExpression. Later during code generation an actual "if" is created
+// with the condition as argument.
+//
+// Problem 3 solved.
+// We replaced "if" by a function call so we can record it now.
+//
+// Problem 4 solved.
+// Expressions are added into the graph in the correct order. That means, after
+// seeing a CreateIf() we know that all following expressions until CreateElse()
+// belong to the true-branch. Similar, all expression from CreateElse() to
+// CreateEndif() belong to the false-branch. This also works recursively with
+// nested ifs.
+//
+// If you want to use the AutoDiff code generation for your cost functors, you
+// have to replace all if/else by the CERES_IF, CERES_ELSE and CERES_ENDIF
+// macros. The example from above looks like this:
+//
+// 1 T a = parameters[0][0];
+// 2 T b = 1.0;
+// 3 CERES_IF (a < b) {
+// 4 b = 3.0;
+// 5 } CERES_ELSE {
+// 6 b = 4.0;
+// 7 } CERES_ENDIF;
+// 8 b += 1.0;
+// 9 residuals[0] = b;
+//
+// These macros don't have a negative impact on performance, because they only
+// expand to the CreateIf/.. functions in code generation mode. Otherwise they
+// expand to the if/else keywords. See expression_ref.h for the exact
+// definition.
//
#ifndef CERES_PUBLIC_EXPRESSION_H_
#define CERES_PUBLIC_EXPRESSION_H_
@@ -68,26 +195,25 @@
// residual[0] = v_51;
OUTPUT_ASSIGNMENT,
- // Trivial Assignment
- // v_1 = v_0;
+ // Trivial assignment
+ // v_3 = v_1
ASSIGNMENT,
// Binary Arithmetic Operations
// v_2 = v_0 + v_1
- PLUS,
- MINUS,
- MULTIPLICATION,
- DIVISION,
+ // The operator is stored in Expression::name_.
+ BINARY_ARITHMETIC,
// Unary Arithmetic Operation
// v_1 = -(v_0);
// v_2 = +(v_1);
- UNARY_MINUS,
- UNARY_PLUS,
+ // The operator is stored in Expression::name_.
+ UNARY_ARITHMETIC,
// Binary Comparison. (<,>,&&,...)
// This is the only expressions which returns a 'bool'.
- // const bool v_2 = v_0 < v_1
+ // v_2 = v_0 < v_1
+ // The operator is stored in Expression::name_.
BINARY_COMPARISON,
// The !-operator on logical expression.
@@ -102,6 +228,12 @@
// v_3 = ternary(v_0,v_1,v_2);
TERNARY,
+ // Conditional control expressions if/else/endif.
+ // These are special expressions, because they don't define a new variable.
+ IF,
+ ELSE,
+ ENDIF,
+
// No Operation. A placeholder for an 'empty' expressions which will be
// optimized out during code generation.
NOP
@@ -129,11 +261,11 @@
static ExpressionId CreateParameter(const std::string& name);
static ExpressionId CreateOutputAssignment(ExpressionId v,
const std::string& name);
- static ExpressionId CreateAssignment(ExpressionId v);
- static ExpressionId CreateBinaryArithmetic(ExpressionType type,
+ static ExpressionId CreateAssignment(ExpressionId dst, ExpressionId src);
+ static ExpressionId CreateBinaryArithmetic(const std::string& op,
ExpressionId l,
ExpressionId r);
- static ExpressionId CreateUnaryArithmetic(ExpressionType type,
+ static ExpressionId CreateUnaryArithmetic(const std::string& op,
ExpressionId v);
static ExpressionId CreateBinaryCompare(const std::string& name,
ExpressionId l,
@@ -145,9 +277,19 @@
ExpressionId if_true,
ExpressionId if_false);
- // Returns true if the expression type is one of the basic math-operators:
- // +,-,*,/
- bool IsArithmetic() const;
+ // Conditional control expressions are inserted into the graph but can't be
+ // referenced by other expressions. Therefore they don't return an
+ // ExpressionId.
+ static void CreateIf(ExpressionId condition);
+ static void CreateElse();
+ static void CreateEndIf();
+
+ // Returns true if this is an arithmetic expression.
+ // Arithmetic expressions must have a valid left hand side.
+ bool IsArithmeticExpression() const;
+
+ // Returns true if this is a control expression.
+ bool IsControlExpression() const;
// If this expression is the compile time constant with the given value.
// Used during optimization to collapse zero/one arithmetic operations.
@@ -170,16 +312,36 @@
// Converts this expression into a NOP
void MakeNop();
+ // Returns true if this expression has a valid lhs.
+ bool HasValidLhs() const { return lhs_id_ != kInvalidExpressionId; }
+
+ ExpressionType type() const { return type_; }
+ ExpressionId lhs_id() const { return lhs_id_; }
+ double value() const { return value_; }
+ const std::string& name() const { return name_; }
+ const std::vector<ExpressionId>& arguments() const { return arguments_; }
+
private:
// Only ExpressionGraph is allowed to call the constructor, because it manages
// the memory and ids.
friend class ExpressionGraph;
// Private constructor. Use the "CreateXX" functions instead.
- Expression(ExpressionType type, ExpressionId id);
+ Expression(ExpressionType type, ExpressionId lhs_id);
ExpressionType type_ = ExpressionType::NOP;
- const ExpressionId id_ = kInvalidExpressionId;
+
+ // If lhs_id_ >= 0, then this expression is assigned to v_<lhs_id>.
+ // For example:
+ // v_1 = v_0 + v_0 (Type = PLUS)
+ // v_3 = sin(v_1) (Type = FUNCTION_CALL)
+ // ^
+ // lhs_id_
+ //
+ // If lhs_id_ == kInvalidExpressionId, then the expression type is not
+ // arithmetic. Currently, only the following types have lhs_id = invalid:
+ // IF,ELSE,ENDIF,NOP
+ const ExpressionId lhs_id_ = kInvalidExpressionId;
// Expressions have different number of arguments. For example a binary "+"
// has 2 parameters and a function call to "sin" has 1 parameter. Here, a
diff --git a/include/ceres/internal/expression_graph.h b/include/ceres/internal/expression_graph.h
index 446fddb..308528f 100644
--- a/include/ceres/internal/expression_graph.h
+++ b/include/ceres/internal/expression_graph.h
@@ -48,9 +48,25 @@
// A is parent of B <=> A has B as a parameter <=> A.DirectlyDependsOn(B)
class ExpressionGraph {
public:
- // Creates an expression and adds it to expressions_.
- // The returned reference will be invalid after this function is called again.
- Expression& CreateExpression(ExpressionType type);
+ // Creates an arithmetic expression of the following form:
+ // <lhs> = <rhs>;
+ //
+ // For example:
+ // CreateArithmeticExpression(PLUS, 5)
+ // will generate:
+ // v_5 = __ + __;
+ // The place holders are then set by the CreateXX functions of Expression.
+ //
+ // If lhs_id == kInvalidExpressionId, then a new lhs_id will be generated and
+ // assigned to the created expression.
+ // Calling this function with a lhs_id that doesn't exist results in an
+ // error.
+ Expression& CreateArithmeticExpression(ExpressionType type,
+ ExpressionId lhs_id);
+
+ // Control expression don't have a left hand side.
+ // Supported types: IF/ELSE/ENDIF/NOP
+ Expression& CreateControlExpression(ExpressionType type);
// Checks if A depends on B.
// -> B is a descendant of A
diff --git a/include/ceres/internal/expression_ref.h b/include/ceres/internal/expression_ref.h
index 67ff227..f4b920f 100644
--- a/include/ceres/internal/expression_ref.h
+++ b/include/ceres/internal/expression_ref.h
@@ -52,8 +52,26 @@
// it's automatically converted to the correct expression.
explicit ExpressionRef(double compile_time_constant);
- // Returns v_id
- std::string ToString() const;
+ // Create an ASSIGNMENT expression from other to this.
+ //
+ // For example:
+ // a = b; // With a.id = 5 and b.id = 3
+ // will generate the following assignment:
+ // v_5 = v_3;
+ //
+ // If this (lhs) ExpressionRef is currently not pointing to a variable
+ // (id==invalid), then we can eliminate the assignment by just letting "this"
+ // point to the same variable as "other".
+ //
+ // Example:
+ // a = b; // With a.id = invalid and b.id = 3
+ // will generate NO expression, but after this line the following will be
+ // true:
+ // a.id == b.id == 3
+ //
+ // If 'other' is not pointing to a variable (id==invalid), we found an
+ // uninitialized assignment, which is handled as an error.
+ ExpressionRef& operator=(const ExpressionRef& other);
// Compound operators
ExpressionRef& operator+=(ExpressionRef x);
@@ -61,6 +79,8 @@
ExpressionRef& operator*=(ExpressionRef x);
ExpressionRef& operator/=(ExpressionRef x);
+ bool IsInitialized() const { return id != kInvalidExpressionId; }
+
// The index into the ExpressionGraph data array.
ExpressionId id = kInvalidExpressionId;
@@ -152,6 +172,22 @@
#define CERES_EXPRESSION_RUNTIME_CONSTANT(_v) \
ceres::internal::MakeRuntimeConstant<T>(_v, #_v)
+
+// The CERES_CODEGEN macro is defined by the build system only during code
+// generation. In all other cases the CERES_IF/ELSE macros just expand to the
+// if/else keywords.
+#ifdef CERES_CODEGEN
+#define CERES_IF(condition_) Expression::CreateIf((condition_).id);
+#define CERES_ELSE Expression::CreateElse();
+#define CERES_ENDIF Expression::CreateEndIf();
+#else
+// clang-format off
+#define CERES_IF(condition_) if (condition_) {
+#define CERES_ELSE } else {
+#define CERES_ENDIF }
+// clang-format on
+#endif
+
} // namespace internal
// See jet.h for more info on this type.
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt
index 63e8540..15770ad 100644
--- a/internal/ceres/CMakeLists.txt
+++ b/internal/ceres/CMakeLists.txt
@@ -446,6 +446,7 @@
ceres_test(evaluation_callback)
ceres_test(evaluator)
ceres_test(expression)
+ ceres_test(conditional_expressions)
ceres_test(expression_graph)
ceres_test(fixed_array)
ceres_test(gradient_checker)
diff --git a/internal/ceres/conditional_expressions_test.cc b/internal/ceres/conditional_expressions_test.cc
new file mode 100644
index 0000000..72fe3ee
--- /dev/null
+++ b/internal/ceres/conditional_expressions_test.cc
@@ -0,0 +1,189 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2019 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: darius.rueckert@fau.de (Darius Rueckert)
+//
+
+#define CERES_CODEGEN
+
+#include "expression_test.h"
+
+namespace ceres {
+namespace internal {
+
+TEST(Expression, AssignmentElimination) {
+ using T = ExpressionRef;
+
+ StartRecordingExpressions();
+ T a(2);
+ T b;
+ b = a;
+ auto graph = StopRecordingExpressions();
+
+ // b is invalid during the assignment so we expect no expression to be
+ // generated. The only expression in the graph should be the constant
+ // assignment to a.
+ EXPECT_EQ(graph.Size(), 1);
+
+ // Expected code
+ // v_0 = 2;
+
+ // clang-format off
+ // Id Type Lhs Value Name Arguments
+ TE(0, COMPILE_TIME_CONSTANT, 0, 2, "");
+ // clang-format on
+
+ // Variables after execution:
+ //
+ // a <=> v_0
+ // b <=> v_0
+ EXPECT_EQ(a.id, 0);
+ EXPECT_EQ(b.id, 0);
+}
+
+TEST(Expression, Assignment) {
+ using T = ExpressionRef;
+
+ StartRecordingExpressions();
+ T a(2);
+ T b(4);
+ b = a;
+ auto graph = StopRecordingExpressions();
+
+ // b is valid during the assignment so we expect an
+ // additional assignment expression.
+ EXPECT_EQ(graph.Size(), 3);
+
+ // Expected code
+ // v_0 = 2;
+ // v_1 = 4;
+ // v_1 = v_0;
+
+ // clang-format off
+ // Id, Type, Lhs, Value, Name, Arguments
+ TE( 0, COMPILE_TIME_CONSTANT, 0, 2, "", );
+ TE( 1, COMPILE_TIME_CONSTANT, 1, 4, "", );
+ TE( 2, ASSIGNMENT, 1, 0, "", 0);
+ // clang-format on
+
+ // Variables after execution:
+ //
+ // a <=> v_0
+ // b <=> v_1
+ EXPECT_EQ(a.id, 0);
+ EXPECT_EQ(b.id, 1);
+}
+
+TEST(Expression, ConditionalMinimal) {
+ using T = ExpressionRef;
+
+ StartRecordingExpressions();
+ T a(2);
+ T b(3);
+ auto c = a < b;
+ CERES_IF(c) {}
+ CERES_ELSE {}
+ CERES_ENDIF
+ auto graph = StopRecordingExpressions();
+
+ // Expected code
+ // v_0 = 2;
+ // v_1 = 3;
+ // v_2 = v_0 < v_1;
+ // if(v_2);
+ // else
+ // endif
+
+ EXPECT_EQ(graph.Size(), 6);
+
+ // clang-format off
+ // Id, Type, Lhs, Value, Name, Arguments...
+ TE( 0, COMPILE_TIME_CONSTANT, 0, 2, "", );
+ TE( 1, COMPILE_TIME_CONSTANT, 1, 3, "", );
+ TE( 2, BINARY_COMPARISON, 2, 0, "<", 0, 1);
+ TE( 3, IF, -1, 0, "", 2);
+ TE( 4, ELSE, -1, 0, "", );
+ TE( 5, ENDIF, -1, 0, "", );
+ // clang-format on
+}
+
+TEST(Expression, ConditionalAssignment) {
+ using T = ExpressionRef;
+
+ StartRecordingExpressions();
+
+ T result;
+ T a(2);
+ T b(3);
+ auto c = a < b;
+ CERES_IF(c) { result = a + b; }
+ CERES_ELSE { result = a - b; }
+ CERES_ENDIF
+ result += a;
+ auto graph = StopRecordingExpressions();
+
+ // Expected code
+ // v_0 = 2;
+ // v_1 = 3;
+ // v_2 = v_0 < v_1;
+ // if(v_2);
+ // v_4 = v_0 + v_1;
+ // else
+ // v_6 = v_0 - v_1;
+ // v_4 = v_6
+ // endif
+ // v_9 = v_4 + v_0;
+ // v_4 = v_9;
+
+ // clang-format off
+ // Id, Type, Lhs, Value, Name, Arguments...
+ TE( 0, COMPILE_TIME_CONSTANT, 0, 2, "", );
+ TE( 1, COMPILE_TIME_CONSTANT, 1, 3, "", );
+ TE( 2, BINARY_COMPARISON, 2, 0, "<", 0, 1);
+ TE( 3, IF, -1, 0, "", 2);
+ TE( 4, BINARY_ARITHMETIC, 4, 0, "+", 0, 1);
+ TE( 5, ELSE, -1, 0, "", );
+ TE( 6, BINARY_ARITHMETIC, 6, 0, "-", 0, 1);
+ TE( 7, ASSIGNMENT, 4, 0, "", 6 );
+ TE( 8, ENDIF, -1, 0, "", );
+ TE( 9, BINARY_ARITHMETIC, 9, 0, "+", 4, 0);
+ TE( 10, ASSIGNMENT, 4, 0, "", 9 );
+ // clang-format on
+
+ // Variables after execution:
+ //
+ // a <=> v_0
+ // b <=> v_1
+ // result <=> v_4
+ EXPECT_EQ(a.id, 0);
+ EXPECT_EQ(b.id, 1);
+ EXPECT_EQ(result.id, 4);
+}
+
+} // namespace internal
+} // namespace ceres
diff --git a/internal/ceres/expression.cc b/internal/ceres/expression.cc
index 3edcc7d..4c8dd5c 100644
--- a/internal/ceres/expression.cc
+++ b/internal/ceres/expression.cc
@@ -37,112 +37,127 @@
namespace ceres {
namespace internal {
-static Expression& MakeExpression(ExpressionType type) {
+// Wrapper for ExpressionGraph::CreateArithmeticExpression, which checks if a
+// graph is currently active. See that function for an explanation.
+static Expression& MakeArithmeticExpression(
+ ExpressionType type, ExpressionId lhs_id = kInvalidExpressionId) {
auto pool = GetCurrentExpressionGraph();
CHECK(pool)
<< "The ExpressionGraph has to be created before using Expressions. This "
"is achieved by calling ceres::StartRecordingExpressions.";
- return pool->CreateExpression(type);
+ return pool->CreateArithmeticExpression(type, lhs_id);
+}
+
+// Wrapper for ExpressionGraph::CreateControlExpression.
+static Expression& MakeControlExpression(ExpressionType type) {
+ auto pool = GetCurrentExpressionGraph();
+ CHECK(pool)
+ << "The ExpressionGraph has to be created before using Expressions. This "
+ "is achieved by calling ceres::StartRecordingExpressions.";
+ return pool->CreateControlExpression(type);
}
ExpressionId Expression::CreateCompileTimeConstant(double v) {
- auto& expr = MakeExpression(ExpressionType::COMPILE_TIME_CONSTANT);
+ auto& expr = MakeArithmeticExpression(ExpressionType::COMPILE_TIME_CONSTANT);
expr.value_ = v;
- return expr.id_;
+ return expr.lhs_id_;
}
ExpressionId Expression::CreateRuntimeConstant(const std::string& name) {
- auto& expr = MakeExpression(ExpressionType::RUNTIME_CONSTANT);
+ auto& expr = MakeArithmeticExpression(ExpressionType::RUNTIME_CONSTANT);
expr.name_ = name;
- return expr.id_;
+ return expr.lhs_id_;
}
ExpressionId Expression::CreateParameter(const std::string& name) {
- auto& expr = MakeExpression(ExpressionType::PARAMETER);
+ auto& expr = MakeArithmeticExpression(ExpressionType::PARAMETER);
expr.name_ = name;
- return expr.id_;
+ return expr.lhs_id_;
}
-ExpressionId Expression::CreateAssignment(ExpressionId v) {
- auto& expr = MakeExpression(ExpressionType::ASSIGNMENT);
- expr.arguments_.push_back(v);
- return expr.id_;
+ExpressionId Expression::CreateAssignment(ExpressionId dst, ExpressionId src) {
+ auto& expr = MakeArithmeticExpression(ExpressionType::ASSIGNMENT, dst);
+
+ expr.arguments_.push_back(src);
+ return expr.lhs_id_;
}
-ExpressionId Expression::CreateUnaryArithmetic(ExpressionType type,
+ExpressionId Expression::CreateUnaryArithmetic(const std::string& op,
ExpressionId v) {
- auto& expr = MakeExpression(type);
+ auto& expr = MakeArithmeticExpression(ExpressionType::UNARY_ARITHMETIC);
+ expr.name_ = op;
expr.arguments_.push_back(v);
- return expr.id_;
+ return expr.lhs_id_;
}
ExpressionId Expression::CreateOutputAssignment(ExpressionId v,
const std::string& name) {
- auto& expr = MakeExpression(ExpressionType::OUTPUT_ASSIGNMENT);
+ auto& expr = MakeArithmeticExpression(ExpressionType::OUTPUT_ASSIGNMENT);
expr.arguments_.push_back(v);
expr.name_ = name;
- return expr.id_;
+ return expr.lhs_id_;
}
ExpressionId Expression::CreateFunctionCall(
const std::string& name, const std::vector<ExpressionId>& params) {
- auto& expr = MakeExpression(ExpressionType::FUNCTION_CALL);
+ auto& expr = MakeArithmeticExpression(ExpressionType::FUNCTION_CALL);
expr.arguments_ = params;
expr.name_ = name;
- return expr.id_;
+ return expr.lhs_id_;
}
ExpressionId Expression::CreateTernary(ExpressionId condition,
ExpressionId if_true,
ExpressionId if_false) {
- auto& expr = MakeExpression(ExpressionType::TERNARY);
+ auto& expr = MakeArithmeticExpression(ExpressionType::TERNARY);
expr.arguments_.push_back(condition);
expr.arguments_.push_back(if_true);
expr.arguments_.push_back(if_false);
- return expr.id_;
+ return expr.lhs_id_;
}
ExpressionId Expression::CreateBinaryCompare(const std::string& name,
ExpressionId l,
ExpressionId r) {
- auto& expr = MakeExpression(ExpressionType::BINARY_COMPARISON);
+ auto& expr = MakeArithmeticExpression(ExpressionType::BINARY_COMPARISON);
expr.arguments_.push_back(l);
expr.arguments_.push_back(r);
expr.name_ = name;
- return expr.id_;
+ return expr.lhs_id_;
}
ExpressionId Expression::CreateLogicalNegation(ExpressionId v) {
- auto& expr = MakeExpression(ExpressionType::LOGICAL_NEGATION);
+ auto& expr = MakeArithmeticExpression(ExpressionType::LOGICAL_NEGATION);
expr.arguments_.push_back(v);
- return expr.id_;
+ return expr.lhs_id_;
}
-ExpressionId Expression::CreateBinaryArithmetic(ExpressionType type,
+ExpressionId Expression::CreateBinaryArithmetic(const std::string& op,
ExpressionId l,
ExpressionId r) {
- auto& expr = MakeExpression(type);
+ auto& expr = MakeArithmeticExpression(ExpressionType::BINARY_ARITHMETIC);
+ expr.name_ = op;
expr.arguments_.push_back(l);
expr.arguments_.push_back(r);
- return expr.id_;
+ return expr.lhs_id_;
}
-Expression::Expression(ExpressionType type, ExpressionId id)
- : type_(type), id_(id) {}
-bool Expression::IsArithmetic() const {
- switch (type_) {
- case ExpressionType::PLUS:
- case ExpressionType::MULTIPLICATION:
- case ExpressionType::DIVISION:
- case ExpressionType::MINUS:
- case ExpressionType::UNARY_MINUS:
- case ExpressionType::UNARY_PLUS:
- return true;
- default:
- return false;
- }
+void Expression::CreateIf(ExpressionId condition) {
+ auto& expr = MakeControlExpression(ExpressionType::IF);
+ expr.arguments_.push_back(condition);
}
+void Expression::CreateElse() { MakeControlExpression(ExpressionType::ELSE); }
+
+void Expression::CreateEndIf() { MakeControlExpression(ExpressionType::ENDIF); }
+
+Expression::Expression(ExpressionType type, ExpressionId id)
+ : type_(type), lhs_id_(id) {}
+
+bool Expression::IsArithmeticExpression() const { return HasValidLhs(); }
+
+bool Expression::IsControlExpression() const { return !HasValidLhs(); }
+
bool Expression::IsReplaceableBy(const Expression& other) const {
// Check everything except the id.
return (type_ == other.type_ && name_ == other.name_ &&
@@ -150,7 +165,7 @@
}
void Expression::Replace(const Expression& other) {
- if (other.id_ == id_) {
+ if (other.lhs_id_ == lhs_id_) {
return;
}
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc
index 0757b97..5c2b84e 100644
--- a/internal/ceres/expression_graph.cc
+++ b/internal/ceres/expression_graph.cc
@@ -55,9 +55,23 @@
ExpressionGraph* GetCurrentExpressionGraph() { return expression_pool; }
-Expression& ExpressionGraph::CreateExpression(ExpressionType type) {
- auto id = expressions_.size();
- Expression expr(type, id);
+Expression& ExpressionGraph::CreateArithmeticExpression(ExpressionType type,
+ ExpressionId lhs_id) {
+ if (lhs_id == kInvalidExpressionId) {
+ // We are creating a new temporary variable.
+ // -> The new lhs_id is the index into the graph
+ lhs_id = static_cast<ExpressionId>(expressions_.size());
+ } else {
+ // The left hand side already exists.
+ }
+
+ Expression expr(type, lhs_id);
+ expressions_.push_back(expr);
+ return expressions_.back();
+}
+
+Expression& ExpressionGraph::CreateControlExpression(ExpressionType type) {
+ Expression expr(type, kInvalidExpressionId);
expressions_.push_back(expr);
return expressions_.back();
}
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc
index 75a7723..224a3bf 100644
--- a/internal/ceres/expression_ref.cc
+++ b/internal/ceres/expression_ref.cc
@@ -29,8 +29,7 @@
// Author: darius.rueckert@fau.de (Darius Rueckert)
#include "ceres/internal/expression_ref.h"
-#include "assert.h"
-#include "ceres/internal/expression.h"
+#include "glog/logging.h"
namespace ceres {
namespace internal {
@@ -41,11 +40,24 @@
return ref;
}
-std::string ExpressionRef::ToString() const { return std::to_string(id); }
-
ExpressionRef::ExpressionRef(double compile_time_constant) {
- (*this) = ExpressionRef::Create(
- Expression::CreateCompileTimeConstant(compile_time_constant));
+ id = Expression::CreateCompileTimeConstant(compile_time_constant);
+}
+
+ExpressionRef& ExpressionRef::operator=(const ExpressionRef& other) {
+ // Assigning an uninitialized variable to another variable is an error.
+ CHECK(other.IsInitialized()) << "Uninitialized Assignment.";
+
+ if (IsInitialized()) {
+ // Create assignment from other -> this
+ Expression::CreateAssignment(id, other.id);
+ } else {
+ // Special case: "this" expressionref is invalid
+ // -> Skip assignment
+ // -> Let "this" point to the same variable as other
+ id = other.id;
+ }
+ return *this;
}
// Compound operators
@@ -71,33 +83,31 @@
// Arith. Operators
ExpressionRef operator-(ExpressionRef x) {
- return ExpressionRef::Create(
- Expression::CreateUnaryArithmetic(ExpressionType::UNARY_MINUS, x.id));
+ return ExpressionRef::Create(Expression::CreateUnaryArithmetic("-", x.id));
}
ExpressionRef operator+(ExpressionRef x) {
- return ExpressionRef::Create(
- Expression::CreateUnaryArithmetic(ExpressionType::UNARY_PLUS, x.id));
+ return ExpressionRef::Create(Expression::CreateUnaryArithmetic("+", x.id));
}
ExpressionRef operator+(ExpressionRef x, ExpressionRef y) {
return ExpressionRef::Create(
- Expression::CreateBinaryArithmetic(ExpressionType::PLUS, x.id, y.id));
+ Expression::CreateBinaryArithmetic("+", x.id, y.id));
}
ExpressionRef operator-(ExpressionRef x, ExpressionRef y) {
return ExpressionRef::Create(
- Expression::CreateBinaryArithmetic(ExpressionType::MINUS, x.id, y.id));
+ Expression::CreateBinaryArithmetic("-", x.id, y.id));
}
ExpressionRef operator/(ExpressionRef x, ExpressionRef y) {
return ExpressionRef::Create(
- Expression::CreateBinaryArithmetic(ExpressionType::DIVISION, x.id, y.id));
+ Expression::CreateBinaryArithmetic("/", x.id, y.id));
}
ExpressionRef operator*(ExpressionRef x, ExpressionRef y) {
- return ExpressionRef::Create(Expression::CreateBinaryArithmetic(
- ExpressionType::MULTIPLICATION, x.id, y.id));
+ return ExpressionRef::Create(
+ Expression::CreateBinaryArithmetic("*", x.id, y.id));
}
// Functions
diff --git a/internal/ceres/expression_test.cc b/internal/ceres/expression_test.cc
index 3e683c1..704a464 100644
--- a/internal/ceres/expression_test.cc
+++ b/internal/ceres/expression_test.cc
@@ -29,6 +29,8 @@
// Author: darius.rueckert@fau.de (Darius Rueckert)
//
+#define CERES_CODEGEN
+
#include "ceres/internal/expression_graph.h"
#include "ceres/internal/expression_ref.h"
@@ -48,10 +50,10 @@
auto graph = StopRecordingExpressions();
- ASSERT_FALSE(graph.ExpressionForId(a.id).IsArithmetic());
- ASSERT_FALSE(graph.ExpressionForId(b.id).IsArithmetic());
- ASSERT_TRUE(graph.ExpressionForId(c.id).IsArithmetic());
- ASSERT_TRUE(graph.ExpressionForId(d.id).IsArithmetic());
+ ASSERT_FALSE(graph.ExpressionForId(a.id).IsArithmeticExpression());
+ ASSERT_FALSE(graph.ExpressionForId(b.id).IsArithmeticExpression());
+ ASSERT_TRUE(graph.ExpressionForId(c.id).IsArithmeticExpression());
+ ASSERT_TRUE(graph.ExpressionForId(d.id).IsArithmeticExpression());
}
TEST(Expression, IsCompileTimeConstantAndEqualTo) {
diff --git a/internal/ceres/expression_test.h b/internal/ceres/expression_test.h
new file mode 100644
index 0000000..de9e2c2
--- /dev/null
+++ b/internal/ceres/expression_test.h
@@ -0,0 +1,69 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2019 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: darius.rueckert@fau.de (Darius Rueckert)
+//
+
+#ifndef CERES_PUBLIC_EXPRESSION_TEST_H_
+#define CERES_PUBLIC_EXPRESSION_TEST_H_
+
+#include "ceres/internal/expression_graph.h"
+#include "ceres/internal/expression_ref.h"
+
+#include "gtest/gtest.h"
+
+// This file adds a few helper functions to test Expressions and
+// ExpressionGraphs for correctness.
+namespace ceres {
+namespace internal {
+
+inline void TestExpression(const Expression& expr,
+ ExpressionType type,
+ ExpressionId lhs_id,
+ double value,
+ const std::string& name,
+ const std::vector<ExpressionId>& arguments) {
+ EXPECT_EQ(static_cast<int>(expr.type()), static_cast<int>(type));
+ EXPECT_EQ(expr.lhs_id(), lhs_id);
+ EXPECT_EQ(expr.value(), value);
+ EXPECT_EQ(expr.name(), name);
+ EXPECT_EQ(expr.arguments(), arguments);
+}
+
+#define TE(_id, _type, _lhs_id, _value, _name, ...) \
+ TestExpression(graph.ExpressionForId(_id), \
+ ExpressionType::_type, \
+ _lhs_id, \
+ _value, \
+ _name, \
+ {__VA_ARGS__})
+
+} // namespace internal
+} // namespace ceres
+
+#endif