MLIR  22.0.0git
EliminateEmptyTensors.cpp
Go to the documentation of this file.
1 //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 using namespace mlir::linalg;
20 
21 /// Get an output operand that matches the given input operand and can be used
22 /// to eliminate a tensor.empty op.
23 static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) {
24  for (OpOperand &operand : op.getDpsInitsMutable()) {
25  // Operand must be unused.
26  if (op.payloadUsesValueFromOperand(&operand))
27  continue;
28  // Types must match.
29  if (operand.get().getType() != in->get().getType())
30  continue;
31  // Indexing maps must match.
32  if (op.getMatchingIndexingMap(&operand) != op.getMatchingIndexingMap(in))
33  continue;
34  return &operand;
35  }
36  return nullptr;
37 }
38 
40  RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
41  OpBuilder::InsertionGuard g(rewriter);
42  DominanceInfo domInfo;
43 
44  op->walk([&](LinalgOp op) {
45  // Only ops with all "parallel" iterator types are supported.
46  if (op.getNumParallelLoops() != op.getNumLoops())
47  return WalkResult::skip();
48 
49  for (OpOperand *in : op.getDpsInputOperands()) {
50  // Skip non-tensor operands.
51  if (!isa<RankedTensorType>(in->get().getType()))
52  continue;
53 
54  // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
55  // equivalent tensors. I.e., stop when there are ops such as extract_slice
56  // on the path.
57  TraversalConfig config;
58  config.followEquivalentOnly = true;
59  config.alwaysIncludeLeaves = false;
60  SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
61  in, /*condition=*/
62  [&](Value val) {
63  return val.getDefiningOp<tensor::EmptyOp>() &&
64  val.getType() == in->get().getType();
65  },
66  config);
67  if (emptyTensors.empty())
68  continue;
69 
70  // Find matching out operand.
71  OpOperand *out = getUnusedOutOperand(op, in);
72  if (!out)
73  continue;
74 
75  // Check if this transform would violate dominance.
76  if (!llvm::all_of(emptyTensors, [&](Value v) {
77  return domInfo.properlyDominates(out->get(), v.getDefiningOp());
78  }))
79  continue;
80 
81  // Replace all uses of the tensor.empty, but do not delete it yet. It will
82  // fold away later (to not invalidate DominanceInfo).
83  for (Value v : emptyTensors) {
84  assert(v.getDefiningOp<tensor::EmptyOp>() && "expected tensor.empty");
85  rewriter.replaceAllUsesWith(v, out->get());
86  }
87 
88  // Turn the "in" into an "out".
89  rewriter.modifyOpInPlace(op, [&]() {
90  out->set(in->get());
91  // The original "in" could be removed entirely here (because it will no
92  // longer have any uses in the payload), but we delegate this to
93  // existing cleanup patterns that remove unused operands.
94  in->set(emptyTensors.front());
95  BlockArgument outArg = op.getMatchingBlockArgument(out);
96  assert(outArg.getUses().empty() && "expected that out has no uses");
97  BlockArgument inArg = op.getMatchingBlockArgument(in);
98  rewriter.replaceAllUsesWith(inArg, outArg);
99  assert(!op.payloadUsesValueFromOperand(in) &&
100  "expected that the in operand is now unused");
101  });
102 
103  state.resetCache();
104  }
105 
106  return WalkResult::advance();
107  });
108  return success();
109 }
static OpOperand * getUnusedOutOperand(LinalgOp op, OpOperand *in)
Get an output operand that matches the given input operand and can be used to eliminate a tensor....
This class represents an argument of a Block.
Definition: Value.h:309
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.cpp:323
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
State for analysis-enabled bufferization.
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config