31 #define GEN_PASS_DEF_ASYNCPARALLELFORPASS
32 #include "mlir/Dialect/Async/Passes.h.inc"
38 #define DEBUG_TYPE "async-parallel-for"
101 struct AsyncParallelForPass
102 :
public impl::AsyncParallelForPassBase<AsyncParallelForPass> {
105 void runOnOperation()
override;
110 AsyncParallelForRewrite(
111 MLIRContext *ctx,
bool asyncDispatch, int32_t numWorkerThreads,
114 numWorkerThreads(numWorkerThreads),
115 computeMinTaskSize(std::move(computeMinTaskSize)) {}
117 LogicalResult matchAndRewrite(scf::ParallelOp op,
122 int32_t numWorkerThreads;
126 struct ParallelComputeFunctionType {
132 struct ParallelComputeFunctionArgs {
144 struct ParallelComputeFunctionBounds {
151 struct ParallelComputeFunction {
159 BlockArgument ParallelComputeFunctionArgs::blockIndex() {
return args[0]; }
160 BlockArgument ParallelComputeFunctionArgs::blockSize() {
return args[1]; }
163 return args.drop_front(2).take_front(numLoops);
167 return args.drop_front(2 + 1 * numLoops).take_front(numLoops);
171 return args.drop_front(2 + 3 * numLoops).take_front(numLoops);
175 return args.drop_front(2 + 4 * numLoops);
178 template <
typename ValueRange>
181 for (
unsigned i = 0; i < values.size(); ++i)
191 assert(!tripCounts.empty() &&
"tripCounts must be not empty");
193 for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
194 coords[i] = b.
create<arith::RemSIOp>(index, tripCounts[i]);
195 index = b.
create<arith::DivSIOp>(index, tripCounts[i]);
204 static ParallelComputeFunctionType
211 inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
216 inputs.push_back(indexTy);
217 inputs.push_back(indexTy);
220 for (
unsigned i = 0; i < op.getNumLoops(); ++i)
221 inputs.push_back(indexTy);
226 for (
unsigned i = 0; i < op.getNumLoops(); ++i) {
227 inputs.push_back(indexTy);
228 inputs.push_back(indexTy);
229 inputs.push_back(indexTy);
233 for (
Value capture : captures)
234 inputs.push_back(capture.getType());
243 scf::ParallelOp op,
const ParallelComputeFunctionBounds &bounds,
248 ModuleOp module = op->getParentOfType<ModuleOp>();
250 ParallelComputeFunctionType computeFuncType =
253 FunctionType type = computeFuncType.type;
254 func::FuncOp func = func::FuncOp::create(
256 numBlockAlignedInnerLoops > 0 ?
"parallel_compute_fn_with_aligned_loops"
257 :
"parallel_compute_fn",
268 b.
createBlock(&func.getBody(), func.begin(), type.getInputs(),
272 ParallelComputeFunctionArgs args = {op.getNumLoops(), func.getArguments()};
284 return llvm::to_vector(
285 llvm::map_range(llvm::zip(args, attrs), [&](
auto tuple) ->
Value {
286 if (IntegerAttr attr = std::get<1>(tuple))
287 return b.
create<arith::ConstantOp>(attr);
288 return std::get<0>(tuple);
293 auto tripCounts = values(args.tripCounts(), bounds.tripCounts);
296 auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds);
297 auto steps = values(args.steps(), bounds.steps);
304 Value tripCount = tripCounts[0];
305 for (
unsigned i = 1; i < tripCounts.size(); ++i)
306 tripCount = b.
create<arith::MulIOp>(tripCount, tripCounts[i]);
310 Value blockFirstIndex = b.
create<arith::MulIOp>(blockIndex, blockSize);
314 Value blockEnd0 = b.
create<arith::AddIOp>(blockFirstIndex, blockSize);
315 Value blockEnd1 = b.
create<arith::MinSIOp>(blockEnd0, tripCount);
316 Value blockLastIndex = b.
create<arith::SubIOp>(blockEnd1, c1);
319 auto blockFirstCoord =
delinearize(b, blockFirstIndex, tripCounts);
320 auto blockLastCoord =
delinearize(b, blockLastIndex, tripCounts);
328 for (
size_t i = 0; i < blockLastCoord.size(); ++i)
329 blockEndCoord[i] = b.
create<arith::AddIOp>(blockLastCoord[i], c1);
333 using LoopBodyBuilder =
335 using LoopNestBuilder = std::function<LoopBodyBuilder(
size_t loopIdx)>;
366 LoopNestBuilder workLoopBuilder = [&](
size_t loopIdx) -> LoopBodyBuilder {
372 computeBlockInductionVars[loopIdx] = b.
create<arith::AddIOp>(
373 lowerBounds[loopIdx], b.
create<arith::MulIOp>(iv, steps[loopIdx]));
376 isBlockFirstCoord[loopIdx] = b.
create<arith::CmpIOp>(
377 arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
378 isBlockLastCoord[loopIdx] = b.
create<arith::CmpIOp>(
379 arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
383 isBlockFirstCoord[loopIdx] = b.
create<arith::AndIOp>(
384 isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
385 isBlockLastCoord[loopIdx] = b.
create<arith::AndIOp>(
386 isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
390 if (loopIdx < op.getNumLoops() - 1) {
391 if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) {
395 workLoopBuilder(loopIdx + 1));
400 auto lb = b.
create<arith::SelectOp>(isBlockFirstCoord[loopIdx],
401 blockFirstCoord[loopIdx + 1], c0);
403 auto ub = b.
create<arith::SelectOp>(isBlockLastCoord[loopIdx],
404 blockEndCoord[loopIdx + 1],
405 tripCounts[loopIdx + 1]);
408 workLoopBuilder(loopIdx + 1));
411 b.
create<scf::YieldOp>(loc);
417 mapping.
map(op.getInductionVars(), computeBlockInductionVars);
418 mapping.
map(computeFuncType.captures, captures);
420 for (
auto &bodyOp : op.getRegion().front().without_terminator())
421 b.
clone(bodyOp, mapping);
422 b.
create<scf::YieldOp>(loc);
426 b.
create<scf::ForOp>(blockFirstCoord[0], blockEndCoord[0], c1,
ValueRange(),
430 return {op.getNumLoops(), func, std::move(computeFuncType.captures)};
456 Location loc = computeFunc.func.getLoc();
459 ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
462 computeFunc.func.getFunctionType().getInputs();
471 inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end());
474 func::FuncOp func = func::FuncOp::create(loc,
"async_dispatch_fn", type);
483 Block *block = b.
createBlock(&func.getBody(), func.begin(), type.getInputs(),
504 scf::WhileOp whileOp = b.
create<scf::WhileOp>(types, operands);
512 Value start = before->getArgument(0);
513 Value end = before->getArgument(1);
514 Value distance = b.
create<arith::SubIOp>(end, start);
516 b.
create<arith::CmpIOp>(arith::CmpIPredicate::sgt, distance, c1);
517 b.
create<scf::ConditionOp>(dispatch, before->getArguments());
526 Value distance = b.
create<arith::SubIOp>(end, start);
527 Value halfDistance = b.
create<arith::DivSIOp>(distance, c2);
528 Value midIndex = b.
create<arith::AddIOp>(start, halfDistance);
531 auto executeBodyBuilder = [&](
OpBuilder &executeBuilder,
536 operands[1] = midIndex;
539 executeBuilder.
create<func::CallOp>(executeLoc, func.getSymName(),
540 func.getResultTypes(), operands);
547 b.
create<AddToGroupOp>(indexTy, execute.getToken(), group);
556 auto forwardedInputs = block->
getArguments().drop_front(3);
558 computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
560 b.
create<func::CallOp>(computeFunc.func.getSymName(),
561 computeFunc.func.getResultTypes(),
562 computeFuncOperands);
570 ParallelComputeFunction ¶llelComputeFunction,
571 scf::ParallelOp op,
Value blockSize,
578 func::FuncOp asyncDispatchFunction =
587 operands.append(tripCounts);
588 operands.append(op.getLowerBound().begin(), op.getLowerBound().end());
589 operands.append(op.getUpperBound().begin(), op.getUpperBound().end());
590 operands.append(op.getStep().begin(), op.getStep().end());
591 operands.append(parallelComputeFunction.captures);
597 Value isSingleBlock =
598 b.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, blockCount, c1);
605 appendBlockComputeOperands(operands);
607 b.
create<func::CallOp>(parallelComputeFunction.func.getSymName(),
608 parallelComputeFunction.func.getResultTypes(),
619 Value groupSize = b.
create<arith::SubIOp>(blockCount, c1);
624 appendBlockComputeOperands(operands);
626 b.
create<func::CallOp>(asyncDispatchFunction.getSymName(),
627 asyncDispatchFunction.getResultTypes(), operands);
630 b.
create<AwaitAllOp>(group);
636 b.
create<scf::IfOp>(isSingleBlock, syncDispatch, asyncDispatch);
643 ParallelComputeFunction ¶llelComputeFunction,
644 scf::ParallelOp op,
Value blockSize,
Value blockCount,
648 func::FuncOp compute = parallelComputeFunction.func;
656 Value groupSize = b.
create<arith::SubIOp>(blockCount, c1);
660 using LoopBodyBuilder =
666 computeFuncOperands.append(tripCounts);
667 computeFuncOperands.append(op.getLowerBound().begin(),
668 op.getLowerBound().end());
669 computeFuncOperands.append(op.getUpperBound().begin(),
670 op.getUpperBound().end());
671 computeFuncOperands.append(op.getStep().begin(), op.getStep().end());
672 computeFuncOperands.append(parallelComputeFunction.captures);
673 return computeFuncOperands;
682 auto executeBodyBuilder = [&](
OpBuilder &executeBuilder,
684 executeBuilder.
create<func::CallOp>(executeLoc, compute.getSymName(),
685 compute.getResultTypes(),
686 computeFuncOperands(iv));
701 b.
create<func::CallOp>(compute.getSymName(), compute.getResultTypes(),
702 computeFuncOperands(c0));
705 b.
create<AwaitAllOp>(group);
709 AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
712 if (op.getNumReductions() != 0)
720 Value minTaskSize = computeMinTaskSize(b, op);
729 for (
size_t i = 0; i < op.getNumLoops(); ++i) {
730 auto lb = op.getLowerBound()[i];
731 auto ub = op.getUpperBound()[i];
732 auto step = op.getStep()[i];
733 auto range = b.createOrFold<arith::SubIOp>(ub, lb);
734 tripCounts[i] = b.createOrFold<arith::CeilDivSIOp>(range, step);
739 Value tripCount = tripCounts[0];
740 for (
size_t i = 1; i < tripCounts.size(); ++i)
741 tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]);
745 Value c0 = b.create<arith::ConstantIndexOp>(0);
746 Value isZeroIterations =
747 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, tripCount, c0);
751 nestedBuilder.
create<scf::YieldOp>(loc);
764 ParallelComputeFunctionBounds staticBounds = {
778 static constexpr int64_t maxUnrollableIterations = 512;
782 int numUnrollableLoops = 0;
784 auto getInt = [](IntegerAttr attr) {
return attr ? attr.getInt() : 0; };
787 numIterations.back() = getInt(staticBounds.tripCounts.back());
789 for (
int i = op.getNumLoops() - 2; i >= 0; --i) {
790 int64_t tripCount = getInt(staticBounds.tripCounts[i]);
791 int64_t innerIterations = numIterations[i + 1];
792 numIterations[i] = tripCount * innerIterations;
795 if (innerIterations > 0 && innerIterations <= maxUnrollableIterations)
796 numUnrollableLoops++;
799 Value numWorkerThreadsVal;
800 if (numWorkerThreads >= 0)
801 numWorkerThreadsVal = b.create<arith::ConstantIndexOp>(numWorkerThreads);
803 numWorkerThreadsVal = b.create<async::RuntimeNumWorkerThreadsOp>();
819 {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}};
820 const float initialOvershardingFactor = 8.0f;
822 Value scalingFactor = b.create<arith::ConstantFloatOp>(
823 llvm::APFloat(initialOvershardingFactor), b.getF32Type());
824 for (
const std::pair<int, float> &p : overshardingBrackets) {
825 Value bracketBegin = b.create<arith::ConstantIndexOp>(p.first);
826 Value inBracket = b.create<arith::CmpIOp>(
827 arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
828 Value bracketScalingFactor = b.create<arith::ConstantFloatOp>(
829 llvm::APFloat(p.second), b.getF32Type());
830 scalingFactor = b.create<arith::SelectOp>(inBracket, bracketScalingFactor,
833 Value numWorkersIndex =
834 b.create<arith::IndexCastOp>(b.getI32Type(), numWorkerThreadsVal);
835 Value numWorkersFloat =
836 b.create<arith::SIToFPOp>(b.getF32Type(), numWorkersIndex);
837 Value scaledNumWorkers =
838 b.create<arith::MulFOp>(scalingFactor, numWorkersFloat);
840 b.create<arith::FPToSIOp>(b.getI32Type(), scaledNumWorkers);
841 Value scaledWorkers =
842 b.create<arith::IndexCastOp>(b.getIndexType(), scaledNumInt);
844 Value maxComputeBlocks = b.create<arith::MaxSIOp>(
845 b.create<arith::ConstantIndexOp>(1), scaledWorkers);
851 Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks);
852 Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize);
853 Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
863 Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
867 ParallelComputeFunction compute =
871 doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts);
872 b.create<scf::YieldOp>();
878 op, staticBounds, numUnrollableLoops, rewriter);
883 Value numIters = b.create<arith::ConstantIndexOp>(
884 numIterations[op.getNumLoops() - numUnrollableLoops]);
885 Value alignedBlockSize = b.create<arith::MulIOp>(
886 b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
887 doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount,
889 b.create<scf::YieldOp>();
895 if (numUnrollableLoops > 0) {
896 Value numIters = b.create<arith::ConstantIndexOp>(
897 numIterations[op.getNumLoops() - numUnrollableLoops]);
898 Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>(
899 arith::CmpIPredicate::sge, blockSize, numIters);
901 b.create<scf::IfOp>(useBlockAlignedComputeFn, dispatchBlockAligned,
903 b.create<scf::YieldOp>();
905 dispatchDefault(b, loc);
910 b.create<scf::IfOp>(isZeroIterations, noOp, dispatch);
918 void AsyncParallelForPass::runOnOperation() {
923 patterns, asyncDispatch, numWorkerThreads,
925 return builder.
create<arith::ConstantIndexOp>(minTaskSize);
935 patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
static ParallelComputeFunction createParallelComputeFunction(scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds, unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter)
static func::FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, PatternRewriter &rewriter)
static void doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction ¶llelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction ¶llelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static ParallelComputeFunctionType getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter)
static SmallVector< IntegerAttr > integerConstants(ValueRange values)
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
BlockArgListType getArguments()
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
std::function< Value(ImplicitLocOpBuilder, scf::ParallelOp)> AsyncMinTaskSizeComputationFunction
Emit the IR to compute the minimum number of iterations of scf.parallel body that would be viable for...
void populateAsyncParallelForPatterns(RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, const AsyncMinTaskSizeComputationFunction &computeMinTaskSize)
Add a pattern to the given pattern list to lower scf.parallel to async operations.
void cloneConstantsIntoTheRegion(Region ®ion)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...