Change implementation of parallel for

Implemented templated invocation routines for ParallelFor backends
in order to improve loop body inlining.

Several modifications of ParallelFor implementation using CXX threads:
 - Index order changed from interleaved to sequential
 - Static task scheduling replaced with dynamic (controlled by
   kWorkBlocksPerThread)
 - Changed index retrieval to atomic

Modifications of OpenMP backend:
 - Changed loop scheduling to guided

Changing index order from interleaved to sequential in parallel seem
to significantly improve run-times of parallel loops, for example in
evaluation of jacobian and residuals.

Other modifications provide minor improvements for unbalanced
sub-problem lengths and parallel for loops with small number of
computation per operation.

Single-threaded performance was improved by avoiding costs of
wrapping parallel loop bodies in std::function.

On BAL dataset the following improvements in time consumed for
evaluation of residuals or jacobian and residuals were observed:

                                     OLD           NEW        OLD/NEW
                 dataset threads     r     J     r     J     r     J
problem-257-65132-pre.txt      1 0.025  0.079  0.025  0.074 1.016 1.056
problem-257-65132-pre.txt      2 0.030  0.062  0.022  0.050 1.333 1.246
problem-257-65132-pre.txt      4 0.023  0.052  0.014  0.034 1.592 1.515
problem-257-65132-pre.txt      8 0.015  0.035  0.010  0.025 1.477 1.401
problem-257-65132-pre.txt     16 0.011  0.027  0.008  0.019 1.365 1.377
problem-356-226730-pre.txt     1 0.150  0.442  0.147  0.412 1.017 1.070
problem-356-226730-pre.txt     2 0.155  0.322  0.100  0.281 1.542 1.145
problem-356-226730-pre.txt     4 0.129  0.291  0.089  0.196 1.439 1.485
problem-356-226730-pre.txt     8 0.091  0.184  0.066  0.139 1.381 1.319
problem-356-226730-pre.txt    16 0.070  0.148  0.055  0.110 1.272 1.340
problem-1723-156502-pre.txt    1 0.084  0.243  0.082  0.229 1.023 1.063
problem-1723-156502-pre.txt    2 0.088  0.188  0.055  0.154 1.589 1.222
problem-1723-156502-pre.txt    4 0.072  0.159  0.049  0.108 1.475 1.475
problem-1723-156502-pre.txt    8 0.050  0.105  0.037  0.077 1.348 1.368
problem-1723-156502-pre.txt   16 0.038  0.083  0.030  0.062 1.269 1.344
problem-1778-993923-pre.txt    1 0.621  1.777  0.609  1.667 1.018 1.065
problem-1778-993923-pre.txt    2 0.621  1.273  0.415  1.199 1.494 1.061
problem-1778-993923-pre.txt    4 0.514  1.140  0.361  0.786 1.421 1.449
problem-1778-993923-pre.txt    8 0.365  0.808  0.277  0.559 1.319 1.443
problem-1778-993923-pre.txt   16 0.279  0.608  0.223  0.441 1.252 1.379
problem-13682-4456117-pre.txt  1 3.877 10.726  3.738 10.082 1.037 1.063
problem-13682-4456117-pre.txt  2 3.310  7.170  2.423  6.448 1.366 1.111
problem-13682-4456117-pre.txt  4 3.070  6.344  2.064  4.474 1.486 1.417
problem-13682-4456117-pre.txt  8 2.051  4.612  1.527  3.133 1.343 1.472
problem-13682-4456117-pre.txt 16 1.549  3.453  1.218  2.488 1.271 1.387

Run time in seconds for a single evaluation, using evaluation_benchmark
numactl -N 0 -m 0 ./bin/evaluation_benchmark --bal_root ${path_to_BAL}
Evaluation was performed on 28-core CPU.

Note: performance when running across numa-nodes degrades in both old
and proposed implementations, thus the test was executed limiting memory
and compute resources allocation to a single numa-node.

