25 LogicalResult matchAndRewrite(scf::WhileOp loop,
34 Block *beforeBody = loop.getBeforeBody();
38 auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->
front());
41 "Loop body must have single cmp op");
43 scf::ConditionOp beforeTerm = loop.getConditionOp();
44 if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
46 diag <<
"Expected single condition use: " << *cmp;
52 std::optional<SmallVector<unsigned>> argReorder;
56 auto getArgReordering =
66 for (
Value a : cond.getArgs()) {
69 if (!arg || arg.
getOwner() != beforeBody)
76 forwarded[idx] =
true;
85 argReorder = getArgReordering(beforeBody, beforeTerm);
90 using Pred = arith::CmpIPredicate;
91 Pred predicate = cmp.getPredicate();
92 if (predicate != Pred::slt && predicate != Pred::sgt)
94 diag <<
"Expected 'slt' or 'sgt' predicate: " << *cmp;
104 for (
bool reverse : {
false,
true}) {
105 auto expectedPred = reverse ? Pred::sgt : Pred::slt;
106 if (cmp.getPredicate() != expectedPred)
109 auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
110 auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
112 auto blockArg = dyn_cast<BlockArgument>(arg1);
113 if (!blockArg || blockArg.getOwner() != beforeBody)
119 inductionVar = blockArg;
126 diag <<
"Unrecognized cmp form: " << *cmp;
131 if (!llvm::hasNItems(inductionVar.
getUses(), 2))
133 diag <<
"Unrecognized induction var: " << inductionVar;
136 Block *afterBody = loop.getAfterBody();
137 scf::YieldOp afterTerm = loop.getYieldOp();
139 Value afterTermIndArg = afterTerm.getResults()[argNumber];
142 return std::distance(indices.begin(),
143 llvm::find_if(indices, [beforeArgNo](
unsigned n) {
144 return n == beforeArgNo;
148 argReorder ? findAfterArgNo(*argReorder, argNumber) : argNumber);
159 if (addOp.getLhs() == inductionVarAfter) {
160 step = addOp.getRhs();
161 }
else if (addOp.getRhs() == inductionVarAfter) {
162 step = addOp.getLhs();
168 Value lb = loop.getInits()[argNumber];
177 newArgs.reserve(loop.getInits().size());
182 newArgs.emplace_back(init);
192 scf::ForOp::create(rewriter, loc, lb, ub, step, newArgs, emptyBuilder);
194 Block *newBody = newLoop.getBody();
199 for (
auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
201 newArgs.emplace_back(newBodyArgs[i + 1]);
202 }
else if (i == argNumber) {
203 newArgs.emplace_back(newBodyArgs.front());
205 newArgs.emplace_back(newBodyArgs[i]);
212 for (
unsigned order : *argReorder)
213 args.push_back(newArgs[order]);
228 newArgs.emplace_back(arg);
238 if (isa<IndexType>(step.
getType())) {
244 Value stepDec = arith::SubIOp::create(rewriter, loc, step, one);
245 Value len = arith::SubIOp::create(rewriter, loc, ub, lb);
246 len = arith::AddIOp::create(rewriter, loc, len, stepDec);
247 len = arith::DivSIOp::create(rewriter, loc, len, step);
248 len = arith::SubIOp::create(rewriter, loc, len, one);
249 Value res = arith::MulIOp::create(rewriter, loc, len, step);
250 res = arith::AddIOp::create(rewriter, loc, lb, res);
255 llvm::append_range(newArgs, newLoop.getResults());
256 newArgs.insert(newArgs.begin() + argNumber, res);
261 for (
unsigned order : *argReorder)
262 results.push_back(newArgs[order]);
static std::string diag(const llvm::Value &value)
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< ForOp > upliftWhileToForLoop(RewriterBase &rewriter, WhileOp loop)
Try to uplift scf.while op to scf.for.
void populateUpliftWhileToForPatterns(RewritePatternSet &patterns)
Populate patterns to uplift scf.while ops to scf.for.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...