28 #define GEN_PASS_DEF_SCFTOCONTROLFLOWPASS
29 #include "mlir/Conversion/Passes.h.inc"
37 struct SCFToControlFlowPass
38 :
public impl::SCFToControlFlowPassBase<SCFToControlFlowPass> {
39 void runOnOperation()
override;
106 LogicalResult matchAndRewrite(ForOp forOp,
199 LogicalResult matchAndRewrite(IfOp ifOp,
206 LogicalResult matchAndRewrite(ExecuteRegionOp op,
213 LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
279 LogicalResult matchAndRewrite(WhileOp whileOp,
290 LogicalResult matchAndRewrite(WhileOp whileOp,
298 LogicalResult matchAndRewrite(IndexSwitchOp op,
309 LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
321 llvm::copy_if(scfOp->
getAttrs(), std::back_inserter(llvmAttrs),
323 return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
328 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
337 auto *endBlock = rewriter.
splitBlock(initBlock, initPosition);
343 auto *conditionBlock = &forOp.getRegion().
front();
344 auto *firstBodyBlock =
345 rewriter.
splitBlock(conditionBlock, conditionBlock->begin());
346 auto *lastBodyBlock = &forOp.getRegion().
back();
348 auto iv = conditionBlock->getArgument(0);
353 Operation *terminator = lastBodyBlock->getTerminator();
355 auto step = forOp.getStep();
356 auto stepped = arith::AddIOp::create(rewriter, loc, iv, step).getResult();
361 loopCarried.push_back(stepped);
364 cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried);
371 Value lowerBound = forOp.getLowerBound();
372 Value upperBound = forOp.getUpperBound();
373 if (!lowerBound || !upperBound)
379 destOperands.push_back(lowerBound);
380 llvm::append_range(destOperands, forOp.getInitArgs());
381 cf::BranchOp::create(rewriter, loc, conditionBlock, destOperands);
385 arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
386 ? arith::CmpIPredicate::ult
387 : arith::CmpIPredicate::slt;
389 arith::CmpIOp::create(rewriter, loc, predicate, iv, upperBound);
391 cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock,
396 rewriter.
replaceOp(forOp, conditionBlock->getArguments().drop_front());
400 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
402 auto loc = ifOp.getLoc();
409 auto *remainingOpsBlock = rewriter.
splitBlock(condBlock, opPosition);
410 Block *continueBlock;
411 if (ifOp.getNumResults() == 0) {
412 continueBlock = remainingOpsBlock;
415 rewriter.
createBlock(remainingOpsBlock, ifOp.getResultTypes(),
417 cf::BranchOp::create(rewriter, loc, remainingOpsBlock);
422 auto &thenRegion = ifOp.getThenRegion();
423 auto *thenBlock = &thenRegion.front();
424 Operation *thenTerminator = thenRegion.back().getTerminator();
427 cf::BranchOp::create(rewriter, loc, continueBlock, thenTerminatorOperands);
428 rewriter.
eraseOp(thenTerminator);
434 auto *elseBlock = continueBlock;
435 auto &elseRegion = ifOp.getElseRegion();
436 if (!elseRegion.empty()) {
437 elseBlock = &elseRegion.
front();
438 Operation *elseTerminator = elseRegion.back().getTerminator();
441 cf::BranchOp::create(rewriter, loc, continueBlock, elseTerminatorOperands);
442 rewriter.
eraseOp(elseTerminator);
447 cf::CondBranchOp::create(rewriter, loc, ifOp.getCondition(), thenBlock,
457 ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
459 auto loc = op.getLoc();
463 auto *remainingOpsBlock = rewriter.
splitBlock(condBlock, opPosition);
465 auto ®ion = op.getRegion();
467 cf::BranchOp::create(rewriter, loc, ®ion.front());
469 for (
Block &block : region) {
470 if (
auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
473 cf::BranchOp::create(rewriter, loc, remainingOpsBlock,
484 remainingOpsBlock->addArguments(op->getResultTypes(), argLocs))
491 ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
494 auto reductionOp = dyn_cast<ReduceOp>(parallelOp.getBody()->getTerminator());
506 ivs.reserve(parallelOp.getNumLoops());
509 for (
auto [iv, lower, upper, step] :
510 llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
511 parallelOp.getUpperBound(), parallelOp.getStep())) {
512 ForOp forOp = ForOp::create(rewriter, loc, lower, upper, step, iterArgs);
513 ivs.push_back(forOp.getInductionVar());
514 auto iterRange = forOp.getRegionIterArgs();
515 iterArgs.assign(iterRange.begin(), iterRange.end());
520 loopResults.assign(forOp.result_begin(), forOp.result_end());
522 }
else if (!forOp.getResults().empty()) {
526 scf::YieldOp::create(rewriter, loc, forOp.getResults());
534 yieldOperands.reserve(parallelOp.getNumResults());
535 for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) {
536 Block &reductionBody = reductionOp.getReductions()[i].
front();
537 Value arg = iterArgs[yieldOperands.size()];
538 yieldOperands.push_back(
539 cast<ReduceReturnOp>(reductionBody.
getTerminator()).getResult());
542 {arg, reductionOp.getOperands()[i]});
548 if (newBody->
empty())
549 rewriter.
mergeBlocks(parallelOp.getBody(), newBody, ivs);
556 if (!yieldOperands.empty()) {
558 scf::YieldOp::create(rewriter, loc, yieldOperands);
561 rewriter.
replaceOp(parallelOp, loopResults);
566 LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
573 Block *continuation =
577 Block *after = whileOp.getAfterBody();
578 Block *before = whileOp.getBeforeBody();
584 cf::BranchOp::create(rewriter, loc, before, whileOp.getInits());
593 after, condOp.getArgs(),
599 yieldOp.getResults());
610 DoWhileLowering::matchAndRewrite(WhileOp whileOp,
612 Block &afterBlock = *whileOp.getAfterBody();
613 if (!llvm::hasSingleElement(afterBlock))
615 "do-while simplification applicable "
616 "only if 'after' region has no payload");
618 auto yield = dyn_cast<scf::YieldOp>(&afterBlock.
front());
619 if (!yield || yield.getResults() != afterBlock.
getArguments())
621 "do-while simplification applicable "
622 "only to forwarding 'after' regions");
627 Block *continuation =
631 Block *before = whileOp.getBeforeBody();
636 cf::BranchOp::create(rewriter, whileOp.getLoc(), before, whileOp.getInits());
641 auto latch = cf::CondBranchOp::create(
642 rewriter, condOp.getLoc(), condOp.getCondition(), before,
643 condOp.getArgs(), continuation,
ValueRange());
648 rewriter.
replaceOp(whileOp, condOp.getArgs());
656 IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
665 results.reserve(op.getNumResults());
666 for (
Type resultType : op.getResultTypes())
667 results.push_back(continueBlock->
addArgument(resultType, op.getLoc()));
670 auto convertRegion = [&](
Region ®ion) -> FailureOr<Block *> {
677 yield.getOperands());
687 caseSuccessors.reserve(op.getCases().size());
688 caseValues.reserve(op.getCases().size());
689 for (
auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
690 FailureOr<Block *> block = convertRegion(region);
693 caseSuccessors.push_back(*block);
694 caseValues.push_back(value);
698 FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
707 Value caseValue = arith::IndexCastOp::create(
708 rewriter, op.getLoc(), rewriter.
getI32Type(), op.getArg());
710 cf::SwitchOp::create(rewriter, op.getLoc(), caseValue, *defaultBlock,
712 caseSuccessors, caseOperands);
713 rewriter.
replaceOp(op, continueBlock->getArguments());
717 LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
724 patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
725 WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
730 void SCFToControlFlowPass::runOnOperation() {
736 target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
737 scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
738 target.markUnknownOpDynamicallyLegal([](
Operation *) {
return true; });
static MLIRContext * getContext(OpFoldResult val)
static void propagateLoopAttrs(Operation *scfOp, Operation *brOp)
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
operand_iterator operand_begin()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
operand_iterator operand_end()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp, ParallelOp *result=nullptr)
Try converting scf.forall into an scf.parallel loop.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...