MLIR  20.0.0git
ResolveShapedTypeResultDims.cpp
Go to the documentation of this file.
1 //===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass resolves `memref.dim` operations of result values in terms of
10 // shapes of their operands using the `InferShapedTypeOpInterface`.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
25 
26 namespace mlir {
27 namespace memref {
28 #define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMS
29 #define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMS
30 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
31 } // namespace memref
32 } // namespace mlir
33 
34 using namespace mlir;
35 
36 namespace {
37 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
38 template <typename OpTy>
39 struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
41 
42  LogicalResult matchAndRewrite(OpTy dimOp,
43  PatternRewriter &rewriter) const override {
44  OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
45  if (!dimValue)
46  return failure();
47  auto shapedTypeOp =
48  dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
49  if (!shapedTypeOp)
50  return failure();
51 
52  std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
53  if (!dimIndex)
54  return failure();
55 
56  SmallVector<Value> reifiedResultShapes;
57  if (failed(shapedTypeOp.reifyReturnTypeShapes(
58  rewriter, shapedTypeOp->getOperands(), reifiedResultShapes)))
59  return failure();
60 
61  if (reifiedResultShapes.size() != shapedTypeOp->getNumResults())
62  return failure();
63 
64  Value resultShape = reifiedResultShapes[dimValue.getResultNumber()];
65  auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.getType());
66  if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType()))
67  return failure();
68 
69  Location loc = dimOp->getLoc();
70  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
71  dimOp, resultShape,
72  rewriter.create<arith::ConstantIndexOp>(loc, *dimIndex).getResult());
73  return success();
74  }
75 };
76 
77 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
78 template <typename OpTy>
79 struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
81 
83 
84  LogicalResult matchAndRewrite(OpTy dimOp,
85  PatternRewriter &rewriter) const override {
86  OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
87  if (!dimValue)
88  return failure();
89  std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
90  if (!dimIndex)
91  return failure();
92 
93  ReifiedRankedShapedTypeDims reifiedResultShapes;
94  if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
95  reifiedResultShapes)))
96  return failure();
97  unsigned resultNumber = dimValue.getResultNumber();
98  // Do not apply pattern if the IR is invalid (dim out of bounds).
99  if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
100  return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
102  rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
103  rewriter.replaceOp(dimOp, replacement);
104  return success();
105  }
106 };
107 
108 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
109 ///
110 /// ```
111 /// %0 = ... : tensor<?x?xf32>
112 /// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
113 /// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
114 /// ...
115 /// }
116 /// ```
117 ///
118 /// is folded to:
119 ///
120 /// ```
121 /// %0 = ... : tensor<?x?xf32>
122 /// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
123 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
124 /// ...
125 /// }
126 /// ```
127 struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
129 
130  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
131  PatternRewriter &rewriter) const final {
132  auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
133  if (!blockArg)
134  return failure();
135  // TODO: Enable this for loopLikeInterface. Restricting for scf.for
136  // because the init args shape might change in the loop body.
137  // For e.g.:
138  // ```
139  // %0 = tensor.empty(%c1) : tensor<?xf32>
140  // %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) ->
141  // tensor<?xf32> {
142  // %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
143  // %2 = arith.addi %c1, %1 : index
144  // %3 = tensor.empty(%2) : tensor<?xf32>
145  // scf.yield %3 : tensor<?xf32>
146  // }
147  //
148  // ```
149  auto forAllOp =
150  dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
151  if (!forAllOp)
152  return failure();
153  Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
154  rewriter.modifyOpInPlace(
155  dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
156  return success();
157  }
158 };
159 } // namespace
160 
161 //===----------------------------------------------------------------------===//
162 // Pass registration
163 //===----------------------------------------------------------------------===//
164 
165 namespace {
166 struct ResolveRankedShapeTypeResultDimsPass final
167  : public memref::impl::ResolveRankedShapeTypeResultDimsBase<
168  ResolveRankedShapeTypeResultDimsPass> {
169  void runOnOperation() override;
170 };
171 
172 struct ResolveShapedTypeResultDimsPass final
173  : public memref::impl::ResolveShapedTypeResultDimsBase<
174  ResolveShapedTypeResultDimsPass> {
175  void runOnOperation() override;
176 };
177 
178 } // namespace
179 
182  patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
183  DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>,
184  IterArgsToInitArgs>(patterns.getContext());
185 }
186 
189  // TODO: Move tensor::DimOp pattern to the Tensor dialect.
190  patterns.add<DimOfShapedTypeOpInterface<memref::DimOp>,
191  DimOfShapedTypeOpInterface<tensor::DimOp>>(
192  patterns.getContext());
193 }
194 
195 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
198  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
199  return signalPassFailure();
200 }
201 
202 void ResolveShapedTypeResultDimsPass::runOnOperation() {
206  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
207  return signalPassFailure();
208 }
209 
211  return std::make_unique<ResolveShapedTypeResultDimsPass>();
212 }
213 
215  return std::make_unique<ResolveRankedShapeTypeResultDimsPass>();
216 }
static MLIRContext * getContext(OpFoldResult val)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:466
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
Definition: PatternMatch.h:202
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
std::unique_ptr< Pass > createResolveShapedTypeResultDimsPass()
Creates an operation pass to resolve memref.dim operations with values that are defined by operations...
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
std::unique_ptr< Pass > createResolveRankedShapeTypeResultDimsPass()
Creates an operation pass to resolve memref.dim operations with values that are defined by operations...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358