Refactor PartitionedMatrixView to cache the partitions

The constructor now takes a LinearSolver::Options as input
and uses that to compute the partitioning once and uses it
for its lifetime.

Change-Id: I9ef30df0b60f8fa91c8b5601c397b2d9314a2cc7
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt
index 4a85e8b..06c0914 100644
--- a/internal/ceres/CMakeLists.txt
+++ b/internal/ceres/CMakeLists.txt
@@ -603,9 +603,6 @@
   add_executable(spmv_benchmark spmv_benchmark.cc)
   add_dependencies_to_benchmark(spmv_benchmark)
 
-  add_executable(partitioned_matrix_view_benchmark partitioned_matrix_view_benchmark.cc)
-  add_dependencies_to_benchmark(partitioned_matrix_view_benchmark)
-
   add_executable(block_jacobi_preconditioner_benchmark
     block_jacobi_preconditioner_benchmark.cc)
   add_dependencies_to_benchmark(block_jacobi_preconditioner_benchmark)
diff --git a/internal/ceres/evaluation_benchmark.cc b/internal/ceres/evaluation_benchmark.cc
index 2f9fff0..9849b5a 100644
--- a/internal/ceres/evaluation_benchmark.cc
+++ b/internal/ceres/evaluation_benchmark.cc
@@ -145,23 +145,20 @@
     return crs_jacobian.get();
   }
 
-  const PartitionedView* PartitionedMatrixViewJacobian(ContextImpl* context) {
-    if (!partitioned_view_jacobian) {
-      auto block_sparse = BlockSparseJacobian(context);
-      partitioned_view_jacobian = std::make_unique<PartitionedView>(
-          *block_sparse, bal_problem->num_points());
-    }
+  const PartitionedView* PartitionedMatrixViewJacobian(
+      const LinearSolver::Options& options) {
+    auto block_sparse = BlockSparseJacobian(options.context);
+    partitioned_view_jacobian =
+        std::make_unique<PartitionedView>(options, *block_sparse);
     return partitioned_view_jacobian.get();
   }
 
   const PartitionedView* PartitionedMatrixViewJacobianWithTranspose(
-      ContextImpl* context) {
-    if (!partitioned_view_jacobian_with_transpose) {
-      auto block_sparse_transpose = BlockSparseJacobianWithTranspose(context);
-      partitioned_view_jacobian_with_transpose =
-          std::make_unique<PartitionedView>(*block_sparse_transpose,
-                                            bal_problem->num_points());
-    }
+      const LinearSolver::Options& options) {
+    auto block_sparse_transpose =
+        BlockSparseJacobianWithTranspose(options.context);
+    partitioned_view_jacobian_with_transpose =
+        std::make_unique<PartitionedView>(options, *block_sparse_transpose);
     return partitioned_view_jacobian_with_transpose.get();
   }
 
@@ -243,16 +240,17 @@
 static void PMVRightMultiplyAndAccumulateF(benchmark::State& state,
                                            BALData* data,
                                            ContextImpl* context) {
-  const int num_threads = state.range(0);
-
-  auto jacobian = data->PartitionedMatrixViewJacobian(context);
+  LinearSolver::Options options;
+  options.num_threads = state.range(0);
+  options.elimination_groups.push_back(data->bal_problem->num_points());
+  options.context = context;
+  auto jacobian = data->PartitionedMatrixViewJacobian(options);
 
   Vector y = Vector::Zero(jacobian->num_rows());
   Vector x = Vector::Random(jacobian->num_cols_f());
 
   for (auto _ : state) {
-    jacobian->RightMultiplyAndAccumulateF(
-        x.data(), y.data(), context, num_threads);
+    jacobian->RightMultiplyAndAccumulateF(x.data(), y.data());
   }
   CHECK_GT(y.squaredNorm(), 0.);
 }
@@ -260,49 +258,35 @@
 static void PMVLeftMultiplyAndAccumulateF(benchmark::State& state,
                                           BALData* data,
                                           ContextImpl* context) {
-  const int num_threads = state.range(0);
-
-  auto jacobian = data->PartitionedMatrixViewJacobian(context);
+  LinearSolver::Options options;
+  options.num_threads = state.range(0);
+  options.elimination_groups.push_back(data->bal_problem->num_points());
+  options.context = context;
+  auto jacobian = data->PartitionedMatrixViewJacobianWithTranspose(options);
 
   Vector y = Vector::Zero(jacobian->num_cols_f());
   Vector x = Vector::Random(jacobian->num_rows());
 
   for (auto _ : state) {
-    jacobian->LeftMultiplyAndAccumulateF(
-        x.data(), y.data(), context, num_threads);
+    jacobian->LeftMultiplyAndAccumulateF(x.data(), y.data());
   }
   CHECK_GT(y.squaredNorm(), 0.);
 }
 
-static void PMVLeftMultiplyAndAccumulateWithTransposeF(benchmark::State& state,
-                                                       BALData* data,
-                                                       ContextImpl* context) {
-  const int num_threads = state.range(0);
-
-  auto jacobian = data->PartitionedMatrixViewJacobianWithTranspose(context);
-
-  Vector y = Vector::Zero(jacobian->num_cols_f());
-  Vector x = Vector::Random(jacobian->num_rows());
-
-  for (auto _ : state) {
-    jacobian->LeftMultiplyAndAccumulateF(
-        x.data(), y.data(), context, num_threads);
-  }
-  CHECK_GT(y.squaredNorm(), 0.);
-}
 static void PMVRightMultiplyAndAccumulateE(benchmark::State& state,
                                            BALData* data,
                                            ContextImpl* context) {
-  const int num_threads = state.range(0);
-
-  auto jacobian = data->PartitionedMatrixViewJacobian(context);
+  LinearSolver::Options options;
+  options.num_threads = state.range(0);
+  options.elimination_groups.push_back(data->bal_problem->num_points());
+  options.context = context;
+  auto jacobian = data->PartitionedMatrixViewJacobian(options);
 
   Vector y = Vector::Zero(jacobian->num_rows());
   Vector x = Vector::Random(jacobian->num_cols_e());
 
   for (auto _ : state) {
-    jacobian->RightMultiplyAndAccumulateE(
-        x.data(), y.data(), context, num_threads);
+    jacobian->RightMultiplyAndAccumulateE(x.data(), y.data());
   }
   CHECK_GT(y.squaredNorm(), 0.);
 }
@@ -310,33 +294,17 @@
 static void PMVLeftMultiplyAndAccumulateE(benchmark::State& state,
                                           BALData* data,
                                           ContextImpl* context) {
-  const int num_threads = state.range(0);
-
-  auto jacobian = data->PartitionedMatrixViewJacobian(context);
+  LinearSolver::Options options;
+  options.num_threads = state.range(0);
+  options.elimination_groups.push_back(data->bal_problem->num_points());
+  options.context = context;
+  auto jacobian = data->PartitionedMatrixViewJacobianWithTranspose(options);
 
   Vector y = Vector::Zero(jacobian->num_cols_e());
   Vector x = Vector::Random(jacobian->num_rows());
 
   for (auto _ : state) {
-    jacobian->LeftMultiplyAndAccumulateE(
-        x.data(), y.data(), context, num_threads);
-  }
-  CHECK_GT(y.squaredNorm(), 0.);
-}
-
-static void PMVLeftMultiplyAndAccumulateWithTransposeE(benchmark::State& state,
-                                                       BALData* data,
-                                                       ContextImpl* context) {
-  const int num_threads = state.range(0);
-
-  auto jacobian = data->PartitionedMatrixViewJacobianWithTranspose(context);
-
-  Vector y = Vector::Zero(jacobian->num_cols_e());
-  Vector x = Vector::Random(jacobian->num_rows());
-
-  for (auto _ : state) {
-    jacobian->LeftMultiplyAndAccumulateE(
-        x.data(), y.data(), context, num_threads);
+    jacobian->LeftMultiplyAndAccumulateE(x.data(), y.data());
   }
   CHECK_GT(y.squaredNorm(), 0.);
 }
@@ -363,22 +331,6 @@
                                               ContextImpl* context) {
   const int num_threads = state.range(0);
 
-  auto jacobian = data->BlockSparseJacobian(context);
-
-  Vector y = Vector::Zero(jacobian->num_cols());
-  Vector x = Vector::Random(jacobian->num_rows());
-
-  for (auto _ : state) {
-    jacobian->LeftMultiplyAndAccumulate(
-        x.data(), y.data(), context, num_threads);
-  }
-  CHECK_GT(y.squaredNorm(), 0.);
-}
-
-static void JacobianLeftMultiplyAndAccumulateWithTranspose(
-    benchmark::State& state, BALData* data, ContextImpl* context) {
-  const int num_threads = state.range(0);
-
   auto jacobian = data->BlockSparseJacobianWithTranspose(context);
 
   Vector y = Vector::Zero(jacobian->num_cols());
@@ -551,15 +503,6 @@
         ceres::internal::JacobianLeftMultiplyAndAccumulate,
         data,
         &context)
