Add functions to find the matching ELSE, ENDIF expressions

ExpressionId FindMatchingEndif(ExpressionId id)
   - Finds the closing ENDIF expression for a given IF expression.

ExpressionId FindMatchingElse(ExpressionId id)
   - Finds the ELSE expression for a given IF expression.
   - Returns -1, if this IF doesn't have an else

This patch is a preparation to the code analyzing required
for various optimization passes.

Change-Id: I893102a757c6a0bcacbcc1190c2b8ef08314eb97
diff --git a/include/ceres/codegen/internal/expression_graph.h b/include/ceres/codegen/internal/expression_graph.h
index 331a2d9..992c23a 100644
--- a/include/ceres/codegen/internal/expression_graph.h
+++ b/include/ceres/codegen/internal/expression_graph.h
@@ -74,6 +74,35 @@
   // can be used for further operations.
   ExpressionId InsertBack(const Expression& expression);
 
+  // Finds the closing ENDIF expression for a given IF expression. Calling this
+  // method is only valid on IF expressions. If no suitable ENDIF is found,
+  // kInvalidExpressionId is returned. Example:
+  // <id> <expr>    FindMatchingEndif(id)
+  //  0  IF         7
+  //  1    IF       3
+  //  2    ELSE     -
+  //  3    ENDIF    -
+  //  4  ELSE       -
+  //  5    IF       6
+  //  6    ENDIF    -
+  //  7  ENDIF      -
+  ExpressionId FindMatchingEndif(ExpressionId id) const;
+
+  // Similar to FindMatchingEndif, but returns the matching ELSE expression. If
+  // no suitable ELSE is found, kInvalidExpressionId is returned.
+  // FindMatchingElse does not throw an error is this case, because IF without
+  // ELSE is allowed.
+  // <id> <expr>    FindMatchingEndif(id)
+  //  0  IF         4
+  //  1    IF       2
+  //  2    ELSE     -
+  //  3    ENDIF    -
+  //  4  ELSE       -
+  //  5    IF       kInvalidEpressionId
+  //  6    ENDIF    -
+  //  7  ENDIF      -
+  ExpressionId FindMatchingElse(ExpressionId id) const;
+
  private:
   // All Expressions are referenced by an ExpressionId. The ExpressionId is
   // the index into this array. Each expression has a list of ExpressionId as
diff --git a/internal/ceres/codegen/expression_graph_test.cc b/internal/ceres/codegen/expression_graph_test.cc
index b08a9b3..420b05a 100644
--- a/internal/ceres/codegen/expression_graph_test.cc
+++ b/internal/ceres/codegen/expression_graph_test.cc
@@ -31,9 +31,9 @@
 // This file tests the ExpressionGraph class. This test depends on the
 // correctness of Expression.
 //
-#include "ceres/codegen/internal/expression.h"
 #include "ceres/codegen/internal/expression_graph.h"
 
+#include "ceres/codegen/internal/expression.h"
 #include "gtest/gtest.h"
 
 namespace ceres {
@@ -146,6 +146,78 @@
   ASSERT_TRUE(graph.DependsOn(3, 1));
 }
 
