24#include "llvm/ADT/TypeSwitch.h"
27#define GEN_PASS_DEF_SCFFORLOOPCANONICALIZATION
28#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
39 assert(arg <
static_cast<int64_t>(forOp.getNumResults()) &&
40 "arg is out of bounds");
41 Value value = forOp.getYieldedValues()[arg];
43 if (value == forOp.getRegionIterArgs()[arg])
45 OpResult opResult = dyn_cast<OpResult>(value);
49 using tensor::InsertSliceOp;
51 .Case([&](InsertSliceOp op) {
return op.getDest(); })
52 .Case([&](ForOp forOp) {
85template <
typename OpTy>
89 LogicalResult matchAndRewrite(OpTy dimOp,
91 auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
94 auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
100 Value initArg = forOp.getTiedLoopInit(blockArg)->get();
102 dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
130template <
typename OpTy>
134 LogicalResult matchAndRewrite(OpTy dimOp,
136 auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>();
139 auto opResult = cast<OpResult>(dimOp.getSource());
144 dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]);
152template <
typename OpTy>
154 using OpRewritePattern<OpTy>::OpRewritePattern;
156 LogicalResult matchAndRewrite(OpTy op,
157 PatternRewriter &rewriter)
const override {
162struct SCFForLoopCanonicalization
164 void runOnOperation()
override {
165 auto *parentOp = getOperation();
166 MLIRContext *ctx = parentOp->getContext();
179 .add<AffineOpSCFCanonicalizationPattern<affine::AffineMinOp>,
180 AffineOpSCFCanonicalizationPattern<affine::AffineMaxOp>,
181 DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>,
182 DimOfLoopResultFolder<tensor::DimOp>,
183 DimOfLoopResultFolder<memref::DimOp>>(ctx);
187 return std::make_unique<SCFForLoopCanonicalization>();
static bool isShapePreserving(ForOp forOp, int64_t arg)
A simple, conservative analysis to determine if the loop is shape conserving.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult matchForLikeLoop(Value iv, OpFoldResult &lb, OpFoldResult &ub, OpFoldResult &step)
Match "for loop"-like operations from the SCF dialect.
LogicalResult canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, Operation *op, LoopMatcherFn loopMatcher)
Try to canonicalize the given affine.min/max operation in the context of for loops with a known range...
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns)
Populate patterns for canonicalizing operations inside SCF loop bodies.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
std::unique_ptr< Pass > createSCFForLoopCanonicalizationPass()
Creates a pass that canonicalizes affine.min and affine.max operations inside of scf....
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...