MLIR 22.0.0git
EliminateEmptyTensors.cpp
Go to the documentation of this file.
1//===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16
17using namespace mlir;
18using namespace mlir::bufferization;
19using namespace mlir::linalg;
20
21/// Get an output operand that matches the given input operand and can be used
22/// to eliminate a tensor.empty op.
23static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) {
24 for (OpOperand &operand : op.getDpsInitsMutable()) {
25 // Operand must be unused.
26 if (op.payloadUsesValueFromOperand(&operand))
27 continue;
28 // Types must match.
29 if (operand.get().getType() != in->get().getType())
30 continue;
31 // Indexing maps must match.
32 if (op.getMatchingIndexingMap(&operand) != op.getMatchingIndexingMap(in))
33 continue;
34 return &operand;
35 }
36 return nullptr;
37}
38
40 RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
41 OpBuilder::InsertionGuard g(rewriter);
42 DominanceInfo domInfo;
43
44 op->walk([&](LinalgOp op) {
45 // Only ops with all "parallel" iterator types are supported.
46 if (op.getNumParallelLoops() != op.getNumLoops())
47 return WalkResult::skip();
48
49 for (OpOperand *in : op.getDpsInputOperands()) {
50 // Skip non-tensor operands.
51 if (!isa<RankedTensorType>(in->get().getType()))
52 continue;
53
54 // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
55 // equivalent tensors. I.e., stop when there are ops such as extract_slice
56 // on the path.
57 TraversalConfig config;
58 config.followEquivalentOnly = true;
59 config.alwaysIncludeLeaves = false;
60 SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
61 in, /*condition=*/
62 [&](Value val) {
63 return val.getDefiningOp<tensor::EmptyOp>() &&
64 val.getType() == in->get().getType();
65 },
66 config);
67 if (emptyTensors.empty())
68 continue;
69
70 // Find matching out operand.
71 OpOperand *out = getUnusedOutOperand(op, in);
72 if (!out)
73 continue;
74
75 // Check if this transform would violate dominance.
76 if (!llvm::all_of(emptyTensors, [&](Value v) {
77 return domInfo.properlyDominates(out->get(), v.getDefiningOp());
78 }))
79 continue;
80
81 // Replace all uses of the tensor.empty, but do not delete it yet. It will
82 // fold away later (to not invalidate DominanceInfo).
83 for (Value v : emptyTensors) {
84 assert(v.getDefiningOp<tensor::EmptyOp>() && "expected tensor.empty");
85 rewriter.replaceAllUsesWith(v, out->get());
86 }
87
88 // Turn the "in" into an "out".
89 rewriter.modifyOpInPlace(op, [&]() {
90 out->set(in->get());
91 // The original "in" could be removed entirely here (because it will no
92 // longer have any uses in the payload), but we delegate this to
93 // existing cleanup patterns that remove unused operands.
94 in->set(emptyTensors.front());
95 BlockArgument outArg = op.getMatchingBlockArgument(out);
96 assert(outArg.getUses().empty() && "expected that out has no uses");
97 BlockArgument inArg = op.getMatchingBlockArgument(in);
98 rewriter.replaceAllUsesWith(inArg, outArg);
99 assert(!op.payloadUsesValueFromOperand(in) &&
100 "expected that the in operand is now unused");
101 });
102
103 state.resetCache();
104 }
105
106 return WalkResult::advance();
107 });
108 return success();
109}
return success()
static OpOperand * getUnusedOutOperand(LinalgOp op, OpOperand *in)
Get an output operand that matches the given input operand and can be used to eliminate a tensor....
This class represents an argument of a Block.
Definition Value.h:309
A class for computing basic dominance information.
Definition Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
State for analysis-enabled bufferization.
void resetCache() override
Reset cached data structures.
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131