116 LogicalResult matchAndRewrite(scf::ParallelOp op,
125struct ParallelComputeFunctionType {
131struct ParallelComputeFunctionArgs {
143struct ParallelComputeFunctionBounds {
150struct ParallelComputeFunction {
158BlockArgument ParallelComputeFunctionArgs::blockIndex() {
return args[0]; }
159BlockArgument ParallelComputeFunctionArgs::blockSize() {
return args[1]; }
161ArrayRef<BlockArgument> ParallelComputeFunctionArgs::tripCounts() {
162 return args.drop_front(2).take_front(numLoops);
166 return args.drop_front(2 + 1 * numLoops).take_front(numLoops);
170 return args.drop_front(2 + 3 * numLoops).take_front(numLoops);
174 return args.drop_front(2 + 4 * numLoops);
177template <
typename ValueRange>
180 for (
unsigned i = 0; i < values.size(); ++i)
190 assert(!tripCounts.empty() &&
"tripCounts must be not empty");
192 for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
193 coords[i] = arith::RemSIOp::create(
b,
index, tripCounts[i]);
194 index = arith::DivSIOp::create(
b,
index, tripCounts[i]);
203static ParallelComputeFunctionType
210 inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
215 inputs.push_back(indexTy);
216 inputs.push_back(indexTy);
219 for (
unsigned i = 0; i < op.getNumLoops(); ++i)
220 inputs.push_back(indexTy);
225 for (
unsigned i = 0; i < op.getNumLoops(); ++i) {
226 inputs.push_back(indexTy);
227 inputs.push_back(indexTy);
228 inputs.push_back(indexTy);
232 for (
Value capture : captures)
233 inputs.push_back(capture.getType());
242 scf::ParallelOp op,
const ParallelComputeFunctionBounds &bounds,
247 ModuleOp module = op->getParentOfType<ModuleOp>();
249 ParallelComputeFunctionType computeFuncType =
252 FunctionType type = computeFuncType.type;
253 func::FuncOp
func = func::FuncOp::create(
255 numBlockAlignedInnerLoops > 0 ?
"parallel_compute_fn_with_aligned_loops"
256 :
"parallel_compute_fn",
267 b.createBlock(&
func.getBody(),
func.begin(), type.getInputs(),
269 b.setInsertionPointToEnd(block);
271 ParallelComputeFunctionArgs args = {op.getNumLoops(),
func.getArguments()};
283 return llvm::to_vector(
284 llvm::map_range(llvm::zip(args, attrs), [&](
auto tuple) ->
Value {
285 if (IntegerAttr attr = std::get<1>(tuple))
286 return arith::ConstantOp::create(
b, attr);
287 return std::get<0>(tuple);
292 auto tripCounts = values(args.tripCounts(), bounds.tripCounts);
295 auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds);
296 auto steps = values(args.steps(), bounds.steps);
303 Value tripCount = tripCounts[0];
304 for (
unsigned i = 1; i < tripCounts.size(); ++i)
305 tripCount = arith::MulIOp::create(
b, tripCount, tripCounts[i]);
309 Value blockFirstIndex = arith::MulIOp::create(
b, blockIndex, blockSize);
313 Value blockEnd0 = arith::AddIOp::create(
b, blockFirstIndex, blockSize);
314 Value blockEnd1 = arith::MinSIOp::create(
b, blockEnd0, tripCount);
315 Value blockLastIndex = arith::SubIOp::create(
b, blockEnd1, c1);
318 auto blockFirstCoord =
delinearize(
b, blockFirstIndex, tripCounts);
319 auto blockLastCoord =
delinearize(
b, blockLastIndex, tripCounts);
327 for (
size_t i = 0; i < blockLastCoord.size(); ++i)
328 blockEndCoord[i] = arith::AddIOp::create(
b, blockLastCoord[i], c1);
332 using LoopBodyBuilder =
334 using LoopNestBuilder = std::function<LoopBodyBuilder(
size_t loopIdx)>;
365 LoopNestBuilder workLoopBuilder = [&](
size_t loopIdx) -> LoopBodyBuilder {
371 computeBlockInductionVars[loopIdx] =
372 arith::AddIOp::create(
b, lowerBounds[loopIdx],
373 arith::MulIOp::create(
b, iv, steps[loopIdx]));
376 isBlockFirstCoord[loopIdx] = arith::CmpIOp::create(
377 b, arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
378 isBlockLastCoord[loopIdx] = arith::CmpIOp::create(
379 b, arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
383 isBlockFirstCoord[loopIdx] = arith::AndIOp::create(
384 b, isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
385 isBlockLastCoord[loopIdx] = arith::AndIOp::create(
386 b, isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
390 if (loopIdx < op.getNumLoops() - 1) {
391 if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) {
394 scf::ForOp::create(
b, c0, tripCounts[loopIdx + 1], c1,
ValueRange(),
395 workLoopBuilder(loopIdx + 1));
400 auto lb = arith::SelectOp::create(
b, isBlockFirstCoord[loopIdx],
401 blockFirstCoord[loopIdx + 1], c0);
403 auto ub = arith::SelectOp::create(
b, isBlockLastCoord[loopIdx],
404 blockEndCoord[loopIdx + 1],
405 tripCounts[loopIdx + 1]);
408 workLoopBuilder(loopIdx + 1));
411 scf::YieldOp::create(
b, loc);
417 mapping.
map(op.getInductionVars(), computeBlockInductionVars);
418 mapping.
map(computeFuncType.captures, captures);
420 for (
auto &bodyOp : op.getRegion().front().without_terminator())
421 b.clone(bodyOp, mapping);
422 scf::YieldOp::create(
b, loc);
426 scf::ForOp::create(
b, blockFirstCoord[0], blockEndCoord[0], c1,
ValueRange(),
430 return {op.getNumLoops(),
func, std::move(computeFuncType.captures)};
456 Location loc = computeFunc.func.getLoc();
459 ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
462 computeFunc.func.getFunctionType().getInputs();
469 inputTypes.push_back(async::GroupType::get(rewriter.
getContext()));
471 inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end());
474 func::FuncOp
func = func::FuncOp::create(loc,
"async_dispatch_fn", type);
483 Block *block =
b.createBlock(&
func.getBody(),
func.begin(), type.getInputs(),
485 b.setInsertionPointToEnd(block);
487 Type indexTy =
b.getIndexType();
504 scf::WhileOp whileOp = scf::WhileOp::create(
b, types, operands);
505 Block *before =
b.createBlock(&whileOp.getBefore(), {}, types, locations);
506 Block *after =
b.createBlock(&whileOp.getAfter(), {}, types, locations);
511 b.setInsertionPointToEnd(before);
514 Value distance = arith::SubIOp::create(
b, end, start);
516 arith::CmpIOp::create(
b, arith::CmpIPredicate::sgt, distance, c1);
517 scf::ConditionOp::create(
b, dispatch, before->
getArguments());
523 b.setInsertionPointToEnd(after);
526 Value distance = arith::SubIOp::create(
b, end, start);
527 Value halfDistance = arith::DivSIOp::create(
b, distance, c2);
528 Value midIndex = arith::AddIOp::create(
b, start, halfDistance);
531 auto executeBodyBuilder = [&](
OpBuilder &executeBuilder,
536 operands[1] = midIndex;
539 func::CallOp::create(executeBuilder, executeLoc,
func.getSymName(),
540 func.getResultTypes(), operands);
541 async::YieldOp::create(executeBuilder, executeLoc,
ValueRange());
547 AddToGroupOp::create(
b, indexTy, execute.getToken(), group);
548 scf::YieldOp::create(
b,
ValueRange({start, midIndex}));
553 b.setInsertionPointAfter(whileOp);
556 auto forwardedInputs = block->
getArguments().drop_front(3);
558 computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
560 func::CallOp::create(
b, computeFunc.func.getSymName(),
561 computeFunc.func.getResultTypes(), computeFuncOperands);
569 ParallelComputeFunction ¶llelComputeFunction,
570 scf::ParallelOp op,
Value blockSize,
577 func::FuncOp asyncDispatchFunction =
586 operands.append(tripCounts);
587 operands.append(op.getLowerBound().begin(), op.getLowerBound().end());
588 operands.append(op.getUpperBound().begin(), op.getUpperBound().end());
589 operands.append(op.getStep().begin(), op.getStep().end());
590 operands.append(parallelComputeFunction.captures);
596 Value isSingleBlock =
597 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, blockCount, c1);
604 appendBlockComputeOperands(operands);
606 func::CallOp::create(
b, parallelComputeFunction.func.getSymName(),
607 parallelComputeFunction.func.getResultTypes(),
609 scf::YieldOp::create(
b);
618 Value groupSize = arith::SubIOp::create(
b, blockCount, c1);
619 Value group = CreateGroupOp::create(
b, GroupType::get(ctx), groupSize);
623 appendBlockComputeOperands(operands);
625 func::CallOp::create(
b, asyncDispatchFunction.getSymName(),
626 asyncDispatchFunction.getResultTypes(), operands);
629 AwaitAllOp::create(
b, group);
631 scf::YieldOp::create(
b);
635 scf::IfOp::create(
b, isSingleBlock, syncDispatch, asyncDispatch);
642 ParallelComputeFunction ¶llelComputeFunction,
643 scf::ParallelOp op,
Value blockSize,
Value blockCount,
647 func::FuncOp compute = parallelComputeFunction.func;
655 Value groupSize = arith::SubIOp::create(
b, blockCount, c1);
656 Value group = CreateGroupOp::create(
b, GroupType::get(ctx), groupSize);
659 using LoopBodyBuilder =
665 computeFuncOperands.append(tripCounts);
666 computeFuncOperands.append(op.getLowerBound().begin(),
667 op.getLowerBound().end());
668 computeFuncOperands.append(op.getUpperBound().begin(),
669 op.getUpperBound().end());
670 computeFuncOperands.append(op.getStep().begin(), op.getStep().end());
671 computeFuncOperands.append(parallelComputeFunction.captures);
672 return computeFuncOperands;
681 auto executeBodyBuilder = [&](
OpBuilder &executeBuilder,
683 func::CallOp::create(executeBuilder, executeLoc, compute.getSymName(),
684 compute.getResultTypes(), computeFuncOperands(iv));
685 async::YieldOp::create(executeBuilder, executeLoc,
ValueRange());
691 AddToGroupOp::create(
b, rewriter.
getIndexType(), execute.getToken(), group);
692 scf::YieldOp::create(
b);
696 scf::ForOp::create(
b, c1, blockCount, c1,
ValueRange(), loopBuilder);
699 func::CallOp::create(
b, compute.getSymName(), compute.getResultTypes(),
700 computeFuncOperands(c0));
703 AwaitAllOp::create(
b, group);
707AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
708 PatternRewriter &rewriter)
const {
710 if (op.getNumReductions() != 0)
713 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
718 Value minTaskSize = computeMinTaskSize(
b, op);
726 SmallVector<Value> tripCounts(op.getNumLoops());
727 for (
size_t i = 0; i < op.getNumLoops(); ++i) {
728 auto lb = op.getLowerBound()[i];
729 auto ub = op.getUpperBound()[i];
730 auto step = op.getStep()[i];
731 auto range =
b.createOrFold<arith::SubIOp>(ub, lb);
732 tripCounts[i] =
b.createOrFold<arith::CeilDivSIOp>(range, step);
737 Value tripCount = tripCounts[0];
738 for (
size_t i = 1; i < tripCounts.size(); ++i)
739 tripCount = arith::MulIOp::create(
b, tripCount, tripCounts[i]);
744 Value isZeroIterations =
745 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, tripCount, c0);
748 auto noOp = [&](OpBuilder &nestedBuilder, Location loc) {
749 scf::YieldOp::create(nestedBuilder, loc);
754 auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) {
755 ImplicitLocOpBuilder
b(loc, nestedBuilder);
762 ParallelComputeFunctionBounds staticBounds = {
776 static constexpr int64_t maxUnrollableIterations = 512;
780 int numUnrollableLoops = 0;
782 auto getInt = [](IntegerAttr attr) {
return attr ? attr.getInt() : 0; };
784 SmallVector<int64_t> numIterations(op.getNumLoops());
785 numIterations.back() = getInt(staticBounds.tripCounts.back());
787 for (
int i = op.getNumLoops() - 2; i >= 0; --i) {
788 int64_t tripCount = getInt(staticBounds.tripCounts[i]);
789 int64_t innerIterations = numIterations[i + 1];
790 numIterations[i] = tripCount * innerIterations;
793 if (innerIterations > 0 && innerIterations <= maxUnrollableIterations)
794 numUnrollableLoops++;
797 Value numWorkerThreadsVal;
798 if (numWorkerThreads >= 0)
801 numWorkerThreadsVal = async::RuntimeNumWorkerThreadsOp::create(
b);
816 const SmallVector<std::pair<int, float>> overshardingBrackets = {
817 {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}};
818 const float initialOvershardingFactor = 8.0f;
821 b,
b.getF32Type(), llvm::APFloat(initialOvershardingFactor));
822 for (
const std::pair<int, float> &p : overshardingBrackets) {
824 Value inBracket = arith::CmpIOp::create(
825 b, arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
827 b,
b.getF32Type(), llvm::APFloat(p.second));
828 scalingFactor = arith::SelectOp::create(
829 b, inBracket, bracketScalingFactor, scalingFactor);
831 Value numWorkersIndex =
832 arith::IndexCastOp::create(
b,
b.getI32Type(), numWorkerThreadsVal);
833 Value numWorkersFloat =
834 arith::SIToFPOp::create(
b,
b.getF32Type(), numWorkersIndex);
835 Value scaledNumWorkers =
836 arith::MulFOp::create(
b, scalingFactor, numWorkersFloat);
838 arith::FPToSIOp::create(
b,
b.getI32Type(), scaledNumWorkers);
839 Value scaledWorkers =
840 arith::IndexCastOp::create(
b,
b.getIndexType(), scaledNumInt);
842 Value maxComputeBlocks = arith::MaxSIOp::create(
849 Value bs0 = arith::CeilDivSIOp::create(
b, tripCount, maxComputeBlocks);
850 Value bs1 = arith::MaxSIOp::create(
b, bs0, minTaskSize);
851 Value blockSize = arith::MinSIOp::create(
b, tripCount, bs1);
861 Value blockCount = arith::CeilDivSIOp::create(
b, tripCount, blockSize);
864 auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) {
865 ParallelComputeFunction compute =
868 ImplicitLocOpBuilder
b(loc, nestedBuilder);
869 doDispatch(
b, rewriter, compute, op, blockSize, blockCount, tripCounts);
870 scf::YieldOp::create(
b);
874 auto dispatchBlockAligned = [&](OpBuilder &nestedBuilder, Location loc) {
876 op, staticBounds, numUnrollableLoops, rewriter);
878 ImplicitLocOpBuilder
b(loc, nestedBuilder);
882 b, numIterations[op.getNumLoops() - numUnrollableLoops]);
883 Value alignedBlockSize = arith::MulIOp::create(
884 b, arith::CeilDivSIOp::create(
b, blockSize, numIters), numIters);
885 doDispatch(
b, rewriter, compute, op, alignedBlockSize, blockCount,
887 scf::YieldOp::create(
b);
893 if (numUnrollableLoops > 0) {
895 b, numIterations[op.getNumLoops() - numUnrollableLoops]);
896 Value useBlockAlignedComputeFn = arith::CmpIOp::create(
897 b, arith::CmpIPredicate::sge, blockSize, numIters);
899 scf::IfOp::create(
b, useBlockAlignedComputeFn, dispatchBlockAligned,
901 scf::YieldOp::create(
b);
903 dispatchDefault(
b, loc);
908 scf::IfOp::create(
b, isZeroIterations, noOp, dispatch);
916void AsyncParallelForPass::runOnOperation() {
921 patterns, asyncDispatch, numWorkerThreads,
922 [&](ImplicitLocOpBuilder builder, scf::ParallelOp op) {
933 patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,