| // Ceres Solver - A fast non-linear least squares minimizer |
| // Copyright 2018 Google Inc. All rights reserved. |
| // http://ceres-solver.org/ |
| // |
| // Redistribution and use in source and binary forms, with or without |
| // modification, are permitted provided that the following conditions are met: |
| // |
| // * Redistributions of source code must retain the above copyright notice, |
| // this list of conditions and the following disclaimer. |
| // * Redistributions in binary form must reproduce the above copyright notice, |
| // this list of conditions and the following disclaimer in the documentation |
| // and/or other materials provided with the distribution. |
| // * Neither the name of Google Inc. nor the names of its contributors may be |
| // used to endorse or promote products derived from this software without |
| // specific prior written permission. |
| // |
| // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
| // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE |
| // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
| // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
| // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
| // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
| // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
| // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE |
| // POSSIBILITY OF SUCH DAMAGE. |
| // |
| // Author: vitus@google.com (Michael Vitus) |
| |
| // This include must come before any #ifndef check on Ceres compile options. |
| #include "ceres/internal/port.h" |
| |
| #ifdef CERES_USE_CXX11_THREADS |
| |
| #include "ceres/parallel_for.h" |
| |
| #include <cmath> |
| #include <condition_variable> |
| #include <memory> |
| #include <mutex> |
| |
| #include "ceres/concurrent_queue.h" |
| #include "ceres/scoped_thread_token.h" |
| #include "ceres/thread_token_provider.h" |
| #include "glog/logging.h" |
| |
| namespace ceres { |
| namespace internal { |
| namespace { |
| // This class creates a thread safe barrier which will block until a |
| // pre-specified number of threads call Finished. This allows us to block the |
| // main thread until all the parallel threads are finished processing all the |
| // work. |
| class BlockUntilFinished { |
| public: |
| explicit BlockUntilFinished(int num_total) |
| : num_finished_(0), num_total_(num_total) {} |
| |
| // Increment the number of jobs that have finished and signal the blocking |
| // thread if all jobs have finished. |
| void Finished() { |
| std::unique_lock<std::mutex> lock(mutex_); |
| ++num_finished_; |
| CHECK_LE(num_finished_, num_total_); |
| if (num_finished_ == num_total_) { |
| condition_.notify_one(); |
| } |
| } |
| |
| // Block until all threads have signaled they are finished. |
| void Block() { |
| std::unique_lock<std::mutex> lock(mutex_); |
| condition_.wait(lock, [&]() { return num_finished_ == num_total_; }); |
| } |
| |
| private: |
| std::mutex mutex_; |
| std::condition_variable condition_; |
| // The current number of jobs finished. |
| int num_finished_; |
| // The total number of jobs. |
| int num_total_; |
| }; |
| |
| // Shared state between the parallel tasks. Each thread will use this |
| // information to get the next block of work to be performed. |
| struct SharedState { |
| SharedState(int start, int end, int num_work_items) |
| : start(start), |
| end(end), |
| num_work_items(num_work_items), |
| i(0), |
| thread_token_provider(num_work_items), |
| block_until_finished(num_work_items) {} |
| |
| // The start and end index of the for loop. |
| const int start; |
| const int end; |
| // The number of blocks that need to be processed. |
| const int num_work_items; |
| |
| // The next block of work to be assigned to a worker. The parallel for loop |
| // range is split into num_work_items blocks of work, i.e. a single block of |
| // work is: |
| // for (int j = start + i; j < end; j += num_work_items) { ... }. |
| int i; |
| std::mutex mutex_i; |
| |
| // Provides a unique thread ID among all active threads working on the same |
| // group of tasks. Thread-safe. |
| ThreadTokenProvider thread_token_provider; |
| |
| // Used to signal when all the work has been completed. Thread safe. |
| BlockUntilFinished block_until_finished; |
| }; |
| |
| } // namespace |
| |
| // See ParallelFor (below) for more details. |
| void ParallelFor(ContextImpl* context, |
| int start, |
| int end, |
| int num_threads, |
| const std::function<void(int)>& function) { |
| CHECK_GT(num_threads, 0); |
| CHECK(context != NULL); |
| if (end <= start) { |
| return; |
| } |
| |
| // Fast path for when it is single threaded. |
| if (num_threads == 1) { |
| for (int i = start; i < end; ++i) { |
| function(i); |
| } |
| return; |
| } |
| |
| ParallelFor(context, start, end, num_threads, |
| [&function](int /*thread_id*/, int i) { function(i); }); |
| } |
| |
| // This implementation uses a fixed size max worker pool with a shared task |
| // queue. The problem of executing the function for the interval of [start, end) |
| // is broken up into at most num_threads blocks and added to the thread pool. To |
| // avoid deadlocks, the calling thread is allowed to steal work from the worker |
| // pool. This is implemented via a shared state between the tasks. In order for |
| // the calling thread or thread pool to get a block of work, it will query the |
| // shared state for the next block of work to be done. If there is nothing left, |
| // it will return. We will exit the ParallelFor call when all of the work has |
| // been done, not when all of the tasks have been popped off the task queue. |
| // |
| // A unique thread ID among all active tasks will be acquired once for each |
| // block of work. This avoids the significant performance penalty for acquiring |
| // it on every iteration of the for loop. The thread ID is guaranteed to be in |
| // [0, num_threads). |
| // |
| // A performance analysis has shown this implementation is onpar with OpenMP and |
| // TBB. |
| void ParallelFor(ContextImpl* context, |
| int start, |
| int end, |
| int num_threads, |
| const std::function<void(int thread_id, int i)>& function) { |
| CHECK_GT(num_threads, 0); |
| CHECK(context != NULL); |
| if (end <= start) { |
| return; |
| } |
| |
| // Fast path for when it is single threaded. |
| if (num_threads == 1) { |
| // Even though we only have one thread, use the thread token provider to |
| // guarantee the exact same behavior when running with multiple threads. |
| ThreadTokenProvider thread_token_provider(num_threads); |
| const ScopedThreadToken scoped_thread_token(&thread_token_provider); |
| const int thread_id = scoped_thread_token.token(); |
| for (int i = start; i < end; ++i) { |
| function(thread_id, i); |
| } |
| return; |
| } |
| |
| // We use a shared_ptr because the main thread can finish all the work before |
| // the tasks have been popped off the queue. So the shared state needs to |
| // exist for the duration of all the tasks. |
| const int num_work_items = std::min((end - start), num_threads); |
| shared_ptr<SharedState> shared_state( |
| new SharedState(start, end, num_work_items)); |
| |
| // A function which tries to perform a chunk of work. This returns false if |
| // there is no work to be done. |
| auto task_function = [shared_state, &function]() { |
| int i = 0; |
| { |
| // Get the next available chunk of work to be performed. If there is no |
| // work, return false. |
| std::unique_lock<std::mutex> lock(shared_state->mutex_i); |
| if (shared_state->i >= shared_state->num_work_items) { |
| return false; |
| } |
| i = shared_state->i; |
| ++shared_state->i; |
| } |
| |
| const ScopedThreadToken scoped_thread_token( |
| &shared_state->thread_token_provider); |
| const int thread_id = scoped_thread_token.token(); |
| |
| // Perform each task. |
| for (int j = shared_state->start + i; |
| j < shared_state->end; |
| j += shared_state->num_work_items) { |
| function(thread_id, j); |
| } |
| shared_state->block_until_finished.Finished(); |
| return true; |
| }; |
| |
| // Add all the tasks to the thread pool. |
| for (int i = 0; i < num_work_items; ++i) { |
| // Note we are taking the task_function as value so the shared_state |
| // shared pointer is copied and the ref count is increased. This is to |
| // prevent it from being deleted when the main thread finishes all the |
| // work and exits before the threads finish. |
| context->thread_pool.AddTask([task_function]() { task_function(); }); |
| } |
| |
| // Try to do any available work on the main thread. This may steal work from |
| // the thread pool, but when there is no work left the thread pool tasks |
| // will be no-ops. |
| while (task_function()) { |
| } |
| |
| // Wait until all tasks have finished. |
| shared_state->block_until_finished.Block(); |
| } |
| |
| } // namespace internal |
| } // namespace ceres |
| |
| #endif // CERES_USE_CXX11_THREADS |