31 #define GEN_PASS_DEF_ASYNCPARALLELFOR
32 #include "mlir/Dialect/Async/Passes.h.inc"
38 #define DEBUG_TYPE "async-parallel-for"
101 struct AsyncParallelForPass
102 :
public impl::AsyncParallelForBase<AsyncParallelForPass> {
103 AsyncParallelForPass() =
default;
105 AsyncParallelForPass(
bool asyncDispatch, int32_t numWorkerThreads,
106 int32_t minTaskSize) {
107 this->asyncDispatch = asyncDispatch;
108 this->numWorkerThreads = numWorkerThreads;
109 this->minTaskSize = minTaskSize;
112 void runOnOperation()
override;
117 AsyncParallelForRewrite(
118 MLIRContext *ctx,
bool asyncDispatch, int32_t numWorkerThreads,
121 numWorkerThreads(numWorkerThreads),
122 computeMinTaskSize(std::move(computeMinTaskSize)) {}
129 int32_t numWorkerThreads;
133 struct ParallelComputeFunctionType {
139 struct ParallelComputeFunctionArgs {
152 struct ParallelComputeFunctionBounds {
159 struct ParallelComputeFunction {
167 BlockArgument ParallelComputeFunctionArgs::blockIndex() {
return args[0]; }
168 BlockArgument ParallelComputeFunctionArgs::blockSize() {
return args[1]; }
171 return args.drop_front(2).take_front(numLoops);
175 return args.drop_front(2 + 1 * numLoops).take_front(numLoops);
179 return args.drop_front(2 + 2 * numLoops).take_front(numLoops);
183 return args.drop_front(2 + 3 * numLoops).take_front(numLoops);
187 return args.drop_front(2 + 4 * numLoops);
190 template <
typename ValueRange>
193 for (
unsigned i = 0; i < values.size(); ++i)
203 assert(!tripCounts.empty() &&
"tripCounts must be not empty");
205 for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
206 coords[i] = b.
create<arith::RemSIOp>(index, tripCounts[i]);
207 index = b.
create<arith::DivSIOp>(index, tripCounts[i]);
216 static ParallelComputeFunctionType
223 inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
228 inputs.push_back(indexTy);
229 inputs.push_back(indexTy);
232 for (
unsigned i = 0; i < op.getNumLoops(); ++i)
233 inputs.push_back(indexTy);
238 for (
unsigned i = 0; i < op.getNumLoops(); ++i) {
239 inputs.push_back(indexTy);
240 inputs.push_back(indexTy);
241 inputs.push_back(indexTy);
245 for (
Value capture : captures)
246 inputs.push_back(capture.getType());
255 scf::ParallelOp op,
const ParallelComputeFunctionBounds &bounds,
262 ParallelComputeFunctionType computeFuncType =
265 FunctionType type = computeFuncType.type;
266 func::FuncOp func = func::FuncOp::create(
268 numBlockAlignedInnerLoops > 0 ?
"parallel_compute_fn_with_aligned_loops"
269 :
"parallel_compute_fn",
280 b.
createBlock(&func.getBody(), func.begin(), type.getInputs(),
284 ParallelComputeFunctionArgs args = {op.getNumLoops(), func.getArguments()};
296 return llvm::to_vector(
297 llvm::map_range(llvm::zip(args, attrs), [&](
auto tuple) ->
Value {
298 if (IntegerAttr attr = std::get<1>(tuple))
299 return b.
create<arith::ConstantOp>(attr);
300 return std::get<0>(tuple);
305 auto tripCounts = values(args.tripCounts(), bounds.tripCounts);
308 auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds);
309 auto steps = values(args.steps(), bounds.steps);
316 Value tripCount = tripCounts[0];
317 for (
unsigned i = 1; i < tripCounts.size(); ++i)
318 tripCount = b.
create<arith::MulIOp>(tripCount, tripCounts[i]);
322 Value blockFirstIndex = b.
create<arith::MulIOp>(blockIndex, blockSize);
326 Value blockEnd0 = b.
create<arith::AddIOp>(blockFirstIndex, blockSize);
327 Value blockEnd1 = b.
create<arith::MinSIOp>(blockEnd0, tripCount);
328 Value blockLastIndex = b.
create<arith::SubIOp>(blockEnd1, c1);
331 auto blockFirstCoord =
delinearize(b, blockFirstIndex, tripCounts);
332 auto blockLastCoord =
delinearize(b, blockLastIndex, tripCounts);
340 for (
size_t i = 0; i < blockLastCoord.size(); ++i)
341 blockEndCoord[i] = b.
create<arith::AddIOp>(blockLastCoord[i], c1);
345 using LoopBodyBuilder =
347 using LoopNestBuilder = std::function<LoopBodyBuilder(
size_t loopIdx)>;
378 LoopNestBuilder workLoopBuilder = [&](
size_t loopIdx) -> LoopBodyBuilder {
384 computeBlockInductionVars[loopIdx] = b.
create<arith::AddIOp>(
385 lowerBounds[loopIdx], b.
create<arith::MulIOp>(iv, steps[loopIdx]));
388 isBlockFirstCoord[loopIdx] = b.
create<arith::CmpIOp>(
389 arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
390 isBlockLastCoord[loopIdx] = b.
create<arith::CmpIOp>(
391 arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
395 isBlockFirstCoord[loopIdx] = b.
create<arith::AndIOp>(
396 isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
397 isBlockLastCoord[loopIdx] = b.
create<arith::AndIOp>(
398 isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
402 if (loopIdx < op.getNumLoops() - 1) {
403 if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) {
407 workLoopBuilder(loopIdx + 1));
412 auto lb = b.
create<arith::SelectOp>(isBlockFirstCoord[loopIdx],
413 blockFirstCoord[loopIdx + 1], c0);
415 auto ub = b.
create<arith::SelectOp>(isBlockLastCoord[loopIdx],
416 blockEndCoord[loopIdx + 1],
417 tripCounts[loopIdx + 1]);
420 workLoopBuilder(loopIdx + 1));
423 b.
create<scf::YieldOp>(loc);
429 mapping.
map(op.getInductionVars(), computeBlockInductionVars);
430 mapping.
map(computeFuncType.captures, captures);
433 b.
clone(bodyOp, mapping);
434 b.
create<scf::YieldOp>(loc);
438 b.
create<scf::ForOp>(blockFirstCoord[0], blockEndCoord[0], c1,
ValueRange(),
442 return {op.getNumLoops(), func, std::move(computeFuncType.captures)};
468 Location loc = computeFunc.func.getLoc();
471 ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
474 computeFunc.func.getFunctionType().getInputs();
483 inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end());
486 func::FuncOp func = func::FuncOp::create(loc,
"async_dispatch_fn", type);
495 Block *block = b.
createBlock(&func.getBody(), func.begin(), type.getInputs(),
516 scf::WhileOp whileOp = b.
create<scf::WhileOp>(types, operands);
524 Value start = before->getArgument(0);
525 Value end = before->getArgument(1);
526 Value distance = b.
create<arith::SubIOp>(end, start);
528 b.
create<arith::CmpIOp>(arith::CmpIPredicate::sgt, distance, c1);
529 b.
create<scf::ConditionOp>(dispatch, before->getArguments());
538 Value distance = b.
create<arith::SubIOp>(end, start);
539 Value halfDistance = b.
create<arith::DivSIOp>(distance, c2);
540 Value midIndex = b.
create<arith::AddIOp>(start, halfDistance);
543 auto executeBodyBuilder = [&](
OpBuilder &executeBuilder,
548 operands[1] = midIndex;
551 executeBuilder.
create<func::CallOp>(executeLoc, func.getSymName(),
552 func.getResultTypes(), operands);
559 b.
create<AddToGroupOp>(indexTy, execute.getToken(), group);
568 auto forwardedInputs = block->
getArguments().drop_front(3);
570 computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
572 b.
create<func::CallOp>(computeFunc.func.getSymName(),
573 computeFunc.func.getResultTypes(),
574 computeFuncOperands);
582 ParallelComputeFunction ¶llelComputeFunction,
583 scf::ParallelOp op,
Value blockSize,
590 func::FuncOp asyncDispatchFunction =
599 operands.append(tripCounts);
600 operands.append(op.getLowerBound().begin(), op.getLowerBound().end());
601 operands.append(op.getUpperBound().begin(), op.getUpperBound().end());
602 operands.append(op.getStep().begin(), op.getStep().end());
603 operands.append(parallelComputeFunction.captures);
609 Value isSingleBlock =
610 b.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, blockCount, c1);
617 appendBlockComputeOperands(operands);
619 b.
create<func::CallOp>(parallelComputeFunction.func.getSymName(),
620 parallelComputeFunction.func.getResultTypes(),
631 Value groupSize = b.
create<arith::SubIOp>(blockCount, c1);
636 appendBlockComputeOperands(operands);
638 b.
create<func::CallOp>(asyncDispatchFunction.getSymName(),
639 asyncDispatchFunction.getResultTypes(), operands);
642 b.
create<AwaitAllOp>(group);
648 b.
create<scf::IfOp>(isSingleBlock, syncDispatch, asyncDispatch);
655 ParallelComputeFunction ¶llelComputeFunction,
656 scf::ParallelOp op,
Value blockSize,
Value blockCount,
660 func::FuncOp compute = parallelComputeFunction.func;
668 Value groupSize = b.
create<arith::SubIOp>(blockCount, c1);
672 using LoopBodyBuilder =
678 computeFuncOperands.append(tripCounts);
679 computeFuncOperands.append(op.getLowerBound().begin(),
680 op.getLowerBound().end());
681 computeFuncOperands.append(op.getUpperBound().begin(),
682 op.getUpperBound().end());
683 computeFuncOperands.append(op.getStep().begin(), op.getStep().end());
684 computeFuncOperands.append(parallelComputeFunction.captures);
685 return computeFuncOperands;
694 auto executeBodyBuilder = [&](
OpBuilder &executeBuilder,
696 executeBuilder.
create<func::CallOp>(executeLoc, compute.getSymName(),
697 compute.getResultTypes(),
698 computeFuncOperands(iv));
713 b.
create<func::CallOp>(compute.getSymName(), compute.getResultTypes(),
714 computeFuncOperands(c0));
717 b.
create<AwaitAllOp>(group);
721 AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
724 if (op.getNumReductions() != 0)
732 Value minTaskSize = computeMinTaskSize(b, op);
741 for (
size_t i = 0; i < op.getNumLoops(); ++i) {
742 auto lb = op.getLowerBound()[i];
743 auto ub = op.getUpperBound()[i];
744 auto step = op.getStep()[i];
745 auto range = b.createOrFold<arith::SubIOp>(ub, lb);
746 tripCounts[i] = b.createOrFold<arith::CeilDivSIOp>(range, step);
751 Value tripCount = tripCounts[0];
752 for (
size_t i = 1; i < tripCounts.size(); ++i)
753 tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]);
757 Value c0 = b.create<arith::ConstantIndexOp>(0);
758 Value isZeroIterations =
759 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, tripCount, c0);
763 nestedBuilder.
create<scf::YieldOp>(loc);
776 ParallelComputeFunctionBounds staticBounds = {
790 static constexpr int64_t maxUnrollableIterations = 512;
794 int numUnrollableLoops = 0;
796 auto getInt = [](IntegerAttr attr) {
return attr ? attr.getInt() : 0; };
799 numIterations.back() = getInt(staticBounds.tripCounts.back());
801 for (
int i = op.getNumLoops() - 2; i >= 0; --i) {
802 int64_t tripCount = getInt(staticBounds.tripCounts[i]);
803 int64_t innerIterations = numIterations[i + 1];
804 numIterations[i] = tripCount * innerIterations;
807 if (innerIterations > 0 && innerIterations <= maxUnrollableIterations)
808 numUnrollableLoops++;
811 Value numWorkerThreadsVal;
812 if (numWorkerThreads >= 0)
813 numWorkerThreadsVal = b.create<arith::ConstantIndexOp>(numWorkerThreads);
815 numWorkerThreadsVal = b.create<async::RuntimeNumWorkerThreadsOp>();
831 {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}};
832 const float initialOvershardingFactor = 8.0f;
834 Value scalingFactor = b.create<arith::ConstantFloatOp>(
835 llvm::APFloat(initialOvershardingFactor), b.getF32Type());
836 for (
const std::pair<int, float> &p : overshardingBrackets) {
837 Value bracketBegin = b.create<arith::ConstantIndexOp>(p.first);
838 Value inBracket = b.create<arith::CmpIOp>(
839 arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
840 Value bracketScalingFactor = b.create<arith::ConstantFloatOp>(
841 llvm::APFloat(p.second), b.getF32Type());
842 scalingFactor = b.create<arith::SelectOp>(inBracket, bracketScalingFactor,
845 Value numWorkersIndex =
846 b.create<arith::IndexCastOp>(b.getI32Type(), numWorkerThreadsVal);
847 Value numWorkersFloat =
848 b.create<arith::SIToFPOp>(b.getF32Type(), numWorkersIndex);
849 Value scaledNumWorkers =
850 b.create<arith::MulFOp>(scalingFactor, numWorkersFloat);
852 b.create<arith::FPToSIOp>(b.getI32Type(), scaledNumWorkers);
853 Value scaledWorkers =
854 b.create<arith::IndexCastOp>(b.getIndexType(), scaledNumInt);
856 Value maxComputeBlocks = b.create<arith::MaxSIOp>(
857 b.create<arith::ConstantIndexOp>(1), scaledWorkers);
863 Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks);
864 Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize);
865 Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
875 Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
879 ParallelComputeFunction compute =
883 doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts);
884 b.create<scf::YieldOp>();
890 op, staticBounds, numUnrollableLoops, rewriter);
895 Value numIters = b.create<arith::ConstantIndexOp>(
896 numIterations[op.getNumLoops() - numUnrollableLoops]);
897 Value alignedBlockSize = b.create<arith::MulIOp>(
898 b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
899 doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount,
901 b.create<scf::YieldOp>();
907 if (numUnrollableLoops > 0) {
908 Value numIters = b.create<arith::ConstantIndexOp>(
909 numIterations[op.getNumLoops() - numUnrollableLoops]);
910 Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>(
911 arith::CmpIPredicate::sge, blockSize, numIters);
913 b.create<scf::IfOp>(useBlockAlignedComputeFn, dispatchBlockAligned,
915 b.create<scf::YieldOp>();
917 dispatchDefault(b, loc);
922 b.create<scf::IfOp>(isZeroIterations, noOp, dispatch);
930 void AsyncParallelForPass::runOnOperation() {
935 patterns, asyncDispatch, numWorkerThreads,
937 return builder.
create<arith::ConstantIndexOp>(minTaskSize);
944 return std::make_unique<AsyncParallelForPass>();
948 int32_t numWorkerThreads,
949 int32_t minTaskSize) {
950 return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads,
958 patterns.
add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
static ParallelComputeFunction createParallelComputeFunction(scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds, unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter)
static func::FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, PatternRewriter &rewriter)
static void doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction ¶llelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, ParallelComputeFunction ¶llelComputeFunction, scf::ParallelOp op, Value blockSize, Value blockCount, const SmallVector< Value > &tripCounts)
static ParallelComputeFunctionType getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter)
static SmallVector< IntegerAttr > integerConstants(ValueRange values)
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
std::function< Value(ImplicitLocOpBuilder, scf::ParallelOp)> AsyncMinTaskSizeComputationFunction
Emit the IR to compute the minimum number of iterations of scf.parallel body that would be viable for...
void populateAsyncParallelForPatterns(RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, const AsyncMinTaskSizeComputationFunction &computeMinTaskSize)
Add a pattern to the given pattern list to lower scf.parallel to async operations.
void cloneConstantsIntoTheRegion(Region ®ion)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::unique_ptr< Pass > createAsyncParallelForPass()
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...