Add functions to find the matching ELSE, ENDIF expressions
ExpressionId FindMatchingEndif(ExpressionId id)
- Finds the closing ENDIF expression for a given IF expression.
ExpressionId FindMatchingElse(ExpressionId id)
- Finds the ELSE expression for a given IF expression.
- Returns -1, if this IF doesn't have an else
This patch is a preparation to the code analyzing required
for various optimization passes.
Change-Id: I893102a757c6a0bcacbcc1190c2b8ef08314eb97
diff --git a/include/ceres/codegen/internal/expression_graph.h b/include/ceres/codegen/internal/expression_graph.h
index 331a2d9..992c23a 100644
--- a/include/ceres/codegen/internal/expression_graph.h
+++ b/include/ceres/codegen/internal/expression_graph.h
@@ -74,6 +74,35 @@
// can be used for further operations.
ExpressionId InsertBack(const Expression& expression);
+ // Finds the closing ENDIF expression for a given IF expression. Calling this
+ // method is only valid on IF expressions. If no suitable ENDIF is found,
+ // kInvalidExpressionId is returned. Example:
+ // <id> <expr> FindMatchingEndif(id)
+ // 0 IF 7
+ // 1 IF 3
+ // 2 ELSE -
+ // 3 ENDIF -
+ // 4 ELSE -
+ // 5 IF 6
+ // 6 ENDIF -
+ // 7 ENDIF -
+ ExpressionId FindMatchingEndif(ExpressionId id) const;
+
+ // Similar to FindMatchingEndif, but returns the matching ELSE expression. If
+ // no suitable ELSE is found, kInvalidExpressionId is returned.
+ // FindMatchingElse does not throw an error is this case, because IF without
+ // ELSE is allowed.
+ // <id> <expr> FindMatchingEndif(id)
+ // 0 IF 4
+ // 1 IF 2
+ // 2 ELSE -
+ // 3 ENDIF -
+ // 4 ELSE -
+ // 5 IF kInvalidEpressionId
+ // 6 ENDIF -
+ // 7 ENDIF -
+ ExpressionId FindMatchingElse(ExpressionId id) const;
+
private:
// All Expressions are referenced by an ExpressionId. The ExpressionId is
// the index into this array. Each expression has a list of ExpressionId as
diff --git a/internal/ceres/codegen/expression_graph_test.cc b/internal/ceres/codegen/expression_graph_test.cc
index b08a9b3..420b05a 100644
--- a/internal/ceres/codegen/expression_graph_test.cc
+++ b/internal/ceres/codegen/expression_graph_test.cc
@@ -31,9 +31,9 @@
// This file tests the ExpressionGraph class. This test depends on the
// correctness of Expression.
//
-#include "ceres/codegen/internal/expression.h"
#include "ceres/codegen/internal/expression_graph.h"
+#include "ceres/codegen/internal/expression.h"
#include "gtest/gtest.h"
namespace ceres {
@@ -146,6 +146,78 @@
ASSERT_TRUE(graph.DependsOn(3, 1));
}
+TEST(ExpressionGraph, FindMatchingEndif) {
+ ExpressionGraph graph;
+ graph.InsertBack(Expression::CreateCompileTimeConstant(1));
+ graph.InsertBack(Expression::CreateCompileTimeConstant(2));
+ graph.InsertBack(Expression::CreateBinaryCompare("<", 0, 1));
+ graph.InsertBack(Expression::CreateIf(2));
+ graph.InsertBack(Expression::CreateIf(2));
+ graph.InsertBack(Expression::CreateElse());
+ graph.InsertBack(Expression::CreateEndIf());
+ graph.InsertBack(Expression::CreateElse());
+ graph.InsertBack(Expression::CreateIf(2));
+ graph.InsertBack(Expression::CreateEndIf());
+ graph.InsertBack(Expression::CreateEndIf());
+ graph.InsertBack(Expression::CreateIf(2)); // < if without matching endif
+ EXPECT_EQ(graph.Size(), 12);
+
+ // Code <id>
+ // v_0 = 1 0
+ // v_1 = 2 1
+ // v_2 = v_0 < v_1 2
+ // IF (v_2) 3
+ // IF (v_2) 4
+ // ELSE 5
+ // ENDIF 6
+ // ELSE 7
+ // IF (v_2) 8
+ // ENDIF 9
+ // ENDIF 10
+ // IF(v_2) 11
+
+ EXPECT_EQ(graph.FindMatchingEndif(3), 10);
+ EXPECT_EQ(graph.FindMatchingEndif(4), 6);
+ EXPECT_EQ(graph.FindMatchingEndif(8), 9);
+ EXPECT_EQ(graph.FindMatchingEndif(11), kInvalidExpressionId);
+}
+
+TEST(ExpressionGraph, FindMatchingElse) {
+ ExpressionGraph graph;
+ graph.InsertBack(Expression::CreateCompileTimeConstant(1));
+ graph.InsertBack(Expression::CreateCompileTimeConstant(2));
+ graph.InsertBack(Expression::CreateBinaryCompare("<", 0, 1));
+ graph.InsertBack(Expression::CreateIf(2));
+ graph.InsertBack(Expression::CreateIf(2));
+ graph.InsertBack(Expression::CreateElse());
+ graph.InsertBack(Expression::CreateEndIf());
+ graph.InsertBack(Expression::CreateElse());
+ graph.InsertBack(Expression::CreateIf(2));
+ graph.InsertBack(Expression::CreateEndIf());
+ graph.InsertBack(Expression::CreateEndIf());
+ graph.InsertBack(Expression::CreateIf(2)); // < if without matching endif
+ EXPECT_EQ(graph.Size(), 12);
+
+ // Code <id>
+ // v_0 = 1 0
+ // v_1 = 2 1
+ // v_2 = v_0 < v_1 2
+ // IF (v_2) 3
+ // IF (v_2) 4
+ // ELSE 5
+ // ENDIF 6
+ // ELSE 7
+ // IF (v_2) 8
+ // ENDIF 9
+ // ENDIF 10
+ // IF(v_2) 11
+
+ EXPECT_EQ(graph.FindMatchingElse(3), 7);
+ EXPECT_EQ(graph.FindMatchingElse(4), 5);
+ EXPECT_EQ(graph.FindMatchingElse(8), kInvalidExpressionId);
+ EXPECT_EQ(graph.FindMatchingEndif(11), kInvalidExpressionId);
+}
+
TEST(ExpressionGraph, InsertExpression_UpdateReferences) {
// This test checks if references to shifted expressions are updated
// accordingly.
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc
index 7bae8d2..e619e53 100644
--- a/internal/ceres/expression_graph.cc
+++ b/internal/ceres/expression_graph.cc
@@ -147,5 +147,59 @@
return Size() - 1;
}
+ExpressionId ExpressionGraph::FindMatchingEndif(ExpressionId id) const {
+ CHECK(ExpressionForId(id).type() == ExpressionType::IF)
+ << "FindClosingControlExpression is only valid on IF "
+ "expressions.";
+
+ // Traverse downwards
+ for (ExpressionId i = id + 1; i < Size(); ++i) {
+ const auto& expr = ExpressionForId(i);
+ if (expr.type() == ExpressionType::ENDIF) {
+ return i;
+
+ } else if (expr.type() == ExpressionType::IF) {
+ // Found a nested IF.
+ // -> Jump over the block and continue behind it.
+ auto matching_endif = FindMatchingEndif(i);
+ if (matching_endif == kInvalidExpressionId) {
+ return kInvalidExpressionId;
+ }
+ i = matching_endif;
+ continue;
+ }
+ }
+ return kInvalidExpressionId;
+}
+
+ExpressionId ExpressionGraph::FindMatchingElse(ExpressionId id) const {
+ CHECK(ExpressionForId(id).type() == ExpressionType::IF)
+ << "FindClosingControlExpression is only valid on IF "
+ "expressions.";
+
+ // Traverse downwards
+ for (ExpressionId i = id + 1; i < Size(); ++i) {
+ const auto& expr = ExpressionForId(i);
+ if (expr.type() == ExpressionType::ELSE) {
+ // Found it!
+ return i;
+ } else if (expr.type() == ExpressionType::ENDIF) {
+ // Found an endif even though we were looking for an ELSE.
+ // -> Return invalidId
+ return kInvalidExpressionId;
+ } else if (expr.type() == ExpressionType::IF) {
+ // Found a nested IF.
+ // -> Jump over the block and continue behind it.
+ auto matching_endif = FindMatchingEndif(i);
+ if (matching_endif == kInvalidExpressionId) {
+ return kInvalidExpressionId;
+ }
+ i = matching_endif;
+ continue;
+ }
+ }
+ return kInvalidExpressionId;
+}
+
} // namespace internal
} // namespace ceres