MLIR  20.0.0git
LoopCanonicalization.cpp
Go to the documentation of this file.
1 //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains cross-dialect canonicalization patterns that cannot be
10 // actual canonicalization patterns due to undesired additional dependencies.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
22 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 namespace mlir {
27 #define GEN_PASS_DEF_SCFFORLOOPCANONICALIZATION
28 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::scf;
33 
34 /// A simple, conservative analysis to determine if the loop is shape
35 /// conserving. I.e., the type of the arg-th yielded value is the same as the
36 /// type of the corresponding basic block argument of the loop.
37 /// Note: This function handles only simple cases. Expand as needed.
38 static bool isShapePreserving(ForOp forOp, int64_t arg) {
39  assert(arg < static_cast<int64_t>(forOp.getNumResults()) &&
40  "arg is out of bounds");
41  Value value = forOp.getYieldedValues()[arg];
42  while (value) {
43  if (value == forOp.getRegionIterArgs()[arg])
44  return true;
45  OpResult opResult = dyn_cast<OpResult>(value);
46  if (!opResult)
47  return false;
48 
49  using tensor::InsertSliceOp;
51  .template Case<InsertSliceOp>(
52  [&](InsertSliceOp op) { return op.getDest(); })
53  .template Case<ForOp>([&](ForOp forOp) {
54  return isShapePreserving(forOp, opResult.getResultNumber())
55  ? forOp.getInitArgs()[opResult.getResultNumber()]
56  : Value();
57  })
58  .Default([&](auto op) { return Value(); });
59  }
60  return false;
61 }
62 
63 namespace {
64 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
65 ///
66 /// ```
67 /// %0 = ... : tensor<?x?xf32>
68 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
69 /// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
70 /// ...
71 /// }
72 /// ```
73 ///
74 /// is folded to:
75 ///
76 /// ```
77 /// %0 = ... : tensor<?x?xf32>
78 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
79 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
80 /// ...
81 /// }
82 /// ```
83 ///
84 /// Note: Dim ops are folded only if it can be proven that the runtime type of
85 /// the iter arg does not change with loop iterations.
86 template <typename OpTy>
87 struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
89 
90  LogicalResult matchAndRewrite(OpTy dimOp,
91  PatternRewriter &rewriter) const override {
92  auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
93  if (!blockArg)
94  return failure();
95  auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
96  if (!forOp)
97  return failure();
98  if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1))
99  return failure();
100 
101  Value initArg = forOp.getTiedLoopInit(blockArg)->get();
102  rewriter.modifyOpInPlace(
103  dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
104 
105  return success();
106  };
107 };
108 
109 /// Fold dim ops of loop results to dim ops of their respective init args. E.g.:
110 ///
111 /// ```
112 /// %0 = ... : tensor<?x?xf32>
113 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
114 /// ...
115 /// }
116 /// %1 = tensor.dim %r, %c0 : tensor<?x?xf32>
117 /// ```
118 ///
119 /// is folded to:
120 ///
121 /// ```
122 /// %0 = ... : tensor<?x?xf32>
123 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
124 /// ...
125 /// }
126 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
127 /// ```
128 ///
129 /// Note: Dim ops are folded only if it can be proven that the runtime type of
130 /// the iter arg does not change with loop iterations.
131 template <typename OpTy>
132 struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
134 
135  LogicalResult matchAndRewrite(OpTy dimOp,
136  PatternRewriter &rewriter) const override {
137  auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>();
138  if (!forOp)
139  return failure();
140  auto opResult = cast<OpResult>(dimOp.getSource());
141  unsigned resultNumber = opResult.getResultNumber();
142  if (!isShapePreserving(forOp, resultNumber))
143  return failure();
144  rewriter.modifyOpInPlace(dimOp, [&]() {
145  dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]);
146  });
147  return success();
148  }
149 };
150 
151 /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
152 /// and scf.parallel loops with a known range.
153 template <typename OpTy>
154 struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
156 
157  LogicalResult matchAndRewrite(OpTy op,
158  PatternRewriter &rewriter) const override {
160  }
161 };
162 
163 struct SCFForLoopCanonicalization
164  : public impl::SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
165  void runOnOperation() override {
166  auto *parentOp = getOperation();
167  MLIRContext *ctx = parentOp->getContext();
170  if (failed(applyPatternsGreedily(parentOp, std::move(patterns))))
171  signalPassFailure();
172  }
173 };
174 } // namespace
175 
179  patterns
180  .add<AffineOpSCFCanonicalizationPattern<affine::AffineMinOp>,
181  AffineOpSCFCanonicalizationPattern<affine::AffineMaxOp>,
182  DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>,
183  DimOfLoopResultFolder<tensor::DimOp>,
184  DimOfLoopResultFolder<memref::DimOp>>(ctx);
185 }
186 
188  return std::make_unique<SCFForLoopCanonicalization>();
189 }
static bool isShapePreserving(ForOp forOp, int64_t arg)
A simple, conservative analysis to determine if the loop is shape conserving.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:466
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
LogicalResult matchForLikeLoop(Value iv, OpFoldResult &lb, OpFoldResult &ub, OpFoldResult &step)
Match "for loop"-like operations from the SCF dialect.
LogicalResult canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, Operation *op, LoopMatcherFn loopMatcher)
Try to canonicalize the given affine.min/max operation in the context of for loops with a known range...
RewritePatternSet & patterns
Definition: Patterns.h:74
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns)
Populate patterns for canonicalizing operations inside SCF loop bodies.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createSCFForLoopCanonicalizationPass()
Creates a pass that canonicalizes affine.min and affine.max operations inside of scf....
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358