Adds C++11 parallel for implementation. Implements ParallelFor using the C++11 based ThreadPool. The C++11 parallel for is 50-70% faster than single threaded, and 20-30% slower than TBB. Tested by compiling with OpenMP, TBB, and C++11 Threading support and ran the unit tests. Ran bazel as well. Change-Id: I7fd6c9037ff9f200ce6999b5f39918995bb6b8ea
diff --git a/CMakeLists.txt b/CMakeLists.txt index 88410f2..f0a7495 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt
@@ -108,6 +108,8 @@ option(OPENMP "Enable threaded solving in Ceres (requires OpenMP)" ON) # Multithreading using TBB option(TBB "Enable threaded solving in Ceres with TBB (requires TBB and C++11)" OFF) +# Multithreading using C++11 primitives. +option(CXX11_THREADS "Enable threaded solving in Ceres with C++11 primitives" OFF) # Enable the use of Eigen as a sparse linear algebra library for # solving the nonlinear least squares problems. option(EIGENSPARSE "Enable Eigen as a sparse linear algebra library." ON) @@ -444,10 +446,35 @@ endif (TBB_FOUND) endif (TBB) -if (NOT OPENMP AND NOT TBB) - message("-- Neither OpenMP or TBB is enabled, disabling multithreading.") +if (CXX11_THREADS) + # OpenMP and C++11 threads are mutually exclusive options. Fail with an error + # if they user requested both. + if (OPENMP) + message(FATAL_ERROR "OpenMP and C++11 threading support are both enabled " + "but they are mutally exclusive. OpenMP is enabled by default. Please " + "disable one of them.") + endif (OPENMP) + + # C++11 threads and TBB are mutually exclusive options. Fail with an error if + # the user requested both. + if (TBB) + message(FATAL_ERROR "Intel TBB and C++11 threading support are both " + "enabled but they are mutally exclusive. Please disable one of them.") + endif (TBB) + + if (NOT CXX11) + message(FATAL_ERROR "C++11 threading support requires C++11. Please " + "enable C++11 to enable.") + endif (NOT CXX11) + + list(APPEND CERES_COMPILE_OPTIONS CERES_USE_CXX11_THREADS) +endif (CXX11_THREADS) + +if (NOT OPENMP AND NOT TBB AND NOT CXX11_THREADS) + message("-- Neither OpenMP, TBB or C++11 threads is enabled, " + "disabling multithreading.") list(APPEND CERES_COMPILE_OPTIONS CERES_NO_THREADS) -else (NOT OPENMP AND NOT TBB) +else (NOT OPENMP AND NOT TBB AND NOT CXX11_THREADS) if (UNIX) # At least on Linux, we need pthreads to be enabled for mutex to # compile. This may not work on Windows or Android. @@ -455,7 +482,7 @@ list(APPEND CERES_COMPILE_OPTIONS CERES_HAVE_PTHREAD) list(APPEND CERES_COMPILE_OPTIONS CERES_HAVE_RWLOCK) endif (UNIX) -endif (NOT OPENMP AND NOT TBB) +endif (NOT OPENMP AND NOT TBB AND NOT CXX11_THREADS) # Initialise CMAKE_REQUIRED_FLAGS used by CheckCXXSourceCompiles with the # contents of CMAKE_CXX_FLAGS such that if the user has passed extra flags @@ -513,6 +540,10 @@ message("-- Failed to find C++11 components in C++11 locations & " "namespaces, disabling CXX11.") update_cache_variable(CXX11 OFF) + if (CXX11_THREADS) + message(FATAL_ERROR "C++11 threading requires C++11 components, which we " + "failed to find. Please disable C++11 threading to continue.") + endif (CXX11_THREADS) else() message(" ==============================================================") message(" Compiling Ceres using C++11. This will result in a version ")
diff --git a/bazel/ceres.bzl b/bazel/ceres.bzl index f0660f1..090707d 100644 --- a/bazel/ceres.bzl +++ b/bazel/ceres.bzl
@@ -89,6 +89,7 @@ "low_rank_inverse_hessian.cc", "minimizer.cc", "normal_prior.cc", + "parallel_for_cxx.cc", "parallel_for_tbb.cc", "parameter_block_ordering.cc", "partitioned_matrix_view.cc",
diff --git a/cmake/config.h.in b/cmake/config.h.in index 792e315..032db42 100644 --- a/cmake/config.h.in +++ b/cmake/config.h.in
@@ -69,6 +69,8 @@ @CERES_USE_OPENMP@ // If defined Ceres was compiled with TBB multithreading support. @CERES_USE_TBB@ +// If defined Ceres was compiled with C++11 thread support. +@CERES_USE_CXX11_THREADS@ // Additionally defined on *nix if Ceres was compiled with OpenMP or TBB // support, as in this case pthreads is also required. @CERES_HAVE_PTHREAD@
diff --git a/include/ceres/internal/port.h b/include/ceres/internal/port.h index 652f6fb..b567fd5 100644 --- a/include/ceres/internal/port.h +++ b/include/ceres/internal/port.h
@@ -35,6 +35,11 @@ #ifdef __cplusplus #include <cstddef> #include "ceres/internal/config.h" + +#if !(defined(CERES_USE_OPENMP) ^ defined(CERES_USE_TBB) ^ defined(CERES_USE_CXX11_THREADS) ^ defined(CERES_NO_THREADS)) +#error CERES_USE_OPENMP, CERES_USE_TBB, CERES_USE_CXX11_THREADS, and CERES_NO_THREADS are mutually exclusive, but multiple are defined. +#endif + #if defined(CERES_TR1_MEMORY_HEADER) #include <tr1/memory> #else
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt index c13e041..92d8cd3 100644 --- a/internal/ceres/CMakeLists.txt +++ b/internal/ceres/CMakeLists.txt
@@ -90,6 +90,7 @@ low_rank_inverse_hessian.cc minimizer.cc normal_prior.cc + parallel_for_cxx.cc parallel_for_tbb.cc parameter_block_ordering.cc partitioned_matrix_view.cc
diff --git a/internal/ceres/parallel_for_cxx.cc b/internal/ceres/parallel_for_cxx.cc new file mode 100644 index 0000000..f03ddb2 --- /dev/null +++ b/internal/ceres/parallel_for_cxx.cc
@@ -0,0 +1,202 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2018 Google Inc. All rights reserved. +// http://ceres-solver.org/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: vitus@google.com (Michael Vitus) + +// This include must come before any #ifndef check on Ceres compile options. +#include "ceres/internal/port.h" + +#ifdef CERES_USE_CXX11_THREADS + +#include "ceres/parallel_for.h" + +#include <cmath> +#include <condition_variable> +#include <memory> +#include <mutex> + +#include "ceres/concurrent_queue.h" +#include "ceres/thread_pool.h" +#include "glog/logging.h" + +namespace ceres { +namespace internal { +namespace { +// This class creates a thread safe barrier which will block until a +// pre-specified number of threads call Finished. This allows us to block the +// main thread until all the parallel threads are finished processing all the +// work. +class BlockUntilFinished { + public: + explicit BlockUntilFinished(int num_total) + : num_finished_(0), num_total_(num_total) {} + + // Increment the number of jobs that have finished and signal the blocking + // thread if all jobs have finished. + void Finished() { + std::unique_lock<std::mutex> lock(mutex_); + ++num_finished_; + CHECK_LE(num_finished_, num_total_); + if (num_finished_ == num_total_) { + condition_.notify_one(); + } + } + + // Block until all threads have signaled they are finished. + void Block() { + std::unique_lock<std::mutex> lock(mutex_); + condition_.wait(lock, [&]() { return num_finished_ == num_total_; }); + } + + private: + std::mutex mutex_; + std::condition_variable condition_; + // The current number of jobs finished. + int num_finished_; + // The total number of jobs. + int num_total_; +}; + +// Shared state between the parallel tasks. Each thread will use this +// information to get the next block of work to be performed. +struct SharedState { + SharedState(int start, int end, int num_work_items) + : start(start), + end(end), + num_work_items(num_work_items), + i(0), + block_until_finished(num_work_items) {} + + // The start and end index of the for loop. + const int start; + const int end; + // The number of blocks that need to be processed. + const int num_work_items; + + // The next block of work to be assigned to a worker. The parallel for loop + // range is split into num_work_items blocks of work, i.e. a single block of + // work is: + // for (int j = start + i; j < end; j += num_work_items) { ... }. + int i; + std::mutex mutex_i; + + // Used to signal when all the work has been completed. + BlockUntilFinished block_until_finished; +}; + +} // namespace + +// This implementation uses a fixed size max worker pool with a shared task +// queue. The problem of executing the function for the interval of [start, end) +// is broken up into at most num_threads blocks and added to the thread pool. To +// avoid deadlocks, the calling thread is allowed to steal work from the worker +// pool. This is implemented via a shared state between the tasks. In order for +// the calling thread or thread pool to get a block of work, it will query the +// shared state for the next block of work to be done. If there is nothing left, +// it will return. We will exit the ParallelFor call when all of the work has +// been done, not when all of the tasks have been popped off the task queue. +// +// A performance analysis has shown this implementation is about ~20% slower +// than OpenMP or TBB. This native implementation is a fix for platforms that do +// not have access to OpenMP or TBB. The gain in enabling multi-threaded Ceres +// is much more significant so we decided to not chase the performance of these +// two libraries. +void ParallelFor(ContextImpl* context, + int start, + int end, + int num_threads, + const std::function<void(int)>& function) { + CHECK_GT(num_threads, 0); + CHECK(context != NULL); + if (end <= start) { + return; + } + + // Fast path for when it is single threaded. + if (num_threads == 1) { + for (int i = start; i < end; ++i) { + function(i); + } + return; + } + + // We use a shared_ptr because the main thread can finish all the work before + // the tasks have been popped off the queue. So the shared state needs to + // exist for the duration of all the tasks. + const int num_work_items = std::min((end - start), num_threads); + shared_ptr<SharedState> shared_state( + new SharedState(start, end, num_work_items)); + + // A function which tries to perform a chunk of work. This returns false if + // there is no work to be done. + auto task_function = [shared_state, &function]() { + int i = 0; + { + // Get the next available chunk of work to be performed. If there is no + // work, return false. + std::unique_lock<std::mutex> lock(shared_state->mutex_i); + if (shared_state->i >= shared_state->num_work_items) { + return false; + } + i = shared_state->i; + ++shared_state->i; + } + + // Perform each task. + for (int j = shared_state->start + i; + j < shared_state->end; + j += shared_state->num_work_items) { + function(j); + } + shared_state->block_until_finished.Finished(); + return true; + }; + + // Add all the tasks to the thread pool. + for (int i = 0; i < num_work_items; ++i) { + // Note we are taking the task_function as value so the shared_state + // shared pointer is copied and the ref count is increased. This is to + // prevent it from being deleted when the main thread finishes all the + // work and exits before the threads finish. + context->thread_pool.AddTask([task_function]() { task_function(); }); + } + + // Try to do any available work on the main thread. This may steal work from + // the thread pool, but when there is no work left the thread pool tasks + // will be no-ops. + while (task_function()) { + } + + // Wait until all tasks have finished. + shared_state->block_until_finished.Block(); +} + +} // namespace internal +} // namespace ceres + +#endif // CERES_USE_CXX11_THREADS
diff --git a/internal/ceres/thread_token_provider.cc b/internal/ceres/thread_token_provider.cc index 5648d4d..61a16e0 100644 --- a/internal/ceres/thread_token_provider.cc +++ b/internal/ceres/thread_token_provider.cc
@@ -45,6 +45,13 @@ pool_.push(i); } #endif + +#ifdef CERES_USE_CXX11_THREADS + for (int i = 0; i < num_threads; i++) { + pool_.Push(i); + } +#endif + } int ThreadTokenProvider::Acquire() { @@ -61,6 +68,13 @@ pool_.pop(thread_id); return thread_id; #endif + +#ifdef CERES_USE_CXX11_THREADS + int thread_id; + CHECK(pool_.Wait(&thread_id)); + return thread_id; +#endif + } void ThreadTokenProvider::Release(int thread_id) { @@ -68,6 +82,11 @@ #ifdef CERES_USE_TBB pool_.push(thread_id); #endif + +#ifdef CERES_USE_CXX11_THREADS + pool_.Push(thread_id); +#endif + } } // namespace internal
diff --git a/internal/ceres/thread_token_provider.h b/internal/ceres/thread_token_provider.h index 209445d..841d312 100644 --- a/internal/ceres/thread_token_provider.h +++ b/internal/ceres/thread_token_provider.h
@@ -32,35 +32,25 @@ #define CERES_INTERNAL_THREAD_TOKEN_PROVIDER_H_ #include "ceres/internal/config.h" - -#if defined(CERES_USE_OPENMP) -# if defined(CERES_USE_TBB) || defined(CERES_NO_THREADS) -# error CERES_USE_OPENMP is mutually exclusive to CERES_USE_TBB and CERES_NO_THREADS -# endif -#elif defined(CERES_USE_TBB) -# if defined(CERES_USE_OPENMP) || defined(CERES_NO_THREADS) -# error CERES_USE_TBB is mutually exclusive to CERES_USE_OPENMP and CERES_NO_THREADS -# endif -#elif defined(CERES_NO_THREADS) -# if defined(CERES_USE_OPENMP) || defined(CERES_USE_TBB) -# error CERES_NO_THREADS is mutually exclusive to CERES_USE_OPENMP and CERES_USE_TBB -# endif -#else -# error One of CERES_USE_OPENMP, CERES_USE_TBB or CERES_NO_THREADS must be defined. -#endif +#include "ceres/internal/port.h" #ifdef CERES_USE_TBB #include <tbb/concurrent_queue.h> #endif +#ifdef CERES_USE_CXX11_THREADS +#include "ceres/concurrent_queue.h" +#endif + namespace ceres { namespace internal { -// Helper for TBB thread number identification that is similar to -// omp_get_thread_num() behaviour. This is necessary to support TBB threading -// with a sequential thread id. This is used to access preallocated resources in -// the parallelized code parts. The sequence of tokens varies from 0 to -// num_threads - 1 that can be acquired to identify the thread in a thread pool. +// Helper for TBB and C++11 thread number identification that is similar to +// omp_get_thread_num() behaviour. This is necessary to support TBB and C++11 +// threading with a sequential thread id. This is used to access preallocated +// resources in the parallelized code parts. The sequence of tokens varies from +// 0 to num_threads - 1 that can be acquired to identify the thread in a thread +// pool. // // If CERES_NO_THREADS is defined, Acquire() always returns 0 and Release() // takes no action. @@ -102,6 +92,22 @@ tbb::concurrent_bounded_queue<int> pool_; #endif +#ifdef CERES_USE_CXX11_THREADS + // This queue initially holds a sequence from 0..num_threads-1. Every + // Acquire() call the first number is removed from here. When the token is not + // needed anymore it shall be given back with corresponding Release() call. + // + // The thread number is acquired on every for loop iteration. The + // ConcurrentQueue uses a mutex to enable thread safety, however, this can + // lead to a large amount of contention between the threads which can cause a + // loss in performance. This is noticable for problems with inexpensive + // residual computations. + // + // TODO(vitus): We should either implement a more performant queue (such as + // lock free), or get the thread ID from the shared state. + ConcurrentQueue<int> pool_; +#endif + ThreadTokenProvider(ThreadTokenProvider&); ThreadTokenProvider& operator=(ThreadTokenProvider&); };