31 #define GEN_PASS_DEF_ASYNCPARALLELFOR
32 #include "mlir/Dialect/Async/Passes.h.inc"
38 #define DEBUG_TYPE "async-parallel-for"
101 struct AsyncParallelForPass
102 :
public impl::AsyncParallelForBase<AsyncParallelForPass> {
103 AsyncParallelForPass() =
default;
105 AsyncParallelForPass(
bool asyncDispatch, int32_t numWorkerThreads,
106 int32_t minTaskSize) {
107 this->asyncDispatch = asyncDispatch;
108 this->numWorkerThreads = numWorkerThreads;
109 this->minTaskSize = minTaskSize;
112 void runOnOperation()
override;
117 AsyncParallelForRewrite(
118 MLIRContext *ctx,
bool asyncDispatch, int32_t numWorkerThreads,
121 numWorkerThreads(numWorkerThreads),
122 computeMinTaskSize(std::move(computeMinTaskSize)) {}
129 int32_t numWorkerThreads;
133 struct ParallelComputeFunctionType {
139 struct ParallelComputeFunctionArgs {
152 struct ParallelComputeFunctionBounds {
159 struct ParallelComputeFunction {
167 BlockArgument ParallelComputeFunctionArgs::blockIndex() {
return args[0]; }
168 BlockArgument ParallelComputeFunctionArgs::blockSize() {
return args[1]; }
171 return args.drop_front(2).take_front(numLoops);
175 return args.drop_front(2 + 1 * numLoops).take_front(numLoops);
179 return args.drop_front(2 + 2 * numLoops).take_front(numLoops);
183 return args.drop_front(2 + 3 * numLoops).take_front(numLoops);
187 return args.drop_front(2 + 4 * numLoops);
190 template <
typename ValueRange>
193 for (
unsigned i = 0; i < values.size(); ++i)
203 assert(!tripCounts.empty() &&
"tripCounts must be not empty");
205 for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
206 coords[i] = b.
create<arith::RemSIOp>(index, tripCounts[i]);
207 index = b.
create<arith::DivSIOp>(index, tripCounts[i]);
216 static ParallelComputeFunctionType
223 inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
228 inputs.push_back(indexTy);
229 inputs.push_back(indexTy);
232 for (
unsigned i = 0; i < op.getNumLoops(); ++i)
233 inputs.push_back(indexTy);
238 for (
unsigned i = 0; i < op.getNumLoops(); ++i) {
239 inputs.push_back(indexTy);
240 inputs.push_back(indexTy);
241 inputs.push_back(indexTy);
245 for (
Value capture : captures)
246 inputs.push_back(capture.getType());
255 scf::ParallelOp op,
const ParallelComputeFunctionBounds &bounds,
262 ParallelComputeFunctionType computeFuncType =
265 FunctionType type = computeFuncType.type;
266 func::FuncOp func = func::FuncOp::create(
268 numBlockAlignedInnerLoops > 0 ?
"parallel_compute_fn_with_aligned_loops"
269 :
"parallel_compute_fn",
280 b.
createBlock(&func.getBody(), func.begin(), type.getInputs(),
284 ParallelComputeFunctionArgs args = {op.getNumLoops(), func.getArguments()};
296 return llvm::to_vector(
297 llvm::map_range(llvm::zip(args, attrs), [&](
auto tuple) ->
Value {
298 if (IntegerAttr attr = std::get<1>(tuple))
299 return b.
create<arith::ConstantOp>(attr);
300 return std::get<0>(tuple);
305 auto tripCounts = values(args.tripCounts(), bounds.tripCounts);
308 auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds);
309 auto steps = values(args.steps(), bounds.steps);
316 Value tripCount = tripCounts[0];
317 for (
unsigned i = 1; i < tripCounts.size(); ++i)
318 tripCount = b.
create<arith::MulIOp>(tripCount, tripCounts[i]);
322 Value blockFirstIndex = b.
create<arith::MulIOp>(blockIndex, blockSize);
326 Value blockEnd0 = b.
create<arith::AddIOp>(blockFirstIndex, blockSize);
327 Value blockEnd1 = b.
create<arith::MinSIOp>(blockEnd0, tripCount);
328 Value blockLastIndex = b.
create<arith::SubIOp>(blockEnd1, c1);
331 auto blockFirstCoord =
delinearize(b, blockFirstIndex, tripCounts);
332 auto blockLastCoord =
delinearize(b, blockLastIndex, tripCounts);
340 for (
size_t i = 0; i < blockLastCoord.size(); ++i)
341 blockEndCoord[i] = b.
create<arith::AddIOp>(blockLastCoord[i], c1);
345 using LoopBodyBuilder =
347 using LoopNestBuilder = std::function<LoopBodyBuilder(
size_t loopIdx)>;
378 LoopNestBuilder workLoopBuilder = [&](
size_t loopIdx) -> LoopBodyBuilder {
384 computeBlockInductionVars[loopIdx] = b.
create<arith::AddIOp>(
385 lowerBounds[loopIdx], b.
create<arith::MulIOp>(iv, steps[loopIdx]));
388 isBlockFirstCoord[loopIdx] = b.
create<arith::CmpIOp>(
389 arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
390 isBlockLastCoord[loopIdx] = b.
create<arith::CmpIOp>(
391 arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
395 isBlockFirstCoord[loopIdx] = b.
create<arith::AndIOp>(
396 isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
397 isBlockLastCoord[loopIdx] = b.
create<arith::AndIOp>(
398 isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
402 if (loopIdx < op.getNumLoops() - 1) {
403 if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) {
407 workLoopBuilder(loopIdx + 1));
412 auto lb = b.
create<arith::SelectOp>(isBlockFirstCoord[loopIdx],
413 blockFirstCoord[loopIdx + 1], c0);
415 auto ub = b.
create<arith::SelectOp>(isBlockLastCoord[loopIdx],
416 blockEndCoord[loopIdx + 1],
417 tripCounts[loopIdx + 1]);
420 workLoopBuilder(loopIdx + 1));
423 b.
create<scf::YieldOp>(loc);
429 mapping.
map(op.getInductionVars(), computeBlockInductionVars);
430 mapping.
map(computeFuncType.captures, captures);
433 b.
clone(bodyOp, mapping);
437 b.
create<scf::ForOp>(blockFirstCoord[0], blockEndCoord[0], c1,
ValueRange(),
441 return {op.getNumLoops(), func, std::move(computeFuncType.captures)};
467 Location loc = computeFunc.func.getLoc();
470 ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
473 computeFunc.func.getFunctionType().getInputs();
482 inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end());
485 func::FuncOp func = func::FuncOp::create(loc,
"async_dispatch_fn", type);
494 Block *block = b.
createBlock(&func.getBody(), func.begin(), type.getInputs(),
515 scf::WhileOp whileOp = b.
create<scf::WhileOp>(types, operands);
523 Value start = before->getArgument(0);
524 Value end = before->getArgument(1);
525 Value distance = b.
create<arith::SubIOp>(end, start);
527 b.
create<arith::CmpIOp>(arith::CmpIPredicate::sgt, distance, c1);
528 b.
create<scf::ConditionOp>(dispatch, before->getArguments());
537 Value distance = b.
create<arith::SubIOp>(end, start);
538 Value halfDistance = b.
create<arith::DivSIOp>(distance, c2);
539 Value midIndex = b.
create<arith::AddIOp>(start, halfDistance);
542 auto executeBodyBuilder = [&](
OpBuilder &executeBuilder,
547 operands[1] = midIndex;
550 executeBuilder.
create<func::CallOp>(executeLoc, func.getSymName(),
551 func.getResultTypes(), operands);
558 b.
create<AddToGroupOp>(indexTy, execute.getToken(), group);
567 auto forwardedInputs = block->
getArguments().drop_front(3);
569 computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
571 b.
create<func::CallOp>(computeFunc.func.getSymName(),
572 computeFunc.func.getResultTypes(),
573 computeFuncOperands);
581 ParallelComputeFunction ¶llelComputeFunction,
582 scf::ParallelOp op,
Value blockSize,
589 func::FuncOp asyncDispatchFunction =
598 operands.append(tripCounts);
599 operands.append(op.getLowerBound().begin(), op.getLowerBound().end());
600 operands.append(op.getUpperBound().begin(), op.getUpperBound().end());
601 operands.append(op.getStep().begin(), op.getStep().end());
602 operands.append(parallelComputeFunction.captures);
608 Value isSingleBlock =
609 b.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, blockCount, c1);
616 appendBlockComputeOperands(operands);
618 b.
create<func::CallOp>(parallelComputeFunction.func.getSymName(),
619 parallelComputeFunction.func.getResultTypes(),
630 Value groupSize = b.
create<arith::SubIOp>(blockCount, c1);
635 appendBlockComputeOperands(operands);
637 b.
create<func::CallOp>(asyncDispatchFunction.getSymName(),
638 asyncDispatchFunction.getResultTypes(), operands);
641 b.
create<AwaitAllOp>(group);
647 b.
create<scf::IfOp>(isSingleBlock, syncDispatch, asyncDispatch);
654 ParallelComputeFunction ¶llelComputeFunction,
655 scf::ParallelOp op,
Value blockSize,
Value blockCount,
659 func::FuncOp compute = parallelComputeFunction.func;
667 Value groupSize = b.
create<arith::SubIOp>(blockCount, c1);
671 using LoopBodyBuilder =
677 computeFuncOperands.append(tripCounts);
678 computeFuncOperands.append(op.getLowerBound().begin(),
679 op.getLowerBound().end());
680 computeFuncOperands.append(op.getUpperBound().begin(),
681 op.getUpperBound().end());
682 computeFuncOperands.append(op.getStep().begin(), op.getStep().end());
683 computeFuncOperands.append(parallelComputeFunction.captures);
684 return computeFuncOperands;
693 auto executeBodyBuilder = [&](
OpBuilder &executeBuilder,
695 executeBuilder.
create<func::CallOp>(executeLoc, compute.getSymName(),
696 compute.getResultTypes(),
697 computeFuncOperands(iv));
712 b.
create<func::CallOp>(compute.getSymName(), compute.getResultTypes(),
713 computeFuncOperands(c0));
716 b.
create<AwaitAllOp>(group);
720 AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
723 if (op.getNumReductions() != 0)
731 Value minTaskSize = computeMinTaskSize(b, op);
740 for (
size_t i = 0; i < op.getNumLoops(); ++i) {
741 auto lb = op.getLowerBound()[i];
742 auto ub = op.getUpperBound()[i];
743 auto step = op.getStep()[i];
744 auto range = b.createOrFold<arith::SubIOp>(ub, lb);
745 tripCounts[i] = b.createOrFold<arith::CeilDivSIOp>(range, step);
750 Value tripCount = tripCounts[0];
751 for (
size_t i = 1; i < tripCounts.size(); ++i)
752 tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]);
756 Value c0 = b.create<arith::ConstantIndexOp>(0);
757 Value isZeroIterations =
758 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, tripCount, c0);
762 nestedBuilder.
create<scf::YieldOp>(loc);
775 ParallelComputeFunctionBounds staticBounds = {
789 static constexpr int64_t maxUnrollableIterations = 512;
793 int numUnrollableLoops = 0;
795 auto getInt = [](IntegerAttr attr) {
return attr ? attr.getInt() : 0; };
798 numIterations.back() = getInt(staticBounds.tripCounts.back());
800 for (
int i = op.getNumLoops() - 2; i >= 0; --i) {
801 int64_t tripCount = getInt(staticBounds.tripCounts[i]);
802 int64_t innerIterations = numIterations[i + 1];
803 numIterations[i] = tripCount * innerIterations;
806 if (innerIterations > 0 && innerIterations <= maxUnrollableIterations)
807 numUnrollableLoops++;
810 Value numWorkerThreadsVal;
811 if (numWorkerThreads >= 0)
812 numWorkerThreadsVal = b.create<arith::ConstantIndexOp>(numWorkerThreads);
814 numWorkerThreadsVal = b.create<async::RuntimeNumWorkerThreadsOp>();
830 {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}};
831 const float initialOvershardingFactor = 8.0f;
833 Value scalingFactor = b.create<arith::ConstantFloatOp>(
834 llvm::APFloat(initialOvershardingFactor), b.getF32Type());
835 for (
const std::pair<int, float> &p : overshardingBrackets) {
836 Value bracketBegin = b.create<arith::ConstantIndexOp>(p.first);
837 Value inBracket = b.create<arith::CmpIOp>(
838 arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
839 Value bracketScalingFactor = b.create<arith::ConstantFloatOp>(
840 llvm::APFloat(p.second), b.getF32Type());
841 scalingFactor = b.create<arith::SelectOp>(inBracket, bracketScalingFactor,
844 Value numWorkersIndex =
845 b.create<arith::IndexCastOp>(b.getI32Type(), numWorkerThreadsVal);
846 Value numWorkersFloat =
847 b.create<arith::SIToFPOp>(b.getF32Type(), numWorkersIndex);
848 Value scaledNumWorkers =
849 b.create<arith::MulFOp>(scalingFactor, numWorkersFloat);
851 b.create<arith::FPToSIOp>(b.getI32Type(), scaledNumWorkers);
852 Value scaledWorkers =
853 b.create<arith::IndexCastOp>(b.getIndexType(), scaledNumInt);
855 Value maxComputeBlocks = b.create<arith::MaxSIOp>(
856 b.create<arith::ConstantIndexOp>(1), scaledWorkers);
862 Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks);
863 Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize);
864 Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
874 Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
878 ParallelComputeFunction compute =
882 doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts);
883 b.create<scf::YieldOp>();
889 op, staticBounds, numUnrollableLoops, rewriter);
894 Value numIters = b.create<arith::ConstantIndexOp>(
895 numIterations[op.getNumLoops() - numUnrollableLoops]);
896 Value alignedBlockSize = b.create<arith::MulIOp>(
897 b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
898 doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount,
900 b.create<scf::YieldOp>();
906 if (numUnrollableLoops > 0) {
907 Value numIters = b.create<arith::ConstantIndexOp>(
908 numIterations[op.getNumLoops() - numUnrollableLoops]);
909 Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>(
910 arith::CmpIPredicate::sge, blockSize, numIters);
912 b.create<scf::IfOp>(useBlockAlignedComputeFn, dispatchBlockAligned,
914 b.create<scf::YieldOp>();
916 dispatchDefault(b, loc);
921 b.create<scf::IfOp>(isZeroIterations, noOp, dispatch);
929 void AsyncParallelForPass::runOnOperation() {
934 patterns, asyncDispatch, numWorkerThreads,
936 return builder.
create<arith::ConstantIndexOp>(minTaskSize);
943 return std::make_unique<AsyncParallelForPass>();
947 int32_t numWorkerThreads,
948 int32_t minTaskSize) {
949 return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads,
957 patterns.
add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
static ParallelComputeFunction createParallelComputeFunction(scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds, unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter)
static func::FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, PatternRewriter &rewriter)
static void doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction ¶llelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction ¶llelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static ParallelComputeFunctionType getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter)
static SmallVector< IntegerAttr > integerConstants(ValueRange values)
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
BlockArgListType getArguments()
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
iterator_range< OpIterator > getOps()
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
std::function< Value(ImplicitLocOpBuilder, scf::ParallelOp)> AsyncMinTaskSizeComputationFunction
Emit the IR to compute the minimum number of iterations of scf.parallel body that would be viable for...
void populateAsyncParallelForPatterns(RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, const AsyncMinTaskSizeComputationFunction &computeMinTaskSize)
Add a pattern to the given pattern list to lower scf.parallel to async operations.
void cloneConstantsIntoTheRegion(Region ®ion)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::unique_ptr< Pass > createAsyncParallelForPass()
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
virtual void notifyOperationInserted(Operation *op)
Notification handler for when an operation is inserted into the builder.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...