MLIR  19.0.0git
EraseUnusedOperandsAndResults.cpp
Go to the documentation of this file.
1 //===- EraseUnusedOperandsAndResults.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 
13 using namespace mlir;
14 using namespace mlir::linalg;
15 
16 /// Return `true` if the `result` of an operation `genericOp` is dead.
17 static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
18  if (!result.use_empty())
19  return false;
20  // If out operand not used in payload, we can drop it.
21  OpOperand *outputOpOperand =
22  genericOp.getDpsInitOperand(result.getResultNumber());
23  if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
24  return true;
25 
26  // The out operand that is part of a payload can be dropped if
27  // these conditions are met:
28  // - Result from out operand is dead.
29  // - User of arg is yield.
30  // - outArg data is not being used by other outArgs.
31 
32  // Check block arg and cycle from out operand has a single use.
33  BlockArgument outputArg =
34  genericOp.getRegionOutputArgs()[result.getResultNumber()];
35  if (!outputArg.hasOneUse())
36  return false;
37  Operation *argUserOp = *outputArg.user_begin();
38 
39  // Check argUser has no other use.
40  if (!argUserOp->use_empty())
41  return false;
42 
43  // Check that argUser is a yield.
44  auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
45  if (!yieldOp)
46  return false;
47 
48  // Check outArg data is not being used by other outArgs.
49  if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
50  return false;
51 
52  return true;
53 }
54 
55 namespace {
56 
57 struct DeduplicateAndRemoveDeadOperandsAndResults
58  : public OpRewritePattern<GenericOp> {
59  DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx,
60  bool removeOutputs)
61  : OpRewritePattern<GenericOp>(ctx), removeOutputs(removeOutputs) {}
62 
63  LogicalResult matchAndRewrite(GenericOp genericOp,
64  PatternRewriter &rewriter) const override {
65  // Create a map from argument position in the original op to the argument
66  // position in the new op. If the argument is dropped it wont have an entry.
67  SmallVector<OpOperand *> droppedOpOperands;
68 
69  // Information needed to build the new op.
70  SmallVector<Value> newInputOperands, newOutputOperands;
71  SmallVector<AffineMap> newIndexingMaps;
72 
73  // Gather information about duplicate input operands.
74  llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
75  deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
76  newIndexingMaps);
77 
78  // Gather information about the dropped outputs.
79  llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
80  deduplicateOutputOperands(genericOp, droppedOpOperands,
81  newOutputOperands, newIndexingMaps);
82 
83  // Check if there is any change to operands.
84  if (newInputOperands.size() + newOutputOperands.size() ==
85  genericOp->getNumOperands())
86  return failure();
87 
88  // Create the new op with the body being empty.
89  Location loc = genericOp.getLoc();
90  SmallVector<Type> newResultTypes;
91  for (Value v : newOutputOperands)
92  if (isa<TensorType>(v.getType()))
93  newResultTypes.push_back(v.getType());
94  auto newOp = rewriter.create<GenericOp>(
95  loc, newResultTypes, newInputOperands, newOutputOperands,
96  rewriter.getAffineMapArrayAttr(newIndexingMaps),
97  genericOp.getIteratorTypes(), genericOp.getDocAttr(),
98  genericOp.getLibraryCallAttr(),
99  [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
100  return;
101  });
102  // Copy over unknown attributes. They might be load bearing for some flow.
103  ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
104  for (NamedAttribute kv : genericOp->getAttrs())
105  if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
106  newOp->setAttr(kv.getName(), kv.getValue());
107 
108  // Fix up the payload of the canonicalized operation.
109  populateOpPayload(genericOp, newOp, origInsToNewInsPos,
110  origOutsToNewOutsPos, rewriter);
111 
112  // Replace all live uses of the op.
113  SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
114  for (const auto &result : llvm::enumerate(genericOp.getResults())) {
115  auto it = origOutsToNewOutsPos.find(result.index());
116  if (it == origOutsToNewOutsPos.end())
117  continue;
118  replacementsVals[result.index()] = newOp.getResult(it->second);
119  }
120  rewriter.replaceOp(genericOp, replacementsVals);
121  return success();
122  }
123 
124 private:
125  /// If unset, outputs are not modified by this pattern.
126  bool removeOutputs;
127 
128  // Deduplicate input operands, and return the
129  // - Mapping from operand position in the original op, to operand position in
130  // the canonicalized op.
131  // - The preserved input operands list (by reference).
132  llvm::SmallDenseMap<unsigned, unsigned>
133  deduplicateInputOperands(GenericOp genericOp,
134  SmallVector<OpOperand *> &droppedOpOperands,
135  SmallVector<Value> &newInputOperands,
136  SmallVector<AffineMap> &newIndexingMaps) const {
137  llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
138  llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
139  for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
140  OpOperand *inputOpOperand = en.value();
141  // Check if operand is dead and if dropping the indexing map makes the
142  // loops to shape computation invalid.
143  if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
144  // Add the current operands to the list of potentially droppable
145  // operands. If it cannot be dropped, this needs to be popped back.
146  droppedOpOperands.push_back(inputOpOperand);
147  if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
148  continue;
149  droppedOpOperands.pop_back();
150  }
151 
152  // Check if this operand is a duplicate.
153  AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
154  auto it = dedupedInputs.find(
155  std::make_pair(inputOpOperand->get(), indexingMap));
156  if (it != dedupedInputs.end()) {
157  origToNewPos[en.index()] = it->second;
158  droppedOpOperands.push_back(inputOpOperand);
159  continue;
160  }
161 
162  // This is a preserved argument.
163  origToNewPos[en.index()] = newInputOperands.size();
164  dedupedInputs[{inputOpOperand->get(), indexingMap}] =
165  newInputOperands.size();
166  newInputOperands.push_back(inputOpOperand->get());
167  newIndexingMaps.push_back(indexingMap);
168  }
169  return origToNewPos;
170  }
171 
172  // Deduplicate output operands, and return the
173  // - Mapping from operand position in the original op, to operand position in
174  // the canonicalized op.
175  // - The preserved output operands list (by reference).
176  llvm::SmallDenseMap<unsigned, unsigned>
177  deduplicateOutputOperands(GenericOp genericOp,
178  SmallVector<OpOperand *> &droppedOpOperands,
179  SmallVector<Value> &newOutputOperands,
180  SmallVector<AffineMap> &newIndexingMaps) const {
181  llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
182  llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
183  dedupedOutpts;
184  // If the op doesn't have tensor semantics or outputs should not be removed,
185  // keep all the outputs as preserved.
186  if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
187  for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
188  origToNewPos[en.index()] = newOutputOperands.size();
189  newOutputOperands.push_back(en.value().get());
190  newIndexingMaps.push_back(
191  genericOp.getMatchingIndexingMap(&en.value()));
192  }
193  return origToNewPos;
194  }
195  // Output argument can be dropped if the result has
196  // - no users, and
197  // - it is not used in the payload, and
198  // - the corresponding indexing maps are not needed for loop bound
199  // computation.
200  auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
201  for (const auto &outputOpOperand :
202  llvm::enumerate(genericOp.getDpsInitsMutable())) {
203  OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
204  AffineMap indexingMap =
205  genericOp.getMatchingIndexingMap(&outputOpOperand.value());
206  auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
207  yieldOp->getOperand(outputOpOperand.index()));
208  if (isResultValueDead(genericOp, result)) {
209  // Check if the opoperand can be dropped without affecting loop
210  // bound computation. Add the operand to the list of dropped op
211  // operand for checking. If it cannot be dropped, need to pop the
212  // value back.
213  droppedOpOperands.push_back(&outputOpOperand.value());
214  if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
215  continue;
216  }
217  droppedOpOperands.pop_back();
218  }
219 
220  if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
221  // The out operand can also be dropped if it is computed redundantly
222  // by another result, the conditions for that are
223  // - The same operand is used as the out operand
224  // - The same indexing map is used
225  // - The same yield value is used.
226  auto it = dedupedOutpts.find(key);
227  if (it != dedupedOutpts.end()) {
228  origToNewPos[outputOpOperand.index()] = it->second;
229  droppedOpOperands.push_back(&outputOpOperand.value());
230  continue;
231  }
232  }
233 
234  origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
235  dedupedOutpts[key] = newOutputOperands.size();
236  newOutputOperands.push_back(outputOpOperand.value().get());
237  newIndexingMaps.push_back(
238  genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
239  }
240  return origToNewPos;
241  }
242 
243  // Populate the body of the canonicalized operation.
244  void populateOpPayload(
245  GenericOp genericOp, GenericOp newOp,
246  const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
247  const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
248  PatternRewriter &rewriter) const {
249  // Merge the body of the original op with the new op.
250  Block *newOpBlock = &newOp.getRegion().front();
251  assert(newOpBlock->empty() && "expected new op to have an empty payload");
252  Block *origOpBlock = &genericOp.getRegion().front();
253  SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
254 
255  // Replace all arguments in the original op, with arguments from the
256  // canonicalized op.
257  auto updateReplacements =
258  [&](SmallVector<OpOperand *> &origOperands,
259  SmallVector<OpOperand *> &newOperands,
260  const llvm::SmallDenseMap<unsigned, unsigned> &map) {
261  for (const auto &origOperand : llvm::enumerate(origOperands)) {
262  auto it = map.find(origOperand.index());
263  if (it == map.end())
264  continue;
265  OpOperand *newOperand = newOperands[it->second];
266  replacements[origOperand.value()->getOperandNumber()] =
267  newOpBlock->getArgument(newOperand->getOperandNumber());
268  }
269  };
270 
271  SmallVector<OpOperand *> origInputOperands =
272  genericOp.getDpsInputOperands();
273  SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
274  updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
275 
276  SmallVector<OpOperand *> origOutputOperands =
277  llvm::to_vector(llvm::map_range(genericOp.getDpsInitsMutable(),
278  [](OpOperand &o) { return &o; }));
279  SmallVector<OpOperand *> newOutputOperands =
280  llvm::to_vector(llvm::map_range(newOp.getDpsInitsMutable(),
281  [](OpOperand &o) { return &o; }));
282  updateReplacements(origOutputOperands, newOutputOperands,
283  origOutsToNewOutsPos);
284 
285  // Drop the unused yield args.
286  if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
287  OpBuilder::InsertionGuard g(rewriter);
288  YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
289  rewriter.setInsertionPoint(origYieldOp);
290 
291  SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
292  for (const auto &yieldOpOperands :
293  llvm::enumerate(origYieldOp.getValues())) {
294  auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
295  if (it == origOutsToNewOutsPos.end())
296  continue;
297  newYieldVals[it->second] = yieldOpOperands.value();
298  }
299  rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
300  }
301 
302  rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
303  }
304 };
305 
306 /// Remove unused cycles.
307 /// We can remove unused cycle within a payload of generic region
308 /// if these conditions are met:
309 /// - Result from out operand is dead.
310 /// - Block arg from out operand has a single use in the %cycle
311 /// instruction.
312 /// - Cycle has a single use and it is in yield.
313 struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
315 
316  LogicalResult matchAndRewrite(GenericOp genericOp,
317  PatternRewriter &rewriter) const override {
318 
319  // If the op doesnt have tensor semantics, preserve the outputs as is.
320  if (!genericOp.hasPureTensorSemantics())
321  return failure();
322 
323  bool hasRemovedCycles = false;
324  // Iterate over output operands and remove any unused cycles.
325  for (const auto &outputOpOperand :
326  llvm::enumerate(genericOp.getDpsInits())) {
327 
328  // Check that result from out operand is dead.
329  Value result = genericOp.getResult(outputOpOperand.index());
330  if (!result.use_empty())
331  continue;
332 
333  // Check that outputArg has one use in cycle.
334  BlockArgument outputArg =
335  genericOp.getRegionOutputArgs()[outputOpOperand.index()];
336  if (!outputArg.hasOneUse())
337  continue;
338 
339  // Check cycle has at most one use.
340  Operation *cycleOp = *outputArg.user_begin();
341  if (!cycleOp->hasOneUse())
342  continue;
343 
344  // Check that the cycleUser is a yield.
345  Operation *cycleUserOp = *cycleOp->user_begin();
346  if (!isa<linalg::YieldOp>(cycleUserOp))
347  continue;
348 
349  // Check that argIndex matches yieldIndex, else data is being used.
350  if (cycleUserOp->getOperand(outputOpOperand.index()) !=
351  cycleOp->getResult(0))
352  continue;
353 
354  // Directly replace the cycle with the blockArg such that
355  // Deduplicate pattern can eliminate it along with unused yield.
356  rewriter.replaceOp(cycleOp, outputArg);
357  rewriter.modifyOpInPlace(genericOp, [] {});
358  hasRemovedCycles = true;
359  }
360 
361  if (hasRemovedCycles) {
362  return success();
363  }
364 
365  return failure();
366  }
367 };
368 
369 /// Fold uses of duplicate inputs in the body of a linalg.generic. E.g.:
370 /// ```
371 /// linalg.generic ins(%a, %b, %a, %b) outs(%a)
372 /// ^bb0(%in0, %in1, %in2, %in3, %out1)
373 /// ```
374 /// Assuming that all %a and %b have the same index map:
375 /// * All uses of %in0 and %in2 are replaced with %out1
376 /// * All uses of %in1 are replaced with %in3
377 /// This pattern can enable additional canonicalizations: In the above example,
378 /// %in0, %in1 and %in3 have no uses anymore and their corresponding operands
379 /// can be folded away. This pattern does not modify uses of output block args.
380 struct FoldDuplicateInputBbArgs : public OpRewritePattern<GenericOp> {
382 
383  LogicalResult matchAndRewrite(GenericOp genericOp,
384  PatternRewriter &rewriter) const override {
385  // Find replacement bbArgs for all input bbArg.
386  DenseMap<int, int> replacements;
387  for (int i = 0; i < genericOp.getNumDpsInputs(); ++i) {
388  // Skip bbArgs that have no uses.
389  if (genericOp.getBody()->getArgument(i).getUses().empty())
390  continue;
391  // Find replacement bbArg. This can be an input or an output bbArg.
392  for (int j = genericOp->getNumOperands() - 1; j > i; --j) {
393  if (genericOp->getOperand(i) == genericOp->getOperand(j) &&
394  genericOp.getIndexingMapsArray()[i] ==
395  genericOp.getIndexingMapsArray()[j]) {
396  replacements[i] = j;
397  break;
398  }
399  }
400  }
401 
402  // Stop here if no replacements were found.
403  if (replacements.empty())
404  return failure();
405 
406  // Rewrite the op.
407  rewriter.modifyOpInPlace(genericOp, [&]() {
408  for (auto [before, after] : replacements) {
409  BlockArgument bbArg = genericOp.getBody()->getArgument(before);
410  BlockArgument replacement = genericOp.getBody()->getArgument(after);
411  rewriter.replaceAllUsesWith(bbArg, replacement);
412  }
413  });
414 
415  return success();
416  }
417 };
418 
419 } // namespace
420 
422  RewritePatternSet &patterns) {
423  patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
424  patterns.getContext(), /*removeOutputs=*/true);
425  patterns.insert<RemoveUnusedCycleInGenericOp>(patterns.getContext());
426 }
427 
429  RewritePatternSet &patterns) {
430  patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
431  patterns.getContext(), /*removeOutputs=*/false);
432  patterns.insert<FoldDuplicateInputBbArgs>(patterns.getContext());
433 }
static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result)
Return true if the result of an operation genericOp is dead.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:30
bool empty()
Definition: Block.h:145
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
Operation & front()
Definition: Block.h:150
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:848
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
user_iterator user_begin()
Definition: Operation.h:865
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:930
MLIRContext * getContext() const
Definition: PatternMatch.h:822
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
user_iterator user_begin() const
Definition: Value.h:226
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.