MLIR  18.0.0git
EliminateEmptyTensors.cpp
Go to the documentation of this file.
1 //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::linalg;
21 
22 /// Get an output operand that matches the given input operand and can be used
23 /// to eliminate a tensor.empty op.
24 static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) {
25  for (OpOperand &operand : op.getDpsInitsMutable()) {
26  // Operand must be unused.
27  if (op.payloadUsesValueFromOperand(&operand))
28  continue;
29  // Types must match.
30  if (operand.get().getType() != in->get().getType())
31  continue;
32  // Indexing maps must match.
33  if (op.getMatchingIndexingMap(&operand) != op.getMatchingIndexingMap(in))
34  continue;
35  return &operand;
36  }
37  return nullptr;
38 }
39 
41  RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
42  OpBuilder::InsertionGuard g(rewriter);
43  DominanceInfo domInfo;
44 
45  op->walk([&](LinalgOp op) {
46  // Only ops with all "parallel" iterator types are supported.
47  if (op.getNumParallelLoops() != op.getNumLoops())
48  return WalkResult::skip();
49 
50  for (OpOperand *in : op.getDpsInputOperands()) {
51  // Skip non-tensor operands.
52  if (!in->get().getType().isa<RankedTensorType>())
53  continue;
54 
55  // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
56  // equivalent tensors. I.e., stop when there are ops such as extract_slice
57  // on the path.
58  TraversalConfig config;
59  config.followEquivalentOnly = true;
60  config.alwaysIncludeLeaves = false;
61  SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
62  in->get(), /*condition=*/
63  [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
64  config);
65  if (emptyTensors.empty())
66  continue;
67 
68  // Find matching out operand.
69  OpOperand *out = getUnusedOutOperand(op, in);
70  if (!out)
71  continue;
72 
73  // Check if this transform would violate dominance.
74  if (!llvm::all_of(emptyTensors, [&](Value v) {
75  return domInfo.properlyDominates(out->get(), v.getDefiningOp());
76  }))
77  continue;
78 
79  // Replace all uses of the tensor.empty, but do not delete it yet. It will
80  // fold away later (to not invalidate DominanceInfo).
81  for (Value v : emptyTensors) {
82  assert(v.getDefiningOp<tensor::EmptyOp>() && "expected tensor.empty");
83  rewriter.replaceAllUsesWith(v, out->get());
84  }
85 
86  // Turn the "in" into an "out".
87  rewriter.updateRootInPlace(op, [&]() {
88  out->set(in->get());
89  // The original "in" could be removed entirely here (because it will no
90  // longer have any uses in the payload), but we delegate this to
91  // existing cleanup patterns that remove unused operands.
92  in->set(emptyTensors.front());
93  BlockArgument outArg = op.getMatchingBlockArgument(out);
94  assert(outArg.getUses().empty() && "expected that out has no uses");
95  BlockArgument inArg = op.getMatchingBlockArgument(in);
96  rewriter.replaceAllUsesWith(inArg, outArg);
97  assert(!op.payloadUsesValueFromOperand(in) &&
98  "expected that the in operand is now unused");
99  });
100 
101  state.resetCache();
102  }
103 
104  return WalkResult::advance();
105  });
106  return success();
107 }
static OpOperand * getUnusedOutOperand(LinalgOp op, OpOperand *in)
Get an output operand that matches the given input operand and can be used to eliminate a tensor....
This class represents an argument of a Block.
Definition: Value.h:315
A class for computing basic dominance information.
Definition: Dominance.h:121
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:134
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
State for analysis-enabled bufferization.
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26