MLIR  19.0.0git
Loops.cpp
Go to the documentation of this file.
1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/IRMapping.h"
23 #include "mlir/Support/LLVM.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_CONVERTLINALGTOAFFINELOOPSPASS
31 #define GEN_PASS_DEF_CONVERTLINALGTOLOOPSPASS
32 #define GEN_PASS_DEF_CONVERTLINALGTOPARALLELLOOPSPASS
33 #include "mlir/Dialect/Linalg/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 using namespace mlir::linalg;
38 
40  AffineMap map,
41  ArrayRef<Value> vals) {
42  if (map.isEmpty())
43  return {};
44 
45  assert(map.getNumInputs() == vals.size());
47  res.reserve(map.getNumResults());
48  auto dims = map.getNumDims();
49  for (auto e : map.getResults()) {
50  auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e);
51  SmallVector<Value> operands(vals.begin(), vals.end());
52  affine::canonicalizeMapAndOperands(&exprMap, &operands);
53  res.push_back(b.create<affine::AffineApplyOp>(loc, exprMap, operands));
54  }
55  return res;
56 }
57 
58 template <typename LoadOpTy, typename StoreOpTy, typename OpType>
59 static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op,
60  ArrayRef<Value> indexedValues,
61  ArrayRef<SmallVector<Value>> indexing,
62  ArrayRef<Value> outputBuffers) {
63  auto &block = op->getRegion(0).front();
64  IRMapping map;
65  map.map(block.getArguments(), indexedValues);
66  for (auto &op : block.without_terminator()) {
67  auto *newOp = b.clone(op, map);
68  map.map(op.getResults(), newOp->getResults());
69  }
70 
71  Operation *terminator = block.getTerminator();
72  for (OpOperand &operand : terminator->getOpOperands()) {
73  Value toStore = map.lookupOrDefault(operand.get());
74  b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()],
75  indexing[operand.getOperandNumber()]);
76  }
77 }
78 
79 // Returns a pair that contains input indices and output indices of a
80 // SingleInputPoolingOp `op`.
84 };
85 template <typename SingleInputPoolingOp>
88  SingleInputPoolingOp op) {
89  auto mapsRange = op.getIndexingMapsArray();
90  auto maps = llvm::to_vector<8>(
91  llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
92  return InputAndOutputIndices{
93  makeCanonicalAffineApplies(b, loc, maps[0], allIvs),
94  makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
95 }
96 
97 /// Emits the MLIR for the scalar part of the generic op by:
98 /// 1. Emitting load ops for each input and output view in order. This is
99 /// achieved by applying the appropriate input or output map to the
100 /// enclosing induction variables.
101 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars
102 /// from point 1. above.
103 /// 3. Emitting store ops to store the results of 2. to the output
104 /// views.
105 ///
106 /// An example output may resemble:
107 ///
108 /// ```
109 /// scf.for %i = %c0 to %0 step %c1 {
110 /// scf.for %j = %c0 to %1 step %c1 {
111 /// scf.for %k = %c0 to %4 step %c1 {
112 /// %11 = load %arg0[%i, %j] :
113 /// memref<?x?xf32, stride_specification>
114 /// %12 = load %arg1[%i, %j, %k] :
115 /// memref<?x?x?xf32, stride_specification>
116 /// %13 = load %arg2[%i, %k, %j] :
117 /// memref<?x?x?xf32, stride_specification>
118 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
119 /// store %14#0, %arg1[%i, %j, %k] :
120 /// memref<?x?x?Xf32, stride_specification>
121 /// store %14#1, %arg2[%i, %k, %j] :
122 /// memref<?x?x?Xf32, stride_specification>
123 /// }
124 /// }
125 /// }
126 /// ```
127 template <typename LoadOpTy, typename StoreOpTy>
129  ArrayRef<Value> allIvs,
130  LinalgOp linalgOp) {
131  assert(linalgOp.hasPureBufferSemantics() &&
132  "expected linalg op with buffer semantics");
133  SmallVector<Value> indexedValues;
134  indexedValues.reserve(linalgOp->getNumOperands());
135 
136  auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end());
137 
138  // TODO: Avoid the loads if the corresponding argument of the
139  // region has no uses.
140  // 1.a. Emit load from input operand or for scalars access the operand itself.
141  for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
142  if (linalgOp.isScalar(inputOperand)) {
143  indexedValues.push_back(inputOperand->get());
144  continue;
145  }
146  auto indexing = makeCanonicalAffineApplies(
147  b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims);
148  indexedValues.push_back(
149  b.create<LoadOpTy>(loc, inputOperand->get(), indexing));
150  }
151  // 1.b. Emit load from output views.
152  for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
154  b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
155  allIvsPlusDims);
156  indexedValues.push_back(
157  b.create<LoadOpTy>(loc, outputOperand.get(), indexing));
158  }
159 
160  // TODO: When a region inliner exists, use it.
161  // 2. Inline region, currently only works for a single basic block.
162  // 3. Emit store.
163  SmallVector<SmallVector<Value>, 8> indexing;
164  SmallVector<Value> outputBuffers;
165  for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
166  if (!isa<MemRefType>(outputOperand.get().getType()))
167  continue;
168  indexing.push_back(makeCanonicalAffineApplies(
169  b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
170  allIvsPlusDims));
171  outputBuffers.push_back(outputOperand.get());
172  }
173  inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues,
174  indexing, outputBuffers);
175 }
176 
177 /// Replace the index operations in the body of the loop nest by the matching
178 /// induction variables.
180  LinalgOp linalgOp,
181  ArrayRef<Operation *> loopOps) {
182  // Extract the induction variables of the loop nest from outer to inner.
183  SmallVector<Value> allIvs;
184  for (Operation *loopOp : loopOps) {
186  .Case([&](scf::ParallelOp parallelOp) {
187  allIvs.append(parallelOp.getInductionVars());
188  })
189  .Case([&](scf::ForOp forOp) {
190  allIvs.push_back(forOp.getInductionVar());
191  })
192  .Case([&](affine::AffineForOp affineForOp) {
193  allIvs.push_back(affineForOp.getInductionVar());
194  })
195  .Default([&](Operation *op) { assert(false && "unexpected op"); });
196  }
197  assert(linalgOp.getNumLoops() == allIvs.size() &&
198  "expected the number of loops and induction variables to match");
199  // Replace the index operations in the body of the innermost loop op.
200  if (!loopOps.empty()) {
201  auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
202  for (Region *r : loopOp.getLoopRegions())
203  for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>()))
204  rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
205  }
206 }
207 
208 template <typename LoopTy>
209 static FailureOr<LinalgLoops> linalgOpToLoopsImpl(RewriterBase &rewriter,
210  LinalgOp linalgOp) {
211  using LoadOpTy =
212  std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
213  affine::AffineLoadOp, memref::LoadOp>;
214  using StoreOpTy =
215  std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
216  affine::AffineStoreOp, memref::StoreOp>;
217 
218  // The flattened loopToOperandRangesMaps is expected to be an invertible
219  // permutation map (which is asserted in the inverse calculation).
220  assert(linalgOp.hasPureBufferSemantics() &&
221  "expected linalg op with buffer semantics");
222 
223  auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
224  auto iteratorTypes = linalgOp.getIteratorTypesArray();
225 
226  SmallVector<Value> allIvs;
228  rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
229  [&](OpBuilder &b, Location loc, ValueRange ivs,
230  ValueRange operandValuesToUse) -> scf::ValueVector {
231  assert(operandValuesToUse == linalgOp->getOperands() &&
232  "expect operands are captured and not passed by loop argument");
233  allIvs.append(ivs.begin(), ivs.end());
234  emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp);
235  return scf::ValueVector{};
236  });
237  // Number of loop ops might be different from the number of ivs since some
238  // loops like affine.parallel and scf.parallel have multiple ivs.
239  SetVector<Operation *> loopSet;
240  for (Value iv : allIvs) {
241  if (!iv)
242  return failure();
243  // The induction variable is a block argument of the entry block of the
244  // loop operation.
245  BlockArgument ivVal = dyn_cast<BlockArgument>(iv);
246  if (!ivVal)
247  return failure();
248  loopSet.insert(ivVal.getOwner()->getParentOp());
249  }
250  LinalgLoops loops(loopSet.begin(), loopSet.end());
251  // Replace all index operations in the loop body.
252  replaceIndexOpsByInductionVariables(rewriter, linalgOp, loops);
253  return loops;
254 }
255 
256 namespace {
257 template <typename LoopType>
258 class LinalgRewritePattern : public RewritePattern {
259 public:
260  LinalgRewritePattern(MLIRContext *context)
261  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
262 
263  LogicalResult matchAndRewrite(Operation *op,
264  PatternRewriter &rewriter) const override {
265  auto linalgOp = dyn_cast<LinalgOp>(op);
266  if (!isa<LinalgOp>(op) || !linalgOp.hasPureBufferSemantics()) {
267  return rewriter.notifyMatchFailure(
268  op, "expected linalg op with buffer semantics");
269  }
270  if (failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)))
271  return failure();
272  rewriter.eraseOp(op);
273  return success();
274  }
275 };
276 
277 /// Local folding pattern for AffineApplyOp that we can apply greedily.
278 /// This replaces AffineApplyOp by the proper value in cases where the
279 /// associated map is trivial.
280 /// A trivial map here is defined as a map with a single result and either:
281 /// 1. Zero operand + returns a single AffineConstantExpr
282 /// 2. One operand + returns a single AffineDimExpr
283 /// 3. One operand + returns a single AffineSymbolExpr
284 //
285 /// In the first case, the AffineApplyOp is replaced by a new constant. In the
286 /// other cases, it is replaced by its unique operand.
287 struct FoldAffineOp : public RewritePattern {
288  FoldAffineOp(MLIRContext *context)
289  : RewritePattern(affine::AffineApplyOp::getOperationName(), 0, context) {}
290 
291  LogicalResult matchAndRewrite(Operation *op,
292  PatternRewriter &rewriter) const override {
293  auto affineApplyOp = cast<affine::AffineApplyOp>(op);
294  auto map = affineApplyOp.getAffineMap();
295  if (map.getNumResults() != 1 || map.getNumInputs() > 1)
296  return failure();
297 
298  AffineExpr expr = map.getResult(0);
299  if (map.getNumInputs() == 0) {
300  if (auto val = dyn_cast<AffineConstantExpr>(expr)) {
301  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, val.getValue());
302  return success();
303  }
304  return failure();
305  }
306  if (dyn_cast<AffineDimExpr>(expr) || dyn_cast<AffineSymbolExpr>(expr)) {
307  rewriter.replaceOp(op, op->getOperand(0));
308  return success();
309  }
310  return failure();
311  }
312 };
313 
314 template <typename LoopType>
315 static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
316  MLIRContext *context = enclosingOp->getContext();
317  RewritePatternSet patterns(context);
318  patterns.add<LinalgRewritePattern<LoopType>>(context);
319  memref::DimOp::getCanonicalizationPatterns(patterns, context);
320  tensor::DimOp::getCanonicalizationPatterns(patterns, context);
321  affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
322  patterns.add<FoldAffineOp>(context);
323  // Just apply the patterns greedily.
324  (void)applyPatternsAndFoldGreedily(enclosingOp, std::move(patterns));
325 }
326 
327 struct LowerToAffineLoops
328  : public impl::ConvertLinalgToAffineLoopsPassBase<LowerToAffineLoops> {
329  using impl::ConvertLinalgToAffineLoopsPassBase<
330  LowerToAffineLoops>::ConvertLinalgToAffineLoopsPassBase;
331  void getDependentDialects(DialectRegistry &registry) const override {
332  registry.insert<memref::MemRefDialect>();
333  }
334  void runOnOperation() override {
335  lowerLinalgToLoopsImpl<affine::AffineForOp>(getOperation());
336  }
337 };
338 
339 struct LowerToLoops : public impl::ConvertLinalgToLoopsPassBase<LowerToLoops> {
340  using impl::ConvertLinalgToLoopsPassBase<
341  LowerToLoops>::ConvertLinalgToLoopsPassBase;
342  void getDependentDialects(DialectRegistry &registry) const override {
343  registry.insert<memref::MemRefDialect, scf::SCFDialect>();
344  }
345  void runOnOperation() override {
346  lowerLinalgToLoopsImpl<scf::ForOp>(getOperation());
347  }
348 };
349 
350 struct LowerToParallelLoops
351  : public impl::ConvertLinalgToParallelLoopsPassBase<LowerToParallelLoops> {
352  using impl::ConvertLinalgToParallelLoopsPassBase<
353  LowerToParallelLoops>::ConvertLinalgToParallelLoopsPassBase;
354  void runOnOperation() override {
355  lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation());
356  }
357 };
358 
359 } // namespace
360 
361 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
362 FailureOr<LinalgLoops>
363 mlir::linalg::linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp) {
364  return linalgOpToLoopsImpl<affine::AffineForOp>(rewriter, linalgOp);
365 }
366 
367 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
368 FailureOr<LinalgLoops> mlir::linalg::linalgOpToLoops(RewriterBase &rewriter,
369  LinalgOp linalgOp) {
370  return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
371 }
372 
373 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
374 FailureOr<LinalgLoops>
376  LinalgOp linalgOp) {
377  return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
378 }
static SmallVector< Value > makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, ArrayRef< Value > vals)
Definition: Loops.cpp:39
static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< Operation * > loopOps)
Replace the index operations in the body of the loop nest by the matching induction variables.
Definition: Loops.cpp:179
static InputAndOutputIndices getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef< Value > allIvs, SingleInputPoolingOp op)
Definition: Loops.cpp:87
static FailureOr< LinalgLoops > linalgOpToLoopsImpl(RewriterBase &rewriter, LinalgOp linalgOp)
Definition: Loops.cpp:209
static void emitScalarImplementation(OpBuilder &b, Location loc, ArrayRef< Value > allIvs, LinalgOp linalgOp)
Emits the MLIR for the scalar part of the generic op by:
Definition: Loops.cpp:128
static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, ArrayRef< Value > indexedValues, ArrayRef< SmallVector< Value >> indexing, ArrayRef< Value > outputBuffers)
Definition: Loops.cpp:59
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
Definition: AffineMap.cpp:356
unsigned getNumSymbols() const
Definition: AffineMap.cpp:385
unsigned getNumDims() const
Definition: AffineMap.cpp:381
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:394
unsigned getNumResults() const
Definition: AffineMap.cpp:389
unsigned getNumInputs() const
Definition: AffineMap.cpp:390
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1433
FailureOr< LinalgLoops > linalgOpToLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.for with the proper body for linalgOp.
Definition: Loops.cpp:368
FailureOr< LinalgLoops > linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of affine.for with the proper body for linalgOp.
Definition: Loops.cpp:363
FailureOr< LinalgLoops > linalgOpToParallelLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.parallel with the proper body for linalgOp.
Definition: Loops.cpp:375
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
Include the generated interface declarations.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< Value > outputs
Definition: Loops.cpp:83
SmallVector< Value > inputs
Definition: Loops.cpp:82
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})