MLIR  18.0.0git
DecomposeLinalgOps.cpp
Go to the documentation of this file.
1 //===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include <optional>
14 
15 using namespace mlir;
16 using namespace mlir::linalg;
17 
18 namespace {
19 
20 /// Pattern to decompose a GenericOp that has more than two statements
21 /// into one GenericOp with the first statement (i.e. peeled operation), and
22 /// a second GenericOp with the remaining statements (i.e. residual operations).
23 
24 /// - The result of the first GenericOp has the same shape as the iteration
25 /// space of the GenericOp. The body of the op yields as many values as the
26 /// original op plus all the results of the peeled operation.
27 /// - The second GenericOp has as many operands as the original operation plus
28 /// all the results of the first Generic Op. It has the same number of yields as
29 /// the original op.
30 /// - If the result of the peeled operation was yielded by the original
31 /// GenericOp the uses of the corresponding results will be replaced with the
32 /// result of the first GenericOp created.
33 ///
34 /// Example
35 ///
36 /// ```mlir
37 /// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
38 /// outs(%init0, %init1 : ...) {
39 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...):
40 /// %0 = <s0> %b0, %b1 : ...
41 /// %1 = <s1> %0, %b2 : ...
42 /// linalg.yield %0, %1 : ...
43 /// } -> (..., ...)
44 /// return %result#0, %result#1
45 /// ```
46 ///
47 /// gets split into
48 ///
49 /// ```mlir
50 /// %init = tensor.empty ...
51 /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
52 /// outs(%init0, %init1, %init : ...)
53 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
54 /// %0 = <s0> %b0, %b1 : ...
55 /// linalg.yield %0, %..., %0 : ...
56 /// } -> (..., ..., ...)
57 /// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...)
58 /// outs(%init0, %init1 : ...) {
59 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
60 /// %1 = <s1> %b3, %b2 : ...
61 /// linalg.yield %..., %1 : ...
62 /// } -> (..., ...)
63 /// return %op0#0, %op1#1
64 /// ```
65 ///
66 /// After canonicalization this is expected to be
67 ///
68 /// ```mlir
69 /// %init = tensor.empty ...
70 /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...)
71 /// outs(%init : ...)
72 /// ^bb0(%b0: ... , %b1: ... , %b2: ...):
73 /// %0 = <s0> %b0, %b1 : ...
74 /// linalg.yield %0 : ...
75 /// } -> ...
76 /// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...)
77 /// outs(%init1 : ...) {
78 /// ^bb0(%b0: ... , %b1: ... , %b2: ...):
79 /// %1 = <s1> %b1, %b0 : ...
80 /// linalg.yield %..., %1 : ...
81 /// } -> ...
82 /// return %op0, %op1
83 /// ```
84 struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> {
86 
87  LogicalResult matchAndRewrite(GenericOp genericOp,
88  PatternRewriter &rewriter) const override;
89 
90 private:
91  /// Helper method to create a generic op for the peeled scalar operation. The
92  /// created op has an empty region.
93  GenericOp createPeeledGenericOp(GenericOp genericOp,
94  PatternRewriter &rewriter) const;
95 
96  /// Helper method to create a generic op for the residual scalar operation.
97  /// The created op has the same region as the original op.
98  GenericOp createResidualGenericOp(GenericOp genericOp,
99  GenericOp peeledGenericOp,
100  PatternRewriter &rewriter) const;
101 };
102 } // namespace
103 
104 /// Helper method to compute the range of a generic op.
106  GenericOp op) {
108  b.setInsertionPoint(op);
109  Location loc = op.getLoc();
110  auto allShapesSizes =
111  cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
112  AffineMap map = op.getShapesToLoopsMap();
113  IRRewriter rewriter(b);
114  return affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
115  allShapesSizes);
116 }
117 
118 /// Helper method to permute the list of `values` based on the `map`.
120  AffineMap map) {
121  assert(map.isPermutation());
122  SmallVector<OpFoldResult> permutedValues(values.size());
123  for (const auto &position :
124  llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) {
125  return expr.cast<AffineDimExpr>().getPosition();
126  })))
127  permutedValues[position.value()] = values[position.index()];
128  return permutedValues;
129 }
130 
131 /// Get zero value for an element type.
132 static Value getZero(OpBuilder &b, Location loc, Type elementType) {
133  assert(elementType.isIntOrIndexOrFloat() &&
134  "expected scalar type while computing zero value");
135  if (isa<IntegerType>(elementType))
136  return b.create<arith::ConstantIntOp>(loc, 0, elementType);
137  if (elementType.isIndex())
138  return b.create<arith::ConstantIndexOp>(loc, 0);
139  // Assume float.
140  auto floatType = cast<FloatType>(elementType);
141  return b.create<arith::ConstantFloatOp>(
142  loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
143 }
144 
145 GenericOp
146 DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
147  PatternRewriter &rewriter) const {
148  Block *body = genericOp.getBody();
149  Operation *peeledScalarOperation = &(*body->begin());
150  SmallVector<AffineMap> peeledGenericOpIndexingMaps =
151  genericOp.getIndexingMapsArray();
152 
153  /// Compute the loop ranges for operation. This is the shape of the result of
154  /// the generic op for the peeled operation.
155  Location loc = genericOp.getLoc();
156  SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp);
157  SmallVector<Value> newInitValues;
158  SmallVector<Type> newResultTypes;
159 
160  // Add as many new results as the number of results of the peeled scalar op.
161  for (auto scalarOpResult : peeledScalarOperation->getResults()) {
162  // If the result is yielded by the original op, use the operand, indexing
163  // map and result type that correspond to the yielded value.
164 
165  std::optional<unsigned> resultNumber;
166  for (auto *user : scalarOpResult.getUsers()) {
167  if (auto yieldOp = dyn_cast<YieldOp>(user)) {
168  // Find the first use of the `scalarOpResult` in the yield op.
169  for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {
170  if (yieldOperand.get() == scalarOpResult) {
171  resultNumber = yieldOperand.getOperandNumber();
172  break;
173  }
174  }
175  assert(resultNumber && "unable to find use of a value in its user");
176  break;
177  }
178  }
179  if (resultNumber) {
180  newInitValues.push_back(
181  genericOp.getDpsInitOperand(*resultNumber)->get());
182  OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber));
183  newResultTypes.push_back(result.getType());
184  peeledGenericOpIndexingMaps.push_back(
185  genericOp.getIndexingMapMatchingResult(result));
186  continue;
187  }
188 
189  // Fall back path, use an `init_tensor` and identity indexing map.
190  AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size());
191  Value emptyTensor =
192  rewriter.create<tensor::EmptyOp>(loc, domain, scalarOpResult.getType());
193  newInitValues.push_back(emptyTensor);
194  newResultTypes.push_back(emptyTensor.getType());
195  peeledGenericOpIndexingMaps.push_back(indexingMap);
196  }
197 
198  /// Create the peeled generic op with an empty body.
199  SmallVector<Value> outsOperands = genericOp.getOutputs();
200  outsOperands.append(newInitValues.begin(), newInitValues.end());
201  SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
202  resultTypes.append(newResultTypes.begin(), newResultTypes.end());
203  auto indexingMapAttr =
204  rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
205  return rewriter.create<GenericOp>(
206  loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr,
207  genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
208  [](OpBuilder, Location, ValueRange) {});
209 }
210 
211 GenericOp
212 DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
213  GenericOp peeledGenericOp,
214  PatternRewriter &rewriter) const {
215  /// Append all results from the peeledGenericOps as `ins` operand for the
216  /// residual generic op.
217  SmallVector<Value> residualGenericOpOperands = genericOp.getInputs();
218  unsigned origNumResults = genericOp.getNumResults();
219  unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
220  SmallVector<Value> extraIns;
221  for (auto resultNum :
222  llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
223  extraIns.push_back(peeledGenericOp->getResult(resultNum));
224  residualGenericOpOperands.append(extraIns);
225 
226  /// Add indexing maps for the newly added operands. Use the same map
227  /// as those used for the new results of the peeledGenericOp.
228  auto indexingMaps = llvm::to_vector(
229  llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) {
230  return genericOp.getMatchingIndexingMap(operand);
231  }));
232  for (auto resultNum :
233  llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
234  OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum));
235  indexingMaps.push_back(
236  peeledGenericOp.getIndexingMapMatchingResult(result));
237  }
238  for (OpOperand &outOperand : genericOp.getDpsInitsMutable())
239  indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand));
240 
241  auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
242  return rewriter.create<GenericOp>(
243  genericOp->getLoc(), genericOp->getResultTypes(),
244  residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
245  genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
246  [](OpBuilder, Location, ValueRange) {});
247 }
248 
250 DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
251  PatternRewriter &rewriter) const {
252  /// For now only match on operations where the iterator types are all parallel
253  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
254  return rewriter.notifyMatchFailure(genericOp,
255  "unhandled decomposition of operation "
256  "with non-parallel iterator types");
257  }
258  // TODO: this could be generalized to handle `linalg.generic` with buffer
259  // operands too but requires allocation for intermediates. Punt on this for
260  // now.
261  if (!genericOp.hasTensorSemantics()) {
262  return rewriter.notifyMatchFailure(
263  genericOp, "only operations with tensor semantics are handled");
264  }
265 
266  if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) {
267  return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation();
268  })) {
269  return rewriter.notifyMatchFailure(
270  genericOp, "unhandled decomposition of generic op with out operand not "
271  "accessed using a permutation");
272  }
273 
274  /// If the op has only a single statement (apart from the yield), do nothing.
275  Block *body = genericOp.getBody();
276  if (body->getOperations().size() <= 2) {
277  return rewriter.notifyMatchFailure(genericOp,
278  "operation has less than 3 statements");
279  }
280 
281  /// Check that the peeled statement has a scalar element type.
282  if (llvm::any_of(body->getOperations().begin()->getResultTypes(),
283  [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
284  return rewriter.notifyMatchFailure(
285  &(*body->getOperations().begin()),
286  "expected return type to be only int, index or float");
287  }
288 
289  GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
290  GenericOp residualGenericOp =
291  createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
292 
293  /// Move the first statement of the original operation into the body of the
294  /// generic op for the peeled operation.
295  Block *peeledGenericOpBody = peeledGenericOp.getBody();
296  Block *residualGenericOpBody = residualGenericOp.getBody();
297  assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
298  "expected split generic ops to have empty region");
299  peeledGenericOpBody->getOperations().splice(
300  peeledGenericOpBody->begin(), body->getOperations(), body->begin());
301  residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),
302  body->getOperations());
303 
304  Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
305  auto *yieldOp = residualGenericOpBody->getTerminator();
306  {
307  // Yield all the result of the peeled scalar operation.
308  OpBuilder::InsertionGuard g(rewriter);
309  rewriter.setInsertionPointToEnd(peeledGenericOpBody);
310  SmallVector<Value> yieldedVals;
311  for (auto origYield : yieldOp->getOperands()) {
312  if (origYield.getDefiningOp() == peeledScalarOperation) {
313  yieldedVals.push_back(origYield);
314  } else {
315  yieldedVals.push_back(
316  getZero(rewriter, genericOp.getLoc(), origYield.getType()));
317  }
318  }
319  yieldedVals.append(llvm::to_vector(
320  llvm::map_range(peeledScalarOperation->getResults(),
321  [](OpResult opr) -> Value { return opr; })));
322  rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals);
323  }
324 
325  /// In the split operations, replace block arguments uses that refer to
326  /// original operation to the block arguments of the newly created operation.
327  unsigned origNumInputs = genericOp.getNumDpsInputs();
328  for (const auto &inputBlockArg :
329  llvm::enumerate(genericOp.getBody()->getArguments())) {
330  Value residualOpReplacementArg =
331  residualGenericOpBody->getArgument(inputBlockArg.index());
332  rewriter.replaceUsesWithIf(
333  inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) {
334  return use.getOwner()->getBlock() == residualGenericOpBody;
335  });
336 
337  Value peeledOpReplacementArg =
338  peeledGenericOpBody->getArgument(inputBlockArg.index());
339  rewriter.replaceUsesWithIf(
340  inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) {
341  return use.getOwner()->getBlock() == peeledGenericOpBody;
342  });
343  }
344 
345  /// Before fixing up the residual operation, track what values are yielded. If
346  /// any of those are from the peeled scalar operation, the uses of the
347  /// corresponding result have to be remapped to result of the generic op for
348  /// the peeled operation.
349  SmallVector<Value> replacements;
350  for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
351  OpResult opr = dyn_cast<OpResult>(yieldValue.value());
352  if (!opr || opr.getOwner() != peeledScalarOperation)
353  replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
354  else
355  replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
356  }
357 
358  /// Update all uses of the peeled scalar operation results in the residual op
359  /// to the newly added arguments.
360  {
361  SmallVector<Value> scalarReplacements;
362  unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
363  scalarReplacements.reserve(peeledScalarOpNumResults);
364  for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
365  scalarReplacements.push_back(
366  residualGenericOpBody->getArgument(num + origNumInputs));
367  bool allUsesReplaced = false;
368  rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
369  residualGenericOpBody, &allUsesReplaced);
370  assert(!allUsesReplaced &&
371  "peeled scalar operation is erased when it wasnt expected to be");
372  }
373 
374  // Replace the original operation
375  rewriter.replaceOp(genericOp, replacements);
376  return success();
377 }
378 
380  RewritePatternSet &patterns, bool removeDeadArgsAndResults) {
381  patterns.insert<DecomposeLinalgOp>(patterns.getContext());
382  // Add the patterns to clean up the dead operands and results.
383  if (removeDeadArgsAndResults)
385 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
SmallVector< OpFoldResult > permuteValues(ArrayRef< OpFoldResult > values, AffineMap map)
Helper method to permute the list of values based on the map.
static SmallVector< OpFoldResult > getGenericOpLoopRange(OpBuilder &b, GenericOp op)
Helper method to compute the range of a generic op.
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:350
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:564
Block represents an ordered list of Operations.
Definition: Block.h:30
bool empty()
Definition: Block.h:141
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
OpListType & getOperations()
Definition: Block.h:130
iterator begin()
Definition: Block.h:136
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:376
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:710
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:421
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents an operand of an operation.
Definition: Value.h:261
This is a value defined by a result of an operation.
Definition: Value.h:448
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
This method replaces the uses of the results of op with the values in newValues when a use is nested ...
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:121
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1321
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, bool removeDeadArgsAndResults=true)
Populate patterns for splitting a LinalgOp with multiple statements within its payload into multiple ...
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357