Add namespaces to generated functions and constants
The generated function names now include the containing namespace.
For example:
std::abs(...)
std::sin(...)
ceres::Ternary(...)
This patch also fixes the generation of inf/nan compile time constants,
using std::numeric_limits.
Change-Id: I4a36b09c68dd2adabed49fd4f7f37c8229ab7377
diff --git a/include/ceres/codegen/internal/expression_ref.h b/include/ceres/codegen/internal/expression_ref.h
index 5a13d76..d73d477 100644
--- a/include/ceres/codegen/internal/expression_ref.h
+++ b/include/ceres/codegen/internal/expression_ref.h
@@ -115,37 +115,37 @@
ExpressionRef operator/(const ExpressionRef& x, const ExpressionRef& y);
// Functions
-#define CERES_DEFINE_UNARY_FUNCTION_CALL(name) \
- inline ExpressionRef name(const ExpressionRef& x) { \
- return AddExpressionToGraph( \
- Expression::CreateScalarFunctionCall(#name, {x.id})); \
+#define CERES_DEFINE_UNARY_FUNCTION_CALL(ns, name) \
+ inline ExpressionRef name(const ExpressionRef& x) { \
+ return AddExpressionToGraph( \
+ Expression::CreateScalarFunctionCall(#ns "::" #name, {x.id})); \
}
-#define CERES_DEFINE_BINARY_FUNCTION_CALL(name) \
+#define CERES_DEFINE_BINARY_FUNCTION_CALL(ns, name) \
inline ExpressionRef name(const ExpressionRef& x, const ExpressionRef& y) { \
return AddExpressionToGraph( \
- Expression::CreateScalarFunctionCall(#name, {x.id, y.id})); \
+ Expression::CreateScalarFunctionCall(#ns "::" #name, {x.id, y.id})); \
}
-CERES_DEFINE_UNARY_FUNCTION_CALL(abs);
-CERES_DEFINE_UNARY_FUNCTION_CALL(acos);
-CERES_DEFINE_UNARY_FUNCTION_CALL(asin);
-CERES_DEFINE_UNARY_FUNCTION_CALL(atan);
-CERES_DEFINE_UNARY_FUNCTION_CALL(cbrt);
-CERES_DEFINE_UNARY_FUNCTION_CALL(ceil);
-CERES_DEFINE_UNARY_FUNCTION_CALL(cos);
-CERES_DEFINE_UNARY_FUNCTION_CALL(cosh);
-CERES_DEFINE_UNARY_FUNCTION_CALL(exp);
-CERES_DEFINE_UNARY_FUNCTION_CALL(exp2);
-CERES_DEFINE_UNARY_FUNCTION_CALL(floor);
-CERES_DEFINE_UNARY_FUNCTION_CALL(log);
-CERES_DEFINE_UNARY_FUNCTION_CALL(log2);
-CERES_DEFINE_UNARY_FUNCTION_CALL(sin);
-CERES_DEFINE_UNARY_FUNCTION_CALL(sinh);
-CERES_DEFINE_UNARY_FUNCTION_CALL(sqrt);
-CERES_DEFINE_UNARY_FUNCTION_CALL(tan);
-CERES_DEFINE_UNARY_FUNCTION_CALL(tanh);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, abs);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, acos);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, asin);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, atan);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, cbrt);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, ceil);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, cos);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, cosh);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, exp);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, exp2);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, floor);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, log);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, log2);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, sin);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, sinh);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, sqrt);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, tan);
+CERES_DEFINE_UNARY_FUNCTION_CALL(std, tanh);
-CERES_DEFINE_BINARY_FUNCTION_CALL(atan2);
-CERES_DEFINE_BINARY_FUNCTION_CALL(pow);
+CERES_DEFINE_BINARY_FUNCTION_CALL(std, atan2);
+CERES_DEFINE_BINARY_FUNCTION_CALL(std, pow);
#undef CERES_DEFINE_UNARY_FUNCTION_CALL
#undef CERES_DEFINE_BINARY_FUNCTION_CALL
@@ -198,16 +198,16 @@
const ComparisonExpressionRef& y);
ComparisonExpressionRef operator!(const ComparisonExpressionRef& x);
-#define CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(name) \
- inline ComparisonExpressionRef name(const ExpressionRef& x) { \
- return ComparisonExpressionRef(AddExpressionToGraph( \
- Expression::CreateLogicalFunctionCall(#name, {x.id}))); \
+#define CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(ns, name) \
+ inline ComparisonExpressionRef name(const ExpressionRef& x) { \
+ return ComparisonExpressionRef(AddExpressionToGraph( \
+ Expression::CreateLogicalFunctionCall(#ns "::" #name, {x.id}))); \
}
-CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(isfinite);
-CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(isinf);
-CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(isnan);
-CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(isnormal);
+CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(std, isfinite);
+CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(std, isinf);
+CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(std, isnan);
+CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL(std, isnormal);
#undef CERES_DEFINE_UNARY_LOGICAL_FUNCTION_CALL
diff --git a/include/ceres/codegen/macros.h b/include/ceres/codegen/macros.h
index ee5068a..d3a43a6 100644
--- a/include/ceres/codegen/macros.h
+++ b/include/ceres/codegen/macros.h
@@ -100,7 +100,7 @@
// The CERES_CODEGEN macro is defined by the build system only during code
// generation.
#ifndef CERES_CODEGEN
-#define CERES_LOCAL_VARIABLE(_template_type, _local_variable) (_local_variable)
+#define CERES_LOCAL_VARIABLE(type, local_variable) type(local_variable)
#define CERES_IF(condition_) if (condition_)
#define CERES_ELSE else
#define CERES_ENDIF
diff --git a/internal/ceres/code_generator.cc b/internal/ceres/code_generator.cc
index 76ab48f..7b5fd61 100644
--- a/internal/ceres/code_generator.cc
+++ b/internal/ceres/code_generator.cc
@@ -29,10 +29,13 @@
// Author: darius.rueckert@fau.de (Darius Rueckert)
#include "ceres/codegen/internal/code_generator.h"
+
+#include <cmath>
+#include <limits>
#include <sstream>
+
#include "assert.h"
#include "glog/logging.h"
-
namespace ceres {
namespace internal {
@@ -113,7 +116,23 @@
// Format: <lhs_id> = <value>;
// Example: v_0 = 3.1415;
//
- result << indentation_ << lhs << " = " << value << ";";
+ result << indentation_ << lhs << " = ";
+
+ // Putting an inf or nan double into std::stringstream will just print the
+ // strings "inf" and "nan". This is not valid C++ code so we have to check
+ // for this here.
+ if (std::isinf(value)) {
+ if (value > 0) {
+ result << "std::numeric_limits<double>::infinity()";
+ } else {
+ result << "-std::numeric_limits<double>::infinity()";
+ }
+ } else if (std::isnan(value)) {
+ result << "std::numeric_limits<double>::quiet_NaN()";
+ } else {
+ result << value;
+ }
+ result << ";";
break;
}
case ExpressionType::INPUT_ASSIGNMENT: {
diff --git a/internal/ceres/codegen/code_generator_test.cc b/internal/ceres/codegen/code_generator_test.cc
index 17246a8..8b4aafc 100644
--- a/internal/ceres/codegen/code_generator_test.cc
+++ b/internal/ceres/codegen/code_generator_test.cc
@@ -31,6 +31,7 @@
#define CERES_CODEGEN
#include "ceres/codegen/internal/code_generator.h"
+
#include "ceres/codegen/internal/expression_graph.h"
#include "ceres/codegen/internal/expression_ref.h"
#include "gtest/gtest.h"
@@ -65,16 +66,26 @@
T a = T(0);
T b = T(123.5);
T c = T(1 + 1);
- T d; // Uninitialized variables should not generate code!
+ T d = T(std::numeric_limits<double>::infinity());
+ T e = T(-std::numeric_limits<double>::infinity());
+ T f = T(std::numeric_limits<double>::quiet_NaN());
+ T g; // Uninitialized variables should not generate code!
auto graph = StopRecordingExpressions();
- std::vector<std::string> expected_code = {"{",
- " double v_0;",
- " double v_1;",
- " double v_2;",
- " v_0 = 0;",
- " v_1 = 123.5;",
- " v_2 = 2;",
- "}"};
+ std::vector<std::string> expected_code = {
+ "{",
+ " double v_0;",
+ " double v_1;",
+ " double v_2;",
+ " double v_3;",
+ " double v_4;",
+ " double v_5;",
+ " v_0 = 0;",
+ " v_1 = 123.5;",
+ " v_2 = 2;",
+ " v_3 = std::numeric_limits<double>::infinity();",
+ " v_4 = -std::numeric_limits<double>::infinity();",
+ " v_5 = std::numeric_limits<double>::quiet_NaN();",
+ "}"};
GenerateAndCheck(graph, expected_code);
}
@@ -352,26 +363,26 @@
" double v_21;",
" v_0 = 0;",
" v_1 = 1;",
- " v_2 = abs(v_0);",
- " v_3 = acos(v_0);",
- " v_4 = asin(v_0);",
- " v_5 = atan(v_0);",
- " v_6 = cbrt(v_0);",
- " v_7 = ceil(v_0);",
- " v_8 = cos(v_0);",
- " v_9 = cosh(v_0);",
- " v_10 = exp(v_0);",
- " v_11 = exp2(v_0);",
- " v_12 = floor(v_0);",
- " v_13 = log(v_0);",
- " v_14 = log2(v_0);",
- " v_15 = sin(v_0);",
- " v_16 = sinh(v_0);",
- " v_17 = sqrt(v_0);",
- " v_18 = tan(v_0);",
- " v_19 = tanh(v_0);",
- " v_20 = atan2(v_0, v_1);",
- " v_21 = pow(v_0, v_1);",
+ " v_2 = std::abs(v_0);",
+ " v_3 = std::acos(v_0);",
+ " v_4 = std::asin(v_0);",
+ " v_5 = std::atan(v_0);",
+ " v_6 = std::cbrt(v_0);",
+ " v_7 = std::ceil(v_0);",
+ " v_8 = std::cos(v_0);",
+ " v_9 = std::cosh(v_0);",
+ " v_10 = std::exp(v_0);",
+ " v_11 = std::exp2(v_0);",
+ " v_12 = std::floor(v_0);",
+ " v_13 = std::log(v_0);",
+ " v_14 = std::log2(v_0);",
+ " v_15 = std::sin(v_0);",
+ " v_16 = std::sinh(v_0);",
+ " v_17 = std::sqrt(v_0);",
+ " v_18 = std::tan(v_0);",
+ " v_19 = std::tanh(v_0);",
+ " v_20 = std::atan2(v_0, v_1);",
+ " v_21 = std::pow(v_0, v_1);",
"}"};
GenerateAndCheck(graph, expected_code);
}
@@ -394,10 +405,10 @@
" bool v_3;",
" bool v_4;",
" v_0 = 1;",
- " v_1 = isfinite(v_0);",
- " v_2 = isinf(v_0);",
- " v_3 = isnan(v_0);",
- " v_4 = isnormal(v_0);",
+ " v_1 = std::isfinite(v_0);",
+ " v_2 = std::isinf(v_0);",
+ " v_3 = std::isnan(v_0);",
+ " v_4 = std::isnormal(v_0);",
"}"};
GenerateAndCheck(graph, expected_code);
}
diff --git a/internal/ceres/codegen/expression_ref_test.cc b/internal/ceres/codegen/expression_ref_test.cc
index 0a70d42..280043b 100644
--- a/internal/ceres/codegen/expression_ref_test.cc
+++ b/internal/ceres/codegen/expression_ref_test.cc
@@ -284,26 +284,28 @@
ExpressionGraph reference;
reference.InsertBack(Expression::CreateCompileTimeConstant(1));
reference.InsertBack(Expression::CreateCompileTimeConstant(2));
- reference.InsertBack(Expression::CreateScalarFunctionCall("abs", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("acos", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("asin", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("atan", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("cbrt", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("ceil", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("cos", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("cosh", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("exp", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("exp2", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("floor", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("log", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("log2", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("sin", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("sinh", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("sqrt", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("tan", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("tanh", {0}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("atan2", {0, 1}));
- reference.InsertBack(Expression::CreateScalarFunctionCall("pow", {0, 1}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::abs", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::acos", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::asin", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::atan", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::cbrt", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::ceil", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::cos", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::cosh", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::exp", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::exp2", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::floor", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::log", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::log2", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::sin", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::sinh", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::sqrt", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::tan", {0}));
+ reference.InsertBack(Expression::CreateScalarFunctionCall("std::tanh", {0}));
+ reference.InsertBack(
+ Expression::CreateScalarFunctionCall("std::atan2", {0, 1}));
+ reference.InsertBack(
+ Expression::CreateScalarFunctionCall("std::pow", {0, 1}));
EXPECT_EQ(reference, graph);
}
@@ -318,10 +320,14 @@
ExpressionGraph reference;
reference.InsertBack(Expression::CreateCompileTimeConstant(1));
- reference.InsertBack(Expression::CreateLogicalFunctionCall("isfinite", {0}));
- reference.InsertBack(Expression::CreateLogicalFunctionCall("isinf", {0}));
- reference.InsertBack(Expression::CreateLogicalFunctionCall("isnan", {0}));
- reference.InsertBack(Expression::CreateLogicalFunctionCall("isnormal", {0}));
+ reference.InsertBack(
+ Expression::CreateLogicalFunctionCall("std::isfinite", {0}));
+ reference.InsertBack(
+ Expression::CreateLogicalFunctionCall("std::isinf", {0}));
+ reference.InsertBack(
+ Expression::CreateLogicalFunctionCall("std::isnan", {0}));
+ reference.InsertBack(
+ Expression::CreateLogicalFunctionCall("std::isnormal", {0}));
EXPECT_EQ(reference, graph);
}
diff --git a/internal/ceres/expression_ref.cc b/internal/ceres/expression_ref.cc
index dda412a..883d4ea 100644
--- a/internal/ceres/expression_ref.cc
+++ b/internal/ceres/expression_ref.cc
@@ -129,8 +129,8 @@
ExpressionRef Ternary(const ComparisonExpressionRef& c,
const ExpressionRef& x,
const ExpressionRef& y) {
- return AddExpressionToGraph(
- Expression::CreateScalarFunctionCall("Ternary", {c.id, x.id, y.id}));
+ return AddExpressionToGraph(Expression::CreateScalarFunctionCall(
+ "ceres::Ternary", {c.id, x.id, y.id}));
}
#define CERES_DEFINE_EXPRESSION_COMPARISON_OPERATOR(op) \