-        ->Arg(1);
-
-    const std::string name_left_product_transpose =
-        "JacobianLeftMultiplyAndAccumulateWithTranspose<" + path + ">";
-    ::benchmark::RegisterBenchmark(
-        name_left_product_transpose.c_str(),
-        ceres::internal::JacobianLeftMultiplyAndAccumulateWithTranspose,
-        data,
-        &context)
         ->Arg(1)
         ->Arg(2)
         ->Arg(4)
@@ -573,14 +516,6 @@
         ceres::internal::PMVLeftMultiplyAndAccumulateF,
         data,
         &context)
-        ->Arg(1);
-    const std::string name_left_product_partitioned_transpose_f =
-        "PMVLeftMultiplyAndAccumulateWithTransposeF<" + path + ">";
-    ::benchmark::RegisterBenchmark(
-        name_left_product_partitioned_transpose_f.c_str(),
-        ceres::internal::PMVLeftMultiplyAndAccumulateWithTransposeF,
-        data,
-        &context)
         ->Arg(1)
         ->Arg(2)
         ->Arg(4)
@@ -594,15 +529,6 @@
         ceres::internal::PMVLeftMultiplyAndAccumulateE,
         data,
         &context)
-        ->Arg(1);
-
-    const std::string name_left_product_partitioned_transpose_e =
-        "PMVLeftMultiplyAndAccumulateWithTransposeE<" + path + ">";
-    ::benchmark::RegisterBenchmark(
-        name_left_product_partitioned_transpose_e.c_str(),
-        ceres::internal::PMVLeftMultiplyAndAccumulateWithTransposeE,
-        data,
-        &context)
         ->Arg(1)
         ->Arg(2)
         ->Arg(4)
diff --git a/internal/ceres/fake_bundle_adjustment_jacobian.cc b/internal/ceres/fake_bundle_adjustment_jacobian.cc
index af1d03c..d8f3c01 100644
--- a/internal/ceres/fake_bundle_adjustment_jacobian.cc
+++ b/internal/ceres/fake_bundle_adjustment_jacobian.cc
@@ -109,8 +109,10 @@
       PartitionedMatrixView<2, Eigen::Dynamic, Eigen::Dynamic>;
   auto block_sparse_matrix = CreateFakeBundleAdjustmentJacobian(
       num_cameras, num_points, camera_size, landmark_size, visibility, rng);
+  LinearSolver::Options options;
+  options.elimination_groups.push_back(num_points);
   auto partitioned_view =
-      std::make_unique<PartitionedView>(*block_sparse_matrix, num_points);
+      std::make_unique<PartitionedView>(options, *block_sparse_matrix);
   return std::make_pair(std::move(partitioned_view),
                         std::move(block_sparse_matrix));
 }
diff --git a/internal/ceres/implicit_schur_complement.cc b/internal/ceres/implicit_schur_complement.cc
index 6249c8c..4e36048 100644
--- a/internal/ceres/implicit_schur_complement.cc
+++ b/internal/ceres/implicit_schur_complement.cc
@@ -105,15 +105,11 @@
                                                          double* y) const {
   // y1 = F x
   tmp_rows_.setZero();
-  A_->RightMultiplyAndAccumulateF(
-      x, tmp_rows_.data(), options_.context, options_.num_threads);
+  A_->RightMultiplyAndAccumulateF(x, tmp_rows_.data());
 
   // y2 = E' y1
   tmp_e_cols_.setZero();
-  A_->LeftMultiplyAndAccumulateE(tmp_rows_.data(),
-                                 tmp_e_cols_.data(),
-                                 options_.context,
-                                 options_.num_threads);
+  A_->LeftMultiplyAndAccumulateE(tmp_rows_.data(), tmp_e_cols_.data());
 
   // y3 = -(E'E)^-1 y2
   tmp_e_cols_2_.setZero();
@@ -121,13 +117,11 @@
                                                           tmp_e_cols_2_.data(),
                                                           options_.context,
                                                           options_.num_threads);
+
   tmp_e_cols_2_ *= -1.0;
 
   // y1 = y1 + E y3
-  A_->RightMultiplyAndAccumulateE(tmp_e_cols_2_.data(),
-                                  tmp_rows_.data(),
-                                  options_.context,
-                                  options_.num_threads);
+  A_->RightMultiplyAndAccumulateE(tmp_e_cols_2_.data(), tmp_rows_.data());
 
   // y5 = D * x
   if (D_ != nullptr) {
@@ -140,8 +134,7 @@
   }
 
   // y = y5 + F' y1
-  A_->LeftMultiplyAndAccumulateF(
-      tmp_rows_.data(), y, options_.context, options_.num_threads);
+  A_->LeftMultiplyAndAccumulateF(tmp_rows_.data(), y);
 }
 
 void ImplicitSchurComplement::InversePowerSeriesOperatorRightMultiplyAccumulate(
@@ -149,15 +142,11 @@
   CHECK(compute_ftf_inverse_);
   // y1 = F x
   tmp_rows_.setZero();
-  A_->RightMultiplyAndAccumulateF(
-      x, tmp_rows_.data(), options_.context, options_.num_threads);
+  A_->RightMultiplyAndAccumulateF(x, tmp_rows_.data());
 
   // y2 = E' y1
   tmp_e_cols_.setZero();
-  A_->LeftMultiplyAndAccumulateE(tmp_rows_.data(),
-                                 tmp_e_cols_.data(),
-                                 options_.context,
-                                 options_.num_threads);
+  A_->LeftMultiplyAndAccumulateE(tmp_rows_.data(), tmp_e_cols_.data());
 
   // y3 = (E'E)^-1 y2
   tmp_e_cols_2_.setZero();
@@ -167,17 +156,11 @@
                                                           options_.num_threads);
   // y1 = E y3
   tmp_rows_.setZero();
-  A_->RightMultiplyAndAccumulateE(tmp_e_cols_2_.data(),
-                                  tmp_rows_.data(),
-                                  options_.context,
-                                  options_.num_threads);
+  A_->RightMultiplyAndAccumulateE(tmp_e_cols_2_.data(), tmp_rows_.data());
 
   // y4 = F' y1
   tmp_f_cols_.setZero();
