28 #define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMS
29 #define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMS
30 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
38 template <
typename OpTy>
42 LogicalResult matchAndRewrite(OpTy dimOp,
44 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
48 dyn_cast<InferShapedTypeOpInterface>(dimValue.
getOwner());
52 std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
57 if (failed(shapedTypeOp.reifyReturnTypeShapes(
58 rewriter, shapedTypeOp->getOperands(), reifiedResultShapes)))
61 if (reifiedResultShapes.size() != shapedTypeOp->getNumResults())
65 auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.
getType());
66 if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType()))
72 rewriter.
create<arith::ConstantIndexOp>(loc, *dimIndex).getResult());
78 template <
typename OpTy>
79 struct DimOfReifyRankedShapedTypeOpInterface :
public OpRewritePattern<OpTy> {
84 LogicalResult matchAndRewrite(OpTy dimOp,
86 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
89 std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
95 reifiedResultShapes)))
99 if ((
size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
102 rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
130 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
132 auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
150 dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
153 Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
155 dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
166 struct ResolveRankedShapeTypeResultDimsPass final
167 :
public memref::impl::ResolveRankedShapeTypeResultDimsBase<
168 ResolveRankedShapeTypeResultDimsPass> {
169 void runOnOperation()
override;
172 struct ResolveShapedTypeResultDimsPass final
173 :
public memref::impl::ResolveShapedTypeResultDimsBase<
174 ResolveShapedTypeResultDimsPass> {
175 void runOnOperation()
override;
182 patterns.
add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
183 DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>,
190 patterns.
add<DimOfShapedTypeOpInterface<memref::DimOp>,
191 DimOfShapedTypeOpInterface<tensor::DimOp>>(
195 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
199 return signalPassFailure();
202 void ResolveShapedTypeResultDimsPass::runOnOperation() {
207 return signalPassFailure();
211 return std::make_unique<ResolveShapedTypeResultDimsPass>();
215 return std::make_unique<ResolveRankedShapeTypeResultDimsPass>();
static MLIRContext * getContext(OpFoldResult val)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
std::unique_ptr< Pass > createResolveShapedTypeResultDimsPass()
Creates an operation pass to resolve memref.dim operations with values that are defined by operations...
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
std::unique_ptr< Pass > createResolveRankedShapeTypeResultDimsPass()
Creates an operation pass to resolve memref.dim operations with values that are defined by operations...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...