MLIR 23.0.0git
DecomposeLinalgOps.cpp
Go to the documentation of this file.
1//===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
13#include "llvm/ADT/SmallVectorExtras.h"
14#include <optional>
15
16using namespace mlir;
17using namespace mlir::linalg;
18
19namespace {
20
21/// Pattern to decompose a GenericOp that has more than two statements
22/// into one GenericOp with the first statement (i.e. peeled operation), and
23/// a second GenericOp with the remaining statements (i.e. residual operations).
24
25/// - The result of the first GenericOp has the same shape as the iteration
26/// space of the GenericOp. The body of the op yields as many values as the
27/// original op plus all the results of the peeled operation.
28/// - The second GenericOp has as many operands as the original operation plus
29/// all the results of the first Generic Op. It has the same number of yields as
30/// the original op.
31/// - If the result of the peeled operation was yielded by the original
32/// GenericOp the uses of the corresponding results will be replaced with the
33/// result of the first GenericOp created.
34///
35/// Example
36///
37/// ```mlir
38/// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
39/// outs(%init0, %init1 : ...) {
40/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...):
41/// %0 = <s0> %b0, %b1 : ...
42/// %1 = <s1> %0, %b2 : ...
43/// linalg.yield %0, %1 : ...
44/// } -> (..., ...)
45/// return %result#0, %result#1
46/// ```
47///
48/// gets split into
49///
50/// ```mlir
51/// %init = tensor.empty ...
52/// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
53/// outs(%init0, %init1, %init : ...)
54/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
55/// %0 = <s0> %b0, %b1 : ...
56/// linalg.yield %0, %..., %0 : ...
57/// } -> (..., ..., ...)
58/// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...)
59/// outs(%init0, %init1 : ...) {
60/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
61/// %1 = <s1> %b3, %b2 : ...
62/// linalg.yield %..., %1 : ...
63/// } -> (..., ...)
64/// return %op0#0, %op1#1
65/// ```
66///
67/// After canonicalization this is expected to be
68///
69/// ```mlir
70/// %init = tensor.empty ...
71/// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...)
72/// outs(%init : ...)
73/// ^bb0(%b0: ... , %b1: ... , %b2: ...):
74/// %0 = <s0> %b0, %b1 : ...
75/// linalg.yield %0 : ...
76/// } -> ...
77/// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...)
78/// outs(%init1 : ...) {
79/// ^bb0(%b0: ... , %b1: ... , %b2: ...):
80/// %1 = <s1> %b1, %b0 : ...
81/// linalg.yield %..., %1 : ...
82/// } -> ...
83/// return %op0, %op1
84/// ```
85struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> {
86 using OpRewritePattern<GenericOp>::OpRewritePattern;
87
88 LogicalResult matchAndRewrite(GenericOp genericOp,
89 PatternRewriter &rewriter) const override;
90
91private:
92 /// Helper method to create a generic op for the peeled scalar operation. The
93 /// created op has an empty region.
94 GenericOp createPeeledGenericOp(GenericOp genericOp,
95 PatternRewriter &rewriter) const;
96
97 /// Helper method to create a generic op for the residual scalar operation.
98 /// The created op has the same region as the original op.
99 GenericOp createResidualGenericOp(GenericOp genericOp,
100 GenericOp peeledGenericOp,
101 PatternRewriter &rewriter) const;
102};
103} // namespace
104
105/// Helper method to compute the range of a generic op.
107 GenericOp op) {
109 b.setInsertionPoint(op);
110 Location loc = op.getLoc();
111 auto allShapesSizes =
112 cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
113 AffineMap map = op.getShapesToLoopsMap();
114 IRRewriter rewriter(b);
115 return affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
116 allShapesSizes);
117}
118
119/// Helper method to permute the list of `values` based on the `map`.
121 AffineMap map) {
122 assert(map.isPermutation());
123 SmallVector<OpFoldResult> permutedValues(values.size());
124 for (const auto &position :
125 llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) {
126 return cast<AffineDimExpr>(expr).getPosition();
127 })))
128 permutedValues[position.value()] = values[position.index()];
129 return permutedValues;
130}
131
132/// Get zero value for an element type.
133static Value getZero(OpBuilder &b, Location loc, Type elementType) {
134 assert(elementType.isIntOrIndexOrFloat() &&
135 "expected scalar type while computing zero value");
136 if (isa<IntegerType>(elementType))
137 return arith::ConstantIntOp::create(b, loc, elementType, 0);
138 if (elementType.isIndex())
139 return arith::ConstantIndexOp::create(b, loc, 0);
140 // Assume float.
141 auto floatType = cast<FloatType>(elementType);
143 b, loc, floatType, APFloat::getZero(floatType.getFloatSemantics()));
144}
145
146GenericOp
147DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
148 PatternRewriter &rewriter) const {
149 Block *body = genericOp.getBody();
150 Operation *peeledScalarOperation = &(*body->begin());
151 SmallVector<AffineMap> peeledGenericOpIndexingMaps =
152 genericOp.getIndexingMapsArray();
153
154 /// Compute the loop ranges for operation. This is the shape of the result of
155 /// the generic op for the peeled operation.
156 Location loc = genericOp.getLoc();
157 SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp);
158 SmallVector<Value> newInitValues;
159 SmallVector<Type> newResultTypes;
160
161 // Add as many new results as the number of results of the peeled scalar op.
162 for (auto scalarOpResult : peeledScalarOperation->getResults()) {
163 // If the result is yielded by the original op, use the operand, indexing
164 // map and result type that correspond to the yielded value.
165
166 std::optional<unsigned> resultNumber;
167 for (auto *user : scalarOpResult.getUsers()) {
168 if (auto yieldOp = dyn_cast<YieldOp>(user)) {
169 // Find the first use of the `scalarOpResult` in the yield op.
170 for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {
171 if (yieldOperand.get() == scalarOpResult) {
172 resultNumber = yieldOperand.getOperandNumber();
173 break;
174 }
175 }
176 assert(resultNumber && "unable to find use of a value in its user");
177 break;
178 }
179 }
180 if (resultNumber) {
181 newInitValues.push_back(
182 genericOp.getDpsInitOperand(*resultNumber)->get());
183 OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber));
184 newResultTypes.push_back(result.getType());
185 peeledGenericOpIndexingMaps.push_back(
186 genericOp.getIndexingMapMatchingResult(result));
187 continue;
188 }
189
190 // Fall back path, use an `init_tensor` and identity indexing map.
191 AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size());
192 Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, domain,
193 scalarOpResult.getType());
194 newInitValues.push_back(emptyTensor);
195 newResultTypes.push_back(emptyTensor.getType());
196 peeledGenericOpIndexingMaps.push_back(indexingMap);
197 }
198
199 /// Create the peeled generic op with an empty body.
200 SmallVector<Value> outsOperands = genericOp.getOutputs();
201 outsOperands.append(newInitValues.begin(), newInitValues.end());
202 SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
203 resultTypes.append(newResultTypes.begin(), newResultTypes.end());
204 auto indexingMapAttr =
205 rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
206 return GenericOp::create(
207 rewriter, loc, resultTypes, genericOp.getInputs(), outsOperands,
208 indexingMapAttr, genericOp.getIteratorTypes(), /*doc=*/nullptr,
209 /*libraryCall=*/nullptr, [](OpBuilder, Location, ValueRange) {});
210}
211
212GenericOp
213DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
214 GenericOp peeledGenericOp,
215 PatternRewriter &rewriter) const {
216 /// Append all results from the peeledGenericOps as `ins` operand for the
217 /// residual generic op.
218 SmallVector<Value> residualGenericOpOperands = genericOp.getInputs();
219 unsigned origNumResults = genericOp.getNumResults();
220 unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
221 SmallVector<Value> extraIns;
222 for (auto resultNum :
223 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
224 extraIns.push_back(peeledGenericOp->getResult(resultNum));
225 residualGenericOpOperands.append(extraIns);
226
227 /// Add indexing maps for the newly added operands. Use the same map
228 /// as those used for the new results of the peeledGenericOp.
229 auto indexingMaps = llvm::map_to_vector(
230 genericOp.getDpsInputOperands(), [&](OpOperand *operand) {
231 return genericOp.getMatchingIndexingMap(operand);
232 });
233 for (auto resultNum :
234 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
235 OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum));
236 indexingMaps.push_back(
237 peeledGenericOp.getIndexingMapMatchingResult(result));
238 }
239 for (OpOperand &outOperand : genericOp.getDpsInitsMutable())
240 indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand));
241
242 auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
243 return GenericOp::create(
244 rewriter, genericOp->getLoc(), genericOp->getResultTypes(),
245 residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
246 genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
247 [](OpBuilder, Location, ValueRange) {});
248}
249
250LogicalResult
251DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
252 PatternRewriter &rewriter) const {
253 /// For now only match on operations where the iterator types are all parallel
254 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
255 return rewriter.notifyMatchFailure(genericOp,
256 "unhandled decomposition of operation "
257 "with non-parallel iterator types");
258 }
259 // TODO: this could be generalized to handle `linalg.generic` with buffer
260 // operands too but requires allocation for intermediates. Punt on this for
261 // now.
262 if (!genericOp.hasPureTensorSemantics()) {
263 return rewriter.notifyMatchFailure(
264 genericOp, "only operations with tensor semantics are handled");
265 }
266
267 if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) {
268 return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation();
269 })) {
270 return rewriter.notifyMatchFailure(
271 genericOp, "unhandled decomposition of generic op with out operand not "
272 "accessed using a permutation");
273 }
274
275 /// If the op has only a single statement (apart from the yield), do nothing.
276 Block *body = genericOp.getBody();
277 if (body->getOperations().size() <= 2) {
278 return rewriter.notifyMatchFailure(genericOp,
279 "operation has less than 3 statements");
280 }
281
282 /// Check that the peeled statement has a scalar element type.
283 if (llvm::any_of(body->getOperations().begin()->getResultTypes(),
284 [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
285 return rewriter.notifyMatchFailure(
286 &(*body->getOperations().begin()),
287 "expected return type to be only int, index or float");
288 }
289
290 GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
291 GenericOp residualGenericOp =
292 createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
293
294 /// Move the first statement of the original operation into the body of the
295 /// generic op for the peeled operation.
296 Block *peeledGenericOpBody = peeledGenericOp.getBody();
297 Block *residualGenericOpBody = residualGenericOp.getBody();
298 assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
299 "expected split generic ops to have empty region");
300 peeledGenericOpBody->getOperations().splice(
301 peeledGenericOpBody->begin(), body->getOperations(), body->begin());
302 residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),
303 body->getOperations());
304
305 Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
306 auto *yieldOp = residualGenericOpBody->getTerminator();
307 {
308 // Yield all the result of the peeled scalar operation.
309 OpBuilder::InsertionGuard g(rewriter);
310 rewriter.setInsertionPointToEnd(peeledGenericOpBody);
311 SmallVector<Value> yieldedVals;
312 for (auto origYield : yieldOp->getOperands()) {
313 if (origYield.getDefiningOp() == peeledScalarOperation) {
314 yieldedVals.push_back(origYield);
315 } else {
316 // Do not materialize any new ops inside of the decomposed LinalgOp,
317 // as that would trigger another application of the rewrite pattern
318 // (infinite loop).
319 OpBuilder::InsertionGuard g(rewriter);
320 rewriter.setInsertionPoint(peeledGenericOp);
321 yieldedVals.push_back(
322 getZero(rewriter, genericOp.getLoc(), origYield.getType()));
323 }
324 }
325 yieldedVals.append(
326 llvm::map_to_vector(peeledScalarOperation->getResults(),
327 [](OpResult opr) -> Value { return opr; }));
328 YieldOp::create(rewriter, genericOp.getLoc(), yieldedVals);
329 }
330
331 /// In the split operations, replace block arguments uses that refer to
332 /// original operation to the block arguments of the newly created operation.
333 unsigned origNumInputs = genericOp.getNumDpsInputs();
334 for (const auto &inputBlockArg :
335 llvm::enumerate(genericOp.getBody()->getArguments())) {
336 Value residualOpReplacementArg =
337 residualGenericOpBody->getArgument(inputBlockArg.index());
338 rewriter.replaceUsesWithIf(
339 inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) {
340 return use.getOwner()->getBlock() == residualGenericOpBody;
341 });
342
343 Value peeledOpReplacementArg =
344 peeledGenericOpBody->getArgument(inputBlockArg.index());
345 rewriter.replaceUsesWithIf(
346 inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) {
347 return use.getOwner()->getBlock() == peeledGenericOpBody;
348 });
349 }
350
351 /// Before fixing up the residual operation, track what values are yielded. If
352 /// any of those are from the peeled scalar operation, the uses of the
353 /// corresponding result have to be remapped to result of the generic op for
354 /// the peeled operation.
355 SmallVector<Value> replacements;
356 for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
357 OpResult opr = dyn_cast<OpResult>(yieldValue.value());
358 if (!opr || opr.getOwner() != peeledScalarOperation)
359 replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
360 else
361 replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
362 }
363
364 /// Update all uses of the peeled scalar operation results in the residual op
365 /// to the newly added arguments.
366 {
367 SmallVector<Value> scalarReplacements;
368 unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
369 scalarReplacements.reserve(peeledScalarOpNumResults);
370 for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
371 scalarReplacements.push_back(
372 residualGenericOpBody->getArgument(num + origNumInputs));
373 bool allUsesReplaced = false;
374 rewriter.replaceOpUsesWithinBlock(peeledScalarOperation, scalarReplacements,
375 residualGenericOpBody, &allUsesReplaced);
376 assert(!allUsesReplaced &&
377 "peeled scalar operation is erased when it wasnt expected to be");
378 }
379
380 // Replace the original operation
381 rewriter.replaceOp(genericOp, replacements);
382 return success();
383}
384
386 RewritePatternSet &patterns, bool removeDeadArgsAndResults) {
387 patterns.insert<DecomposeLinalgOp>(patterns.getContext());
388 // Add the patterns to clean up the dead operands and results.
389 if (removeDeadArgsAndResults)
391}
return success()
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static SmallVector< OpFoldResult > getGenericOpLoopRange(OpBuilder &b, GenericOp op)
Helper method to compute the range of a generic op.
SmallVector< OpFoldResult > permuteValues(ArrayRef< OpFoldResult > values, AffineMap map)
Helper method to permute the list of values based on the map.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
ArrayRef< AffineExpr > getResults() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
OpListType & getOperations()
Definition Block.h:147
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator begin()
Definition Block.h:153
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
Operation * getOwner() const
Returns the operation that owns this result.
Definition Value.h:466
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
Find uses of from within block and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:333
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:362
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:261
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, bool removeDeadArgsAndResults=true)
Populate patterns for splitting a LinalgOp with multiple statements within its payload into multiple ...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...