MLIR  19.0.0git
Loops.cpp
Go to the documentation of this file.
1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/IRMapping.h"
23 #include "mlir/Support/LLVM.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_CONVERTLINALGTOAFFINELOOPSPASS
31 #define GEN_PASS_DEF_CONVERTLINALGTOLOOPSPASS
32 #define GEN_PASS_DEF_CONVERTLINALGTOPARALLELLOOPSPASS
33 #include "mlir/Dialect/Linalg/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 using namespace mlir::linalg;
38 
40  AffineMap map,
41  ArrayRef<Value> vals) {
42  if (map.isEmpty())
43  return {};
44 
45  assert(map.getNumInputs() == vals.size());
47  res.reserve(map.getNumResults());
48  auto dims = map.getNumDims();
49  for (auto e : map.getResults()) {
50  auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e);
51  SmallVector<Value> operands(vals.begin(), vals.end());
52  affine::canonicalizeMapAndOperands(&exprMap, &operands);
53  res.push_back(b.create<affine::AffineApplyOp>(loc, exprMap, operands));
54  }
55  return res;
56 }
57 
58 template <typename LoadOpTy, typename StoreOpTy, typename OpType>
59 static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op,
60  ArrayRef<Value> indexedValues,
61  ArrayRef<SmallVector<Value>> indexing,
62  ArrayRef<Value> outputBuffers) {
63  auto &block = op->getRegion(0).front();
64  IRMapping map;
65  map.map(block.getArguments(), indexedValues);
66  for (auto &op : block.without_terminator()) {
67  auto *newOp = b.clone(op, map);
68  map.map(op.getResults(), newOp->getResults());
69  }
70 
71  Operation *terminator = block.getTerminator();
72  for (OpOperand &operand : terminator->getOpOperands()) {
73  Value toStore = map.lookupOrDefault(operand.get());
74  b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()],
75  indexing[operand.getOperandNumber()]);
76  }
77 }
78 
79 // Returns a pair that contains input indices and output indices of a
80 // SingleInputPoolingOp `op`.
84 };
85 template <typename SingleInputPoolingOp>
88  SingleInputPoolingOp op) {
89  auto mapsRange = op.getIndexingMapsArray();
90  auto maps = llvm::to_vector<8>(
91  llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
92  return InputAndOutputIndices{
93  makeCanonicalAffineApplies(b, loc, maps[0], allIvs),
94  makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
95 }
96 
97 /// Emits the MLIR for the scalar part of the generic op by:
98 /// 1. Emitting load ops for each input and output view in order. This is
99 /// achieved by applying the appropriate input or output map to the
100 /// enclosing induction variables.
101 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars
102 /// from point 1. above.
103 /// 3. Emitting store ops to store the results of 2. to the output
104 /// views.
105 ///
106 /// An example output may resemble:
107 ///
108 /// ```
109 /// scf.for %i = %c0 to %0 step %c1 {
110 /// scf.for %j = %c0 to %1 step %c1 {
111 /// scf.for %k = %c0 to %4 step %c1 {
112 /// %11 = load %arg0[%i, %j] :
113 /// memref<?x?xf32, stride_specification>
114 /// %12 = load %arg1[%i, %j, %k] :
115 /// memref<?x?x?xf32, stride_specification>
116 /// %13 = load %arg2[%i, %k, %j] :
117 /// memref<?x?x?xf32, stride_specification>
118 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
119 /// store %14#0, %arg1[%i, %j, %k] :
120 /// memref<?x?x?Xf32, stride_specification>
121 /// store %14#1, %arg2[%i, %k, %j] :
122 /// memref<?x?x?Xf32, stride_specification>
123 /// }
124 /// }
125 /// }
126 /// ```
127 template <typename LoadOpTy, typename StoreOpTy>
129  ArrayRef<Value> allIvs,
130  LinalgOp linalgOp) {
131  assert(linalgOp.hasPureBufferSemantics() &&
132  "expected linalg op with buffer semantics");
133  SmallVector<Value> indexedValues;
134  indexedValues.reserve(linalgOp->getNumOperands());
135 
136  auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end());
137 
138  // TODO: Avoid the loads if the corresponding argument of the
139  // region has no uses.
140  // 1.a. Emit load from input operand or for scalars access the operand itself.
141  for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
142  if (linalgOp.isScalar(inputOperand)) {
143  indexedValues.push_back(inputOperand->get());
144  continue;
145  }
146  auto indexing = makeCanonicalAffineApplies(
147  b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims);
148  indexedValues.push_back(
149  b.create<LoadOpTy>(loc, inputOperand->get(), indexing));
150  }
151  // 1.b. Emit load from output views.
152  for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
154  b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
155  allIvsPlusDims);
156  indexedValues.push_back(
157  b.create<LoadOpTy>(loc, outputOperand.get(), indexing));
158  }
159 
160  // TODO: When a region inliner exists, use it.
161  // 2. Inline region, currently only works for a single basic block.
162  // 3. Emit store.
163  SmallVector<SmallVector<Value>, 8> indexing;
164  SmallVector<Value> outputBuffers;
165  for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
166  if (!isa<MemRefType>(outputOperand.get().getType()))
167  continue;
168  indexing.push_back(makeCanonicalAffineApplies(
169  b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
170  allIvsPlusDims));
171  outputBuffers.push_back(outputOperand.get());
172  }
173  inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues,
174  indexing, outputBuffers);
175 }
176 
177 /// Replace the index operations in the body of the loop nest by the matching
178 /// induction variables.
180  LinalgOp linalgOp,
181  ArrayRef<Operation *> loopOps) {
182  // Extract the induction variables of the loop nest from outer to inner.
183  SmallVector<Value> allIvs;
184  for (Operation *loopOp : loopOps) {
186  .Case([&](scf::ParallelOp parallelOp) {
187  allIvs.append(parallelOp.getInductionVars().begin(),
188  parallelOp.getInductionVars().end());
189  })
190  .Case([&](scf::ForOp forOp) {
191  allIvs.push_back(forOp.getInductionVar());
192  })
193  .Case([&](affine::AffineForOp affineForOp) {
194  allIvs.push_back(affineForOp.getInductionVar());
195  })
196  .Default([&](Operation *op) { assert(false && "unexpected op"); });
197  }
198  assert(linalgOp.getNumLoops() == allIvs.size() &&
199  "expected the number of loops and induction variables to match");
200  // Replace the index operations in the body of the innermost loop op.
201  if (!loopOps.empty()) {
202  auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
203  for (Region *r : loopOp.getLoopRegions())
204  for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>()))
205  rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
206  }
207 }
208 
209 template <typename LoopTy>
211  LinalgOp linalgOp) {
212  using LoadOpTy =
213  std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
214  affine::AffineLoadOp, memref::LoadOp>;
215  using StoreOpTy =
216  std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
217  affine::AffineStoreOp, memref::StoreOp>;
218 
219  // The flattened loopToOperandRangesMaps is expected to be an invertible
220  // permutation map (which is asserted in the inverse calculation).
221  assert(linalgOp.hasPureBufferSemantics() &&
222  "expected linalg op with buffer semantics");
223 
224  auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
225  auto iteratorTypes = linalgOp.getIteratorTypesArray();
226 
227  SmallVector<Value> allIvs;
229  rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
230  [&](OpBuilder &b, Location loc, ValueRange ivs,
231  ValueRange operandValuesToUse) -> scf::ValueVector {
232  assert(operandValuesToUse == linalgOp->getOperands() &&
233  "expect operands are captured and not passed by loop argument");
234  allIvs.append(ivs.begin(), ivs.end());
235  emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp);
236  return scf::ValueVector{};
237  });
238  // Number of loop ops might be different from the number of ivs since some
239  // loops like affine.parallel and scf.parallel have multiple ivs.
240  SetVector<Operation *> loopSet;
241  for (Value iv : allIvs) {
242  if (!iv)
243  return failure();
244  // The induction variable is a block argument of the entry block of the
245  // loop operation.
246  BlockArgument ivVal = dyn_cast<BlockArgument>(iv);
247  if (!ivVal)
248  return failure();
249  loopSet.insert(ivVal.getOwner()->getParentOp());
250  }
251  LinalgLoops loops(loopSet.begin(), loopSet.end());
252  // Replace all index operations in the loop body.
253  replaceIndexOpsByInductionVariables(rewriter, linalgOp, loops);
254  return loops;
255 }
256 
257 namespace {
258 template <typename LoopType>
259 class LinalgRewritePattern : public RewritePattern {
260 public:
261  LinalgRewritePattern(MLIRContext *context)
262  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
263 
264  LogicalResult matchAndRewrite(Operation *op,
265  PatternRewriter &rewriter) const override {
266  auto linalgOp = dyn_cast<LinalgOp>(op);
267  if (!isa<LinalgOp>(op) || !linalgOp.hasPureBufferSemantics()) {
268  return rewriter.notifyMatchFailure(
269  op, "expected linalg op with buffer semantics");
270  }
271  if (failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)))
272  return failure();
273  rewriter.eraseOp(op);
274  return success();
275  }
276 };
277 
278 /// Local folding pattern for AffineApplyOp that we can apply greedily.
279 /// This replaces AffineApplyOp by the proper value in cases where the
280 /// associated map is trivial.
281 /// A trivial map here is defined as a map with a single result and either:
282 /// 1. Zero operand + returns a single AffineConstantExpr
283 /// 2. One operand + returns a single AffineDimExpr
284 /// 3. One operand + returns a single AffineSymbolExpr
285 //
286 /// In the first case, the AffineApplyOp is replaced by a new constant. In the
287 /// other cases, it is replaced by its unique operand.
288 struct FoldAffineOp : public RewritePattern {
289  FoldAffineOp(MLIRContext *context)
290  : RewritePattern(affine::AffineApplyOp::getOperationName(), 0, context) {}
291 
292  LogicalResult matchAndRewrite(Operation *op,
293  PatternRewriter &rewriter) const override {
294  auto affineApplyOp = cast<affine::AffineApplyOp>(op);
295  auto map = affineApplyOp.getAffineMap();
296  if (map.getNumResults() != 1 || map.getNumInputs() > 1)
297  return failure();
298 
299  AffineExpr expr = map.getResult(0);
300  if (map.getNumInputs() == 0) {
301  if (auto val = dyn_cast<AffineConstantExpr>(expr)) {
302  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, val.getValue());
303  return success();
304  }
305  return failure();
306  }
307  if (dyn_cast<AffineDimExpr>(expr) || dyn_cast<AffineSymbolExpr>(expr)) {
308  rewriter.replaceOp(op, op->getOperand(0));
309  return success();
310  }
311  return failure();
312  }
313 };
314 
315 template <typename LoopType>
316 static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
317  MLIRContext *context = enclosingOp->getContext();
318  RewritePatternSet patterns(context);
319  patterns.add<LinalgRewritePattern<LoopType>>(context);
320  memref::DimOp::getCanonicalizationPatterns(patterns, context);
321  tensor::DimOp::getCanonicalizationPatterns(patterns, context);
322  affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
323  patterns.add<FoldAffineOp>(context);
324  // Just apply the patterns greedily.
325  (void)applyPatternsAndFoldGreedily(enclosingOp, std::move(patterns));
326 }
327 
328 struct LowerToAffineLoops
329  : public impl::ConvertLinalgToAffineLoopsPassBase<LowerToAffineLoops> {
330  using impl::ConvertLinalgToAffineLoopsPassBase<
331  LowerToAffineLoops>::ConvertLinalgToAffineLoopsPassBase;
332  void getDependentDialects(DialectRegistry &registry) const override {
333  registry.insert<memref::MemRefDialect>();
334  }
335  void runOnOperation() override {
336  lowerLinalgToLoopsImpl<affine::AffineForOp>(getOperation());
337  }
338 };
339 
340 struct LowerToLoops : public impl::ConvertLinalgToLoopsPassBase<LowerToLoops> {
341  using impl::ConvertLinalgToLoopsPassBase<
342  LowerToLoops>::ConvertLinalgToLoopsPassBase;
343  void getDependentDialects(DialectRegistry &registry) const override {
344  registry.insert<memref::MemRefDialect, scf::SCFDialect>();
345  }
346  void runOnOperation() override {
347  lowerLinalgToLoopsImpl<scf::ForOp>(getOperation());
348  }
349 };
350 
351 struct LowerToParallelLoops
352  : public impl::ConvertLinalgToParallelLoopsPassBase<LowerToParallelLoops> {
353  using impl::ConvertLinalgToParallelLoopsPassBase<
354  LowerToParallelLoops>::ConvertLinalgToParallelLoopsPassBase;
355  void runOnOperation() override {
356  lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation());
357  }
358 };
359 
360 } // namespace
361 
362 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
364 mlir::linalg::linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp) {
365  return linalgOpToLoopsImpl<affine::AffineForOp>(rewriter, linalgOp);
366 }
367 
368 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
370  LinalgOp linalgOp) {
371  return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
372 }
373 
374 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
377  LinalgOp linalgOp) {
378  return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
379 }
static SmallVector< Value > makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, ArrayRef< Value > vals)
Definition: Loops.cpp:39
static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< Operation * > loopOps)
Replace the index operations in the body of the loop nest by the matching induction variables.
Definition: Loops.cpp:179
static InputAndOutputIndices getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef< Value > allIvs, SingleInputPoolingOp op)
Definition: Loops.cpp:87
static FailureOr< LinalgLoops > linalgOpToLoopsImpl(RewriterBase &rewriter, LinalgOp linalgOp)
Definition: Loops.cpp:210
static void emitScalarImplementation(OpBuilder &b, Location loc, ArrayRef< Value > allIvs, LinalgOp linalgOp)
Emits the MLIR for the scalar part of the generic op by:
Definition: Loops.cpp:128
static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, ArrayRef< Value > indexedValues, ArrayRef< SmallVector< Value >> indexing, ArrayRef< Value > outputBuffers)
Definition: Loops.cpp:59
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
Definition: AffineMap.cpp:353
unsigned getNumSymbols() const
Definition: AffineMap.cpp:384
unsigned getNumDims() const
Definition: AffineMap.cpp:380
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:393
unsigned getNumResults() const
Definition: AffineMap.cpp:388
unsigned getNumInputs() const
Definition: AffineMap.cpp:389
This class represents an argument of a Block.
Definition: Value.h:315
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:324
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:553
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:245
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:685
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:537
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1426
FailureOr< LinalgLoops > linalgOpToLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.for with the proper body for linalgOp.
Definition: Loops.cpp:369
FailureOr< LinalgLoops > linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of affine.for with the proper body for linalgOp.
Definition: Loops.cpp:364
FailureOr< LinalgLoops > linalgOpToParallelLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.parallel with the proper body for linalgOp.
Definition: Loops.cpp:376
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
SmallVector< Value > outputs
Definition: Loops.cpp:83
SmallVector< Value > inputs
Definition: Loops.cpp:82
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})