Change-Id: Ia195580bdab9d05c95ac983bfe37b045eecfaf49
diff --git a/internal/ceres/parallel_for.h b/internal/ceres/parallel_for.h
index e5599cc..3c3d887 100644
--- a/internal/ceres/parallel_for.h
+++ b/internal/ceres/parallel_for.h
@@ -26,7 +26,8 @@
 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 // POSSIBILITY OF SUCH DAMAGE.
 //
-// Author: vitus@google.com (Michael Vitus)
+// Authors: vitus@google.com (Michael Vitus),
+//          dmitriy.korchemkin@gmail.com (Dmitriy Korchemkin)
 
 #ifndef CERES_INTERNAL_PARALLEL_FOR_H_
 #define CERES_INTERNAL_PARALLEL_FOR_H_
@@ -36,6 +37,7 @@
 #include "ceres/context_impl.h"
 #include "ceres/internal/disable_warnings.h"
 #include "ceres/internal/export.h"
+#include "glog/logging.h"
 
 namespace ceres::internal {
 
@@ -44,28 +46,101 @@
 CERES_NO_EXPORT
 int MaxNumThreadsAvailable();
 
-// Execute the function for every element in the range [start, end) with at most
-// num_threads. It will execute all the work on the calling thread if
-// num_threads is 1.
-CERES_NO_EXPORT void ParallelFor(ContextImpl* context,
-                                 int start,
-                                 int end,
-                                 int num_threads,
-                                 const std::function<void(int)>& function);
+// Parallel for implementations share a common set of routines in order
+// to enforce inlining of loop bodies, ensuring that single-threaded
+// performance is equivalent to a simple for loop
+namespace parallel_for_details {
+// Get arguments of callable as a tuple
+template <typename F, typename... Args>
+std::tuple<Args...> args_of(void (F::*)(Args...) const);
+
+template <typename F>
+using args_of_t = decltype(args_of(&F::operator()));
+
+// Parallelizable functions might require passing thread_id as the first
+// argument. This class supplies thread_id argument to functions that
+// support it and ignores it otherwise.
+template <typename F, typename Args>
+struct InvokeImpl;
+
+// For parallel for iterations of type [](int i) -> void
+template <typename F>
+struct InvokeImpl<F, std::tuple<int>> {
+  static void Invoke(int thread_id, int i, const F& function) {
+    (void)thread_id;
+    function(i);
+  }
+};
+
+// For parallel for iterations of type [](int thread_id, int i) -> void
+template <typename F>
+struct InvokeImpl<F, std::tuple<int, int>> {
+  static void Invoke(int thread_id, int i, const F& function) {
+    function(thread_id, i);
+  }
+};
+
+// Invoke function passing thread_id only if required
+template <typename F>
+void Invoke(int thread_id, int i, const F& function) {
+  InvokeImpl<F, args_of_t<F>>::Invoke(thread_id, i, function);
+}
+}  // namespace parallel_for_details
+
+// Forward declaration of parallel invocation function that is to be
+// implemented by each threading backend
+template <typename F>
+void ParallelInvoke(ContextImpl* context,
+                    int i,
+                    int num_threads,
+                    const F& function);
 
 // Execute the function for every element in the range [start, end) with at most
 // num_threads. It will execute all the work on the calling thread if
-// num_threads is 1.  Each invocation of function() will be passed a thread_id
-// in [0, num_threads) that is guaranteed to be distinct from the value passed
-// to any concurrent execution of function().
-CERES_NO_EXPORT void ParallelFor(
-    ContextImpl* context,
-    int start,
-    int end,
-    int num_threads,
-    const std::function<void(int thread_id, int i)>& function);
+// num_threads or (end - start) is equal to 1.
+//
+// Functions with two arguments will be passed thread_id and loop index on each
+// invocation, functions with one argument will be invoked with loop index
+template <typename F>
+void ParallelFor(ContextImpl* context,
+                 int start,
+                 int end,
+                 int num_threads,
+                 const F& function) {
+  using namespace parallel_for_details;
+  CHECK_GT(num_threads, 0);
+  if (start >= end) {
+    return;
+  }
+
+  if (num_threads == 1 || end - start == 1) {
+    for (int i = start; i < end; ++i) {
+      Invoke<F>(0, i, function);
+    }
+    return;
+  }
+
+  CHECK(context != nullptr);
+  ParallelInvoke<F>(context, start, end, num_threads, function);
+}
 }  // namespace ceres::internal
 
