MLIR  16.0.0git
Threading.h
Go to the documentation of this file.
1 //===- Threading.h - MLIR Threading Utilities -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various utilies for multithreaded processing within MLIR.
10 // These utilities automatically handle many of the necessary threading
11 // conditions, such as properly ordering diagnostics, observing if threading is
12 // disabled, etc. These utilities should be used over other threading utilities
13 // whenever feasible.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef MLIR_IR_THREADING_H
18 #define MLIR_IR_THREADING_H
19 
20 #include "mlir/IR/Diagnostics.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/Support/ThreadPool.h"
23 #include <atomic>
24 
25 namespace mlir {
26 
27 /// Invoke the given function on the elements between [begin, end)
28 /// asynchronously. If the given function returns a failure when processing any
29 /// of the elements, execution is stopped and a failure is returned from this
30 /// function. This means that in the case of failure, not all elements of the
31 /// range will be processed. Diagnostics emitted during processing are ordered
32 /// relative to the element's position within [begin, end). If the provided
33 /// context does not have multi-threading enabled, this function always
34 /// processes elements sequentially.
35 template <typename IteratorT, typename FuncT>
37  IteratorT end, FuncT &&func) {
38  unsigned numElements = static_cast<unsigned>(std::distance(begin, end));
39  if (numElements == 0)
40  return success();
41 
42  // If multithreading is disabled or there is a small number of elements,
43  // process the elements directly on this thread.
44  if (!context->isMultithreadingEnabled() || numElements <= 1) {
45  for (; begin != end; ++begin)
46  if (failed(func(*begin)))
47  return failure();
48  return success();
49  }
50 
51  // Build a wrapper processing function that properly initializes a parallel
52  // diagnostic handler.
53  ParallelDiagnosticHandler handler(context);
54  std::atomic<unsigned> curIndex(0);
55  std::atomic<bool> processingFailed(false);
56  auto processFn = [&] {
57  while (!processingFailed) {
58  unsigned index = curIndex++;
59  if (index >= numElements)
60  break;
61  handler.setOrderIDForThread(index);
62  if (failed(func(*std::next(begin, index))))
63  processingFailed = true;
64  handler.eraseOrderIDForThread();
65  }
66  };
67 
68  // Otherwise, process the elements in parallel.
69  llvm::ThreadPool &threadPool = context->getThreadPool();
70  llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
71  size_t numActions = std::min(numElements, threadPool.getThreadCount());
72  for (unsigned i = 0; i < numActions; ++i)
73  tasksGroup.async(processFn);
74  // If the current thread is a worker thread from the pool, then waiting for
75  // the task group allows the current thread to also participate in processing
76  // tasks from the group, which avoid any deadlock/starvation.
77  tasksGroup.wait();
78  return failure(processingFailed);
79 }
80 
81 /// Invoke the given function on the elements in the provided range
82 /// asynchronously. If the given function returns a failure when processing any
83 /// of the elements, execution is stopped and a failure is returned from this
84 /// function. This means that in the case of failure, not all elements of the
85 /// range will be processed. Diagnostics emitted during processing are ordered
86 /// relative to the element's position within the range. If the provided context
87 /// does not have multi-threading enabled, this function always processes
88 /// elements sequentially.
89 template <typename RangeT, typename FuncT>
91  FuncT &&func) {
92  return failableParallelForEach(context, std::begin(range), std::end(range),
93  std::forward<FuncT>(func));
94 }
95 
96 /// Invoke the given function on the elements between [begin, end)
97 /// asynchronously. If the given function returns a failure when processing any
98 /// of the elements, execution is stopped and a failure is returned from this
99 /// function. This means that in the case of failure, not all elements of the
100 /// range will be processed. Diagnostics emitted during processing are ordered
101 /// relative to the element's position within [begin, end). If the provided
102 /// context does not have multi-threading enabled, this function always
103 /// processes elements sequentially.
104 template <typename FuncT>
106  size_t end, FuncT &&func) {
107  return failableParallelForEach(context, llvm::seq(begin, end),
108  std::forward<FuncT>(func));
109 }
110 
111 /// Invoke the given function on the elements between [begin, end)
112 /// asynchronously. Diagnostics emitted during processing are ordered relative
113 /// to the element's position within [begin, end). If the provided context does
114 /// not have multi-threading enabled, this function always processes elements
115 /// sequentially.
116 template <typename IteratorT, typename FuncT>
117 void parallelForEach(MLIRContext *context, IteratorT begin, IteratorT end,
118  FuncT &&func) {
119  (void)failableParallelForEach(context, begin, end, [&](auto &&value) {
120  return func(std::forward<decltype(value)>(value)), success();
121  });
122 }
123 
124 /// Invoke the given function on the elements in the provided range
125 /// asynchronously. Diagnostics emitted during processing are ordered relative
126 /// to the element's position within the range. If the provided context does not
127 /// have multi-threading enabled, this function always processes elements
128 /// sequentially.
129 template <typename RangeT, typename FuncT>
130 void parallelForEach(MLIRContext *context, RangeT &&range, FuncT &&func) {
131  parallelForEach(context, std::begin(range), std::end(range),
132  std::forward<FuncT>(func));
133 }
134 
135 /// Invoke the given function on the elements between [begin, end)
136 /// asynchronously. Diagnostics emitted during processing are ordered relative
137 /// to the element's position within [begin, end). If the provided context does
138 /// not have multi-threading enabled, this function always processes elements
139 /// sequentially.
140 template <typename FuncT>
141 void parallelFor(MLIRContext *context, size_t begin, size_t end, FuncT &&func) {
142  parallelForEach(context, llvm::seq(begin, end), std::forward<FuncT>(func));
143 }
144 
145 } // namespace mlir
146 
147 #endif // MLIR_IR_THREADING_H
void parallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
Definition: Threading.h:117
Include the generated interface declarations.
void parallelFor(MLIRContext *context, size_t begin, size_t end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
Definition: Threading.h:141
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static constexpr const bool value
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
Definition: Threading.h:36
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This class is a utility diagnostic handler for use when multi-threading some part of the compiler whe...
Definition: Diagnostics.h:648
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
void setOrderIDForThread(size_t orderID)
Set the order id for the current thread.
void eraseOrderIDForThread()
Remove the order id for the current thread.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
LogicalResult failableParallelForEachN(MLIRContext *context, size_t begin, size_t end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
Definition: Threading.h:105
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
llvm::ThreadPool & getThreadPool()
Return the thread pool used by this context.