Add ExpressionGraph::InsertExpression This functions allows the insertion of new expression anywhere into the graph. All references are updated. This will be used by semantic expression testing and by the optimizer. Change-Id: Ieafcfec7672f6328106dc0511c85d6fb5bd64d97
diff --git a/include/ceres/internal/expression.h b/include/ceres/internal/expression.h index ce470b2..09d4ceb 100644 --- a/include/ceres/internal/expression.h +++ b/include/ceres/internal/expression.h
@@ -349,7 +349,7 @@ // If lhs_id_ == kInvalidExpressionId, then the expression type is not // arithmetic. Currently, only the following types have lhs_id = invalid: // IF,ELSE,ENDIF,NOP - const ExpressionId lhs_id_ = kInvalidExpressionId; + ExpressionId lhs_id_ = kInvalidExpressionId; // Expressions have different number of arguments. For example a binary "+" // has 2 parameters and a function call to "sin" has 1 parameter. Here, a
diff --git a/include/ceres/internal/expression_graph.h b/include/ceres/internal/expression_graph.h index 0250d19..644af0b 100644 --- a/include/ceres/internal/expression_graph.h +++ b/include/ceres/internal/expression_graph.h
@@ -81,6 +81,16 @@ int Size() const { return expressions_.size(); } + // Insert a new expression at "location" into the graph. All expression + // after "location" are moved by one element to the back. References to moved + // expression are updated. + void InsertExpression(ExpressionId location, + ExpressionType type, + ExpressionId lhs_id, + const std::vector<ExpressionId>& arguments, + const std::string& name, + double value); + private: // All Expressions are referenced by an ExpressionId. The ExpressionId is the // index into this array. Each expression has a list of ExpressionId as
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc index b7e7888..511a98b 100644 --- a/internal/ceres/expression_graph.cc +++ b/internal/ceres/expression_graph.cc
@@ -108,5 +108,39 @@ return true; } +void ExpressionGraph::InsertExpression( + ExpressionId location, + ExpressionType type, + ExpressionId lhs_id, + const std::vector<ExpressionId>& arguments, + const std::string& name, + double value) { + ExpressionId last_expression_id = Size() - 1; + // Increase size by adding a dummy expression. + expressions_.push_back(Expression(ExpressionType::NOP, kInvalidExpressionId)); + + // Move everything after id back and update references + for (ExpressionId id = last_expression_id; id >= location; --id) { + auto& expression = expressions_[id]; + // Increment reference if it points to a shifted variable. + if (expression.lhs_id_ >= location) { + expression.lhs_id_++; + } + for (auto& arg : expression.arguments_) { + if (arg >= location) { + arg++; + } + } + expressions_[id + 1] = expression; + } + + // Insert new expression at the correct place + Expression expr(type, lhs_id); + expr.arguments_ = arguments; + expr.name_ = name; + expr.value_ = value; + expressions_[location] = expr; +} + } // namespace internal } // namespace ceres
diff --git a/internal/ceres/expression_graph_test.cc b/internal/ceres/expression_graph_test.cc index ee06f1a..faaa10c 100644 --- a/internal/ceres/expression_graph_test.cc +++ b/internal/ceres/expression_graph_test.cc
@@ -73,5 +73,89 @@ ASSERT_FALSE(tree.DependsOn(d.id, unused.id)); } +TEST(ExpressionGraph, InsertExpression_UpdateReferences) { + // This test checks if references to shifted expressions are updated + // accordingly. + using T = ExpressionRef; + StartRecordingExpressions(); + T a(2); // 0 + T b(3); // 1 + T c = a + b; // 2 + auto graph = StopRecordingExpressions(); + + // Test if 'a' and 'c' are actually at location 0 and 2 + auto& a_expr = graph.ExpressionForId(0); + EXPECT_EQ(a_expr.type(), ExpressionType::COMPILE_TIME_CONSTANT); + EXPECT_EQ(a_expr.value(), 2); + + // At this point 'c' should have 0 and 1 as arguments. + auto& c_expr = graph.ExpressionForId(2); + EXPECT_EQ(c_expr.type(), ExpressionType::BINARY_ARITHMETIC); + EXPECT_EQ(c_expr.arguments()[0], 0); + EXPECT_EQ(c_expr.arguments()[1], 1); + + // We insert at the beginning, which shifts everything by one spot. + graph.InsertExpression( + 0, ExpressionType::COMPILE_TIME_CONSTANT, 0, {}, "", 10.2); + + // Test if 'a' and 'c' are actually at location 1 and 3 + auto& a_expr2 = graph.ExpressionForId(1); + EXPECT_EQ(a_expr2.type(), ExpressionType::COMPILE_TIME_CONSTANT); + EXPECT_EQ(a_expr2.value(), 2); + + // At this point 'c' should have 1 and 2 as arguments. + auto& c_expr2 = graph.ExpressionForId(3); + EXPECT_EQ(c_expr2.type(), ExpressionType::BINARY_ARITHMETIC); + EXPECT_EQ(c_expr2.arguments()[0], 1); + EXPECT_EQ(c_expr2.arguments()[1], 2); +} + +TEST(ExpressionGraph, InsertExpression) { + using T = ExpressionRef; + + StartRecordingExpressions(); + + { + T a(2); // 0 + T b(3); // 1 + T five = 5; // 2 + T tmp = a + five; // 3 + a = tmp; // 4 + T c = a + b; // 5 + T d = a * b; // 6 + T e = c + d; // 7 + MakeOutput(e, "result"); // 8 + } + auto reference = StopRecordingExpressions(); + EXPECT_EQ(reference.Size(), 9); + + StartRecordingExpressions(); + + { + // The expressions 2,3,4 from above are missing. + T a(2); // 0 + T b(3); // 1 + T c = a + b; // 2 + T d = a * b; // 3 + T e = c + d; // 4 + MakeOutput(e, "result"); // 5 + } + + auto graph1 = StopRecordingExpressions(); + EXPECT_EQ(graph1.Size(), 6); + ASSERT_FALSE(reference == graph1); + + // We manually insert the 3 missing expressions + // clang-format off + graph1.InsertExpression(2, ExpressionType::COMPILE_TIME_CONSTANT, 2, {}, "", 5); + graph1.InsertExpression(3, ExpressionType::BINARY_ARITHMETIC, 3, {0, 2}, "+", 0); + graph1.InsertExpression(4, ExpressionType::ASSIGNMENT, 0, {3}, "", 0); + // clang-format on + + // Now the graphs are identical! + EXPECT_EQ(graph1.Size(), 9); + ASSERT_TRUE(reference == graph1); +} + } // namespace internal } // namespace ceres