24 #include "llvm/ADT/TypeSwitch.h" 34 auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
35 assert(arg < static_cast<int64_t>(yieldOp.getResults().size()) &&
36 "arg is out of bounds");
39 if (value == forOp.getRegionIterArgs()[arg])
45 using tensor::InsertSliceOp;
48 .
template Case<InsertSliceOp>(
49 [&](InsertSliceOp op) {
return op.getDest(); })
50 .
template Case<ForOp>([&](ForOp forOp) {
55 .Default([&](
auto op) {
return Value(); });
83 template <
typename OpTy>
89 auto blockArg = dimOp.getSource().template dyn_cast<
BlockArgument>();
92 auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
98 Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get();
100 dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
128 template <
typename OpTy>
134 auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>();
137 auto opResult = dimOp.getSource().template cast<OpResult>();
138 unsigned resultNumber = opResult.getResultNumber();
142 dimOp.getSourceMutable().assign(forOp.getIterOperands()[resultNumber]);
150 template <
typename OpTy,
bool IsMin>
159 lb = forOp.getLowerBound();
160 ub = forOp.getUpperBound();
161 step = forOp.getStep();
165 for (
unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) {
166 if (parOp.getInductionVars()[idx] == iv) {
167 lb = parOp.getLowerBound()[idx];
168 ub = parOp.getUpperBound()[idx];
169 step = parOp.getStep()[idx];
175 if (scf::ForeachThreadOp foreachThreadOp =
177 for (int64_t idx = 0; idx < foreachThreadOp.getRank(); ++idx) {
178 if (foreachThreadOp.getThreadIndices()[idx] == iv) {
180 ub = foreachThreadOp.getNumThreads()[idx];
191 op.operands(), IsMin, loopMatcher);
195 struct SCFForLoopCanonicalization
196 :
public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
197 void runOnOperation()
override {
198 auto *parentOp = getOperation();
212 .
add<AffineOpSCFCanonicalizationPattern<AffineMinOp,
true>,
213 AffineOpSCFCanonicalizationPattern<AffineMaxOp,
false>,
214 DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>,
215 DimOfLoopResultFolder<tensor::DimOp>,
216 DimOfLoopResultFolder<memref::DimOp>>(ctx);
220 return std::make_unique<SCFForLoopCanonicalization>();
Include the generated interface declarations.
std::unique_ptr< Pass > createSCFForLoopCanonicalizationPass()
Creates a pass that canonicalizes affine.min and affine.max operations inside of scf.for loops with known lower and upper bounds.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a value defined by a result of an operation.
This class represents a single result from folding an operation.
LogicalResult canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, Operation *op, AffineMap map, ValueRange operands, bool isMin, LoopMatcherFn loopMatcher)
Try to canonicalize an min/max operations in the context of for loops with a known range...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
static bool isShapePreserving(ForOp forOp, int64_t arg)
A simple, conservative analysis to determine if the loop is shape conserving.
static constexpr const bool value
Operation * getOwner() const
Returns the operation that owns this result.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
unsigned getResultNumber() const
Returns the number of this result.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
ForeachThreadOp getForeachThreadOpThreadIndexOwner(Value val)
Returns the ForeachThreadOp parent of an thread index variable.
This class represents an argument of a Block.
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns)
Populate patterns for canonicalizing operations inside SCF loop bodies.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
MLIRContext is the top-level object for a collection of MLIR operations.
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
This class helps build Operations.
MLIRContext * getContext() const
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.