MLIR  21.0.0git
EmptyTensorElimination.cpp
Go to the documentation of this file.
1 //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/Dominance.h"
19 
20 namespace mlir {
21 namespace bufferization {
22 #define GEN_PASS_DEF_EMPTYTENSORELIMINATIONPASS
23 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
24 } // namespace bufferization
25 } // namespace mlir
26 
27 using namespace mlir;
28 using namespace mlir::bufferization;
29 
30 /// Return true if all `neededValues` are in scope at the given
31 /// `insertionPoint`.
32 static bool
34  Operation *insertionPoint,
35  const SmallVector<Value> &neededValues) {
36  for (Value val : neededValues) {
37  if (auto bbArg = dyn_cast<BlockArgument>(val)) {
38  Block *owner = bbArg.getOwner();
39  if (!owner->findAncestorOpInBlock(*insertionPoint))
40  return false;
41  } else {
42  auto opResult = cast<OpResult>(val);
43  if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
44  return false;
45  }
46  }
47  return true;
48 }
49 
50 /// Find a valid insertion point for a replacement of `emptyTensorOp`'s
51 /// use of `user` operation, assuming that the replacement may use any
52 /// value from `neededValues`.
53 static Operation *
55  const SmallVector<Value> &neededValues) {
56  DominanceInfo domInfo;
57  Operation *candidateInsertionPoint = emptyTensorOp;
58 
59  // Gather all possible insertion points: the location of
60  // `candidateInsertionPoint` and right after the definition of each value in
61  // `neededValues`.
62  SmallVector<Operation *> insertionPointCandidates;
63  insertionPointCandidates.push_back(candidateInsertionPoint);
64  for (Value val : neededValues) {
65  // Note: The anchor op is using all of `neededValues`, so:
66  // * in case of a block argument: There must be at least one op in the block
67  // (the anchor op or one of its parents).
68  // * in case of an OpResult: There must be at least one op right after the
69  // defining op (the anchor op or one of its
70  // parents).
71  if (auto bbArg = dyn_cast<BlockArgument>(val)) {
72  insertionPointCandidates.push_back(
73  &bbArg.getOwner()->getOperations().front());
74  } else {
75  insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
76  }
77  }
78 
79  // Select first matching insertion point.
80  for (Operation *insertionPoint : insertionPointCandidates) {
81  // Check if all needed values are in scope.
82  if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
83  neededValues))
84  continue;
85  // Check if the insertion point is before the use to be replaced.
86  if (!domInfo.dominates(insertionPoint, user))
87  continue;
88  return insertionPoint;
89  }
90 
91  // No suitable insertion point was found.
92  return nullptr;
93 }
94 
96  SubsetInsertionOpInterface op,
97  tensor::EmptyOp emptyTensorOp,
98  Operation *user) {
99 
100  mlir::OpBuilder::InsertionGuard guard(rewriter);
101  // All values that are needed to create the replacement op.
102  SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction();
103  // Find a suitable insertion point. If no suitable insertion point
104  // for the replacement can be found, return an empty value to skip
105  // this replacement.
106  Operation *insertionPoint =
107  findValidInsertionPoint(emptyTensorOp, user, neededValues);
108  if (!insertionPoint)
109  return {};
110 
111  rewriter.setInsertionPoint(insertionPoint);
112  Value replacement =
113  op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
114  return replacement;
115 }
116 
118  RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
119  ControlBuildSubsetExtractionFn subsetsExtractionFn) {
120  OpBuilder::InsertionGuard g(rewriter);
121  llvm::DenseSet<OpOperand *> visitedOpOperands;
122  op->walk([&](SubsetInsertionOpInterface op) {
123  visitedOpOperands.clear();
124  OpOperand &source = op.getSourceOperand();
125  // Skip operands that do not bufferize inplace. "tensor.empty" could still
126  // be replaced, but the transformation may not be beneficial.
127  if (!state.isInPlace(source))
128  return WalkResult::skip();
129 
130  // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
131  // equivalent tensors. I.e., stop when there are ops such as extract_slice
132  // on the path.
134  config.followEquivalentOnly = true;
135  config.alwaysIncludeLeaves = false;
136  // Replace only if the types match or are static <-> dynamic casts. We do
137  // not support slices or reshapes.
138  // TODO: This could be extended to support IR such as:
139  // %0 = tensor.empty() : tensor<128xf32>
140  // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
141  // %2 = tensor.expand_shape %1 ...
142  // %3 = tensor.insert_slice %2 into ...
143  config.followSameTypeOrCastsOnly = true;
144  SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
145  &source, /*condition=*/
146  [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
147  &visitedOpOperands);
148 
149  for (Value v : emptyTensors) {
150  auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
151  assert(emptyTensorOp && "expected tensor.empty op");
152  // Find the use to be replaced from the use-def chain.
153  auto iter = llvm::find_if(
154  visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
155  return llvm::count(emptyTensorOp->getUses(), *opOperand);
156  });
157 
158  assert(iter != visitedOpOperands.end() && "could not find use");
159  OpOperand *useToBeReplaced = *iter;
160  Operation *user = useToBeReplaced->getOwner();
161  auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
162  if (!replacement)
163  continue;
164  if (emptyTensorOp == replacement.getDefiningOp())
165  continue;
166  if (replacement.getType() != v.getType()) {
167  if (cast<ShapedType>(replacement.getType()).getElementType() !=
168  cast<ShapedType>(v.getType()).getElementType())
169  continue;
170  rewriter.setInsertionPointAfterValue(replacement);
171  replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
172  replacement);
173  }
174  // Replace the specific use of the tensor::EmptyOp.
175  rewriter.modifyOpInPlace(user, [&]() {
176  user->setOperand(useToBeReplaced->getOperandNumber(), replacement);
177  });
178  state.resetCache();
179  }
180 
181  return WalkResult::advance();
182  });
183 
184  return success();
185 }
186 
187 namespace {
188 struct EmptyTensorElimination
189  : public bufferization::impl::EmptyTensorEliminationPassBase<
190  EmptyTensorElimination> {
191  using Base::Base;
192 
193  void runOnOperation() override;
194 
195  void getDependentDialects(DialectRegistry &registry) const override {
196  registry
197  .insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
198  }
199 };
200 } // namespace
201 
203  Operation *op) {
204  auto moduleOp = dyn_cast<ModuleOp>(op);
206  options.allowReturnAllocsFromLoops = true;
207  if (moduleOp)
208  options.bufferizeFunctionBoundaries = true;
209  OneShotAnalysisState state(op, options);
210  if (moduleOp) {
211  // Module analysis takes into account function boundaries.
212  if (failed(analyzeModuleOp(moduleOp, state)))
213  return failure();
214  } else {
215  // Regular One-Shot Bufferize ignores func.func block arguments, func.call,
216  // func.return.
217  if (failed(analyzeOp(op, state)))
218  return failure();
219  }
220 
221  return bufferization::eliminateEmptyTensors(rewriter, op, state);
222 }
223 
224 void EmptyTensorElimination::runOnOperation() {
225  IRRewriter rewriter(getOperation()->getContext());
226  if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation())))
227  signalPassFailure();
228 }
static Operation * findValidInsertionPoint(Operation *emptyTensorOp, Operation *user, const SmallVector< Value > &neededValues)
Find a valid insertion point for a replacement of emptyTensorOp's use of user operation,...
static bool neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, Operation *insertionPoint, const SmallVector< Value > &neededValues)
Return true if all neededValues are in scope at the given insertionPoint.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:74
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.cpp:323
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:729
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:419
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:593
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
std::function< Value(RewriterBase &, SubsetInsertionOpInterface, tensor::EmptyOp emptyTensorOp, Operation *user)> ControlBuildSubsetExtractionFn
A function type that defines a callback to control the construction of the subset extraction of the S...
Definition: Transforms.h:49
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
Value buildSubsetExtraction(RewriterBase &rewriter, SubsetInsertionOpInterface op, tensor::EmptyOp emptyTensorOp, Operation *user)
This method builds and returns a subset extraction value for the destination tensor that the given op...
llvm::LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op)
Try to eliminate "tensor.empty" ops inside op.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
Options for analysis-enabled bufferization.
Traversal parameters for findValueInReverseUseDefChain.