-  A_->LeftMultiplyAndAccumulateF(tmp_rows_.data(),
-                                 tmp_f_cols_.data(),
-                                 options_.context,
-                                 options_.num_threads);
+  A_->LeftMultiplyAndAccumulateF(tmp_rows_.data(), tmp_f_cols_.data());
 
   // y += (F'F)^-1 y4
   block_diagonal_FtF_inverse_->RightMultiplyAndAccumulate(
@@ -219,18 +202,14 @@
 
   // y1 = F x
   tmp_rows_.setZero();
-  A_->RightMultiplyAndAccumulateF(
-      x, tmp_rows_.data(), options_.context, options_.num_threads);
+  A_->RightMultiplyAndAccumulateF(x, tmp_rows_.data());
 
   // y2 = b - y1
   tmp_rows_ = ConstVectorRef(b_, num_rows) - tmp_rows_;
 
   // y3 = E' y2
   tmp_e_cols_.setZero();
-  A_->LeftMultiplyAndAccumulateE(tmp_rows_.data(),
-                                 tmp_e_cols_.data(),
-                                 options_.context,
-                                 options_.num_threads);
+  A_->LeftMultiplyAndAccumulateE(tmp_rows_.data(), tmp_e_cols_.data());
 
   // y = (E'E)^-1 y3
   VectorRef(y, num_cols).setZero();
@@ -254,8 +233,7 @@
 void ImplicitSchurComplement::UpdateRhs() {
   // y1 = E'b
   tmp_e_cols_.setZero();
-  A_->LeftMultiplyAndAccumulateE(
-      b_, tmp_e_cols_.data(), options_.context, options_.num_threads);
+  A_->LeftMultiplyAndAccumulateE(b_, tmp_e_cols_.data());
 
   // y2 = (E'E)^-1 y1
   Vector y2 = Vector::Zero(A_->num_cols_e());
@@ -264,16 +242,14 @@
 
   // y3 = E y2
   tmp_rows_.setZero();
-  A_->RightMultiplyAndAccumulateE(
-      y2.data(), tmp_rows_.data(), options_.context, options_.num_threads);
+  A_->RightMultiplyAndAccumulateE(y2.data(), tmp_rows_.data());
 
   // y3 = b - y3
   tmp_rows_ = ConstVectorRef(b_, A_->num_rows()) - tmp_rows_;
 
   // rhs = F' y3
   rhs_.setZero();
-  A_->LeftMultiplyAndAccumulateF(
-      tmp_rows_.data(), rhs_.data(), options_.context, options_.num_threads);
+  A_->LeftMultiplyAndAccumulateF(tmp_rows_.data(), rhs_.data());
 }
 
 }  // namespace ceres::internal
diff --git a/internal/ceres/parallel_for.h b/internal/ceres/parallel_for.h
index 107d19c..fe5d242 100644
--- a/internal/ceres/parallel_for.h
+++ b/internal/ceres/parallel_for.h
@@ -299,6 +299,42 @@
               });
 }
 
+// Execute function for every element in the range [start, end) with at most
+// num_threads, using the user provided partitioning. taking into account
+// user-provided integer cumulative costs of iterations.
+template <typename F>
+void ParallelFor(ContextImpl* context,
+                 int start,
+                 int end,
+                 int num_threads,
+                 const F& function,
+                 const std::vector<int>& partitions) {
+  using namespace parallel_for_details;
+  CHECK_GT(num_threads, 0);
+  if (start >= end) {
+    return;
+  }
+  CHECK_EQ(partitions.front(), start);
+  CHECK_EQ(partitions.back(), end);
+  if (num_threads == 1 || end - start <= num_threads) {
+    ParallelFor(context, start, end, num_threads, function);
+    return;
+  }
+  CHECK_GT(partitions.size(), 1);
+  const int num_partitions = partitions.size() - 1;
+  ParallelFor(context,
+              0,
+              num_partitions,
+              num_threads,
+              [&function, &partitions](int thread_id, int partition_id) {
+                const int partition_start = partitions[partition_id];
+                const int partition_end = partitions[partition_id + 1];
+
+                for (int i = partition_start; i < partition_end; ++i) {
+                  Invoke<F>(thread_id, i, function);
+                }
+              });
+}
 }  // namespace ceres::internal
 
 // Backend-specific implementations of ParallelInvoke
diff --git a/internal/ceres/partitioned_matrix_view.cc b/internal/ceres/partitioned_matrix_view.cc
index 2710366..d952a2a 100644
--- a/internal/ceres/partitioned_matrix_view.cc
+++ b/internal/ceres/partitioned_matrix_view.cc
@@ -55,121 +55,121 @@
      (options.e_block_size == 2) &&
      (options.f_block_size == 2)) {
     return std::make_unique<PartitionedMatrixView<2,2, 2>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 2) &&
      (options.e_block_size == 2) &&
      (options.f_block_size == 3)) {
     return std::make_unique<PartitionedMatrixView<2,2, 3>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 2) &&
      (options.e_block_size == 2) &&
      (options.f_block_size == 4)) {
     return std::make_unique<PartitionedMatrixView<2,2, 4>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 2) &&
      (options.e_block_size == 2)) {
     return std::make_unique<PartitionedMatrixView<2,2, Eigen::Dynamic>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 2) &&
      (options.e_block_size == 3) &&
      (options.f_block_size == 3)) {
     return std::make_unique<PartitionedMatrixView<2,3, 3>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 2) &&
      (options.e_block_size == 3) &&
      (options.f_block_size == 4)) {
     return std::make_unique<PartitionedMatrixView<2,3, 4>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 2) &&
      (options.e_block_size == 3) &&
      (options.f_block_size == 6)) {
     return std::make_unique<PartitionedMatrixView<2,3, 6>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 2) &&
      (options.e_block_size == 3) &&
      (options.f_block_size == 9)) {
     return std::make_unique<PartitionedMatrixView<2,3, 9>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 2) &&
      (options.e_block_size == 3)) {
     return std::make_unique<PartitionedMatrixView<2,3, Eigen::Dynamic>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 2) &&
      (options.e_block_size == 4) &&
      (options.f_block_size == 3)) {
     return std::make_unique<PartitionedMatrixView<2,4, 3>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 2) &&
      (options.e_block_size == 4) &&
      (options.f_block_size == 4)) {
     return std::make_unique<PartitionedMatrixView<2,4, 4>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 2) &&
      (options.e_block_size == 4) &&
      (options.f_block_size == 6)) {
     return std::make_unique<PartitionedMatrixView<2,4, 6>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 2) &&
      (options.e_block_size == 4) &&
      (options.f_block_size == 8)) {
     return std::make_unique<PartitionedMatrixView<2,4, 8>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 2) &&
      (options.e_block_size == 4) &&
      (options.f_block_size == 9)) {
     return std::make_unique<PartitionedMatrixView<2,4, 9>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 2) &&
      (options.e_block_size == 4)) {
     return std::make_unique<PartitionedMatrixView<2,4, Eigen::Dynamic>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if (options.row_block_size == 2) {
     return std::make_unique<PartitionedMatrixView<2,Eigen::Dynamic, Eigen::Dynamic>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 3) &&
      (options.e_block_size == 3) &&
      (options.f_block_size == 3)) {
     return std::make_unique<PartitionedMatrixView<3,3, 3>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 4) &&
      (options.e_block_size == 4) &&
      (options.f_block_size == 2)) {
     return std::make_unique<PartitionedMatrixView<4,4, 2>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 4) &&
      (options.e_block_size == 4) &&
      (options.f_block_size == 3)) {
     return std::make_unique<PartitionedMatrixView<4,4, 3>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 4) &&
      (options.e_block_size == 4) &&
      (options.f_block_size == 4)) {
     return std::make_unique<PartitionedMatrixView<4,4, 4>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
   if ((options.row_block_size == 4) &&
      (options.e_block_size == 4)) {
     return std::make_unique<PartitionedMatrixView<4,4, Eigen::Dynamic>>(
-                   matrix, options.elimination_groups[0]);
+                   options, matrix);
   }
 
 #endif
