Add namespaces to generated functions and constants The generated function names now include the containing namespace. For example: std::abs(...) std::sin(...) ceres::Ternary(...) This patch also fixes the generation of inf/nan compile time constants, using std::numeric_limits. Change-Id: I4a36b09c68dd2adabed49fd4f7f37c8229ab7377
diff --git a/include/ceres/codegen/internal/expression_ref.h b/include/ceres/codegen/internal/expression_ref.h index 5a13d76..d73d477 100644 --- a/include/ceres/codegen/internal/expression_ref.h +++ b/include/ceres/codegen/internal/expression_ref.h
@@ -115,37 +115,37 @@ ExpressionRef operator/(const ExpressionRef& x, const ExpressionRef& y); // Functions -#define CERES_DEFINE_UNARY_FUNCTION_CALL(name) \ - inline ExpressionRef name(const ExpressionRef& x) { \ - return AddExpressionToGraph( \ - Expression::CreateScalarFunctionCall(#name, {x.id})); \ +#define CERES_DEFINE_UNARY_FUNCTION_CALL(ns, name) \ + inline ExpressionRef name(const ExpressionRef& x) { \ + return AddExpressionToGraph( \ + Expression::CreateScalarFunctionCall(#ns "::" #name, {x.id})); \ } -#define CERES_DEFINE_BINARY_FUNCTION_CALL(name) \ +#define CERES_DEFINE_BINARY_FUNCTION_CALL(ns, name) \ inline ExpressionRef name(const ExpressionRef& x, const ExpressionRef& y) { \ return AddExpressionToGraph( \ - Expression::CreateScalarFunctionCall(#name, {x.id, y.id})); \ + Expression::CreateScalarFunctionCall(#ns "::" #name, {x.id, y.id})); \ } -CERES_DEFINE_UNARY_FUNCTION_CALL(abs); -CERES_DEFINE_UNARY_FUNCTION_CALL(acos); -CERES_DEFINE_UNARY_FUNCTION_CALL(asin); -CERES_DEFINE_UNARY_FUNCTION_CALL(atan); -CERES_DEFINE_UNARY_FUNCTION_CALL(cbrt); -CERES_DEFINE_UNARY_FUNCTION_CALL(ceil); -CERES_DEFINE_UNARY_FUNCTION_CALL(cos); -CERES_DEFINE_UNARY_FUNCTION_CALL(cosh); -CERES_DEFINE_UNARY_FUNCTION_CALL(exp); -CERES_DEFINE_UNARY_FUNCTION_CALL(exp2); -CERES_DEFINE_UNARY_FUNCTION_CALL(floor); -CERES_DEFINE_UNARY_FUNCTION_CALL(log); -CERES_DEFINE_UNARY_FUNCTION_CALL(log2); -CERES_DEFINE_UNARY_FUNCTION_CALL(sin); -CERES_DEFINE_UNARY_FUNCTION_CALL(sinh); -CERES_DEFINE_UNARY_FUNCTION_CALL(sqrt); -CERES_DEFINE_UNARY_FUNCTION_CALL(tan); -CERES_DEFINE_UNARY_FUNCTION_CALL(tanh); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, abs); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, acos); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, asin); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, atan); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, cbrt); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, ceil); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, cos); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, cosh); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, exp); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, exp2); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, floor); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, log); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, log2); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, sin); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, sinh); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, sqrt); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, tan); +CERES_DEFINE_UNARY_FUNCTION_CALL(std, tanh); -CERES_DEFINE_BINARY_FUNCTION_CALL(atan2); -CERES_DEFINE_BINARY_FUNCTION_CALL(pow); +CERES_DEFINE_BINARY_FUNCTION_CALL(std, atan2); +CERES_DEFINE_BINARY_FUNCTION_CALL(std, pow); #undef CERES_DEFINE_UNARY_FUNCTION_CALL #undef CERES_DEFINE_BINARY_FUNCTION_CALL @@ -198,16 +198,16 @@ const ComparisonExpressionRef& y); ComparisonExpressionRef operator!(const ComparisonExpressionRef& x); -#define CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(name) \ - inline ComparisonExpressionRef name(const ExpressionRef& x) { \ - return ComparisonExpressionRef(AddExpressionToGraph( \ - Expression::CreateLogicalFunctionCall(#name, {x.id}))); \ +#define CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(ns, name) \ + inline ComparisonExpressionRef name(const ExpressionRef& x) { \ + return ComparisonExpressionRef(AddExpressionToGraph( \ + Expression::CreateLogicalFunctionCall(#ns "::" #name, {x.id}))); \ } -CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(isfinite); -CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(isinf); -CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(isnan); -CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(isnormal); +CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(std, isfinite); +CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(std, isinf); +CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(std, isnan); +CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(std, isnormal); #undef CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL
diff --git a/include/ceres/codegen/macros.h b/include/ceres/codegen/macros.h index ee5068a..d3a43a6 100644 --- a/include/ceres/codegen/macros.h +++ b/include/ceres/codegen/macros.h
@@ -100,7 +100,7 @@ // The CERES_CODEGEN macro is defined by the build system only during code // generation. #ifndef CERES_CODEGEN -#define CERES_LOCAL_VARIABLE(_template_type, _local_variable) (_local_variable) +#define CERES_LOCAL_VARIABLE(type, local_variable) type(local_variable) #define CERES_IF(condition_) if (condition_) #define CERES_ELSE else #define CERES_ENDIF
diff --git a/internal/ceres/code_generator.cc b/internal/ceres/code_generator.cc index 76ab48f..7b5fd61 100644 --- a/internal/ceres/code_generator.cc +++ b/internal/ceres/code_generator.cc
@@ -29,10 +29,13 @@ // Author: darius.rueckert@fau.de (Darius Rueckert) #include "ceres/codegen/internal/code_generator.h" + +#include <cmath> +#include <limits> #include <sstream> + #include "assert.h" #include "glog/logging.h" - namespace ceres { namespace internal { @@ -113,7 +116,23 @@ // Format: <lhs_id> = <value>; // Example: v_0 = 3.1415; // - result << indentation_ << lhs << " = " << value << ";"; + result << indentation_ << lhs << " = "; + + // Putting an inf or nan double into std::stringstream will just print the + // strings "inf" and "nan". This is not valid C++ code so we have to check + // for this here. + if (std::isinf(value)) { + if (value > 0) { + result << "std::numeric_limits<double>::infinity()"; + } else { + result << "-std::numeric_limits<double>::infinity()"; + } + } else if (std::isnan(value)) { + result << "std::numeric_limits<double>::quiet_NaN()"; + } else { + result << value; + } + result << ";"; break; } case ExpressionType::INPUT_ASSIGNMENT: {
diff --git a/internal/ceres/codegen/code_generator_test.cc b/internal/ceres/codegen/code_generator_test.cc index 17246a8..8b4aafc 100644 --- a/internal/ceres/codegen/code_generator_test.cc +++ b/internal/ceres/codegen/code_generator_test.cc
@@ -31,6 +31,7 @@ #define CERES_CODEGEN #include "ceres/codegen/internal/code_generator.h" + #include "ceres/codegen/internal/expression_graph.h" #include "ceres/codegen/internal/expression_ref.h" #include "gtest/gtest.h" @@ -65,16 +66,26 @@ T a = T(0); T b = T(123.5); T c = T(1 + 1); - T d; // Uninitialized variables should not generate code! + T d = T(std::numeric_limits<double>::infinity()); + T e = T(-std::numeric_limits<double>::infinity()); + T f = T(std::numeric_limits<double>::quiet_NaN()); + T g; // Uninitialized variables should not generate code! auto graph = StopRecordingExpressions(); - std::vector<std::string> expected_code = {"{", - " double v_0;", - " double v_1;", - " double v_2;", - " v_0 = 0;", - " v_1 = 123.5;", - " v_2 = 2;", - "}"}; + std::vector<std::string> expected_code = { + "{", + " double v_0;", + " double v_1;", + " double v_2;", + " double v_3;", + " double v_4;", + " double v_5;", + " v_0 = 0;", + " v_1 = 123.5;", + " v_2 = 2;", + " v_3 = std::numeric_limits<double>::infinity();", + " v_4 = -std::numeric_limits<double>::infinity();", + " v_5 = std::numeric_limits<double>::quiet_NaN();", + "}"}; GenerateAndCheck(graph, expected_code); } @@ -352,26 +363,26 @@ " double v_21;", " v_0 = 0;", " v_1 = 1;", - " v_2 = abs(v_0);", - " v_3 = acos(v_0);", - " v_4 = asin(v_0);", - " v_5 = atan(v_0);", - " v_6 = cbrt(v_0);", - " v_7 = ceil(v_0);", - " v_8 = cos(v_0);", - " v_9 = cosh(v_0);", - " v_10 = exp(v_0);", - " v_11 = exp2(v_0);", - " v_12 = floor(v_0);", - " v_13 = log(v_0);", - " v_14 = log2(v_0);", - " v_15 = sin(v_0);", - " v_16 = sinh(v_0);", - " v_17 = sqrt(v_0);", - " v_18 = tan(v_0);", - " v_19 = tanh(v_0);", - " v_20 = atan2(v_0, v_1);", - " v_21 = pow(v_0, v_1);", + " v_2 = std::abs(v_0);", + " v_3 = std::acos(v_0);", + " v_4 = std::asin(v_0);", + " v_5 = std::atan(v_0);", + " v_6 = std::cbrt(v_0);", + " v_7 = std::ceil(v_0);", + " v_8 = std::cos(v_0);", + " v_9 = std::cosh(v_0);", + " v_10 = std::exp(v_0);", + " v_11 = std::exp2(v_0);", + " v_12 = std::floor(v_0);", + " v_13 = std::log(v_0);", + " v_14 = std::log2(v_0);", + " v_15 = std::sin(v_0);", + " v_16 = std::sinh(v_0);", + " v_17 = std::sqrt(v_0);", + " v_18 = std::tan(v_0);", + " v_19 = std::tanh(v_0);", + " v_20 = std::atan2(v_0, v_1);", + " v_21 = std::pow(v_0, v_1);", "}"}; GenerateAndCheck(graph, expected_code); } @@ -394,10 +405,10 @@ " bool v_3;", " bool v_4;", " v_0 = 1;", - " v_1 = isfinite(v_0);", - " v_2 = isinf(v_0);", - " v_3 = isnan(v_0);", - " v_4 = isnormal(v_0);", + " v_1 = std::isfinite(v_0);", + " v_2 = std::isinf(v_0);", + " v_3 = std::isnan(v_0);", + " v_4 = std::isnormal(v_0);", "}"}; GenerateAndCheck(graph, expected_code); }
diff --git a/internal/ceres/codegen/expression_ref_test.cc b/internal/ceres/codegen/expression_ref_test.cc index 0a70d42..280043b 100644 --- a/internal/ceres/codegen/expression_ref_test.cc +++ b/internal/ceres/codegen/expression_ref_test.cc
@@ -284,26 +284,28 @@ ExpressionGraph reference; reference.InsertBack(Expression::CreateCompileTimeConstant(1)); reference.InsertBack(Expression::CreateCompileTimeConstant(2)); - reference.InsertBack(Expression::CreateScalarFunctionCall("abs", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("acos", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("asin", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("atan", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("cbrt", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("ceil", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("cos", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("cosh", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("exp", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("exp2", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("floor", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("log", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("log2", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("sin", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("sinh", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("sqrt", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("tan", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("tanh", {0})); - reference.InsertBack(Expression::CreateScalarFunctionCall("atan2", {0, 1})); - reference.InsertBack(Expression::CreateScalarFunctionCall("pow", {0, 1})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::abs", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::acos", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::asin", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::atan", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::cbrt", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::ceil", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::cos", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::cosh", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::exp", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::exp2", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::floor", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::log", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::log2", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::sin", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::sinh", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::sqrt", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::tan", {0})); + reference.InsertBack(Expression::CreateScalarFunctionCall("std::tanh", {0})); + reference.InsertBack( + Expression::CreateScalarFunctionCall("std::atan2", {0, 1})); + reference.InsertBack( + Expression::CreateScalarFunctionCall("std::pow", {0, 1})); EXPECT_EQ(reference, graph); } @@ -318,10 +320,14 @@ ExpressionGraph reference; reference.InsertBack(Expression::CreateCompileTimeConstant(1)); - reference.InsertBack(Expression::CreateLogicalFunctionCall("isfinite", {0})); - reference.InsertBack(Expression::CreateLogicalFunctionCall("isinf", {0})); - reference.InsertBack(Expression::CreateLogicalFunctionCall("isnan", {0})); - reference.InsertBack(Expression::CreateLogicalFunctionCall("isnormal", {0})); + reference.InsertBack( + Expression::CreateLogicalFunctionCall("std::isfinite", {0})); + reference.InsertBack( + Expression::CreateLogicalFunctionCall("std::isinf", {0})); + reference.InsertBack( + Expression::CreateLogicalFunctionCall("std::isnan", {0})); + reference.InsertBack( + Expression::CreateLogicalFunctionCall("std::isnormal", {0})); EXPECT_EQ(reference, graph); }
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc index dda412a..883d4ea 100644 --- a/internal/ceres/expression_ref.cc +++ b/internal/ceres/expression_ref.cc
@@ -129,8 +129,8 @@ ExpressionRef Ternary(const ComparisonExpressionRef& c, const ExpressionRef& x, const ExpressionRef& y) { - return AddExpressionToGraph( - Expression::CreateScalarFunctionCall("Ternary", {c.id, x.id, y.id})); + return AddExpressionToGraph(Expression::CreateScalarFunctionCall( + "ceres::Ternary", {c.id, x.id, y.id})); } #define CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR(op) \