Add ExpressionGraph::Erase(ExpressionId)
Add the function ExpressionGraph::Erase and a test-case for it.
Erase removes the given expression from the graph by shifting
all later expressions to the front. Indices and references
are updated accordingly.
Change-Id: Ic0449ccf28b369600fd2959a7e2a919d47f4cbe3
diff --git a/include/ceres/codegen/internal/expression_graph.h b/include/ceres/codegen/internal/expression_graph.h
index 992c23a..7c5df80 100644
--- a/include/ceres/codegen/internal/expression_graph.h
+++ b/include/ceres/codegen/internal/expression_graph.h
@@ -64,6 +64,12 @@
int Size() const { return expressions_.size(); }
+ // Erases the expression at "location". All expression after "location" are
+ // moved by one element to the front. References to moved expressions are
+ // updated. Removing an expression that is still referenced somewhere is
+ // undefined behaviour.
+ void Erase(ExpressionId location);
+
// Insert a new expression at "location" into the graph. All expression
// after "location" are moved by one element to the back. References to
// moved expressions are updated.
diff --git a/internal/ceres/codegen/expression_graph_test.cc b/internal/ceres/codegen/expression_graph_test.cc
index 420b05a..4f4f0ec 100644
--- a/internal/ceres/codegen/expression_graph_test.cc
+++ b/internal/ceres/codegen/expression_graph_test.cc
@@ -250,5 +250,37 @@
EXPECT_EQ(graph, ref);
}
+TEST(ExpressionGraph, Erase) {
+ // This test checks if references to shifted expressions are updated
+ // accordingly.
+ ExpressionGraph graph;
+ graph.InsertBack(Expression::CreateCompileTimeConstant(42));
+ graph.InsertBack(Expression::CreateCompileTimeConstant(10));
+ graph.InsertBack(Expression::CreateCompileTimeConstant(3));
+ graph.InsertBack(Expression::CreateBinaryArithmetic(
+ "+", ExpressionId(0), ExpressionId(2)));
+ // Code:
+ // v_0 = 42
+ // v_1 = 10
+ // v_2 = 3
+ // v_3 = v_0 + v_2
+
+ // Erase the unused expression v_1 = 10
+ graph.Erase(1);
+ // This should shift all indices like this:
+ // v_0 = 42
+ // v_1 = 3
+ // v_2 = v_0 + v_1
+
+ // Test by inserting it in the correct order
+ ExpressionGraph ref;
+ ref.InsertBack(Expression::CreateCompileTimeConstant(42));
+ ref.InsertBack(Expression::CreateCompileTimeConstant(3));
+ ref.InsertBack(Expression::CreateBinaryArithmetic(
+ "+", ExpressionId(0), ExpressionId(1)));
+ EXPECT_EQ(graph.Size(), ref.Size());
+ EXPECT_EQ(graph, ref);
+}
+
} // namespace internal
} // namespace ceres
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc
index e619e53..59f20ea 100644
--- a/internal/ceres/expression_graph.cc
+++ b/internal/ceres/expression_graph.cc
@@ -87,8 +87,30 @@
return true;
}
+void ExpressionGraph::Erase(ExpressionId location) {
+ CHECK_GE(location, 0);
+ CHECK_LT(location, Size());
+ // Move everything after id to the front and update references
+ for (ExpressionId id = location + 1; id < Size(); ++id) {
+ expressions_[id - 1] = expressions_[id];
+ auto& expression = expressions_[id - 1];
+ // Decrement reference if it points to a shifted variable.
+ if (expression.lhs_id() >= location) {
+ expression.set_lhs_id(expression.lhs_id() - 1);
+ }
+ for (auto& arg : *expression.mutable_arguments()) {
+ if (arg >= location) {
+ arg--;
+ }
+ }
+ }
+ expressions_.resize(Size() - 1);
+}
+
void ExpressionGraph::Insert(ExpressionId location,
const Expression& expression) {
+ CHECK_GE(location, 0);
+ CHECK_LE(location, Size());
ExpressionId last_expression_id = Size() - 1;
// Increase size by adding a dummy expression.
expressions_.push_back(Expression());