MLIR  16.0.0git
DecomposeLinalgOps.cpp
Go to the documentation of this file.
1 //===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 
14 using namespace mlir;
15 using namespace mlir::linalg;
16 
17 namespace {
18 
19 /// Pattern to decompose a GenericOp that has more than two statements
20 /// into one GenericOp with the first statement (i.e. peeled operation), and
21 /// a second GenericOp with the remaining statements (i.e. residual operations).
22 
23 /// - The result of the first GenericOp has the same shape as the iteration
24 /// space of the GenericOp. The body of the op yields as many values as the
25 /// original op plus all the results of the peeled operation.
26 /// - The second GenericOp has as many operands as the original operation plus
27 /// all the results of the first Generic Op. It has the same number of yields as
28 /// the original op.
29 /// - If the result of the peeled operation was yielded by the original
30 /// GenericOp the uses of the corresponding results will be replaced with the
31 /// result of the first GenericOp created.
32 ///
33 /// Example
34 ///
35 /// ```mlir
36 /// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
37 /// outs(%init0, %init1 : ...) {
38 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...):
39 /// %0 = <s0> %b0, %b1 : ...
40 /// %1 = <s1> %0, %b2 : ...
41 /// linalg.yield %0, %1 : ...
42 /// } -> (..., ...)
43 /// return %result#0, %result#1
44 /// ```
45 ///
46 /// gets split into
47 ///
48 /// ```mlir
49 /// %init = linalg.init_tensor ...
50 /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
51 /// outs(%init0, %init1, %init : ...)
52 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
53 /// %0 = <s0> %b0, %b1 : ...
54 /// linalg.yield %0, %..., %0 : ...
55 /// } -> (..., ..., ...)
56 /// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...)
57 /// outs(%init0, %init1 : ...) {
58 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
59 /// %1 = <s1> %b3, %b2 : ...
60 /// linalg.yield %..., %1 : ...
61 /// } -> (..., ...)
62 /// return %op0#0, %op1#1
63 /// ```
64 ///
65 /// After canonicalization this is expected to be
66 ///
67 /// ```mlir
68 /// %init = linalg.init_tensor ...
69 /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...)
70 /// outs(%init : ...)
71 /// ^bb0(%b0: ... , %b1: ... , %b2: ...):
72 /// %0 = <s0> %b0, %b1 : ...
73 /// linalg.yield %0 : ...
74 /// } -> ...
75 /// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...)
76 /// outs(%init1 : ...) {
77 /// ^bb0(%b0: ... , %b1: ... , %b2: ...):
78 /// %1 = <s1> %b1, %b0 : ...
79 /// linalg.yield %..., %1 : ...
80 /// } -> ...
81 /// return %op0, %op1
82 /// ```
83 struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> {
85 
86  LogicalResult matchAndRewrite(GenericOp genericOp,
87  PatternRewriter &rewriter) const override;
88 
89 private:
90  /// Helper method to create a generic op for the peeled scalar operation. The
91  /// created op has an empty region.
92  GenericOp createPeeledGenericOp(GenericOp genericOp,
93  PatternRewriter &rewriter) const;
94 
95  /// Helper method to create a generic op for the residual scalar operation.
96  /// The created op has the same region as the original op.
97  GenericOp createResidualGenericOp(GenericOp genericOp,
98  GenericOp peeledGenericOp,
99  PatternRewriter &rewriter) const;
100 };
101 } // namespace
102 
103 /// Helper method to compute the range of a generic op.
105  GenericOp op) {
107  b.setInsertionPoint(op);
108  Location loc = op.getLoc();
109  auto allShapesSizes =
110  cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
111  AffineMap map = op.getShapesToLoopsMap();
112  IRRewriter rewriter(b);
113  return makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
114  allShapesSizes);
115 }
116 
117 /// Helper method to permute the list of `values` based on the `map`.
119  AffineMap map) {
120  assert(map.isPermutation());
121  SmallVector<OpFoldResult> permutedValues(values.size());
122  for (const auto &position :
123  llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) {
124  return expr.cast<AffineDimExpr>().getPosition();
125  })))
126  permutedValues[position.value()] = values[position.index()];
127  return permutedValues;
128 }
129 
130 /// Get zero value for an element type.
131 static Value getZero(OpBuilder &b, Location loc, Type elementType) {
132  assert(elementType.isIntOrIndexOrFloat() &&
133  "expected scalar type while computing zero value");
134  if (elementType.isa<IntegerType>())
135  return b.create<arith::ConstantIntOp>(loc, 0, elementType);
136  if (elementType.isIndex())
137  return b.create<arith::ConstantIndexOp>(loc, 0);
138  // Assume float.
139  auto floatType = elementType.cast<FloatType>();
140  return b.create<arith::ConstantFloatOp>(
141  loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
142 }
143 
144 GenericOp
145 DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
146  PatternRewriter &rewriter) const {
147  Block *body = genericOp.getBody();
148  Operation *peeledScalarOperation = &(*body->begin());
149  SmallVector<AffineMap> peeledGenericOpIndexingMaps =
150  genericOp.getIndexingMapsArray();
151 
152  /// Compute the loop ranges for operation. This is the shape of the result of
153  /// the generic op for the peeled operation.
154  Location loc = genericOp.getLoc();
155  SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp);
156  SmallVector<Value> newInitValues;
157  SmallVector<Type> newResultTypes;
158 
159  // Add as many new results as the number of results of the peeled scalar op.
160  for (auto scalarOpResult : peeledScalarOperation->getResults()) {
161  // If the result is yielded by the original op, use the operand, indexing
162  // map and result type that correspond to the yielded value.
163 
164  Optional<unsigned> resultNumber;
165  for (auto *user : scalarOpResult.getUsers()) {
166  if (auto yieldOp = dyn_cast<YieldOp>(user)) {
167  // Find the first use of the `scalarOpResult` in the yield op.
168  for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {
169  if (yieldOperand.get() == scalarOpResult) {
170  resultNumber = yieldOperand.getOperandNumber();
171  break;
172  }
173  }
174  assert(resultNumber && "unable to find use of a value in its user");
175  break;
176  }
177  }
178  if (resultNumber) {
179  newInitValues.push_back(genericOp.getOutputOperand(*resultNumber)->get());
180  OpResult result = genericOp.getResult(*resultNumber).cast<OpResult>();
181  newResultTypes.push_back(result.getType());
182  peeledGenericOpIndexingMaps.push_back(
183  genericOp.getIndexingMapMatchingResult(result));
184  continue;
185  }
186 
187  // Fall back path, use an `init_tensor` and identity indexing map.
188  AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size());
189  Value initTensor = rewriter.create<linalg::InitTensorOp>(
190  loc, domain, scalarOpResult.getType());
191  newInitValues.push_back(initTensor);
192  newResultTypes.push_back(initTensor.getType());
193  peeledGenericOpIndexingMaps.push_back(indexingMap);
194  }
195 
196  /// Create the peeled generic op with an empty body.
197  SmallVector<Value> outsOperands = genericOp.getOutputOperands();
198  outsOperands.append(newInitValues.begin(), newInitValues.end());
199  SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
200  resultTypes.append(newResultTypes.begin(), newResultTypes.end());
201  auto indexingMapAttr =
202  rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
203  return rewriter.create<GenericOp>(
204  loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr,
205  genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
206  [](OpBuilder, Location, ValueRange) {});
207 }
208 
209 GenericOp
210 DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
211  GenericOp peeledGenericOp,
212  PatternRewriter &rewriter) const {
213  /// Append all results from the peeledGenericOps as `ins` operand for the
214  /// residual generic op.
215  SmallVector<Value> residualGenericOpOperands = llvm::to_vector(
216  llvm::map_range(genericOp.getInputOperands(),
217  [](OpOperand *operand) { return operand->get(); }));
218  unsigned origNumResults = genericOp.getNumResults();
219  unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
220  SmallVector<Value> extraIns;
221  for (auto resultNum :
222  llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
223  extraIns.push_back(peeledGenericOp->getResult(resultNum));
224  residualGenericOpOperands.append(extraIns);
225 
226  /// Add indexing maps for the newly added operands. Use the same map
227  /// as those used for the new results of the peeledGenericOp.
228  auto indexingMaps = llvm::to_vector(
229  llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) {
230  return genericOp.getMatchingIndexingMap(operand);
231  }));
232  for (auto resultNum :
233  llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
234  OpResult result = peeledGenericOp.getResult(resultNum).cast<OpResult>();
235  indexingMaps.push_back(
236  peeledGenericOp.getIndexingMapMatchingResult(result));
237  }
238  for (OpOperand *outOperand : genericOp.getOutputOperands())
239  indexingMaps.push_back(genericOp.getMatchingIndexingMap(outOperand));
240 
241  auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
242  return rewriter.create<GenericOp>(
243  genericOp->getLoc(), genericOp->getResultTypes(),
244  residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
245  genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
246  [](OpBuilder, Location, ValueRange) {});
247 }
248 
250 DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
251  PatternRewriter &rewriter) const {
252  /// For now only match on operations where the iterator types are all parallel
253  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
254  return rewriter.notifyMatchFailure(genericOp,
255  "unhandled decomposition of operation "
256  "with non-parallel iterator types");
257  }
258  // TODO: this could be generalized to handle `linalg.generic` with buffer
259  // operands too but requires allocation for intermediates. Punt on this for
260  // now.
261  if (!genericOp.hasTensorSemantics()) {
262  return rewriter.notifyMatchFailure(
263  genericOp, "only operations with tensor semantics are handled");
264  }
265 
266  if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) {
267  return !genericOp.getMatchingIndexingMap(outOperand).isPermutation();
268  })) {
269  return rewriter.notifyMatchFailure(
270  genericOp, "unhandled decomposition of generic op with out operand not "
271  "accessed using a permutation");
272  }
273 
274  /// If the op has only a single statement (apart from the yield), do nothing.
275  Block *body = genericOp.getBody();
276  if (body->getOperations().size() <= 2) {
277  return rewriter.notifyMatchFailure(genericOp,
278  "operation has less than 3 statements");
279  }
280 
281  /// Check that the peeled statement has a scalar element type.
282  if (llvm::any_of(body->getOperations().begin()->getResultTypes(),
283  [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
284  return rewriter.notifyMatchFailure(
285  &(*body->getOperations().begin()),
286  "expected return type to be only int, index or float");
287  }
288 
289  GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
290  GenericOp residualGenericOp =
291  createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
292 
293  /// Move the first statement of the original operation into the body of the
294  /// generic op for the peeled operation.
295  Block *peeledGenericOpBody = peeledGenericOp.getBody();
296  Block *residualGenericOpBody = residualGenericOp.getBody();
297  assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
298  "expected split generic ops to have empty region");
299  peeledGenericOpBody->getOperations().splice(
300  peeledGenericOpBody->begin(), body->getOperations(), body->begin());
301  residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),
302  body->getOperations());
303 
304  Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
305  auto *yieldOp = residualGenericOpBody->getTerminator();
306  {
307  // Yield all the result of the peeled scalar operation.
308  OpBuilder::InsertionGuard g(rewriter);
309  rewriter.setInsertionPointToEnd(peeledGenericOpBody);
310  SmallVector<Value> yieldedVals;
311  for (auto origYield : yieldOp->getOperands()) {
312  if (origYield.getDefiningOp() == peeledScalarOperation) {
313  yieldedVals.push_back(origYield);
314  } else {
315  yieldedVals.push_back(
316  getZero(rewriter, genericOp.getLoc(), origYield.getType()));
317  }
318  }
319  yieldedVals.append(llvm::to_vector(
320  llvm::map_range(peeledScalarOperation->getResults(),
321  [](OpResult opr) -> Value { return opr; })));
322  rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals);
323  }
324 
325  /// In the split operations, replace block arguments uses that refer to
326  /// original operation to the block arguments of the newly created operation.
327  unsigned origNumInputs = genericOp.getNumInputs();
328  for (const auto &inputBlockArg :
329  llvm::enumerate(genericOp.getBody()->getArguments())) {
330  Value residualOpReplacementArg =
331  residualGenericOpBody->getArgument(inputBlockArg.index());
332  inputBlockArg.value().replaceUsesWithIf(
333  residualOpReplacementArg, [&](OpOperand &use) {
334  return use.getOwner()->getBlock() == residualGenericOpBody;
335  });
336 
337  Value peeledOpReplacementArg =
338  peeledGenericOpBody->getArgument(inputBlockArg.index());
339  inputBlockArg.value().replaceUsesWithIf(
340  peeledOpReplacementArg, [&](OpOperand &use) {
341  return use.getOwner()->getBlock() == peeledGenericOpBody;
342  });
343  }
344 
345  /// Before fixing up the residual operation, track what values are yielded. If
346  /// any of those are from the peeled scalar operation, the uses of the
347  /// corresponding result have to be remapped to result of the generic op for
348  /// the peeled operation.
349  SmallVector<Value> replacements;
350  for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
351  OpResult opr = yieldValue.value().dyn_cast<OpResult>();
352  if (!opr || opr.getOwner() != peeledScalarOperation)
353  replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
354  else
355  replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
356  }
357 
358  /// Update all uses of the peeled scalar operation results in the residual op
359  /// to the newly added arguments.
360  {
361  SmallVector<Value> scalarReplacements;
362  unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
363  scalarReplacements.reserve(peeledScalarOpNumResults);
364  for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
365  scalarReplacements.push_back(
366  residualGenericOpBody->getArgument(num + origNumInputs));
367  bool allUsesReplaced = false;
368  rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
369  residualGenericOpBody, &allUsesReplaced);
370  assert(!allUsesReplaced &&
371  "peeled scalar operation is erased when it wasnt expected to be");
372  }
373 
374  // Replace the original operation
375  rewriter.replaceOp(genericOp, replacements);
376  return success();
377 }
378 
380  RewritePatternSet &patterns) {
381  patterns.insert<DecomposeLinalgOp>(patterns.getContext());
382 }
Include the generated interface declarations.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
iterator begin()
Definition: Block.h:132
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:342
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns)
Populate patterns for splitting a LinalgOp with multiple statements within its payload into multiple ...
This is a value defined by a result of an operation.
Definition: Value.h:446
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:43
Block represents an ordered list of Operations.
Definition: Block.h:29
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:348
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:89
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
This method replaces the uses of the results of op with the values in newValues when a use is nested ...
OpListType & getOperations()
Definition: Block.h:126
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of &#39;this&#39; value with &#39;newValue&#39; if the given callback returns true.
Definition: Value.cpp:82
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:86
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:455
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:292
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:414
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
U dyn_cast() const
Definition: Value.h:100
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:232
bool isIndex() const
Definition: Types.cpp:28
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:307
bool empty()
Definition: Block.h:137
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:584
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
Specialization of arith.constant op that returns a floating point value.
Definition: Arith.h:64
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
static SmallVector< OpFoldResult > getGenericOpLoopRange(OpBuilder &b, GenericOp op)
Helper method to compute the range of a generic op.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:298
Type getType() const
Return the type of this value.
Definition: Value.h:118
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:80
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:871
This class represents an operand of an operation.
Definition: Value.h:251
U cast() const
Definition: Value.h:108
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:386
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:512
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
bool isa() const
Definition: Types.h:258
result_range getResults()
Definition: Operation.h:332
This class helps build Operations.
Definition: Builders.h:196
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:508
MLIRContext * getContext() const
U cast() const
Definition: Types.h:278
SmallVector< OpFoldResult > permuteValues(ArrayRef< OpFoldResult > values, AffineMap map)
Helper method to permute the list of values based on the map.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...