MLIR  22.0.0git
Loops.cpp
Go to the documentation of this file.
1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/Support/LLVM.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTLINALGTOAFFINELOOPSPASS
30 #define GEN_PASS_DEF_CONVERTLINALGTOLOOPSPASS
31 #define GEN_PASS_DEF_CONVERTLINALGTOPARALLELLOOPSPASS
32 #include "mlir/Dialect/Linalg/Passes.h.inc"
33 } // namespace mlir
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
39  AffineMap map,
40  ArrayRef<Value> vals) {
41  if (map.isEmpty())
42  return {};
43 
44  assert(map.getNumInputs() == vals.size());
46  res.reserve(map.getNumResults());
47  auto dims = map.getNumDims();
48  for (auto e : map.getResults()) {
49  auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e);
50  SmallVector<Value> operands(vals);
51  affine::canonicalizeMapAndOperands(&exprMap, &operands);
52  res.push_back(b.create<affine::AffineApplyOp>(loc, exprMap, operands));
53  }
54  return res;
55 }
56 
57 template <typename LoadOpTy, typename StoreOpTy, typename OpType>
58 static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op,
59  ArrayRef<Value> indexedValues,
60  ArrayRef<SmallVector<Value>> indexing,
61  ArrayRef<Value> outputBuffers) {
62  auto &block = op->getRegion(0).front();
63  IRMapping map;
64  map.map(block.getArguments(), indexedValues);
65  for (auto &op : block.without_terminator()) {
66  auto *newOp = b.clone(op, map);
67  map.map(op.getResults(), newOp->getResults());
68  }
69 
70  Operation *terminator = block.getTerminator();
71  for (OpOperand &operand : terminator->getOpOperands()) {
72  Value toStore = map.lookupOrDefault(operand.get());
73  b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()],
74  indexing[operand.getOperandNumber()]);
75  }
76 }
77 
78 // Returns a pair that contains input indices and output indices of a
79 // SingleInputPoolingOp `op`.
83 };
84 template <typename SingleInputPoolingOp>
87  SingleInputPoolingOp op) {
88  auto mapsRange = op.getIndexingMapsArray();
89  auto maps = llvm::to_vector<8>(
90  llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
91  return InputAndOutputIndices{
92  makeCanonicalAffineApplies(b, loc, maps[0], allIvs),
93  makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
94 }
95 
96 /// Emits the MLIR for the scalar part of the generic op by:
97 /// 1. Emitting load ops for each input and output view in order. This is
98 /// achieved by applying the appropriate input or output map to the
99 /// enclosing induction variables.
100 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars
101 /// from point 1. above.
102 /// 3. Emitting store ops to store the results of 2. to the output
103 /// views.
104 ///
105 /// An example output may resemble:
106 ///
107 /// ```
108 /// scf.for %i = %c0 to %0 step %c1 {
109 /// scf.for %j = %c0 to %1 step %c1 {
110 /// scf.for %k = %c0 to %4 step %c1 {
111 /// %11 = load %arg0[%i, %j] :
112 /// memref<?x?xf32, stride_specification>
113 /// %12 = load %arg1[%i, %j, %k] :
114 /// memref<?x?x?xf32, stride_specification>
115 /// %13 = load %arg2[%i, %k, %j] :
116 /// memref<?x?x?xf32, stride_specification>
117 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
118 /// store %14#0, %arg1[%i, %j, %k] :
119 /// memref<?x?x?Xf32, stride_specification>
120 /// store %14#1, %arg2[%i, %k, %j] :
121 /// memref<?x?x?Xf32, stride_specification>
122 /// }
123 /// }
124 /// }
125 /// ```
126 template <typename LoadOpTy, typename StoreOpTy>
128  ArrayRef<Value> allIvs,
129  LinalgOp linalgOp) {
130  assert(linalgOp.hasPureBufferSemantics() &&
131  "expected linalg op with buffer semantics");
132  SmallVector<Value> indexedValues;
133  indexedValues.reserve(linalgOp->getNumOperands());
134 
135  auto allIvsPlusDims = SmallVector<Value>(allIvs);
136 
137  // TODO: Avoid the loads if the corresponding argument of the
138  // region has no uses.
139  // 1.a. Emit load from input operand or for scalars access the operand itself.
140  for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
141  if (linalgOp.isScalar(inputOperand)) {
142  indexedValues.push_back(inputOperand->get());
143  continue;
144  }
145  auto indexing = makeCanonicalAffineApplies(
146  b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims);
147  indexedValues.push_back(
148  b.create<LoadOpTy>(loc, inputOperand->get(), indexing));
149  }
150  // 1.b. Emit load from output views.
151  for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
153  b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
154  allIvsPlusDims);
155  indexedValues.push_back(
156  b.create<LoadOpTy>(loc, outputOperand.get(), indexing));
157  }
158 
159  // TODO: When a region inliner exists, use it.
160  // 2. Inline region, currently only works for a single basic block.
161  // 3. Emit store.
162  SmallVector<SmallVector<Value>, 8> indexing;
163  SmallVector<Value> outputBuffers;
164  for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
165  if (!isa<MemRefType>(outputOperand.get().getType()))
166  continue;
167  indexing.push_back(makeCanonicalAffineApplies(
168  b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
169  allIvsPlusDims));
170  outputBuffers.push_back(outputOperand.get());
171  }
172  inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues,
173  indexing, outputBuffers);
174 }
175 
176 /// Replace the index operations in the body of the loop nest by the matching
177 /// induction variables.
179  LinalgOp linalgOp,
180  ArrayRef<Operation *> loopOps) {
181  // Extract the induction variables of the loop nest from outer to inner.
182  SmallVector<Value> allIvs;
183  for (Operation *loopOp : loopOps) {
185  .Case([&](scf::ParallelOp parallelOp) {
186  allIvs.append(parallelOp.getInductionVars());
187  })
188  .Case([&](scf::ForOp forOp) {
189  allIvs.push_back(forOp.getInductionVar());
190  })
191  .Case([&](affine::AffineForOp affineForOp) {
192  allIvs.push_back(affineForOp.getInductionVar());
193  })
194  .Default([&](Operation *op) { assert(false && "unexpected op"); });
195  }
196  assert(linalgOp.getNumLoops() == allIvs.size() &&
197  "expected the number of loops and induction variables to match");
198  // Replace the index operations in the body of the innermost loop op.
199  if (!loopOps.empty()) {
200  auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
201  for (Region *r : loopOp.getLoopRegions())
202  for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>()))
203  rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
204  }
205 }
206 
207 template <typename LoopTy>
208 static FailureOr<LinalgLoops> linalgOpToLoopsImpl(RewriterBase &rewriter,
209  LinalgOp linalgOp) {
210  using LoadOpTy =
211  std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
212  affine::AffineLoadOp, memref::LoadOp>;
213  using StoreOpTy =
214  std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
215  affine::AffineStoreOp, memref::StoreOp>;
216 
217  // The flattened loopToOperandRangesMaps is expected to be an invertible
218  // permutation map (which is asserted in the inverse calculation).
219  assert(linalgOp.hasPureBufferSemantics() &&
220  "expected linalg op with buffer semantics");
221 
222  auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
223  auto iteratorTypes = linalgOp.getIteratorTypesArray();
224 
225  SmallVector<Value> allIvs;
227  rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
228  [&](OpBuilder &b, Location loc, ValueRange ivs,
229  ValueRange operandValuesToUse) -> scf::ValueVector {
230  assert(operandValuesToUse == linalgOp->getOperands() &&
231  "expect operands are captured and not passed by loop argument");
232  allIvs.append(ivs.begin(), ivs.end());
233  emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp);
234  return scf::ValueVector{};
235  });
236  // Number of loop ops might be different from the number of ivs since some
237  // loops like affine.parallel and scf.parallel have multiple ivs.
238  SetVector<Operation *> loopSet;
239  for (Value iv : allIvs) {
240  if (!iv)
241  return failure();
242  // The induction variable is a block argument of the entry block of the
243  // loop operation.
244  BlockArgument ivVal = dyn_cast<BlockArgument>(iv);
245  if (!ivVal)
246  return failure();
247  loopSet.insert(ivVal.getOwner()->getParentOp());
248  }
249  LinalgLoops loops(loopSet.begin(), loopSet.end());
250  // Replace all index operations in the loop body.
251  replaceIndexOpsByInductionVariables(rewriter, linalgOp, loops);
252  return loops;
253 }
254 
255 namespace {
256 template <typename LoopType>
257 class LinalgRewritePattern : public RewritePattern {
258 public:
259  LinalgRewritePattern(MLIRContext *context)
260  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
261 
262  LogicalResult matchAndRewrite(Operation *op,
263  PatternRewriter &rewriter) const override {
264  auto linalgOp = dyn_cast<LinalgOp>(op);
265  if (!isa<LinalgOp>(op) || !linalgOp.hasPureBufferSemantics()) {
266  return rewriter.notifyMatchFailure(
267  op, "expected linalg op with buffer semantics");
268  }
269  if (failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)))
270  return failure();
271  rewriter.eraseOp(op);
272  return success();
273  }
274 };
275 
276 /// Local folding pattern for AffineApplyOp that we can apply greedily.
277 /// This replaces AffineApplyOp by the proper value in cases where the
278 /// associated map is trivial.
279 /// A trivial map here is defined as a map with a single result and either:
280 /// 1. Zero operand + returns a single AffineConstantExpr
281 /// 2. One operand + returns a single AffineDimExpr
282 /// 3. One operand + returns a single AffineSymbolExpr
283 //
284 /// In the first case, the AffineApplyOp is replaced by a new constant. In the
285 /// other cases, it is replaced by its unique operand.
286 struct FoldAffineOp : public RewritePattern {
287  FoldAffineOp(MLIRContext *context)
288  : RewritePattern(affine::AffineApplyOp::getOperationName(), 0, context) {}
289 
290  LogicalResult matchAndRewrite(Operation *op,
291  PatternRewriter &rewriter) const override {
292  auto affineApplyOp = cast<affine::AffineApplyOp>(op);
293  auto map = affineApplyOp.getAffineMap();
294  if (map.getNumResults() != 1 || map.getNumInputs() > 1)
295  return failure();
296 
297  AffineExpr expr = map.getResult(0);
298  if (map.getNumInputs() == 0) {
299  if (auto val = dyn_cast<AffineConstantExpr>(expr)) {
300  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, val.getValue());
301  return success();
302  }
303  return failure();
304  }
305  if (isa<AffineDimExpr, AffineSymbolExpr>(expr)) {
306  rewriter.replaceOp(op, op->getOperand(0));
307  return success();
308  }
309  return failure();
310  }
311 };
312 
313 template <typename LoopType>
314 static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
315  MLIRContext *context = enclosingOp->getContext();
316  RewritePatternSet patterns(context);
317  patterns.add<LinalgRewritePattern<LoopType>>(context);
318  memref::DimOp::getCanonicalizationPatterns(patterns, context);
319  tensor::DimOp::getCanonicalizationPatterns(patterns, context);
320  affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
321  patterns.add<FoldAffineOp>(context);
322  // Just apply the patterns greedily.
323  (void)applyPatternsGreedily(enclosingOp, std::move(patterns));
324 }
325 
326 struct LowerToAffineLoops
327  : public impl::ConvertLinalgToAffineLoopsPassBase<LowerToAffineLoops> {
328  using impl::ConvertLinalgToAffineLoopsPassBase<
329  LowerToAffineLoops>::ConvertLinalgToAffineLoopsPassBase;
330  void getDependentDialects(DialectRegistry &registry) const override {
331  registry.insert<memref::MemRefDialect>();
332  }
333  void runOnOperation() override {
334  lowerLinalgToLoopsImpl<affine::AffineForOp>(getOperation());
335  }
336 };
337 
338 struct LowerToLoops : public impl::ConvertLinalgToLoopsPassBase<LowerToLoops> {
339  using impl::ConvertLinalgToLoopsPassBase<
340  LowerToLoops>::ConvertLinalgToLoopsPassBase;
341  void getDependentDialects(DialectRegistry &registry) const override {
342  registry.insert<memref::MemRefDialect, scf::SCFDialect>();
343  }
344  void runOnOperation() override {
345  lowerLinalgToLoopsImpl<scf::ForOp>(getOperation());
346  }
347 };
348 
349 struct LowerToParallelLoops
350  : public impl::ConvertLinalgToParallelLoopsPassBase<LowerToParallelLoops> {
351  using impl::ConvertLinalgToParallelLoopsPassBase<
352  LowerToParallelLoops>::ConvertLinalgToParallelLoopsPassBase;
353  void runOnOperation() override {
354  lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation());
355  }
356 };
357 
358 } // namespace
359 
360 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
361 FailureOr<LinalgLoops>
362 mlir::linalg::linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp) {
363  return linalgOpToLoopsImpl<affine::AffineForOp>(rewriter, linalgOp);
364 }
365 
366 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
367 FailureOr<LinalgLoops> mlir::linalg::linalgOpToLoops(RewriterBase &rewriter,
368  LinalgOp linalgOp) {
369  return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
370 }
371 
372 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
373 FailureOr<LinalgLoops>
375  LinalgOp linalgOp) {
376  return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
377 }
static SmallVector< Value > makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, ArrayRef< Value > vals)
Definition: Loops.cpp:38
static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< Operation * > loopOps)
Replace the index operations in the body of the loop nest by the matching induction variables.
Definition: Loops.cpp:178
static InputAndOutputIndices getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef< Value > allIvs, SingleInputPoolingOp op)
Definition: Loops.cpp:86
static FailureOr< LinalgLoops > linalgOpToLoopsImpl(RewriterBase &rewriter, LinalgOp linalgOp)
Definition: Loops.cpp:208
static void emitScalarImplementation(OpBuilder &b, Location loc, ArrayRef< Value > allIvs, LinalgOp linalgOp)
Emits the MLIR for the scalar part of the generic op by:
Definition: Loops.cpp:127
static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, ArrayRef< Value > indexedValues, ArrayRef< SmallVector< Value >> indexing, ArrayRef< Value > outputBuffers)
Definition: Loops.cpp:58
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
Definition: AffineMap.cpp:365
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:767
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:700
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1621
FailureOr< LinalgLoops > linalgOpToLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.for with the proper body for linalgOp.
Definition: Loops.cpp:367
FailureOr< LinalgLoops > linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of affine.for with the proper body for linalgOp.
Definition: Loops.cpp:362
FailureOr< LinalgLoops > linalgOpToParallelLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.parallel with the proper body for linalgOp.
Definition: Loops.cpp:374
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
SmallVector< Value > outputs
Definition: Loops.cpp:82
SmallVector< Value > inputs
Definition: Loops.cpp:81
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})