MLIR 23.0.0git
DecomposeLinalgOps.cpp
Go to the documentation of this file.
1//===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
13#include "llvm/ADT/SmallVectorExtras.h"
14#include <optional>
15
16using namespace mlir;
17using namespace mlir::linalg;
18
19namespace {
20
21/// Pattern to decompose a GenericOp that has more than two statements
22/// into one GenericOp with the first statement (i.e. peeled operation), and
23/// a second GenericOp with the remaining statements (i.e. residual operations).
24
25/// - The result of the first GenericOp has the same shape as the iteration
26/// space of the GenericOp. The body of the op yields as many values as the
27/// original op plus all the results of the peeled operation.
28/// - The second GenericOp has as many operands as the original operation plus
29/// all the results of the first Generic Op. It has the same number of yields as
30/// the original op.
31/// - If the result of the peeled operation was yielded by the original
32/// GenericOp the uses of the corresponding results will be replaced with the
33/// result of the first GenericOp created.
34///
35/// Example
36///
37/// ```mlir
38/// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
39/// outs(%init0, %init1 : ...) {
40/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...):
41/// %0 = <s0> %b0, %b1 : ...
42/// %1 = <s1> %0, %b2 : ...
43/// linalg.yield %0, %1 : ...
44/// } -> (..., ...)
45/// return %result#0, %result#1
46/// ```
47///
48/// gets split into
49///
50/// ```mlir
51/// %init = tensor.empty ...
52/// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
53/// outs(%init0, %init1, %init : ...)
54/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
55/// %0 = <s0> %b0, %b1 : ...
56/// linalg.yield %0, %..., %0 : ...
57/// } -> (..., ..., ...)
58/// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...)
59/// outs(%init0, %init1 : ...) {
60/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
61/// %1 = <s1> %b3, %b2 : ...
62/// linalg.yield %..., %1 : ...
63/// } -> (..., ...)
64/// return %op0#0, %op1#1
65/// ```
66///
67/// After canonicalization this is expected to be
68///
69/// ```mlir
70/// %init = tensor.empty ...
71/// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...)
72/// outs(%init : ...)
73/// ^bb0(%b0: ... , %b1: ... , %b2: ...):
74/// %0 = <s0> %b0, %b1 : ...
75/// linalg.yield %0 : ...
76/// } -> ...
77/// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...)
78/// outs(%init1 : ...) {
79/// ^bb0(%b0: ... , %b1: ... , %b2: ...):
80/// %1 = <s1> %b1, %b0 : ...
81/// linalg.yield %..., %1 : ...
82/// } -> ...
83/// return %op0, %op1
84/// ```
85struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> {
86 using OpRewritePattern<GenericOp>::OpRewritePattern;
87
88 LogicalResult matchAndRewrite(GenericOp genericOp,
89 PatternRewriter &rewriter) const override;
90
91private:
92 /// Helper method to create a generic op for the peeled scalar operation. The
93 /// created op has an empty region.
94 GenericOp createPeeledGenericOp(GenericOp genericOp,
95 PatternRewriter &rewriter) const;
96
97 /// Helper method to create a generic op for the residual scalar operation.
98 /// The created op has the same region as the original op.
99 GenericOp createResidualGenericOp(GenericOp genericOp,
100 GenericOp peeledGenericOp,
101 PatternRewriter &rewriter) const;
102};
103} // namespace
104
105/// Helper method to compute the range of a generic op.
107 GenericOp op) {
109 b.setInsertionPoint(op);
110 Location loc = op.getLoc();
111 auto allShapesSizes =
112 cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
113 AffineMap map = op.getShapesToLoopsMap();
114 IRRewriter rewriter(b);
115 return affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
116 allShapesSizes);
117}
118
119/// Get zero value for an element type.
120static Value getZero(OpBuilder &b, Location loc, Type elementType) {
121 assert(elementType.isIntOrIndexOrFloat() &&
122 "expected scalar type while computing zero value");
123 if (isa<IntegerType>(elementType))
124 return arith::ConstantIntOp::create(b, loc, elementType, 0);
125 if (elementType.isIndex())
126 return arith::ConstantIndexOp::create(b, loc, 0);
127 // Assume float.
128 auto floatType = cast<FloatType>(elementType);
130 b, loc, floatType, APFloat::getZero(floatType.getFloatSemantics()));
131}
132
133GenericOp
134DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
135 PatternRewriter &rewriter) const {
136 Block *body = genericOp.getBody();
137 Operation *peeledScalarOperation = &(*body->begin());
138 SmallVector<AffineMap> peeledGenericOpIndexingMaps =
139 genericOp.getIndexingMapsArray();
140
141 /// Compute the loop ranges for operation. This is the shape of the result of
142 /// the generic op for the peeled operation.
143 Location loc = genericOp.getLoc();
144 SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp);
145 SmallVector<Value> newInitValues;
146 SmallVector<Type> newResultTypes;
147
148 // Add as many new results as the number of results of the peeled scalar op.
149 for (auto scalarOpResult : peeledScalarOperation->getResults()) {
150 // If the result is yielded by the original op, use the operand, indexing
151 // map and result type that correspond to the yielded value.
152
153 std::optional<unsigned> resultNumber;
154 for (auto *user : scalarOpResult.getUsers()) {
155 if (auto yieldOp = dyn_cast<YieldOp>(user)) {
156 // Find the first use of the `scalarOpResult` in the yield op.
157 for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {
158 if (yieldOperand.get() == scalarOpResult) {
159 resultNumber = yieldOperand.getOperandNumber();
160 break;
161 }
162 }
163 assert(resultNumber && "unable to find use of a value in its user");
164 break;
165 }
166 }
167 if (resultNumber) {
168 newInitValues.push_back(
169 genericOp.getDpsInitOperand(*resultNumber)->get());
170 OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber));
171 newResultTypes.push_back(result.getType());
172 peeledGenericOpIndexingMaps.push_back(
173 genericOp.getIndexingMapMatchingResult(result));
174 continue;
175 }
176
177 // Fall back path, use an `init_tensor` and identity indexing map.
178 AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size());
179 Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, domain,
180 scalarOpResult.getType());
181 newInitValues.push_back(emptyTensor);
182 newResultTypes.push_back(emptyTensor.getType());
183 peeledGenericOpIndexingMaps.push_back(indexingMap);
184 }
185
186 /// Create the peeled generic op with an empty body.
187 SmallVector<Value> outsOperands = genericOp.getOutputs();
188 outsOperands.append(newInitValues.begin(), newInitValues.end());
189 SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
190 resultTypes.append(newResultTypes.begin(), newResultTypes.end());
191 auto indexingMapAttr =
192 rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
193 return GenericOp::create(
194 rewriter, loc, resultTypes, genericOp.getInputs(), outsOperands,
195 indexingMapAttr, genericOp.getIteratorTypes(), /*doc=*/nullptr,
196 /*libraryCall=*/nullptr, [](OpBuilder, Location, ValueRange) {});
197}
198
199GenericOp
200DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
201 GenericOp peeledGenericOp,
202 PatternRewriter &rewriter) const {
203 /// Append all results from the peeledGenericOps as `ins` operand for the
204 /// residual generic op.
205 SmallVector<Value> residualGenericOpOperands = genericOp.getInputs();
206 unsigned origNumResults = genericOp.getNumResults();
207 unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
208 SmallVector<Value> extraIns;
209 for (auto resultNum :
210 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
211 extraIns.push_back(peeledGenericOp->getResult(resultNum));
212 residualGenericOpOperands.append(extraIns);
213
214 /// Add indexing maps for the newly added operands. Use the same map
215 /// as those used for the new results of the peeledGenericOp.
216 auto indexingMaps = llvm::map_to_vector(
217 genericOp.getDpsInputOperands(), [&](OpOperand *operand) {
218 return genericOp.getMatchingIndexingMap(operand);
219 });
220 for (auto resultNum :
221 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
222 OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum));
223 indexingMaps.push_back(
224 peeledGenericOp.getIndexingMapMatchingResult(result));
225 }
226 for (OpOperand &outOperand : genericOp.getDpsInitsMutable())
227 indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand));
228
229 auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
230 return GenericOp::create(
231 rewriter, genericOp->getLoc(), genericOp->getResultTypes(),
232 residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
233 genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
234 [](OpBuilder, Location, ValueRange) {});
235}
236
237LogicalResult
238DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
239 PatternRewriter &rewriter) const {
240 /// For now only match on operations where the iterator types are all parallel
241 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
242 return rewriter.notifyMatchFailure(genericOp,
243 "unhandled decomposition of operation "
244 "with non-parallel iterator types");
245 }
246 // TODO: this could be generalized to handle `linalg.generic` with buffer
247 // operands too but requires allocation for intermediates. Punt on this for
248 // now.
249 if (!genericOp.hasPureTensorSemantics()) {
250 return rewriter.notifyMatchFailure(
251 genericOp, "only operations with tensor semantics are handled");
252 }
253
254 if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) {
255 return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation();
256 })) {
257 return rewriter.notifyMatchFailure(
258 genericOp, "unhandled decomposition of generic op with out operand not "
259 "accessed using a permutation");
260 }
261
262 /// If the op has only a single statement (apart from the yield), do nothing.
263 Block *body = genericOp.getBody();
264 if (body->getOperations().size() <= 2) {
265 return rewriter.notifyMatchFailure(genericOp,
266 "operation has less than 3 statements");
267 }
268
269 /// Check that the peeled statement has a scalar element type.
270 if (llvm::any_of(body->getOperations().begin()->getResultTypes(),
271 [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
272 return rewriter.notifyMatchFailure(
273 &(*body->getOperations().begin()),
274 "expected return type to be only int, index or float");
275 }
276
277 GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
278 GenericOp residualGenericOp =
279 createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
280
281 /// Move the first statement of the original operation into the body of the
282 /// generic op for the peeled operation.
283 Block *peeledGenericOpBody = peeledGenericOp.getBody();
284 Block *residualGenericOpBody = residualGenericOp.getBody();
285 assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
286 "expected split generic ops to have empty region");
287 peeledGenericOpBody->getOperations().splice(
288 peeledGenericOpBody->begin(), body->getOperations(), body->begin());
289 residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),
290 body->getOperations());
291
292 Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
293 auto *yieldOp = residualGenericOpBody->getTerminator();
294 {
295 // Yield all the result of the peeled scalar operation.
296 OpBuilder::InsertionGuard g(rewriter);
297 rewriter.setInsertionPointToEnd(peeledGenericOpBody);
298 SmallVector<Value> yieldedVals;
299 for (auto origYield : yieldOp->getOperands()) {
300 if (origYield.getDefiningOp() == peeledScalarOperation) {
301 yieldedVals.push_back(origYield);
302 } else {
303 // Do not materialize any new ops inside of the decomposed LinalgOp,
304 // as that would trigger another application of the rewrite pattern
305 // (infinite loop).
306 OpBuilder::InsertionGuard g(rewriter);
307 rewriter.setInsertionPoint(peeledGenericOp);
308 yieldedVals.push_back(
309 getZero(rewriter, genericOp.getLoc(), origYield.getType()));
310 }
311 }
312 yieldedVals.append(
313 llvm::map_to_vector(peeledScalarOperation->getResults(),
314 [](OpResult opr) -> Value { return opr; }));
315 YieldOp::create(rewriter, genericOp.getLoc(), yieldedVals);
316 }
317
318 /// In the split operations, replace block arguments uses that refer to
319 /// original operation to the block arguments of the newly created operation.
320 unsigned origNumInputs = genericOp.getNumDpsInputs();
321 for (const auto &inputBlockArg :
322 llvm::enumerate(genericOp.getBody()->getArguments())) {
323 Value residualOpReplacementArg =
324 residualGenericOpBody->getArgument(inputBlockArg.index());
325 rewriter.replaceUsesWithIf(
326 inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) {
327 return use.getOwner()->getBlock() == residualGenericOpBody;
328 });
329
330 Value peeledOpReplacementArg =
331 peeledGenericOpBody->getArgument(inputBlockArg.index());
332 rewriter.replaceUsesWithIf(
333 inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) {
334 return use.getOwner()->getBlock() == peeledGenericOpBody;
335 });
336 }
337
338 /// Before fixing up the residual operation, track what values are yielded. If
339 /// any of those are from the peeled scalar operation, the uses of the
340 /// corresponding result have to be remapped to result of the generic op for
341 /// the peeled operation.
342 SmallVector<Value> replacements;
343 for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
344 OpResult opr = dyn_cast<OpResult>(yieldValue.value());
345 if (!opr || opr.getOwner() != peeledScalarOperation)
346 replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
347 else
348 replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
349 }
350
351 /// Update all uses of the peeled scalar operation results in the residual op
352 /// to the newly added arguments.
353 {
354 SmallVector<Value> scalarReplacements;
355 unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
356 scalarReplacements.reserve(peeledScalarOpNumResults);
357 for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
358 scalarReplacements.push_back(
359 residualGenericOpBody->getArgument(num + origNumInputs));
360 bool allUsesReplaced = false;
361 rewriter.replaceOpUsesWithinBlock(peeledScalarOperation, scalarReplacements,
362 residualGenericOpBody, &allUsesReplaced);
363 assert(!allUsesReplaced &&
364 "peeled scalar operation is erased when it wasnt expected to be");
365 }
366
367 // Replace the original operation
368 rewriter.replaceOp(genericOp, replacements);
369 return success();
370}
371
373 RewritePatternSet &patterns, bool removeDeadArgsAndResults) {
374 patterns.insert<DecomposeLinalgOp>(patterns.getContext());
375 // Add the patterns to clean up the dead operands and results.
376 if (removeDeadArgsAndResults)
378}
return success()
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static SmallVector< OpFoldResult > getGenericOpLoopRange(OpBuilder &b, GenericOp op)
Helper method to compute the range of a generic op.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
OpListType & getOperations()
Definition Block.h:147
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator begin()
Definition Block.h:153
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:391
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
Operation * getOwner() const
Returns the operation that owns this result.
Definition Value.h:463
result_range getResults()
Definition Operation.h:441
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
Find uses of from within block and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:334
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, bool removeDeadArgsAndResults=true)
Populate patterns for splitting a LinalgOp with multiple statements within its payload into multiple ...
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...