Add ExpressionGraph::Erase(ExpressionId)

Add the function ExpressionGraph::Erase and a test-case for it.
Erase removes the given expression from the graph by shifting
all later expressions to the front. Indices and references
are updated accordingly.

Change-Id: Ic0449ccf28b369600fd2959a7e2a919d47f4cbe3
diff --git a/include/ceres/codegen/internal/expression_graph.h b/include/ceres/codegen/internal/expression_graph.h
index 992c23a..7c5df80 100644
--- a/include/ceres/codegen/internal/expression_graph.h
+++ b/include/ceres/codegen/internal/expression_graph.h
@@ -64,6 +64,12 @@
 
   int Size() const { return expressions_.size(); }
 
+  // Erases the expression at "location". All expression after "location" are
+  // moved by one element to the front. References to moved expressions are
+  // updated. Removing an expression that is still referenced somewhere is
+  // undefined behaviour.
+  void Erase(ExpressionId location);
+
   // Insert a new expression at "location" into the graph. All expression
   // after "location" are moved by one element to the back. References to
   // moved expressions are updated.
diff --git a/internal/ceres/codegen/expression_graph_test.cc b/internal/ceres/codegen/expression_graph_test.cc
index 420b05a..4f4f0ec 100644
--- a/internal/ceres/codegen/expression_graph_test.cc
+++ b/internal/ceres/codegen/expression_graph_test.cc
@@ -250,5 +250,37 @@
   EXPECT_EQ(graph, ref);
 }
 
+TEST(ExpressionGraph, Erase) {
+  // This test checks if references to shifted expressions are updated
+  // accordingly.
+  ExpressionGraph graph;
+  graph.InsertBack(Expression::CreateCompileTimeConstant(42));
+  graph.InsertBack(Expression::CreateCompileTimeConstant(10));
+  graph.InsertBack(Expression::CreateCompileTimeConstant(3));
+  graph.InsertBack(Expression::CreateBinaryArithmetic(
+      "+", ExpressionId(0), ExpressionId(2)));
+  // Code:
+  // v_0 = 42
+  // v_1 = 10
+  // v_2 = 3
+  // v_3 = v_0 + v_2
+
+  // Erase the unused expression v_1 = 10
+  graph.Erase(1);
+  // This should shift all indices like this:
+  // v_0 = 42
+  // v_1 = 3
+  // v_2 = v_0 + v_1
+
+  // Test by inserting it in the correct order
+  ExpressionGraph ref;
+  ref.InsertBack(Expression::CreateCompileTimeConstant(42));
+  ref.InsertBack(Expression::CreateCompileTimeConstant(3));
+  ref.InsertBack(Expression::CreateBinaryArithmetic(
+      "+", ExpressionId(0), ExpressionId(1)));
+  EXPECT_EQ(graph.Size(), ref.Size());
+  EXPECT_EQ(graph, ref);
+}
+
 }  // namespace internal
 }  // namespace ceres
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc
index e619e53..59f20ea 100644
--- a/internal/ceres/expression_graph.cc
+++ b/internal/ceres/expression_graph.cc
@@ -87,8 +87,30 @@
   return true;
 }
 
+void ExpressionGraph::Erase(ExpressionId location) {
+  CHECK_GE(location, 0);
+  CHECK_LT(location, Size());
+  // Move everything after id to the front and update references
+  for (ExpressionId id = location + 1; id < Size(); ++id) {
+    expressions_[id - 1] = expressions_[id];
+    auto& expression = expressions_[id - 1];
+    // Decrement reference if it points to a shifted variable.
+    if (expression.lhs_id() >= location) {
+      expression.set_lhs_id(expression.lhs_id() - 1);
+    }
+    for (auto& arg : *expression.mutable_arguments()) {
+      if (arg >= location) {
+        arg--;
+      }
+    }
+  }
+  expressions_.resize(Size() - 1);
+}
+
 void ExpressionGraph::Insert(ExpressionId location,
                              const Expression& expression) {
+  CHECK_GE(location, 0);
+  CHECK_LE(location, Size());
   ExpressionId last_expression_id = Size() - 1;
   // Increase size by adding a dummy expression.
   expressions_.push_back(Expression());