MLIR 23.0.0git
ResolveShapedTypeResultDims.cpp
Go to the documentation of this file.
1//===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass resolves `memref.dim` operations of result values in terms of
10// shapes of their operands using the `InferShapedTypeOpInterface`.
11//
12//===----------------------------------------------------------------------===//
13
15
25
26namespace mlir {
27namespace memref {
28#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
29#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
30#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
31} // namespace memref
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37/// Fold dim of an operation that implements the InferShapedTypeOpInterface
38template <typename OpTy>
39struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
40 using OpRewritePattern<OpTy>::OpRewritePattern;
41
42 LogicalResult matchAndRewrite(OpTy dimOp,
43 PatternRewriter &rewriter) const override {
44 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
45 if (!dimValue)
46 return failure();
47 auto shapedTypeOp =
48 dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
49 if (!shapedTypeOp)
50 return failure();
51
52 std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
53 if (!dimIndex)
54 return failure();
55
56 SmallVector<Value> reifiedResultShapes;
57 if (failed(shapedTypeOp.reifyReturnTypeShapes(
58 rewriter, shapedTypeOp->getOperands(), reifiedResultShapes)))
59 return failure();
60
61 if (reifiedResultShapes.size() != shapedTypeOp->getNumResults())
62 return failure();
63
64 Value resultShape = reifiedResultShapes[dimValue.getResultNumber()];
65 auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.getType());
66 if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType()))
67 return failure();
68
69 Location loc = dimOp->getLoc();
70 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
71 dimOp, resultShape,
72 arith::ConstantIndexOp::create(rewriter, loc, *dimIndex).getResult());
73 return success();
74 }
75};
76
77/// Fold dim of an operation that implements the InferShapedTypeOpInterface
78template <typename OpTy>
79struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
80 using OpRewritePattern<OpTy>::OpRewritePattern;
81
82 void initialize() { OpRewritePattern<OpTy>::setHasBoundedRewriteRecursion(); }
83
84 LogicalResult matchAndRewrite(OpTy dimOp,
85 PatternRewriter &rewriter) const override {
86 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
87 if (!dimValue)
88 return failure();
89 std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
90 if (!dimIndex)
91 return failure();
92
93 // Save the op immediately before dimOp so we can identify and erase any
94 // ops inserted during the reification attempt if it fails. The
95 // pattern-rewrite invariant requires the IR to be unchanged on failure.
96 Operation *opBeforeReify = dimOp->getPrevNode();
97
98 // Erase any ops inserted between opBeforeReify and dimOp in reverse order
99 // to respect use-def chains within that range. Collect pointers first to
100 // avoid iterator invalidation: erasing a node in an ilist invalidates
101 // iterators to that node, and std::reverse_iterator stores the iterator to
102 // the *next* forward element, so make_early_inc_range(reverse(...)) would
103 // still dereference a stale iterator after erasure.
104 auto eraseInsertedOps = [&]() {
105 Block::iterator begin = opBeforeReify
106 ? std::next(opBeforeReify->getIterator())
107 : dimOp->getBlock()->begin();
108 SmallVector<Operation *> toErase;
109 for (Block::iterator it = begin; it != dimOp->getIterator(); ++it)
110 toErase.push_back(&*it);
111 for (Operation *op : llvm::reverse(toErase))
112 rewriter.eraseOp(op);
113 };
114
115 FailureOr<OpFoldResult> replacement = reifyDimOfResult(
116 rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex);
117 // An empty (or failed) OpFoldResult signals that this specific dimension
118 // cannot be reified. Some implementations materialize all dimensions at
119 // once (e.g. via reifyResultShapes) and may create ops for other dimensions
120 // before discovering that this dimension is not reifiable. Erase those
121 // stray ops before returning failure.
122 if (failed(replacement) || !replacement.value()) {
123 eraseInsertedOps();
124 return failure();
125 }
126 Value replacementVal = getValueOrCreateConstantIndexOp(
127 rewriter, dimOp.getLoc(), replacement.value());
128 rewriter.replaceOp(dimOp, replacementVal);
129 return success();
130 }
131};
132
133/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
134///
135/// ```
136/// %0 = ... : tensor<?x?xf32>
137/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
138/// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
139/// ...
140/// }
141/// ```
142///
143/// is folded to:
144///
145/// ```
146/// %0 = ... : tensor<?x?xf32>
147/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
148/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
149/// ...
150/// }
151/// ```
152struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
153 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
154
155 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
156 PatternRewriter &rewriter) const final {
157 auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
158 if (!blockArg)
159 return failure();
160 // TODO: Enable this for loopLikeInterface. Restricting for scf.for
161 // because the init args shape might change in the loop body.
162 // For e.g.:
163 // ```
164 // %0 = tensor.empty(%c1) : tensor<?xf32>
165 // %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) ->
166 // tensor<?xf32> {
167 // %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
168 // %2 = arith.addi %c1, %1 : index
169 // %3 = tensor.empty(%2) : tensor<?xf32>
170 // scf.yield %3 : tensor<?xf32>
171 // }
172 //
173 // ```
174 auto forAllOp =
175 dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
176 if (!forAllOp)
177 return failure();
178 Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
179 rewriter.modifyOpInPlace(
180 dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
181 return success();
182 }
183};
184} // namespace
185
186//===----------------------------------------------------------------------===//
187// Pass registration
188//===----------------------------------------------------------------------===//
189
190namespace {
191struct ResolveRankedShapeTypeResultDimsPass final
193 ResolveRankedShapeTypeResultDimsPass> {
194 using Base::Base;
195 void runOnOperation() override;
196};
197
198struct ResolveShapedTypeResultDimsPass final
200 ResolveShapedTypeResultDimsPass> {
201 using Base::Base;
202 void runOnOperation() override;
203};
204
205} // namespace
206
208 RewritePatternSet &patterns) {
209 patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
210 DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>,
211 IterArgsToInitArgs>(patterns.getContext());
212}
213
215 RewritePatternSet &patterns) {
216 // TODO: Move tensor::DimOp pattern to the Tensor dialect.
217 patterns.add<DimOfShapedTypeOpInterface<memref::DimOp>,
218 DimOfShapedTypeOpInterface<tensor::DimOp>>(
219 patterns.getContext());
220}
221
222void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
223 RewritePatternSet patterns(&getContext());
225 auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
226 if (errorOnPatternIterationLimit && failed(result)) {
227 getOperation()->emitOpError(
228 "dim operation resolution hit pattern iteration limit");
229 return signalPassFailure();
230 }
231}
232
233void ResolveShapedTypeResultDimsPass::runOnOperation() {
234 RewritePatternSet patterns(&getContext());
237 auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
238 if (errorOnPatternIterationLimit && failed(result)) {
239 getOperation()->emitOpError(
240 "dim operation resolution hit pattern iteration limit");
241 return signalPassFailure();
242 }
243}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
OpListType::iterator iterator
Definition Block.h:150
Operation * getOwner() const
Returns the operation that owns this result.
Definition Value.h:463
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:466
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
FailureOr< OpFoldResult > reifyDimOfResult(OpBuilder &b, Operation *op, int resultIndex, int dim)
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...