27 #define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMS
28 #define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMS
29 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
37 template <
typename OpTy>
41 LogicalResult matchAndRewrite(OpTy dimOp,
43 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
47 dyn_cast<InferShapedTypeOpInterface>(dimValue.
getOwner());
51 std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
56 if (failed(shapedTypeOp.reifyReturnTypeShapes(
57 rewriter, shapedTypeOp->getOperands(), reifiedResultShapes)))
60 if (reifiedResultShapes.size() != shapedTypeOp->getNumResults())
64 auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.
getType());
65 if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType()))
71 rewriter.
create<arith::ConstantIndexOp>(loc, *dimIndex).getResult());
77 template <
typename OpTy>
78 struct DimOfReifyRankedShapedTypeOpInterface :
public OpRewritePattern<OpTy> {
83 LogicalResult matchAndRewrite(OpTy dimOp,
85 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
88 std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
94 reifiedResultShapes)))
98 if ((
size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
101 rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
113 struct ResolveRankedShapeTypeResultDimsPass final
114 :
public memref::impl::ResolveRankedShapeTypeResultDimsBase<
115 ResolveRankedShapeTypeResultDimsPass> {
116 void runOnOperation()
override;
119 struct ResolveShapedTypeResultDimsPass final
120 :
public memref::impl::ResolveShapedTypeResultDimsBase<
121 ResolveShapedTypeResultDimsPass> {
122 void runOnOperation()
override;
129 patterns.
add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
130 DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
137 patterns.
add<DimOfShapedTypeOpInterface<memref::DimOp>,
138 DimOfShapedTypeOpInterface<tensor::DimOp>>(
142 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
146 return signalPassFailure();
149 void ResolveShapedTypeResultDimsPass::runOnOperation() {
154 return signalPassFailure();
158 return std::make_unique<ResolveShapedTypeResultDimsPass>();
162 return std::make_unique<ResolveRankedShapeTypeResultDimsPass>();
static MLIRContext * getContext(OpFoldResult val)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
std::unique_ptr< Pass > createResolveShapedTypeResultDimsPass()
Creates an operation pass to resolve memref.dim operations with values that are defined by operations...
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
std::unique_ptr< Pass > createResolveRankedShapeTypeResultDimsPass()
Creates an operation pass to resolve memref.dim operations with values that are defined by operations...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...