Adds a ThreadPool and a thread-safe ConcurrentQueue. This is in preparation for adding support for a C++11 based parallel for implementation. The code is behind CERES_USE_CXX11_THREADS which is not exposed to the user yet. Tested by building with and without CERES_USE_CXX11_THREADS defined and the tests pass. Change-Id: I60f5730fa055feeb0ee0fa6c980633aebd8d87b4
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt index 91ffc11..4080ea7 100644 --- a/internal/ceres/CMakeLists.txt +++ b/internal/ceres/CMakeLists.txt
@@ -115,6 +115,7 @@ split.cc stringprintf.cc suitesparse.cc + thread_pool.cc thread_token_provider.cc triplet_sparse_matrix.cc trust_region_preprocessor.cc @@ -304,6 +305,7 @@ ceres_test(canonical_views_clustering) ceres_test(compressed_col_sparse_matrix_utils) ceres_test(compressed_row_sparse_matrix) + ceres_test(concurrent_queue) ceres_test(conditioned_cost_function) ceres_test(conjugate_gradients_solver) ceres_test(corrector) @@ -365,6 +367,7 @@ ceres_test(tiny_solver) ceres_test(tiny_solver_autodiff_function) ceres_test(tiny_solver_cost_function_adapter) + ceres_test(thread_pool) ceres_test(triplet_sparse_matrix) ceres_test(trust_region_minimizer) ceres_test(trust_region_preprocessor)
diff --git a/internal/ceres/concurrent_queue.h b/internal/ceres/concurrent_queue.h new file mode 100644 index 0000000..c4e076f --- /dev/null +++ b/internal/ceres/concurrent_queue.h
@@ -0,0 +1,159 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2018 Google Inc. All rights reserved. +// http://ceres-solver.org/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: vitus@google.com (Michael Vitus) + +#ifndef CERES_INTERNAL_CONCURRENT_QUEUE_H_ +#define CERES_INTERNAL_CONCURRENT_QUEUE_H_ + +#include <condition_variable> +#include <mutex> +#include <queue> +#include <thread> + +#include "glog/logging.h" + +namespace ceres { +namespace internal { + +// A thread-safe multi-producer, multi-consumer queue for queueing items that +// are typically handled asynchronously by multiple threads. The ConcurrentQueue +// has two states which only affect the Wait call: +// +// (1) Waiters have been enabled (enabled by default or calling +// EnableWaiters). The call to Wait will block until an item is available. +// Push and pop will operate as expected. +// +// (2) StopWaiters has been called. All threads blocked in a Wait() call will +// be woken up and pop any available items from the queue. All future Wait +// requests will either return an element from the queue or return +// immediately if no element is present. Push and pop will operate as +// expected. +// +// A common use case is using the concurrent queue as an interface for +// scheduling tasks for a set of thread workers: +// +// ConcurrentQueue<Task> task_queue; +// +// [Worker threads]: +// Task task; +// while(task_queue.Wait(&task)) { +// ... +// } +// +// [Producers]: +// task_queue.Push(...); +// .. +// task_queue.Push(...); +// ... +// // Signal worker threads to stop blocking on Wait and terminate. +// task_queue.StopWaiters(); +// +template <typename T> +class ConcurrentQueue { + public: + // Defaults the queue to blocking on Wait calls. + ConcurrentQueue() : wait_(true) {} + + // Atomically push an element onto the queue. If a thread was waiting for an + // element, wake it up. + void Push(const T& value) { + std::unique_lock<std::mutex> lock(mutex_); + queue_.push(value); + work_pending_condition_.notify_one(); + } + + // Atomically pop an element from the queue. If an element is present, return + // true. If the queue was empty, return false. + bool Pop(T* value) { + CHECK(value != nullptr); + + std::unique_lock<std::mutex> lock(mutex_); + return PopUnlocked(value); + } + + // Atomically pop an element from the queue. Blocks until one is available or + // StopWaiters is called. Returns true if an element was successfully popped + // from the queue, otherwise returns false. + bool Wait(T* value) { + CHECK(value != nullptr); + + std::unique_lock<std::mutex> lock(mutex_); + work_pending_condition_.wait(lock, + [&]() { return !(wait_ && queue_.empty()); }); + + return PopUnlocked(value); + } + + // Unblock all threads waiting to pop a value from the queue, and they will + // exit Wait() without getting a value. All future Wait requests will return + // immediately if no element is present until EnableWaiters is called. + void StopWaiters() { + std::unique_lock<std::mutex> lock(mutex_); + wait_ = false; + work_pending_condition_.notify_all(); + } + + // Enable threads to block on Wait calls. + void EnableWaiters() { + std::unique_lock<std::mutex> lock(mutex_); + wait_ = true; + } + + private: + // Pops an element from the queue. If an element is present, return + // true. If the queue was empty, return false. Not thread-safe. Must acquire + // the lock before calling. + bool PopUnlocked(T* value) { + if (queue_.empty()) { + return false; + } + + *value = queue_.front(); + queue_.pop(); + + return true; + } + + // The mutex controls read and write access to the queue_ and stop_ + // variables. It is also used to block the calling thread until an element is + // available to pop from the queue. + std::mutex mutex_; + std::condition_variable work_pending_condition_; + + std::queue<T> queue_; + // If true, signals that callers of Wait will block waiting to pop an + // element off the queue. + bool wait_; +}; + + +} // namespace internal +} // namespace ceres + +#endif // CERES_INTERNAL_CONCURRENT_QUEUE_H_
diff --git a/internal/ceres/concurrent_queue_test.cc b/internal/ceres/concurrent_queue_test.cc new file mode 100644 index 0000000..3b15c4b --- /dev/null +++ b/internal/ceres/concurrent_queue_test.cc
@@ -0,0 +1,307 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2018 Google Inc. All rights reserved. +// http://ceres-solver.org/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: vitus@google.com (Michael Vitus) + +// This include must come before any #ifndef check on Ceres compile options. +#include "ceres/internal/port.h" + +#ifdef CERES_USE_CXX11_THREADS + +#include <chrono> +#include <thread> + +#include "ceres/concurrent_queue.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace ceres { +namespace internal { + +// A basic test of push and pop. +TEST(ConcurrentQueue, PushPop) { + ConcurrentQueue<int> queue; + + const int num_to_add = 10; + for (int i = 0; i < num_to_add; ++i) { + queue.Push(i); + } + + for (int i = 0; i < num_to_add; ++i) { + int value; + ASSERT_TRUE(queue.Pop(&value)); + EXPECT_EQ(i, value); + } +} + +// Push and pop elements from the queue after StopWaiters has been called. +TEST(ConcurrentQueue, PushPopAfterStopWaiters) { + ConcurrentQueue<int> queue; + + const int num_to_add = 10; + int value; + + // Pop should return immediately with false with an empty queue. + ASSERT_FALSE(queue.Pop(&value)); + + for (int i = 0; i < num_to_add; ++i) { + queue.Push(i); + } + + // Call stop waiters to ensure we can still Push and Pop from the queue. + queue.StopWaiters(); + + for (int i = 0; i < num_to_add; ++i) { + ASSERT_TRUE(queue.Pop(&value)); + EXPECT_EQ(i, value); + } + + // Pop should return immediately with false with an empty queue. + ASSERT_FALSE(queue.Pop(&value)); + + // Ensure we can still push onto the queue after StopWaiters has been called. + const int offset = 123; + for (int i = 0; i < num_to_add; ++i) { + queue.Push(i + offset); + } + + for (int i = 0; i < num_to_add; ++i) { + int value; + ASSERT_TRUE(queue.Pop(&value)); + EXPECT_EQ(i + offset, value); + } + + // Pop should return immediately with false with an empty queue. + ASSERT_FALSE(queue.Pop(&value)); + + // Try calling StopWaiters again to ensure nothing changes. + queue.StopWaiters(); + + queue.Push(13456); + ASSERT_TRUE(queue.Pop(&value)); + EXPECT_EQ(13456, value); +} + +// Push and pop elements after StopWaiters and EnableWaiters has been called. +TEST(ConcurrentQueue, PushPopStopAndStart) { + ConcurrentQueue<int> queue; + + int value; + + queue.Push(13456); + queue.Push(256); + + queue.StopWaiters(); + + ASSERT_TRUE(queue.Pop(&value)); + EXPECT_EQ(13456, value); + + queue.EnableWaiters(); + + // Try adding another entry after enable has been called. + queue.Push(989); + + // Ensure we can pop both elements off. + ASSERT_TRUE(queue.Pop(&value)); + EXPECT_EQ(256, value); + + ASSERT_TRUE(queue.Pop(&value)); + EXPECT_EQ(989, value); + + // Re-enable waiting. + queue.EnableWaiters(); + + // Pop should return immediately with false with an empty queue. + ASSERT_FALSE(queue.Pop(&value)); +} + +// A basic test for Wait. +TEST(ConcurrentQueue, Wait) { + ConcurrentQueue<int> queue; + + int value; + + queue.Push(13456); + + ASSERT_TRUE(queue.Wait(&value)); + EXPECT_EQ(13456, value); + + queue.StopWaiters(); + + // Ensure waiting returns immediately after StopWaiters. + EXPECT_FALSE(queue.Wait(&value)); + EXPECT_FALSE(queue.Wait(&value)); + + EXPECT_FALSE(queue.Pop(&value)); + + // Calling StopWaiters multiple times does not change anything. + queue.StopWaiters(); + + EXPECT_FALSE(queue.Wait(&value)); + EXPECT_FALSE(queue.Wait(&value)); + + queue.Push(989); + queue.Push(789); + + ASSERT_TRUE(queue.Wait(&value)); + EXPECT_EQ(989, value); + + ASSERT_TRUE(queue.Wait(&value)); + EXPECT_EQ(789, value); +} + +// Ensure wait blocks until an element is pushed. Also ensure wait does not +// block after StopWaiters is called and there is no value in the queue. +// Finally, ensures EnableWaiters re-enables waiting. +TEST(ConcurrentQueue, EnsureWaitBlocks) { + ConcurrentQueue<int> queue; + + int value = 0; + bool valid_value = false; + bool waiting = false; + std::mutex mutex; + + std::thread thread([&]() { + { + std::unique_lock<std::mutex> lock(mutex); + waiting = true; + } + + int element = 87987; + bool valid = queue.Wait(&element); + + { + std::unique_lock<std::mutex> lock(mutex); + waiting = false; + value = element; + valid_value = valid; + } + }); + + // Give the thread time to start and wait. + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + + // Ensure nothing is has been popped off the queue + { + std::unique_lock<std::mutex> lock(mutex); + EXPECT_TRUE(waiting); + ASSERT_FALSE(valid_value); + ASSERT_EQ(0, value); + } + + queue.Push(13456); + + // Wait for the thread to pop the value. + thread.join(); + + EXPECT_TRUE(valid_value); + EXPECT_EQ(13456, value); +} + +TEST(ConcurrentQueue, StopAndEnableWaiters) { + ConcurrentQueue<int> queue; + + int value = 0; + bool valid_value = false; + bool waiting = false; + std::mutex mutex; + + auto task = [&]() { + { + std::unique_lock<std::mutex> lock(mutex); + waiting = true; + } + + int element = 87987; + bool valid = queue.Wait(&element); + + { + std::unique_lock<std::mutex> lock(mutex); + waiting = false; + value = element; + valid_value = valid; + } + }; + + std::thread thread_1(task); + + // Give the thread time to start and wait. + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + + // Ensure the thread is waiting. + { + std::unique_lock<std::mutex> lock(mutex); + EXPECT_TRUE(waiting); + } + + // Unblock the thread. + queue.StopWaiters(); + + thread_1.join(); + + // Ensure nothing has been popped off the queue. + EXPECT_FALSE(valid_value); + EXPECT_EQ(87987, value); + + // Ensure another call to Wait returns immediately. + EXPECT_FALSE(queue.Wait(&value)); + + queue.EnableWaiters(); + + value = 0; + valid_value = false; + waiting = false; + + // Start another task waiting for an element to be pushed. + std::thread thread_2(task); + + // Give the thread time to start and wait. + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + + // Ensure nothing is popped off the queue. + { + std::unique_lock<std::mutex> lock(mutex); + EXPECT_TRUE(waiting); + ASSERT_FALSE(valid_value); + ASSERT_EQ(0, value); + } + + queue.Push(13456); + + // Wait for the thread to pop the value. + thread_2.join(); + + EXPECT_TRUE(valid_value); + EXPECT_EQ(13456, value); +} + +} // namespace internal +} // namespace ceres + +#endif // CERES_USE_CXX11_THREADS
diff --git a/internal/ceres/thread_pool.cc b/internal/ceres/thread_pool.cc new file mode 100644 index 0000000..9c7bb89 --- /dev/null +++ b/internal/ceres/thread_pool.cc
@@ -0,0 +1,113 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2018 Google Inc. All rights reserved. +// http://ceres-solver.org/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: vitus@google.com (Michael Vitus) + +// This include must come before any #ifndef check on Ceres compile options. +#include "ceres/internal/port.h" + +#ifdef CERES_USE_CXX11_THREADS + +#include "ceres/thread_pool.h" + +#include <cmath> + +namespace ceres { +namespace internal { +namespace { + +// Constrain the total number of threads to the amount the hardware can support. +int GetNumAllowedThreads(int requested_num_threads) { + const int num_hardware_threads = std::thread::hardware_concurrency(); + // hardware_concurrency() can return 0 if the value is not well defined or not + // computable. + if (num_hardware_threads == 0) { + return requested_num_threads; + } + + return std::min(requested_num_threads, num_hardware_threads); +} + +} // namespace + +ThreadPool::ThreadPool() { } + +ThreadPool::ThreadPool(int num_threads) { + Resize(num_threads); +} + +ThreadPool::~ThreadPool() { + std::unique_lock<std::mutex> lock(thread_pool_mutex_); + // Signal the thread workers to stop and wait for them to finish all scheduled + // tasks. + Stop(); + for (std::thread& thread : thread_pool_) { + thread.join(); + } +} + +void ThreadPool::Resize(int num_threads) { + std::unique_lock<std::mutex> lock(thread_pool_mutex_); + + const int num_current_threads = thread_pool_.size(); + if (num_current_threads >= num_threads) { + return; + } + + const int create_num_threads = + GetNumAllowedThreads(num_threads) - num_current_threads; + + for (int i = 0; i < create_num_threads; ++i) { + thread_pool_.push_back(std::thread(&ThreadPool::ThreadMainLoop, this)); + } +} + +void ThreadPool::AddTask(const std::function<void()>& func) { + task_queue_.Push(func); +} + +int ThreadPool::Size() { + std::unique_lock<std::mutex> lock(thread_pool_mutex_); + return thread_pool_.size(); +} + +void ThreadPool::ThreadMainLoop() { + std::function<void()> task; + while (task_queue_.Wait(&task)) { + task(); + } +} + +void ThreadPool::Stop() { + task_queue_.StopWaiters(); +} + +} // namespace internal +} // namespace ceres + +#endif // CERES_USE_CXX11_THREADS
diff --git a/internal/ceres/thread_pool.h b/internal/ceres/thread_pool.h new file mode 100644 index 0000000..d596ecd --- /dev/null +++ b/internal/ceres/thread_pool.h
@@ -0,0 +1,116 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2018 Google Inc. All rights reserved. +// http://ceres-solver.org/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: vitus@google.com (Michael Vitus) + +#ifndef CERES_INTERNAL_THREAD_POOL_H_ +#define CERES_INTERNAL_THREAD_POOL_H_ + +#include <mutex> +#include <thread> +#include <vector> + +#include "ceres/concurrent_queue.h" + +namespace ceres { +namespace internal { + +// A thread-safe thread pool with an unbounded task queue and a resizable number +// of workers. The size of the thread pool can be increased by never decreased +// in order to support the largest number of threads requested. The ThreadPool +// has three states: +// +// (1) The thread pool size is zero. Tasks may be added to the thread pool via +// AddTask but they will not be executed until the thread pool is resized. +// +// (2) The thread pool size is greater than zero. Tasks may be added to the +// thread pool and will be executed as soon as a worker is available. The +// thread pool may be resized while the thread pool is running. +// +// (3) The thread pool is destructing. The thread pool will signal all the +// workers to stop. The workers will finish all of the tasks that have already +// been added to the thread pool. +// +class ThreadPool { + public: + // Default constructor with no active threads. We allow instantiating a + // thread pool with no threads to support the use case of single threaded + // Ceres where everything will be executed on the main thread. For single + // threaded execution this has two benefits: avoid any overhead as threads + // are expensive to create, and no unused threads shown in the debugger. + ThreadPool(); + + // Instantiates a thread pool with min(num_hardware_threads, num_threads) + // number of threads. + explicit ThreadPool(int num_threads); + + // Signals the workers to stop and waits for them to finish any tasks that + // have been scheduled. + ~ThreadPool(); + + // Resizes the thread pool if it is currently less than the requested number + // of threads. The thread pool will be resized to min(num_hardware_threads, + // num_threads) number of threads. Resize does not support reducing the + // thread pool size. If a smaller number of threads is requested, the thread + // pool remains the same size. The thread pool is reused within Ceres with + // different number of threads, and we need to ensure we can support the + // largest number of threads requested. It is safe to resize the thread pool + // while the workers are executing tasks, and the resizing is guaranteed to + // complete upon return. + void Resize(int num_threads); + + // Adds a task to the queue and wakes up a blocked thread. If the thread pool + // size is greater than zero, then the task will be executed by a currently + // idle thread or when a thread becomes available. If the thread pool has no + // threads, then the task will never be executed and the user should use + // Resize() to create a non-empty thread pool. + void AddTask(const std::function<void()>& func); + + // Returns the current size of the thread pool. + int Size(); + + private: + // Main loop for the threads which blocks on the task queue until work becomes + // available. It will return if and only if Stop has been called. + void ThreadMainLoop(); + + // Signal all the threads to stop. It does not block until the threads are + // finished. + void Stop(); + + // The queue that stores the units of work available for the thread pool. The + // task queue maintains its own thread safety. + ConcurrentQueue<std::function<void()>> task_queue_; + std::vector<std::thread> thread_pool_; + std::mutex thread_pool_mutex_; +}; + +} // namespace internal +} // namespace ceres + +#endif // CERES_INTERNAL_THREAD_POOL_H_
diff --git a/internal/ceres/thread_pool_test.cc b/internal/ceres/thread_pool_test.cc new file mode 100644 index 0000000..1aa81f2 --- /dev/null +++ b/internal/ceres/thread_pool_test.cc
@@ -0,0 +1,197 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2018 Google Inc. All rights reserved. +// http://ceres-solver.org/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: vitus@google.com (Michael Vitus) + +// This include must come before any #ifndef check on Ceres compile options. +#include "ceres/internal/port.h" + +#ifdef CERES_USE_CXX11_THREADS + +#include "ceres/thread_pool.h" + +#include <chrono> +#include <thread> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace ceres { +namespace internal { + +// Adds a number of tasks to the thread pool and ensures they all run. +TEST(ThreadPool, AddTask) { + int value = 0; + const int num_tasks = 100; + { + ThreadPool thread_pool(2); + + std::condition_variable condition; + std::mutex mutex; + + for (int i = 0; i < num_tasks; ++i) { + thread_pool.AddTask([&]() { + std::unique_lock<std::mutex> lock(mutex); + ++value; + condition.notify_all(); + }); + } + + std::unique_lock<std::mutex> lock(mutex); + condition.wait(lock, [&](){return value == num_tasks;}); + } + + EXPECT_EQ(num_tasks, value); +} + +// Adds a number of tasks to the queue and resizes the thread pool while the +// threads are executing their work. +TEST(ThreadPool, ResizingDuringExecution) { + int value = 0; + + const int num_tasks = 100; + + // Run this test in a scope to delete the thread pool and all of the threads + // are stopped. + { + ThreadPool thread_pool(/*num_threads=*/2); + + std::condition_variable condition; + std::mutex mutex; + + // Acquire a lock on the mutex to prevent the threads from finishing their + // execution so we can test resizing the thread pool while the workers are + // executing a task. + std::unique_lock<std::mutex> lock(mutex); + + // The same task for all of the workers to execute. + auto task = [&]() { + // This will block until the mutex is released inside the condition + // variable. + std::unique_lock<std::mutex> lock(mutex); + ++value; + condition.notify_all(); + }; + + // Add the initial set of tasks to run. + for (int i = 0; i < num_tasks / 2; ++i) { + thread_pool.AddTask(task); + } + + // Resize the thread pool while tasks are executing. + thread_pool.Resize(/*num_threads=*/3); + + // Add more tasks to the thread pool to guarantee these are also completed. + for (int i = 0; i < num_tasks / 2; ++i) { + thread_pool.AddTask(task); + } + + // Unlock the mutex to unblock all of the threads and wait until all of the + // tasks are completed. + condition.wait(lock, [&](){return value == num_tasks;}); + } + + EXPECT_EQ(num_tasks, value); +} + +// Tests the destructor will wait until all running tasks are finished before +// destructing the thread pool. +TEST(ThreadPool, Destructor) { + // Ensure the hardware supports more than 1 thread to ensure the test will + // pass. + const int num_hardware_threads = std::thread::hardware_concurrency(); + if (num_hardware_threads <= 1) { + LOG(ERROR) + << "Test not supported, the hardware does not support threading."; + return; + } + + std::condition_variable condition; + std::mutex mutex; + // Lock the mutex to ensure the tasks are blocked. + std::unique_lock<std::mutex> master_lock(mutex); + int value = 0; + + // Create a thread that will instantiate and delete the thread pool. This is + // required because we need to block on the thread pool being deleted and + // signal the tasks to finish. + std::thread thread([&]() { + ThreadPool thread_pool(/*num_threads=*/2); + + for (int i = 0; i < 100; ++i) { + thread_pool.AddTask([&]() { + // This will block until the mutex is released inside the condition + // variable. + std::unique_lock<std::mutex> lock(mutex); + ++value; + condition.notify_all(); + }); + } + // The thread pool should be deleted. + }); + + // Give the thread pool time to start, add all the tasks, and then delete + // itself. + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + + // Unlock the tasks. + master_lock.unlock(); + + // Wait for the thread to complete. + thread.join(); + + EXPECT_EQ(100, value); +} + +TEST(ThreadPool, Resize) { + // Ensure the hardware supports more than 1 thread to ensure the test will + // pass. + const int num_hardware_threads = std::thread::hardware_concurrency(); + if (num_hardware_threads <= 1) { + LOG(ERROR) + << "Test not supported, the hardware does not support threading."; + return; + } + + ThreadPool thread_pool(1); + + EXPECT_EQ(1, thread_pool.Size()); + + thread_pool.Resize(2); + + EXPECT_EQ(2, thread_pool.Size()); + + // Try reducing the thread pool size and verify it stays the same size. + thread_pool.Resize(1); + EXPECT_EQ(2, thread_pool.Size()); +} + +} // namespace internal +} // namespace ceres + +#endif // CERES_USE_CXX11_THREADS