27 LogicalResult matchAndRewrite(scf::WhileOp loop,
36 Block *beforeBody = loop.getBeforeBody();
40 auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->
front());
43 "Loop body must have single cmp op");
45 scf::ConditionOp beforeTerm = loop.getConditionOp();
46 if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
48 diag <<
"Expected single condition use: " << *cmp;
54 std::optional<SmallVector<unsigned>> argReorder;
58 auto getArgReordering =
68 for (
Value a : cond.getArgs()) {
71 if (!arg || arg.
getOwner() != beforeBody)
78 forwarded[idx] =
true;
87 argReorder = getArgReordering(beforeBody, beforeTerm);
92 using Pred = arith::CmpIPredicate;
93 Pred predicate = cmp.getPredicate();
94 if (predicate != Pred::slt && predicate != Pred::sgt)
96 diag <<
"Expected 'slt' or 'sgt' predicate: " << *cmp;
106 for (
bool reverse : {
false,
true}) {
107 auto expectedPred = reverse ? Pred::sgt : Pred::slt;
108 if (cmp.getPredicate() != expectedPred)
111 auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
112 auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
114 auto blockArg = dyn_cast<BlockArgument>(arg1);
115 if (!blockArg || blockArg.getOwner() != beforeBody)
121 inductionVar = blockArg;
128 diag <<
"Unrecognized cmp form: " << *cmp;
133 if (!llvm::hasNItems(inductionVar.
getUses(), 2))
135 diag <<
"Unrecognized induction var: " << inductionVar;
138 Block *afterBody = loop.getAfterBody();
139 scf::YieldOp afterTerm = loop.getYieldOp();
141 Value afterTermIndArg = afterTerm.getResults()[argNumber];
144 return std::distance(indices.begin(),
145 llvm::find_if(indices, [beforeArgNo](
unsigned n) {
146 return n == beforeArgNo;
150 argReorder ? findAfterArgNo(*argReorder, argNumber) : argNumber);
161 if (addOp.getLhs() == inductionVarAfter) {
162 step = addOp.getRhs();
163 }
else if (addOp.getRhs() == inductionVarAfter) {
164 step = addOp.getLhs();
170 Value lb = loop.getInits()[argNumber];
179 newArgs.reserve(loop.getInits().size());
184 newArgs.emplace_back(init);
194 rewriter.
create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
196 Block *newBody = newLoop.getBody();
200 ValueRange newBodyArgs = newBody->getArguments();
201 for (
auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
203 newArgs.emplace_back(newBodyArgs[i + 1]);
204 }
else if (i == argNumber) {
205 newArgs.emplace_back(newBodyArgs.front());
207 newArgs.emplace_back(newBodyArgs[i]);
214 for (
unsigned order : *argReorder)
215 args.push_back(newArgs[order]);
222 auto term = cast<scf::YieldOp>(newBody->getTerminator());
230 newArgs.emplace_back(arg);
240 if (isa<IndexType>(step.getType())) {
241 one = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
243 one = rewriter.
create<arith::ConstantIntOp>(loc, 1, step.getType());
246 Value stepDec = rewriter.
create<arith::SubIOp>(loc, step, one);
247 Value len = rewriter.
create<arith::SubIOp>(loc, ub, lb);
248 len = rewriter.
create<arith::AddIOp>(loc, len, stepDec);
249 len = rewriter.
create<arith::DivSIOp>(loc, len, step);
250 len = rewriter.
create<arith::SubIOp>(loc, len, one);
251 Value res = rewriter.
create<arith::MulIOp>(loc, len, step);
252 res = rewriter.
create<arith::AddIOp>(loc, lb, res);
257 llvm::append_range(newArgs, newLoop.getResults());
258 newArgs.insert(newArgs.begin() + argNumber, res);
263 for (
unsigned order : *argReorder)
264 results.push_back(newArgs[order]);
static std::string diag(const llvm::Value &value)
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< ForOp > upliftWhileToForLoop(RewriterBase &rewriter, WhileOp loop)
Try to uplift scf.while op to scf.for.
void populateUpliftWhileToForPatterns(RewritePatternSet &patterns)
Populate patterns to uplift scf.while ops to scf.for.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...