22 #define GEN_PASS_DEF_SCFFORTOWHILELOOP
23 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
36 LogicalResult matchAndRewrite(ForOp forOp,
42 lcvTypes.push_back(forOp.getInductionVar().getType());
43 lcvLocs.push_back(forOp.getInductionVar().getLoc());
44 for (
Value value : forOp.getInitArgs()) {
45 lcvTypes.push_back(value.getType());
46 lcvLocs.push_back(value.getLoc());
51 initArgs.push_back(forOp.getLowerBound());
52 llvm::append_range(initArgs, forOp.getInitArgs());
53 auto whileOp = WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs,
59 &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
61 arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
62 ? arith::CmpIPredicate::ult
63 : arith::CmpIPredicate::slt;
64 auto cmpOp = arith::CmpIOp::create(rewriter, whileOp.getLoc(), predicate,
65 beforeBlock->getArgument(0),
66 forOp.getUpperBound());
67 scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(),
68 beforeBlock->getArguments());
74 &whileOp.getAfter(), whileOp.getAfter().begin(), lcvTypes, lcvLocs);
79 arith::AddIOp::create(rewriter, whileOp.getLoc(),
80 afterBlock->getArgument(0), forOp.getStep());
84 for (
const auto &barg :
enumerate(forOp.getBody(0)->getArguments()))
86 afterBlock->getArgument(barg.index()));
89 for (
auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
90 rewriter.
moveOpBefore(&arg, afterBlock, afterBlock->end());
93 for (
auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
95 yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
97 [&]() { yieldOp->setOperands(yieldOperands); });
106 whileOp.getResult(arg.index() + 1));
113 struct ForToWhileLoop :
public impl::SCFForToWhileLoopBase<ForToWhileLoop> {
114 void runOnOperation()
override {
115 auto *parentOp = getOperation();
118 patterns.add<ForLoopLoweringPattern>(ctx);
125 return std::make_unique<ForToWhileLoop>();
MLIRContext is the top-level object for a collection of MLIR operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createForToWhileLoopPass()
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...