27 LogicalResult matchAndRewrite(scf::WhileOp loop,
36 Block *beforeBody = loop.getBeforeBody();
40 auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->
front());
43 "Loop body must have single cmp op");
45 scf::ConditionOp beforeTerm = loop.getConditionOp();
46 if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
48 diag <<
"Expected single condition use: " << *cmp;
56 using Pred = arith::CmpIPredicate;
57 Pred predicate = cmp.getPredicate();
58 if (predicate != Pred::slt && predicate != Pred::sgt)
60 diag <<
"Expected 'slt' or 'sgt' predicate: " << *cmp;
70 for (
bool reverse : {
false,
true}) {
71 auto expectedPred = reverse ? Pred::sgt : Pred::slt;
72 if (cmp.getPredicate() != expectedPred)
75 auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
76 auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
78 auto blockArg = dyn_cast<BlockArgument>(arg1);
79 if (!blockArg || blockArg.getOwner() != beforeBody)
85 inductionVar = blockArg;
92 diag <<
"Unrecognized cmp form: " << *cmp;
97 if (!llvm::hasNItems(inductionVar.
getUses(), 2))
99 diag <<
"Unrecognized induction var: " << inductionVar;
102 Block *afterBody = loop.getAfterBody();
103 scf::YieldOp afterTerm = loop.getYieldOp();
105 Value afterTermIndArg = afterTerm.getResults()[argNumber];
118 if (addOp.getLhs() == inductionVarAfter) {
119 step = addOp.getRhs();
120 }
else if (addOp.getRhs() == inductionVarAfter) {
121 step = addOp.getLhs();
127 Value lb = loop.getInits()[argNumber];
136 newArgs.reserve(loop.getInits().size());
141 newArgs.emplace_back(init);
151 rewriter.
create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
153 Block *newBody = newLoop.getBody();
157 ValueRange newBodyArgs = newBody->getArguments();
158 for (
auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
160 newArgs.emplace_back(newBodyArgs[i + 1]);
161 }
else if (i == argNumber) {
162 newArgs.emplace_back(newBodyArgs.front());
164 newArgs.emplace_back(newBodyArgs[i]);
171 auto term = cast<scf::YieldOp>(newBody->getTerminator());
179 newArgs.emplace_back(arg);
189 if (isa<IndexType>(step.getType())) {
190 one = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
192 one = rewriter.
create<arith::ConstantIntOp>(loc, 1, step.getType());
195 Value stepDec = rewriter.
create<arith::SubIOp>(loc, step, one);
196 Value len = rewriter.
create<arith::SubIOp>(loc, ub, lb);
197 len = rewriter.
create<arith::AddIOp>(loc, len, stepDec);
198 len = rewriter.
create<arith::DivSIOp>(loc, len, step);
199 len = rewriter.
create<arith::SubIOp>(loc, len, one);
200 Value res = rewriter.
create<arith::MulIOp>(loc, len, step);
201 res = rewriter.
create<arith::AddIOp>(loc, lb, res);
206 llvm::append_range(newArgs, newLoop.getResults());
207 newArgs.insert(newArgs.begin() + argNumber, res);
static std::string diag(const llvm::Value &value)
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< ForOp > upliftWhileToForLoop(RewriterBase &rewriter, WhileOp loop)
Try to uplift scf.while op to scf.for.
void populateUpliftWhileToForPatterns(RewritePatternSet &patterns)
Populate patterns to uplift scf.while ops to scf.for.
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...