22 namespace bufferization {
23 #define GEN_PASS_DEF_EMPTYTENSORELIMINATION
24 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
37 for (
Value val : neededValues) {
38 if (
auto bbArg = dyn_cast<BlockArgument>(val)) {
39 Block *owner = bbArg.getOwner();
43 auto opResult = cast<OpResult>(val);
57 if (!domInfo.
dominates(insertionPoint, user))
72 insertionPointCandidates.push_back(emptyTensorOp);
73 for (
Value val : neededValues) {
80 if (
auto bbArg = dyn_cast<BlockArgument>(val)) {
81 insertionPointCandidates.push_back(
82 &bbArg.getOwner()->getOperations().front());
84 insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
89 for (
Operation *insertionPoint : insertionPointCandidates) {
97 return insertionPoint;
108 op->
walk([&](SubsetInsertionOpInterface op) {
109 OpOperand &source = op.getSourceOperand();
112 if (!state.isInPlace(source))
117 op.getValuesNeededToBuildSubsetExtraction();
124 config.alwaysIncludeLeaves =
false;
132 config.followSameTypeOrCastsOnly =
true;
135 [&](
Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
138 for (
Value v : emptyTensors) {
139 Operation *emptyTensorOp = v.getDefiningOp();
150 op.buildSubsetExtraction(rewriter, emptyTensorOp->
getLoc());
155 if (replacement.
getType() != v.getType()) {
157 replacement = rewriter.
create<tensor::CastOp>(v.getLoc(), v.getType(),
161 rewriter.
replaceOp(emptyTensorOp, replacement);
172 struct EmptyTensorElimination
173 :
public bufferization::impl::EmptyTensorEliminationBase<
174 EmptyTensorElimination> {
175 EmptyTensorElimination() =
default;
177 void runOnOperation()
override;
181 .
insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
188 auto moduleOp = dyn_cast<ModuleOp>(op);
190 options.allowReturnAllocsFromLoops =
true;
192 options.bufferizeFunctionBoundaries =
true;
208 void EmptyTensorElimination::runOnOperation() {
215 return std::make_unique<EmptyTensorElimination>();
static bool insertionPointDominatesUses(const DominanceInfo &domInfo, Operation *insertionPoint, Operation *emptyTensorOp)
Return true if the given insertionPoint dominates all uses of emptyTensorOp.
static bool neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, Operation *insertionPoint, const SmallVector< Value > &neededValues)
Return true if all neededValues are in scope at the given insertionPoint.
static Operation * findValidInsertionPoint(Operation *emptyTensorOp, const SmallVector< Value > &neededValues)
Find a valid insertion point for a replacement of emptyTensorOp, assuming that the replacement may us...
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
user_range getUsers()
Returns a range of all users.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
State for analysis-enabled bufferization.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op)
Try to eliminate "tensor.empty" ops inside op.
std::unique_ptr< Pass > createEmptyTensorEliminationPass()
Create a pass that tries to eliminate tensor.empty ops that are anchored on insert_slice ops.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Options for analysis-enabled bufferization.
Traversal parameters for findValueInReverseUseDefChain.
bool followEquivalentOnly
Specifies whether non-equivalent OpOperands should be followed.