21 namespace bufferization {
22 #define GEN_PASS_DEF_EMPTYTENSORELIMINATIONPASS
23 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
36 for (
Value val : neededValues) {
37 if (
auto bbArg = dyn_cast<BlockArgument>(val)) {
38 Block *owner = bbArg.getOwner();
42 auto opResult = cast<OpResult>(val);
57 Operation *candidateInsertionPoint = emptyTensorOp;
63 insertionPointCandidates.push_back(candidateInsertionPoint);
64 for (
Value val : neededValues) {
71 if (
auto bbArg = dyn_cast<BlockArgument>(val)) {
72 insertionPointCandidates.push_back(
73 &bbArg.getOwner()->getOperations().front());
75 insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
80 for (
Operation *insertionPoint : insertionPointCandidates) {
86 if (!domInfo.
dominates(insertionPoint, user))
88 return insertionPoint;
96 SubsetInsertionOpInterface op,
97 tensor::EmptyOp emptyTensorOp,
113 op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
122 op->
walk([&](SubsetInsertionOpInterface op) {
123 visitedOpOperands.clear();
124 OpOperand &source = op.getSourceOperand();
127 if (!state.isInPlace(source))
134 config.followEquivalentOnly =
true;
135 config.alwaysIncludeLeaves =
false;
143 config.followSameTypeOrCastsOnly =
true;
146 [&](
Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
config,
149 for (
Value v : emptyTensors) {
150 auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
151 assert(emptyTensorOp &&
"expected tensor.empty op");
153 auto iter = llvm::find_if(
154 visitedOpOperands, [&emptyTensorOp](
OpOperand *opOperand) {
155 return llvm::count(emptyTensorOp->getUses(), *opOperand);
158 assert(iter != visitedOpOperands.end() &&
"could not find use");
161 auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
164 if (emptyTensorOp == replacement.getDefiningOp())
166 if (replacement.getType() != v.getType()) {
167 if (cast<ShapedType>(replacement.getType()).getElementType() !=
168 cast<ShapedType>(v.getType()).getElementType())
171 replacement = rewriter.
create<tensor::CastOp>(v.getLoc(), v.getType(),
188 struct EmptyTensorElimination
189 :
public bufferization::impl::EmptyTensorEliminationPassBase<
190 EmptyTensorElimination> {
193 void runOnOperation()
override;
197 .
insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
204 auto moduleOp = dyn_cast<ModuleOp>(op);
206 options.allowReturnAllocsFromLoops =
true;
208 options.bufferizeFunctionBoundaries =
true;
224 void EmptyTensorElimination::runOnOperation() {
static Operation * findValidInsertionPoint(Operation *emptyTensorOp, Operation *user, const SmallVector< Value > &neededValues)
Find a valid insertion point for a replacement of emptyTensorOp's use of user operation,...
static bool neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, Operation *insertionPoint, const SmallVector< Value > &neededValues)
Return true if all neededValues are in scope at the given insertionPoint.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
std::function< Value(RewriterBase &, SubsetInsertionOpInterface, tensor::EmptyOp emptyTensorOp, Operation *user)> ControlBuildSubsetExtractionFn
A function type that defines a callback to control the construction of the subset extraction of the S...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
Value buildSubsetExtraction(RewriterBase &rewriter, SubsetInsertionOpInterface op, tensor::EmptyOp emptyTensorOp, Operation *user)
This method builds and returns a subset extraction value for the destination tensor that the given op...
llvm::LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op)
Try to eliminate "tensor.empty" ops inside op.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
Options for analysis-enabled bufferization.
Traversal parameters for findValueInReverseUseDefChain.