+TEST(ExpressionGraph, FindMatchingEndif) {
+  ExpressionGraph graph;
+  graph.InsertBack(Expression::CreateCompileTimeConstant(1));
+  graph.InsertBack(Expression::CreateCompileTimeConstant(2));
+  graph.InsertBack(Expression::CreateBinaryCompare("<", 0, 1));
+  graph.InsertBack(Expression::CreateIf(2));
+  graph.InsertBack(Expression::CreateIf(2));
+  graph.InsertBack(Expression::CreateElse());
+  graph.InsertBack(Expression::CreateEndIf());
+  graph.InsertBack(Expression::CreateElse());
+  graph.InsertBack(Expression::CreateIf(2));
+  graph.InsertBack(Expression::CreateEndIf());
+  graph.InsertBack(Expression::CreateEndIf());
+  graph.InsertBack(Expression::CreateIf(2));  // < if without matching endif
+  EXPECT_EQ(graph.Size(), 12);
+
+  // Code              <id>
+  // v_0 = 1            0
+  // v_1 = 2            1
+  // v_2 = v_0 < v_1    2
+  // IF (v_2)           3
+  //   IF (v_2)         4
+  //   ELSE             5
+  //   ENDIF            6
+  // ELSE               7
+  //   IF (v_2)         8
+  //   ENDIF            9
+  // ENDIF              10
+  // IF(v_2)            11
+
+  EXPECT_EQ(graph.FindMatchingEndif(3), 10);
+  EXPECT_EQ(graph.FindMatchingEndif(4), 6);
+  EXPECT_EQ(graph.FindMatchingEndif(8), 9);
+  EXPECT_EQ(graph.FindMatchingEndif(11), kInvalidExpressionId);
+}
+
+TEST(ExpressionGraph, FindMatchingElse) {
+  ExpressionGraph graph;
+  graph.InsertBack(Expression::CreateCompileTimeConstant(1));
+  graph.InsertBack(Expression::CreateCompileTimeConstant(2));
+  graph.InsertBack(Expression::CreateBinaryCompare("<", 0, 1));
+  graph.InsertBack(Expression::CreateIf(2));
+  graph.InsertBack(Expression::CreateIf(2));
+  graph.InsertBack(Expression::CreateElse());
+  graph.InsertBack(Expression::CreateEndIf());
+  graph.InsertBack(Expression::CreateElse());
+  graph.InsertBack(Expression::CreateIf(2));
+  graph.InsertBack(Expression::CreateEndIf());
+  graph.InsertBack(Expression::CreateEndIf());
+  graph.InsertBack(Expression::CreateIf(2));  // < if without matching endif
+  EXPECT_EQ(graph.Size(), 12);
+
+  // Code              <id>
+  // v_0 = 1            0
+  // v_1 = 2            1
+  // v_2 = v_0 < v_1    2
+  // IF (v_2)           3
+  //   IF (v_2)         4
+  //   ELSE             5
+  //   ENDIF            6
+  // ELSE               7
+  //   IF (v_2)         8
+  //   ENDIF            9
+  // ENDIF              10
+  // IF(v_2)            11
+
+  EXPECT_EQ(graph.FindMatchingElse(3), 7);
+  EXPECT_EQ(graph.FindMatchingElse(4), 5);
+  EXPECT_EQ(graph.FindMatchingElse(8), kInvalidExpressionId);
+  EXPECT_EQ(graph.FindMatchingEndif(11), kInvalidExpressionId);
+}
+
 TEST(ExpressionGraph, InsertExpression_UpdateReferences) {
   // This test checks if references to shifted expressions are updated
   // accordingly.
diff --git a/internal/ceres/expression_graph.cc b/internal/ceres/expression_graph.cc
index 7bae8d2..e619e53 100644
--- a/internal/ceres/expression_graph.cc
+++ b/internal/ceres/expression_graph.cc
@@ -147,5 +147,59 @@
   return Size() - 1;
 }
 
+ExpressionId ExpressionGraph::FindMatchingEndif(ExpressionId id) const {
+  CHECK(ExpressionForId(id).type() == ExpressionType::IF)
+      << "FindClosingControlExpression is only valid on IF "
+         "expressions.";
+
+  // Traverse downwards
+  for (ExpressionId i = id + 1; i < Size(); ++i) {
+    const auto& expr = ExpressionForId(i);
+    if (expr.type() == ExpressionType::ENDIF) {
+      return i;
+
+    } else if (expr.type() == ExpressionType::IF) {
+      // Found a nested IF.
+      // -> Jump over the block and continue behind it.
+      auto matching_endif = FindMatchingEndif(i);
+      if (matching_endif == kInvalidExpressionId) {
+        return kInvalidExpressionId;
+      }
+      i = matching_endif;
+      continue;
+    }
+  }
+  return kInvalidExpressionId;
+}
+
+ExpressionId ExpressionGraph::FindMatchingElse(ExpressionId id) const {
+  CHECK(ExpressionForId(id).type() == ExpressionType::IF)
+      << "FindClosingControlExpression is only valid on IF "
+         "expressions.";
+
+  // Traverse downwards
+  for (ExpressionId i = id + 1; i < Size(); ++i) {
+    const auto& expr = ExpressionForId(i);
+    if (expr.type() == ExpressionType::ELSE) {
+      // Found it!
+      return i;
+    } else if (expr.type() == ExpressionType::ENDIF) {
+      // Found an endif even though we were looking for an ELSE.
+      // -> Return invalidId
+      return kInvalidExpressionId;
+    } else if (expr.type() == ExpressionType::IF) {
+      // Found a nested IF.
+      // -> Jump over the block and continue behind it.
+      auto matching_endif = FindMatchingEndif(i);
+      if (matching_endif == kInvalidExpressionId) {
+        return kInvalidExpressionId;
+      }
+      i = matching_endif;
+      continue;
+    }
+  }
+  return kInvalidExpressionId;
+}
+
 }  // namespace internal
 }  // namespace ceres