28#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
29#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
30#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
38template <
typename OpTy>
40 using OpRewritePattern<OpTy>::OpRewritePattern;
42 LogicalResult matchAndRewrite(OpTy dimOp,
43 PatternRewriter &rewriter)
const override {
44 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
48 dyn_cast<InferShapedTypeOpInterface>(dimValue.
getOwner());
52 std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
56 SmallVector<Value> reifiedResultShapes;
57 if (
failed(shapedTypeOp.reifyReturnTypeShapes(
58 rewriter, shapedTypeOp->getOperands(), reifiedResultShapes)))
61 if (reifiedResultShapes.size() != shapedTypeOp->getNumResults())
65 auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.
getType());
66 if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType()))
69 Location loc = dimOp->getLoc();
78template <
typename OpTy>
79struct DimOfReifyRankedShapedTypeOpInterface :
public OpRewritePattern<OpTy> {
80 using OpRewritePattern<OpTy>::OpRewritePattern;
82 void initialize() { OpRewritePattern<OpTy>::setHasBoundedRewriteRecursion(); }
84 LogicalResult matchAndRewrite(OpTy dimOp,
85 PatternRewriter &rewriter)
const override {
86 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
89 std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
102 rewriter.
replaceOp(dimOp, replacementVal);
127 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
129 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
130 PatternRewriter &rewriter)
const final {
131 auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
149 dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
152 Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
154 dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
165struct ResolveRankedShapeTypeResultDimsPass final
167 ResolveRankedShapeTypeResultDimsPass> {
169 void runOnOperation()
override;
172struct ResolveShapedTypeResultDimsPass final
174 ResolveShapedTypeResultDimsPass> {
176 void runOnOperation()
override;
183 patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
184 DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>,
185 IterArgsToInitArgs>(
patterns.getContext());
191 patterns.add<DimOfShapedTypeOpInterface<memref::DimOp>,
192 DimOfShapedTypeOpInterface<tensor::DimOp>>(
196void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
200 if (errorOnPatternIterationLimit && failed(
result)) {
201 getOperation()->emitOpError(
202 "dim operation resolution hit pattern iteration limit");
203 return signalPassFailure();
207void ResolveShapedTypeResultDimsPass::runOnOperation() {
213 getOperation()->emitOpError(
214 "dim operation resolution hit pattern iteration limit");
215 return signalPassFailure();
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
FailureOr< OpFoldResult > reifyDimOfResult(OpBuilder &b, Operation *op, int resultIndex, int dim)
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...