MLIR 22.0.0git
Loops.cpp
Go to the documentation of this file.
1//===- Loops.cpp - conversion from Linalg named and generic ops to loops --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/IR/AffineExpr.h"
20#include "mlir/IR/AffineMap.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/Support/LLVM.h"
26#include "llvm/ADT/TypeSwitch.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTLINALGTOAFFINELOOPSPASS
30#define GEN_PASS_DEF_CONVERTLINALGTOLOOPSPASS
31#define GEN_PASS_DEF_CONVERTLINALGTOPARALLELLOOPSPASS
32#include "mlir/Dialect/Linalg/Passes.h.inc"
33} // namespace mlir
34
35using namespace mlir;
36using namespace mlir::linalg;
37
39 AffineMap map,
40 ArrayRef<Value> vals) {
41 if (map.isEmpty())
42 return {};
43
44 assert(map.getNumInputs() == vals.size());
46 res.reserve(map.getNumResults());
47 auto dims = map.getNumDims();
48 for (auto e : map.getResults()) {
49 auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e);
50 SmallVector<Value> operands(vals);
51 affine::canonicalizeMapAndOperands(&exprMap, &operands);
52 res.push_back(affine::AffineApplyOp::create(b, loc, exprMap, operands));
53 }
54 return res;
55}
56
57template <typename LoadOpTy, typename StoreOpTy, typename OpType>
58static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op,
59 ArrayRef<Value> indexedValues,
61 ArrayRef<Value> outputBuffers) {
62 auto &block = op->getRegion(0).front();
63 IRMapping map;
64 map.map(block.getArguments(), indexedValues);
65 for (auto &op : block.without_terminator()) {
66 auto *newOp = b.clone(op, map);
67 map.map(op.getResults(), newOp->getResults());
68 }
69
70 Operation *terminator = block.getTerminator();
71 for (OpOperand &operand : terminator->getOpOperands()) {
72 Value toStore = map.lookupOrDefault(operand.get());
73 StoreOpTy::create(b, loc, toStore,
74 outputBuffers[operand.getOperandNumber()],
75 indexing[operand.getOperandNumber()]);
76 }
77}
78
79// Returns a pair that contains input indices and output indices of a
80// SingleInputPoolingOp `op`.
85template <typename SingleInputPoolingOp>
88 SingleInputPoolingOp op) {
89 auto mapsRange = op.getIndexingMapsArray();
90 auto maps = llvm::to_vector<8>(
91 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
93 makeCanonicalAffineApplies(b, loc, maps[0], allIvs),
94 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
95}
96
97/// Emits the MLIR for the scalar part of the generic op by:
98/// 1. Emitting load ops for each input and output view in order. This is
99/// achieved by applying the appropriate input or output map to the
100/// enclosing induction variables.
101/// 2. Emitting a call to `op.fun()` that takes as arguments the scalars
102/// from point 1. above.
103/// 3. Emitting store ops to store the results of 2. to the output
104/// views.
105///
106/// An example output may resemble:
107///
108/// ```
109/// scf.for %i = %c0 to %0 step %c1 {
110/// scf.for %j = %c0 to %1 step %c1 {
111/// scf.for %k = %c0 to %4 step %c1 {
112/// %11 = load %arg0[%i, %j] :
113/// memref<?x?xf32, stride_specification>
114/// %12 = load %arg1[%i, %j, %k] :
115/// memref<?x?x?xf32, stride_specification>
116/// %13 = load %arg2[%i, %k, %j] :
117/// memref<?x?x?xf32, stride_specification>
118/// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
119/// store %14#0, %arg1[%i, %j, %k] :
120/// memref<?x?x?Xf32, stride_specification>
121/// store %14#1, %arg2[%i, %k, %j] :
122/// memref<?x?x?Xf32, stride_specification>
123/// }
124/// }
125/// }
126/// ```
127template <typename LoadOpTy, typename StoreOpTy>
130 LinalgOp linalgOp) {
131 assert(linalgOp.hasPureBufferSemantics() &&
132 "expected linalg op with buffer semantics");
133 SmallVector<Value> indexedValues;
134 indexedValues.reserve(linalgOp->getNumOperands());
135
136 auto allIvsPlusDims = SmallVector<Value>(allIvs);
138 // TODO: Avoid the loads if the corresponding argument of the
139 // region has no uses.
140 // 1.a. Emit load from input operand or for scalars access the operand itself.
141 for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
142 if (linalgOp.isScalar(inputOperand)) {
143 indexedValues.push_back(inputOperand->get());
144 continue;
146 auto indexing = makeCanonicalAffineApplies(
147 b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims);
148 indexedValues.push_back(
149 LoadOpTy::create(b, loc, inputOperand->get(), indexing));
151 // 1.b. Emit load from output views.
152 for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
154 b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
155 allIvsPlusDims);
156 indexedValues.push_back(
157 LoadOpTy::create(b, loc, outputOperand.get(), indexing));
158 }
159
160 // TODO: When a region inliner exists, use it.
161 // 2. Inline region, currently only works for a single basic block.
162 // 3. Emit store.
164 SmallVector<Value> outputBuffers;
165 for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
166 if (!isa<MemRefType>(outputOperand.get().getType()))
167 continue;
168 indexing.push_back(makeCanonicalAffineApplies(
169 b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
170 allIvsPlusDims));
171 outputBuffers.push_back(outputOperand.get());
172 }
173 inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues,
174 indexing, outputBuffers);
175}
176
177/// Replace the index operations in the body of the loop nest by the matching
178/// induction variables.
180 LinalgOp linalgOp,
181 ArrayRef<Operation *> loopOps) {
182 // Extract the induction variables of the loop nest from outer to inner.
183 SmallVector<Value> allIvs;
184 for (Operation *loopOp : loopOps) {
186 .Case([&](scf::ParallelOp parallelOp) {
187 allIvs.append(parallelOp.getInductionVars());
188 })
189 .Case([&](scf::ForOp forOp) {
190 allIvs.push_back(forOp.getInductionVar());
191 })
192 .Case([&](affine::AffineForOp affineForOp) {
193 allIvs.push_back(affineForOp.getInductionVar());
194 })
195 .DefaultUnreachable("unexpected op");
197 assert(linalgOp.getNumLoops() == allIvs.size() &&
198 "expected the number of loops and induction variables to match");
199 // Replace the index operations in the body of the innermost loop op.
200 if (!loopOps.empty()) {
201 auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
202 for (Region *r : loopOp.getLoopRegions())
203 for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>()))
204 rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
205 }
207
208template <typename LoopTy>
209static FailureOr<LinalgLoops> linalgOpToLoopsImpl(RewriterBase &rewriter,
210 LinalgOp linalgOp) {
211 using LoadOpTy =
212 std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
213 affine::AffineLoadOp, memref::LoadOp>;
214 using StoreOpTy =
215 std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
216 affine::AffineStoreOp, memref::StoreOp>;
218 // The flattened loopToOperandRangesMaps is expected to be an invertible
219 // permutation map (which is asserted in the inverse calculation).
220 assert(linalgOp.hasPureBufferSemantics() &&
221 "expected linalg op with buffer semantics");
223 auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
224 auto iteratorTypes = linalgOp.getIteratorTypesArray();
225
226 SmallVector<Value> allIvs;
228 rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
229 [&](OpBuilder &b, Location loc, ValueRange ivs,
230 ValueRange operandValuesToUse) -> scf::ValueVector {
231 assert(operandValuesToUse == linalgOp->getOperands() &&
232 "expect operands are captured and not passed by loop argument");
233 allIvs.append(ivs.begin(), ivs.end());
234 emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp);
235 return scf::ValueVector{};
236 });
237 // Number of loop ops might be different from the number of ivs since some
238 // loops like affine.parallel and scf.parallel have multiple ivs.
240 for (Value iv : allIvs) {
241 if (!iv)
242 return failure();
243 // The induction variable is a block argument of the entry block of the
244 // loop operation.
245 BlockArgument ivVal = dyn_cast<BlockArgument>(iv);
246 if (!ivVal)
247 return failure();
248 loopSet.insert(ivVal.getOwner()->getParentOp());
249 }
250 LinalgLoops loops(loopSet.begin(), loopSet.end());
251 // Replace all index operations in the loop body.
252 replaceIndexOpsByInductionVariables(rewriter, linalgOp, loops);
253 return loops;
254}
255
256namespace {
257template <typename LoopType>
258class LinalgRewritePattern : public RewritePattern {
259public:
260 LinalgRewritePattern(MLIRContext *context)
261 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
262
263 LogicalResult matchAndRewrite(Operation *op,
264 PatternRewriter &rewriter) const override {
265 auto linalgOp = dyn_cast<LinalgOp>(op);
266 if (!isa<LinalgOp>(op) || !linalgOp.hasPureBufferSemantics()) {
267 return rewriter.notifyMatchFailure(
268 op, "expected linalg op with buffer semantics");
269 }
270 if (failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)))
271 return failure();
272 rewriter.eraseOp(op);
273 return success();
277/// Local folding pattern for AffineApplyOp that we can apply greedily.
278/// This replaces AffineApplyOp by the proper value in cases where the
279/// associated map is trivial.
280/// A trivial map here is defined as a map with a single result and either:
281/// 1. Zero operand + returns a single AffineConstantExpr
282/// 2. One operand + returns a single AffineDimExpr
283/// 3. One operand + returns a single AffineSymbolExpr
284//
285/// In the first case, the AffineApplyOp is replaced by a new constant. In the
286/// other cases, it is replaced by its unique operand.
287struct FoldAffineOp : public RewritePattern {
288 FoldAffineOp(MLIRContext *context)
289 : RewritePattern(affine::AffineApplyOp::getOperationName(), 0, context) {}
290
291 LogicalResult matchAndRewrite(Operation *op,
292 PatternRewriter &rewriter) const override {
293 auto affineApplyOp = cast<affine::AffineApplyOp>(op);
294 auto map = affineApplyOp.getAffineMap();
295 if (map.getNumResults() != 1 || map.getNumInputs() > 1)
296 return failure();
297
298 AffineExpr expr = map.getResult(0);
299 if (map.getNumInputs() == 0) {
300 if (auto val = dyn_cast<AffineConstantExpr>(expr)) {
301 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, val.getValue());
302 return success();
303 }
304 return failure();
305 }
306 if (isa<AffineDimExpr, AffineSymbolExpr>(expr)) {
307 rewriter.replaceOp(op, op->getOperand(0));
308 return success();
309 }
310 return failure();
311 }
312};
313
314template <typename LoopType>
315static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
316 MLIRContext *context = enclosingOp->getContext();
318 patterns.add<LinalgRewritePattern<LoopType>>(context);
319 memref::DimOp::getCanonicalizationPatterns(patterns, context);
320 tensor::DimOp::getCanonicalizationPatterns(patterns, context);
321 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
322 patterns.add<FoldAffineOp>(context);
323 // Just apply the patterns greedily.
324 (void)applyPatternsGreedily(enclosingOp, std::move(patterns));
326
327struct LowerToAffineLoops
328 : public impl::ConvertLinalgToAffineLoopsPassBase<LowerToAffineLoops> {
330 LowerToAffineLoops>::ConvertLinalgToAffineLoopsPassBase;
331 void getDependentDialects(DialectRegistry &registry) const override {
332 registry.insert<memref::MemRefDialect>();
333 }
334 void runOnOperation() override {
335 lowerLinalgToLoopsImpl<affine::AffineForOp>(getOperation());
336 }
337};
338
339struct LowerToLoops : public impl::ConvertLinalgToLoopsPassBase<LowerToLoops> {
340 using impl::ConvertLinalgToLoopsPassBase<
341 LowerToLoops>::ConvertLinalgToLoopsPassBase;
342 void getDependentDialects(DialectRegistry &registry) const override {
343 registry.insert<memref::MemRefDialect, scf::SCFDialect>();
344 }
345 void runOnOperation() override {
346 lowerLinalgToLoopsImpl<scf::ForOp>(getOperation());
347 }
348};
349
350struct LowerToParallelLoops
351 : public impl::ConvertLinalgToParallelLoopsPassBase<LowerToParallelLoops> {
352 using impl::ConvertLinalgToParallelLoopsPassBase<
353 LowerToParallelLoops>::ConvertLinalgToParallelLoopsPassBase;
354 void runOnOperation() override {
355 lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation());
356 }
357};
358
359} // namespace
360
361/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
362FailureOr<LinalgLoops>
364 return linalgOpToLoopsImpl<affine::AffineForOp>(rewriter, linalgOp);
365}
366
367/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
368FailureOr<LinalgLoops> mlir::linalg::linalgOpToLoops(RewriterBase &rewriter,
369 LinalgOp linalgOp) {
370 return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
371}
372
373/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
374FailureOr<LinalgLoops>
376 LinalgOp linalgOp) {
377 return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
378}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static SmallVector< Value > makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, ArrayRef< Value > vals)
Definition Loops.cpp:38
static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< Operation * > loopOps)
Replace the index operations in the body of the loop nest by the matching induction variables.
Definition Loops.cpp:179
static InputAndOutputIndices getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef< Value > allIvs, SingleInputPoolingOp op)
Definition Loops.cpp:87
static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, ArrayRef< Value > indexedValues, ArrayRef< SmallVector< Value > > indexing, ArrayRef< Value > outputBuffers)
Definition Loops.cpp:58
static FailureOr< LinalgLoops > linalgOpToLoopsImpl(RewriterBase &rewriter, LinalgOp linalgOp)
Definition Loops.cpp:209
static void emitScalarImplementation(OpBuilder &b, Location loc, ArrayRef< Value > allIvs, LinalgOp linalgOp)
Emits the MLIR for the scalar part of the generic op by:
Definition Loops.cpp:128
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
unsigned getNumInputs() const
This class represents an argument of a Block.
Definition Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
SmallVector< Operation *, 4 > LinalgLoops
Definition Transforms.h:517
FailureOr< LinalgLoops > linalgOpToLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.for with the proper body for linalgOp.
Definition Loops.cpp:368
FailureOr< LinalgLoops > linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of affine.for with the proper body for linalgOp.
Definition Loops.cpp:363
FailureOr< LinalgLoops > linalgOpToParallelLoops(RewriterBase &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.parallel with the proper body for linalgOp.
Definition Loops.cpp:375
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
const FrozenRewritePatternSet & patterns
SmallVector< Value > outputs
Definition Loops.cpp:83
SmallVector< Value > inputs
Definition Loops.cpp:82
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})