Autodiff Codegen Part 1: Expressions This patch adds the 'Expression' class, which is a fundamental building block of automatic code generation. The expressions can be used as scalar types for cost functors as well as Jets. Dynamic branching is not yet supported. Change-Id: I8c61bee5c307e0eec20fd39382683ea90f720dff
diff --git a/include/ceres/internal/expression.h b/include/ceres/internal/expression.h new file mode 100644 index 0000000..c990d93 --- /dev/null +++ b/include/ceres/internal/expression.h
@@ -0,0 +1,203 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) +// +// +// This file contains the basic expression type, which is used during code +// generation. Only assignment expressions of the following form are supported: +// +// result = [constant|binary_expr|functioncall] +// +// Examples: +// v_78 = v_28 / v_62; +// v_97 = exp(v_20); +// v_89 = 3.000000; +// +// +#ifndef CERES_PUBLIC_EXPRESSION_H_ +#define CERES_PUBLIC_EXPRESSION_H_ + +#include <string> +#include <vector> + +namespace ceres { +namespace internal { + +using ExpressionId = int; +static constexpr ExpressionId kInvalidExpressionId = -1; + +enum class ExpressionType { + // v_0 = 3.1415; + COMPILE_TIME_CONSTANT, + + // For example a local member of the cost-functor. + // v_0 = _observed_point_x; + RUNTIME_CONSTANT, + + // Input parameter + // v_0 = parameters[1][5]; + PARAMETER, + + // Output Variable Assignemnt + // residual[0] = v_51; + OUTPUT_ASSIGNMENT, + + // Trivial Assignment + // v_1 = v_0; + ASSIGNMENT, + + // Binary Arithmetic Operations + // v_2 = v_0 + v_1 + PLUS, + MINUS, + MULTIPLICATION, + DIVISION, + + // Unary Arithmetic Operation + // v_1 = -(v_0); + // v_2 = +(v_1); + UNARY_MINUS, + UNARY_PLUS, + + // Binary Comparison. (<,>,&&,...) + // This is the only expressions which returns a 'bool'. + // const bool v_2 = v_0 < v_1 + BINARY_COMPARISON, + + // The !-operator on logical expression. + LOGICAL_NEGATION, + + // General Function Call. + // v_5 = f(v_0,v_1,...) + FUNCTION_CALL, + + // The ternary ?-operator. Separated from the general function call for easier + // access. + // v_3 = ternary(v_0,v_1,v_2); + TERNARY, + + // No Operation. A placeholder for an 'empty' expressions which will be + // optimized out during code generation. + NOP +}; + +// This class contains all data that is required to generate one line of code. +// Each line has the following form: +// +// lhs = rhs; +// +// The left hand side is the variable name given by its own id. The right hand +// side depends on the ExpressionType. For example, a COMPILE_TIME_CONSTANT +// expressions with id 4 generates the following line: +// v_4 = 3.1415; +// +// Objects of this class are created indirectly using the static CreateXX +// methods. During creation, the Expression objects are added to the +// ExpressionGraph (see expression_graph.h). +class Expression { + public: + // These functions create the corresponding expression, add them to an + // internal vector and return a reference to them. + static ExpressionId CreateCompileTimeConstant(double v); + static ExpressionId CreateRuntimeConstant(const std::string& name); + static ExpressionId CreateParameter(const std::string& name); + static ExpressionId CreateOutputAssignment(ExpressionId v, + const std::string& name); + static ExpressionId CreateAssignment(ExpressionId v); + static ExpressionId CreateBinaryArithmetic(ExpressionType type, + ExpressionId l, + ExpressionId r); + static ExpressionId CreateUnaryArithmetic(ExpressionType type, + ExpressionId v); + static ExpressionId CreateBinaryCompare(const std::string& name, + ExpressionId l, + ExpressionId r); + static ExpressionId CreateLogicalNegation(ExpressionId v); + static ExpressionId CreateFunctionCall( + const std::string& name, const std::vector<ExpressionId>& params); + static ExpressionId CreateTernary(ExpressionId condition, + ExpressionId if_true, + ExpressionId if_false); + + // Returns true if the expression type is one of the basic math-operators: + // +,-,*,/ + bool IsArithmetic() const; + + // If this expression is the compile time constant with the given value. + // Used during optimization to collapse zero/one arithmetic operations. + // b = a + 0; -> b = a; + bool IsCompileTimeConstantAndEqualTo(double constant) const; + + // Checks if "other" is identical to "this" so that one of the epxressions can + // be replaced by a trivial assignment. Used during common subexpression + // elimination. + bool IsReplaceableBy(const Expression& other) const; + + // Replace this expression by 'other'. + // The current id will be not replaced. That means other experssions + // referencing this one stay valid. + void Replace(const Expression& other); + + // If this expression has 'other' as an argument. + bool DirectlyDependsOn(ExpressionId other) const; + + // Converts this expression into a NOP + void MakeNop(); + + private: + // Only ExpressionGraph is allowed to call the constructor, because it manages + // the memory and ids. + friend class ExpressionGraph; + + // Private constructor. Use the "CreateXX" functions instead. + Expression(ExpressionType type, ExpressionId id); + + ExpressionType type_ = ExpressionType::NOP; + const ExpressionId id_ = kInvalidExpressionId; + + // Expressions have different number of arguments. For example a binary "+" + // has 2 parameters and a function call to "sin" has 1 parameter. Here, a + // reference to these paratmers is stored. Note: The order matters! + std::vector<ExpressionId> arguments_; + + // Depending on the type this name is one of the following: + // (type == FUNCTION_CALL) -> the function name + // (type == PARAMETER) -> the parameter name + // (type == OUTPUT_ASSIGN) -> the output variable name + // (type == BINARY_COMPARE)-> the comparison symbol "<","&&",... + // else -> unused + std::string name_; + + // Only valid if type == COMPILE_TIME_CONSTANT + double value_ = 0; +}; + +} // namespace internal +} // namespace ceres +#endif
diff --git a/include/ceres/internal/expression_graph.h b/include/ceres/internal/expression_graph.h new file mode 100644 index 0000000..446fddb --- /dev/null +++ b/include/ceres/internal/expression_graph.h
@@ -0,0 +1,92 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) + +#ifndef CERES_PUBLIC_EXPRESSION_TREE_H_ +#define CERES_PUBLIC_EXPRESSION_TREE_H_ + +#include <vector> + +#include "expression.h" + +namespace ceres { +namespace internal { + +// A directed, acyclic, unconnected graph containing all expressions of a +// program. +// +// The expression graph is stored linear in the expressions_ array. The order is +// identical to the execution order. Each expression can have multiple children +// and multiple parents. +// A is child of B <=> B has A as a parameter <=> B.DirectlyDependsOn(A) +// A is parent of B <=> A has B as a parameter <=> A.DirectlyDependsOn(B) +class ExpressionGraph { + public: + // Creates an expression and adds it to expressions_. + // The returned reference will be invalid after this function is called again. + Expression& CreateExpression(ExpressionType type); + + // Checks if A depends on B. + // -> B is a descendant of A + bool DependsOn(ExpressionId A, ExpressionId B) const; + + Expression& ExpressionForId(ExpressionId id) { return expressions_[id]; } + const Expression& ExpressionForId(ExpressionId id) const { + return expressions_[id]; + } + + int Size() const { return expressions_.size(); } + + private: + // All Expressions are referenced by an ExpressionId. The ExpressionId is the + // index into this array. Each expression has a list of ExpressionId as + // arguments. These references form the graph. + std::vector<Expression> expressions_; +}; + +// After calling this method, all operations on 'ExpressionRef' objects will be +// recorded into an ExpressionGraph. You can obtain this graph by calling +// StopRecordingExpressions. +// +// Performing expression operations before calling StartRecordingExpressions or +// calling StartRecodring. twice is an error. +void StartRecordingExpressions(); + +// Stops recording and returns all expressions that have been executed since the +// call to StartRecordingExpressions. The internal ExpressionGraph will be +// invalidated and a second consecutive call to this method results in an error. +ExpressionGraph StopRecordingExpressions(); + +// Returns a pointer to the active expression tree. +// Normal users should not use this functions. +ExpressionGraph* GetCurrentExpressionGraph(); + +} // namespace internal +} // namespace ceres +#endif
diff --git a/include/ceres/internal/expression_ref.h b/include/ceres/internal/expression_ref.h new file mode 100644 index 0000000..67ff227 --- /dev/null +++ b/include/ceres/internal/expression_ref.h
@@ -0,0 +1,164 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) +// +// TODO: Documentation +#ifndef CERES_PUBLIC_EXPRESSION_REF_H_ +#define CERES_PUBLIC_EXPRESSION_REF_H_ + +#include <string> +#include "ceres/jet.h" +#include "expression.h" + +namespace ceres { +namespace internal { + +// This class represents a scalar value that creates new expressions during +// evaluation. ExpressionRef can be used as template parameter for cost functors +// and Jets. +// +// ExpressionRef should be passed by value. +struct ExpressionRef { + ExpressionRef() = default; + + // Create a compile time constant expression directly from a double value. + // This is important so that we can write T(3.14) in our code and + // it's automatically converted to the correct expression. + explicit ExpressionRef(double compile_time_constant); + + // Returns v_id + std::string ToString() const; + + // Compound operators + ExpressionRef& operator+=(ExpressionRef x); + ExpressionRef& operator-=(ExpressionRef x); + ExpressionRef& operator*=(ExpressionRef x); + ExpressionRef& operator/=(ExpressionRef x); + + // The index into the ExpressionGraph data array. + ExpressionId id = kInvalidExpressionId; + + static ExpressionRef Create(ExpressionId id); +}; + +// Arithmetic Operators +ExpressionRef operator-(ExpressionRef x); +ExpressionRef operator+(ExpressionRef x); +ExpressionRef operator+(ExpressionRef x, ExpressionRef y); +ExpressionRef operator-(ExpressionRef x, ExpressionRef y); +ExpressionRef operator*(ExpressionRef x, ExpressionRef y); +ExpressionRef operator/(ExpressionRef x, ExpressionRef y); + +// Functions +// TODO: Add all function supported by Jet. +ExpressionRef sin(ExpressionRef x); + +// This additonal type is required, so that we can detect invalid conditions +// during compile time. For example, the following should create a compile time +// error: +// +// ExpressionRef a(5); +// CERES_IF(a){ // Error: Invalid conversion +// ... +// +// Following will work: +// +// ExpressionRef a(5), b(7); +// ComparisonExpressionRef c = a < b; +// CERES_IF(c){ +// ... +struct ComparisonExpressionRef { + ExpressionId id; + explicit ComparisonExpressionRef(ExpressionRef ref) : id(ref.id) {} +}; + +ExpressionRef Ternary(ComparisonExpressionRef c, + ExpressionRef a, + ExpressionRef b); + +// Comparison operators +ComparisonExpressionRef operator<(ExpressionRef a, ExpressionRef b); +ComparisonExpressionRef operator<=(ExpressionRef a, ExpressionRef b); +ComparisonExpressionRef operator>(ExpressionRef a, ExpressionRef b); +ComparisonExpressionRef operator>=(ExpressionRef a, ExpressionRef b); +ComparisonExpressionRef operator==(ExpressionRef a, ExpressionRef b); +ComparisonExpressionRef operator!=(ExpressionRef a, ExpressionRef b); + +// Logical Operators +ComparisonExpressionRef operator&&(ComparisonExpressionRef a, + ComparisonExpressionRef b); +ComparisonExpressionRef operator||(ComparisonExpressionRef a, + ComparisonExpressionRef b); +ComparisonExpressionRef operator!(ComparisonExpressionRef a); + +// This struct is used to mark numbers which are constant over +// multiple invocations but can differ between instances. +template <typename T> +struct RuntimeConstant { + using ReturnType = T; + static inline ReturnType Get(double v, const char* name) { return v; } +}; + +template <typename G, int N> +struct RuntimeConstant<Jet<G, N>> { + using ReturnType = Jet<G, N>; + static inline Jet<G, N> Get(double v, const char* name) { + return Jet<G, N>(v); + } +}; + +template <int N> +struct RuntimeConstant<Jet<ExpressionRef, N>> { + using ReturnType = Jet<ExpressionRef, N>; + static inline ReturnType Get(double v, const char* name) { + // Note: The scalar value of v will be thrown away, because we don't need it + // during code generation. + (void)v; + return Jet<ExpressionRef, N>(Expression::CreateRuntimeConstant(name)); + } +}; + +template <typename T> +inline typename RuntimeConstant<T>::ReturnType MakeRuntimeConstant( + double v, const char* name) { + return RuntimeConstant<T>::Get(v, name); +} + +#define CERES_EXPRESSION_RUNTIME_CONSTANT(_v) \ + ceres::internal::MakeRuntimeConstant<T>(_v, #_v) +} // namespace internal + +// See jet.h for more info on this type. +template <> +struct ComparisonReturnType<internal::ExpressionRef> { + using type = internal::ComparisonExpressionRef; +}; + +} // namespace ceres +#endif
diff --git a/include/ceres/jet.h b/include/ceres/jet.h index 8d83563..25d11e6 100644 --- a/include/ceres/jet.h +++ b/include/ceres/jet.h
@@ -168,6 +168,20 @@ namespace ceres { +// The return type of a Jet comparison, for example from <, &&, ==. +// +// In the context of traditional Ceres Jet operations, this would +// always be a bool. However, in the autodiff code generation context, +// the return is always an expression, and so a different type must be +// used as a return from comparisons. +// +// In the autodiff codegen context, this function is overloaded so that 'type' +// is one of the autodiff code generation expression types. +template <typename T> +struct ComparisonReturnType { + using type = bool; +}; + template <typename T, int N> struct Jet { enum { DIMENSION = N }; @@ -353,18 +367,21 @@ } // Binary comparison operators for both scalars and jets. -#define CERES_DEFINE_JET_COMPARISON_OPERATOR(op) \ - template <typename T, int N> \ - inline bool operator op(const Jet<T, N>& f, const Jet<T, N>& g) { \ - return f.a op g.a; \ - } \ - template <typename T, int N> \ - inline bool operator op(const T& s, const Jet<T, N>& g) { \ - return s op g.a; \ - } \ - template <typename T, int N> \ - inline bool operator op(const Jet<T, N>& f, const T& s) { \ - return f.a op s; \ +#define CERES_DEFINE_JET_COMPARISON_OPERATOR(op) \ + template <typename T, int N> \ + inline typename ComparisonReturnType<T>::type operator op( \ + const Jet<T, N>& f, const Jet<T, N>& g) { \ + return f.a op g.a; \ + } \ + template <typename T, int N> \ + inline typename ComparisonReturnType<T>::type operator op( \ + const T& s, const Jet<T, N>& g) { \ + return s op g.a; \ + } \ + template <typename T, int N> \ + inline typename ComparisonReturnType<T>::type operator op( \ + const Jet<T, N>& f, const T& s) { \ + return f.a op s; \ } CERES_DEFINE_JET_COMPARISON_OPERATOR(<) // NOLINT CERES_DEFINE_JET_COMPARISON_OPERATOR(<=) // NOLINT @@ -374,6 +391,26 @@ CERES_DEFINE_JET_COMPARISON_OPERATOR(!=) // NOLINT #undef CERES_DEFINE_JET_COMPARISON_OPERATOR +// A function equivalent to the ternary ?-operator. +// This function is required, because in the context of code generation a +// comparison returns an expression type which is not convertible to bool. +template <typename T> +inline T Ternary(bool c, T a, T b) { + return c ? a : b; +} + +template <typename T, int N> +inline Jet<T, N> Ternary(typename ComparisonReturnType<T>::type c, + const Jet<T, N>& f, + const Jet<T, N>& g) { + Jet<T, N> r; + r.a = Ternary(c, f.a, g.a); + for (int i = 0; i < N; ++i) { + r.v[i] = Ternary(c, f.v[i], g.v[i]); + } + return r; +} + // Pull some functions from namespace std. // // This is necessary because we want to use the same name (e.g. 'sqrt') for
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt index e452f48..01c23ca 100644 --- a/internal/ceres/CMakeLists.txt +++ b/internal/ceres/CMakeLists.txt
@@ -82,6 +82,9 @@ dynamic_sparse_normal_cholesky_solver.cc evaluator.cc eigensparse.cc + expression.cc + expression_graph.cc + expression_ref.cc file.cc float_suitesparse.cc float_cxsparse.cc @@ -433,6 +436,8 @@ ceres_test(dynamic_sparsity) ceres_test(evaluation_callback) ceres_test(evaluator) + ceres_test(expression) + ceres_test(expression_graph) ceres_test(fixed_array) ceres_test(gradient_checker) ceres_test(gradient_checking_cost_function)
diff --git a/internal/ceres/expression.cc b/internal/ceres/expression.cc new file mode 100644 index 0000000..3edcc7d --- /dev/null +++ b/internal/ceres/expression.cc
@@ -0,0 +1,178 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) + +#include "ceres/internal/expression.h" +#include <algorithm> + +#include "ceres/internal/expression_graph.h" +#include "glog/logging.h" + +namespace ceres { +namespace internal { + +static Expression& MakeExpression(ExpressionType type) { + auto pool = GetCurrentExpressionGraph(); + CHECK(pool) + << "The ExpressionGraph has to be created before using Expressions. This " + "is achieved by calling ceres::StartRecordingExpressions."; + return pool->CreateExpression(type); +} + +ExpressionId Expression::CreateCompileTimeConstant(double v) { + auto& expr = MakeExpression(ExpressionType::COMPILE_TIME_CONSTANT); + expr.value_ = v; + return expr.id_; +} + +ExpressionId Expression::CreateRuntimeConstant(const std::string& name) { + auto& expr = MakeExpression(ExpressionType::RUNTIME_CONSTANT); + expr.name_ = name; + return expr.id_; +} + +ExpressionId Expression::CreateParameter(const std::string& name) { + auto& expr = MakeExpression(ExpressionType::PARAMETER); + expr.name_ = name; + return expr.id_; +} + +ExpressionId Expression::CreateAssignment(ExpressionId v) { + auto& expr = MakeExpression(ExpressionType::ASSIGNMENT); + expr.arguments_.push_back(v); + return expr.id_; +} + +ExpressionId Expression::CreateUnaryArithmetic(ExpressionType type, + ExpressionId v) { + auto& expr = MakeExpression(type); + expr.arguments_.push_back(v); + return expr.id_; +} + +ExpressionId Expression::CreateOutputAssignment(ExpressionId v, + const std::string& name) { + auto& expr = MakeExpression(ExpressionType::OUTPUT_ASSIGNMENT); + expr.arguments_.push_back(v); + expr.name_ = name; + return expr.id_; +} + +ExpressionId Expression::CreateFunctionCall( + const std::string& name, const std::vector<ExpressionId>& params) { + auto& expr = MakeExpression(ExpressionType::FUNCTION_CALL); + expr.arguments_ = params; + expr.name_ = name; + return expr.id_; +} + +ExpressionId Expression::CreateTernary(ExpressionId condition, + ExpressionId if_true, + ExpressionId if_false) { + auto& expr = MakeExpression(ExpressionType::TERNARY); + expr.arguments_.push_back(condition); + expr.arguments_.push_back(if_true); + expr.arguments_.push_back(if_false); + return expr.id_; +} + +ExpressionId Expression::CreateBinaryCompare(const std::string& name, + ExpressionId l, + ExpressionId r) { + auto& expr = MakeExpression(ExpressionType::BINARY_COMPARISON); + expr.arguments_.push_back(l); + expr.arguments_.push_back(r); + expr.name_ = name; + return expr.id_; +} + +ExpressionId Expression::CreateLogicalNegation(ExpressionId v) { + auto& expr = MakeExpression(ExpressionType::LOGICAL_NEGATION); + expr.arguments_.push_back(v); + return expr.id_; +} + +ExpressionId Expression::CreateBinaryArithmetic(ExpressionType type, + ExpressionId l, + ExpressionId r) { + auto& expr = MakeExpression(type); + expr.arguments_.push_back(l); + expr.arguments_.push_back(r); + return expr.id_; +} +Expression::Expression(ExpressionType type, ExpressionId id) + : type_(type), id_(id) {} + +bool Expression::IsArithmetic() const { + switch (type_) { + case ExpressionType::PLUS: + case ExpressionType::MULTIPLICATION: + case ExpressionType::DIVISION: + case ExpressionType::MINUS: + case ExpressionType::UNARY_MINUS: + case ExpressionType::UNARY_PLUS: + return true; + default: + return false; + } +} + +bool Expression::IsReplaceableBy(const Expression& other) const { + // Check everything except the id. + return (type_ == other.type_ && name_ == other.name_ && + value_ == other.value_ && arguments_ == other.arguments_); +} + +void Expression::Replace(const Expression& other) { + if (other.id_ == id_) { + return; + } + + type_ = other.type_; + arguments_ = other.arguments_; + name_ = other.name_; + value_ = other.value_; +} + +bool Expression::DirectlyDependsOn(ExpressionId other) const { + return (std::find(arguments_.begin(), arguments_.end(), other) != + arguments_.end()); +} + +bool Expression::IsCompileTimeConstantAndEqualTo(double constant) const { + return type_ == ExpressionType::COMPILE_TIME_CONSTANT && value_ == constant; +} + +void Expression::MakeNop() { + type_ = ExpressionType::NOP; + arguments_.clear(); +} + +} // namespace internal +} // namespace ceres
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc new file mode 100644 index 0000000..0757b97 --- /dev/null +++ b/internal/ceres/expression_graph.cc
@@ -0,0 +1,85 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) + +#include "ceres/internal/expression_graph.h" + +#include "glog/logging.h" +namespace ceres { +namespace internal { + +static ExpressionGraph* expression_pool = nullptr; + +void StartRecordingExpressions() { + CHECK(expression_pool == nullptr) + << "Expression recording must be stopped before calling " + "StartRecordingExpressions again."; + expression_pool = new ExpressionGraph; +} + +ExpressionGraph StopRecordingExpressions() { + CHECK(expression_pool) + << "Expression recording hasn't started yet or you tried " + "to stop it twice."; + ExpressionGraph result = std::move(*expression_pool); + delete expression_pool; + expression_pool = nullptr; + return result; +} + +ExpressionGraph* GetCurrentExpressionGraph() { return expression_pool; } + +Expression& ExpressionGraph::CreateExpression(ExpressionType type) { + auto id = expressions_.size(); + Expression expr(type, id); + expressions_.push_back(expr); + return expressions_.back(); +} + +bool ExpressionGraph::DependsOn(ExpressionId A, ExpressionId B) const { + // Depth first search on the expression graph + // Equivalent Recursive Implementation: + // if (A.DirectlyDependsOn(B)) return true; + // for (auto p : A.params_) { + // if (pool[p.id].DependsOn(B, pool)) return true; + // } + std::vector<ExpressionId> stack = ExpressionForId(A).arguments_; + while (!stack.empty()) { + auto top = stack.back(); + stack.pop_back(); + if (top == B) { + return true; + } + auto& expr = ExpressionForId(top); + stack.insert(stack.end(), expr.arguments_.begin(), expr.arguments_.end()); + } + return false; +} +} // namespace internal +} // namespace ceres
diff --git a/internal/ceres/expression_graph_test.cc b/internal/ceres/expression_graph_test.cc new file mode 100644 index 0000000..ee06f1a --- /dev/null +++ b/internal/ceres/expression_graph_test.cc
@@ -0,0 +1,77 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) +// +// Test expression creation and logic. + +#include "ceres/internal/expression_graph.h" +#include "ceres/internal/expression_ref.h" + +#include "gtest/gtest.h" + +namespace ceres { +namespace internal { + +TEST(ExpressionGraph, Creation) { + using T = ExpressionRef; + ExpressionGraph graph; + + StartRecordingExpressions(); + graph = StopRecordingExpressions(); + EXPECT_EQ(graph.Size(), 0); + + StartRecordingExpressions(); + T a(1); + T b(2); + T c = a + b; + graph = StopRecordingExpressions(); + EXPECT_EQ(graph.Size(), 3); +} + +TEST(ExpressionGraph, Dependencies) { + using T = ExpressionRef; + + StartRecordingExpressions(); + + T unused(6); + T a(2), b(3); + T c = a + b; + T d = c + a; + + auto tree = StopRecordingExpressions(); + + // Recursive dependency check + ASSERT_TRUE(tree.DependsOn(d.id, c.id)); + ASSERT_TRUE(tree.DependsOn(d.id, a.id)); + ASSERT_TRUE(tree.DependsOn(d.id, b.id)); + ASSERT_FALSE(tree.DependsOn(d.id, unused.id)); +} + +} // namespace internal +} // namespace ceres
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc new file mode 100644 index 0000000..75a7723 --- /dev/null +++ b/internal/ceres/expression_ref.cc
@@ -0,0 +1,144 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) + +#include "ceres/internal/expression_ref.h" +#include "assert.h" +#include "ceres/internal/expression.h" + +namespace ceres { +namespace internal { + +ExpressionRef ExpressionRef::Create(ExpressionId id) { + ExpressionRef ref; + ref.id = id; + return ref; +} + +std::string ExpressionRef::ToString() const { return std::to_string(id); } + +ExpressionRef::ExpressionRef(double compile_time_constant) { + (*this) = ExpressionRef::Create( + Expression::CreateCompileTimeConstant(compile_time_constant)); +} + +// Compound operators +ExpressionRef& ExpressionRef::operator+=(ExpressionRef y) { + *this = *this + y; + return *this; +} + +ExpressionRef& ExpressionRef::operator-=(ExpressionRef y) { + *this = *this - y; + return *this; +} + +ExpressionRef& ExpressionRef::operator*=(ExpressionRef y) { + *this = *this * y; + return *this; +} + +ExpressionRef& ExpressionRef::operator/=(ExpressionRef y) { + *this = *this / y; + return *this; +} + +// Arith. Operators +ExpressionRef operator-(ExpressionRef x) { + return ExpressionRef::Create( + Expression::CreateUnaryArithmetic(ExpressionType::UNARY_MINUS, x.id)); +} + +ExpressionRef operator+(ExpressionRef x) { + return ExpressionRef::Create( + Expression::CreateUnaryArithmetic(ExpressionType::UNARY_PLUS, x.id)); +} + +ExpressionRef operator+(ExpressionRef x, ExpressionRef y) { + return ExpressionRef::Create( + Expression::CreateBinaryArithmetic(ExpressionType::PLUS, x.id, y.id)); +} + +ExpressionRef operator-(ExpressionRef x, ExpressionRef y) { + return ExpressionRef::Create( + Expression::CreateBinaryArithmetic(ExpressionType::MINUS, x.id, y.id)); +} + +ExpressionRef operator/(ExpressionRef x, ExpressionRef y) { + return ExpressionRef::Create( + Expression::CreateBinaryArithmetic(ExpressionType::DIVISION, x.id, y.id)); +} + +ExpressionRef operator*(ExpressionRef x, ExpressionRef y) { + return ExpressionRef::Create(Expression::CreateBinaryArithmetic( + ExpressionType::MULTIPLICATION, x.id, y.id)); +} + +// Functions +ExpressionRef sin(ExpressionRef x) { + return ExpressionRef::Create(Expression::CreateFunctionCall("sin", {x.id})); +} + +ExpressionRef Ternary(ComparisonExpressionRef c, + ExpressionRef a, + ExpressionRef b) { + return ExpressionRef::Create(Expression::CreateTernary(c.id, a.id, b.id)); +} + +#define CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR(op) \ + ComparisonExpressionRef operator op(ExpressionRef a, ExpressionRef b) { \ + return ComparisonExpressionRef(ExpressionRef::Create( \ + Expression::CreateBinaryCompare(#op, a.id, b.id))); \ + } + +#define CERES_DEFINE_EXPRESSION_LOGICAL_OPERATOR(op) \ + ComparisonExpressionRef operator op(ComparisonExpressionRef a, \ + ComparisonExpressionRef b) { \ + return ComparisonExpressionRef(ExpressionRef::Create( \ + Expression::CreateBinaryCompare(#op, a.id, b.id))); \ + } + +CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR(<) +CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR(<=) +CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR(>) +CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR(>=) +CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR(==) +CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR(!=) +CERES_DEFINE_EXPRESSION_LOGICAL_OPERATOR(&&) +CERES_DEFINE_EXPRESSION_LOGICAL_OPERATOR(||) +#undef CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR +#undef CERES_DEFINE_EXPRESSION_LOGICAL_OPERATOR + +ComparisonExpressionRef operator!(ComparisonExpressionRef a) { + return ComparisonExpressionRef( + ExpressionRef::Create(Expression::CreateLogicalNegation(a.id))); +} + +} // namespace internal +} // namespace ceres
diff --git a/internal/ceres/expression_test.cc b/internal/ceres/expression_test.cc new file mode 100644 index 0000000..3e683c1 --- /dev/null +++ b/internal/ceres/expression_test.cc
@@ -0,0 +1,119 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2019 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: darius.rueckert@fau.de (Darius Rueckert) +// + +#include "ceres/internal/expression_graph.h" +#include "ceres/internal/expression_ref.h" + +#include "gtest/gtest.h" + +namespace ceres { +namespace internal { + +TEST(Expression, IsArithmetic) { + using T = ExpressionRef; + + StartRecordingExpressions(); + + T a(2), b(3); + T c = a + b; + T d = c + a; + + auto graph = StopRecordingExpressions(); + + ASSERT_FALSE(graph.ExpressionForId(a.id).IsArithmetic()); + ASSERT_FALSE(graph.ExpressionForId(b.id).IsArithmetic()); + ASSERT_TRUE(graph.ExpressionForId(c.id).IsArithmetic()); + ASSERT_TRUE(graph.ExpressionForId(d.id).IsArithmetic()); +} + +TEST(Expression, IsCompileTimeConstantAndEqualTo) { + using T = ExpressionRef; + + StartRecordingExpressions(); + + T a(2), b(3); + T c = a + b; + + auto graph = StopRecordingExpressions(); + + ASSERT_TRUE(graph.ExpressionForId(a.id).IsCompileTimeConstantAndEqualTo(2)); + ASSERT_FALSE(graph.ExpressionForId(a.id).IsCompileTimeConstantAndEqualTo(0)); + ASSERT_TRUE(graph.ExpressionForId(b.id).IsCompileTimeConstantAndEqualTo(3)); + ASSERT_FALSE(graph.ExpressionForId(c.id).IsCompileTimeConstantAndEqualTo(0)); +} + +TEST(Expression, IsReplaceableBy) { + using T = ExpressionRef; + + StartRecordingExpressions(); + + // a2 should be replaceable by a + T a(2), b(3), a2(2); + + // two redundant expressions + // -> d should be replaceable by c + T c = a + b; + T d = a + b; + + auto graph = StopRecordingExpressions(); + + ASSERT_TRUE(graph.ExpressionForId(a2.id).IsReplaceableBy( + graph.ExpressionForId(a.id))); + ASSERT_TRUE( + graph.ExpressionForId(d.id).IsReplaceableBy(graph.ExpressionForId(c.id))); + ASSERT_FALSE(graph.ExpressionForId(d.id).IsReplaceableBy( + graph.ExpressionForId(a2.id))); +} + +TEST(Expression, DirectlyDependsOn) { + using T = ExpressionRef; + + StartRecordingExpressions(); + + T unused(6); + T a(2), b(3); + T c = a + b; + T d = c + a; + + auto graph = StopRecordingExpressions(); + + ASSERT_FALSE(graph.ExpressionForId(a.id).DirectlyDependsOn(unused.id)); + ASSERT_TRUE(graph.ExpressionForId(c.id).DirectlyDependsOn(a.id)); + ASSERT_TRUE(graph.ExpressionForId(c.id).DirectlyDependsOn(b.id)); + ASSERT_TRUE(graph.ExpressionForId(d.id).DirectlyDependsOn(a.id)); + ASSERT_FALSE(graph.ExpressionForId(d.id).DirectlyDependsOn(b.id)); + ASSERT_TRUE(graph.ExpressionForId(d.id).DirectlyDependsOn(c.id)); +} + +// Todo: remaining functions of Expression + +} // namespace internal +} // namespace ceres