30 #define GEN_PASS_DEF_ASYNCPARALLELFORPASS
31 #include "mlir/Dialect/Async/Passes.h.inc"
37 #define DEBUG_TYPE "async-parallel-for"
100 struct AsyncParallelForPass
101 :
public impl::AsyncParallelForPassBase<AsyncParallelForPass> {
104 void runOnOperation()
override;
109 AsyncParallelForRewrite(
110 MLIRContext *ctx,
bool asyncDispatch, int32_t numWorkerThreads,
113 numWorkerThreads(numWorkerThreads),
114 computeMinTaskSize(std::move(computeMinTaskSize)) {}
116 LogicalResult matchAndRewrite(scf::ParallelOp op,
121 int32_t numWorkerThreads;
125 struct ParallelComputeFunctionType {
131 struct ParallelComputeFunctionArgs {
143 struct ParallelComputeFunctionBounds {
150 struct ParallelComputeFunction {
158 BlockArgument ParallelComputeFunctionArgs::blockIndex() {
return args[0]; }
159 BlockArgument ParallelComputeFunctionArgs::blockSize() {
return args[1]; }
162 return args.drop_front(2).take_front(numLoops);
166 return args.drop_front(2 + 1 * numLoops).take_front(numLoops);
170 return args.drop_front(2 + 3 * numLoops).take_front(numLoops);
174 return args.drop_front(2 + 4 * numLoops);
177 template <
typename ValueRange>
180 for (
unsigned i = 0; i < values.size(); ++i)
190 assert(!tripCounts.empty() &&
"tripCounts must be not empty");
192 for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
193 coords[i] = arith::RemSIOp::create(b, index, tripCounts[i]);
194 index = arith::DivSIOp::create(b, index, tripCounts[i]);
203 static ParallelComputeFunctionType
210 inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
215 inputs.push_back(indexTy);
216 inputs.push_back(indexTy);
219 for (
unsigned i = 0; i < op.getNumLoops(); ++i)
220 inputs.push_back(indexTy);
225 for (
unsigned i = 0; i < op.getNumLoops(); ++i) {
226 inputs.push_back(indexTy);
227 inputs.push_back(indexTy);
228 inputs.push_back(indexTy);
232 for (
Value capture : captures)
233 inputs.push_back(capture.getType());
242 scf::ParallelOp op,
const ParallelComputeFunctionBounds &bounds,
247 ModuleOp module = op->getParentOfType<ModuleOp>();
249 ParallelComputeFunctionType computeFuncType =
252 FunctionType type = computeFuncType.type;
253 func::FuncOp func = func::FuncOp::create(
255 numBlockAlignedInnerLoops > 0 ?
"parallel_compute_fn_with_aligned_loops"
256 :
"parallel_compute_fn",
267 b.
createBlock(&func.getBody(), func.begin(), type.getInputs(),
271 ParallelComputeFunctionArgs args = {op.getNumLoops(), func.getArguments()};
283 return llvm::to_vector(
284 llvm::map_range(llvm::zip(args, attrs), [&](
auto tuple) ->
Value {
285 if (IntegerAttr attr = std::get<1>(tuple))
286 return arith::ConstantOp::create(b, attr);
287 return std::get<0>(tuple);
292 auto tripCounts = values(args.tripCounts(), bounds.tripCounts);
295 auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds);
296 auto steps = values(args.steps(), bounds.steps);
303 Value tripCount = tripCounts[0];
304 for (
unsigned i = 1; i < tripCounts.size(); ++i)
305 tripCount = arith::MulIOp::create(b, tripCount, tripCounts[i]);
309 Value blockFirstIndex = arith::MulIOp::create(b, blockIndex, blockSize);
313 Value blockEnd0 = arith::AddIOp::create(b, blockFirstIndex, blockSize);
314 Value blockEnd1 = arith::MinSIOp::create(b, blockEnd0, tripCount);
315 Value blockLastIndex = arith::SubIOp::create(b, blockEnd1, c1);
318 auto blockFirstCoord =
delinearize(b, blockFirstIndex, tripCounts);
319 auto blockLastCoord =
delinearize(b, blockLastIndex, tripCounts);
327 for (
size_t i = 0; i < blockLastCoord.size(); ++i)
328 blockEndCoord[i] = arith::AddIOp::create(b, blockLastCoord[i], c1);
332 using LoopBodyBuilder =
334 using LoopNestBuilder = std::function<LoopBodyBuilder(
size_t loopIdx)>;
365 LoopNestBuilder workLoopBuilder = [&](
size_t loopIdx) -> LoopBodyBuilder {
371 computeBlockInductionVars[loopIdx] =
372 arith::AddIOp::create(b, lowerBounds[loopIdx],
373 arith::MulIOp::create(b, iv, steps[loopIdx]));
376 isBlockFirstCoord[loopIdx] = arith::CmpIOp::create(
377 b, arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
378 isBlockLastCoord[loopIdx] = arith::CmpIOp::create(
379 b, arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
383 isBlockFirstCoord[loopIdx] = arith::AndIOp::create(
384 b, isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
385 isBlockLastCoord[loopIdx] = arith::AndIOp::create(
386 b, isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
390 if (loopIdx < op.getNumLoops() - 1) {
391 if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) {
394 scf::ForOp::create(b, c0, tripCounts[loopIdx + 1], c1,
ValueRange(),
395 workLoopBuilder(loopIdx + 1));
400 auto lb = arith::SelectOp::create(b, isBlockFirstCoord[loopIdx],
401 blockFirstCoord[loopIdx + 1], c0);
403 auto ub = arith::SelectOp::create(b, isBlockLastCoord[loopIdx],
404 blockEndCoord[loopIdx + 1],
405 tripCounts[loopIdx + 1]);
407 scf::ForOp::create(b, lb, ub, c1,
ValueRange(),
408 workLoopBuilder(loopIdx + 1));
411 scf::YieldOp::create(b, loc);
417 mapping.
map(op.getInductionVars(), computeBlockInductionVars);
418 mapping.
map(computeFuncType.captures, captures);
420 for (
auto &bodyOp : op.getRegion().front().without_terminator())
421 b.
clone(bodyOp, mapping);
422 scf::YieldOp::create(b, loc);
426 scf::ForOp::create(b, blockFirstCoord[0], blockEndCoord[0], c1,
ValueRange(),
430 return {op.getNumLoops(), func, std::move(computeFuncType.captures)};
456 Location loc = computeFunc.func.getLoc();
459 ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
462 computeFunc.func.getFunctionType().getInputs();
471 inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end());
474 func::FuncOp func = func::FuncOp::create(loc,
"async_dispatch_fn", type);
483 Block *block = b.
createBlock(&func.getBody(), func.begin(), type.getInputs(),
504 scf::WhileOp whileOp = scf::WhileOp::create(b, types, operands);
514 Value distance = arith::SubIOp::create(b, end, start);
516 arith::CmpIOp::create(b, arith::CmpIPredicate::sgt, distance, c1);
517 scf::ConditionOp::create(b, dispatch, before->
getArguments());
526 Value distance = arith::SubIOp::create(b, end, start);
527 Value halfDistance = arith::DivSIOp::create(b, distance, c2);
528 Value midIndex = arith::AddIOp::create(b, start, halfDistance);
531 auto executeBodyBuilder = [&](
OpBuilder &executeBuilder,
536 operands[1] = midIndex;
539 func::CallOp::create(executeBuilder, executeLoc, func.getSymName(),
540 func.getResultTypes(), operands);
541 async::YieldOp::create(executeBuilder, executeLoc,
ValueRange());
547 AddToGroupOp::create(b, indexTy, execute.getToken(), group);
548 scf::YieldOp::create(b,
ValueRange({start, midIndex}));
556 auto forwardedInputs = block->
getArguments().drop_front(3);
558 computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
560 func::CallOp::create(b, computeFunc.func.getSymName(),
561 computeFunc.func.getResultTypes(), computeFuncOperands);
569 ParallelComputeFunction ¶llelComputeFunction,
570 scf::ParallelOp op,
Value blockSize,
577 func::FuncOp asyncDispatchFunction =
586 operands.append(tripCounts);
587 operands.append(op.getLowerBound().begin(), op.getLowerBound().end());
588 operands.append(op.getUpperBound().begin(), op.getUpperBound().end());
589 operands.append(op.getStep().begin(), op.getStep().end());
590 operands.append(parallelComputeFunction.captures);
596 Value isSingleBlock =
597 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, blockCount, c1);
604 appendBlockComputeOperands(operands);
606 func::CallOp::create(b, parallelComputeFunction.func.getSymName(),
607 parallelComputeFunction.func.getResultTypes(),
609 scf::YieldOp::create(b);
618 Value groupSize = arith::SubIOp::create(b, blockCount, c1);
623 appendBlockComputeOperands(operands);
625 func::CallOp::create(b, asyncDispatchFunction.getSymName(),
626 asyncDispatchFunction.getResultTypes(), operands);
629 AwaitAllOp::create(b, group);
631 scf::YieldOp::create(b);
635 scf::IfOp::create(b, isSingleBlock, syncDispatch, asyncDispatch);
642 ParallelComputeFunction ¶llelComputeFunction,
643 scf::ParallelOp op,
Value blockSize,
Value blockCount,
647 func::FuncOp compute = parallelComputeFunction.func;
655 Value groupSize = arith::SubIOp::create(b, blockCount, c1);
659 using LoopBodyBuilder =
665 computeFuncOperands.append(tripCounts);
666 computeFuncOperands.append(op.getLowerBound().begin(),
667 op.getLowerBound().end());
668 computeFuncOperands.append(op.getUpperBound().begin(),
669 op.getUpperBound().end());
670 computeFuncOperands.append(op.getStep().begin(), op.getStep().end());
671 computeFuncOperands.append(parallelComputeFunction.captures);
672 return computeFuncOperands;
681 auto executeBodyBuilder = [&](
OpBuilder &executeBuilder,
683 func::CallOp::create(executeBuilder, executeLoc, compute.getSymName(),
684 compute.getResultTypes(), computeFuncOperands(iv));
685 async::YieldOp::create(executeBuilder, executeLoc,
ValueRange());
691 AddToGroupOp::create(b, rewriter.
getIndexType(), execute.getToken(), group);
692 scf::YieldOp::create(b);
696 scf::ForOp::create(b, c1, blockCount, c1,
ValueRange(), loopBuilder);
699 func::CallOp::create(b, compute.getSymName(), compute.getResultTypes(),
700 computeFuncOperands(c0));
703 AwaitAllOp::create(b, group);
707 AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
710 if (op.getNumReductions() != 0)
718 Value minTaskSize = computeMinTaskSize(b, op);
727 for (
size_t i = 0; i < op.getNumLoops(); ++i) {
728 auto lb = op.getLowerBound()[i];
729 auto ub = op.getUpperBound()[i];
730 auto step = op.getStep()[i];
731 auto range = b.createOrFold<arith::SubIOp>(ub, lb);
732 tripCounts[i] = b.createOrFold<arith::CeilDivSIOp>(range, step);
737 Value tripCount = tripCounts[0];
738 for (
size_t i = 1; i < tripCounts.size(); ++i)
739 tripCount = arith::MulIOp::create(b, tripCount, tripCounts[i]);
744 Value isZeroIterations =
745 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, tripCount, c0);
749 scf::YieldOp::create(nestedBuilder, loc);
762 ParallelComputeFunctionBounds staticBounds = {
776 static constexpr int64_t maxUnrollableIterations = 512;
780 int numUnrollableLoops = 0;
782 auto getInt = [](IntegerAttr attr) {
return attr ? attr.getInt() : 0; };
785 numIterations.back() = getInt(staticBounds.tripCounts.back());
787 for (
int i = op.getNumLoops() - 2; i >= 0; --i) {
788 int64_t tripCount = getInt(staticBounds.tripCounts[i]);
789 int64_t innerIterations = numIterations[i + 1];
790 numIterations[i] = tripCount * innerIterations;
793 if (innerIterations > 0 && innerIterations <= maxUnrollableIterations)
794 numUnrollableLoops++;
797 Value numWorkerThreadsVal;
798 if (numWorkerThreads >= 0)
801 numWorkerThreadsVal = async::RuntimeNumWorkerThreadsOp::create(b);
817 {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}};
818 const float initialOvershardingFactor = 8.0f;
821 b, b.getF32Type(), llvm::APFloat(initialOvershardingFactor));
822 for (
const std::pair<int, float> &p : overshardingBrackets) {
824 Value inBracket = arith::CmpIOp::create(
825 b, arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
827 b, b.getF32Type(), llvm::APFloat(p.second));
828 scalingFactor = arith::SelectOp::create(
829 b, inBracket, bracketScalingFactor, scalingFactor);
831 Value numWorkersIndex =
832 arith::IndexCastOp::create(b, b.getI32Type(), numWorkerThreadsVal);
833 Value numWorkersFloat =
834 arith::SIToFPOp::create(b, b.getF32Type(), numWorkersIndex);
835 Value scaledNumWorkers =
836 arith::MulFOp::create(b, scalingFactor, numWorkersFloat);
838 arith::FPToSIOp::create(b, b.getI32Type(), scaledNumWorkers);
839 Value scaledWorkers =
840 arith::IndexCastOp::create(b, b.getIndexType(), scaledNumInt);
842 Value maxComputeBlocks = arith::MaxSIOp::create(
849 Value bs0 = arith::CeilDivSIOp::create(b, tripCount, maxComputeBlocks);
850 Value bs1 = arith::MaxSIOp::create(b, bs0, minTaskSize);
851 Value blockSize = arith::MinSIOp::create(b, tripCount, bs1);
861 Value blockCount = arith::CeilDivSIOp::create(b, tripCount, blockSize);
865 ParallelComputeFunction compute =
869 doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts);
870 scf::YieldOp::create(b);
876 op, staticBounds, numUnrollableLoops, rewriter);
882 b, numIterations[op.getNumLoops() - numUnrollableLoops]);
883 Value alignedBlockSize = arith::MulIOp::create(
884 b, arith::CeilDivSIOp::create(b, blockSize, numIters), numIters);
885 doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount,
887 scf::YieldOp::create(b);
893 if (numUnrollableLoops > 0) {
895 b, numIterations[op.getNumLoops() - numUnrollableLoops]);
896 Value useBlockAlignedComputeFn = arith::CmpIOp::create(
897 b, arith::CmpIPredicate::sge, blockSize, numIters);
899 scf::IfOp::create(b, useBlockAlignedComputeFn, dispatchBlockAligned,
901 scf::YieldOp::create(b);
903 dispatchDefault(b, loc);
908 scf::IfOp::create(b, isZeroIterations, noOp, dispatch);
916 void AsyncParallelForPass::runOnOperation() {
921 patterns, asyncDispatch, numWorkerThreads,
933 patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
static ParallelComputeFunction createParallelComputeFunction(scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds, unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter)
static func::FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, PatternRewriter &rewriter)
static void doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction ¶llelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction ¶llelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static ParallelComputeFunctionType getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter)
static SmallVector< IntegerAttr > integerConstants(ValueRange values)
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
BlockArgListType getArguments()
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
std::function< Value(ImplicitLocOpBuilder, scf::ParallelOp)> AsyncMinTaskSizeComputationFunction
Emit the IR to compute the minimum number of iterations of scf.parallel body that would be viable for...
void populateAsyncParallelForPatterns(RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, const AsyncMinTaskSizeComputationFunction &computeMinTaskSize)
Add a pattern to the given pattern list to lower scf.parallel to async operations.
void cloneConstantsIntoTheRegion(Region ®ion)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...