MLIR  19.0.0git
DecomposeLinalgOps.cpp
Go to the documentation of this file.
1 //===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include <optional>
14 
15 using namespace mlir;
16 using namespace mlir::linalg;
17 
18 namespace {
19 
20 /// Pattern to decompose a GenericOp that has more than two statements
21 /// into one GenericOp with the first statement (i.e. peeled operation), and
22 /// a second GenericOp with the remaining statements (i.e. residual operations).
23 
24 /// - The result of the first GenericOp has the same shape as the iteration
25 /// space of the GenericOp. The body of the op yields as many values as the
26 /// original op plus all the results of the peeled operation.
27 /// - The second GenericOp has as many operands as the original operation plus
28 /// all the results of the first Generic Op. It has the same number of yields as
29 /// the original op.
30 /// - If the result of the peeled operation was yielded by the original
31 /// GenericOp the uses of the corresponding results will be replaced with the
32 /// result of the first GenericOp created.
33 ///
34 /// Example
35 ///
36 /// ```mlir
37 /// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
38 /// outs(%init0, %init1 : ...) {
39 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...):
40 /// %0 = <s0> %b0, %b1 : ...
41 /// %1 = <s1> %0, %b2 : ...
42 /// linalg.yield %0, %1 : ...
43 /// } -> (..., ...)
44 /// return %result#0, %result#1
45 /// ```
46 ///
47 /// gets split into
48 ///
49 /// ```mlir
50 /// %init = tensor.empty ...
51 /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
52 /// outs(%init0, %init1, %init : ...)
53 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
54 /// %0 = <s0> %b0, %b1 : ...
55 /// linalg.yield %0, %..., %0 : ...
56 /// } -> (..., ..., ...)
57 /// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...)
58 /// outs(%init0, %init1 : ...) {
59 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
60 /// %1 = <s1> %b3, %b2 : ...
61 /// linalg.yield %..., %1 : ...
62 /// } -> (..., ...)
63 /// return %op0#0, %op1#1
64 /// ```
65 ///
66 /// After canonicalization this is expected to be
67 ///
68 /// ```mlir
69 /// %init = tensor.empty ...
70 /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...)
71 /// outs(%init : ...)
72 /// ^bb0(%b0: ... , %b1: ... , %b2: ...):
73 /// %0 = <s0> %b0, %b1 : ...
74 /// linalg.yield %0 : ...
75 /// } -> ...
76 /// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...)
77 /// outs(%init1 : ...) {
78 /// ^bb0(%b0: ... , %b1: ... , %b2: ...):
79 /// %1 = <s1> %b1, %b0 : ...
80 /// linalg.yield %..., %1 : ...
81 /// } -> ...
82 /// return %op0, %op1
83 /// ```
84 struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> {
86 
87  LogicalResult matchAndRewrite(GenericOp genericOp,
88  PatternRewriter &rewriter) const override;
89 
90 private:
91  /// Helper method to create a generic op for the peeled scalar operation. The
92  /// created op has an empty region.
93  GenericOp createPeeledGenericOp(GenericOp genericOp,
94  PatternRewriter &rewriter) const;
95 
96  /// Helper method to create a generic op for the residual scalar operation.
97  /// The created op has the same region as the original op.
98  GenericOp createResidualGenericOp(GenericOp genericOp,
99  GenericOp peeledGenericOp,
100  PatternRewriter &rewriter) const;
101 };
102 } // namespace
103 
104 /// Helper method to compute the range of a generic op.
106  GenericOp op) {
108  b.setInsertionPoint(op);
109  Location loc = op.getLoc();
110  auto allShapesSizes =
111  cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
112  AffineMap map = op.getShapesToLoopsMap();
113  IRRewriter rewriter(b);
114  return affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
115  allShapesSizes);
116 }
117 
118 /// Helper method to permute the list of `values` based on the `map`.
120  AffineMap map) {
121  assert(map.isPermutation());
122  SmallVector<OpFoldResult> permutedValues(values.size());
123  for (const auto &position :
124  llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) {
125  return cast<AffineDimExpr>(expr).getPosition();
126  })))
127  permutedValues[position.value()] = values[position.index()];
128  return permutedValues;
129 }
130 
131 /// Get zero value for an element type.
132 static Value getZero(OpBuilder &b, Location loc, Type elementType) {
133  assert(elementType.isIntOrIndexOrFloat() &&
134  "expected scalar type while computing zero value");
135  if (isa<IntegerType>(elementType))
136  return b.create<arith::ConstantIntOp>(loc, 0, elementType);
137  if (elementType.isIndex())
138  return b.create<arith::ConstantIndexOp>(loc, 0);
139  // Assume float.
140  auto floatType = cast<FloatType>(elementType);
141  return b.create<arith::ConstantFloatOp>(
142  loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
143 }
144 
145 GenericOp
146 DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
147  PatternRewriter &rewriter) const {
148  Block *body = genericOp.getBody();
149  Operation *peeledScalarOperation = &(*body->begin());
150  SmallVector<AffineMap> peeledGenericOpIndexingMaps =
151  genericOp.getIndexingMapsArray();
152 
153  /// Compute the loop ranges for operation. This is the shape of the result of
154  /// the generic op for the peeled operation.
155  Location loc = genericOp.getLoc();
156  SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp);
157  SmallVector<Value> newInitValues;
158  SmallVector<Type> newResultTypes;
159 
160  // Add as many new results as the number of results of the peeled scalar op.
161  for (auto scalarOpResult : peeledScalarOperation->getResults()) {
162  // If the result is yielded by the original op, use the operand, indexing
163  // map and result type that correspond to the yielded value.
164 
165  std::optional<unsigned> resultNumber;
166  for (auto *user : scalarOpResult.getUsers()) {
167  if (auto yieldOp = dyn_cast<YieldOp>(user)) {
168  // Find the first use of the `scalarOpResult` in the yield op.
169  for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {
170  if (yieldOperand.get() == scalarOpResult) {
171  resultNumber = yieldOperand.getOperandNumber();
172  break;
173  }
174  }
175  assert(resultNumber && "unable to find use of a value in its user");
176  break;
177  }
178  }
179  if (resultNumber) {
180  newInitValues.push_back(
181  genericOp.getDpsInitOperand(*resultNumber)->get());
182  OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber));
183  newResultTypes.push_back(result.getType());
184  peeledGenericOpIndexingMaps.push_back(
185  genericOp.getIndexingMapMatchingResult(result));
186  continue;
187  }
188 
189  // Fall back path, use an `init_tensor` and identity indexing map.
190  AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size());
191  Value emptyTensor =
192  rewriter.create<tensor::EmptyOp>(loc, domain, scalarOpResult.getType());
193  newInitValues.push_back(emptyTensor);
194  newResultTypes.push_back(emptyTensor.getType());
195  peeledGenericOpIndexingMaps.push_back(indexingMap);
196  }
197 
198  /// Create the peeled generic op with an empty body.
199  SmallVector<Value> outsOperands = genericOp.getOutputs();
200  outsOperands.append(newInitValues.begin(), newInitValues.end());
201  SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
202  resultTypes.append(newResultTypes.begin(), newResultTypes.end());
203  auto indexingMapAttr =
204  rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
205  return rewriter.create<GenericOp>(
206  loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr,
207  genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
208  [](OpBuilder, Location, ValueRange) {});
209 }
210 
211 GenericOp
212 DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
213  GenericOp peeledGenericOp,
214  PatternRewriter &rewriter) const {
215  /// Append all results from the peeledGenericOps as `ins` operand for the
216  /// residual generic op.
217  SmallVector<Value> residualGenericOpOperands = genericOp.getInputs();
218  unsigned origNumResults = genericOp.getNumResults();
219  unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
220  SmallVector<Value> extraIns;
221  for (auto resultNum :
222  llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
223  extraIns.push_back(peeledGenericOp->getResult(resultNum));
224  residualGenericOpOperands.append(extraIns);
225 
226  /// Add indexing maps for the newly added operands. Use the same map
227  /// as those used for the new results of the peeledGenericOp.
228  auto indexingMaps = llvm::to_vector(
229  llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) {
230  return genericOp.getMatchingIndexingMap(operand);
231  }));
232  for (auto resultNum :
233  llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
234  OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum));
235  indexingMaps.push_back(
236  peeledGenericOp.getIndexingMapMatchingResult(result));
237  }
238  for (OpOperand &outOperand : genericOp.getDpsInitsMutable())
239  indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand));
240 
241  auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
242  return rewriter.create<GenericOp>(
243  genericOp->getLoc(), genericOp->getResultTypes(),
244  residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
245  genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
246  [](OpBuilder, Location, ValueRange) {});
247 }
248 
250 DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
251  PatternRewriter &rewriter) const {
252  /// For now only match on operations where the iterator types are all parallel
253  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
254  return rewriter.notifyMatchFailure(genericOp,
255  "unhandled decomposition of operation "
256  "with non-parallel iterator types");
257  }
258  // TODO: this could be generalized to handle `linalg.generic` with buffer
259  // operands too but requires allocation for intermediates. Punt on this for
260  // now.
261  if (!genericOp.hasPureTensorSemantics()) {
262  return rewriter.notifyMatchFailure(
263  genericOp, "only operations with tensor semantics are handled");
264  }
265 
266  if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) {
267  return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation();
268  })) {
269  return rewriter.notifyMatchFailure(
270  genericOp, "unhandled decomposition of generic op with out operand not "
271  "accessed using a permutation");
272  }
273 
274  /// If the op has only a single statement (apart from the yield), do nothing.
275  Block *body = genericOp.getBody();
276  if (body->getOperations().size() <= 2) {
277  return rewriter.notifyMatchFailure(genericOp,
278  "operation has less than 3 statements");
279  }
280 
281  /// Check that the peeled statement has a scalar element type.
282  if (llvm::any_of(body->getOperations().begin()->getResultTypes(),
283  [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
284  return rewriter.notifyMatchFailure(
285  &(*body->getOperations().begin()),
286  "expected return type to be only int, index or float");
287  }
288 
289  GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
290  GenericOp residualGenericOp =
291  createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
292 
293  /// Move the first statement of the original operation into the body of the
294  /// generic op for the peeled operation.
295  Block *peeledGenericOpBody = peeledGenericOp.getBody();
296  Block *residualGenericOpBody = residualGenericOp.getBody();
297  assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
298  "expected split generic ops to have empty region");
299  peeledGenericOpBody->getOperations().splice(
300  peeledGenericOpBody->begin(), body->getOperations(), body->begin());
301  residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),
302  body->getOperations());
303 
304  Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
305  auto *yieldOp = residualGenericOpBody->getTerminator();
306  {
307  // Yield all the result of the peeled scalar operation.
308  OpBuilder::InsertionGuard g(rewriter);
309  rewriter.setInsertionPointToEnd(peeledGenericOpBody);
310  SmallVector<Value> yieldedVals;
311  for (auto origYield : yieldOp->getOperands()) {
312  if (origYield.getDefiningOp() == peeledScalarOperation) {
313  yieldedVals.push_back(origYield);
314  } else {
315  // Do not materialize any new ops inside of the decomposed LinalgOp,
316  // as that would trigger another application of the rewrite pattern
317  // (infinite loop).
318  OpBuilder::InsertionGuard g(rewriter);
319  rewriter.setInsertionPoint(peeledGenericOp);
320  yieldedVals.push_back(
321  getZero(rewriter, genericOp.getLoc(), origYield.getType()));
322  }
323  }
324  yieldedVals.append(llvm::to_vector(
325  llvm::map_range(peeledScalarOperation->getResults(),
326  [](OpResult opr) -> Value { return opr; })));
327  rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals);
328  }
329 
330  /// In the split operations, replace block arguments uses that refer to
331  /// original operation to the block arguments of the newly created operation.
332  unsigned origNumInputs = genericOp.getNumDpsInputs();
333  for (const auto &inputBlockArg :
334  llvm::enumerate(genericOp.getBody()->getArguments())) {
335  Value residualOpReplacementArg =
336  residualGenericOpBody->getArgument(inputBlockArg.index());
337  rewriter.replaceUsesWithIf(
338  inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) {
339  return use.getOwner()->getBlock() == residualGenericOpBody;
340  });
341 
342  Value peeledOpReplacementArg =
343  peeledGenericOpBody->getArgument(inputBlockArg.index());
344  rewriter.replaceUsesWithIf(
345  inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) {
346  return use.getOwner()->getBlock() == peeledGenericOpBody;
347  });
348  }
349 
350  /// Before fixing up the residual operation, track what values are yielded. If
351  /// any of those are from the peeled scalar operation, the uses of the
352  /// corresponding result have to be remapped to result of the generic op for
353  /// the peeled operation.
354  SmallVector<Value> replacements;
355  for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
356  OpResult opr = dyn_cast<OpResult>(yieldValue.value());
357  if (!opr || opr.getOwner() != peeledScalarOperation)
358  replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
359  else
360  replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
361  }
362 
363  /// Update all uses of the peeled scalar operation results in the residual op
364  /// to the newly added arguments.
365  {
366  SmallVector<Value> scalarReplacements;
367  unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
368  scalarReplacements.reserve(peeledScalarOpNumResults);
369  for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
370  scalarReplacements.push_back(
371  residualGenericOpBody->getArgument(num + origNumInputs));
372  bool allUsesReplaced = false;
373  rewriter.replaceOpUsesWithinBlock(peeledScalarOperation, scalarReplacements,
374  residualGenericOpBody, &allUsesReplaced);
375  assert(!allUsesReplaced &&
376  "peeled scalar operation is erased when it wasnt expected to be");
377  }
378 
379  // Replace the original operation
380  rewriter.replaceOp(genericOp, replacements);
381  return success();
382 }
383 
385  RewritePatternSet &patterns, bool removeDeadArgsAndResults) {
386  patterns.insert<DecomposeLinalgOp>(patterns.getContext());
387  // Add the patterns to clean up the dead operands and results.
388  if (removeDeadArgsAndResults)
390 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
SmallVector< OpFoldResult > permuteValues(ArrayRef< OpFoldResult > values, AffineMap map)
Helper method to permute the list of values based on the map.
static SmallVector< OpFoldResult > getGenericOpLoopRange(OpBuilder &b, GenericOp op)
Helper method to compute the range of a generic op.
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:391
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:609
Block represents an ordered list of Operations.
Definition: Block.h:30
bool empty()
Definition: Block.h:145
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
OpListType & getOperations()
Definition: Block.h:134
iterator begin()
Definition: Block.h:140
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:263
This is a value defined by a result of an operation.
Definition: Value.h:453
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:462
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:930
MLIRContext * getContext() const
Definition: PatternMatch.h:822
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
Find uses of from within block and replace them with to.
Definition: PatternMatch.h:689
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1235
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, bool removeDeadArgsAndResults=true)
Populate patterns for splitting a LinalgOp with multiple statements within its payload into multiple ...
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358