+// Backend-specific implementations of ParallelInvoke
+#include "ceres/parallel_for_cxx.h"
+#include "ceres/parallel_for_openmp.h"
+#ifdef CERES_NO_THREADS
+namespace ceres::internal {
+template <typename F>
+void ParallelInvoke(ContextImpl* context,
+                    int start,
+                    int end,
+                    int num_threads,
+                    const F& function) {
+  ParallelFor(context, start, end, 1, function);
+}
+}  // namespace ceres::internal
+#endif
+
 #include "ceres/internal/disable_warnings.h"
 
 #endif  // CERES_INTERNAL_PARALLEL_FOR_H_
diff --git a/internal/ceres/parallel_for_cxx.cc b/internal/ceres/parallel_for_cxx.cc
index df2f619..13cabf9 100644
--- a/internal/ceres/parallel_for_cxx.cc
+++ b/internal/ceres/parallel_for_cxx.cc
@@ -33,211 +33,51 @@
 
 #ifdef CERES_USE_CXX_THREADS
 
+#include <atomic>
 #include <cmath>
 #include <condition_variable>
 #include <memory>
 #include <mutex>
 
-#include "ceres/concurrent_queue.h"
 #include "ceres/parallel_for.h"
-#include "ceres/scoped_thread_token.h"
-#include "ceres/thread_token_provider.h"
 #include "glog/logging.h"
 
 namespace ceres::internal {
-namespace {
-// This class creates a thread safe barrier which will block until a
-// pre-specified number of threads call Finished.  This allows us to block the
-// main thread until all the parallel threads are finished processing all the
-// work.
-class BlockUntilFinished {
- public:
-  explicit BlockUntilFinished(int num_total)
-      : num_finished_(0), num_total_(num_total) {}
 
-  // Increment the number of jobs that have finished and signal the blocking
-  // thread if all jobs have finished.
-  void Finished() {
-    std::lock_guard<std::mutex> lock(mutex_);
-    ++num_finished_;
-    CHECK_LE(num_finished_, num_total_);
-    if (num_finished_ == num_total_) {
-      condition_.notify_one();
-    }
+BlockUntilFinished::BlockUntilFinished(int num_total_jobs)
+    : num_total_jobs_finished_(0), num_total_jobs_(num_total_jobs) {}
+
+void BlockUntilFinished::Finished(int num_jobs_finished) {
+  if (num_jobs_finished == 0) return;
+  std::lock_guard<std::mutex> lock(mutex_);
+  num_total_jobs_finished_ += num_jobs_finished;
+  CHECK_LE(num_total_jobs_finished_, num_total_jobs_);
+  if (num_total_jobs_finished_ == num_total_jobs_) {
+    condition_.notify_one();
   }
+}
 
-  // Block until all threads have signaled they are finished.
-  void Block() {
-    std::unique_lock<std::mutex> lock(mutex_);
-    condition_.wait(lock, [&]() { return num_finished_ == num_total_; });
-  }
+void BlockUntilFinished::Block() {
+  std::unique_lock<std::mutex> lock(mutex_);
+  condition_.wait(
+      lock, [&]() { return num_total_jobs_finished_ == num_total_jobs_; });
+}
 
- private:
-  std::mutex mutex_;
-  std::condition_variable condition_;
-  // The current number of jobs finished.
-  int num_finished_;
-  // The total number of jobs.
-  int num_total_;
-};
-
-// Shared state between the parallel tasks. Each thread will use this
-// information to get the next block of work to be performed.
-struct SharedState {
-  SharedState(int start, int end, int num_work_items)
-      : start(start),
-        end(end),
-        num_work_items(num_work_items),
-        i(0),
-        thread_token_provider(num_work_items),
-        block_until_finished(num_work_items) {}
-
-  // The start and end index of the for loop.
-  const int start;
-  const int end;
-  // The number of blocks that need to be processed.
-  const int num_work_items;
-
-  // The next block of work to be assigned to a worker.  The parallel for loop
-  // range is split into num_work_items blocks of work, i.e. a single block of
-  // work is:
-  //  for (int j = start + i; j < end; j += num_work_items) { ... }.
-  int i;
-  std::mutex mutex_i;
-
-  // Provides a unique thread ID among all active threads working on the same
-  // group of tasks.  Thread-safe.
-  ThreadTokenProvider thread_token_provider;
-
-  // Used to signal when all the work has been completed.  Thread safe.
-  BlockUntilFinished block_until_finished;
-};
-
-}  // namespace
+ThreadPoolState::ThreadPoolState(int start,
+                                 int end,
+                                 int num_work_blocks,
+                                 int num_workers)
+    : start(start),
+      end(end),
+      num_work_blocks(num_work_blocks),
+      base_block_size((end - start) / num_work_blocks),
+      num_base_p1_sized_blocks((end - start) % num_work_blocks),
+      block_id(0),
+      thread_id(0),
+      block_until_finished(num_work_blocks) {}
 
 int MaxNumThreadsAvailable() { return ThreadPool::MaxNumThreadsAvailable(); }
 
-// See ParallelFor (below) for more details.
-void ParallelFor(ContextImpl* context,
-                 int start,
-                 int end,
-                 int num_threads,
-                 const std::function<void(int)>& function) {
-  CHECK_GT(num_threads, 0);
-  CHECK(context != nullptr);
-  if (end <= start) {
-    return;
-  }
-
-  // Fast path for when it is single threaded.
-  if (num_threads == 1) {
-    for (int i = start; i < end; ++i) {
-      function(i);
-    }
-    return;
-  }
-
-  ParallelFor(
-      context, start, end, num_threads, [&function](int /*thread_id*/, int i) {
-        function(i);
-      });
-}
-
-// This implementation uses a fixed size max worker pool with a shared task
-// queue. The problem of executing the function for the interval of [start, end)
-// is broken up into at most num_threads blocks and added to the thread pool. To
-// avoid deadlocks, the calling thread is allowed to steal work from the worker
-// pool. This is implemented via a shared state between the tasks. In order for
-// the calling thread or thread pool to get a block of work, it will query the
-// shared state for the next block of work to be done. If there is nothing left,
-// it will return. We will exit the ParallelFor call when all of the work has
-// been done, not when all of the tasks have been popped off the task queue.
-//
-// A unique thread ID among all active tasks will be acquired once for each
-// block of work.  This avoids the significant performance penalty for acquiring
-// it on every iteration of the for loop. The thread ID is guaranteed to be in
-// [0, num_threads).
-//
-// A performance analysis has shown this implementation is on par with OpenMP
-// and TBB.
-void ParallelFor(ContextImpl* context,
-                 int start,
-                 int end,
-                 int num_threads,
-                 const std::function<void(int thread_id, int i)>& function) {
-  CHECK_GT(num_threads, 0);
-  CHECK(context != nullptr);
-  if (end <= start) {
-    return;
-  }
-
-  // Fast path for when it is single threaded.
-  if (num_threads == 1) {
-    // Even though we only have one thread, use the thread token provider to
-    // guarantee the exact same behavior when running with multiple threads.
-    ThreadTokenProvider thread_token_provider(num_threads);
-    const ScopedThreadToken scoped_thread_token(&thread_token_provider);
-    const int thread_id = scoped_thread_token.token();
-    for (int i = start; i < end; ++i) {
-      function(thread_id, i);
-    }
-    return;
-  }
-
-  // We use a std::shared_ptr because the main thread can finish all
-  // the work before the tasks have been popped off the queue.  So the
-  // shared state needs to exist for the duration of all the tasks.
-  const int num_work_items = std::min((end - start), num_threads);
-  std::shared_ptr<SharedState> shared_state(
-      new SharedState(start, end, num_work_items));
-
-  // A function which tries to perform a chunk of work. This returns false if
-  // there is no work to be done.
-  auto task_function = [shared_state, &function]() {
-    int i = 0;
-    {
-      // Get the next available chunk of work to be performed. If there is no
-      // work, return false.
-      std::lock_guard<std::mutex> lock(shared_state->mutex_i);
-      if (shared_state->i >= shared_state->num_work_items) {
-        return false;
-      }
-      i = shared_state->i;
-      ++shared_state->i;
-    }
-
-    const ScopedThreadToken scoped_thread_token(
-        &shared_state->thread_token_provider);
-    const int thread_id = scoped_thread_token.token();
-
-    // Perform each task.
-    for (int j = shared_state->start + i; j < shared_state->end;
-         j += shared_state->num_work_items) {
-      function(thread_id, j);
-    }
-    shared_state->block_until_finished.Finished();
-    return true;
-  };
-
-  // Add all the tasks to the thread pool.
-  for (int i = 0; i < num_work_items; ++i) {
-    // Note we are taking the task_function as value so the shared_state
-    // shared pointer is copied and the ref count is increased. This is to
-    // prevent it from being deleted when the main thread finishes all the
-    // work and exits before the threads finish.
-    context->thread_pool.AddTask([task_function]() { task_function(); });
-  }
-
-  // Try to do any available work on the main thread. This may steal work from
-  // the thread pool, but when there is no work left the thread pool tasks
-  // will be no-ops.
-  while (task_function()) {
-  }
-
-  // Wait until all tasks have finished.
-  shared_state->block_until_finished.Block();
-}
-
 }  // namespace ceres::internal
 
 #endif  // CERES_USE_CXX_THREADS
diff --git a/internal/ceres/parallel_for_cxx.h b/internal/ceres/parallel_for_cxx.h
new file mode 100644
index 0000000..90edc07
--- /dev/null
+++ b/internal/ceres/parallel_for_cxx.h
@@ -0,0 +1,249 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2018 Google Inc. All rights reserved.
+// http://ceres-solver.org/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+//   this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+//   this list of conditions and the following disclaimer in the documentation
+//   and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+//   used to endorse or promote products derived from this software without
+//   specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Authors: vitus@google.com (Michael Vitus),
+//          dmitriy.korchemkin@gmail.com (Dmitriy Korchemkin)
+
+// This include must come before any #ifndef check on Ceres compile options.
+#ifndef CERES_INTERNAL_PARALLEL_FOR_CXX_H_
+#define CERES_INTERNAL_PARALLEL_FOR_CXX_H_
+
+// This include must come before any #ifndef check on Ceres compile options.
+#include "ceres/internal/config.h"
+
+#ifdef CERES_USE_CXX_THREADS
+
+#include <atomic>
+#include <cmath>
+#include <condition_variable>
+#include <memory>
+#include <mutex>
+
+#include "glog/logging.h"
+
+namespace ceres::internal {
+// This class creates a thread safe barrier which will block until a
+// pre-specified number of threads call Finished.  This allows us to block the
+// main thread until all the parallel threads are finished processing all the
+// work.
+class BlockUntilFinished {
+ public:
+  explicit BlockUntilFinished(int num_total_jobs);
+
+  // Increment the number of jobs that have been processed by the number of
+  // jobs processed by caller and signal the blocking thread if all jobs
+  // have finished.
+  void Finished(int num_jobs_finished);
+
+  // Block until receiving confirmation of all jobs being finished.
+  void Block();
+
+ private:
+  std::mutex mutex_;
+  std::condition_variable condition_;
+  int num_total_jobs_finished_;
+  int num_total_jobs_;
+};
+
+// Shared state between the parallel tasks. Each thread will use this
+// information to get the next block of work to be performed.
+struct ThreadPoolState {
+  // The entire range [start, end) is split into num_work_blocks contiguous
+  // disjoint intervals (blocks), which are as equal as possible given
+  // total index count and requested number of  blocks.
+  //
+  // Those num_work_blocks blocks are then processed by num_workers
+  // workers
+  //
+  // Total number of integer indices in interval [start, end) is
+  // end - start, and when splitting them into num_work_blocks blocks
+  // we can either
+  //  - Split into equal blocks when (end - start) is divisible by
+  //    num_work_blocks
+  //  - Split into blocks with size difference at most 1:
+  //     - Size of the smallest block(s) is (end - start) / num_work_blocks
+  //     - (end - start) % num_work_blocks will need to be 1 index larger
+  //
+  // Note that this splitting is optimal in the sense of maximal difference
+  // between block sizes, since splitting into equal blocks is possible
+  // if and only if number of indices is divisible by number of blocks.
+  ThreadPoolState(int start, int end, int num_work_blocks, int num_workers);
+
+  // The start and end index of the for loop.
+  const int start;
+  const int end;
+  // The number of blocks that need to be processed.
+  const int num_work_blocks;
+  // Size of the smallest block
+  const int base_block_size;
+  // Number of blocks of size base_block_size + 1
+  const int num_base_p1_sized_blocks;
+
+  // The next block of work to be assigned to a worker.  The parallel for loop
+  // range is split into num_work_blocks blocks of work, with a single block of
+  // work being of size
+  //  - base_block_size + 1 for the first num_base_p1_sized_blocks blocks
+  //  - base_block_size for the rest of the blocks
+  //  blocks of indices are contiguous and disjoint
+  std::atomic<int> block_id;
+
+  // Provides a unique thread ID among all active threads
+  // We do not schedule more than num_threads threads via thread pool
+  // and caller thread might steal one ID
+  std::atomic<int> thread_id;
+
+  // Used to signal when all the work has been completed.  Thread safe.
+  BlockUntilFinished block_until_finished;
+};
+
+// This implementation uses a fixed size max worker pool with a shared task
+// queue. The problem of executing the function for the interval of [start, end)
+// is broken up into at most num_threads * kWorkBlocksPerThread blocks
+// and added to the thread pool. To avoid deadlocks, the calling thread is
+// allowed to steal work from the worker pool.
+// This is implemented via a shared state between the tasks. In order for
+// the calling thread or thread pool to get a block of work, it will query the
+// shared state for the next block of work to be done. If there is nothing left,
+// it will return. We will exit the ParallelFor call when all of the work has
+// been done, not when all of the tasks have been popped off the task queue.
+//
+// A unique thread ID among all active tasks will be acquired once for each
+// block of work.  This avoids the significant performance penalty for acquiring
+// it on every iteration of the for loop. The thread ID is guaranteed to be in
+// [0, num_threads).
+//
+// A performance analysis has shown this implementation is on par with OpenMP
+// and TBB.
+template <typename F>
+void ParallelInvoke(ContextImpl* context,
+                    int start,
+                    int end,
+                    int num_threads,
+                    const F& function) {
+  using namespace parallel_for_details;
+  CHECK(context != nullptr);
+
+  // Maximal number of work items scheduled for a single thread
+  //  - Lower number of work items results in larger runtimes on unequal tasks
+  //  - Higher number of work items results in larger losses for synchronization
+  constexpr int kWorkBlocksPerThread = 4;
+
+  // Interval [start, end) is being split into
+  // num_threads * kWorkBlocksPerThread contiguous disjoint blocks.
+  //
+  // In order to avoid creating empty blocks of work, we need to limit
+  // number of work blocks by a total number of indices.
+  const int num_work_blocks =
+      std::min((end - start), num_threads * kWorkBlocksPerThread);
+
+  // We use a std::shared_ptr because the main thread can finish all
+  // the work before the tasks have been popped off the queue.  So the
+  // shared state needs to exist for the duration of all the tasks.
+  std::shared_ptr<ThreadPoolState> shared_state(
+      new ThreadPoolState(start, end, num_work_blocks, num_threads));
+
+  // A function which tries to perform several chunks of work.
+  auto task = [shared_state, num_threads, &function]() {
+    int num_jobs_finished = 0;
+    const int thread_id = shared_state->thread_id.fetch_add(1);
+    // In order to avoid dead-locks in nested parallel for loops, task() will be
+    // invoked num_threads + 1 times:
+    //  - num_threads times via enqueueing task into thread pool
+    //  - one more time in the main thread
+    //  Tasks enqueued to thread pool might take some time before execution, and
+    //  the last task being executed will be terminated here in order to avoid
+    //  having more than num_threads active threads
+    if (thread_id >= num_threads) return;
+
+    const int start = shared_state->start;
+    const int end = shared_state->end;
+    const int base_block_size = shared_state->base_block_size;
+    const int num_base_p1_sized_blocks = shared_state->num_base_p1_sized_blocks;
+    const int num_work_blocks = shared_state->num_work_blocks;
+
+    while (true) {
+      // Get the next available chunk of work to be performed. If there is no
+      // work, return.
+      int block_id = shared_state->block_id.fetch_add(1);
+      if (block_id >= num_work_blocks) {
+        break;
+      }
+      ++num_jobs_finished;
+
+      // For-loop interval [start, end) was split into num_work_blocks,
+      // with num_base_p1_sized_blocks of size base_block_size + 1 and remaining
+      // num_work_blocks - num_base_p1_sized_blocks of size base_block_size
+      //
+      // Then, start index of the block #block_id is given by a total
+      // length of preceeding blocks:
+      //  * Total length of preceeding blocks of size base_block_size + 1:
+      //     min(block_id, num_base_p1_sized_blocks) * (base_block_size + 1)
+      //
+      //  * Total length of preceeding blocks of size base_block_size:
+      //     (block_id - min(block_id, num_base_p1_sized_blocks)) *
+      //     base_block_size
+      //
+      // Simplifying sum of those quantities yields a following
+      // expression for start index of the block #block_id
+      const int curr_start = start + block_id * base_block_size +
+                             std::min(block_id, num_base_p1_sized_blocks);
+      // First num_base_p1_sized_blocks have size base_block_size + 1
+      //
+      // Note that it is guaranteed that all blocks are within
+      // [start, end) interval
+      const int curr_end = curr_start + base_block_size +
+                           (block_id < num_base_p1_sized_blocks ? 1 : 0);
+      // Perform each task in current block
+      for (int i = curr_start; i < curr_end; ++i) {
+        Invoke<F>(thread_id, i, function);
+      }
+    }
+    shared_state->block_until_finished.Finished(num_jobs_finished);
+  };
+
+  // Add all the tasks to the thread pool.
+  for (int i = 0; i < num_threads; ++i) {
+    // Note we are taking the task as value so the copy of shared_state shared
+    // pointer (captured by value at declaration of task lambda-function) is
+    // copied and the ref count is increased. This is to prevent it from being
+    // deleted when the main thread finishes all the work and exits before the
+    // threads finish.
+    context->thread_pool.AddTask([task]() { task(); });
+  }
+
+  // Try to do any available work on the main thread. This may steal work from
+  // the thread pool, but when there is no work left the thread pool tasks
+  // will be no-ops.
+  task();
+
+  // Wait until all tasks have finished.
+  shared_state->block_until_finished.Block();
+}
+}  // namespace ceres::internal
+#endif
+#endif
diff --git a/internal/ceres/parallel_for_nothreads.cc b/internal/ceres/parallel_for_nothreads.cc
index 1c18716..8d3611d 100644
--- a/internal/ceres/parallel_for_nothreads.cc
+++ b/internal/ceres/parallel_for_nothreads.cc
@@ -41,37 +41,6 @@
 
 int MaxNumThreadsAvailable() { return 1; }
 
-void ParallelFor(ContextImpl* context,
-                 int start,
-                 int end,
-                 int num_threads,
-                 const std::function<void(int)>& function) {
-  CHECK_GT(num_threads, 0);
-  CHECK(context != nullptr);
-  if (end <= start) {
-    return;
-  }
-  for (int i = start; i < end; ++i) {
-    function(i);
-  }
-}
-
-void ParallelFor(ContextImpl* context,
-                 int start,
-                 int end,
-                 int num_threads,
-                 const std::function<void(int thread_id, int i)>& function) {
-  CHECK_GT(num_threads, 0);
-  CHECK(context != nullptr);
-  if (end <= start) {
-    return;
-  }
-  const int thread_id = 0;
-  for (int i = start; i < end; ++i) {
-    function(thread_id, i);
-  }
-}
-
 }  // namespace internal
 }  // namespace ceres
 
