23#define GEN_PASS_DEF_EMPTYTENSORELIMINATIONPASS
24#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
37 for (
Value val : neededValues) {
38 if (
auto bbArg = dyn_cast<BlockArgument>(val)) {
39 Block *owner = bbArg.getOwner();
43 auto opResult = cast<OpResult>(val);
58 Operation *candidateInsertionPoint = emptyTensorOp;
64 insertionPointCandidates.push_back(candidateInsertionPoint);
65 for (
Value val : neededValues) {
72 if (
auto bbArg = dyn_cast<BlockArgument>(val)) {
73 insertionPointCandidates.push_back(
74 &bbArg.getOwner()->getOperations().front());
76 insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
81 for (
Operation *insertionPoint : insertionPointCandidates) {
87 if (!domInfo.
dominates(insertionPoint, user))
89 return insertionPoint;
97 SubsetInsertionOpInterface op,
98 tensor::EmptyOp emptyTensorOp,
109 if (!insertionPoint) {
114 insertionPoint = user;
119 op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
128 op->
walk([&](SubsetInsertionOpInterface op) {
129 visitedOpOperands.clear();
130 OpOperand &source = op.getSourceOperand();
140 config.followEquivalentOnly =
true;
141 config.alwaysIncludeLeaves =
false;
149 config.followSameTypeOrCastsOnly =
true;
155 for (
Value v : emptyTensors) {
156 auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
157 assert(emptyTensorOp &&
"expected tensor.empty op");
159 auto iter = llvm::find_if(
160 visitedOpOperands, [&emptyTensorOp](
OpOperand *opOperand) {
161 return llvm::count(emptyTensorOp->getUses(), *opOperand);
164 assert(iter != visitedOpOperands.end() &&
"could not find use");
167 auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
173 if (cast<ShapedType>(
replacement.getType()).getElementType() !=
174 cast<ShapedType>(v.getType()).getElementType())
194struct EmptyTensorElimination
196 EmptyTensorElimination> {
199 void runOnOperation()
override;
203 .
insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
210 auto moduleOp = dyn_cast<ModuleOp>(op);
212 options.allowReturnAllocsFromLoops =
true;
214 options.bufferizeFunctionBoundaries =
true;
230void EmptyTensorElimination::runOnOperation() {
static Operation * findValidInsertionPoint(Operation *emptyTensorOp, Operation *user, const SmallVector< Value > &neededValues)
Find a valid insertion point for a replacement of emptyTensorOp's use of user operation,...
static bool neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, Operation *insertionPoint, const SmallVector< Value > &neededValues)
Return true if all neededValues are in scope at the given insertionPoint.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
State for analysis-enabled bufferization.
bool isInPlace(OpOperand &opOperand) const override
Return true if the given OpResult has been decided to bufferize inplace.
void resetCache() override
Reset cached data structures.
Operation * getOwner() const
Return the owner of this operand.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
std::function< Value(RewriterBase &, SubsetInsertionOpInterface, tensor::EmptyOp emptyTensorOp, Operation *user)> ControlBuildSubsetExtractionFn
A function type that defines a callback to control the construction of the subset extraction of the S...
Value buildSubsetExtraction(RewriterBase &rewriter, SubsetInsertionOpInterface op, tensor::EmptyOp emptyTensorOp, Operation *user)
This method builds and returns a subset extraction value for the destination tensor that the given op...
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op)
Try to eliminate "tensor.empty" ops inside op.
llvm::LogicalResult analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::SetVector< T, Vector, Set, N > SetVector
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
Options for analysis-enabled bufferization.