28#define GEN_PASS_DEF_SCFTOCONTROLFLOWPASS
29#include "mlir/Conversion/Passes.h.inc"
37struct SCFToControlFlowPass
40 void runOnOperation()
override;
105 using OpRewritePattern<ForOp>::OpRewritePattern;
107 LogicalResult matchAndRewrite(ForOp forOp,
108 PatternRewriter &rewriter)
const override;
198 using OpRewritePattern<IfOp>::OpRewritePattern;
200 LogicalResult matchAndRewrite(IfOp ifOp,
201 PatternRewriter &rewriter)
const override;
205 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
207 LogicalResult matchAndRewrite(ExecuteRegionOp op,
208 PatternRewriter &rewriter)
const override;
212 using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
214 LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
215 PatternRewriter &rewriter)
const override;
281 PatternRewriter &rewriter)
const override;
289 using OpRewritePattern<WhileOp>::OpRewritePattern;
291 LogicalResult matchAndRewrite(WhileOp whileOp,
292 PatternRewriter &rewriter)
const override;
299 LogicalResult matchAndRewrite(IndexSwitchOp op,
300 PatternRewriter &rewriter)
const override;
308 using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern;
310 LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
311 PatternRewriter &rewriter)
const override;
322 llvm::copy_if(scfOp->
getAttrs(), std::back_inserter(llvmAttrs),
324 return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
329LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
331 Location loc = forOp.getLoc();
338 auto *endBlock = rewriter.
splitBlock(initBlock, initPosition);
344 auto *conditionBlock = &forOp.getRegion().
front();
345 auto *firstBodyBlock =
346 rewriter.
splitBlock(conditionBlock, conditionBlock->begin());
347 auto *lastBodyBlock = &forOp.getRegion().
back();
349 auto iv = conditionBlock->getArgument(0);
354 Operation *terminator = lastBodyBlock->getTerminator();
356 auto step = forOp.getStep();
357 auto stepped = arith::AddIOp::create(rewriter, loc, iv, step).getResult();
361 SmallVector<Value, 8> loopCarried;
362 loopCarried.push_back(stepped);
365 cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried);
372 Value lowerBound = forOp.getLowerBound();
373 Value upperBound = forOp.getUpperBound();
374 if (!lowerBound || !upperBound)
379 SmallVector<Value, 8> destOperands;
380 destOperands.push_back(lowerBound);
381 llvm::append_range(destOperands, forOp.getInitArgs());
382 cf::BranchOp::create(rewriter, loc, conditionBlock, destOperands);
386 arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
387 ? arith::CmpIPredicate::ult
388 : arith::CmpIPredicate::slt;
390 arith::CmpIOp::create(rewriter, loc, predicate, iv, upperBound);
392 cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock,
393 ArrayRef<Value>(), endBlock, ArrayRef<Value>());
397 rewriter.
replaceOp(forOp, conditionBlock->getArguments().drop_front());
401LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
402 PatternRewriter &rewriter)
const {
403 auto loc = ifOp.getLoc();
410 auto *remainingOpsBlock = rewriter.
splitBlock(condBlock, opPosition);
411 Block *continueBlock;
412 if (ifOp.getNumResults() == 0) {
413 continueBlock = remainingOpsBlock;
416 rewriter.
createBlock(remainingOpsBlock, ifOp.getResultTypes(),
417 SmallVector<Location>(ifOp.getNumResults(), loc));
418 cf::BranchOp::create(rewriter, loc, remainingOpsBlock);
423 auto &thenRegion = ifOp.getThenRegion();
424 auto *thenBlock = &thenRegion.
front();
425 Operation *thenTerminator = thenRegion.back().getTerminator();
428 cf::BranchOp::create(rewriter, loc, continueBlock, thenTerminatorOperands);
429 rewriter.
eraseOp(thenTerminator);
435 auto *elseBlock = continueBlock;
436 auto &elseRegion = ifOp.getElseRegion();
437 if (!elseRegion.empty()) {
438 elseBlock = &elseRegion.front();
439 Operation *elseTerminator = elseRegion.back().getTerminator();
442 cf::BranchOp::create(rewriter, loc, continueBlock, elseTerminatorOperands);
443 rewriter.
eraseOp(elseTerminator);
448 cf::CondBranchOp::create(rewriter, loc, ifOp.getCondition(), thenBlock,
449 ArrayRef<Value>(), elseBlock,
458ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
459 PatternRewriter &rewriter)
const {
460 auto loc = op.getLoc();
464 auto *remainingOpsBlock = rewriter.
splitBlock(condBlock, opPosition);
466 auto ®ion = op.getRegion();
468 cf::BranchOp::create(rewriter, loc, ®ion.front());
470 for (
Block &block : region) {
471 if (
auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
474 cf::BranchOp::create(rewriter, loc, remainingOpsBlock,
482 SmallVector<Value> vals;
483 SmallVector<Location> argLocs(op.getNumResults(), op->getLoc());
485 remainingOpsBlock->addArguments(op->getResultTypes(), argLocs))
492ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
493 PatternRewriter &rewriter)
const {
494 Location loc = parallelOp.getLoc();
495 auto reductionOp = dyn_cast<ReduceOp>(parallelOp.getBody()->getTerminator());
505 SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals());
506 SmallVector<Value, 4> ivs;
507 ivs.reserve(parallelOp.getNumLoops());
509 SmallVector<Value, 4> loopResults(iterArgs);
510 for (
auto [iv, lower, upper, step] :
511 llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
512 parallelOp.getUpperBound(), parallelOp.getStep())) {
513 ForOp forOp = ForOp::create(rewriter, loc, lower, upper, step, iterArgs);
514 ivs.push_back(forOp.getInductionVar());
515 auto iterRange = forOp.getRegionIterArgs();
516 iterArgs.assign(iterRange.begin(), iterRange.end());
521 loopResults.assign(forOp.result_begin(), forOp.result_end());
523 }
else if (!forOp.getResults().empty()) {
527 scf::YieldOp::create(rewriter, loc, forOp.getResults());
534 SmallVector<Value> yieldOperands;
535 yieldOperands.reserve(parallelOp.getNumResults());
536 for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) {
537 Block &reductionBody = reductionOp.getReductions()[i].front();
538 Value arg = iterArgs[yieldOperands.size()];
539 yieldOperands.push_back(
540 cast<ReduceReturnOp>(reductionBody.
getTerminator()).getResult());
543 {arg, reductionOp.getOperands()[i]});
549 if (newBody->
empty())
550 rewriter.
mergeBlocks(parallelOp.getBody(), newBody, ivs);
557 if (!yieldOperands.empty()) {
559 scf::YieldOp::create(rewriter, loc, yieldOperands);
562 rewriter.
replaceOp(parallelOp, loopResults);
568 PatternRewriter &rewriter)
const {
569 OpBuilder::InsertionGuard guard(rewriter);
570 Location loc = whileOp.getLoc();
574 Block *continuation =
578 Block *after = whileOp.getAfterBody();
579 Block *before = whileOp.getBeforeBody();
585 cf::BranchOp::create(rewriter, loc, before, whileOp.getInits());
592 SmallVector<Value> args = llvm::to_vector(condOp.getArgs());
594 after, condOp.getArgs(),
600 yieldOp.getResults());
611DoWhileLowering::matchAndRewrite(WhileOp whileOp,
612 PatternRewriter &rewriter)
const {
613 Block &afterBlock = *whileOp.getAfterBody();
614 if (!llvm::hasSingleElement(afterBlock))
616 "do-while simplification applicable "
617 "only if 'after' region has no payload");
619 auto yield = dyn_cast<scf::YieldOp>(&afterBlock.
front());
620 if (!yield || yield.getResults() != afterBlock.
getArguments())
622 "do-while simplification applicable "
623 "only to forwarding 'after' regions");
626 OpBuilder::InsertionGuard guard(rewriter);
628 Block *continuation =
632 Block *before = whileOp.getBeforeBody();
637 cf::BranchOp::create(rewriter, whileOp.getLoc(), before, whileOp.getInits());
642 auto latch = cf::CondBranchOp::create(
643 rewriter, condOp.getLoc(), condOp.getCondition(), before,
644 condOp.getArgs(), continuation,
ValueRange());
649 rewriter.
replaceOp(whileOp, condOp.getArgs());
657IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
658 PatternRewriter &rewriter)
const {
665 SmallVector<Value> results;
666 results.reserve(op.getNumResults());
667 for (Type resultType : op.getResultTypes())
668 results.push_back(continueBlock->
addArgument(resultType, op.getLoc()));
671 auto convertRegion = [&](Region ®ion) -> FailureOr<Block *> {
672 Block *block = ®ion.front();
678 yield.getOperands());
686 SmallVector<Block *> caseSuccessors;
687 SmallVector<int32_t> caseValues;
688 caseSuccessors.reserve(op.getCases().size());
689 caseValues.reserve(op.getCases().size());
690 for (
auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
691 FailureOr<Block *> block = convertRegion(region);
694 caseSuccessors.push_back(*block);
695 caseValues.push_back(value);
699 FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
705 SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
708 Value caseValue = arith::IndexCastOp::create(
709 rewriter, op.getLoc(), rewriter.
getI32Type(), op.getArg());
711 cf::SwitchOp::create(rewriter, op.getLoc(), caseValue, *defaultBlock,
713 caseSuccessors, caseOperands);
714 rewriter.
replaceOp(op, continueBlock->getArguments());
718LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
719 PatternRewriter &rewriter)
const {
725 patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
731void SCFToControlFlowPass::runOnOperation() {
737 target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
738 scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
739 target.markUnknownOpDynamicallyLegal([](
Operation *) {
return true; });
741 config.allowPatternRollback = allowPatternRollback;
static void propagateLoopAttrs(Operation *scfOp, Operation *brOp)
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
operand_iterator operand_begin()
operand_iterator operand_end()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp, ParallelOp *result=nullptr)
Try converting scf.forall into an scf.parallel loop.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
const FrozenRewritePatternSet & patterns
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
LogicalResult matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})