@@ -179,7 +179,7 @@
   return std::make_unique<PartitionedMatrixView<Eigen::Dynamic,
                                                 Eigen::Dynamic,
                                                 Eigen::Dynamic>>(
-      matrix, options.elimination_groups[0]);
+      options, matrix);
 };
 
 }  // namespace ceres::internal
diff --git a/internal/ceres/partitioned_matrix_view.h b/internal/ceres/partitioned_matrix_view.h
index 533a0e5..f34fd4a 100644
--- a/internal/ceres/partitioned_matrix_view.h
+++ b/internal/ceres/partitioned_matrix_view.h
@@ -70,33 +70,25 @@
 
   // y += E'x
   virtual void LeftMultiplyAndAccumulateE(const double* x, double* y) const = 0;
-  virtual void LeftMultiplyAndAccumulateE(const double* x,
-                                          double* y,
-                                          ContextImpl* context,
-                                          int num_threads) const = 0;
+  virtual void LeftMultiplyAndAccumulateESingleThreaded(const double* x,
+                                                        double* y) const = 0;
+  virtual void LeftMultiplyAndAccumulateEMultiThreaded(const double* x,
+                                                       double* y) const = 0;
 
   // y += F'x
   virtual void LeftMultiplyAndAccumulateF(const double* x, double* y) const = 0;
-  virtual void LeftMultiplyAndAccumulateF(const double* x,
-                                          double* y,
-                                          ContextImpl* context,
-                                          int num_threads) const = 0;
+  virtual void LeftMultiplyAndAccumulateFSingleThreaded(const double* x,
+                                                        double* y) const = 0;
+  virtual void LeftMultiplyAndAccumulateFMultiThreaded(const double* x,
+                                                       double* y) const = 0;
 
   // y += Ex
   virtual void RightMultiplyAndAccumulateE(const double* x,
                                            double* y) const = 0;
-  virtual void RightMultiplyAndAccumulateE(const double* x,
-                                           double* y,
-                                           ContextImpl* context,
-                                           int num_threads) const = 0;
 
   // y += Fx
   virtual void RightMultiplyAndAccumulateF(const double* x,
                                            double* y) const = 0;
-  virtual void RightMultiplyAndAccumulateF(const double* x,
-                                           double* y,
-                                           ContextImpl* context,
-                                           int num_threads) const = 0;
 
   // Create and return the block diagonal of the matrix E'E.
   virtual std::unique_ptr<BlockSparseMatrix> CreateBlockDiagonalEtE() const = 0;
@@ -128,6 +120,8 @@
   virtual int num_cols_f()       const = 0;
   virtual int num_rows()         const = 0;
   virtual int num_cols()         const = 0;
+  virtual const std::vector<int>& e_cols_partition() const = 0;
+  virtual const std::vector<int>& f_cols_partition() const = 0;
   // clang-format on
 
   static std::unique_ptr<PartitionedMatrixViewBase> Create(
@@ -141,29 +135,34 @@
     : public PartitionedMatrixViewBase {
  public:
   // matrix = [E F], where the matrix E contains the first
-  // num_col_blocks_a column blocks.
-  PartitionedMatrixView(const BlockSparseMatrix& matrix, int num_col_blocks_e);
+  // options.elimination_groups[0] column blocks.
+  PartitionedMatrixView(const LinearSolver::Options& options,
+                        const BlockSparseMatrix& matrix);
 
-  void LeftMultiplyAndAccumulateE(const double* x, double* y) const final;
-  void LeftMultiplyAndAccumulateF(const double* x, double* y) const final;
-  void LeftMultiplyAndAccumulateE(const double* x,
-                                  double* y,
-                                  ContextImpl* context,
-                                  int num_threads) const final;
-  void LeftMultiplyAndAccumulateF(const double* x,
-                                  double* y,
-                                  ContextImpl* context,
-                                  int num_threads) const final;
-  void RightMultiplyAndAccumulateE(const double* x, double* y) const final;
-  void RightMultiplyAndAccumulateF(const double* x, double* y) const final;
-  void RightMultiplyAndAccumulateE(const double* x,
-                                   double* y,
-                                   ContextImpl* context,
-                                   int num_threads) const final;
-  void RightMultiplyAndAccumulateF(const double* x,
-                                   double* y,
-                                   ContextImpl* context,
-                                   int num_threads) const final;
+  // y += E'x
+  virtual void LeftMultiplyAndAccumulateE(const double* x,
+                                          double* y) const final;
+  virtual void LeftMultiplyAndAccumulateESingleThreaded(const double* x,
+                                                        double* y) const final;
+  virtual void LeftMultiplyAndAccumulateEMultiThreaded(const double* x,
+                                                       double* y) const final;
+
+  // y += F'x
+  virtual void LeftMultiplyAndAccumulateF(const double* x,
+                                          double* y) const final;
+  virtual void LeftMultiplyAndAccumulateFSingleThreaded(const double* x,
+                                                        double* y) const final;
+  virtual void LeftMultiplyAndAccumulateFMultiThreaded(const double* x,
+                                                       double* y) const final;
+
+  // y += Ex
+  virtual void RightMultiplyAndAccumulateE(const double* x,
+                                           double* y) const final;
+
+  // y += Fx
+  virtual void RightMultiplyAndAccumulateF(const double* x,
+                                           double* y) const final;
+
   std::unique_ptr<BlockSparseMatrix> CreateBlockDiagonalEtE() const final;
   std::unique_ptr<BlockSparseMatrix> CreateBlockDiagonalFtF() const final;
   void UpdateBlockDiagonalEtE(BlockSparseMatrix* block_diagonal) const final;
@@ -176,17 +175,26 @@
   int num_rows()         const final { return matrix_.num_rows(); }
   int num_cols()         const final { return matrix_.num_cols(); }
   // clang-format on
+  const std::vector<int>& e_cols_partition() const final {
+    return e_cols_partition_;
+  }
+  const std::vector<int>& f_cols_partition() const final {
+    return f_cols_partition_;
+  }
 
  private:
   std::unique_ptr<BlockSparseMatrix> CreateBlockDiagonalMatrixLayout(
       int start_col_block, int end_col_block) const;
 
+  const LinearSolver::Options options_;
   const BlockSparseMatrix& matrix_;
   int num_row_blocks_e_;
   int num_col_blocks_e_;
   int num_col_blocks_f_;
   int num_cols_e_;
   int num_cols_f_;
+  std::vector<int> e_cols_partition_;
+  std::vector<int> f_cols_partition_;
 };
 
 }  // namespace ceres::internal
