MLIR 22.0.0git
EmptyTensorElimination.cpp
Go to the documentation of this file.
1//===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
17#include "mlir/IR/Dominance.h"
20
21namespace mlir {
22namespace bufferization {
23#define GEN_PASS_DEF_EMPTYTENSORELIMINATIONPASS
24#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
25} // namespace bufferization
26} // namespace mlir
27
28using namespace mlir;
29using namespace mlir::bufferization;
30
31/// Return true if all `neededValues` are in scope at the given
32/// `insertionPoint`.
33static bool
35 Operation *insertionPoint,
36 const SmallVector<Value> &neededValues) {
37 for (Value val : neededValues) {
38 if (auto bbArg = dyn_cast<BlockArgument>(val)) {
39 Block *owner = bbArg.getOwner();
40 if (!owner->findAncestorOpInBlock(*insertionPoint))
41 return false;
42 } else {
43 auto opResult = cast<OpResult>(val);
44 if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
45 return false;
46 }
47 }
48 return true;
49}
50
51/// Find a valid insertion point for a replacement of `emptyTensorOp`'s
52/// use of `user` operation, assuming that the replacement may use any
53/// value from `neededValues`.
54static Operation *
56 const SmallVector<Value> &neededValues) {
57 DominanceInfo domInfo;
58 Operation *candidateInsertionPoint = emptyTensorOp;
59
60 // Gather all possible insertion points: the location of
61 // `candidateInsertionPoint` and right after the definition of each value in
62 // `neededValues`.
63 SmallVector<Operation *> insertionPointCandidates;
64 insertionPointCandidates.push_back(candidateInsertionPoint);
65 for (Value val : neededValues) {
66 // Note: The anchor op is using all of `neededValues`, so:
67 // * in case of a block argument: There must be at least one op in the block
68 // (the anchor op or one of its parents).
69 // * in case of an OpResult: There must be at least one op right after the
70 // defining op (the anchor op or one of its
71 // parents).
72 if (auto bbArg = dyn_cast<BlockArgument>(val)) {
73 insertionPointCandidates.push_back(
74 &bbArg.getOwner()->getOperations().front());
75 } else {
76 insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
77 }
78 }
79
80 // Select first matching insertion point.
81 for (Operation *insertionPoint : insertionPointCandidates) {
82 // Check if all needed values are in scope.
83 if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
84 neededValues))
85 continue;
86 // Check if the insertion point is before the use to be replaced.
87 if (!domInfo.dominates(insertionPoint, user))
88 continue;
89 return insertionPoint;
90 }
91
92 // No suitable insertion point was found.
93 return nullptr;
94}
95
97 SubsetInsertionOpInterface op,
98 tensor::EmptyOp emptyTensorOp,
99 Operation *user) {
100
101 mlir::OpBuilder::InsertionGuard guard(rewriter);
102 // All values that are needed to create the replacement op.
103 SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction();
104 // Find a suitable insertion point. If no suitable insertion point
105 // for the replacement can be found, return an empty value to skip
106 // this replacement.
107 Operation *insertionPoint =
108 findValidInsertionPoint(emptyTensorOp, user, neededValues);
109 if (!insertionPoint) {
110 // If no already suitable insertion point was found, attempt to move all
111 // needed values before the user.
112 if (failed(moveValueDefinitions(rewriter, neededValues, user)))
113 return {};
114 insertionPoint = user;
115 }
116
117 rewriter.setInsertionPoint(insertionPoint);
119 op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
120 return replacement;
121}
122
124 RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
125 ControlBuildSubsetExtractionFn subsetsExtractionFn) {
126 OpBuilder::InsertionGuard g(rewriter);
127 llvm::DenseSet<OpOperand *> visitedOpOperands;
128 op->walk([&](SubsetInsertionOpInterface op) {
129 visitedOpOperands.clear();
130 OpOperand &source = op.getSourceOperand();
131 // Skip operands that do not bufferize inplace. "tensor.empty" could still
132 // be replaced, but the transformation may not be beneficial.
133 if (!state.isInPlace(source))
134 return WalkResult::skip();
135
136 // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
137 // equivalent tensors. I.e., stop when there are ops such as extract_slice
138 // on the path.
139 TraversalConfig config;
140 config.followEquivalentOnly = true;
141 config.alwaysIncludeLeaves = false;
142 // Replace only if the types match or are static <-> dynamic casts. We do
143 // not support slices or reshapes.
144 // TODO: This could be extended to support IR such as:
145 // %0 = tensor.empty() : tensor<128xf32>
146 // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
147 // %2 = tensor.expand_shape %1 ...
148 // %3 = tensor.insert_slice %2 into ...
149 config.followSameTypeOrCastsOnly = true;
150 SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
151 &source, /*condition=*/
152 [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
153 &visitedOpOperands);
154
155 for (Value v : emptyTensors) {
156 auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
157 assert(emptyTensorOp && "expected tensor.empty op");
158 // Find the use to be replaced from the use-def chain.
159 auto iter = llvm::find_if(
160 visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
161 return llvm::count(emptyTensorOp->getUses(), *opOperand);
162 });
163
164 assert(iter != visitedOpOperands.end() && "could not find use");
165 OpOperand *useToBeReplaced = *iter;
166 Operation *user = useToBeReplaced->getOwner();
167 auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
168 if (!replacement)
169 continue;
170 if (emptyTensorOp == replacement.getDefiningOp())
171 continue;
172 if (replacement.getType() != v.getType()) {
173 if (cast<ShapedType>(replacement.getType()).getElementType() !=
174 cast<ShapedType>(v.getType()).getElementType())
175 continue;
177 replacement = tensor::CastOp::create(rewriter, v.getLoc(), v.getType(),
179 }
180 // Replace the specific use of the tensor::EmptyOp.
181 rewriter.modifyOpInPlace(user, [&]() {
182 user->setOperand(useToBeReplaced->getOperandNumber(), replacement);
183 });
184 state.resetCache();
185 }
186
187 return WalkResult::advance();
188 });
189
190 return success();
191}
192
193namespace {
194struct EmptyTensorElimination
196 EmptyTensorElimination> {
197 using Base::Base;
198
199 void runOnOperation() override;
200
201 void getDependentDialects(DialectRegistry &registry) const override {
202 registry
203 .insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
204 }
205};
206} // namespace
207
209 Operation *op) {
210 auto moduleOp = dyn_cast<ModuleOp>(op);
212 options.allowReturnAllocsFromLoops = true;
213 if (moduleOp)
214 options.bufferizeFunctionBoundaries = true;
215 OneShotAnalysisState state(op, options);
216 if (moduleOp) {
217 // Module analysis takes into account function boundaries.
218 if (failed(analyzeModuleOp(moduleOp, state)))
219 return failure();
220 } else {
221 // Regular One-Shot Bufferize ignores func.func block arguments, func.call,
222 // func.return.
223 if (failed(analyzeOp(op, state)))
224 return failure();
225 }
226
227 return bufferization::eliminateEmptyTensors(rewriter, op, state);
228}
229
230void EmptyTensorElimination::runOnOperation() {
231 IRRewriter rewriter(getOperation()->getContext());
232 if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation())))
233 signalPassFailure();
234}
return success()
static Operation * findValidInsertionPoint(Operation *emptyTensorOp, Operation *user, const SmallVector< Value > &neededValues)
Find a valid insertion point for a replacement of emptyTensorOp's use of user operation,...
static bool neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, Operation *insertionPoint, const SmallVector< Value > &neededValues)
Return true if all neededValues are in scope at the given insertionPoint.
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition Block.cpp:74
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
A class for computing basic dominance information.
Definition Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:421
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setOperand(unsigned idx, Value value)
Definition Operation.h:351
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
State for analysis-enabled bufferization.
bool isInPlace(OpOperand &opOperand) const override
Return true if the given OpResult has been decided to bufferize inplace.
void resetCache() override
Reset cached data structures.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
std::function< Value(RewriterBase &, SubsetInsertionOpInterface, tensor::EmptyOp emptyTensorOp, Operation *user)> ControlBuildSubsetExtractionFn
A function type that defines a callback to control the construction of the subset extraction of the S...
Definition Transforms.h:47
Value buildSubsetExtraction(RewriterBase &rewriter, SubsetInsertionOpInterface op, tensor::EmptyOp emptyTensorOp, Operation *user)
This method builds and returns a subset extraction value for the destination tensor that the given op...
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op)
Try to eliminate "tensor.empty" ops inside op.
llvm::LogicalResult analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
Options for analysis-enabled bufferization.