15 #include "../PassDetail.h" 32 struct SCFToControlFlowPass
33 :
public SCFToControlFlowBase<SCFToControlFlowPass> {
34 void runOnOperation()
override;
208 LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
299 auto *endBlock = rewriter.
splitBlock(initBlock, initPosition);
305 auto *conditionBlock = &forOp.getRegion().
front();
306 auto *firstBodyBlock =
307 rewriter.
splitBlock(conditionBlock, conditionBlock->begin());
308 auto *lastBodyBlock = &forOp.getRegion().
back();
310 auto iv = conditionBlock->getArgument(0);
315 Operation *terminator = lastBodyBlock->getTerminator();
317 auto step = forOp.getStep();
318 auto stepped = rewriter.
create<arith::AddIOp>(loc, iv, step).getResult();
323 loopCarried.push_back(stepped);
325 rewriter.
create<cf::BranchOp>(loc, conditionBlock, loopCarried);
330 Value lowerBound = forOp.getLowerBound();
331 Value upperBound = forOp.getUpperBound();
332 if (!lowerBound || !upperBound)
338 destOperands.push_back(lowerBound);
339 auto iterOperands = forOp.getIterOperands();
340 destOperands.append(iterOperands.begin(), iterOperands.end());
341 rewriter.
create<cf::BranchOp>(loc, conditionBlock, destOperands);
345 auto comparison = rewriter.
create<arith::CmpIOp>(
346 loc, arith::CmpIPredicate::slt, iv, upperBound);
348 rewriter.
create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
353 rewriter.
replaceOp(forOp, conditionBlock->getArguments().drop_front());
359 auto loc = ifOp.getLoc();
366 auto *remainingOpsBlock = rewriter.
splitBlock(condBlock, opPosition);
367 Block *continueBlock;
368 if (ifOp.getNumResults() == 0) {
369 continueBlock = remainingOpsBlock;
372 rewriter.
createBlock(remainingOpsBlock, ifOp.getResultTypes(),
374 rewriter.
create<cf::BranchOp>(loc, remainingOpsBlock);
379 auto &thenRegion = ifOp.getThenRegion();
380 auto *thenBlock = &thenRegion.front();
381 Operation *thenTerminator = thenRegion.back().getTerminator();
384 rewriter.
create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
385 rewriter.
eraseOp(thenTerminator);
391 auto *elseBlock = continueBlock;
392 auto &elseRegion = ifOp.getElseRegion();
393 if (!elseRegion.empty()) {
394 elseBlock = &elseRegion.front();
395 Operation *elseTerminator = elseRegion.back().getTerminator();
398 rewriter.
create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
399 rewriter.
eraseOp(elseTerminator);
404 rewriter.
create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
409 rewriter.
replaceOp(ifOp, continueBlock->getArguments());
414 ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
416 auto loc = op.getLoc();
420 auto *remainingOpsBlock = rewriter.
splitBlock(condBlock, opPosition);
422 auto ®ion = op.getRegion();
424 rewriter.
create<cf::BranchOp>(loc, ®ion.front());
426 for (
Block &block : region) {
427 if (
auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
430 rewriter.
create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
440 remainingOpsBlock->addArguments(op->getResultTypes(), argLocs))
447 ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
458 ivs.reserve(parallelOp.getNumLoops());
461 for (
auto [iv, lower, upper, step] :
462 llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
463 parallelOp.getUpperBound(), parallelOp.getStep())) {
464 ForOp forOp = rewriter.
create<ForOp>(loc, lower, upper, step, iterArgs);
465 ivs.push_back(forOp.getInductionVar());
466 auto iterRange = forOp.getRegionIterArgs();
467 iterArgs.assign(iterRange.begin(), iterRange.end());
472 loopResults.assign(forOp.result_begin(), forOp.result_end());
474 }
else if (!forOp.getResults().empty()) {
478 rewriter.
create<scf::YieldOp>(loc, forOp.getResults());
486 yieldOperands.reserve(parallelOp.getNumResults());
487 for (
auto &op : *parallelOp.getBody()) {
488 auto reduce = dyn_cast<ReduceOp>(op);
492 Block &reduceBlock = reduce.getReductionOperator().
front();
493 Value arg = iterArgs[yieldOperands.size()];
501 rewriter.
eraseOp(parallelOp.getBody()->getTerminator());
503 if (newBody->
empty())
504 rewriter.
mergeBlocks(parallelOp.getBody(), newBody, ivs);
511 if (!yieldOperands.empty()) {
513 rewriter.
create<scf::YieldOp>(loc, yieldOperands);
516 rewriter.
replaceOp(parallelOp, loopResults);
521 LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
528 Block *continuation =
533 Block *afterLast = &whileOp.getAfter().
back();
535 Block *beforeLast = &whileOp.getBefore().
back();
541 rewriter.
create<cf::BranchOp>(loc, before, whileOp.getInits());
547 auto condOp = cast<ConditionOp>(beforeLast->
getTerminator());
549 after, condOp.getArgs(),
553 auto yieldOp = cast<scf::YieldOp>(afterLast->
getTerminator());
555 yieldOp.getResults());
559 rewriter.
replaceOp(whileOp, condOp.getArgs());
565 DoWhileLowering::matchAndRewrite(WhileOp whileOp,
567 if (!llvm::hasSingleElement(whileOp.getAfter()))
569 "do-while simplification applicable to " 570 "single-block 'after' region only");
572 Block &afterBlock = whileOp.getAfter().
front();
573 if (!llvm::hasSingleElement(afterBlock))
575 "do-while simplification applicable " 576 "only if 'after' region has no payload");
578 auto yield = dyn_cast<scf::YieldOp>(&afterBlock.
front());
579 if (!yield || yield.getResults() != afterBlock.
getArguments())
581 "do-while simplification applicable " 582 "only to forwarding 'after' regions");
587 Block *continuation =
592 Block *beforeLast = &whileOp.getBefore().
back();
597 rewriter.
create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
601 auto condOp = cast<ConditionOp>(beforeLast->
getTerminator());
603 before, condOp.getArgs(),
608 rewriter.
replaceOp(whileOp, condOp.getArgs());
615 patterns.
add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
616 ExecuteRegionLowering>(patterns.
getContext());
620 void SCFToControlFlowPass::runOnOperation() {
626 target.
addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
627 scf::ExecuteRegionOp>();
635 return std::make_unique<SCFToControlFlowPass>();
Include the generated interface declarations.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is a basic unit of execution within MLIR.
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::unique_ptr< Pass > createConvertSCFToCFPass()
Creates a pass to convert SCF operations to CFG branch-based operation in the ControlFlow dialect...
Block represents an ordered list of Operations.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Value getOperand(unsigned idx)
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void mergeBlockBefore(Block *source, Operation *op, ValueRange argValues=llvm::None)
BlockArgListType getArguments()
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getTerminator()
Get the terminator operation of this block.
operand_iterator operand_begin()
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
RAII guard to reset the insertion point of the builder when destroyed.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
This class describes a specific conversion target.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with 'argTypes' arguments and set the insertion point to the end of it...
operand_iterator operand_end()
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block 'source' into the end of block 'dest'.
This class provides an abstraction over the different types of ranges over Values.
MLIRContext * getContext() const