MLIR  18.0.0git
EraseUnusedOperandsAndResults.cpp
Go to the documentation of this file.
1 //===- EraseUnusedOperandsAndResults.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 
13 using namespace mlir;
14 using namespace mlir::linalg;
15 
16 /// Return `true` if the `result` of an operation `genericOp` is dead.
17 static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
18  if (!result.use_empty())
19  return false;
20  // If out operand not used in payload, we can drop it.
21  OpOperand *outputOpOperand =
22  genericOp.getDpsInitOperand(result.getResultNumber());
23  if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
24  return true;
25 
26  // The out operand that is part of a payload can be dropped if
27  // these conditions are met:
28  // - Result from out operand is dead.
29  // - User of arg is yield.
30  // - outArg data is not being used by other outArgs.
31 
32  // Check block arg and cycle from out operand has a single use.
33  BlockArgument outputArg =
34  genericOp.getRegionOutputArgs()[result.getResultNumber()];
35  if (!outputArg.hasOneUse())
36  return false;
37  Operation *argUserOp = *outputArg.user_begin();
38 
39  // Check argUser has no other use.
40  if (!argUserOp->use_empty())
41  return false;
42 
43  // Check that argUser is a yield.
44  auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
45  if (!yieldOp)
46  return false;
47 
48  // Check outArg data is not being used by other outArgs.
49  if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
50  return false;
51 
52  return true;
53 }
54 
55 namespace {
56 
57 struct DeduplicateAndRemoveDeadOperandsAndResults
58  : public OpRewritePattern<GenericOp> {
59  DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx,
60  bool removeOutputs)
61  : OpRewritePattern<GenericOp>(ctx), removeOutputs(removeOutputs) {}
62 
63  LogicalResult matchAndRewrite(GenericOp genericOp,
64  PatternRewriter &rewriter) const override {
65  // Create a map from argument position in the original op to the argument
66  // position in the new op. If the argument is dropped it wont have an entry.
67  SmallVector<OpOperand *> droppedOpOperands;
68 
69  // Information needed to build the new op.
70  SmallVector<Value> newInputOperands, newOutputOperands;
71  SmallVector<AffineMap> newIndexingMaps;
72 
73  // Gather information about duplicate input operands.
74  llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
75  deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
76  newIndexingMaps);
77 
78  // Gather information about the dropped outputs.
79  llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
80  deduplicateOutputOperands(genericOp, droppedOpOperands,
81  newOutputOperands, newIndexingMaps);
82 
83  // Check if there is any change to operands.
84  if (newInputOperands.size() + newOutputOperands.size() ==
85  genericOp->getNumOperands())
86  return failure();
87 
88  // Create the new op with the body being empty.
89  Location loc = genericOp.getLoc();
90  SmallVector<Type> newResultTypes;
91  for (Value v : newOutputOperands)
92  if (isa<TensorType>(v.getType()))
93  newResultTypes.push_back(v.getType());
94  auto newOp = rewriter.create<GenericOp>(
95  loc, newResultTypes, newInputOperands, newOutputOperands,
96  rewriter.getAffineMapArrayAttr(newIndexingMaps),
97  genericOp.getIteratorTypes(), genericOp.getDocAttr(),
98  genericOp.getLibraryCallAttr(),
99  [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
100  return;
101  });
102  // Copy over unknown attributes. They might be load bearing for some flow.
103  ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
104  for (NamedAttribute kv : genericOp->getAttrs())
105  if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
106  newOp->setAttr(kv.getName(), kv.getValue());
107 
108  // Fix up the payload of the canonicalized operation.
109  populateOpPayload(genericOp, newOp, origInsToNewInsPos,
110  origOutsToNewOutsPos, rewriter);
111 
112  // Replace all live uses of the op.
113  SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
114  for (const auto &result : llvm::enumerate(genericOp.getResults())) {
115  auto it = origOutsToNewOutsPos.find(result.index());
116  if (it == origOutsToNewOutsPos.end())
117  continue;
118  replacementsVals[result.index()] = newOp.getResult(it->second);
119  }
120  rewriter.replaceOp(genericOp, replacementsVals);
121  return success();
122  }
123 
124 private:
125  /// If unset, outputs are not modified by this pattern.
126  bool removeOutputs;
127 
128  // Deduplicate input operands, and return the
129  // - Mapping from operand position in the original op, to operand position in
130  // the canonicalized op.
131  // - The preserved input operands list (by reference).
132  llvm::SmallDenseMap<unsigned, unsigned>
133  deduplicateInputOperands(GenericOp genericOp,
134  SmallVector<OpOperand *> &droppedOpOperands,
135  SmallVector<Value> &newInputOperands,
136  SmallVector<AffineMap> &newIndexingMaps) const {
137  llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
138  llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
139  for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
140  OpOperand *inputOpOperand = en.value();
141  // Check if operand is dead and if dropping the indexing map makes the
142  // loops to shape computation invalid.
143  if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
144  // Add the current operands to the list of potentially droppable
145  // operands. If it cannot be dropped, this needs to be popped back.
146  droppedOpOperands.push_back(inputOpOperand);
147  if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
148  continue;
149  droppedOpOperands.pop_back();
150  }
151 
152  // Check if this operand is a duplicate.
153  AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
154  auto it = dedupedInputs.find(
155  std::make_pair(inputOpOperand->get(), indexingMap));
156  if (it != dedupedInputs.end()) {
157  origToNewPos[en.index()] = it->second;
158  droppedOpOperands.push_back(inputOpOperand);
159  continue;
160  }
161 
162  // This is a preserved argument.
163  origToNewPos[en.index()] = newInputOperands.size();
164  dedupedInputs[{inputOpOperand->get(), indexingMap}] =
165  newInputOperands.size();
166  newInputOperands.push_back(inputOpOperand->get());
167  newIndexingMaps.push_back(indexingMap);
168  }
169  return origToNewPos;
170  }
171 
172  // Deduplicate output operands, and return the
173  // - Mapping from operand position in the original op, to operand position in
174  // the canonicalized op.
175  // - The preserved output operands list (by reference).
176  llvm::SmallDenseMap<unsigned, unsigned>
177  deduplicateOutputOperands(GenericOp genericOp,
178  SmallVector<OpOperand *> &droppedOpOperands,
179  SmallVector<Value> &newOutputOperands,
180  SmallVector<AffineMap> &newIndexingMaps) const {
181  llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
182  llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
183  dedupedOutpts;
184  // If the op doesn't have tensor semantics or outputs should not be removed,
185  // keep all the outputs as preserved.
186  if (!genericOp.hasTensorSemantics() || !removeOutputs) {
187  for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
188  origToNewPos[en.index()] = newOutputOperands.size();
189  newOutputOperands.push_back(en.value().get());
190  newIndexingMaps.push_back(
191  genericOp.getMatchingIndexingMap(&en.value()));
192  }
193  return origToNewPos;
194  }
195  // Output argument can be dropped if the result has
196  // - no users, and
197  // - it is not used in the payload, and
198  // - the corresponding indexing maps are not needed for loop bound
199  // computation.
200  auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
201  for (const auto &outputOpOperand :
202  llvm::enumerate(genericOp.getDpsInitsMutable())) {
203  OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
204  AffineMap indexingMap =
205  genericOp.getMatchingIndexingMap(&outputOpOperand.value());
206  auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
207  yieldOp->getOperand(outputOpOperand.index()));
208  if (isResultValueDead(genericOp, result)) {
209  // Check if the opoperand can be dropped without affecting loop
210  // bound computation. Add the operand to the list of dropped op
211  // operand for checking. If it cannot be dropped, need to pop the
212  // value back.
213  droppedOpOperands.push_back(&outputOpOperand.value());
214  if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
215  continue;
216  }
217  droppedOpOperands.pop_back();
218  }
219 
220  if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
221  // The out operand can also be dropped if it is computed redundantly
222  // by another result, the conditions for that are
223  // - The same operand is used as the out operand
224  // - The same indexing map is used
225  // - The same yield value is used.
226  auto it = dedupedOutpts.find(key);
227  if (it != dedupedOutpts.end()) {
228  origToNewPos[outputOpOperand.index()] = it->second;
229  droppedOpOperands.push_back(&outputOpOperand.value());
230  continue;
231  }
232  }
233 
234  origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
235  dedupedOutpts[key] = newOutputOperands.size();
236  newOutputOperands.push_back(outputOpOperand.value().get());
237  newIndexingMaps.push_back(
238  genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
239  }
240  return origToNewPos;
241  }
242 
243  // Populate the body of the canonicalized operation.
244  void populateOpPayload(
245  GenericOp genericOp, GenericOp newOp,
246  const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
247  const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
248  PatternRewriter &rewriter) const {
249  // Merge the body of the original op with the new op.
250  Block *newOpBlock = &newOp.getRegion().front();
251  assert(newOpBlock->empty() && "expected new op to have an empty payload");
252  Block *origOpBlock = &genericOp.getRegion().front();
253  SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
254 
255  // Replace all arguments in the original op, with arguments from the
256  // canonicalized op.
257  auto updateReplacements =
258  [&](SmallVector<OpOperand *> &origOperands,
259  SmallVector<OpOperand *> &newOperands,
260  const llvm::SmallDenseMap<unsigned, unsigned> &map) {
261  for (const auto &origOperand : llvm::enumerate(origOperands)) {
262  auto it = map.find(origOperand.index());
263  if (it == map.end())
264  continue;
265  OpOperand *newOperand = newOperands[it->second];
266  replacements[origOperand.value()->getOperandNumber()] =
267  newOpBlock->getArgument(newOperand->getOperandNumber());
268  }
269  };
270 
271  SmallVector<OpOperand *> origInputOperands =
272  genericOp.getDpsInputOperands();
273  SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
274  updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
275 
276  SmallVector<OpOperand *> origOutputOperands =
277  llvm::to_vector(llvm::map_range(genericOp.getDpsInitsMutable(),
278  [](OpOperand &o) { return &o; }));
279  SmallVector<OpOperand *> newOutputOperands =
280  llvm::to_vector(llvm::map_range(newOp.getDpsInitsMutable(),
281  [](OpOperand &o) { return &o; }));
282  updateReplacements(origOutputOperands, newOutputOperands,
283  origOutsToNewOutsPos);
284 
285  // Drop the unused yield args.
286  if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
287  OpBuilder::InsertionGuard g(rewriter);
288  YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
289  rewriter.setInsertionPoint(origYieldOp);
290 
291  SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
292  for (const auto &yieldOpOperands :
293  llvm::enumerate(origYieldOp.getValues())) {
294  auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
295  if (it == origOutsToNewOutsPos.end())
296  continue;
297  newYieldVals[it->second] = yieldOpOperands.value();
298  }
299  rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
300  }
301 
302  rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
303  }
304 };
305 
306 /// Remove unused cycles.
307 /// We can remove unused cycle within a payload of generic region
308 /// if these conditions are met:
309 /// - Result from out operand is dead.
310 /// - Block arg from out operand has a single use in the %cycle
311 /// instruction.
312 /// - Cycle has a single use and it is in yield.
313 struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
315 
316  LogicalResult matchAndRewrite(GenericOp genericOp,
317  PatternRewriter &rewriter) const override {
318 
319  // If the op doesnt have tensor semantics, preserve the outputs as is.
320  if (!genericOp.hasTensorSemantics())
321  return failure();
322 
323  bool hasRemovedCycles = false;
324  // Iterate over output operands and remove any unused cycles.
325  for (const auto &outputOpOperand :
326  llvm::enumerate(genericOp.getDpsInits())) {
327 
328  // Check that result from out operand is dead.
329  Value result = genericOp.getResult(outputOpOperand.index());
330  if (!result.use_empty())
331  continue;
332 
333  // Check that outputArg has one use in cycle.
334  BlockArgument outputArg =
335  genericOp.getRegionOutputArgs()[outputOpOperand.index()];
336  if (!outputArg.hasOneUse())
337  continue;
338 
339  // Check cycle has at most one use.
340  Operation *cycleOp = *outputArg.user_begin();
341  if (!cycleOp->hasOneUse())
342  continue;
343 
344  // Check that the cycleUser is a yield.
345  Operation *cycleUserOp = *cycleOp->user_begin();
346  if (!isa<linalg::YieldOp>(cycleUserOp))
347  continue;
348 
349  // Check that argIndex matches yieldIndex, else data is being used.
350  if (cycleUserOp->getOperand(outputOpOperand.index()) !=
351  cycleOp->getResult(0))
352  continue;
353 
354  // Directly replace the cycle with the blockArg such that
355  // Deduplicate pattern can eliminate it along with unused yield.
356  rewriter.replaceOp(cycleOp, outputArg);
357  rewriter.updateRootInPlace(genericOp, [] {});
358  hasRemovedCycles = true;
359  }
360 
361  if (hasRemovedCycles) {
362  return success();
363  }
364 
365  return failure();
366  }
367 };
368 
369 /// Fold uses of duplicate inputs in the body of a linalg.generic. E.g.:
370 /// ```
371 /// linalg.generic ins(%a, %b, %a, %b) outs(%a)
372 /// ^bb0(%in0, %in1, %in2, %in3, %out1)
373 /// ```
374 /// Assuming that all %a and %b have the same index map:
375 /// * All uses of %in0 and %in2 are replaced with %out1
376 /// * All uses of %in1 are replaced with %in3
377 /// This pattern can enable additional canonicalizations: In the above example,
378 /// %in0, %in1 and %in3 have no uses anymore and their corresponding operands
379 /// can be folded away. This pattern does not modify uses of output block args.
380 struct FoldDuplicateInputBbArgs : public OpRewritePattern<GenericOp> {
382 
383  LogicalResult matchAndRewrite(GenericOp genericOp,
384  PatternRewriter &rewriter) const override {
385  // Find replacement bbArgs for all input bbArg.
386  DenseMap<int, int> replacements;
387  for (int i = 0; i < genericOp.getNumDpsInputs(); ++i) {
388  // Skip bbArgs that have no uses.
389  if (genericOp.getBody()->getArgument(i).getUses().empty())
390  continue;
391  // Find replacement bbArg. This can be an input or an output bbArg.
392  for (int j = genericOp->getNumOperands() - 1; j > i; --j) {
393  if (genericOp->getOperand(i) == genericOp->getOperand(j) &&
394  genericOp.getIndexingMapsArray()[i] ==
395  genericOp.getIndexingMapsArray()[j]) {
396  replacements[i] = j;
397  break;
398  }
399  }
400  }
401 
402  // Stop here if no replacements were found.
403  if (replacements.empty())
404  return failure();
405 
406  // Rewrite the op.
407  rewriter.updateRootInPlace(genericOp, [&]() {
408  for (auto [before, after] : replacements) {
409  BlockArgument bbArg = genericOp.getBody()->getArgument(before);
410  BlockArgument replacement = genericOp.getBody()->getArgument(after);
411  rewriter.replaceAllUsesWith(bbArg, replacement);
412  }
413  });
414 
415  return success();
416  }
417 };
418 
419 } // namespace
420 
422  RewritePatternSet &patterns) {
423  patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
424  patterns.getContext(), /*removeOutputs=*/true);
425  patterns.insert<RemoveUnusedCycleInGenericOp>(patterns.getContext());
426 }
427 
429  RewritePatternSet &patterns) {
430  patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
431  patterns.getContext(), /*removeOutputs=*/false);
432  patterns.insert<FoldDuplicateInputBbArgs>(patterns.getContext());
433 }
static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result)
Return true if the result of an operation genericOp is dead.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
bool empty()
Definition: Block.h:141
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
unsigned getNumArguments()
Definition: Block.h:121
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
Operation & front()
Definition: Block.h:146
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:198
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:453
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:465
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:831
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:828
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
user_iterator user_begin()
Definition: Operation.h:848
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:214
user_iterator user_begin() const
Definition: Value.h:222
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:211
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.