MLIR  16.0.0git
LoopCanonicalization.cpp
Go to the documentation of this file.
1 //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains cross-dialect canonicalization patterns that cannot be
10 // actual canonicalization patterns due to undesired additional dependencies.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
22 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 using namespace mlir;
27 using namespace mlir::scf;
28 
29 /// A simple, conservative analysis to determine if the loop is shape
30 /// conserving. I.e., the type of the arg-th yielded value is the same as the
31 /// type of the corresponding basic block argument of the loop.
32 /// Note: This function handles only simple cases. Expand as needed.
33 static bool isShapePreserving(ForOp forOp, int64_t arg) {
34  auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
35  assert(arg < static_cast<int64_t>(yieldOp.getResults().size()) &&
36  "arg is out of bounds");
37  Value value = yieldOp.getResults()[arg];
38  while (value) {
39  if (value == forOp.getRegionIterArgs()[arg])
40  return true;
41  OpResult opResult = value.dyn_cast<OpResult>();
42  if (!opResult)
43  return false;
44 
45  using tensor::InsertSliceOp;
46  value =
48  .template Case<InsertSliceOp>(
49  [&](InsertSliceOp op) { return op.getDest(); })
50  .template Case<ForOp>([&](ForOp forOp) {
51  return isShapePreserving(forOp, opResult.getResultNumber())
52  ? forOp.getIterOperands()[opResult.getResultNumber()]
53  : Value();
54  })
55  .Default([&](auto op) { return Value(); });
56  }
57  return false;
58 }
59 
60 namespace {
61 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
62 ///
63 /// ```
64 /// %0 = ... : tensor<?x?xf32>
65 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
66 /// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
67 /// ...
68 /// }
69 /// ```
70 ///
71 /// is folded to:
72 ///
73 /// ```
74 /// %0 = ... : tensor<?x?xf32>
75 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
76 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
77 /// ...
78 /// }
79 /// ```
80 ///
81 /// Note: Dim ops are folded only if it can be proven that the runtime type of
82 /// the iter arg does not change with loop iterations.
83 template <typename OpTy>
84 struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
86 
87  LogicalResult matchAndRewrite(OpTy dimOp,
88  PatternRewriter &rewriter) const override {
89  auto blockArg = dimOp.getSource().template dyn_cast<BlockArgument>();
90  if (!blockArg)
91  return failure();
92  auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
93  if (!forOp)
94  return failure();
95  if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1))
96  return failure();
97 
98  Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get();
99  rewriter.updateRootInPlace(
100  dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
101 
102  return success();
103  };
104 };
105 
106 /// Fold dim ops of loop results to dim ops of their respective init args. E.g.:
107 ///
108 /// ```
109 /// %0 = ... : tensor<?x?xf32>
110 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
111 /// ...
112 /// }
113 /// %1 = tensor.dim %r, %c0 : tensor<?x?xf32>
114 /// ```
115 ///
116 /// is folded to:
117 ///
118 /// ```
119 /// %0 = ... : tensor<?x?xf32>
120 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
121 /// ...
122 /// }
123 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
124 /// ```
125 ///
126 /// Note: Dim ops are folded only if it can be proven that the runtime type of
127 /// the iter arg does not change with loop iterations.
128 template <typename OpTy>
129 struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
131 
132  LogicalResult matchAndRewrite(OpTy dimOp,
133  PatternRewriter &rewriter) const override {
134  auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>();
135  if (!forOp)
136  return failure();
137  auto opResult = dimOp.getSource().template cast<OpResult>();
138  unsigned resultNumber = opResult.getResultNumber();
139  if (!isShapePreserving(forOp, resultNumber))
140  return failure();
141  rewriter.updateRootInPlace(dimOp, [&]() {
142  dimOp.getSourceMutable().assign(forOp.getIterOperands()[resultNumber]);
143  });
144  return success();
145  }
146 };
147 
148 /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
149 /// and scf.parallel loops with a known range.
150 template <typename OpTy, bool IsMin>
151 struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
153 
154  LogicalResult matchAndRewrite(OpTy op,
155  PatternRewriter &rewriter) const override {
156  auto loopMatcher = [](Value iv, OpFoldResult &lb, OpFoldResult &ub,
157  OpFoldResult &step) {
158  if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) {
159  lb = forOp.getLowerBound();
160  ub = forOp.getUpperBound();
161  step = forOp.getStep();
162  return success();
163  }
164  if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) {
165  for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) {
166  if (parOp.getInductionVars()[idx] == iv) {
167  lb = parOp.getLowerBound()[idx];
168  ub = parOp.getUpperBound()[idx];
169  step = parOp.getStep()[idx];
170  return success();
171  }
172  }
173  return failure();
174  }
175  if (scf::ForeachThreadOp foreachThreadOp =
177  for (int64_t idx = 0; idx < foreachThreadOp.getRank(); ++idx) {
178  if (foreachThreadOp.getThreadIndices()[idx] == iv) {
179  lb = OpBuilder(iv.getContext()).getIndexAttr(0);
180  ub = foreachThreadOp.getNumThreads()[idx];
181  step = OpBuilder(iv.getContext()).getIndexAttr(1);
182  return success();
183  }
184  }
185  return failure();
186  }
187  return failure();
188  };
189 
190  return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(),
191  op.operands(), IsMin, loopMatcher);
192  }
193 };
194 
195 struct SCFForLoopCanonicalization
196  : public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
197  void runOnOperation() override {
198  auto *parentOp = getOperation();
199  MLIRContext *ctx = parentOp->getContext();
200  RewritePatternSet patterns(ctx);
202  if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns))))
203  signalPassFailure();
204  }
205 };
206 } // namespace
207 
209  RewritePatternSet &patterns) {
210  MLIRContext *ctx = patterns.getContext();
211  patterns
212  .add<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>,
213  AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>,
214  DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>,
215  DimOfLoopResultFolder<tensor::DimOp>,
216  DimOfLoopResultFolder<memref::DimOp>>(ctx);
217 }
218 
220  return std::make_unique<SCFForLoopCanonicalization>();
221 }
Include the generated interface declarations.
std::unique_ptr< Pass > createSCFForLoopCanonicalizationPass()
Creates a pass that canonicalizes affine.min and affine.max operations inside of scf.for loops with known lower and upper bounds.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
This is a value defined by a result of an operation.
Definition: Value.h:425
This class represents a single result from folding an operation.
Definition: OpDefinition.h:235
LogicalResult canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, Operation *op, AffineMap map, ValueRange operands, bool isMin, LoopMatcherFn loopMatcher)
Try to canonicalize an min/max operations in the context of for loops with a known range...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static bool isShapePreserving(ForOp forOp, int64_t arg)
A simple, conservative analysis to determine if the loop is shape conserving.
static constexpr const bool value
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:434
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
U dyn_cast() const
Definition: Value.h:100
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:437
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:2292
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
ForeachThreadOp getForeachThreadOpThreadIndexOwner(Value val)
Returns the ForeachThreadOp parent of an thread index variable.
Definition: SCF.cpp:1205
This class represents an argument of a Block.
Definition: Value.h:300
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns)
Populate patterns for canonicalizing operations inside SCF loop bodies.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
This class helps build Operations.
Definition: Builders.h:192
MLIRContext * getContext() const
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:121