MLIR  16.0.0git
ResolveShapedTypeResultDims.cpp
Go to the documentation of this file.
1 //===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass resolves `memref.dim` operations of result values in terms of
10 // shapes of their operands using the `InferShapedTypeOpInterface`.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
22 
23 namespace mlir {
24 namespace memref {
25 #define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMS
26 #define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMS
27 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
28 } // namespace memref
29 } // namespace mlir
30 
31 using namespace mlir;
32 
33 namespace {
34 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
35 template <typename OpTy>
36 struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
38 
39  LogicalResult matchAndRewrite(OpTy dimOp,
40  PatternRewriter &rewriter) const override {
41  OpResult dimValue = dimOp.getSource().template dyn_cast<OpResult>();
42  if (!dimValue)
43  return failure();
44  auto shapedTypeOp =
45  dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
46  if (!shapedTypeOp)
47  return failure();
48 
49  Optional<int64_t> dimIndex = dimOp.getConstantIndex();
50  if (!dimIndex)
51  return failure();
52 
53  SmallVector<Value> reifiedResultShapes;
54  if (failed(shapedTypeOp.reifyReturnTypeShapes(
55  rewriter, shapedTypeOp->getOperands(), reifiedResultShapes)))
56  return failure();
57 
58  if (reifiedResultShapes.size() != shapedTypeOp->getNumResults())
59  return failure();
60 
61  Value resultShape = reifiedResultShapes[dimValue.getResultNumber()];
62  auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
63  if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
64  return failure();
65 
66  Location loc = dimOp->getLoc();
67  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
68  dimOp, resultShape,
69  rewriter.createOrFold<arith::ConstantIndexOp>(loc, *dimIndex));
70  return success();
71  }
72 };
73 
74 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
75 template <typename OpTy>
76 struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
78 
79  LogicalResult matchAndRewrite(OpTy dimOp,
80  PatternRewriter &rewriter) const override {
81  OpResult dimValue = dimOp.getSource().template dyn_cast<OpResult>();
82  if (!dimValue)
83  return failure();
84  auto rankedShapeTypeOp =
85  dyn_cast<ReifyRankedShapedTypeOpInterface>(dimValue.getOwner());
86  if (!rankedShapeTypeOp)
87  return failure();
88 
89  Optional<int64_t> dimIndex = dimOp.getConstantIndex();
90  if (!dimIndex)
91  return failure();
92 
93  SmallVector<SmallVector<Value>> reifiedResultShapes;
94  if (failed(
95  rankedShapeTypeOp.reifyResultShapes(rewriter, reifiedResultShapes)))
96  return failure();
97 
98  if (reifiedResultShapes.size() != rankedShapeTypeOp->getNumResults())
99  return failure();
100 
101  unsigned resultNumber = dimValue.getResultNumber();
102  auto sourceType = dimValue.getType().dyn_cast<RankedTensorType>();
103  if (reifiedResultShapes[resultNumber].size() !=
104  static_cast<size_t>(sourceType.getRank()))
105  return failure();
106 
107  rewriter.replaceOp(dimOp, reifiedResultShapes[resultNumber][*dimIndex]);
108  return success();
109  }
110 };
111 } // namespace
112 
113 //===----------------------------------------------------------------------===//
114 // Pass registration
115 //===----------------------------------------------------------------------===//
116 
117 namespace {
118 struct ResolveRankedShapeTypeResultDimsPass final
119  : public memref::impl::ResolveRankedShapeTypeResultDimsBase<
120  ResolveRankedShapeTypeResultDimsPass> {
121  void runOnOperation() override;
122 };
123 
124 struct ResolveShapedTypeResultDimsPass final
125  : public memref::impl::ResolveShapedTypeResultDimsBase<
126  ResolveShapedTypeResultDimsPass> {
127  void runOnOperation() override;
128 };
129 
130 } // namespace
131 
133  RewritePatternSet &patterns) {
134  patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
135  DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
136  patterns.getContext());
137 }
138 
140  RewritePatternSet &patterns) {
141  // TODO: Move tensor::DimOp pattern to the Tensor dialect.
142  patterns.add<DimOfShapedTypeOpInterface<memref::DimOp>,
143  DimOfShapedTypeOpInterface<tensor::DimOp>>(
144  patterns.getContext());
145 }
146 
147 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
148  RewritePatternSet patterns(&getContext());
150  if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
151  std::move(patterns))))
152  return signalPassFailure();
153 }
154 
155 void ResolveShapedTypeResultDimsPass::runOnOperation() {
156  RewritePatternSet patterns(&getContext());
159  if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
160  std::move(patterns))))
161  return signalPassFailure();
162 }
163 
165  return std::make_unique<ResolveShapedTypeResultDimsPass>();
166 }
167 
169  return std::make_unique<ResolveRankedShapeTypeResultDimsPass>();
170 }
Include the generated interface declarations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:470
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
This is a value defined by a result of an operation.
Definition: Value.h:446
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:455
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
U dyn_cast() const
Definition: Types.h:268
void populateResolveRankedShapeTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:458
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
std::unique_ptr< Pass > createResolveRankedShapeTypeResultDimsPass()
Creates an operation pass to resolve memref.dim operations with values that are defined by operations...
Type getType() const
Return the type of this value.
Definition: Value.h:118
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:80
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
MLIRContext * getContext() const
std::unique_ptr< Pass > createResolveShapedTypeResultDimsPass()
Creates an operation pass to resolve memref.dim operations with values that are defined by operations...