Adds C++11 parallel for implementation.
Implements ParallelFor using the C++11 based ThreadPool. The C++11
parallel for is 50-70% faster than single threaded, and 20-30% slower
than TBB.
Tested by compiling with OpenMP, TBB, and C++11 Threading support and
ran the unit tests. Ran bazel as well.
Change-Id: I7fd6c9037ff9f200ce6999b5f39918995bb6b8ea
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 88410f2..f0a7495 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -108,6 +108,8 @@
option(OPENMP "Enable threaded solving in Ceres (requires OpenMP)" ON)
# Multithreading using TBB
option(TBB "Enable threaded solving in Ceres with TBB (requires TBB and C++11)" OFF)
+# Multithreading using C++11 primitives.
+option(CXX11_THREADS "Enable threaded solving in Ceres with C++11 primitives" OFF)
# Enable the use of Eigen as a sparse linear algebra library for
# solving the nonlinear least squares problems.
option(EIGENSPARSE "Enable Eigen as a sparse linear algebra library." ON)
@@ -444,10 +446,35 @@
endif (TBB_FOUND)
endif (TBB)
-if (NOT OPENMP AND NOT TBB)
- message("-- Neither OpenMP or TBB is enabled, disabling multithreading.")
+if (CXX11_THREADS)
+ # OpenMP and C++11 threads are mutually exclusive options. Fail with an error
+ # if they user requested both.
+ if (OPENMP)
+ message(FATAL_ERROR "OpenMP and C++11 threading support are both enabled "
+ "but they are mutally exclusive. OpenMP is enabled by default. Please "
+ "disable one of them.")
+ endif (OPENMP)
+
+ # C++11 threads and TBB are mutually exclusive options. Fail with an error if
+ # the user requested both.
+ if (TBB)
+ message(FATAL_ERROR "Intel TBB and C++11 threading support are both "
+ "enabled but they are mutally exclusive. Please disable one of them.")
+ endif (TBB)
+
+ if (NOT CXX11)
+ message(FATAL_ERROR "C++11 threading support requires C++11. Please "
+ "enable C++11 to enable.")
+ endif (NOT CXX11)
+
+ list(APPEND CERES_COMPILE_OPTIONS CERES_USE_CXX11_THREADS)
+endif (CXX11_THREADS)
+
+if (NOT OPENMP AND NOT TBB AND NOT CXX11_THREADS)
+ message("-- Neither OpenMP, TBB or C++11 threads is enabled, "
+ "disabling multithreading.")
list(APPEND CERES_COMPILE_OPTIONS CERES_NO_THREADS)
-else (NOT OPENMP AND NOT TBB)
+else (NOT OPENMP AND NOT TBB AND NOT CXX11_THREADS)
if (UNIX)
# At least on Linux, we need pthreads to be enabled for mutex to
# compile. This may not work on Windows or Android.
@@ -455,7 +482,7 @@
list(APPEND CERES_COMPILE_OPTIONS CERES_HAVE_PTHREAD)
list(APPEND CERES_COMPILE_OPTIONS CERES_HAVE_RWLOCK)
endif (UNIX)
-endif (NOT OPENMP AND NOT TBB)
+endif (NOT OPENMP AND NOT TBB AND NOT CXX11_THREADS)
# Initialise CMAKE_REQUIRED_FLAGS used by CheckCXXSourceCompiles with the
# contents of CMAKE_CXX_FLAGS such that if the user has passed extra flags
@@ -513,6 +540,10 @@
message("-- Failed to find C++11 components in C++11 locations & "
"namespaces, disabling CXX11.")
update_cache_variable(CXX11 OFF)
+ if (CXX11_THREADS)
+ message(FATAL_ERROR "C++11 threading requires C++11 components, which we "
+ "failed to find. Please disable C++11 threading to continue.")
+ endif (CXX11_THREADS)
else()
message(" ==============================================================")
message(" Compiling Ceres using C++11. This will result in a version ")
diff --git a/bazel/ceres.bzl b/bazel/ceres.bzl
index f0660f1..090707d 100644
--- a/bazel/ceres.bzl
+++ b/bazel/ceres.bzl
@@ -89,6 +89,7 @@
"low_rank_inverse_hessian.cc",
"minimizer.cc",
"normal_prior.cc",
+ "parallel_for_cxx.cc",
"parallel_for_tbb.cc",
"parameter_block_ordering.cc",
"partitioned_matrix_view.cc",
diff --git a/cmake/config.h.in b/cmake/config.h.in
index 792e315..032db42 100644
--- a/cmake/config.h.in
+++ b/cmake/config.h.in
@@ -69,6 +69,8 @@
@CERES_USE_OPENMP@
// If defined Ceres was compiled with TBB multithreading support.
@CERES_USE_TBB@
+// If defined Ceres was compiled with C++11 thread support.
+@CERES_USE_CXX11_THREADS@
// Additionally defined on *nix if Ceres was compiled with OpenMP or TBB
// support, as in this case pthreads is also required.
@CERES_HAVE_PTHREAD@
diff --git a/include/ceres/internal/port.h b/include/ceres/internal/port.h
index 652f6fb..b567fd5 100644
--- a/include/ceres/internal/port.h
+++ b/include/ceres/internal/port.h
@@ -35,6 +35,11 @@
#ifdef __cplusplus
#include <cstddef>
#include "ceres/internal/config.h"
+
+#if !(defined(CERES_USE_OPENMP) ^ defined(CERES_USE_TBB) ^ defined(CERES_USE_CXX11_THREADS) ^ defined(CERES_NO_THREADS))
+#error CERES_USE_OPENMP, CERES_USE_TBB, CERES_USE_CXX11_THREADS, and CERES_NO_THREADS are mutually exclusive, but multiple are defined.
+#endif
+
#if defined(CERES_TR1_MEMORY_HEADER)
#include <tr1/memory>
#else
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt
index c13e041..92d8cd3 100644
--- a/internal/ceres/CMakeLists.txt
+++ b/internal/ceres/CMakeLists.txt
@@ -90,6 +90,7 @@
low_rank_inverse_hessian.cc
minimizer.cc
normal_prior.cc
+ parallel_for_cxx.cc
parallel_for_tbb.cc
parameter_block_ordering.cc
partitioned_matrix_view.cc
diff --git a/internal/ceres/parallel_for_cxx.cc b/internal/ceres/parallel_for_cxx.cc
new file mode 100644
index 0000000..f03ddb2
--- /dev/null
+++ b/internal/ceres/parallel_for_cxx.cc
@@ -0,0 +1,202 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2018 Google Inc. All rights reserved.
+// http://ceres-solver.org/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: vitus@google.com (Michael Vitus)
+
+// This include must come before any #ifndef check on Ceres compile options.
+#include "ceres/internal/port.h"
+
+#ifdef CERES_USE_CXX11_THREADS
+
+#include "ceres/parallel_for.h"
+
+#include <cmath>
+#include <condition_variable>
+#include <memory>
+#include <mutex>
+
+#include "ceres/concurrent_queue.h"
+#include "ceres/thread_pool.h"
+#include "glog/logging.h"
+
+namespace ceres {
+namespace internal {
+namespace {
+// This class creates a thread safe barrier which will block until a
+// pre-specified number of threads call Finished. This allows us to block the
+// main thread until all the parallel threads are finished processing all the
+// work.
+class BlockUntilFinished {
+ public:
+ explicit BlockUntilFinished(int num_total)
+ : num_finished_(0), num_total_(num_total) {}
+
+ // Increment the number of jobs that have finished and signal the blocking
+ // thread if all jobs have finished.
+ void Finished() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ ++num_finished_;
+ CHECK_LE(num_finished_, num_total_);
+ if (num_finished_ == num_total_) {
+ condition_.notify_one();
+ }
+ }
+
+ // Block until all threads have signaled they are finished.
+ void Block() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ condition_.wait(lock, [&]() { return num_finished_ == num_total_; });
+ }
+
+ private:
+ std::mutex mutex_;
+ std::condition_variable condition_;
+ // The current number of jobs finished.
+ int num_finished_;
+ // The total number of jobs.
+ int num_total_;
+};
+
+// Shared state between the parallel tasks. Each thread will use this
+// information to get the next block of work to be performed.
+struct SharedState {
+ SharedState(int start, int end, int num_work_items)
+ : start(start),
+ end(end),
+ num_work_items(num_work_items),
+ i(0),
+ block_until_finished(num_work_items) {}
+
+ // The start and end index of the for loop.
+ const int start;
+ const int end;
+ // The number of blocks that need to be processed.
+ const int num_work_items;
+
+ // The next block of work to be assigned to a worker. The parallel for loop
+ // range is split into num_work_items blocks of work, i.e. a single block of
+ // work is:
+ // for (int j = start + i; j < end; j += num_work_items) { ... }.
+ int i;
+ std::mutex mutex_i;
+
+ // Used to signal when all the work has been completed.
+ BlockUntilFinished block_until_finished;
+};
+
+} // namespace
+
+// This implementation uses a fixed size max worker pool with a shared task
+// queue. The problem of executing the function for the interval of [start, end)
+// is broken up into at most num_threads blocks and added to the thread pool. To
+// avoid deadlocks, the calling thread is allowed to steal work from the worker
+// pool. This is implemented via a shared state between the tasks. In order for
+// the calling thread or thread pool to get a block of work, it will query the
+// shared state for the next block of work to be done. If there is nothing left,
+// it will return. We will exit the ParallelFor call when all of the work has
+// been done, not when all of the tasks have been popped off the task queue.
+//
+// A performance analysis has shown this implementation is about ~20% slower
+// than OpenMP or TBB. This native implementation is a fix for platforms that do
+// not have access to OpenMP or TBB. The gain in enabling multi-threaded Ceres
+// is much more significant so we decided to not chase the performance of these
+// two libraries.
+void ParallelFor(ContextImpl* context,
+ int start,
+ int end,
+ int num_threads,
+ const std::function<void(int)>& function) {
+ CHECK_GT(num_threads, 0);
+ CHECK(context != NULL);
+ if (end <= start) {
+ return;
+ }
+
+ // Fast path for when it is single threaded.
+ if (num_threads == 1) {
+ for (int i = start; i < end; ++i) {
+ function(i);
+ }
+ return;
+ }
+
+ // We use a shared_ptr because the main thread can finish all the work before
+ // the tasks have been popped off the queue. So the shared state needs to
+ // exist for the duration of all the tasks.
+ const int num_work_items = std::min((end - start), num_threads);
+ shared_ptr<SharedState> shared_state(
+ new SharedState(start, end, num_work_items));
+
+ // A function which tries to perform a chunk of work. This returns false if
+ // there is no work to be done.
+ auto task_function = [shared_state, &function]() {
+ int i = 0;
+ {
+ // Get the next available chunk of work to be performed. If there is no
+ // work, return false.
+ std::unique_lock<std::mutex> lock(shared_state->mutex_i);
+ if (shared_state->i >= shared_state->num_work_items) {
+ return false;
+ }
+ i = shared_state->i;
+ ++shared_state->i;
+ }
+
+ // Perform each task.
+ for (int j = shared_state->start + i;
+ j < shared_state->end;
+ j += shared_state->num_work_items) {
+ function(j);
+ }
+ shared_state->block_until_finished.Finished();
+ return true;
+ };
+
+ // Add all the tasks to the thread pool.
+ for (int i = 0; i < num_work_items; ++i) {
+ // Note we are taking the task_function as value so the shared_state
+ // shared pointer is copied and the ref count is increased. This is to
+ // prevent it from being deleted when the main thread finishes all the
+ // work and exits before the threads finish.
+ context->thread_pool.AddTask([task_function]() { task_function(); });
+ }
+
+ // Try to do any available work on the main thread. This may steal work from
+ // the thread pool, but when there is no work left the thread pool tasks
+ // will be no-ops.
+ while (task_function()) {
+ }
+
+ // Wait until all tasks have finished.
+ shared_state->block_until_finished.Block();
+}
+
+} // namespace internal
+} // namespace ceres
+
+#endif // CERES_USE_CXX11_THREADS
diff --git a/internal/ceres/thread_token_provider.cc b/internal/ceres/thread_token_provider.cc
index 5648d4d..61a16e0 100644
--- a/internal/ceres/thread_token_provider.cc
+++ b/internal/ceres/thread_token_provider.cc
@@ -45,6 +45,13 @@
pool_.push(i);
}
#endif
+
+#ifdef CERES_USE_CXX11_THREADS
+ for (int i = 0; i < num_threads; i++) {
+ pool_.Push(i);
+ }
+#endif
+
}
int ThreadTokenProvider::Acquire() {
@@ -61,6 +68,13 @@
pool_.pop(thread_id);
return thread_id;
#endif
+
+#ifdef CERES_USE_CXX11_THREADS
+ int thread_id;
+ CHECK(pool_.Wait(&thread_id));
+ return thread_id;
+#endif
+
}
void ThreadTokenProvider::Release(int thread_id) {
@@ -68,6 +82,11 @@
#ifdef CERES_USE_TBB
pool_.push(thread_id);
#endif
+
+#ifdef CERES_USE_CXX11_THREADS
+ pool_.Push(thread_id);
+#endif
+
}
} // namespace internal
diff --git a/internal/ceres/thread_token_provider.h b/internal/ceres/thread_token_provider.h
index 209445d..841d312 100644
--- a/internal/ceres/thread_token_provider.h
+++ b/internal/ceres/thread_token_provider.h
@@ -32,35 +32,25 @@
#define CERES_INTERNAL_THREAD_TOKEN_PROVIDER_H_
#include "ceres/internal/config.h"
-
-#if defined(CERES_USE_OPENMP)
-# if defined(CERES_USE_TBB) || defined(CERES_NO_THREADS)
-# error CERES_USE_OPENMP is mutually exclusive to CERES_USE_TBB and CERES_NO_THREADS
-# endif
-#elif defined(CERES_USE_TBB)
-# if defined(CERES_USE_OPENMP) || defined(CERES_NO_THREADS)
-# error CERES_USE_TBB is mutually exclusive to CERES_USE_OPENMP and CERES_NO_THREADS
-# endif
-#elif defined(CERES_NO_THREADS)
-# if defined(CERES_USE_OPENMP) || defined(CERES_USE_TBB)
-# error CERES_NO_THREADS is mutually exclusive to CERES_USE_OPENMP and CERES_USE_TBB
-# endif
-#else
-# error One of CERES_USE_OPENMP, CERES_USE_TBB or CERES_NO_THREADS must be defined.
-#endif
+#include "ceres/internal/port.h"
#ifdef CERES_USE_TBB
#include <tbb/concurrent_queue.h>
#endif
+#ifdef CERES_USE_CXX11_THREADS
+#include "ceres/concurrent_queue.h"
+#endif
+
namespace ceres {
namespace internal {
-// Helper for TBB thread number identification that is similar to
-// omp_get_thread_num() behaviour. This is necessary to support TBB threading
-// with a sequential thread id. This is used to access preallocated resources in
-// the parallelized code parts. The sequence of tokens varies from 0 to
-// num_threads - 1 that can be acquired to identify the thread in a thread pool.
+// Helper for TBB and C++11 thread number identification that is similar to
+// omp_get_thread_num() behaviour. This is necessary to support TBB and C++11
+// threading with a sequential thread id. This is used to access preallocated
+// resources in the parallelized code parts. The sequence of tokens varies from
+// 0 to num_threads - 1 that can be acquired to identify the thread in a thread
+// pool.
//
// If CERES_NO_THREADS is defined, Acquire() always returns 0 and Release()
// takes no action.
@@ -102,6 +92,22 @@
tbb::concurrent_bounded_queue<int> pool_;
#endif
+#ifdef CERES_USE_CXX11_THREADS
+ // This queue initially holds a sequence from 0..num_threads-1. Every
+ // Acquire() call the first number is removed from here. When the token is not
+ // needed anymore it shall be given back with corresponding Release() call.
+ //
+ // The thread number is acquired on every for loop iteration. The
+ // ConcurrentQueue uses a mutex to enable thread safety, however, this can
+ // lead to a large amount of contention between the threads which can cause a
+ // loss in performance. This is noticable for problems with inexpensive
+ // residual computations.
+ //
+ // TODO(vitus): We should either implement a more performant queue (such as
+ // lock free), or get the thread ID from the shared state.
+ ConcurrentQueue<int> pool_;
+#endif
+
ThreadTokenProvider(ThreadTokenProvider&);
ThreadTokenProvider& operator=(ThreadTokenProvider&);
};