diff --git a/internal/ceres/parallel_for_openmp.cc b/internal/ceres/parallel_for_openmp.cc
index 1d44bf9..02690f3 100644
--- a/internal/ceres/parallel_for_openmp.cc
+++ b/internal/ceres/parallel_for_openmp.cc
@@ -34,8 +34,6 @@
 #if defined(CERES_USE_OPENMP)
 
 #include "ceres/parallel_for.h"
-#include "ceres/scoped_thread_token.h"
-#include "ceres/thread_token_provider.h"
 #include "glog/logging.h"
 #include "omp.h"
 
@@ -44,41 +42,6 @@
 
 int MaxNumThreadsAvailable() { return omp_get_max_threads(); }
 
-void ParallelFor(ContextImpl* context,
-                 int start,
-                 int end,
-                 int num_threads,
-                 const std::function<void(int)>& function) {
-  CHECK_GT(num_threads, 0);
-  CHECK(context != nullptr);
-  if (end <= start) {
-    return;
-  }
-
-#ifdef CERES_USE_OPENMP
-#pragma omp parallel for num_threads(num_threads) \
-    schedule(dynamic) if (num_threads > 1)
-#endif  // CERES_USE_OPENMP
-  for (int i = start; i < end; ++i) {
-    function(i);
-  }
-}
-
-void ParallelFor(ContextImpl* context,
-                 int start,
-                 int end,
-                 int num_threads,
-                 const std::function<void(int thread_id, int i)>& function) {
-  CHECK(context != nullptr);
-
-  ThreadTokenProvider thread_token_provider(num_threads);
-  ParallelFor(context, start, end, num_threads, [&](int i) {
-    const ScopedThreadToken scoped_thread_token(&thread_token_provider);
-    const int thread_id = scoped_thread_token.token();
-    function(thread_id, i);
-  });
-}
-
 }  // namespace internal
 }  // namespace ceres
 
