22 namespace bufferization {
23 #define GEN_PASS_DEF_EMPTYTENSORELIMINATION
24 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
37 for (
Value val : neededValues) {
38 if (
auto bbArg = dyn_cast<BlockArgument>(val)) {
39 Block *owner = bbArg.getOwner();
43 auto opResult = cast<OpResult>(val);
58 Operation *candidateInsertionPoint = emptyTensorOp;
64 insertionPointCandidates.push_back(candidateInsertionPoint);
65 for (
Value val : neededValues) {
72 if (
auto bbArg = dyn_cast<BlockArgument>(val)) {
73 insertionPointCandidates.push_back(
74 &bbArg.getOwner()->getOperations().front());
76 insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
81 for (
Operation *insertionPoint : insertionPointCandidates) {
87 if (!domInfo.
dominates(insertionPoint, user))
89 return insertionPoint;
100 op->
walk([&](SubsetInsertionOpInterface op) {
101 visitedOpOperands.clear();
102 OpOperand &source = op.getSourceOperand();
105 if (!state.isInPlace(source))
110 op.getValuesNeededToBuildSubsetExtraction();
116 config.followEquivalentOnly =
true;
117 config.alwaysIncludeLeaves =
false;
125 config.followSameTypeOrCastsOnly =
true;
128 [&](
Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
config,
131 for (
Value v : emptyTensors) {
132 Operation *emptyTensorOp = v.getDefiningOp();
135 auto iter = llvm::find_if(
136 visitedOpOperands, [&emptyTensorOp](
OpOperand *opOperand) {
137 return llvm::count(emptyTensorOp->
getUses(), *opOperand);
141 if (iter == visitedOpOperands.end())
155 op.buildSubsetExtraction(rewriter, emptyTensorOp->
getLoc());
160 if (replacement.
getType() != v.getType()) {
161 if (cast<ShapedType>(replacement.
getType()).getElementType() !=
162 cast<ShapedType>(v.getType()).getElementType())
165 replacement = rewriter.
create<tensor::CastOp>(v.getLoc(), v.getType(),
182 struct EmptyTensorElimination
183 :
public bufferization::impl::EmptyTensorEliminationBase<
184 EmptyTensorElimination> {
185 EmptyTensorElimination() =
default;
187 void runOnOperation()
override;
191 .
insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
198 auto moduleOp = dyn_cast<ModuleOp>(op);
200 options.allowReturnAllocsFromLoops =
true;
202 options.bufferizeFunctionBoundaries =
true;
218 void EmptyTensorElimination::runOnOperation() {
225 return std::make_unique<EmptyTensorElimination>();
static Operation * findValidInsertionPoint(Operation *emptyTensorOp, Operation *user, const SmallVector< Value > &neededValues)
Find a valid insertion point for a replacement of emptyTensorOp's use of user operation,...
static bool neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, Operation *insertionPoint, const SmallVector< Value > &neededValues)
Return true if all neededValues are in scope at the given insertionPoint.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
llvm::LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op)
Try to eliminate "tensor.empty" ops inside op.
std::unique_ptr< Pass > createEmptyTensorEliminationPass()
Create a pass that tries to eliminate tensor.empty ops that are anchored on insert_slice ops.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
Options for analysis-enabled bufferization.
Traversal parameters for findValueInReverseUseDefChain.