MLIR 22.0.0git
ResolveShapedTypeResultDims.cpp
Go to the documentation of this file.
1//===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass resolves `memref.dim` operations of result values in terms of
10// shapes of their operands using the `InferShapedTypeOpInterface`.
11//
12//===----------------------------------------------------------------------===//
13
15
25
26namespace mlir {
27namespace memref {
28#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
29#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
30#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
31} // namespace memref
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37/// Fold dim of an operation that implements the InferShapedTypeOpInterface
38template <typename OpTy>
39struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
40 using OpRewritePattern<OpTy>::OpRewritePattern;
41
42 LogicalResult matchAndRewrite(OpTy dimOp,
43 PatternRewriter &rewriter) const override {
44 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
45 if (!dimValue)
46 return failure();
47 auto shapedTypeOp =
48 dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
49 if (!shapedTypeOp)
50 return failure();
51
52 std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
53 if (!dimIndex)
54 return failure();
55
56 SmallVector<Value> reifiedResultShapes;
57 if (failed(shapedTypeOp.reifyReturnTypeShapes(
58 rewriter, shapedTypeOp->getOperands(), reifiedResultShapes)))
59 return failure();
60
61 if (reifiedResultShapes.size() != shapedTypeOp->getNumResults())
62 return failure();
63
64 Value resultShape = reifiedResultShapes[dimValue.getResultNumber()];
65 auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.getType());
66 if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType()))
67 return failure();
68
69 Location loc = dimOp->getLoc();
70 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
71 dimOp, resultShape,
72 arith::ConstantIndexOp::create(rewriter, loc, *dimIndex).getResult());
73 return success();
74 }
75};
76
77/// Fold dim of an operation that implements the InferShapedTypeOpInterface
78template <typename OpTy>
79struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
80 using OpRewritePattern<OpTy>::OpRewritePattern;
81
82 void initialize() { OpRewritePattern<OpTy>::setHasBoundedRewriteRecursion(); }
83
84 LogicalResult matchAndRewrite(OpTy dimOp,
85 PatternRewriter &rewriter) const override {
86 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
87 if (!dimValue)
88 return failure();
89 std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
90 if (!dimIndex)
91 return failure();
92
93 FailureOr<OpFoldResult> replacement = reifyDimOfResult(
94 rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex);
96 return failure();
97 // Check if the OpFoldResult is empty (unreifiable dimension).
98 if (!replacement.value())
99 return failure();
100 Value replacementVal = getValueOrCreateConstantIndexOp(
101 rewriter, dimOp.getLoc(), replacement.value());
102 rewriter.replaceOp(dimOp, replacementVal);
103 return success();
104 }
105};
106
107/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
108///
109/// ```
110/// %0 = ... : tensor<?x?xf32>
111/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
112/// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
113/// ...
114/// }
115/// ```
116///
117/// is folded to:
118///
119/// ```
120/// %0 = ... : tensor<?x?xf32>
121/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
122/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
123/// ...
124/// }
125/// ```
126struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
127 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
128
129 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
130 PatternRewriter &rewriter) const final {
131 auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
132 if (!blockArg)
133 return failure();
134 // TODO: Enable this for loopLikeInterface. Restricting for scf.for
135 // because the init args shape might change in the loop body.
136 // For e.g.:
137 // ```
138 // %0 = tensor.empty(%c1) : tensor<?xf32>
139 // %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) ->
140 // tensor<?xf32> {
141 // %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
142 // %2 = arith.addi %c1, %1 : index
143 // %3 = tensor.empty(%2) : tensor<?xf32>
144 // scf.yield %3 : tensor<?xf32>
145 // }
146 //
147 // ```
148 auto forAllOp =
149 dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
150 if (!forAllOp)
151 return failure();
152 Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
153 rewriter.modifyOpInPlace(
154 dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
155 return success();
156 }
157};
158} // namespace
159
160//===----------------------------------------------------------------------===//
161// Pass registration
162//===----------------------------------------------------------------------===//
163
164namespace {
165struct ResolveRankedShapeTypeResultDimsPass final
167 ResolveRankedShapeTypeResultDimsPass> {
168 using Base::Base;
169 void runOnOperation() override;
170};
171
172struct ResolveShapedTypeResultDimsPass final
174 ResolveShapedTypeResultDimsPass> {
175 using Base::Base;
176 void runOnOperation() override;
177};
178
179} // namespace
180
183 patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
184 DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>,
185 IterArgsToInitArgs>(patterns.getContext());
186}
187
190 // TODO: Move tensor::DimOp pattern to the Tensor dialect.
191 patterns.add<DimOfShapedTypeOpInterface<memref::DimOp>,
192 DimOfShapedTypeOpInterface<tensor::DimOp>>(
193 patterns.getContext());
194}
195
196void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
199 auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
200 if (errorOnPatternIterationLimit && failed(result)) {
201 getOperation()->emitOpError(
202 "dim operation resolution hit pattern iteration limit");
203 return signalPassFailure();
204 }
205}
206
207void ResolveShapedTypeResultDimsPass::runOnOperation() {
208 RewritePatternSet patterns(&getContext());
211 auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
212 if (errorOnPatternIterationLimit && failed(result)) {
213 getOperation()->emitOpError(
214 "dim operation resolution hit pattern iteration limit");
215 return signalPassFailure();
216 }
217}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Operation * getOwner() const
Returns the operation that owns this result.
Definition Value.h:466
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
FailureOr< OpFoldResult > reifyDimOfResult(OpBuilder &b, Operation *op, int resultIndex, int dim)
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...