diff --git a/internal/ceres/partitioned_matrix_view_benchmark.cc b/internal/ceres/partitioned_matrix_view_benchmark.cc
deleted file mode 100644
index 63d0e34..0000000
--- a/internal/ceres/partitioned_matrix_view_benchmark.cc
+++ /dev/null
@@ -1,147 +0,0 @@
-// Ceres Solver - A fast non-linear least squares minimizer
-// Copyright 2022 Google Inc. All rights reserved.
-// http://ceres-solver.org/
-//
-// Redistribution and use in source and binary forms, with or without
-// modification, are permitted provided that the following conditions are met:
-//
-// * Redistributions of source code must retain the above copyright notice,
-//   this list of conditions and the following disclaimer.
-// * Redistributions in binary form must reproduce the above copyright notice,
-//   this list of conditions and the following disclaimer in the documentation
-//   and/or other materials provided with the distribution.
-// * Neither the name of Google Inc. nor the names of its contributors may be
-//   used to endorse or promote products derived from this software without
-//   specific prior written permission.
-//
-// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
-// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
-// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
-// POSSIBILITY OF SUCH DAMAGE.
-
-#include <memory>
-#include <random>
-
-#include "benchmark/benchmark.h"
-#include "ceres/context_impl.h"
-#include "ceres/fake_bundle_adjustment_jacobian.h"
-#include "ceres/partitioned_matrix_view.h"
-
-constexpr int kNumCameras = 1000;
-constexpr int kNumPoints = 10000;
-constexpr int kCameraSize = 6;
-constexpr int kPointSize = 3;
-constexpr double kVisibility = 0.1;
-
-namespace ceres::internal {
-
-static void BM_PatitionedViewRightMultiplyAndAccumulateE_Static(
-    benchmark::State& state) {
-  const int num_threads = state.range(0);
-  std::mt19937 prng;
-  auto [partitioned_view, jacobian] =
-      CreateFakeBundleAdjustmentPartitionedJacobian<kPointSize, kCameraSize>(
-          kNumCameras, kNumPoints, kVisibility, prng);
-
-  ContextImpl context;
-  context.EnsureMinimumThreads(num_threads);
-
-  Vector x(partitioned_view->num_cols_e());
-  Vector y(partitioned_view->num_rows());
-  x.setRandom();
-  y.setRandom();
-  double sum = 0;
-  for (auto _ : state) {
-    partitioned_view->RightMultiplyAndAccumulateE(
-        x.data(), y.data(), &context, num_threads);
-    sum += y.norm();
-  }
-  CHECK_NE(sum, 0.0);
-}
-BENCHMARK(BM_PatitionedViewRightMultiplyAndAccumulateE_Static)
-    ->Arg(1)
-    ->Arg(2)
-    ->Arg(4)
-    ->Arg(8)
-    ->Arg(16);
-
-static void BM_PatitionedViewRightMultiplyAndAccumulateE_Dynamic(
-    benchmark::State& state) {
-  std::mt19937 prng;
-  auto [partitioned_view, jacobian] =
-      CreateFakeBundleAdjustmentPartitionedJacobian(
-          kNumCameras, kNumPoints, kCameraSize, kPointSize, kVisibility, prng);
-
-  Vector x(partitioned_view->num_cols_e());
-  Vector y(partitioned_view->num_rows());
-  x.setRandom();
-  y.setRandom();
-  double sum = 0;
-  for (auto _ : state) {
-    partitioned_view->RightMultiplyAndAccumulateE(x.data(), y.data());
-    sum += y.norm();
-  }
-  CHECK_NE(sum, 0.0);
-}
-BENCHMARK(BM_PatitionedViewRightMultiplyAndAccumulateE_Dynamic);
-
-static void BM_PatitionedViewRightMultiplyAndAccumulateF_Static(
-    benchmark::State& state) {
-  const int num_threads = state.range(0);
-  std::mt19937 prng;
-  auto [partitioned_view, jacobian] =
-      CreateFakeBundleAdjustmentPartitionedJacobian<kPointSize, kCameraSize>(
-          kNumCameras, kNumPoints, kVisibility, prng);
-
-  ContextImpl context;
-  context.EnsureMinimumThreads(num_threads);
-
-  Vector x(partitioned_view->num_cols_f());
-  Vector y(partitioned_view->num_rows());
-  x.setRandom();
-  y.setRandom();
-  double sum = 0;
-  for (auto _ : state) {
-    partitioned_view->RightMultiplyAndAccumulateF(
-        x.data(), y.data(), &context, num_threads);
-    sum += y.norm();
-  }
-  CHECK_NE(sum, 0.0);
-}
-BENCHMARK(BM_PatitionedViewRightMultiplyAndAccumulateF_Static)
-    ->Arg(1)
-    ->Arg(2)
-    ->Arg(4)
-    ->Arg(8)
-    ->Arg(16);
-
-static void BM_PatitionedViewRightMultiplyAndAccumulateF_Dynamic(
-    benchmark::State& state) {
-  std::mt19937 prng;
-  auto [partitioned_view, jacobian] =
-      CreateFakeBundleAdjustmentPartitionedJacobian(
-          kNumCameras, kNumPoints, kCameraSize, kPointSize, kVisibility, prng);
-
-  Vector x(partitioned_view->num_cols_f());
-  Vector y(partitioned_view->num_rows());
-  x.setRandom();
-  y.setRandom();
-  double sum = 0;
-  for (auto _ : state) {
-    partitioned_view->RightMultiplyAndAccumulateF(x.data(), y.data());
-    sum += y.norm();
-  }
-  CHECK_NE(sum, 0.0);
-}
-BENCHMARK(BM_PatitionedViewRightMultiplyAndAccumulateF_Dynamic);
-
-}  // namespace ceres::internal
-
-BENCHMARK_MAIN();
diff --git a/internal/ceres/partitioned_matrix_view_impl.h b/internal/ceres/partitioned_matrix_view_impl.h
index 689f09e..1ca2e14 100644
--- a/internal/ceres/partitioned_matrix_view_impl.h
+++ b/internal/ceres/partitioned_matrix_view_impl.h
@@ -45,11 +45,14 @@
 
 template <int kRowBlockSize, int kEBlockSize, int kFBlockSize>
 PartitionedMatrixView<kRowBlockSize, kEBlockSize, kFBlockSize>::
-    PartitionedMatrixView(const BlockSparseMatrix& matrix, int num_col_blocks_e)
-    : matrix_(matrix), num_col_blocks_e_(num_col_blocks_e) {
+    PartitionedMatrixView(const LinearSolver::Options& options,
+                          const BlockSparseMatrix& matrix)
+
+    : options_(options), matrix_(matrix) {
   const CompressedRowBlockStructure* bs = matrix_.block_structure();
   CHECK(bs != nullptr);
 
+  num_col_blocks_e_ = options_.elimination_groups[0];
   num_col_blocks_f_ = bs->cols.size() - num_col_blocks_e_;
 
   // Compute the number of row blocks in E. The number of row blocks
@@ -79,6 +82,25 @@
   }
 
   CHECK_EQ(num_cols_e_ + num_cols_f_, matrix_.num_cols());
+
+  auto transpose_bs = matrix_.transpose_block_structure();
+  const int num_threads = options_.num_threads;
+  if (transpose_bs != nullptr && num_threads > 1) {
+    int kMaxPartitions = num_threads * 4;
+    e_cols_partition_ = parallel_for_details::ComputePartition(
+        0,
+        num_col_blocks_e_,
+        kMaxPartitions,
+        transpose_bs->rows.data(),
+        [](const CompressedRow& row) { return row.cumulative_nnz; });
+
+    f_cols_partition_ = parallel_for_details::ComputePartition(
+        num_col_blocks_e_,
+        num_col_blocks_e_ + num_col_blocks_f_,
+        kMaxPartitions,
+        transpose_bs->rows.data(),
+        [](const CompressedRow& row) { return row.cumulative_nnz; });
+  }
 }
 
 // The next four methods don't seem to be particularly cache
