MLIR  20.0.0git
ResolveShapedTypeResultDims.cpp
Go to the documentation of this file.
1 //===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass resolves `memref.dim` operations of result values in terms of
10 // shapes of their operands using the `InferShapedTypeOpInterface`.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
24 
25 namespace mlir {
26 namespace memref {
27 #define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMS
28 #define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMS
29 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
30 } // namespace memref
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 namespace {
36 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
37 template <typename OpTy>
38 struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
40 
41  LogicalResult matchAndRewrite(OpTy dimOp,
42  PatternRewriter &rewriter) const override {
43  OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
44  if (!dimValue)
45  return failure();
46  auto shapedTypeOp =
47  dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
48  if (!shapedTypeOp)
49  return failure();
50 
51  std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
52  if (!dimIndex)
53  return failure();
54 
55  SmallVector<Value> reifiedResultShapes;
56  if (failed(shapedTypeOp.reifyReturnTypeShapes(
57  rewriter, shapedTypeOp->getOperands(), reifiedResultShapes)))
58  return failure();
59 
60  if (reifiedResultShapes.size() != shapedTypeOp->getNumResults())
61  return failure();
62 
63  Value resultShape = reifiedResultShapes[dimValue.getResultNumber()];
64  auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.getType());
65  if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType()))
66  return failure();
67 
68  Location loc = dimOp->getLoc();
69  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
70  dimOp, resultShape,
71  rewriter.create<arith::ConstantIndexOp>(loc, *dimIndex).getResult());
72  return success();
73  }
74 };
75 
76 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
77 template <typename OpTy>
78 struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
80 
82 
83  LogicalResult matchAndRewrite(OpTy dimOp,
84  PatternRewriter &rewriter) const override {
85  OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
86  if (!dimValue)
87  return failure();
88  std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
89  if (!dimIndex)
90  return failure();
91 
92  ReifiedRankedShapedTypeDims reifiedResultShapes;
93  if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
94  reifiedResultShapes)))
95  return failure();
96  unsigned resultNumber = dimValue.getResultNumber();
97  // Do not apply pattern if the IR is invalid (dim out of bounds).
98  if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
99  return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
101  rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
102  rewriter.replaceOp(dimOp, replacement);
103  return success();
104  }
105 };
106 } // namespace
107 
108 //===----------------------------------------------------------------------===//
109 // Pass registration
110 //===----------------------------------------------------------------------===//
111 
112 namespace {
113 struct ResolveRankedShapeTypeResultDimsPass final
114  : public memref::impl::ResolveRankedShapeTypeResultDimsBase<
115  ResolveRankedShapeTypeResultDimsPass> {
116  void runOnOperation() override;
117 };
118 
119 struct ResolveShapedTypeResultDimsPass final
120  : public memref::impl::ResolveShapedTypeResultDimsBase<
121  ResolveShapedTypeResultDimsPass> {
122  void runOnOperation() override;
123 };
124 
125 } // namespace
126 
128  RewritePatternSet &patterns) {
129  patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
130  DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
131  patterns.getContext());
132 }
133 
135  RewritePatternSet &patterns) {
136  // TODO: Move tensor::DimOp pattern to the Tensor dialect.
137  patterns.add<DimOfShapedTypeOpInterface<memref::DimOp>,
138  DimOfShapedTypeOpInterface<tensor::DimOp>>(
139  patterns.getContext());
140 }
141 
142 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
143  RewritePatternSet patterns(&getContext());
145  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
146  return signalPassFailure();
147 }
148 
149 void ResolveShapedTypeResultDimsPass::runOnOperation() {
150  RewritePatternSet patterns(&getContext());
153  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
154  return signalPassFailure();
155 }
156 
158  return std::make_unique<ResolveShapedTypeResultDimsPass>();
159 }
160 
162  return std::make_unique<ResolveRankedShapeTypeResultDimsPass>();
163 }
static MLIRContext * getContext(OpFoldResult val)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:466
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
Definition: PatternMatch.h:202
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
std::unique_ptr< Pass > createResolveShapedTypeResultDimsPass()
Creates an operation pass to resolve memref.dim operations with values that are defined by operations...
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
std::unique_ptr< Pass > createResolveRankedShapeTypeResultDimsPass()
Creates an operation pass to resolve memref.dim operations with values that are defined by operations...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358