Add ExpressionGraph::InsertExpression
This functions allows the insertion of new expression anywhere into
the graph. All references are updated. This will be used by semantic
expression testing and by the optimizer.
Change-Id: Ieafcfec7672f6328106dc0511c85d6fb5bd64d97
diff --git a/include/ceres/internal/expression.h b/include/ceres/internal/expression.h
index ce470b2..09d4ceb 100644
--- a/include/ceres/internal/expression.h
+++ b/include/ceres/internal/expression.h
@@ -349,7 +349,7 @@
// If lhs_id_ == kInvalidExpressionId, then the expression type is not
// arithmetic. Currently, only the following types have lhs_id = invalid:
// IF,ELSE,ENDIF,NOP
- const ExpressionId lhs_id_ = kInvalidExpressionId;
+ ExpressionId lhs_id_ = kInvalidExpressionId;
// Expressions have different number of arguments. For example a binary "+"
// has 2 parameters and a function call to "sin" has 1 parameter. Here, a
diff --git a/include/ceres/internal/expression_graph.h b/include/ceres/internal/expression_graph.h
index 0250d19..644af0b 100644
--- a/include/ceres/internal/expression_graph.h
+++ b/include/ceres/internal/expression_graph.h
@@ -81,6 +81,16 @@
int Size() const { return expressions_.size(); }
+ // Insert a new expression at "location" into the graph. All expression
+ // after "location" are moved by one element to the back. References to moved
+ // expression are updated.
+ void InsertExpression(ExpressionId location,
+ ExpressionType type,
+ ExpressionId lhs_id,
+ const std::vector<ExpressionId>& arguments,
+ const std::string& name,
+ double value);
+
private:
// All Expressions are referenced by an ExpressionId. The ExpressionId is the
// index into this array. Each expression has a list of ExpressionId as
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc
index b7e7888..511a98b 100644
--- a/internal/ceres/expression_graph.cc
+++ b/internal/ceres/expression_graph.cc
@@ -108,5 +108,39 @@
return true;
}
+void ExpressionGraph::InsertExpression(
+ ExpressionId location,
+ ExpressionType type,
+ ExpressionId lhs_id,
+ const std::vector<ExpressionId>& arguments,
+ const std::string& name,
+ double value) {
+ ExpressionId last_expression_id = Size() - 1;
+ // Increase size by adding a dummy expression.
+ expressions_.push_back(Expression(ExpressionType::NOP, kInvalidExpressionId));
+
+ // Move everything after id back and update references
+ for (ExpressionId id = last_expression_id; id >= location; --id) {
+ auto& expression = expressions_[id];
+ // Increment reference if it points to a shifted variable.
+ if (expression.lhs_id_ >= location) {
+ expression.lhs_id_++;
+ }
+ for (auto& arg : expression.arguments_) {
+ if (arg >= location) {
+ arg++;
+ }
+ }
+ expressions_[id + 1] = expression;
+ }
+
+ // Insert new expression at the correct place
+ Expression expr(type, lhs_id);
+ expr.arguments_ = arguments;
+ expr.name_ = name;
+ expr.value_ = value;
+ expressions_[location] = expr;
+}
+
} // namespace internal
} // namespace ceres
diff --git a/internal/ceres/expression_graph_test.cc b/internal/ceres/expression_graph_test.cc
index ee06f1a..faaa10c 100644
--- a/internal/ceres/expression_graph_test.cc
+++ b/internal/ceres/expression_graph_test.cc
@@ -73,5 +73,89 @@
ASSERT_FALSE(tree.DependsOn(d.id, unused.id));
}
+TEST(ExpressionGraph, InsertExpression_UpdateReferences) {
+ // This test checks if references to shifted expressions are updated
+ // accordingly.
+ using T = ExpressionRef;
+ StartRecordingExpressions();
+ T a(2); // 0
+ T b(3); // 1
+ T c = a + b; // 2
+ auto graph = StopRecordingExpressions();
+
+ // Test if 'a' and 'c' are actually at location 0 and 2
+ auto& a_expr = graph.ExpressionForId(0);
+ EXPECT_EQ(a_expr.type(), ExpressionType::COMPILE_TIME_CONSTANT);
+ EXPECT_EQ(a_expr.value(), 2);
+
+ // At this point 'c' should have 0 and 1 as arguments.
+ auto& c_expr = graph.ExpressionForId(2);
+ EXPECT_EQ(c_expr.type(), ExpressionType::BINARY_ARITHMETIC);
+ EXPECT_EQ(c_expr.arguments()[0], 0);
+ EXPECT_EQ(c_expr.arguments()[1], 1);
+
+ // We insert at the beginning, which shifts everything by one spot.
+ graph.InsertExpression(
+ 0, ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 10.2);
+
+ // Test if 'a' and 'c' are actually at location 1 and 3
+ auto& a_expr2 = graph.ExpressionForId(1);
+ EXPECT_EQ(a_expr2.type(), ExpressionType::COMPILE_TIME_CONSTANT);
+ EXPECT_EQ(a_expr2.value(), 2);
+
+ // At this point 'c' should have 1 and 2 as arguments.
+ auto& c_expr2 = graph.ExpressionForId(3);
+ EXPECT_EQ(c_expr2.type(), ExpressionType::BINARY_ARITHMETIC);
+ EXPECT_EQ(c_expr2.arguments()[0], 1);
+ EXPECT_EQ(c_expr2.arguments()[1], 2);
+}
+
+TEST(ExpressionGraph, InsertExpression) {
+ using T = ExpressionRef;
+
+ StartRecordingExpressions();
+
+ {
+ T a(2); // 0
+ T b(3); // 1
+ T five = 5; // 2
+ T tmp = a + five; // 3
+ a = tmp; // 4
+ T c = a + b; // 5
+ T d = a * b; // 6
+ T e = c + d; // 7
+ MakeOutput(e, "result"); // 8
+ }
+ auto reference = StopRecordingExpressions();
+ EXPECT_EQ(reference.Size(), 9);
+
+ StartRecordingExpressions();
+
+ {
+ // The expressions 2,3,4 from above are missing.
+ T a(2); // 0
+ T b(3); // 1
+ T c = a + b; // 2
+ T d = a * b; // 3
+ T e = c + d; // 4
+ MakeOutput(e, "result"); // 5
+ }
+
+ auto graph1 = StopRecordingExpressions();
+ EXPECT_EQ(graph1.Size(), 6);
+ ASSERT_FALSE(reference == graph1);
+
+ // We manually insert the 3 missing expressions
+ // clang-format off
+ graph1.InsertExpression(2, ExpressionType::COMPILE_TIME_CONSTANT, 2, {}, "", 5);
+ graph1.InsertExpression(3, ExpressionType::BINARY_ARITHMETIC, 3, {0, 2}, "+", 0);
+ graph1.InsertExpression(4, ExpressionType::ASSIGNMENT, 0, {3}, "", 0);
+ // clang-format on
+
+ // Now the graphs are identical!
+ EXPECT_EQ(graph1.Size(), 9);
+ ASSERT_TRUE(reference == graph1);
+}
+
} // namespace internal
} // namespace ceres