31 #define GEN_PASS_DEF_ASYNCPARALLELFOR
32 #include "mlir/Dialect/Async/Passes.h.inc"
38 #define DEBUG_TYPE "async-parallel-for"
101 struct AsyncParallelForPass
102 :
public impl::AsyncParallelForBase<AsyncParallelForPass> {
103 AsyncParallelForPass() =
default;
105 AsyncParallelForPass(
bool asyncDispatch, int32_t numWorkerThreads,
106 int32_t minTaskSize) {
107 this->asyncDispatch = asyncDispatch;
108 this->numWorkerThreads = numWorkerThreads;
109 this->minTaskSize = minTaskSize;
112 void runOnOperation()
override;
117 AsyncParallelForRewrite(
118 MLIRContext *ctx,
bool asyncDispatch, int32_t numWorkerThreads,
121 numWorkerThreads(numWorkerThreads),
122 computeMinTaskSize(std::move(computeMinTaskSize)) {}
124 LogicalResult matchAndRewrite(scf::ParallelOp op,
129 int32_t numWorkerThreads;
133 struct ParallelComputeFunctionType {
139 struct ParallelComputeFunctionArgs {
151 struct ParallelComputeFunctionBounds {
158 struct ParallelComputeFunction {
166 BlockArgument ParallelComputeFunctionArgs::blockIndex() {
return args[0]; }
167 BlockArgument ParallelComputeFunctionArgs::blockSize() {
return args[1]; }
170 return args.drop_front(2).take_front(numLoops);
174 return args.drop_front(2 + 1 * numLoops).take_front(numLoops);
178 return args.drop_front(2 + 3 * numLoops).take_front(numLoops);
182 return args.drop_front(2 + 4 * numLoops);
185 template <
typename ValueRange>
188 for (
unsigned i = 0; i < values.size(); ++i)
198 assert(!tripCounts.empty() &&
"tripCounts must be not empty");
200 for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
201 coords[i] = b.
create<arith::RemSIOp>(index, tripCounts[i]);
202 index = b.
create<arith::DivSIOp>(index, tripCounts[i]);
211 static ParallelComputeFunctionType
218 inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
223 inputs.push_back(indexTy);
224 inputs.push_back(indexTy);
227 for (
unsigned i = 0; i < op.getNumLoops(); ++i)
228 inputs.push_back(indexTy);
233 for (
unsigned i = 0; i < op.getNumLoops(); ++i) {
234 inputs.push_back(indexTy);
235 inputs.push_back(indexTy);
236 inputs.push_back(indexTy);
240 for (
Value capture : captures)
241 inputs.push_back(capture.getType());
250 scf::ParallelOp op,
const ParallelComputeFunctionBounds &bounds,
255 ModuleOp module = op->getParentOfType<ModuleOp>();
257 ParallelComputeFunctionType computeFuncType =
260 FunctionType type = computeFuncType.type;
261 func::FuncOp func = func::FuncOp::create(
263 numBlockAlignedInnerLoops > 0 ?
"parallel_compute_fn_with_aligned_loops"
264 :
"parallel_compute_fn",
275 b.
createBlock(&func.getBody(), func.begin(), type.getInputs(),
279 ParallelComputeFunctionArgs args = {op.getNumLoops(), func.getArguments()};
291 return llvm::to_vector(
292 llvm::map_range(llvm::zip(args, attrs), [&](
auto tuple) ->
Value {
293 if (IntegerAttr attr = std::get<1>(tuple))
294 return b.
create<arith::ConstantOp>(attr);
295 return std::get<0>(tuple);
300 auto tripCounts = values(args.tripCounts(), bounds.tripCounts);
303 auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds);
304 auto steps = values(args.steps(), bounds.steps);
311 Value tripCount = tripCounts[0];
312 for (
unsigned i = 1; i < tripCounts.size(); ++i)
313 tripCount = b.
create<arith::MulIOp>(tripCount, tripCounts[i]);
317 Value blockFirstIndex = b.
create<arith::MulIOp>(blockIndex, blockSize);
321 Value blockEnd0 = b.
create<arith::AddIOp>(blockFirstIndex, blockSize);
322 Value blockEnd1 = b.
create<arith::MinSIOp>(blockEnd0, tripCount);
323 Value blockLastIndex = b.
create<arith::SubIOp>(blockEnd1, c1);
326 auto blockFirstCoord =
delinearize(b, blockFirstIndex, tripCounts);
327 auto blockLastCoord =
delinearize(b, blockLastIndex, tripCounts);
335 for (
size_t i = 0; i < blockLastCoord.size(); ++i)
336 blockEndCoord[i] = b.
create<arith::AddIOp>(blockLastCoord[i], c1);
340 using LoopBodyBuilder =
342 using LoopNestBuilder = std::function<LoopBodyBuilder(
size_t loopIdx)>;
373 LoopNestBuilder workLoopBuilder = [&](
size_t loopIdx) -> LoopBodyBuilder {
379 computeBlockInductionVars[loopIdx] = b.
create<arith::AddIOp>(
380 lowerBounds[loopIdx], b.
create<arith::MulIOp>(iv, steps[loopIdx]));
383 isBlockFirstCoord[loopIdx] = b.
create<arith::CmpIOp>(
384 arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
385 isBlockLastCoord[loopIdx] = b.
create<arith::CmpIOp>(
386 arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
390 isBlockFirstCoord[loopIdx] = b.
create<arith::AndIOp>(
391 isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
392 isBlockLastCoord[loopIdx] = b.
create<arith::AndIOp>(
393 isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
397 if (loopIdx < op.getNumLoops() - 1) {
398 if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) {
402 workLoopBuilder(loopIdx + 1));
407 auto lb = b.
create<arith::SelectOp>(isBlockFirstCoord[loopIdx],
408 blockFirstCoord[loopIdx + 1], c0);
410 auto ub = b.
create<arith::SelectOp>(isBlockLastCoord[loopIdx],
411 blockEndCoord[loopIdx + 1],
412 tripCounts[loopIdx + 1]);
415 workLoopBuilder(loopIdx + 1));
418 b.
create<scf::YieldOp>(loc);
424 mapping.
map(op.getInductionVars(), computeBlockInductionVars);
425 mapping.
map(computeFuncType.captures, captures);
427 for (
auto &bodyOp : op.getRegion().front().without_terminator())
428 b.
clone(bodyOp, mapping);
429 b.
create<scf::YieldOp>(loc);
433 b.
create<scf::ForOp>(blockFirstCoord[0], blockEndCoord[0], c1,
ValueRange(),
437 return {op.getNumLoops(), func, std::move(computeFuncType.captures)};
463 Location loc = computeFunc.func.getLoc();
466 ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
469 computeFunc.func.getFunctionType().getInputs();
478 inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end());
481 func::FuncOp func = func::FuncOp::create(loc,
"async_dispatch_fn", type);
490 Block *block = b.
createBlock(&func.getBody(), func.begin(), type.getInputs(),
511 scf::WhileOp whileOp = b.
create<scf::WhileOp>(types, operands);
519 Value start = before->getArgument(0);
520 Value end = before->getArgument(1);
521 Value distance = b.
create<arith::SubIOp>(end, start);
523 b.
create<arith::CmpIOp>(arith::CmpIPredicate::sgt, distance, c1);
524 b.
create<scf::ConditionOp>(dispatch, before->getArguments());
533 Value distance = b.
create<arith::SubIOp>(end, start);
534 Value halfDistance = b.
create<arith::DivSIOp>(distance, c2);
535 Value midIndex = b.
create<arith::AddIOp>(start, halfDistance);
538 auto executeBodyBuilder = [&](
OpBuilder &executeBuilder,
543 operands[1] = midIndex;
546 executeBuilder.
create<func::CallOp>(executeLoc, func.getSymName(),
547 func.getResultTypes(), operands);
554 b.
create<AddToGroupOp>(indexTy, execute.getToken(), group);
563 auto forwardedInputs = block->
getArguments().drop_front(3);
565 computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
567 b.
create<func::CallOp>(computeFunc.func.getSymName(),
568 computeFunc.func.getResultTypes(),
569 computeFuncOperands);
577 ParallelComputeFunction ¶llelComputeFunction,
578 scf::ParallelOp op,
Value blockSize,
585 func::FuncOp asyncDispatchFunction =
594 operands.append(tripCounts);
595 operands.append(op.getLowerBound().begin(), op.getLowerBound().end());
596 operands.append(op.getUpperBound().begin(), op.getUpperBound().end());
597 operands.append(op.getStep().begin(), op.getStep().end());
598 operands.append(parallelComputeFunction.captures);
604 Value isSingleBlock =
605 b.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, blockCount, c1);
612 appendBlockComputeOperands(operands);
614 b.
create<func::CallOp>(parallelComputeFunction.func.getSymName(),
615 parallelComputeFunction.func.getResultTypes(),
626 Value groupSize = b.
create<arith::SubIOp>(blockCount, c1);
631 appendBlockComputeOperands(operands);
633 b.
create<func::CallOp>(asyncDispatchFunction.getSymName(),
634 asyncDispatchFunction.getResultTypes(), operands);
637 b.
create<AwaitAllOp>(group);
643 b.
create<scf::IfOp>(isSingleBlock, syncDispatch, asyncDispatch);
650 ParallelComputeFunction ¶llelComputeFunction,
651 scf::ParallelOp op,
Value blockSize,
Value blockCount,
655 func::FuncOp compute = parallelComputeFunction.func;
663 Value groupSize = b.
create<arith::SubIOp>(blockCount, c1);
667 using LoopBodyBuilder =
673 computeFuncOperands.append(tripCounts);
674 computeFuncOperands.append(op.getLowerBound().begin(),
675 op.getLowerBound().end());
676 computeFuncOperands.append(op.getUpperBound().begin(),
677 op.getUpperBound().end());
678 computeFuncOperands.append(op.getStep().begin(), op.getStep().end());
679 computeFuncOperands.append(parallelComputeFunction.captures);
680 return computeFuncOperands;
689 auto executeBodyBuilder = [&](
OpBuilder &executeBuilder,
691 executeBuilder.
create<func::CallOp>(executeLoc, compute.getSymName(),
692 compute.getResultTypes(),
693 computeFuncOperands(iv));
708 b.
create<func::CallOp>(compute.getSymName(), compute.getResultTypes(),
709 computeFuncOperands(c0));
712 b.
create<AwaitAllOp>(group);
716 AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
719 if (op.getNumReductions() != 0)
727 Value minTaskSize = computeMinTaskSize(b, op);
736 for (
size_t i = 0; i < op.getNumLoops(); ++i) {
737 auto lb = op.getLowerBound()[i];
738 auto ub = op.getUpperBound()[i];
739 auto step = op.getStep()[i];
740 auto range = b.createOrFold<arith::SubIOp>(ub, lb);
741 tripCounts[i] = b.createOrFold<arith::CeilDivSIOp>(range, step);
746 Value tripCount = tripCounts[0];
747 for (
size_t i = 1; i < tripCounts.size(); ++i)
748 tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]);
752 Value c0 = b.create<arith::ConstantIndexOp>(0);
753 Value isZeroIterations =
754 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, tripCount, c0);
758 nestedBuilder.
create<scf::YieldOp>(loc);
771 ParallelComputeFunctionBounds staticBounds = {
785 static constexpr int64_t maxUnrollableIterations = 512;
789 int numUnrollableLoops = 0;
791 auto getInt = [](IntegerAttr attr) {
return attr ? attr.getInt() : 0; };
794 numIterations.back() = getInt(staticBounds.tripCounts.back());
796 for (
int i = op.getNumLoops() - 2; i >= 0; --i) {
797 int64_t tripCount = getInt(staticBounds.tripCounts[i]);
798 int64_t innerIterations = numIterations[i + 1];
799 numIterations[i] = tripCount * innerIterations;
802 if (innerIterations > 0 && innerIterations <= maxUnrollableIterations)
803 numUnrollableLoops++;
806 Value numWorkerThreadsVal;
807 if (numWorkerThreads >= 0)
808 numWorkerThreadsVal = b.create<arith::ConstantIndexOp>(numWorkerThreads);
810 numWorkerThreadsVal = b.create<async::RuntimeNumWorkerThreadsOp>();
826 {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}};
827 const float initialOvershardingFactor = 8.0f;
829 Value scalingFactor = b.create<arith::ConstantFloatOp>(
830 llvm::APFloat(initialOvershardingFactor), b.getF32Type());
831 for (
const std::pair<int, float> &p : overshardingBrackets) {
832 Value bracketBegin = b.create<arith::ConstantIndexOp>(p.first);
833 Value inBracket = b.create<arith::CmpIOp>(
834 arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
835 Value bracketScalingFactor = b.create<arith::ConstantFloatOp>(
836 llvm::APFloat(p.second), b.getF32Type());
837 scalingFactor = b.create<arith::SelectOp>(inBracket, bracketScalingFactor,
840 Value numWorkersIndex =
841 b.create<arith::IndexCastOp>(b.getI32Type(), numWorkerThreadsVal);
842 Value numWorkersFloat =
843 b.create<arith::SIToFPOp>(b.getF32Type(), numWorkersIndex);
844 Value scaledNumWorkers =
845 b.create<arith::MulFOp>(scalingFactor, numWorkersFloat);
847 b.create<arith::FPToSIOp>(b.getI32Type(), scaledNumWorkers);
848 Value scaledWorkers =
849 b.create<arith::IndexCastOp>(b.getIndexType(), scaledNumInt);
851 Value maxComputeBlocks = b.create<arith::MaxSIOp>(
852 b.create<arith::ConstantIndexOp>(1), scaledWorkers);
858 Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks);
859 Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize);
860 Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
870 Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
874 ParallelComputeFunction compute =
878 doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts);
879 b.create<scf::YieldOp>();
885 op, staticBounds, numUnrollableLoops, rewriter);
890 Value numIters = b.create<arith::ConstantIndexOp>(
891 numIterations[op.getNumLoops() - numUnrollableLoops]);
892 Value alignedBlockSize = b.create<arith::MulIOp>(
893 b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
894 doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount,
896 b.create<scf::YieldOp>();
902 if (numUnrollableLoops > 0) {
903 Value numIters = b.create<arith::ConstantIndexOp>(
904 numIterations[op.getNumLoops() - numUnrollableLoops]);
905 Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>(
906 arith::CmpIPredicate::sge, blockSize, numIters);
908 b.create<scf::IfOp>(useBlockAlignedComputeFn, dispatchBlockAligned,
910 b.create<scf::YieldOp>();
912 dispatchDefault(b, loc);
917 b.create<scf::IfOp>(isZeroIterations, noOp, dispatch);
925 void AsyncParallelForPass::runOnOperation() {
930 patterns, asyncDispatch, numWorkerThreads,
932 return builder.
create<arith::ConstantIndexOp>(minTaskSize);
939 return std::make_unique<AsyncParallelForPass>();
943 int32_t numWorkerThreads,
944 int32_t minTaskSize) {
945 return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads,
953 patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
static ParallelComputeFunction createParallelComputeFunction(scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds, unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter)
static func::FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, PatternRewriter &rewriter)
static void doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction ¶llelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction ¶llelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static ParallelComputeFunctionType getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter)
static SmallVector< IntegerAttr > integerConstants(ValueRange values)
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
BlockArgListType getArguments()
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
std::function< Value(ImplicitLocOpBuilder, scf::ParallelOp)> AsyncMinTaskSizeComputationFunction
Emit the IR to compute the minimum number of iterations of scf.parallel body that would be viable for...
void populateAsyncParallelForPatterns(RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, const AsyncMinTaskSizeComputationFunction &computeMinTaskSize)
Add a pattern to the given pattern list to lower scf.parallel to async operations.
void cloneConstantsIntoTheRegion(Region ®ion)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createAsyncParallelForPass()
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...