MLIR  19.0.0git
EliminateEmptyTensors.cpp
Go to the documentation of this file.
1 //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::linalg;
21 
22 /// Get an output operand that matches the given input operand and can be used
23 /// to eliminate a tensor.empty op.
24 static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) {
25  for (OpOperand &operand : op.getDpsInitsMutable()) {
26  // Operand must be unused.
27  if (op.payloadUsesValueFromOperand(&operand))
28  continue;
29  // Types must match.
30  if (operand.get().getType() != in->get().getType())
31  continue;
32  // Indexing maps must match.
33  if (op.getMatchingIndexingMap(&operand) != op.getMatchingIndexingMap(in))
34  continue;
35  return &operand;
36  }
37  return nullptr;
38 }
39 
41  RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
42  OpBuilder::InsertionGuard g(rewriter);
43  DominanceInfo domInfo;
44 
45  op->walk([&](LinalgOp op) {
46  // Only ops with all "parallel" iterator types are supported.
47  if (op.getNumParallelLoops() != op.getNumLoops())
48  return WalkResult::skip();
49 
50  for (OpOperand *in : op.getDpsInputOperands()) {
51  // Skip non-tensor operands.
52  if (!isa<RankedTensorType>(in->get().getType()))
53  continue;
54 
55  // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
56  // equivalent tensors. I.e., stop when there are ops such as extract_slice
57  // on the path.
58  TraversalConfig config;
59  config.followEquivalentOnly = true;
60  config.alwaysIncludeLeaves = false;
61  SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
62  in->get(), /*condition=*/
63  [&](Value val) {
64  return val.getDefiningOp<tensor::EmptyOp>() &&
65  val.getType() == in->get().getType();
66  },
67  config);
68  if (emptyTensors.empty())
69  continue;
70 
71  // Find matching out operand.
72  OpOperand *out = getUnusedOutOperand(op, in);
73  if (!out)
74  continue;
75 
76  // Check if this transform would violate dominance.
77  if (!llvm::all_of(emptyTensors, [&](Value v) {
78  return domInfo.properlyDominates(out->get(), v.getDefiningOp());
79  }))
80  continue;
81 
82  // Replace all uses of the tensor.empty, but do not delete it yet. It will
83  // fold away later (to not invalidate DominanceInfo).
84  for (Value v : emptyTensors) {
85  assert(v.getDefiningOp<tensor::EmptyOp>() && "expected tensor.empty");
86  rewriter.replaceAllUsesWith(v, out->get());
87  }
88 
89  // Turn the "in" into an "out".
90  rewriter.modifyOpInPlace(op, [&]() {
91  out->set(in->get());
92  // The original "in" could be removed entirely here (because it will no
93  // longer have any uses in the payload), but we delegate this to
94  // existing cleanup patterns that remove unused operands.
95  in->set(emptyTensors.front());
96  BlockArgument outArg = op.getMatchingBlockArgument(out);
97  assert(outArg.getUses().empty() && "expected that out has no uses");
98  BlockArgument inArg = op.getMatchingBlockArgument(in);
99  rewriter.replaceAllUsesWith(inArg, outArg);
100  assert(!op.payloadUsesValueFromOperand(in) &&
101  "expected that the in operand is now unused");
102  });
103 
104  state.resetCache();
105  }
106 
107  return WalkResult::advance();
108  });
109  return success();
110 }
static OpOperand * getUnusedOutOperand(LinalgOp op, OpOperand *in)
Get an output operand that matches the given input operand and can be used to eliminate a tensor....
This class represents an argument of a Block.
Definition: Value.h:319
A class for computing basic dominance information.
Definition: Dominance.h:136
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:149
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
State for analysis-enabled bufferization.
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26