27 #include "llvm/ADT/DenseMap.h"
30 #define GEN_PASS_DEF_SCFFORLOOPPEELING
31 #define GEN_PASS_DEF_SCFFORLOOPSPECIALIZATION
32 #define GEN_PASS_DEF_SCFPARALLELLOOPSPECIALIZATION
33 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
38 using scf::ParallelOp;
46 constantIndices.reserve(op.getUpperBound().size());
47 for (
auto bound : op.getUpperBound()) {
48 auto minOp = bound.getDefiningOp<AffineMinOp>();
52 for (
AffineExpr expr : minOp.getMap().getResults()) {
58 constantIndices.push_back(minConstant);
64 for (
auto bound : llvm::zip(op.getUpperBound(), constantIndices)) {
66 b.
create<arith::ConstantIndexOp>(op.getLoc(), std::get<1>(bound));
67 Value cmp = b.
create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq,
68 std::get<0>(bound), constant);
69 cond = cond ? b.
create<arith::AndIOp>(op.getLoc(), cond, cmp) : cmp;
70 map.
map(std::get<0>(bound), constant);
72 auto ifOp = b.
create<scf::IfOp>(op.getLoc(), cond,
true);
73 ifOp.getThenBodyBuilder().
clone(*op.getOperation(), map);
74 ifOp.getElseBodyBuilder().
clone(*op.getOperation());
83 auto bound = op.getUpperBound();
84 auto minOp = bound.getDefiningOp<AffineMinOp>();
88 for (
AffineExpr expr : minOp.getMap().getResults()) {
97 Value constant = b.
create<arith::ConstantIndexOp>(op.getLoc(), minConstant);
98 Value cond = b.
create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq,
100 map.
map(bound, constant);
101 auto ifOp = b.
create<scf::IfOp>(op.getLoc(), cond,
true);
102 ifOp.getThenBodyBuilder().
clone(*op.getOperation(), map);
103 ifOp.getElseBodyBuilder().
clone(*op.getOperation());
119 ForOp &partialIteration,
Value &splitBound) {
120 RewriterBase::InsertionGuard guard(b);
126 if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
129 if (stepInt ==
static_cast<int64_t
>(1))
132 auto loc = forOp.getLoc();
136 auto modMap =
AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
140 forOp.getUpperBound(),
145 partialIteration = cast<ForOp>(b.
clone(*forOp.getOperation()));
146 partialIteration.getLowerBoundMutable().assign(splitBound);
148 partialIteration.getInitArgsMutable().assign(forOp->getResults());
152 forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
158 ForOp partialIteration,
160 Value mainIv = forOp.getInductionVar();
161 Value partialIv = partialIteration.getInductionVar();
162 assert(forOp.getStep() == partialIteration.getStep() &&
163 "expected same step in main and partial loop");
164 Value step = forOp.getStep();
167 if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
174 partialIteration.walk([&](
Operation *affineOp) {
175 if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
185 ForOp &partialIteration) {
186 Value previousUb = forOp.getUpperBound();
202 ForLoopPeelingPattern(
MLIRContext *ctx,
bool skipPartial)
220 scf::ForOp partialIteration;
244 struct ParallelLoopSpecialization
245 :
public impl::SCFParallelLoopSpecializationBase<
246 ParallelLoopSpecialization> {
247 void runOnOperation()
override {
248 getOperation()->walk(
253 struct ForLoopSpecialization
254 :
public impl::SCFForLoopSpecializationBase<ForLoopSpecialization> {
255 void runOnOperation()
override {
260 struct ForLoopPeeling :
public impl::SCFForLoopPeelingBase<ForLoopPeeling> {
261 void runOnOperation()
override {
262 auto *parentOp = getOperation();
265 patterns.add<ForLoopPeelingPattern>(ctx, skipPartial);
278 return std::make_unique<ParallelLoopSpecialization>();
282 return std::make_unique<ForLoopSpecialization>();
286 return std::make_unique<ForLoopPeeling>();
static void specializeForLoopForUnrolling(ForOp op)
Rewrite a for loop with bounds defined by an affine.min with a constant into 2 loops after checking i...
static void specializeParallelLoopForUnrolling(ParallelOp op)
Rewrite a parallel loop with bounds defined by an affine.min with a constant into 2 loops after check...
static constexpr char kPeeledLoopLabel[]
static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp, ForOp partialIteration, Value previousUb)
static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp, ForOp &partialIteration, Value &splitBound)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
static constexpr char kPartialIterationLabel[]
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
An integer constant appearing in affine expression.
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
void erase()
Remove this operation from its parent block and delete it.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
LogicalResult rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op, Value iv, Value ub, Value step, bool insideLoop)
Try to simplify the given affine.min/max operation op after loop peeling.
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createParallelLoopSpecializationPass()
Creates a pass that specializes parallel loop for unrolling and vectorization.
std::unique_ptr< Pass > createForLoopSpecializationPass()
Creates a pass that specializes for loop for unrolling and vectorization.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
std::unique_ptr< Pass > createForLoopPeelingPass()
Creates a pass that peels for loops at their upper bounds for better vectorization.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...