28 #define GEN_PASS_DEF_SCFTOCONTROLFLOW
29 #include "mlir/Conversion/Passes.h.inc"
37 struct SCFToControlFlowPass
38 :
public impl::SCFToControlFlowBase<SCFToControlFlowPass> {
39 void runOnOperation()
override;
213 LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
324 auto *endBlock = rewriter.
splitBlock(initBlock, initPosition);
330 auto *conditionBlock = &forOp.getRegion().
front();
331 auto *firstBodyBlock =
332 rewriter.
splitBlock(conditionBlock, conditionBlock->begin());
333 auto *lastBodyBlock = &forOp.getRegion().
back();
335 auto iv = conditionBlock->getArgument(0);
340 Operation *terminator = lastBodyBlock->getTerminator();
342 auto step = forOp.getStep();
343 auto stepped = rewriter.
create<arith::AddIOp>(loc, iv, step).getResult();
348 loopCarried.push_back(stepped);
350 rewriter.
create<cf::BranchOp>(loc, conditionBlock, loopCarried);
355 Value lowerBound = forOp.getLowerBound();
356 Value upperBound = forOp.getUpperBound();
357 if (!lowerBound || !upperBound)
363 destOperands.push_back(lowerBound);
364 llvm::append_range(destOperands, forOp.getInitArgs());
365 rewriter.
create<cf::BranchOp>(loc, conditionBlock, destOperands);
369 auto comparison = rewriter.
create<arith::CmpIOp>(
370 loc, arith::CmpIPredicate::slt, iv, upperBound);
372 rewriter.
create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
377 rewriter.
replaceOp(forOp, conditionBlock->getArguments().drop_front());
383 auto loc = ifOp.getLoc();
390 auto *remainingOpsBlock = rewriter.
splitBlock(condBlock, opPosition);
391 Block *continueBlock;
392 if (ifOp.getNumResults() == 0) {
393 continueBlock = remainingOpsBlock;
396 rewriter.
createBlock(remainingOpsBlock, ifOp.getResultTypes(),
398 rewriter.
create<cf::BranchOp>(loc, remainingOpsBlock);
403 auto &thenRegion = ifOp.getThenRegion();
404 auto *thenBlock = &thenRegion.front();
405 Operation *thenTerminator = thenRegion.back().getTerminator();
408 rewriter.
create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
409 rewriter.
eraseOp(thenTerminator);
415 auto *elseBlock = continueBlock;
416 auto &elseRegion = ifOp.getElseRegion();
417 if (!elseRegion.empty()) {
418 elseBlock = &elseRegion.front();
419 Operation *elseTerminator = elseRegion.back().getTerminator();
422 rewriter.
create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
423 rewriter.
eraseOp(elseTerminator);
428 rewriter.
create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
433 rewriter.
replaceOp(ifOp, continueBlock->getArguments());
438 ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
444 auto *remainingOpsBlock = rewriter.
splitBlock(condBlock, opPosition);
448 rewriter.
create<cf::BranchOp>(loc, ®ion.front());
450 for (
Block &block : region) {
451 if (
auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
454 rewriter.
create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
471 ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
474 auto reductionOp = cast<ReduceOp>(parallelOp.getBody()->getTerminator());
483 ivs.reserve(parallelOp.getNumLoops());
486 for (
auto [iv, lower, upper, step] :
487 llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
488 parallelOp.getUpperBound(), parallelOp.getStep())) {
489 ForOp forOp = rewriter.
create<ForOp>(loc, lower, upper, step, iterArgs);
490 ivs.push_back(forOp.getInductionVar());
491 auto iterRange = forOp.getRegionIterArgs();
492 iterArgs.assign(iterRange.begin(), iterRange.end());
497 loopResults.assign(forOp.result_begin(), forOp.result_end());
499 }
else if (!forOp.getResults().empty()) {
503 rewriter.
create<scf::YieldOp>(loc, forOp.getResults());
511 yieldOperands.reserve(parallelOp.getNumResults());
512 for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) {
513 Block &reductionBody = reductionOp.getReductions()[i].
front();
514 Value arg = iterArgs[yieldOperands.size()];
515 yieldOperands.push_back(
516 cast<ReduceReturnOp>(reductionBody.
getTerminator()).getResult());
519 {arg, reductionOp.getOperands()[i]});
525 if (newBody->
empty())
526 rewriter.
mergeBlocks(parallelOp.getBody(), newBody, ivs);
533 if (!yieldOperands.empty()) {
535 rewriter.
create<scf::YieldOp>(loc, yieldOperands);
538 rewriter.
replaceOp(parallelOp, loopResults);
543 LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
550 Block *continuation =
554 Block *after = whileOp.getAfterBody();
555 Block *before = whileOp.getBeforeBody();
561 rewriter.
create<cf::BranchOp>(loc, before, whileOp.getInits());
567 auto condOp = cast<ConditionOp>(before->getTerminator());
569 after, condOp.getArgs(),
575 yieldOp.getResults());
579 rewriter.
replaceOp(whileOp, condOp.getArgs());
585 DoWhileLowering::matchAndRewrite(WhileOp whileOp,
587 Block &afterBlock = *whileOp.getAfterBody();
588 if (!llvm::hasSingleElement(afterBlock))
590 "do-while simplification applicable "
591 "only if 'after' region has no payload");
593 auto yield = dyn_cast<scf::YieldOp>(&afterBlock.
front());
594 if (!yield || yield.getResults() != afterBlock.
getArguments())
596 "do-while simplification applicable "
597 "only to forwarding 'after' regions");
602 Block *continuation =
606 Block *before = whileOp.getBeforeBody();
611 rewriter.
create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
617 before, condOp.getArgs(),
622 rewriter.
replaceOp(whileOp, condOp.getArgs());
628 IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
649 yield.getOperands());
659 caseSuccessors.reserve(op.getCases().size());
660 caseValues.reserve(op.getCases().size());
661 for (
auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
665 caseSuccessors.push_back(*block);
666 caseValues.push_back(value);
679 Value caseValue = rewriter.
create<arith::IndexCastOp>(
682 rewriter.
create<cf::SwitchOp>(
685 rewriter.
replaceOp(op, continueBlock->getArguments());
689 LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
692 if (!forallOp.getOutputs().empty())
695 "only fully bufferized scf.forall ops can be lowered to scf.parallel");
699 rewriter, loc, forallOp.getMixedLowerBound());
701 rewriter, loc, forallOp.getMixedUpperBound());
706 auto parallelOp = rewriter.
create<ParallelOp>(loc, lbs, ubs, steps);
707 rewriter.
eraseBlock(¶llelOp.getRegion().front());
709 parallelOp.getRegion().begin());
713 parallelOp.getRegion().front().getTerminator());
716 rewriter.
replaceOp(forallOp, parallelOp);
722 patterns.
add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
723 WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
728 void SCFToControlFlowPass::runOnOperation() {
734 target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
735 scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
736 target.markUnknownOpDynamicallyLegal([](
Operation *) {
return true; });
743 return std::make_unique<SCFToControlFlowPass>();
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
This class describes a specific conversion target.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
operand_iterator operand_begin()
Location getLoc()
The source location the operation was defined or derived from.
operand_iterator operand_end()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
std::unique_ptr< Pass > createConvertSCFToCFPass()
Creates a pass to convert SCF operations to CFG branch-based operation in the ControlFlow dialect.
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...