Add ExpressionGraph::Erase(ExpressionId) Add the function ExpressionGraph::Erase and a test-case for it. Erase removes the given expression from the graph by shifting all later expressions to the front. Indices and references are updated accordingly. Change-Id: Ic0449ccf28b369600fd2959a7e2a919d47f4cbe3
diff --git a/include/ceres/codegen/internal/expression_graph.h b/include/ceres/codegen/internal/expression_graph.h index 992c23a..7c5df80 100644 --- a/include/ceres/codegen/internal/expression_graph.h +++ b/include/ceres/codegen/internal/expression_graph.h
@@ -64,6 +64,12 @@ int Size() const { return expressions_.size(); } + // Erases the expression at "location". All expression after "location" are + // moved by one element to the front. References to moved expressions are + // updated. Removing an expression that is still referenced somewhere is + // undefined behaviour. + void Erase(ExpressionId location); + // Insert a new expression at "location" into the graph. All expression // after "location" are moved by one element to the back. References to // moved expressions are updated.
diff --git a/internal/ceres/codegen/expression_graph_test.cc b/internal/ceres/codegen/expression_graph_test.cc index 420b05a..4f4f0ec 100644 --- a/internal/ceres/codegen/expression_graph_test.cc +++ b/internal/ceres/codegen/expression_graph_test.cc
@@ -250,5 +250,37 @@ EXPECT_EQ(graph, ref); } +TEST(ExpressionGraph, Erase) { + // This test checks if references to shifted expressions are updated + // accordingly. + ExpressionGraph graph; + graph.InsertBack(Expression::CreateCompileTimeConstant(42)); + graph.InsertBack(Expression::CreateCompileTimeConstant(10)); + graph.InsertBack(Expression::CreateCompileTimeConstant(3)); + graph.InsertBack(Expression::CreateBinaryArithmetic( + "+", ExpressionId(0), ExpressionId(2))); + // Code: + // v_0 = 42 + // v_1 = 10 + // v_2 = 3 + // v_3 = v_0 + v_2 + + // Erase the unused expression v_1 = 10 + graph.Erase(1); + // This should shift all indices like this: + // v_0 = 42 + // v_1 = 3 + // v_2 = v_0 + v_1 + + // Test by inserting it in the correct order + ExpressionGraph ref; + ref.InsertBack(Expression::CreateCompileTimeConstant(42)); + ref.InsertBack(Expression::CreateCompileTimeConstant(3)); + ref.InsertBack(Expression::CreateBinaryArithmetic( + "+", ExpressionId(0), ExpressionId(1))); + EXPECT_EQ(graph.Size(), ref.Size()); + EXPECT_EQ(graph, ref); +} + } // namespace internal } // namespace ceres
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc index e619e53..59f20ea 100644 --- a/internal/ceres/expression_graph.cc +++ b/internal/ceres/expression_graph.cc
@@ -87,8 +87,30 @@ return true; } +void ExpressionGraph::Erase(ExpressionId location) { + CHECK_GE(location, 0); + CHECK_LT(location, Size()); + // Move everything after id to the front and update references + for (ExpressionId id = location + 1; id < Size(); ++id) { + expressions_[id - 1] = expressions_[id]; + auto& expression = expressions_[id - 1]; + // Decrement reference if it points to a shifted variable. + if (expression.lhs_id() >= location) { + expression.set_lhs_id(expression.lhs_id() - 1); + } + for (auto& arg : *expression.mutable_arguments()) { + if (arg >= location) { + arg--; + } + } + } + expressions_.resize(Size() - 1); +} + void ExpressionGraph::Insert(ExpressionId location, const Expression& expression) { + CHECK_GE(location, 0); + CHECK_LE(location, Size()); ExpressionId last_expression_id = Size() - 1; // Increase size by adding a dummy expression. expressions_.push_back(Expression());