22 namespace bufferization {
23 #define GEN_PASS_DEF_EMPTYTENSORELIMINATION
24 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
37 for (
Value val : neededValues) {
38 if (
auto bbArg = dyn_cast<BlockArgument>(val)) {
39 Block *owner = bbArg.getOwner();
43 auto opResult = cast<OpResult>(val);
57 return domInfo.dominates(insertionPoint, user);
71 insertionPointCandidates.push_back(emptyTensorOp);
72 for (
Value val : neededValues) {
79 if (
auto bbArg = dyn_cast<BlockArgument>(val)) {
80 insertionPointCandidates.push_back(
81 &bbArg.getOwner()->getOperations().front());
83 insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
88 for (
Operation *insertionPoint : insertionPointCandidates) {
96 return insertionPoint;
107 op->
walk([&](SubsetInsertionOpInterface op) {
108 OpOperand &source = op.getSourceOperand();
111 if (!state.isInPlace(source))
116 op.getValuesNeededToBuildSubsetExtraction();
123 config.alwaysIncludeLeaves =
false;
131 config.followSameTypeOrCastsOnly =
true;
134 [&](
Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
137 for (
Value v : emptyTensors) {
138 Operation *emptyTensorOp = v.getDefiningOp();
149 op.buildSubsetExtraction(rewriter, emptyTensorOp->
getLoc());
154 if (replacement.
getType() != v.getType()) {
155 if (cast<ShapedType>(replacement.
getType()).getElementType() !=
156 cast<ShapedType>(v.getType()).getElementType())
159 replacement = rewriter.
create<tensor::CastOp>(v.getLoc(), v.getType(),
163 rewriter.
replaceOp(emptyTensorOp, replacement);
174 struct EmptyTensorElimination
175 :
public bufferization::impl::EmptyTensorEliminationBase<
176 EmptyTensorElimination> {
177 EmptyTensorElimination() =
default;
179 void runOnOperation()
override;
183 .
insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
190 auto moduleOp = dyn_cast<ModuleOp>(op);
192 options.allowReturnAllocsFromLoops =
true;
194 options.bufferizeFunctionBoundaries =
true;
210 void EmptyTensorElimination::runOnOperation() {
217 return std::make_unique<EmptyTensorElimination>();
static bool insertionPointDominatesUses(const DominanceInfo &domInfo, Operation *insertionPoint, Operation *emptyTensorOp)
Return true if the given insertionPoint dominates all uses of emptyTensorOp.
static bool neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, Operation *insertionPoint, const SmallVector< Value > &neededValues)
Return true if all neededValues are in scope at the given insertionPoint.
static Operation * findValidInsertionPoint(Operation *emptyTensorOp, const SmallVector< Value > &neededValues)
Find a valid insertion point for a replacement of emptyTensorOp, assuming that the replacement may us...
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
user_range getUsers()
Returns a range of all users.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
State for analysis-enabled bufferization.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
llvm::LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op)
Try to eliminate "tensor.empty" ops inside op.
std::unique_ptr< Pass > createEmptyTensorEliminationPass()
Create a pass that tries to eliminate tensor.empty ops that are anchored on insert_slice ops.
Include the generated interface declarations.
Options for analysis-enabled bufferization.
Traversal parameters for findValueInReverseUseDefChain.
bool followEquivalentOnly
Specifies whether non-equivalent OpOperands should be followed.