MLIR  19.0.0git
DecomposeLinalgOps.cpp
Go to the documentation of this file.
1 //===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include <optional>
14 
15 using namespace mlir;
16 using namespace mlir::linalg;
17 
18 namespace {
19 
20 /// Pattern to decompose a GenericOp that has more than two statements
21 /// into one GenericOp with the first statement (i.e. peeled operation), and
22 /// a second GenericOp with the remaining statements (i.e. residual operations).
23 
24 /// - The result of the first GenericOp has the same shape as the iteration
25 /// space of the GenericOp. The body of the op yields as many values as the
26 /// original op plus all the results of the peeled operation.
27 /// - The second GenericOp has as many operands as the original operation plus
28 /// all the results of the first Generic Op. It has the same number of yields as
29 /// the original op.
30 /// - If the result of the peeled operation was yielded by the original
31 /// GenericOp the uses of the corresponding results will be replaced with the
32 /// result of the first GenericOp created.
33 ///
34 /// Example
35 ///
36 /// ```mlir
37 /// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
38 /// outs(%init0, %init1 : ...) {
39 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...):
40 /// %0 = <s0> %b0, %b1 : ...
41 /// %1 = <s1> %0, %b2 : ...
42 /// linalg.yield %0, %1 : ...
43 /// } -> (..., ...)
44 /// return %result#0, %result#1
45 /// ```
46 ///
47 /// gets split into
48 ///
49 /// ```mlir
50 /// %init = tensor.empty ...
51 /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
52 /// outs(%init0, %init1, %init : ...)
53 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
54 /// %0 = <s0> %b0, %b1 : ...
55 /// linalg.yield %0, %..., %0 : ...
56 /// } -> (..., ..., ...)
57 /// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...)
58 /// outs(%init0, %init1 : ...) {
59 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
60 /// %1 = <s1> %b3, %b2 : ...
61 /// linalg.yield %..., %1 : ...
62 /// } -> (..., ...)
63 /// return %op0#0, %op1#1
64 /// ```
65 ///
66 /// After canonicalization this is expected to be
67 ///
68 /// ```mlir
69 /// %init = tensor.empty ...
70 /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...)
71 /// outs(%init : ...)
72 /// ^bb0(%b0: ... , %b1: ... , %b2: ...):
73 /// %0 = <s0> %b0, %b1 : ...
74 /// linalg.yield %0 : ...
75 /// } -> ...
76 /// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...)
77 /// outs(%init1 : ...) {
78 /// ^bb0(%b0: ... , %b1: ... , %b2: ...):
79 /// %1 = <s1> %b1, %b0 : ...
80 /// linalg.yield %..., %1 : ...
81 /// } -> ...
82 /// return %op0, %op1
83 /// ```
84 struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> {
86 
87  LogicalResult matchAndRewrite(GenericOp genericOp,
88  PatternRewriter &rewriter) const override;
89 
90 private:
91  /// Helper method to create a generic op for the peeled scalar operation. The
92  /// created op has an empty region.
93  GenericOp createPeeledGenericOp(GenericOp genericOp,
94  PatternRewriter &rewriter) const;
95 
96  /// Helper method to create a generic op for the residual scalar operation.
97  /// The created op has the same region as the original op.
98  GenericOp createResidualGenericOp(GenericOp genericOp,
99  GenericOp peeledGenericOp,
100  PatternRewriter &rewriter) const;
101 };
102 } // namespace
103 
104 /// Helper method to compute the range of a generic op.
106  GenericOp op) {
108  b.setInsertionPoint(op);
109  Location loc = op.getLoc();
110  auto allShapesSizes =
111  cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
112  AffineMap map = op.getShapesToLoopsMap();
113  IRRewriter rewriter(b);
114  return affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
115  allShapesSizes);
116 }
117 
118 /// Helper method to permute the list of `values` based on the `map`.
120  AffineMap map) {
121  assert(map.isPermutation());
122  SmallVector<OpFoldResult> permutedValues(values.size());
123  for (const auto &position :
124  llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) {
125  return cast<AffineDimExpr>(expr).getPosition();
126  })))
127  permutedValues[position.value()] = values[position.index()];
128  return permutedValues;
129 }
130 
131 /// Get zero value for an element type.
132 static Value getZero(OpBuilder &b, Location loc, Type elementType) {
133  assert(elementType.isIntOrIndexOrFloat() &&
134  "expected scalar type while computing zero value");
135  if (isa<IntegerType>(elementType))
136  return b.create<arith::ConstantIntOp>(loc, 0, elementType);
137  if (elementType.isIndex())
138  return b.create<arith::ConstantIndexOp>(loc, 0);
139  // Assume float.
140  auto floatType = cast<FloatType>(elementType);
141  return b.create<arith::ConstantFloatOp>(
142  loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
143 }
144 
145 GenericOp
146 DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
147  PatternRewriter &rewriter) const {
148  Block *body = genericOp.getBody();
149  Operation *peeledScalarOperation = &(*body->begin());
150  SmallVector<AffineMap> peeledGenericOpIndexingMaps =
151  genericOp.getIndexingMapsArray();
152 
153  /// Compute the loop ranges for operation. This is the shape of the result of
154  /// the generic op for the peeled operation.
155  Location loc = genericOp.getLoc();
156  SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp);
157  SmallVector<Value> newInitValues;
158  SmallVector<Type> newResultTypes;
159 
160  // Add as many new results as the number of results of the peeled scalar op.
161  for (auto scalarOpResult : peeledScalarOperation->getResults()) {
162  // If the result is yielded by the original op, use the operand, indexing
163  // map and result type that correspond to the yielded value.
164 
165  std::optional<unsigned> resultNumber;
166  for (auto *user : scalarOpResult.getUsers()) {
167  if (auto yieldOp = dyn_cast<YieldOp>(user)) {
168  // Find the first use of the `scalarOpResult` in the yield op.
169  for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {
170  if (yieldOperand.get() == scalarOpResult) {
171  resultNumber = yieldOperand.getOperandNumber();
172  break;
173  }
174  }
175  assert(resultNumber && "unable to find use of a value in its user");
176  break;
177  }
178  }
179  if (resultNumber) {
180  newInitValues.push_back(
181  genericOp.getDpsInitOperand(*resultNumber)->get());
182  OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber));
183  newResultTypes.push_back(result.getType());
184  peeledGenericOpIndexingMaps.push_back(
185  genericOp.getIndexingMapMatchingResult(result));
186  continue;
187  }
188 
189  // Fall back path, use an `init_tensor` and identity indexing map.
190  AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size());
191  Value emptyTensor =
192  rewriter.create<tensor::EmptyOp>(loc, domain, scalarOpResult.getType());
193  newInitValues.push_back(emptyTensor);
194  newResultTypes.push_back(emptyTensor.getType());
195  peeledGenericOpIndexingMaps.push_back(indexingMap);
196  }
197 
198  /// Create the peeled generic op with an empty body.
199  SmallVector<Value> outsOperands = genericOp.getOutputs();
200  outsOperands.append(newInitValues.begin(), newInitValues.end());
201  SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
202  resultTypes.append(newResultTypes.begin(), newResultTypes.end());
203  auto indexingMapAttr =
204  rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
205  return rewriter.create<GenericOp>(
206  loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr,
207  genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
208  [](OpBuilder, Location, ValueRange) {});
209 }
210 
211 GenericOp
212 DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
213  GenericOp peeledGenericOp,
214  PatternRewriter &rewriter) const {
215  /// Append all results from the peeledGenericOps as `ins` operand for the
216  /// residual generic op.
217  SmallVector<Value> residualGenericOpOperands = genericOp.getInputs();
218  unsigned origNumResults = genericOp.getNumResults();
219  unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
220  SmallVector<Value> extraIns;
221  for (auto resultNum :
222  llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
223  extraIns.push_back(peeledGenericOp->getResult(resultNum));
224  residualGenericOpOperands.append(extraIns);
225 
226  /// Add indexing maps for the newly added operands. Use the same map
227  /// as those used for the new results of the peeledGenericOp.
228  auto indexingMaps = llvm::to_vector(
229  llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) {
230  return genericOp.getMatchingIndexingMap(operand);
231  }));
232  for (auto resultNum :
233  llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
234  OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum));
235  indexingMaps.push_back(
236  peeledGenericOp.getIndexingMapMatchingResult(result));
237  }
238  for (OpOperand &outOperand : genericOp.getDpsInitsMutable())
239  indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand));
240 
241  auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
242  return rewriter.create<GenericOp>(
243  genericOp->getLoc(), genericOp->getResultTypes(),
244  residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
245  genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
246  [](OpBuilder, Location, ValueRange) {});
247 }
248 
250 DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
251  PatternRewriter &rewriter) const {
252  /// For now only match on operations where the iterator types are all parallel
253  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
254  return rewriter.notifyMatchFailure(genericOp,
255  "unhandled decomposition of operation "
256  "with non-parallel iterator types");
257  }
258  // TODO: this could be generalized to handle `linalg.generic` with buffer
259  // operands too but requires allocation for intermediates. Punt on this for
260  // now.
261  if (!genericOp.hasPureTensorSemantics()) {
262  return rewriter.notifyMatchFailure(
263  genericOp, "only operations with tensor semantics are handled");
264  }
265 
266  if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) {
267  return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation();
268  })) {
269  return rewriter.notifyMatchFailure(
270  genericOp, "unhandled decomposition of generic op with out operand not "
271  "accessed using a permutation");
272  }
273 
274  /// If the op has only a single statement (apart from the yield), do nothing.
275  Block *body = genericOp.getBody();
276  if (body->getOperations().size() <= 2) {
277  return rewriter.notifyMatchFailure(genericOp,
278  "operation has less than 3 statements");
279  }
280 
281  /// Check that the peeled statement has a scalar element type.
282  if (llvm::any_of(body->getOperations().begin()->getResultTypes(),
283  [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
284  return rewriter.notifyMatchFailure(
285  &(*body->getOperations().begin()),
286  "expected return type to be only int, index or float");
287  }
288 
289  GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
290  GenericOp residualGenericOp =
291  createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
292 
293  /// Move the first statement of the original operation into the body of the
294  /// generic op for the peeled operation.
295  Block *peeledGenericOpBody = peeledGenericOp.getBody();
296  Block *residualGenericOpBody = residualGenericOp.getBody();
297  assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
298  "expected split generic ops to have empty region");
299  peeledGenericOpBody->getOperations().splice(
300  peeledGenericOpBody->begin(), body->getOperations(), body->begin());
301  residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),
302  body->getOperations());
303 
304  Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
305  auto *yieldOp = residualGenericOpBody->getTerminator();
306  {
307  // Yield all the result of the peeled scalar operation.
308  OpBuilder::InsertionGuard g(rewriter);
309  rewriter.setInsertionPointToEnd(peeledGenericOpBody);
310  SmallVector<Value> yieldedVals;
311  for (auto origYield : yieldOp->getOperands()) {
312  if (origYield.getDefiningOp() == peeledScalarOperation) {
313  yieldedVals.push_back(origYield);
314  } else {
315  // Do not materialize any new ops inside of the decomposed LinalgOp,
316  // as that would trigger another application of the rewrite pattern
317  // (infinite loop).
318  OpBuilder::InsertionGuard g(rewriter);
319  rewriter.setInsertionPoint(peeledGenericOp);
320  yieldedVals.push_back(
321  getZero(rewriter, genericOp.getLoc(), origYield.getType()));
322  }
323  }
324  yieldedVals.append(llvm::to_vector(
325  llvm::map_range(peeledScalarOperation->getResults(),
326  [](OpResult opr) -> Value { return opr; })));
327  rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals);
328  }
329 
330  /// In the split operations, replace block arguments uses that refer to
331  /// original operation to the block arguments of the newly created operation.
332  unsigned origNumInputs = genericOp.getNumDpsInputs();
333  for (const auto &inputBlockArg :
334  llvm::enumerate(genericOp.getBody()->getArguments())) {
335  Value residualOpReplacementArg =
336  residualGenericOpBody->getArgument(inputBlockArg.index());
337  rewriter.replaceUsesWithIf(
338  inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) {
339  return use.getOwner()->getBlock() == residualGenericOpBody;
340  });
341 
342  Value peeledOpReplacementArg =
343  peeledGenericOpBody->getArgument(inputBlockArg.index());
344  rewriter.replaceUsesWithIf(
345  inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) {
346  return use.getOwner()->getBlock() == peeledGenericOpBody;
347  });
348  }
349 
350  /// Before fixing up the residual operation, track what values are yielded. If
351  /// any of those are from the peeled scalar operation, the uses of the
352  /// corresponding result have to be remapped to result of the generic op for
353  /// the peeled operation.
354  SmallVector<Value> replacements;
355  for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
356  OpResult opr = dyn_cast<OpResult>(yieldValue.value());
357  if (!opr || opr.getOwner() != peeledScalarOperation)
358  replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
359  else
360  replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
361  }
362 
363  /// Update all uses of the peeled scalar operation results in the residual op
364  /// to the newly added arguments.
365  {
366  SmallVector<Value> scalarReplacements;
367  unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
368  scalarReplacements.reserve(peeledScalarOpNumResults);
369  for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
370  scalarReplacements.push_back(
371  residualGenericOpBody->getArgument(num + origNumInputs));
372  bool allUsesReplaced = false;
373  rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
374  residualGenericOpBody, &allUsesReplaced);
375  assert(!allUsesReplaced &&
376  "peeled scalar operation is erased when it wasnt expected to be");
377  }
378 
379  // Replace the original operation
380  rewriter.replaceOp(genericOp, replacements);
381  return success();
382 }
383 
385  RewritePatternSet &patterns, bool removeDeadArgsAndResults) {
386  patterns.insert<DecomposeLinalgOp>(patterns.getContext());
387  // Add the patterns to clean up the dead operands and results.
388  if (removeDeadArgsAndResults)
390 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
SmallVector< OpFoldResult > permuteValues(ArrayRef< OpFoldResult > values, AffineMap map)
Helper method to permute the list of values based on the map.
static SmallVector< OpFoldResult > getGenericOpLoopRange(OpBuilder &b, GenericOp op)
Helper method to compute the range of a generic op.
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:393
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:611
Block represents an ordered list of Operations.
Definition: Block.h:30
bool empty()
Definition: Block.h:145
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
OpListType & getOperations()
Definition: Block.h:134
iterator begin()
Definition: Block.h:140
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:731
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:263
This is a value defined by a result of an operation.
Definition: Value.h:453
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:462
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:893
MLIRContext * getContext() const
Definition: PatternMatch.h:785
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:685
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
This method replaces the uses of the results of op with the values in newValues when a use is nested ...
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:121
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1232
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, bool removeDeadArgsAndResults=true)
Populate patterns for splitting a LinalgOp with multiple statements within its payload into multiple ...
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357