Add ExpressionGraph::InsertExpression

This functions allows the insertion of new expression anywhere into
the graph. All references are updated. This will be used by semantic
expression testing and by the optimizer.

Change-Id: Ieafcfec7672f6328106dc0511c85d6fb5bd64d97
diff --git a/include/ceres/internal/expression.h b/include/ceres/internal/expression.h
index ce470b2..09d4ceb 100644
--- a/include/ceres/internal/expression.h
+++ b/include/ceres/internal/expression.h
@@ -349,7 +349,7 @@
   // If lhs_id_ == kInvalidExpressionId, then the expression type is not
   // arithmetic. Currently, only the following types have lhs_id = invalid:
   // IF,ELSE,ENDIF,NOP
-  const ExpressionId lhs_id_ = kInvalidExpressionId;
+  ExpressionId lhs_id_ = kInvalidExpressionId;
 
   // Expressions have different number of arguments. For example a binary "+"
   // has 2 parameters and a function call to "sin" has 1 parameter. Here, a
diff --git a/include/ceres/internal/expression_graph.h b/include/ceres/internal/expression_graph.h
index 0250d19..644af0b 100644
--- a/include/ceres/internal/expression_graph.h
+++ b/include/ceres/internal/expression_graph.h
@@ -81,6 +81,16 @@
 
   int Size() const { return expressions_.size(); }
 
+  // Insert a new expression at "location" into the graph. All expression
+  // after "location" are moved by one element to the back. References to moved
+  // expression are updated.
+  void InsertExpression(ExpressionId location,
+                        ExpressionType type,
+                        ExpressionId lhs_id,
+                        const std::vector<ExpressionId>& arguments,
+                        const std::string& name,
+                        double value);
+
  private:
   // All Expressions are referenced by an ExpressionId. The ExpressionId is the
   // index into this array. Each expression has a list of ExpressionId as
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc
index b7e7888..511a98b 100644
--- a/internal/ceres/expression_graph.cc
+++ b/internal/ceres/expression_graph.cc
@@ -108,5 +108,39 @@
   return true;
 }
 
+void ExpressionGraph::InsertExpression(
+    ExpressionId location,
+    ExpressionType type,
+    ExpressionId lhs_id,
+    const std::vector<ExpressionId>& arguments,
+    const std::string& name,
+    double value) {
+  ExpressionId last_expression_id = Size() - 1;
+  // Increase size by adding a dummy expression.
+  expressions_.push_back(Expression(ExpressionType::NOP, kInvalidExpressionId));
+
+  // Move everything after id back and update references
+  for (ExpressionId id = last_expression_id; id >= location; --id) {
+    auto& expression = expressions_[id];
+    // Increment reference if it points to a shifted variable.
+    if (expression.lhs_id_ >= location) {
+      expression.lhs_id_++;
+    }
+    for (auto& arg : expression.arguments_) {
+      if (arg >= location) {
+        arg++;
+      }
+    }
+    expressions_[id + 1] = expression;
+  }
+
+  // Insert new expression at the correct place
+  Expression expr(type, lhs_id);
+  expr.arguments_ = arguments;
+  expr.name_ = name;
+  expr.value_ = value;
+  expressions_[location] = expr;
+}
+
 }  // namespace internal
 }  // namespace ceres
diff --git a/internal/ceres/expression_graph_test.cc b/internal/ceres/expression_graph_test.cc
index ee06f1a..faaa10c 100644
--- a/internal/ceres/expression_graph_test.cc
+++ b/internal/ceres/expression_graph_test.cc
@@ -73,5 +73,89 @@
   ASSERT_FALSE(tree.DependsOn(d.id, unused.id));
 }
 
+TEST(ExpressionGraph, InsertExpression_UpdateReferences) {
+  // This test checks if references to shifted expressions are updated
+  // accordingly.
+  using T = ExpressionRef;
+  StartRecordingExpressions();
+  T a(2);       // 0
+  T b(3);       // 1
+  T c = a + b;  // 2
+  auto graph = StopRecordingExpressions();
+
+  // Test if 'a' and 'c' are actually at location 0 and 2
+  auto& a_expr = graph.ExpressionForId(0);
+  EXPECT_EQ(a_expr.type(), ExpressionType::COMPILE_TIME_CONSTANT);
+  EXPECT_EQ(a_expr.value(), 2);
+
+  // At this point 'c' should have 0 and 1 as arguments.
+  auto& c_expr = graph.ExpressionForId(2);
+  EXPECT_EQ(c_expr.type(), ExpressionType::BINARY_ARITHMETIC);
+  EXPECT_EQ(c_expr.arguments()[0], 0);
+  EXPECT_EQ(c_expr.arguments()[1], 1);
+
+  // We insert at the beginning, which shifts everything by one spot.
+  graph.InsertExpression(
+      0, ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 10.2);
+
+  // Test if 'a' and 'c' are actually at location 1 and 3
+  auto& a_expr2 = graph.ExpressionForId(1);
+  EXPECT_EQ(a_expr2.type(), ExpressionType::COMPILE_TIME_CONSTANT);
+  EXPECT_EQ(a_expr2.value(), 2);
+
+  // At this point 'c' should have 1 and 2 as arguments.
+  auto& c_expr2 = graph.ExpressionForId(3);
+  EXPECT_EQ(c_expr2.type(), ExpressionType::BINARY_ARITHMETIC);
+  EXPECT_EQ(c_expr2.arguments()[0], 1);
+  EXPECT_EQ(c_expr2.arguments()[1], 2);
+}
+
+TEST(ExpressionGraph, InsertExpression) {
+  using T = ExpressionRef;
+
+  StartRecordingExpressions();
+
+  {
+    T a(2);                   // 0
+    T b(3);                   // 1
+    T five = 5;               // 2
+    T tmp = a + five;         // 3
+    a = tmp;                  // 4
+    T c = a + b;              // 5
+    T d = a * b;              // 6
+    T e = c + d;              // 7
+    MakeOutput(e, "result");  // 8
+  }
+  auto reference = StopRecordingExpressions();
+  EXPECT_EQ(reference.Size(), 9);
+
+  StartRecordingExpressions();
+
+  {
+    // The expressions 2,3,4 from above are missing.
+    T a(2);                   // 0
+    T b(3);                   // 1
+    T c = a + b;              // 2
+    T d = a * b;              // 3
+    T e = c + d;              // 4
+    MakeOutput(e, "result");  // 5
+  }
+
+  auto graph1 = StopRecordingExpressions();
+  EXPECT_EQ(graph1.Size(), 6);
+  ASSERT_FALSE(reference == graph1);
+
+  // We manually insert the 3 missing expressions
+  // clang-format off
+  graph1.InsertExpression(2, ExpressionType::COMPILE_TIME_CONSTANT, 2,     {},   "",  5);
+  graph1.InsertExpression(3,     ExpressionType::BINARY_ARITHMETIC, 3, {0, 2},  "+",  0);
+  graph1.InsertExpression(4,            ExpressionType::ASSIGNMENT, 0,    {3},   "",  0);
+  // clang-format on
+
+  // Now the graphs are identical!
+  EXPECT_EQ(graph1.Size(), 9);
+  ASSERT_TRUE(reference == graph1);
+}
+
 }  // namespace internal
 }  // namespace ceres