MLIR  18.0.0git
EmptyTensorElimination.cpp
Go to the documentation of this file.
1 //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/Dominance.h"
19 #include "mlir/Pass/Pass.h"
20 
21 namespace mlir {
22 namespace bufferization {
23 #define GEN_PASS_DEF_EMPTYTENSORELIMINATION
24 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
25 } // namespace bufferization
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::bufferization;
30 
31 /// Return true if all `neededValues` are in scope at the given
32 /// `insertionPoint`.
33 static bool
35  Operation *insertionPoint,
36  const SmallVector<Value> &neededValues) {
37  for (Value val : neededValues) {
38  if (auto bbArg = dyn_cast<BlockArgument>(val)) {
39  Block *owner = bbArg.getOwner();
40  if (!owner->findAncestorOpInBlock(*insertionPoint))
41  return false;
42  } else {
43  auto opResult = cast<OpResult>(val);
44  if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
45  return false;
46  }
47  }
48  return true;
49 }
50 
51 /// Return true if the given `insertionPoint` dominates all uses of
52 /// `emptyTensorOp`.
53 static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
54  Operation *insertionPoint,
55  Operation *emptyTensorOp) {
56  for (Operation *user : emptyTensorOp->getUsers())
57  if (!domInfo.dominates(insertionPoint, user))
58  return false;
59  return true;
60 }
61 
62 /// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming
63 /// that the replacement may use any value from `neededValues`.
64 static Operation *
66  const SmallVector<Value> &neededValues) {
67  DominanceInfo domInfo;
68 
69  // Gather all possible insertion points: the location of `emptyTensorOp` and
70  // right after the definition of each value in `neededValues`.
71  SmallVector<Operation *> insertionPointCandidates;
72  insertionPointCandidates.push_back(emptyTensorOp);
73  for (Value val : neededValues) {
74  // Note: The anchor op is using all of `neededValues`, so:
75  // * in case of a block argument: There must be at least one op in the block
76  // (the anchor op or one of its parents).
77  // * in case of an OpResult: There must be at least one op right after the
78  // defining op (the anchor op or one of its
79  // parents).
80  if (auto bbArg = dyn_cast<BlockArgument>(val)) {
81  insertionPointCandidates.push_back(
82  &bbArg.getOwner()->getOperations().front());
83  } else {
84  insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
85  }
86  }
87 
88  // Select first matching insertion point.
89  for (Operation *insertionPoint : insertionPointCandidates) {
90  // Check if all needed values are in scope.
91  if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
92  neededValues))
93  continue;
94  // Check if the insertion point is before all uses.
95  if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp))
96  continue;
97  return insertionPoint;
98  }
99 
100  // No suitable insertion point was found.
101  return nullptr;
102 }
103 
105  RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
106  OpBuilder::InsertionGuard g(rewriter);
107 
108  op->walk([&](SubsetInsertionOpInterface op) {
109  OpOperand &source = op.getSourceOperand();
110  // Skip operands that do not bufferize inplace. "tensor.empty" could still
111  // be replaced, but the transformation may not be beneficial.
112  if (!state.isInPlace(source))
113  return WalkResult::skip();
114 
115  // All values that are needed to create the replacement op.
116  SmallVector<Value> neededValues =
117  op.getValuesNeededToBuildSubsetExtraction();
118 
119  // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
120  // equivalent tensors. I.e., stop when there are ops such as extract_slice
121  // on the path.
122  TraversalConfig config;
123  config.followEquivalentOnly = true;
124  config.alwaysIncludeLeaves = false;
125  // Replace only if the types match or are static <-> dynamic casts. We do
126  // not support slices or reshapes.
127  // TODO: This could be extended to support IR such as:
128  // %0 = tensor.empty() : tensor<128xf32>
129  // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
130  // %2 = tensor.expand_shape %1 ...
131  // %3 = tensor.insert_slice %2 into ...
132  config.followSameTypeOrCastsOnly = true;
133  SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
134  source.get(), /*condition=*/
135  [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
136  config);
137 
138  for (Value v : emptyTensors) {
139  Operation *emptyTensorOp = v.getDefiningOp();
140 
141  // Find a suitable insertion point. If no suitable insertion point for
142  // the replacement can be found, skip this replacement.
143  Operation *insertionPoint =
144  findValidInsertionPoint(emptyTensorOp, neededValues);
145  if (!insertionPoint)
146  continue;
147 
148  rewriter.setInsertionPoint(insertionPoint);
149  Value replacement =
150  op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
151  if (!replacement)
152  continue;
153  if (emptyTensorOp == replacement.getDefiningOp())
154  continue;
155  if (replacement.getType() != v.getType()) {
156  rewriter.setInsertionPointAfterValue(replacement);
157  replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
158  replacement);
159  }
160  // Replace the tensor::EmptyOp.
161  rewriter.replaceOp(emptyTensorOp, replacement);
162  state.resetCache();
163  }
164 
165  return WalkResult::advance();
166  });
167 
168  return success();
169 }
170 
171 namespace {
172 struct EmptyTensorElimination
173  : public bufferization::impl::EmptyTensorEliminationBase<
174  EmptyTensorElimination> {
175  EmptyTensorElimination() = default;
176 
177  void runOnOperation() override;
178 
179  void getDependentDialects(DialectRegistry &registry) const override {
180  registry
181  .insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
182  }
183 };
184 } // namespace
185 
187  Operation *op) {
188  auto moduleOp = dyn_cast<ModuleOp>(op);
190  options.allowReturnAllocsFromLoops = true;
191  if (moduleOp)
192  options.bufferizeFunctionBoundaries = true;
193  OneShotAnalysisState state(op, options);
194  if (moduleOp) {
195  // Module analysis takes into account function boundaries.
196  if (failed(analyzeModuleOp(moduleOp, state)))
197  return failure();
198  } else {
199  // Regular One-Shot Bufferize ignores func.func block arguments, func.call,
200  // func.return.
201  if (failed(analyzeOp(op, state)))
202  return failure();
203  }
204 
205  return bufferization::eliminateEmptyTensors(rewriter, op, state);
206 }
207 
208 void EmptyTensorElimination::runOnOperation() {
209  IRRewriter rewriter(getOperation()->getContext());
210  if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation())))
211  signalPassFailure();
212 }
213 
215  return std::make_unique<EmptyTensorElimination>();
216 }
static bool insertionPointDominatesUses(const DominanceInfo &domInfo, Operation *insertionPoint, Operation *emptyTensorOp)
Return true if the given insertionPoint dominates all uses of emptyTensorOp.
static bool neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, Operation *insertionPoint, const SmallVector< Value > &neededValues)
Return true if all neededValues are in scope at the given insertionPoint.
static Operation * findValidInsertionPoint(Operation *emptyTensorOp, const SmallVector< Value > &neededValues)
Find a valid insertion point for a replacement of emptyTensorOp, assuming that the replacement may us...
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:68
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
A class for computing basic dominance information.
Definition: Dominance.h:121
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:134
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:141
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:710
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:852
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
State for analysis-enabled bufferization.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op)
Try to eliminate "tensor.empty" ops inside op.
std::unique_ptr< Pass > createEmptyTensorEliminationPass()
Create a pass that tries to eliminate tensor.empty ops that are anchored on insert_slice ops.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options for analysis-enabled bufferization.
Traversal parameters for findValueInReverseUseDefChain.
bool followEquivalentOnly
Specifies whether non-equivalent OpOperands should be followed.