24 #include "llvm/ADT/TypeSwitch.h"
27 #define GEN_PASS_DEF_SCFFORLOOPCANONICALIZATION
28 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
39 assert(arg <
static_cast<int64_t
>(forOp.getNumResults()) &&
40 "arg is out of bounds");
41 Value value = forOp.getYieldedValues()[arg];
43 if (value == forOp.getRegionIterArgs()[arg])
45 OpResult opResult = dyn_cast<OpResult>(value);
49 using tensor::InsertSliceOp;
51 .template Case<InsertSliceOp>(
52 [&](InsertSliceOp op) {
return op.getDest(); })
53 .
template Case<ForOp>([&](ForOp forOp) {
58 .Default([&](
auto op) {
return Value(); });
86 template <
typename OpTy>
90 LogicalResult matchAndRewrite(OpTy dimOp,
92 auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
95 auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
101 Value initArg = forOp.getTiedLoopInit(blockArg)->get();
103 dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
131 template <
typename OpTy>
135 LogicalResult matchAndRewrite(OpTy dimOp,
137 auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>();
140 auto opResult = cast<OpResult>(dimOp.getSource());
141 unsigned resultNumber = opResult.getResultNumber();
145 dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]);
153 template <
typename OpTy>
157 LogicalResult matchAndRewrite(OpTy op,
163 struct SCFForLoopCanonicalization
164 :
public impl::SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
165 void runOnOperation()
override {
166 auto *parentOp = getOperation();
180 .
add<AffineOpSCFCanonicalizationPattern<affine::AffineMinOp>,
181 AffineOpSCFCanonicalizationPattern<affine::AffineMaxOp>,
182 DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>,
183 DimOfLoopResultFolder<tensor::DimOp>,
184 DimOfLoopResultFolder<memref::DimOp>>(ctx);
188 return std::make_unique<SCFForLoopCanonicalization>();
static bool isShapePreserving(ForOp forOp, int64_t arg)
A simple, conservative analysis to determine if the loop is shape conserving.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult matchForLikeLoop(Value iv, OpFoldResult &lb, OpFoldResult &ub, OpFoldResult &step)
Match "for loop"-like operations from the SCF dialect.
LogicalResult canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, Operation *op, LoopMatcherFn loopMatcher)
Try to canonicalize the given affine.min/max operation in the context of for loops with a known range...
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns)
Populate patterns for canonicalizing operations inside SCF loop bodies.
Include the generated interface declarations.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createSCFForLoopCanonicalizationPass()
Creates a pass that canonicalizes affine.min and affine.max operations inside of scf....
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...