Add functions to find the matching ELSE, ENDIF expressions ExpressionId FindMatchingEndif(ExpressionId id) - Finds the closing ENDIF expression for a given IF expression. ExpressionId FindMatchingElse(ExpressionId id) - Finds the ELSE expression for a given IF expression. - Returns -1, if this IF doesn't have an else This patch is a preparation to the code analyzing required for various optimization passes. Change-Id: I893102a757c6a0bcacbcc1190c2b8ef08314eb97
diff --git a/include/ceres/codegen/internal/expression_graph.h b/include/ceres/codegen/internal/expression_graph.h index 331a2d9..992c23a 100644 --- a/include/ceres/codegen/internal/expression_graph.h +++ b/include/ceres/codegen/internal/expression_graph.h
@@ -74,6 +74,35 @@ // can be used for further operations. ExpressionId InsertBack(const Expression& expression); + // Finds the closing ENDIF expression for a given IF expression. Calling this + // method is only valid on IF expressions. If no suitable ENDIF is found, + // kInvalidExpressionId is returned. Example: + // <id> <expr> FindMatchingEndif(id) + // 0 IF 7 + // 1 IF 3 + // 2 ELSE - + // 3 ENDIF - + // 4 ELSE - + // 5 IF 6 + // 6 ENDIF - + // 7 ENDIF - + ExpressionId FindMatchingEndif(ExpressionId id) const; + + // Similar to FindMatchingEndif, but returns the matching ELSE expression. If + // no suitable ELSE is found, kInvalidExpressionId is returned. + // FindMatchingElse does not throw an error is this case, because IF without + // ELSE is allowed. + // <id> <expr> FindMatchingEndif(id) + // 0 IF 4 + // 1 IF 2 + // 2 ELSE - + // 3 ENDIF - + // 4 ELSE - + // 5 IF kInvalidEpressionId + // 6 ENDIF - + // 7 ENDIF - + ExpressionId FindMatchingElse(ExpressionId id) const; + private: // All Expressions are referenced by an ExpressionId. The ExpressionId is // the index into this array. Each expression has a list of ExpressionId as
diff --git a/internal/ceres/codegen/expression_graph_test.cc b/internal/ceres/codegen/expression_graph_test.cc index b08a9b3..420b05a 100644 --- a/internal/ceres/codegen/expression_graph_test.cc +++ b/internal/ceres/codegen/expression_graph_test.cc
@@ -31,9 +31,9 @@ // This file tests the ExpressionGraph class. This test depends on the // correctness of Expression. // -#include "ceres/codegen/internal/expression.h" #include "ceres/codegen/internal/expression_graph.h" +#include "ceres/codegen/internal/expression.h" #include "gtest/gtest.h" namespace ceres { @@ -146,6 +146,78 @@ ASSERT_TRUE(graph.DependsOn(3, 1)); } +TEST(ExpressionGraph, FindMatchingEndif) { + ExpressionGraph graph; + graph.InsertBack(Expression::CreateCompileTimeConstant(1)); + graph.InsertBack(Expression::CreateCompileTimeConstant(2)); + graph.InsertBack(Expression::CreateBinaryCompare("<", 0, 1)); + graph.InsertBack(Expression::CreateIf(2)); + graph.InsertBack(Expression::CreateIf(2)); + graph.InsertBack(Expression::CreateElse()); + graph.InsertBack(Expression::CreateEndIf()); + graph.InsertBack(Expression::CreateElse()); + graph.InsertBack(Expression::CreateIf(2)); + graph.InsertBack(Expression::CreateEndIf()); + graph.InsertBack(Expression::CreateEndIf()); + graph.InsertBack(Expression::CreateIf(2)); // < if without matching endif + EXPECT_EQ(graph.Size(), 12); + + // Code <id> + // v_0 = 1 0 + // v_1 = 2 1 + // v_2 = v_0 < v_1 2 + // IF (v_2) 3 + // IF (v_2) 4 + // ELSE 5 + // ENDIF 6 + // ELSE 7 + // IF (v_2) 8 + // ENDIF 9 + // ENDIF 10 + // IF(v_2) 11 + + EXPECT_EQ(graph.FindMatchingEndif(3), 10); + EXPECT_EQ(graph.FindMatchingEndif(4), 6); + EXPECT_EQ(graph.FindMatchingEndif(8), 9); + EXPECT_EQ(graph.FindMatchingEndif(11), kInvalidExpressionId); +} + +TEST(ExpressionGraph, FindMatchingElse) { + ExpressionGraph graph; + graph.InsertBack(Expression::CreateCompileTimeConstant(1)); + graph.InsertBack(Expression::CreateCompileTimeConstant(2)); + graph.InsertBack(Expression::CreateBinaryCompare("<", 0, 1)); + graph.InsertBack(Expression::CreateIf(2)); + graph.InsertBack(Expression::CreateIf(2)); + graph.InsertBack(Expression::CreateElse()); + graph.InsertBack(Expression::CreateEndIf()); + graph.InsertBack(Expression::CreateElse()); + graph.InsertBack(Expression::CreateIf(2)); + graph.InsertBack(Expression::CreateEndIf()); + graph.InsertBack(Expression::CreateEndIf()); + graph.InsertBack(Expression::CreateIf(2)); // < if without matching endif + EXPECT_EQ(graph.Size(), 12); + + // Code <id> + // v_0 = 1 0 + // v_1 = 2 1 + // v_2 = v_0 < v_1 2 + // IF (v_2) 3 + // IF (v_2) 4 + // ELSE 5 + // ENDIF 6 + // ELSE 7 + // IF (v_2) 8 + // ENDIF 9 + // ENDIF 10 + // IF(v_2) 11 + + EXPECT_EQ(graph.FindMatchingElse(3), 7); + EXPECT_EQ(graph.FindMatchingElse(4), 5); + EXPECT_EQ(graph.FindMatchingElse(8), kInvalidExpressionId); + EXPECT_EQ(graph.FindMatchingEndif(11), kInvalidExpressionId); +} + TEST(ExpressionGraph, InsertExpression_UpdateReferences) { // This test checks if references to shifted expressions are updated // accordingly.
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc index 7bae8d2..e619e53 100644 --- a/internal/ceres/expression_graph.cc +++ b/internal/ceres/expression_graph.cc
@@ -147,5 +147,59 @@ return Size() - 1; } +ExpressionId ExpressionGraph::FindMatchingEndif(ExpressionId id) const { + CHECK(ExpressionForId(id).type() == ExpressionType::IF) + << "FindClosingControlExpression is only valid on IF " + "expressions."; + + // Traverse downwards + for (ExpressionId i = id + 1; i < Size(); ++i) { + const auto& expr = ExpressionForId(i); + if (expr.type() == ExpressionType::ENDIF) { + return i; + + } else if (expr.type() == ExpressionType::IF) { + // Found a nested IF. + // -> Jump over the block and continue behind it. + auto matching_endif = FindMatchingEndif(i); + if (matching_endif == kInvalidExpressionId) { + return kInvalidExpressionId; + } + i = matching_endif; + continue; + } + } + return kInvalidExpressionId; +} + +ExpressionId ExpressionGraph::FindMatchingElse(ExpressionId id) const { + CHECK(ExpressionForId(id).type() == ExpressionType::IF) + << "FindClosingControlExpression is only valid on IF " + "expressions."; + + // Traverse downwards + for (ExpressionId i = id + 1; i < Size(); ++i) { + const auto& expr = ExpressionForId(i); + if (expr.type() == ExpressionType::ELSE) { + // Found it! + return i; + } else if (expr.type() == ExpressionType::ENDIF) { + // Found an endif even though we were looking for an ELSE. + // -> Return invalidId + return kInvalidExpressionId; + } else if (expr.type() == ExpressionType::IF) { + // Found a nested IF. + // -> Jump over the block and continue behind it. + auto matching_endif = FindMatchingEndif(i); + if (matching_endif == kInvalidExpressionId) { + return kInvalidExpressionId; + } + i = matching_endif; + continue; + } + } + return kInvalidExpressionId; +} + } // namespace internal } // namespace ceres