diff --git a/internal/ceres/parallel_for_openmp.h b/internal/ceres/parallel_for_openmp.h
new file mode 100644
index 0000000..94254c4
--- /dev/null
+++ b/internal/ceres/parallel_for_openmp.h
@@ -0,0 +1,70 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2018 Google Inc. All rights reserved.
+// http://ceres-solver.org/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+//   this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+//   this list of conditions and the following disclaimer in the documentation
+//   and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+//   used to endorse or promote products derived from this software without
+//   specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Authors: vitus@google.com (Michael Vitus),
+//          dmitriy.korchemkin@gmail.com (Dmitriy Korchemkin)
+
+// This include must come before any #ifndef check on Ceres compile options.
+#ifndef CERES_INTERNAL_PARALLEL_FOR_OPENMP_H_
+#define CERES_INTERNAL_PARALLEL_FOR_OPENMP_H_
+
+#include "ceres/internal/config.h"
+
+#if defined(CERES_USE_OPENMP)
+
+#include "ceres/parallel_for.h"
+#include "ceres/scoped_thread_token.h"
+#include "ceres/thread_token_provider.h"
+#include "glog/logging.h"
+#include "omp.h"
+
+namespace ceres::internal {
+
+template <typename F>
+void ParallelInvoke(ContextImpl* context,
+                    int start,
+                    int end,
+                    int num_threads,
+                    const F& function) {
+  using namespace parallel_for_details;
+  ThreadTokenProvider token_provider(num_threads);
+#pragma omp parallel num_threads(num_threads)
+  {
+    const ScopedThreadToken scoped_thread_token(&token_provider);
+    const int thread_id = scoped_thread_token.token();
+#pragma omp for schedule(guided)
+    for (int i = start; i < end; ++i) {
+      Invoke<F>(thread_id, i, function);
+    }
+  }
+}
+
+}  // namespace ceres::internal
+
+#endif
+#endif