30 #define GEN_PASS_DEF_SCFTOCONTROLFLOW
31 #include "mlir/Conversion/Passes.h.inc"
39 struct SCFToControlFlowPass
40 :
public impl::SCFToControlFlowBase<SCFToControlFlowPass> {
41 void runOnOperation()
override;
108 LogicalResult matchAndRewrite(ForOp forOp,
201 LogicalResult matchAndRewrite(IfOp ifOp,
208 LogicalResult matchAndRewrite(ExecuteRegionOp op,
215 LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
281 LogicalResult matchAndRewrite(WhileOp whileOp,
292 LogicalResult matchAndRewrite(WhileOp whileOp,
300 LogicalResult matchAndRewrite(IndexSwitchOp op,
311 LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
317 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
326 auto *endBlock = rewriter.
splitBlock(initBlock, initPosition);
332 auto *conditionBlock = &forOp.getRegion().
front();
333 auto *firstBodyBlock =
334 rewriter.
splitBlock(conditionBlock, conditionBlock->begin());
335 auto *lastBodyBlock = &forOp.getRegion().
back();
337 auto iv = conditionBlock->getArgument(0);
342 Operation *terminator = lastBodyBlock->getTerminator();
344 auto step = forOp.getStep();
345 auto stepped = rewriter.
create<arith::AddIOp>(loc, iv, step).getResult();
350 loopCarried.push_back(stepped);
352 rewriter.
create<cf::BranchOp>(loc, conditionBlock, loopCarried);
357 Value lowerBound = forOp.getLowerBound();
358 Value upperBound = forOp.getUpperBound();
359 if (!lowerBound || !upperBound)
365 destOperands.push_back(lowerBound);
366 llvm::append_range(destOperands, forOp.getInitArgs());
367 rewriter.
create<cf::BranchOp>(loc, conditionBlock, destOperands);
371 auto comparison = rewriter.
create<arith::CmpIOp>(
372 loc, arith::CmpIPredicate::slt, iv, upperBound);
374 auto condBranchOp = rewriter.
create<cf::CondBranchOp>(
381 llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs),
383 return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
385 condBranchOp->setDiscardableAttrs(llvmAttrs);
388 rewriter.
replaceOp(forOp, conditionBlock->getArguments().drop_front());
392 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
394 auto loc = ifOp.getLoc();
401 auto *remainingOpsBlock = rewriter.
splitBlock(condBlock, opPosition);
402 Block *continueBlock;
403 if (ifOp.getNumResults() == 0) {
404 continueBlock = remainingOpsBlock;
407 rewriter.
createBlock(remainingOpsBlock, ifOp.getResultTypes(),
409 rewriter.
create<cf::BranchOp>(loc, remainingOpsBlock);
414 auto &thenRegion = ifOp.getThenRegion();
415 auto *thenBlock = &thenRegion.front();
416 Operation *thenTerminator = thenRegion.back().getTerminator();
419 rewriter.
create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
420 rewriter.
eraseOp(thenTerminator);
426 auto *elseBlock = continueBlock;
427 auto &elseRegion = ifOp.getElseRegion();
428 if (!elseRegion.empty()) {
429 elseBlock = &elseRegion.front();
430 Operation *elseTerminator = elseRegion.back().getTerminator();
433 rewriter.
create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
434 rewriter.
eraseOp(elseTerminator);
439 rewriter.
create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
444 rewriter.
replaceOp(ifOp, continueBlock->getArguments());
449 ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
451 auto loc = op.getLoc();
455 auto *remainingOpsBlock = rewriter.
splitBlock(condBlock, opPosition);
457 auto ®ion = op.getRegion();
459 rewriter.
create<cf::BranchOp>(loc, ®ion.front());
461 for (
Block &block : region) {
462 if (
auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
465 rewriter.
create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
475 remainingOpsBlock->addArguments(op->getResultTypes(), argLocs))
482 ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
485 auto reductionOp = dyn_cast<ReduceOp>(parallelOp.getBody()->getTerminator());
497 ivs.reserve(parallelOp.getNumLoops());
500 for (
auto [iv, lower, upper, step] :
501 llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
502 parallelOp.getUpperBound(), parallelOp.getStep())) {
503 ForOp forOp = rewriter.
create<ForOp>(loc, lower, upper, step, iterArgs);
504 ivs.push_back(forOp.getInductionVar());
505 auto iterRange = forOp.getRegionIterArgs();
506 iterArgs.assign(iterRange.begin(), iterRange.end());
511 loopResults.assign(forOp.result_begin(), forOp.result_end());
513 }
else if (!forOp.getResults().empty()) {
517 rewriter.
create<scf::YieldOp>(loc, forOp.getResults());
525 yieldOperands.reserve(parallelOp.getNumResults());
526 for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) {
527 Block &reductionBody = reductionOp.getReductions()[i].
front();
528 Value arg = iterArgs[yieldOperands.size()];
529 yieldOperands.push_back(
530 cast<ReduceReturnOp>(reductionBody.
getTerminator()).getResult());
533 {arg, reductionOp.getOperands()[i]});
539 if (newBody->
empty())
540 rewriter.
mergeBlocks(parallelOp.getBody(), newBody, ivs);
547 if (!yieldOperands.empty()) {
549 rewriter.
create<scf::YieldOp>(loc, yieldOperands);
552 rewriter.
replaceOp(parallelOp, loopResults);
557 LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
564 Block *continuation =
568 Block *after = whileOp.getAfterBody();
569 Block *before = whileOp.getBeforeBody();
575 rewriter.
create<cf::BranchOp>(loc, before, whileOp.getInits());
581 auto condOp = cast<ConditionOp>(before->getTerminator());
583 after, condOp.getArgs(),
589 yieldOp.getResults());
593 rewriter.
replaceOp(whileOp, condOp.getArgs());
599 DoWhileLowering::matchAndRewrite(WhileOp whileOp,
601 Block &afterBlock = *whileOp.getAfterBody();
602 if (!llvm::hasSingleElement(afterBlock))
604 "do-while simplification applicable "
605 "only if 'after' region has no payload");
607 auto yield = dyn_cast<scf::YieldOp>(&afterBlock.
front());
608 if (!yield || yield.getResults() != afterBlock.
getArguments())
610 "do-while simplification applicable "
611 "only to forwarding 'after' regions");
616 Block *continuation =
620 Block *before = whileOp.getBeforeBody();
625 rewriter.
create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
631 before, condOp.getArgs(),
636 rewriter.
replaceOp(whileOp, condOp.getArgs());
642 IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
651 results.reserve(op.getNumResults());
652 for (
Type resultType : op.getResultTypes())
653 results.push_back(continueBlock->
addArgument(resultType, op.getLoc()));
656 auto convertRegion = [&](
Region ®ion) -> FailureOr<Block *> {
663 yield.getOperands());
673 caseSuccessors.reserve(op.getCases().size());
674 caseValues.reserve(op.getCases().size());
675 for (
auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
676 FailureOr<Block *> block = convertRegion(region);
679 caseSuccessors.push_back(*block);
680 caseValues.push_back(value);
684 FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
685 if (failed(defaultBlock))
693 Value caseValue = rewriter.
create<arith::IndexCastOp>(
694 op.getLoc(), rewriter.
getI32Type(), op.getArg());
696 rewriter.
create<cf::SwitchOp>(
697 op.getLoc(), caseValue, *defaultBlock,
ValueRange(),
699 rewriter.
replaceOp(op, continueBlock->getArguments());
703 LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
710 patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
711 WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
716 void SCFToControlFlowPass::runOnOperation() {
722 target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
723 scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
724 target.markUnknownOpDynamicallyLegal([](
Operation *) {
return true; });
731 return std::make_unique<SCFToControlFlowPass>();
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
operand_iterator operand_begin()
operand_iterator operand_end()
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp, ParallelOp *result=nullptr)
Try converting scf.forall into an scf.parallel loop.
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertSCFToCFPass()
Creates a pass to convert SCF operations to CFG branch-based operation in the ControlFlow dialect.
const FrozenRewritePatternSet & patterns
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...