1 //===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass resolves `memref.dim` operations of result values in terms of
10 // shapes of their operands using the `InferShapedTypeOpInterface`.
11 //
12 //===----------------------------------------------------------------------===//
26 namespace mlir {
27 namespace memref {
30 #include "mlir/Dialect/MemRef/Transforms/"
31 } // namespace memref
32 } // namespace mlir
34 using namespace mlir;
36 namespace {
37 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
38 template <typename OpTy>
39 struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
42  LogicalResult matchAndRewrite(OpTy dimOp,
43  PatternRewriter &rewriter) const override {
44  OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
45  if (!dimValue)
46  return failure();
47  auto shapedTypeOp =
48  dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
49  if (!shapedTypeOp)
50  return failure();
52  std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
53  if (!dimIndex)
54  return failure();
56  SmallVector<Value> reifiedResultShapes;
57  if (failed(shapedTypeOp.reifyReturnTypeShapes(
58  rewriter, shapedTypeOp->getOperands(), reifiedResultShapes)))
59  return failure();
61  if (reifiedResultShapes.size() != shapedTypeOp->getNumResults())
62  return failure();
64  Value resultShape = reifiedResultShapes[dimValue.getResultNumber()];
65  auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.getType());
66  if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType()))
67  return failure();
69  Location loc = dimOp->getLoc();
70  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
71  dimOp, resultShape,
72  rewriter.create<arith::ConstantIndexOp>(loc, *dimIndex).getResult());
73  return success();
74  }
75 };
77 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
78 template <typename OpTy>
79 struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
84  LogicalResult matchAndRewrite(OpTy dimOp,
85  PatternRewriter &rewriter) const override {
86  OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
87  if (!dimValue)
88  return failure();
89  std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
90  if (!dimIndex)
91  return failure();
93  ReifiedRankedShapedTypeDims reifiedResultShapes;
94  if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
95  reifiedResultShapes)))
96  return failure();
97  unsigned resultNumber = dimValue.getResultNumber();
98  // Do not apply pattern if the IR is invalid (dim out of bounds).
99  if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
100  return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
102  rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
103  rewriter.replaceOp(dimOp, replacement);
104  return success();
105  }
106 };
108 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
109 ///
110 /// ```
111 /// %0 = ... : tensor<?x?xf32>
112 /// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
113 /// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
114 /// ...
115 /// }
116 /// ```
117 ///
118 /// is folded to:
119 ///
120 /// ```
121 /// %0 = ... : tensor<?x?xf32>
122 /// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
123 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
124 /// ...
125 /// }
126 /// ```
127 struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
130  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
131  PatternRewriter &rewriter) const final {
132  auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
133  if (!blockArg)
134  return failure();
135  // TODO: Enable this for loopLikeInterface. Restricting for scf.for
136  // because the init args shape might change in the loop body.
137  // For e.g.:
138  // ```
139  // %0 = tensor.empty(%c1) : tensor<?xf32>
140  // %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) ->
141  // tensor<?xf32> {
142  // %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
143  // %2 = arith.addi %c1, %1 : index
144  // %3 = tensor.empty(%2) : tensor<?xf32>
145  // scf.yield %3 : tensor<?xf32>
146  // }
147  //
148  // ```
149  auto forAllOp =
150  dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
151  if (!forAllOp)
152  return failure();
153  Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
154  rewriter.modifyOpInPlace(
155  dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
156  return success();
157  }
158 };
159 } // namespace
161 //===----------------------------------------------------------------------===//
162 // Pass registration
163 //===----------------------------------------------------------------------===//
165 namespace {
166 struct ResolveRankedShapeTypeResultDimsPass final
167  : public memref::impl::ResolveRankedShapeTypeResultDimsBase<
168  ResolveRankedShapeTypeResultDimsPass> {
169  void runOnOperation() override;
170 };
172 struct ResolveShapedTypeResultDimsPass final
173  : public memref::impl::ResolveShapedTypeResultDimsBase<
174  ResolveShapedTypeResultDimsPass> {
175  void runOnOperation() override;
176 };
178 } // namespace
181  RewritePatternSet &patterns) {
182  patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
183  DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>,
184  IterArgsToInitArgs>(patterns.getContext());
185 }
188  RewritePatternSet &patterns) {
189  // TODO: Move tensor::DimOp pattern to the Tensor dialect.
190  patterns.add<DimOfShapedTypeOpInterface<memref::DimOp>,
191  DimOfShapedTypeOpInterface<tensor::DimOp>>(
192  patterns.getContext());
193 }
195 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
196  RewritePatternSet patterns(&getContext());
198  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
199  return signalPassFailure();
200 }
202 void ResolveShapedTypeResultDimsPass::runOnOperation() {
203  RewritePatternSet patterns(&getContext());
206  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
207  return signalPassFailure();
208 }
211  return std::make_unique<ResolveShapedTypeResultDimsPass>();
212 }
215  return std::make_unique<ResolveRankedShapeTypeResultDimsPass>();
216 }
