MLIR  20.0.0git
Go to the documentation of this file.
1 //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
17 #include "mlir/IR/Dominance.h"
19 #include "mlir/Pass/Pass.h"
21 namespace mlir {
22 namespace bufferization {
24 #include "mlir/Dialect/Bufferization/Transforms/"
25 } // namespace bufferization
26 } // namespace mlir
28 using namespace mlir;
29 using namespace mlir::bufferization;
31 /// Return true if all `neededValues` are in scope at the given
32 /// `insertionPoint`.
33 static bool
35  Operation *insertionPoint,
36  const SmallVector<Value> &neededValues) {
37  for (Value val : neededValues) {
38  if (auto bbArg = dyn_cast<BlockArgument>(val)) {
39  Block *owner = bbArg.getOwner();
40  if (!owner->findAncestorOpInBlock(*insertionPoint))
41  return false;
42  } else {
43  auto opResult = cast<OpResult>(val);
44  if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
45  return false;
46  }
47  }
48  return true;
49 }
51 /// Return true if the given `insertionPoint` dominates all uses of
52 /// `emptyTensorOp`.
53 static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
54  Operation *insertionPoint,
55  Operation *emptyTensorOp) {
56  return llvm::all_of(emptyTensorOp->getUsers(), [&](Operation *user) {
57  return domInfo.dominates(insertionPoint, user);
58  });
59 }
61 /// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming
62 /// that the replacement may use any value from `neededValues`.
63 static Operation *
65  const SmallVector<Value> &neededValues) {
66  DominanceInfo domInfo;
68  // Gather all possible insertion points: the location of `emptyTensorOp` and
69  // right after the definition of each value in `neededValues`.
70  SmallVector<Operation *> insertionPointCandidates;
71  insertionPointCandidates.push_back(emptyTensorOp);
72  for (Value val : neededValues) {
73  // Note: The anchor op is using all of `neededValues`, so:
74  // * in case of a block argument: There must be at least one op in the block
75  // (the anchor op or one of its parents).
76  // * in case of an OpResult: There must be at least one op right after the
77  // defining op (the anchor op or one of its
78  // parents).
79  if (auto bbArg = dyn_cast<BlockArgument>(val)) {
80  insertionPointCandidates.push_back(
81  &bbArg.getOwner()->getOperations().front());
82  } else {
83  insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
84  }
85  }
87  // Select first matching insertion point.
88  for (Operation *insertionPoint : insertionPointCandidates) {
89  // Check if all needed values are in scope.
90  if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
91  neededValues))
92  continue;
93  // Check if the insertion point is before all uses.
94  if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp))
95  continue;
96  return insertionPoint;
97  }
99  // No suitable insertion point was found.
100  return nullptr;
101 }
104  RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
105  OpBuilder::InsertionGuard g(rewriter);
107  op->walk([&](SubsetInsertionOpInterface op) {
108  OpOperand &source = op.getSourceOperand();
109  // Skip operands that do not bufferize inplace. "tensor.empty" could still
110  // be replaced, but the transformation may not be beneficial.
111  if (!state.isInPlace(source))
112  return WalkResult::skip();
114  // All values that are needed to create the replacement op.
115  SmallVector<Value> neededValues =
116  op.getValuesNeededToBuildSubsetExtraction();
118  // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
119  // equivalent tensors. I.e., stop when there are ops such as extract_slice
120  // on the path.
121  TraversalConfig config;
122  config.followEquivalentOnly = true;
123  config.alwaysIncludeLeaves = false;
124  // Replace only if the types match or are static <-> dynamic casts. We do
125  // not support slices or reshapes.
126  // TODO: This could be extended to support IR such as:
127  // %0 = tensor.empty() : tensor<128xf32>
128  // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
129  // %2 = tensor.expand_shape %1 ...
130  // %3 = tensor.insert_slice %2 into ...
131  config.followSameTypeOrCastsOnly = true;
132  SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
133  source.get(), /*condition=*/
134  [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
135  config);
137  for (Value v : emptyTensors) {
138  Operation *emptyTensorOp = v.getDefiningOp();
140  // Find a suitable insertion point. If no suitable insertion point for
141  // the replacement can be found, skip this replacement.
142  Operation *insertionPoint =
143  findValidInsertionPoint(emptyTensorOp, neededValues);
144  if (!insertionPoint)
145  continue;
147  rewriter.setInsertionPoint(insertionPoint);
148  Value replacement =
149  op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
150  if (!replacement)
151  continue;
152  if (emptyTensorOp == replacement.getDefiningOp())
153  continue;
154  if (replacement.getType() != v.getType()) {
155  if (cast<ShapedType>(replacement.getType()).getElementType() !=
156  cast<ShapedType>(v.getType()).getElementType())
157  continue;
158  rewriter.setInsertionPointAfterValue(replacement);
159  replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
160  replacement);
161  }
162  // Replace the tensor::EmptyOp.
163  rewriter.replaceOp(emptyTensorOp, replacement);
164  state.resetCache();
165  }
167  return WalkResult::advance();
168  });
170  return success();
171 }
173 namespace {
174 struct EmptyTensorElimination
175  : public bufferization::impl::EmptyTensorEliminationBase<
176  EmptyTensorElimination> {
177  EmptyTensorElimination() = default;
179  void runOnOperation() override;
181  void getDependentDialects(DialectRegistry &registry) const override {
182  registry
183  .insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
184  }
185 };
186 } // namespace
189  Operation *op) {
190  auto moduleOp = dyn_cast<ModuleOp>(op);
192  options.allowReturnAllocsFromLoops = true;
193  if (moduleOp)
194  options.bufferizeFunctionBoundaries = true;
195  OneShotAnalysisState state(op, options);
196  if (moduleOp) {
197  // Module analysis takes into account function boundaries.
198  if (failed(analyzeModuleOp(moduleOp, state)))
199  return failure();
200  } else {
201  // Regular One-Shot Bufferize ignores func.func block arguments,,
202  // func.return.
203  if (failed(analyzeOp(op, state)))
204  return failure();
205  }
207  return bufferization::eliminateEmptyTensors(rewriter, op, state);
208 }
210 void EmptyTensorElimination::runOnOperation() {
211  IRRewriter rewriter(getOperation()->getContext());
212  if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation())))
213  signalPassFailure();
214 }
217  return std::make_unique<EmptyTensorElimination>();
218 }
static bool insertionPointDominatesUses(const DominanceInfo &domInfo, Operation *insertionPoint, Operation *emptyTensorOp)
Return true if the given insertionPoint dominates all uses of emptyTensorOp.
static bool neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, Operation *insertionPoint, const SmallVector< Value > &neededValues)
Return true if all neededValues are in scope at the given insertionPoint.
static Operation * findValidInsertionPoint(Operation *emptyTensorOp, const SmallVector< Value > &neededValues)
Find a valid insertion point for a replacement of emptyTensorOp, assuming that the replacement may us...
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:73
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
A class for computing basic dominance information.
Definition: Dominance.h:136
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:149
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:424
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
State for analysis-enabled bufferization.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
llvm::LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op)
Try to eliminate "tensor.empty" ops inside op.
std::unique_ptr< Pass > createEmptyTensorEliminationPass()
Create a pass that tries to eliminate tensor.empty ops that are anchored on insert_slice ops.
Include the generated interface declarations.
Options for analysis-enabled bufferization.
Traversal parameters for findValueInReverseUseDefChain.
bool followEquivalentOnly
Specifies whether non-equivalent OpOperands should be followed.