blob: 11db1fbc4886c5146bf2c45dea71039a6c7556cd [file] [log] [blame]
Mike Vitusdc5ea0e2018-01-24 15:53:19 -08001// Ceres Solver - A fast non-linear least squares minimizer
Sameer Agarwal5a30cae2023-09-19 15:29:34 -07002// Copyright 2023 Google Inc. All rights reserved.
Mike Vitusdc5ea0e2018-01-24 15:53:19 -08003// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
Dmitriy Korchemkinc0c4f932022-08-18 22:10:17 +030029// Authors: vitus@google.com (Michael Vitus),
30// dmitriy.korchemkin@gmail.com (Dmitriy Korchemkin)
Mike Vitusdc5ea0e2018-01-24 15:53:19 -080031
Sameer Agarwal47051592022-03-12 15:22:19 -080032#ifndef CERES_INTERNAL_PARALLEL_FOR_H_
33#define CERES_INTERNAL_PARALLEL_FOR_H_
Mike Vitusdc5ea0e2018-01-24 15:53:19 -080034
Sameer Agarwal9a289472022-09-20 09:50:10 -070035#include <mutex>
Dmitriy Korchemkin54ad3dd2022-12-19 18:24:54 +030036#include <vector>
Mike Vitusdc5ea0e2018-01-24 15:53:19 -080037
Mike Vitusf408f892018-02-22 10:28:39 -080038#include "ceres/context_impl.h"
Dmitriy Korchemkinb1585152022-11-27 21:35:44 +030039#include "ceres/internal/eigen.h"
Sergiu Deitschf90833f2022-02-07 23:43:19 +010040#include "ceres/internal/export.h"
Dmitriy Korchemkin54ad3dd2022-12-19 18:24:54 +030041#include "ceres/parallel_invoke.h"
42#include "ceres/partition_range_for_parallel_for.h"
Dmitriy Korchemkinc0c4f932022-08-18 22:10:17 +030043#include "glog/logging.h"
Mike Vitusf408f892018-02-22 10:28:39 -080044
Sameer Agarwalcaf614a2022-04-21 17:41:10 -070045namespace ceres::internal {
Mike Vitusdc5ea0e2018-01-24 15:53:19 -080046
Sameer Agarwal9a289472022-09-20 09:50:10 -070047// Use a dummy mutex if num_threads = 1.
48inline decltype(auto) MakeConditionalLock(const int num_threads,
49 std::mutex& m) {
50 return (num_threads == 1) ? std::unique_lock<std::mutex>{}
51 : std::unique_lock<std::mutex>{m};
52}
53
Mike Vitusf0c3b232018-02-28 13:08:48 -080054// Execute the function for every element in the range [start, end) with at most
55// num_threads. It will execute all the work on the calling thread if
Dmitriy Korchemkinc0c4f932022-08-18 22:10:17 +030056// num_threads or (end - start) is equal to 1.
Dmitriy Korchemkin54ad3dd2022-12-19 18:24:54 +030057// Depending on function signature, it will be supplied with either loop index
58// or a range of loop indicies; function can also be supplied with thread_id.
59// The following function signatures are supported:
60// - Functions accepting a single loop index:
61// - [](int index) { ... }
62// - [](int thread_id, int index) { ... }
63// - Functions accepting a range of loop index:
64// - [](std::tuple<int, int> index) { ... }
65// - [](int thread_id, std::tuple<int, int> index) { ... }
Dmitriy Korchemkinc0c4f932022-08-18 22:10:17 +030066//
Dmitriy Korchemkin54ad3dd2022-12-19 18:24:54 +030067// When distributing workload between threads, it is assumed that each loop
68// iteration takes approximately equal time to complete.
Dmitriy Korchemkinc0c4f932022-08-18 22:10:17 +030069template <typename F>
Dmitriy Korchemkindc7a8592023-10-06 15:55:21 +000070void ParallelFor(ContextImpl* context,
71 int start,
72 int end,
73 int num_threads,
74 F&& function,
75 int min_block_size = 1) {
Dmitriy Korchemkinc0c4f932022-08-18 22:10:17 +030076 CHECK_GT(num_threads, 0);
77 if (start >= end) {
78 return;
79 }
80
Dmitriy Korchemkindc7a8592023-10-06 15:55:21 +000081 if (num_threads == 1 || end - start < min_block_size * 2) {
Dmitriy Korchemkin54ad3dd2022-12-19 18:24:54 +030082 InvokeOnSegment(0, std::make_tuple(start, end), std::forward<F>(function));
Dmitriy Korchemkinc0c4f932022-08-18 22:10:17 +030083 return;
84 }
85
86 CHECK(context != nullptr);
Dmitriy Korchemkindc7a8592023-10-06 15:55:21 +000087 ParallelInvoke(context,
88 start,
89 end,
90 num_threads,
91 std::forward<F>(function),
92 min_block_size);
Dmitriy Korchemkin54ad3dd2022-12-19 18:24:54 +030093}
94
95// Execute function for every element in the range [start, end) with at most
96// num_threads, using user-provided partitions array.
97// When distributing workload between threads, it is assumed that each segment
98// bounded by adjacent elements of partitions array takes approximately equal
99// time to process.
100template <typename F>
101void ParallelFor(ContextImpl* context,
102 int start,
103 int end,
104 int num_threads,
105 F&& function,
106 const std::vector<int>& partitions) {
107 CHECK_GT(num_threads, 0);
108 if (start >= end) {
109 return;
110 }
111 CHECK_EQ(partitions.front(), start);
112 CHECK_EQ(partitions.back(), end);
113 if (num_threads == 1 || end - start <= num_threads) {
114 ParallelFor(context, start, end, num_threads, std::forward<F>(function));
115 return;
116 }
117 CHECK_GT(partitions.size(), 1);
118 const int num_partitions = partitions.size() - 1;
119 ParallelFor(context,
120 0,
121 num_partitions,
122 num_threads,
123 [&function, &partitions](int thread_id,
124 std::tuple<int, int> partition_ids) {
125 // partition_ids is a range of partition indices
126 const auto [partition_start, partition_end] = partition_ids;
127 // Execution over several adjacent segments is equivalent
128 // to execution over union of those segments (which is also a
129 // contiguous segment)
130 const int range_start = partitions[partition_start];
131 const int range_end = partitions[partition_end];
132 // Range of original loop indices
133 const auto range = std::make_tuple(range_start, range_end);
134 InvokeOnSegment(thread_id, range, function);
135 });
Dmitriy Korchemkinc0c4f932022-08-18 22:10:17 +0300136}
Dmitriy Korchemkin5d53d1e2022-11-02 16:06:48 +0300137
138// Execute function for every element in the range [start, end) with at most
139// num_threads, taking into account user-provided integer cumulative costs of
140// iterations. Cumulative costs of iteration for indices in range [0, end) are
141// stored in objects from cumulative_cost_data. User-provided
142// cumulative_cost_fun returns non-decreasing integer values corresponding to
143// inclusive cumulative cost of loop iterations, provided with a reference to
144// user-defined object. Only indices from [start, end) will be referenced. This
145// routine assumes that cumulative_cost_fun is non-decreasing (in other words,
146// all costs are non-negative);
Dmitriy Korchemkin54ad3dd2022-12-19 18:24:54 +0300147// When distributing workload between threads, input range of loop indices will
148// be partitioned into disjoint contiguous intervals, with the maximal cost
149// being minimized.
150// For example, with iteration costs of [1, 1, 5, 3, 1, 4] cumulative_cost_fun
151// should return [1, 2, 7, 10, 11, 15], and with num_threads = 4 this range
152// will be split into segments [0, 2) [2, 3) [3, 5) [5, 6) with costs
153// [2, 5, 4, 4].
Dmitriy Korchemkin5d53d1e2022-11-02 16:06:48 +0300154template <typename F, typename CumulativeCostData, typename CumulativeCostFun>
155void ParallelFor(ContextImpl* context,
156 int start,
157 int end,
158 int num_threads,
Dmitriy Korchemkin54ad3dd2022-12-19 18:24:54 +0300159 F&& function,
Dmitriy Korchemkin5d53d1e2022-11-02 16:06:48 +0300160 const CumulativeCostData* cumulative_cost_data,
Dmitriy Korchemkin54ad3dd2022-12-19 18:24:54 +0300161 CumulativeCostFun&& cumulative_cost_fun) {
Dmitriy Korchemkin5d53d1e2022-11-02 16:06:48 +0300162 CHECK_GT(num_threads, 0);
163 if (start >= end) {
164 return;
165 }
166 if (num_threads == 1 || end - start <= num_threads) {
Dmitriy Korchemkin54ad3dd2022-12-19 18:24:54 +0300167 ParallelFor(context, start, end, num_threads, std::forward<F>(function));
Dmitriy Korchemkin5d53d1e2022-11-02 16:06:48 +0300168 return;
169 }
170 // Creating several partitions allows us to tolerate imperfections of
171 // partitioning and user-supplied iteration costs up to a certain extent
Dmitriy Korchemkin54ad3dd2022-12-19 18:24:54 +0300172 constexpr int kNumPartitionsPerThread = 4;
Dmitriy Korchemkin5d53d1e2022-11-02 16:06:48 +0300173 const int kMaxPartitions = num_threads * kNumPartitionsPerThread;
Dmitriy Korchemkin54ad3dd2022-12-19 18:24:54 +0300174 const auto& partitions = PartitionRangeForParallelFor(
175 start,
176 end,
177 kMaxPartitions,
178 cumulative_cost_data,
179 std::forward<CumulativeCostFun>(cumulative_cost_fun));
Dmitriy Korchemkin5d53d1e2022-11-02 16:06:48 +0300180 CHECK_GT(partitions.size(), 1);
Dmitriy Korchemkin54ad3dd2022-12-19 18:24:54 +0300181 ParallelFor(
182 context, start, end, num_threads, std::forward<F>(function), partitions);
Dmitriy Korchemkin5d53d1e2022-11-02 16:06:48 +0300183}
Sameer Agarwalcaf614a2022-04-21 17:41:10 -0700184} // namespace ceres::internal
Mike Vitusdc5ea0e2018-01-24 15:53:19 -0800185
Mike Vitusdc5ea0e2018-01-24 15:53:19 -0800186#endif // CERES_INTERNAL_PARALLEL_FOR_H_