27#include "llvm/ADT/SmallVectorExtras.h"
31#define GEN_PASS_DEF_ASYNCPARALLELFORPASS
32#include "mlir/Dialect/Async/Passes.h.inc"
38#define DEBUG_TYPE "async-parallel-for"
101struct AsyncParallelForPass
105 void runOnOperation()
override;
110 AsyncParallelForRewrite(
111 MLIRContext *ctx,
bool asyncDispatch, int32_t numWorkerThreads,
114 numWorkerThreads(numWorkerThreads),
115 computeMinTaskSize(std::move(computeMinTaskSize)) {}
117 LogicalResult matchAndRewrite(scf::ParallelOp op,
126struct ParallelComputeFunctionType {
132struct ParallelComputeFunctionArgs {
144struct ParallelComputeFunctionBounds {
151struct ParallelComputeFunction {
159BlockArgument ParallelComputeFunctionArgs::blockIndex() {
return args[0]; }
160BlockArgument ParallelComputeFunctionArgs::blockSize() {
return args[1]; }
162ArrayRef<BlockArgument> ParallelComputeFunctionArgs::tripCounts() {
163 return args.drop_front(2).take_front(numLoops);
167 return args.drop_front(2 + 1 * numLoops).take_front(numLoops);
171 return args.drop_front(2 + 3 * numLoops).take_front(numLoops);
175 return args.drop_front(2 + 4 * numLoops);
178template <
typename ValueRange>
181 for (
unsigned i = 0; i < values.size(); ++i)
191 assert(!tripCounts.empty() &&
"tripCounts must be not empty");
193 for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
194 coords[i] = arith::RemSIOp::create(
b,
index, tripCounts[i]);
195 index = arith::DivSIOp::create(
b,
index, tripCounts[i]);
204static ParallelComputeFunctionType
211 inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
216 inputs.push_back(indexTy);
217 inputs.push_back(indexTy);
220 for (
unsigned i = 0; i < op.getNumLoops(); ++i)
221 inputs.push_back(indexTy);
226 for (
unsigned i = 0; i < op.getNumLoops(); ++i) {
227 inputs.push_back(indexTy);
228 inputs.push_back(indexTy);
229 inputs.push_back(indexTy);
233 for (
Value capture : captures)
234 inputs.push_back(capture.getType());
243 scf::ParallelOp op,
const ParallelComputeFunctionBounds &bounds,
248 ModuleOp module = op->getParentOfType<ModuleOp>();
250 ParallelComputeFunctionType computeFuncType =
253 FunctionType type = computeFuncType.type;
254 func::FuncOp
func = func::FuncOp::create(
256 numBlockAlignedInnerLoops > 0 ?
"parallel_compute_fn_with_aligned_loops"
257 :
"parallel_compute_fn",
268 b.createBlock(&
func.getBody(),
func.begin(), type.getInputs(),
270 b.setInsertionPointToEnd(block);
272 ParallelComputeFunctionArgs args = {op.getNumLoops(),
func.getArguments()};
284 return llvm::map_to_vector(llvm::zip(args, attrs),
285 [&](
auto tuple) ->
Value {
286 if (IntegerAttr attr = std::get<1>(tuple))
287 return arith::ConstantOp::create(
b, attr);
288 return std::get<0>(tuple);
293 auto tripCounts = values(args.tripCounts(), bounds.tripCounts);
296 auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds);
297 auto steps = values(args.steps(), bounds.steps);
304 Value tripCount = tripCounts[0];
305 for (
unsigned i = 1; i < tripCounts.size(); ++i)
306 tripCount = arith::MulIOp::create(
b, tripCount, tripCounts[i]);
310 Value blockFirstIndex = arith::MulIOp::create(
b, blockIndex, blockSize);
314 Value blockEnd0 = arith::AddIOp::create(
b, blockFirstIndex, blockSize);
315 Value blockEnd1 = arith::MinSIOp::create(
b, blockEnd0, tripCount);
316 Value blockLastIndex = arith::SubIOp::create(
b, blockEnd1, c1);
319 auto blockFirstCoord =
delinearize(
b, blockFirstIndex, tripCounts);
320 auto blockLastCoord =
delinearize(
b, blockLastIndex, tripCounts);
328 for (
size_t i = 0; i < blockLastCoord.size(); ++i)
329 blockEndCoord[i] = arith::AddIOp::create(
b, blockLastCoord[i], c1);
333 using LoopBodyBuilder =
335 using LoopNestBuilder = std::function<LoopBodyBuilder(
size_t loopIdx)>;
366 LoopNestBuilder workLoopBuilder = [&](
size_t loopIdx) -> LoopBodyBuilder {
372 computeBlockInductionVars[loopIdx] =
373 arith::AddIOp::create(
b, lowerBounds[loopIdx],
374 arith::MulIOp::create(
b, iv, steps[loopIdx]));
377 isBlockFirstCoord[loopIdx] = arith::CmpIOp::create(
378 b, arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
379 isBlockLastCoord[loopIdx] = arith::CmpIOp::create(
380 b, arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
384 isBlockFirstCoord[loopIdx] = arith::AndIOp::create(
385 b, isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
386 isBlockLastCoord[loopIdx] = arith::AndIOp::create(
387 b, isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
391 if (loopIdx < op.getNumLoops() - 1) {
392 if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) {
395 scf::ForOp::create(
b, c0, tripCounts[loopIdx + 1], c1,
ValueRange(),
396 workLoopBuilder(loopIdx + 1));
401 auto lb = arith::SelectOp::create(
b, isBlockFirstCoord[loopIdx],
402 blockFirstCoord[loopIdx + 1], c0);
404 auto ub = arith::SelectOp::create(
b, isBlockLastCoord[loopIdx],
405 blockEndCoord[loopIdx + 1],
406 tripCounts[loopIdx + 1]);
409 workLoopBuilder(loopIdx + 1));
412 scf::YieldOp::create(
b, loc);
418 mapping.
map(op.getInductionVars(), computeBlockInductionVars);
419 mapping.
map(computeFuncType.captures, captures);
421 for (
auto &bodyOp : op.getRegion().front().without_terminator())
422 b.clone(bodyOp, mapping);
423 scf::YieldOp::create(
b, loc);
427 scf::ForOp::create(
b, blockFirstCoord[0], blockEndCoord[0], c1,
ValueRange(),
431 return {op.getNumLoops(),
func, std::move(computeFuncType.captures)};
457 Location loc = computeFunc.func.getLoc();
460 ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
463 computeFunc.func.getFunctionType().getInputs();
470 inputTypes.push_back(async::GroupType::get(rewriter.
getContext()));
472 inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end());
475 func::FuncOp
func = func::FuncOp::create(loc,
"async_dispatch_fn", type);
484 Block *block =
b.createBlock(&
func.getBody(),
func.begin(), type.getInputs(),
486 b.setInsertionPointToEnd(block);
488 Type indexTy =
b.getIndexType();
505 scf::WhileOp whileOp = scf::WhileOp::create(
b, types, operands);
506 Block *before =
b.createBlock(&whileOp.getBefore(), {}, types, locations);
507 Block *after =
b.createBlock(&whileOp.getAfter(), {}, types, locations);
512 b.setInsertionPointToEnd(before);
515 Value distance = arith::SubIOp::create(
b, end, start);
517 arith::CmpIOp::create(
b, arith::CmpIPredicate::sgt, distance, c1);
518 scf::ConditionOp::create(
b, dispatch, before->
getArguments());
524 b.setInsertionPointToEnd(after);
527 Value distance = arith::SubIOp::create(
b, end, start);
528 Value halfDistance = arith::DivSIOp::create(
b, distance, c2);
529 Value midIndex = arith::AddIOp::create(
b, start, halfDistance);
532 auto executeBodyBuilder = [&](
OpBuilder &executeBuilder,
537 operands[1] = midIndex;
540 func::CallOp::create(executeBuilder, executeLoc,
func.getSymName(),
541 func.getResultTypes(), operands);
542 async::YieldOp::create(executeBuilder, executeLoc,
ValueRange());
548 AddToGroupOp::create(
b, indexTy, execute.getToken(), group);
549 scf::YieldOp::create(
b,
ValueRange({start, midIndex}));
554 b.setInsertionPointAfter(whileOp);
557 auto forwardedInputs = block->
getArguments().drop_front(3);
559 computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
561 func::CallOp::create(
b, computeFunc.func.getSymName(),
562 computeFunc.func.getResultTypes(), computeFuncOperands);
570 ParallelComputeFunction ¶llelComputeFunction,
571 scf::ParallelOp op,
Value blockSize,
578 func::FuncOp asyncDispatchFunction =
587 operands.append(tripCounts);
588 operands.append(op.getLowerBound().begin(), op.getLowerBound().end());
589 operands.append(op.getUpperBound().begin(), op.getUpperBound().end());
590 operands.append(op.getStep().begin(), op.getStep().end());
591 operands.append(parallelComputeFunction.captures);
597 Value isSingleBlock =
598 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, blockCount, c1);
605 appendBlockComputeOperands(operands);
607 func::CallOp::create(
b, parallelComputeFunction.func.getSymName(),
608 parallelComputeFunction.func.getResultTypes(),
610 scf::YieldOp::create(
b);
619 Value groupSize = arith::SubIOp::create(
b, blockCount, c1);
620 Value group = CreateGroupOp::create(
b, GroupType::get(ctx), groupSize);
624 appendBlockComputeOperands(operands);
626 func::CallOp::create(
b, asyncDispatchFunction.getSymName(),
627 asyncDispatchFunction.getResultTypes(), operands);
630 AwaitAllOp::create(
b, group);
632 scf::YieldOp::create(
b);
636 scf::IfOp::create(
b, isSingleBlock, syncDispatch, asyncDispatch);
643 ParallelComputeFunction ¶llelComputeFunction,
644 scf::ParallelOp op,
Value blockSize,
Value blockCount,
648 func::FuncOp compute = parallelComputeFunction.func;
656 Value groupSize = arith::SubIOp::create(
b, blockCount, c1);
657 Value group = CreateGroupOp::create(
b, GroupType::get(ctx), groupSize);
660 using LoopBodyBuilder =
666 computeFuncOperands.append(tripCounts);
667 computeFuncOperands.append(op.getLowerBound().begin(),
668 op.getLowerBound().end());
669 computeFuncOperands.append(op.getUpperBound().begin(),
670 op.getUpperBound().end());
671 computeFuncOperands.append(op.getStep().begin(), op.getStep().end());
672 computeFuncOperands.append(parallelComputeFunction.captures);
673 return computeFuncOperands;
682 auto executeBodyBuilder = [&](
OpBuilder &executeBuilder,
684 func::CallOp::create(executeBuilder, executeLoc, compute.getSymName(),
685 compute.getResultTypes(), computeFuncOperands(iv));
686 async::YieldOp::create(executeBuilder, executeLoc,
ValueRange());
692 AddToGroupOp::create(
b, rewriter.
getIndexType(), execute.getToken(), group);
693 scf::YieldOp::create(
b);
697 scf::ForOp::create(
b, c1, blockCount, c1,
ValueRange(), loopBuilder);
700 func::CallOp::create(
b, compute.getSymName(), compute.getResultTypes(),
701 computeFuncOperands(c0));
704 AwaitAllOp::create(
b, group);
708AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
709 PatternRewriter &rewriter)
const {
711 if (op.getNumReductions() != 0)
714 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
719 Value minTaskSize = computeMinTaskSize(
b, op);
727 SmallVector<Value> tripCounts(op.getNumLoops());
728 for (
size_t i = 0; i < op.getNumLoops(); ++i) {
729 auto lb = op.getLowerBound()[i];
730 auto ub = op.getUpperBound()[i];
731 auto step = op.getStep()[i];
732 auto range =
b.createOrFold<arith::SubIOp>(ub, lb);
733 tripCounts[i] =
b.createOrFold<arith::CeilDivSIOp>(range, step);
738 Value tripCount = tripCounts[0];
739 for (
size_t i = 1; i < tripCounts.size(); ++i)
740 tripCount = arith::MulIOp::create(
b, tripCount, tripCounts[i]);
745 Value isZeroIterations =
746 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, tripCount, c0);
749 auto noOp = [&](OpBuilder &nestedBuilder, Location loc) {
750 scf::YieldOp::create(nestedBuilder, loc);
755 auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) {
756 ImplicitLocOpBuilder
b(loc, nestedBuilder);
763 ParallelComputeFunctionBounds staticBounds = {
777 static constexpr int64_t maxUnrollableIterations = 512;
781 int numUnrollableLoops = 0;
783 auto getInt = [](IntegerAttr attr) {
return attr ? attr.getInt() : 0; };
785 SmallVector<int64_t> numIterations(op.getNumLoops());
786 numIterations.back() = getInt(staticBounds.tripCounts.back());
788 for (
int i = op.getNumLoops() - 2; i >= 0; --i) {
789 int64_t tripCount = getInt(staticBounds.tripCounts[i]);
790 int64_t innerIterations = numIterations[i + 1];
791 numIterations[i] = tripCount * innerIterations;
794 if (innerIterations > 0 && innerIterations <= maxUnrollableIterations)
795 numUnrollableLoops++;
798 Value numWorkerThreadsVal;
799 if (numWorkerThreads >= 0)
802 numWorkerThreadsVal = async::RuntimeNumWorkerThreadsOp::create(
b);
817 const SmallVector<std::pair<int, float>> overshardingBrackets = {
818 {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}};
819 const float initialOvershardingFactor = 8.0f;
822 b,
b.getF32Type(), llvm::APFloat(initialOvershardingFactor));
823 for (
const std::pair<int, float> &p : overshardingBrackets) {
825 Value inBracket = arith::CmpIOp::create(
826 b, arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
828 b,
b.getF32Type(), llvm::APFloat(p.second));
829 scalingFactor = arith::SelectOp::create(
830 b, inBracket, bracketScalingFactor, scalingFactor);
832 Value numWorkersIndex =
833 arith::IndexCastOp::create(
b,
b.getI32Type(), numWorkerThreadsVal);
834 Value numWorkersFloat =
835 arith::SIToFPOp::create(
b,
b.getF32Type(), numWorkersIndex);
836 Value scaledNumWorkers =
837 arith::MulFOp::create(
b, scalingFactor, numWorkersFloat);
839 arith::FPToSIOp::create(
b,
b.getI32Type(), scaledNumWorkers);
840 Value scaledWorkers =
841 arith::IndexCastOp::create(
b,
b.getIndexType(), scaledNumInt);
843 Value maxComputeBlocks = arith::MaxSIOp::create(
850 Value bs0 = arith::CeilDivSIOp::create(
b, tripCount, maxComputeBlocks);
851 Value bs1 = arith::MaxSIOp::create(
b, bs0, minTaskSize);
852 Value blockSize = arith::MinSIOp::create(
b, tripCount, bs1);
862 Value blockCount = arith::CeilDivSIOp::create(
b, tripCount, blockSize);
865 auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) {
866 ParallelComputeFunction compute =
869 ImplicitLocOpBuilder
b(loc, nestedBuilder);
870 doDispatch(
b, rewriter, compute, op, blockSize, blockCount, tripCounts);
871 scf::YieldOp::create(
b);
875 auto dispatchBlockAligned = [&](OpBuilder &nestedBuilder, Location loc) {
877 op, staticBounds, numUnrollableLoops, rewriter);
879 ImplicitLocOpBuilder
b(loc, nestedBuilder);
883 b, numIterations[op.getNumLoops() - numUnrollableLoops]);
884 Value alignedBlockSize = arith::MulIOp::create(
885 b, arith::CeilDivSIOp::create(
b, blockSize, numIters), numIters);
886 doDispatch(
b, rewriter, compute, op, alignedBlockSize, blockCount,
888 scf::YieldOp::create(
b);
894 if (numUnrollableLoops > 0) {
896 b, numIterations[op.getNumLoops() - numUnrollableLoops]);
897 Value useBlockAlignedComputeFn = arith::CmpIOp::create(
898 b, arith::CmpIPredicate::sge, blockSize, numIters);
900 scf::IfOp::create(
b, useBlockAlignedComputeFn, dispatchBlockAligned,
902 scf::YieldOp::create(
b);
904 dispatchDefault(
b, loc);
909 scf::IfOp::create(
b, isZeroIterations, noOp, dispatch);
917void AsyncParallelForPass::runOnOperation() {
922 patterns, asyncDispatch, numWorkerThreads,
923 [&](ImplicitLocOpBuilder builder, scf::ParallelOp op) {
934 patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
static SmallVector< IntegerAttr > integerConstants(ValueRange values)
static ParallelComputeFunction createParallelComputeFunction(scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds, unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter)
static func::FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, PatternRewriter &rewriter)
static void doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction ¶llelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction ¶llelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static ParallelComputeFunctionType getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter)
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
BlockArgListType getArguments()
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
::mlir::Pass::Option< int32_t > numWorkerThreads
::mlir::Pass::Option< bool > asyncDispatch
void populateAsyncParallelForPatterns(RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, const AsyncMinTaskSizeComputationFunction &computeMinTaskSize)
Add a pattern to the given pattern list to lower scf.parallel to async operations.
void cloneConstantsIntoTheRegion(Region ®ion)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
std::function< Value(ImplicitLocOpBuilder, scf::ParallelOp)> AsyncMinTaskSizeComputationFunction
Emit the IR to compute the minimum number of iterations of scf.parallel body that would be viable for...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...