17#ifndef MLIR_IR_THREADING_H
18#define MLIR_IR_THREADING_H
21#include "llvm/ADT/Sequence.h"
22#include "llvm/Support/ThreadPool.h"
35template <
typename IteratorT,
typename FuncT>
37 IteratorT end, FuncT &&
func) {
38 unsigned numElements =
static_cast<unsigned>(std::distance(begin, end));
45 for (; begin != end; ++begin)
46 if (failed(
func(*begin)))
54 std::atomic<unsigned> curIndex(0);
55 std::atomic<bool> processingFailed(
false);
56 auto processFn = [&] {
57 while (!processingFailed) {
58 unsigned index = curIndex++;
59 if (
index >= numElements)
62 if (failed(
func(*std::next(begin,
index))))
63 processingFailed =
true;
69 llvm::ThreadPoolInterface &threadPool = context->
getThreadPool();
70 llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
71 size_t numActions = std::min(numElements, threadPool.getMaxConcurrency());
72 for (
unsigned i = 0; i < numActions; ++i)
73 tasksGroup.async(processFn);
78 return failure(processingFailed);
89template <
typename RangeT,
typename FuncT>
93 std::forward<FuncT>(
func));
104template <
typename FuncT>
106 size_t end, FuncT &&
func) {
108 std::forward<FuncT>(
func));
116template <
typename IteratorT,
typename FuncT>
120 return func(std::forward<
decltype(value)>(value)),
success();
129template <
typename RangeT,
typename FuncT>
132 std::forward<FuncT>(
func));
140template <
typename FuncT>
MLIRContext is the top-level object for a collection of MLIR operations.
llvm::ThreadPoolInterface & getThreadPool()
Return the thread pool used by this context.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class is a utility diagnostic handler for use when multi-threading some part of the compiler whe...
void eraseOrderIDForThread()
Remove the order id for the current thread.
void setOrderIDForThread(size_t orderID)
Set the order id for the current thread.
Include the generated interface declarations.
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
void parallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
void parallelFor(MLIRContext *context, size_t begin, size_t end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
LogicalResult failableParallelForEachN(MLIRContext *context, size_t begin, size_t end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.