22 namespace bufferization {
23 #define GEN_PASS_DEF_EMPTYTENSORELIMINATION
24 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
37 for (
Value val : neededValues) {
38 if (
auto bbArg = dyn_cast<BlockArgument>(val)) {
39 Block *owner = bbArg.getOwner();
43 auto opResult = cast<OpResult>(val);
58 Operation *candidateInsertionPoint = emptyTensorOp;
64 insertionPointCandidates.push_back(candidateInsertionPoint);
65 for (
Value val : neededValues) {
72 if (
auto bbArg = dyn_cast<BlockArgument>(val)) {
73 insertionPointCandidates.push_back(
74 &bbArg.getOwner()->getOperations().front());
76 insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
81 for (
Operation *insertionPoint : insertionPointCandidates) {
87 if (!domInfo.
dominates(insertionPoint, user))
89 return insertionPoint;
97 SubsetInsertionOpInterface op,
98 tensor::EmptyOp emptyTensorOp,
114 op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
123 op->
walk([&](SubsetInsertionOpInterface op) {
124 visitedOpOperands.clear();
125 OpOperand &source = op.getSourceOperand();
128 if (!state.isInPlace(source))
135 config.followEquivalentOnly =
true;
136 config.alwaysIncludeLeaves =
false;
144 config.followSameTypeOrCastsOnly =
true;
147 [&](
Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
config,
150 for (
Value v : emptyTensors) {
151 auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
152 assert(emptyTensorOp &&
"expected tensor.empty op");
154 auto iter = llvm::find_if(
155 visitedOpOperands, [&emptyTensorOp](
OpOperand *opOperand) {
156 return llvm::count(emptyTensorOp->getUses(), *opOperand);
159 assert(iter != visitedOpOperands.end() &&
"could not find use");
162 auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
165 if (emptyTensorOp == replacement.getDefiningOp())
167 if (replacement.getType() != v.getType()) {
168 if (cast<ShapedType>(replacement.getType()).getElementType() !=
169 cast<ShapedType>(v.getType()).getElementType())
172 replacement = rewriter.
create<tensor::CastOp>(v.getLoc(), v.getType(),
189 struct EmptyTensorElimination
190 :
public bufferization::impl::EmptyTensorEliminationBase<
191 EmptyTensorElimination> {
192 EmptyTensorElimination() =
default;
194 void runOnOperation()
override;
198 .
insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
205 auto moduleOp = dyn_cast<ModuleOp>(op);
207 options.allowReturnAllocsFromLoops =
true;
209 options.bufferizeFunctionBoundaries =
true;
225 void EmptyTensorElimination::runOnOperation() {
232 return std::make_unique<EmptyTensorElimination>();
static Operation * findValidInsertionPoint(Operation *emptyTensorOp, Operation *user, const SmallVector< Value > &neededValues)
Find a valid insertion point for a replacement of emptyTensorOp's use of user operation,...
static bool neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, Operation *insertionPoint, const SmallVector< Value > &neededValues)
Return true if all neededValues are in scope at the given insertionPoint.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
std::function< Value(RewriterBase &, SubsetInsertionOpInterface, tensor::EmptyOp emptyTensorOp, Operation *user)> ControlBuildSubsetExtractionFn
A function type that defines a callback to control the construction of the subset extraction of the S...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
Value buildSubsetExtraction(RewriterBase &rewriter, SubsetInsertionOpInterface op, tensor::EmptyOp emptyTensorOp, Operation *user)
This method builds and returns a subset extraction value for the destination tensor that the given op...
llvm::LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op)
Try to eliminate "tensor.empty" ops inside op.
std::unique_ptr< Pass > createEmptyTensorEliminationPass()
Create a pass that tries to eliminate tensor.empty ops that are anchored on insert_slice ops.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
Options for analysis-enabled bufferization.
Traversal parameters for findValueInReverseUseDefChain.