@@ -89,23 +111,14 @@
 template <int kRowBlockSize, int kEBlockSize, int kFBlockSize>
 void PartitionedMatrixView<kRowBlockSize, kEBlockSize, kFBlockSize>::
     RightMultiplyAndAccumulateE(const double* x, double* y) const {
-  RightMultiplyAndAccumulateE(x, y, nullptr, 1);
-}
-
-template <int kRowBlockSize, int kEBlockSize, int kFBlockSize>
-void PartitionedMatrixView<kRowBlockSize, kEBlockSize, kFBlockSize>::
-    RightMultiplyAndAccumulateE(const double* x,
-                                double* y,
-                                ContextImpl* context,
-                                int num_threads) const {
   // Iterate over the first num_row_blocks_e_ row blocks, and multiply
   // by the first cell in each row block.
   auto bs = matrix_.block_structure();
   const double* values = matrix_.values();
-  ParallelFor(context,
+  ParallelFor(options_.context,
               0,
               num_row_blocks_e_,
-              num_threads,
+              options_.num_threads,
               [values, bs, x, y](int row_block_id) {
                 const Cell& cell = bs->rows[row_block_id].cells[0];
                 const int row_block_pos = bs->rows[row_block_id].block.position;
@@ -125,15 +138,6 @@
 template <int kRowBlockSize, int kEBlockSize, int kFBlockSize>
 void PartitionedMatrixView<kRowBlockSize, kEBlockSize, kFBlockSize>::
     RightMultiplyAndAccumulateF(const double* x, double* y) const {
-  RightMultiplyAndAccumulateF(x, y, nullptr, 1);
-}
-
-template <int kRowBlockSize, int kEBlockSize, int kFBlockSize>
-void PartitionedMatrixView<kRowBlockSize, kEBlockSize, kFBlockSize>::
-    RightMultiplyAndAccumulateF(const double* x,
-                                double* y,
-                                ContextImpl* context,
-                                int num_threads) const {
   // Iterate over row blocks, and if the row block is in E, then
   // multiply by all the cells except the first one which is of type
   // E. If the row block is not in E (i.e its in the bottom
@@ -143,10 +147,10 @@
   const int num_row_blocks = bs->rows.size();
   const int num_cols_e = num_cols_e_;
   const double* values = matrix_.values();
-  ParallelFor(context,
+  ParallelFor(options_.context,
               0,
               num_row_blocks_e_,
-              num_threads,
+              options_.num_threads,
               [values, bs, num_cols_e, x, y](int row_block_id) {
                 const int row_block_pos = bs->rows[row_block_id].block.position;
                 const int row_block_size = bs->rows[row_block_id].block.size;
@@ -163,10 +167,10 @@
                   // clang-format on
                 }
               });
-  ParallelFor(context,
+  ParallelFor(options_.context,
               num_row_blocks_e_,
               num_row_blocks,
-              num_threads,
+              options_.num_threads,
               [values, bs, num_cols_e, x, y](int row_block_id) {
                 const int row_block_pos = bs->rows[row_block_id].block.position;
                 const int row_block_size = bs->rows[row_block_id].block.size;
@@ -190,7 +194,17 @@
     LeftMultiplyAndAccumulateE(const double* x, double* y) const {
   if (!num_col_blocks_e_) return;
   if (!num_row_blocks_e_) return;
+  if (options_.num_threads == 1) {
+    LeftMultiplyAndAccumulateESingleThreaded(x, y);
+  } else {
+    CHECK(options_.context != nullptr);
+    LeftMultiplyAndAccumulateEMultiThreaded(x, y);
+  }
+}
 
+template <int kRowBlockSize, int kEBlockSize, int kFBlockSize>
+void PartitionedMatrixView<kRowBlockSize, kEBlockSize, kFBlockSize>::
+    LeftMultiplyAndAccumulateESingleThreaded(const double* x, double* y) const {
   const CompressedRowBlockStructure* bs = matrix_.block_structure();
 
   // Iterate over the first num_row_blocks_e_ row blocks, and multiply
@@ -214,28 +228,19 @@
 
 template <int kRowBlockSize, int kEBlockSize, int kFBlockSize>
 void PartitionedMatrixView<kRowBlockSize, kEBlockSize, kFBlockSize>::
-    LeftMultiplyAndAccumulateE(const double* x,
-                               double* y,
-                               ContextImpl* context,
-                               int num_threads) const {
-  if (!num_col_blocks_e_) return;
-  if (!num_row_blocks_e_) return;
-
+    LeftMultiplyAndAccumulateEMultiThreaded(const double* x, double* y) const {
   auto transpose_bs = matrix_.transpose_block_structure();
-  if (transpose_bs == nullptr || num_threads == 1) {
-    LeftMultiplyAndAccumulateE(x, y);
-    return;
-  }
+  CHECK(transpose_bs != nullptr);
 
   // Local copies of class members in order to avoid capturing pointer to the
   // whole object in lambda function
   auto values = matrix_.values();
   const int num_row_blocks_e = num_row_blocks_e_;
   ParallelFor(
-      context,
+      options_.context,
       0,
       num_col_blocks_e_,
-      num_threads,
+      options_.num_threads,
       [values, transpose_bs, num_row_blocks_e, x, y](int row_block_id) {
         int row_block_pos = transpose_bs->rows[row_block_id].block.position;
         int row_block_size = transpose_bs->rows[row_block_id].block.size;
@@ -254,14 +259,24 @@
               y + row_block_pos);
         }
       },
-      transpose_bs->rows.data(),
-      [](const CompressedRow& row) { return row.cumulative_nnz; });
+      e_cols_partition());
 }
 
 template <int kRowBlockSize, int kEBlockSize, int kFBlockSize>
 void PartitionedMatrixView<kRowBlockSize, kEBlockSize, kFBlockSize>::
     LeftMultiplyAndAccumulateF(const double* x, double* y) const {
   if (!num_col_blocks_f_) return;
+  if (options_.num_threads == 1) {
+    LeftMultiplyAndAccumulateFSingleThreaded(x, y);
+  } else {
+    CHECK(options_.context != nullptr);
+    LeftMultiplyAndAccumulateFMultiThreaded(x, y);
+  }
+}
+
+template <int kRowBlockSize, int kEBlockSize, int kFBlockSize>
+void PartitionedMatrixView<kRowBlockSize, kEBlockSize, kFBlockSize>::
+    LeftMultiplyAndAccumulateFSingleThreaded(const double* x, double* y) const {
   const CompressedRowBlockStructure* bs = matrix_.block_structure();
 
   // Iterate over row blocks, and if the row block is in E, then
@@ -307,26 +322,19 @@
 
 template <int kRowBlockSize, int kEBlockSize, int kFBlockSize>
 void PartitionedMatrixView<kRowBlockSize, kEBlockSize, kFBlockSize>::
-    LeftMultiplyAndAccumulateF(const double* x,
-                               double* y,
-                               ContextImpl* context,
-                               int num_threads) const {
-  if (!num_col_blocks_f_) return;
+    LeftMultiplyAndAccumulateFMultiThreaded(const double* x, double* y) const {
   auto transpose_bs = matrix_.transpose_block_structure();
-  if (transpose_bs == nullptr || num_threads == 1) {
-    LeftMultiplyAndAccumulateF(x, y);
-    return;
-  }
+  CHECK(transpose_bs != nullptr);
   // Local copies of class members  in order to avoid capturing pointer to the
   // whole object in lambda function
   auto values = matrix_.values();
   const int num_row_blocks_e = num_row_blocks_e_;
   const int num_cols_e = num_cols_e_;
   ParallelFor(
-      context,
+      options_.context,
       num_col_blocks_e_,
       num_col_blocks_e_ + num_col_blocks_f_,
-      num_threads,
+      options_.num_threads,
       [values, transpose_bs, num_row_blocks_e, num_cols_e, x, y](
           int row_block_id) {
         int row_block_pos = transpose_bs->rows[row_block_id].block.position;
@@ -362,8 +370,7 @@
               y + row_block_pos - num_cols_e);
         }
       },
-      transpose_bs->rows.data(),
-      [](const CompressedRow& row) { return row.cumulative_nnz; });
+      f_cols_partition());
 }
 
 // Given a range of columns blocks of a matrix m, compute the block
diff --git a/internal/ceres/partitioned_matrix_view_template.py b/internal/ceres/partitioned_matrix_view_template.py
index 27f368f..8978d63 100644
--- a/internal/ceres/partitioned_matrix_view_template.py
+++ b/internal/ceres/partitioned_matrix_view_template.py
@@ -132,7 +132,7 @@
 #ifndef CERES_RESTRICT_SCHUR_SPECIALIZATION
 """
 FACTORY = """  return std::make_unique<PartitionedMatrixView<%s,%s, %s>>(
-                   matrix, options.elimination_groups[0]);"""
+                   options, matrix);"""
 
 FACTORY_FOOTER = """
 #endif
@@ -142,7 +142,7 @@
   return std::make_unique<PartitionedMatrixView<Eigen::Dynamic,
                                                 Eigen::Dynamic,
                                                 Eigen::Dynamic>>(
-      matrix, options.elimination_groups[0]);
+      options, matrix);
 };
 
 }  // namespace ceres::internal
diff --git a/internal/ceres/partitioned_matrix_view_test.cc b/internal/ceres/partitioned_matrix_view_test.cc
index 4653e06..fc1ebdd 100644
--- a/internal/ceres/partitioned_matrix_view_test.cc
+++ b/internal/ceres/partitioned_matrix_view_test.cc
@@ -76,76 +76,6 @@
       std::uniform_real_distribution<double>(0.0, 1.0);
 };
 
-TEST_F(PartitionedMatrixViewTest, DimensionsTest) {
-  EXPECT_EQ(pmv_->num_col_blocks_e(), num_eliminate_blocks_);
-  EXPECT_EQ(pmv_->num_col_blocks_f(), num_cols_ - num_eliminate_blocks_);
-  EXPECT_EQ(pmv_->num_cols_e(), num_eliminate_blocks_);
-  EXPECT_EQ(pmv_->num_cols_f(), num_cols_ - num_eliminate_blocks_);
-  EXPECT_EQ(pmv_->num_cols(), A_->num_cols());
-  EXPECT_EQ(pmv_->num_rows(), A_->num_rows());
-}
-
-TEST_F(PartitionedMatrixViewTest, RightMultiplyAndAccumulateE) {
-  Vector x1(pmv_->num_cols_e());
-  Vector x2(pmv_->num_cols());
-  x2.setZero();
-
-  for (int i = 0; i < pmv_->num_cols_e(); ++i) {
-    x1(i) = x2(i) = RandDouble();
-  }
-
-  Vector y1 = Vector::Zero(pmv_->num_rows());
-  pmv_->RightMultiplyAndAccumulateE(x1.data(), y1.data());
-
-  Vector y2 = Vector::Zero(pmv_->num_rows());
-  A_->RightMultiplyAndAccumulate(x2.data(), y2.data());
-
-  for (int i = 0; i < pmv_->num_rows(); ++i) {
-    EXPECT_NEAR(y1(i), y2(i), kEpsilon);
-  }
-}
-
-TEST_F(PartitionedMatrixViewTest, RightMultiplyAndAccumulateF) {
-  Vector x1(pmv_->num_cols_f());
-  Vector x2 = Vector::Zero(pmv_->num_cols());
-
-  for (int i = 0; i < pmv_->num_cols_f(); ++i) {
-    x1(i) = RandDouble();
-    x2(i + pmv_->num_cols_e()) = x1(i);
-  }
-
-  Vector y1 = Vector::Zero(pmv_->num_rows());
-  pmv_->RightMultiplyAndAccumulateF(x1.data(), y1.data());
-
-  Vector y2 = Vector::Zero(pmv_->num_rows());
-  A_->RightMultiplyAndAccumulate(x2.data(), y2.data());
-
-  for (int i = 0; i < pmv_->num_rows(); ++i) {
-    EXPECT_NEAR(y1(i), y2(i), kEpsilon);
-  }
-}
-
-TEST_F(PartitionedMatrixViewTest, LeftMultiplyAndAccumulate) {
-  Vector x = Vector::Zero(pmv_->num_rows());
-  for (int i = 0; i < pmv_->num_rows(); ++i) {
-    x(i) = RandDouble();
-  }
-
-  Vector y = Vector::Zero(pmv_->num_cols());
-  Vector y1 = Vector::Zero(pmv_->num_cols_e());
-  Vector y2 = Vector::Zero(pmv_->num_cols_f());
-
-  A_->LeftMultiplyAndAccumulate(x.data(), y.data());
-  pmv_->LeftMultiplyAndAccumulateE(x.data(), y1.data());
-  pmv_->LeftMultiplyAndAccumulateF(x.data(), y2.data());
-
-  for (int i = 0; i < pmv_->num_cols(); ++i) {
-    EXPECT_NEAR(y(i),
-                (i < pmv_->num_cols_e()) ? y1(i) : y2(i - pmv_->num_cols_e()),
-                kEpsilon);
-  }
-}
-
 TEST_F(PartitionedMatrixViewTest, BlockDiagonalEtE) {
   std::unique_ptr<BlockSparseMatrix> block_diagonal_ee(
       pmv_->CreateBlockDiagonalEtE());
@@ -174,134 +104,122 @@
   EXPECT_NEAR(block_diagonal_ff->values()[2], 37.0, kEpsilon);
 }
 
-const int kMaxNumThreads = 8;
-class PartitionedMatrixViewParallelTest : public ::testing::TestWithParam<int> {
+// Param = <problem_id, num_threads>
+using Param = ::testing::tuple<int, int>;
+
+static std::string ParamInfoToString(testing::TestParamInfo<Param> info) {
+  Param param = info.param;
+  std::stringstream ss;
+  ss << ::testing::get<0>(param) << "_" << ::testing::get<1>(param);
+  return ss.str();
+}
+
+class PartitionedMatrixViewSpMVTest : public ::testing::TestWithParam<Param> {
  protected:
-  static const int kNumProblems = 3;
   void SetUp() final {
-    int problem_ids[kNumProblems] = {2, 4, 6};
-    for (int i = 0; i < kNumProblems; ++i) {
-      auto problem = CreateLinearLeastSquaresProblemFromId(problem_ids[i]);
-      CHECK(problem != nullptr);
-      SetUpMatrix(i, problem.get());
-    }
-
-    context_.EnsureMinimumThreads(kMaxNumThreads);
-  }
-
-  void SetUpMatrix(int id, LinearLeastSquaresProblem* problem) {
-    A_[id] = std::move(problem->A);
-    auto& A = A_[id];
-    auto block_sparse = down_cast<BlockSparseMatrix*>(A.get());
+    const int problem_id = ::testing::get<0>(GetParam());
+    const int num_threads = ::testing::get<1>(GetParam());
+    auto problem = CreateLinearLeastSquaresProblemFromId(problem_id);
+    CHECK(problem != nullptr);
+    A_ = std::move(problem->A);
+    auto block_sparse = down_cast<BlockSparseMatrix*>(A_.get());
     block_sparse->AddTransposeBlockStructure();
 
-    num_cols_[id] = A->num_cols();
-    num_rows_[id] = A->num_rows();
-    num_eliminate_blocks_[id] = problem->num_eliminate_blocks;
-    LinearSolver::Options options;
-    options.elimination_groups.push_back(num_eliminate_blocks_[id]);
-    pmv_[id] = PartitionedMatrixViewBase::Create(options, *block_sparse);
+    options_.num_threads = num_threads;
+    options_.context = &context_;
+    options_.elimination_groups.push_back(problem->num_eliminate_blocks);
+    pmv_ = PartitionedMatrixViewBase::Create(options_, *block_sparse);
+
+    EXPECT_EQ(pmv_->num_col_blocks_e(), problem->num_eliminate_blocks);
+    EXPECT_EQ(pmv_->num_col_blocks_f(),
+              block_sparse->block_structure()->cols.size() -
+                  problem->num_eliminate_blocks);
+    EXPECT_EQ(pmv_->num_cols(), A_->num_cols());
+    EXPECT_EQ(pmv_->num_rows(), A_->num_rows());
   }
 
   double RandDouble() { return distribution_(prng_); }
 
+  LinearSolver::Options options_;
   ContextImpl context_;
-  int num_rows_[kNumProblems];
-  int num_cols_[kNumProblems];
-  int num_eliminate_blocks_[kNumProblems];
-  std::unique_ptr<SparseMatrix> A_[kNumProblems];
-  std::unique_ptr<PartitionedMatrixViewBase> pmv_[kNumProblems];
+  std::unique_ptr<LinearLeastSquaresProblem> problem_;
+  std::unique_ptr<SparseMatrix> A_;
+  std::unique_ptr<PartitionedMatrixViewBase> pmv_;
+  int num_cols_e;
   std::mt19937 prng_;
   std::uniform_real_distribution<double> distribution_ =
       std::uniform_real_distribution<double>(0.0, 1.0);
 };
 
-TEST_P(PartitionedMatrixViewParallelTest, RightMultiplyAndAccumulateEParallel) {
-  const int kNumThreads = GetParam();
-  for (int p = 0; p < kNumProblems; ++p) {
-    auto& pmv = pmv_[p];
-    auto& A = A_[p];
+TEST_P(PartitionedMatrixViewSpMVTest, RightMultiplyAndAccumulateE) {
+  Vector x1(pmv_->num_cols_e());
+  Vector x2(pmv_->num_cols());
+  x2.setZero();
 
-    Vector x1(pmv->num_cols_e());
-    Vector x2(pmv->num_cols());
-    x2.setZero();
+  for (int i = 0; i < pmv_->num_cols_e(); ++i) {
+    x1(i) = x2(i) = RandDouble();
+  }
 
-    for (int i = 0; i < pmv->num_cols_e(); ++i) {
-      x1(i) = x2(i) = RandDouble();
-    }
+  Vector expected = Vector::Zero(pmv_->num_rows());
+  A_->RightMultiplyAndAccumulate(x2.data(), expected.data());
 
-    Vector y1 = Vector::Zero(pmv->num_rows());
-    pmv->RightMultiplyAndAccumulateE(
-        x1.data(), y1.data(), &context_, kNumThreads);
+  Vector actual = Vector::Zero(pmv_->num_rows());
+  pmv_->RightMultiplyAndAccumulateE(x1.data(), actual.data());
 
-    Vector y2 = Vector::Zero(pmv->num_rows());
-    A->RightMultiplyAndAccumulate(x2.data(), y2.data());
-
-    for (int i = 0; i < pmv->num_rows(); ++i) {
-      EXPECT_NEAR(y1(i), y2(i), kEpsilon);
-    }
+  for (int i = 0; i < pmv_->num_rows(); ++i) {
+    EXPECT_NEAR(actual(i), expected(i), kEpsilon);
   }
 }
 
-TEST_P(PartitionedMatrixViewParallelTest, RightMultiplyAndAccumulateFParallel) {
-  const int kNumThreads = GetParam();
-  for (int p = 0; p < kNumProblems; ++p) {
-    auto& pmv = pmv_[p];
-    auto& A = A_[p];
-    Vector x1(pmv->num_cols_f());
-    Vector x2(pmv->num_cols());
-    x2.setZero();
+TEST_P(PartitionedMatrixViewSpMVTest, RightMultiplyAndAccumulateF) {
+  Vector x1(pmv_->num_cols_f());
+  Vector x2(pmv_->num_cols());
+  x2.setZero();
 
-    for (int i = 0; i < pmv->num_cols_f(); ++i) {
-      x1(i) = x2(i + pmv->num_cols_e()) = RandDouble();
-    }
+  for (int i = 0; i < pmv_->num_cols_f(); ++i) {
+    x1(i) = x2(i + pmv_->num_cols_e()) = RandDouble();
+  }
 
-    Vector y1 = Vector::Zero(pmv->num_rows());
-    pmv->RightMultiplyAndAccumulateF(
-        x1.data(), y1.data(), &context_, kNumThreads);
+  Vector actual = Vector::Zero(pmv_->num_rows());
+  pmv_->RightMultiplyAndAccumulateF(x1.data(), actual.data());
 
-    Vector y2 = Vector::Zero(pmv->num_rows());
-    A->RightMultiplyAndAccumulate(x2.data(), y2.data());
+  Vector expected = Vector::Zero(pmv_->num_rows());
+  A_->RightMultiplyAndAccumulate(x2.data(), expected.data());
 
-    for (int i = 0; i < pmv->num_rows(); ++i) {
-      EXPECT_NEAR(y1(i), y2(i), kEpsilon);
-    }
+  for (int i = 0; i < pmv_->num_rows(); ++i) {
+    EXPECT_NEAR(actual(i), expected(i), kEpsilon);
   }
 }
 
-TEST_P(PartitionedMatrixViewParallelTest, LeftMultiplyAndAccumulateParallel) {
-  const int kNumThreads = GetParam();
-  for (int p = 0; p < kNumProblems; ++p) {
-    auto& pmv = pmv_[p];
-    auto& A = A_[p];
-    Vector x = Vector::Zero(pmv->num_rows());
-    for (int i = 0; i < pmv->num_rows(); ++i) {
-      x(i) = RandDouble();
-    }
-    Vector x_pre = x;
+TEST_P(PartitionedMatrixViewSpMVTest, LeftMultiplyAndAccumulate) {
+  Vector x = Vector::Zero(pmv_->num_rows());
+  for (int i = 0; i < pmv_->num_rows(); ++i) {
+    x(i) = RandDouble();
+  }
+  Vector x_pre = x;
 
-    Vector y = Vector::Zero(pmv->num_cols());
-    Vector y1 = Vector::Zero(pmv->num_cols_e());
-    Vector y2 = Vector::Zero(pmv->num_cols_f());
+  Vector expected = Vector::Zero(pmv_->num_cols());
+  Vector e_actual = Vector::Zero(pmv_->num_cols_e());
+  Vector f_actual = Vector::Zero(pmv_->num_cols_f());
 
-    A->LeftMultiplyAndAccumulate(x.data(), y.data());
-    pmv->LeftMultiplyAndAccumulateE(
-        x.data(), y1.data(), &context_, kNumThreads);
-    pmv->LeftMultiplyAndAccumulateF(
-        x.data(), y2.data(), &context_, kNumThreads);
+  A_->LeftMultiplyAndAccumulate(x.data(), expected.data());
+  pmv_->LeftMultiplyAndAccumulateE(x.data(), e_actual.data());
+  pmv_->LeftMultiplyAndAccumulateF(x.data(), f_actual.data());
 
-    for (int i = 0; i < pmv->num_cols(); ++i) {
-      EXPECT_NEAR(y(i),
-                  (i < pmv->num_cols_e()) ? y1(i) : y2(i - pmv->num_cols_e()),
-                  kEpsilon);
-    }
+  for (int i = 0; i < pmv_->num_cols(); ++i) {
+    EXPECT_NEAR(expected(i),
+                (i < pmv_->num_cols_e()) ? e_actual(i)
+                                         : f_actual(i - pmv_->num_cols_e()),
+                kEpsilon);
   }
 }
 
-INSTANTIATE_TEST_SUITE_P(ParallelProducts,
-                         PartitionedMatrixViewParallelTest,
-                         ::testing::Values(1, 2, 4, 8),
-                         ::testing::PrintToStringParamName());
+INSTANTIATE_TEST_SUITE_P(
+    ParallelProducts,
+    PartitionedMatrixViewSpMVTest,
+    ::testing::Combine(::testing::Values(2, 4, 6),
+                       ::testing::Values(1, 2, 3, 4, 5, 6, 7, 8)),
+    ParamInfoToString);
 
 }  // namespace internal
 }  // namespace ceres