MLIR  19.0.0git
EmptyTensorElimination.cpp
Go to the documentation of this file.
1 //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/Dominance.h"
19 #include "mlir/Pass/Pass.h"
20 
21 namespace mlir {
22 namespace bufferization {
23 #define GEN_PASS_DEF_EMPTYTENSORELIMINATION
24 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
25 } // namespace bufferization
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::bufferization;
30 
31 /// Return true if all `neededValues` are in scope at the given
32 /// `insertionPoint`.
33 static bool
35  Operation *insertionPoint,
36  const SmallVector<Value> &neededValues) {
37  for (Value val : neededValues) {
38  if (auto bbArg = dyn_cast<BlockArgument>(val)) {
39  Block *owner = bbArg.getOwner();
40  if (!owner->findAncestorOpInBlock(*insertionPoint))
41  return false;
42  } else {
43  auto opResult = cast<OpResult>(val);
44  if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
45  return false;
46  }
47  }
48  return true;
49 }
50 
51 /// Return true if the given `insertionPoint` dominates all uses of
52 /// `emptyTensorOp`.
53 static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
54  Operation *insertionPoint,
55  Operation *emptyTensorOp) {
56  return llvm::all_of(emptyTensorOp->getUsers(), [&](Operation *user) {
57  return domInfo.dominates(insertionPoint, user);
58  });
59 }
60 
61 /// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming
62 /// that the replacement may use any value from `neededValues`.
63 static Operation *
65  const SmallVector<Value> &neededValues) {
66  DominanceInfo domInfo;
67 
68  // Gather all possible insertion points: the location of `emptyTensorOp` and
69  // right after the definition of each value in `neededValues`.
70  SmallVector<Operation *> insertionPointCandidates;
71  insertionPointCandidates.push_back(emptyTensorOp);
72  for (Value val : neededValues) {
73  // Note: The anchor op is using all of `neededValues`, so:
74  // * in case of a block argument: There must be at least one op in the block
75  // (the anchor op or one of its parents).
76  // * in case of an OpResult: There must be at least one op right after the
77  // defining op (the anchor op or one of its
78  // parents).
79  if (auto bbArg = dyn_cast<BlockArgument>(val)) {
80  insertionPointCandidates.push_back(
81  &bbArg.getOwner()->getOperations().front());
82  } else {
83  insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
84  }
85  }
86 
87  // Select first matching insertion point.
88  for (Operation *insertionPoint : insertionPointCandidates) {
89  // Check if all needed values are in scope.
90  if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
91  neededValues))
92  continue;
93  // Check if the insertion point is before all uses.
94  if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp))
95  continue;
96  return insertionPoint;
97  }
98 
99  // No suitable insertion point was found.
100  return nullptr;
101 }
102 
104  RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
105  OpBuilder::InsertionGuard g(rewriter);
106 
107  op->walk([&](SubsetInsertionOpInterface op) {
108  OpOperand &source = op.getSourceOperand();
109  // Skip operands that do not bufferize inplace. "tensor.empty" could still
110  // be replaced, but the transformation may not be beneficial.
111  if (!state.isInPlace(source))
112  return WalkResult::skip();
113 
114  // All values that are needed to create the replacement op.
115  SmallVector<Value> neededValues =
116  op.getValuesNeededToBuildSubsetExtraction();
117 
118  // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
119  // equivalent tensors. I.e., stop when there are ops such as extract_slice
120  // on the path.
121  TraversalConfig config;
122  config.followEquivalentOnly = true;
123  config.alwaysIncludeLeaves = false;
124  // Replace only if the types match or are static <-> dynamic casts. We do
125  // not support slices or reshapes.
126  // TODO: This could be extended to support IR such as:
127  // %0 = tensor.empty() : tensor<128xf32>
128  // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
129  // %2 = tensor.expand_shape %1 ...
130  // %3 = tensor.insert_slice %2 into ...
131  config.followSameTypeOrCastsOnly = true;
132  SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
133  source.get(), /*condition=*/
134  [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
135  config);
136 
137  for (Value v : emptyTensors) {
138  Operation *emptyTensorOp = v.getDefiningOp();
139 
140  // Find a suitable insertion point. If no suitable insertion point for
141  // the replacement can be found, skip this replacement.
142  Operation *insertionPoint =
143  findValidInsertionPoint(emptyTensorOp, neededValues);
144  if (!insertionPoint)
145  continue;
146 
147  rewriter.setInsertionPoint(insertionPoint);
148  Value replacement =
149  op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
150  if (!replacement)
151  continue;
152  if (emptyTensorOp == replacement.getDefiningOp())
153  continue;
154  if (replacement.getType() != v.getType()) {
155  rewriter.setInsertionPointAfterValue(replacement);
156  replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
157  replacement);
158  }
159  // Replace the tensor::EmptyOp.
160  rewriter.replaceOp(emptyTensorOp, replacement);
161  state.resetCache();
162  }
163 
164  return WalkResult::advance();
165  });
166 
167  return success();
168 }
169 
170 namespace {
171 struct EmptyTensorElimination
172  : public bufferization::impl::EmptyTensorEliminationBase<
173  EmptyTensorElimination> {
174  EmptyTensorElimination() = default;
175 
176  void runOnOperation() override;
177 
178  void getDependentDialects(DialectRegistry &registry) const override {
179  registry
180  .insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
181  }
182 };
183 } // namespace
184 
186  Operation *op) {
187  auto moduleOp = dyn_cast<ModuleOp>(op);
189  options.allowReturnAllocsFromLoops = true;
190  if (moduleOp)
191  options.bufferizeFunctionBoundaries = true;
192  OneShotAnalysisState state(op, options);
193  if (moduleOp) {
194  // Module analysis takes into account function boundaries.
195  if (failed(analyzeModuleOp(moduleOp, state)))
196  return failure();
197  } else {
198  // Regular One-Shot Bufferize ignores func.func block arguments, func.call,
199  // func.return.
200  if (failed(analyzeOp(op, state)))
201  return failure();
202  }
203 
204  return bufferization::eliminateEmptyTensors(rewriter, op, state);
205 }
206 
207 void EmptyTensorElimination::runOnOperation() {
208  IRRewriter rewriter(getOperation()->getContext());
209  if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation())))
210  signalPassFailure();
211 }
212 
214  return std::make_unique<EmptyTensorElimination>();
215 }
static bool insertionPointDominatesUses(const DominanceInfo &domInfo, Operation *insertionPoint, Operation *emptyTensorOp)
Return true if the given insertionPoint dominates all uses of emptyTensorOp.
static bool neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, Operation *insertionPoint, const SmallVector< Value > &neededValues)
Return true if all neededValues are in scope at the given insertionPoint.
static Operation * findValidInsertionPoint(Operation *emptyTensorOp, const SmallVector< Value > &neededValues)
Find a valid insertion point for a replacement of emptyTensorOp, assuming that the replacement may us...
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:73
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
A class for computing basic dominance information.
Definition: Dominance.h:136
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:149
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:423
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
State for analysis-enabled bufferization.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op)
Try to eliminate "tensor.empty" ops inside op.
std::unique_ptr< Pass > createEmptyTensorEliminationPass()
Create a pass that tries to eliminate tensor.empty ops that are anchored on insert_slice ops.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options for analysis-enabled bufferization.
Traversal parameters for findValueInReverseUseDefChain.
bool followEquivalentOnly
Specifies whether non-equivalent OpOperands should be followed.