MLIR 22.0.0git
CommutativityUtils.cpp
Go to the documentation of this file.
1//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a commutativity utility pattern and a function to
10// populate this pattern. The function is intended to be used inside passes to
11// simplify the matching of commutative operations by fixing the order of their
12// operands.
13//
14//===----------------------------------------------------------------------===//
15
17
18#include <queue>
19
20using namespace mlir;
21
22/// The possible "types" of ancestors. Here, an ancestor is an op or a block
23/// argument present in the backward slice of a value.
25 /// Pertains to a block argument.
27
28 /// Pertains to a non-constant-like op.
30
31 /// Pertains to a constant-like op.
33};
34
35/// Stores the "key" associated with an ancestor.
37 /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on
38 /// the ancestor.
40
41 /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or
42 /// `CONSTANT_OP`. Else, holds "".
43 StringRef opName;
44
45 /// Constructor for `AncestorKey`.
47 if (!op) {
49 } else {
50 type =
52 opName = op->getName().getStringRef();
53 }
54 }
55
56 /// Overloaded operator `<` for `AncestorKey`.
57 ///
58 /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those
59 /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in
60 /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller
61 /// ones are the ones with smaller op names (lexicographically).
62 ///
63 /// TODO: Include other information like attributes, value type, etc., to
64 /// enhance this comparison. For example, currently this comparison doesn't
65 /// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and
66 /// `addi (in i64)`. Such an enhancement should only be done if the need
67 /// arises.
68 bool operator<(const AncestorKey &key) const {
69 return std::tie(type, opName) < std::tie(key.type, key.opName);
70 }
71};
72
73/// Stores a commutative operand along with its BFS traversal information.
75 /// Stores the operand.
77
78 /// Stores the queue of ancestors of the operand's BFS traversal at a
79 /// particular point in time.
80 std::queue<Operation *> ancestorQueue;
81
82 /// Stores the list of ancestors that have been visited by the BFS traversal
83 /// at a particular point in time.
85
86 /// Stores the operand's "key". This "key" is defined as a list of the
87 /// "AncestorKeys" associated with the ancestors of this operand, in a
88 /// breadth-first order.
89 ///
90 /// So, if an operand, say `A`, was produced as follows:
91 ///
92 /// `<block argument>` `<block argument>`
93 /// \ /
94 /// \ /
95 /// `arith.subi` `arith.constant`
96 /// \ /
97 /// `arith.addi`
98 /// |
99 /// returns `A`
100 ///
101 /// Then, the ancestors of `A`, in the breadth-first order are:
102 /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and
103 /// `<block argument>`.
104 ///
105 /// Thus, the "key" associated with operand `A` is:
106 /// {
107 /// {type: `NON_CONSTANT_OP`, opName: "arith.addi"},
108 /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"},
109 /// {type: `CONSTANT_OP`, opName: "arith.constant"},
110 /// {type: `BLOCK_ARGUMENT`, opName: ""},
111 /// {type: `BLOCK_ARGUMENT`, opName: ""}
112 /// }
114
115 /// Push an ancestor into the operand's BFS information structure. This
116 /// entails it being pushed into the queue (always) and inserted into the
117 /// "visited ancestors" list (iff it is an op rather than a block argument).
119 ancestorQueue.push(op);
120 if (op)
121 visitedAncestors.insert(op);
122 }
123
124 /// Refresh the key.
125 ///
126 /// Refreshing a key entails making it up-to-date with the operand's BFS
127 /// traversal that has happened till that point in time, i.e, appending the
128 /// existing key with the front ancestor's "AncestorKey". Note that a key
129 /// directly reflects the BFS and thus needs to be refreshed during the
130 /// progression of the traversal.
131 void refreshKey() {
132 if (ancestorQueue.empty())
133 return;
134
135 Operation *frontAncestor = ancestorQueue.front();
136 AncestorKey frontAncestorKey(frontAncestor);
137 key.push_back(frontAncestorKey);
138 }
139
140 /// Pop the front ancestor, if any, from the queue and then push its adjacent
141 /// unvisited ancestors, if any, to the queue (this is the main body of the
142 /// BFS algorithm).
144 if (ancestorQueue.empty())
145 return;
146 Operation *frontAncestor = ancestorQueue.front();
147 ancestorQueue.pop();
148 if (!frontAncestor)
149 return;
150 for (Value operand : frontAncestor->getOperands()) {
151 Operation *operandDefOp = operand.getDefiningOp();
152 if (!operandDefOp || !visitedAncestors.contains(operandDefOp))
153 pushAncestor(operandDefOp);
154 }
155 }
156};
157
158/// Sorts the operands of `op` in ascending order of the "key" associated with
159/// each operand iff `op` is commutative. This is a stable sort.
160///
161/// After the application of this pattern, since the commutative operands now
162/// have a deterministic order in which they occur in an op, the matching of
163/// large DAGs becomes much simpler, i.e., requires much less number of checks
164/// to be written by a user in her/his pattern matching function.
165///
166/// Some examples of such a sorting:
167///
168/// Assume that the sorting is being applied to `foo.commutative`, which is a
169/// commutative op.
170///
171/// Example 1:
172///
173/// %1 = foo.const 0
174/// %2 = foo.mul <block argument>, <block argument>
175/// %3 = foo.commutative %1, %2
176///
177/// Here,
178/// 1. The key associated with %1 is:
179/// `{
180/// {CONSTANT_OP, "foo.const"}
181/// }`
182/// 2. The key associated with %2 is:
183/// `{
184/// {NON_CONSTANT_OP, "foo.mul"},
185/// {BLOCK_ARGUMENT, ""},
186/// {BLOCK_ARGUMENT, ""}
187/// }`
188///
189/// The key of %2 < the key of %1
190/// Thus, the sorted `foo.commutative` is:
191/// %3 = foo.commutative %2, %1
192///
193/// Example 2:
194///
195/// %1 = foo.const 0
196/// %2 = foo.mul <block argument>, <block argument>
197/// %3 = foo.mul %2, %1
198/// %4 = foo.add %2, %1
199/// %5 = foo.commutative %1, %2, %3, %4
200///
201/// Here,
202/// 1. The key associated with %1 is:
203/// `{
204/// {CONSTANT_OP, "foo.const"}
205/// }`
206/// 2. The key associated with %2 is:
207/// `{
208/// {NON_CONSTANT_OP, "foo.mul"},
209/// {BLOCK_ARGUMENT, ""}
210/// }`
211/// 3. The key associated with %3 is:
212/// `{
213/// {NON_CONSTANT_OP, "foo.mul"},
214/// {NON_CONSTANT_OP, "foo.mul"},
215/// {CONSTANT_OP, "foo.const"},
216/// {BLOCK_ARGUMENT, ""},
217/// {BLOCK_ARGUMENT, ""}
218/// }`
219/// 4. The key associated with %4 is:
220/// `{
221/// {NON_CONSTANT_OP, "foo.add"},
222/// {NON_CONSTANT_OP, "foo.mul"},
223/// {CONSTANT_OP, "foo.const"},
224/// {BLOCK_ARGUMENT, ""},
225/// {BLOCK_ARGUMENT, ""}
226/// }`
227///
228/// Thus, the sorted `foo.commutative` is:
229/// %5 = foo.commutative %4, %3, %2, %1
231public:
233 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {}
234 LogicalResult matchAndRewrite(Operation *op,
235 PatternRewriter &rewriter) const override {
236 // Custom comparator for two commutative operands, which returns true iff
237 // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`,
238 // i.e.,
239 // 1. In the first unequal pair of corresponding AncestorKeys, the
240 // AncestorKey in `constCommOperandA` is smaller, or,
241 // 2. Both the AncestorKeys in every pair are the same and the size of
242 // `constCommOperandA`'s "key" is smaller.
243 auto commutativeOperandComparator =
244 [](const std::unique_ptr<CommutativeOperand> &constCommOperandA,
245 const std::unique_ptr<CommutativeOperand> &constCommOperandB) {
246 if (constCommOperandA->operand == constCommOperandB->operand)
247 return false;
248
249 auto &commOperandA =
250 const_cast<std::unique_ptr<CommutativeOperand> &>(
251 constCommOperandA);
252 auto &commOperandB =
253 const_cast<std::unique_ptr<CommutativeOperand> &>(
254 constCommOperandB);
255
256 // Iteratively perform the BFS's of both operands until an order among
257 // them can be determined.
258 unsigned keyIndex = 0;
259 while (true) {
260 if (commOperandA->key.size() <= keyIndex) {
261 if (commOperandA->ancestorQueue.empty())
262 return true;
263 commOperandA->popFrontAndPushAdjacentUnvisitedAncestors();
264 commOperandA->refreshKey();
265 }
266 if (commOperandB->key.size() <= keyIndex) {
267 if (commOperandB->ancestorQueue.empty())
268 return false;
269 commOperandB->popFrontAndPushAdjacentUnvisitedAncestors();
270 commOperandB->refreshKey();
271 }
272 if (commOperandA->ancestorQueue.empty() ||
273 commOperandB->ancestorQueue.empty())
274 return commOperandA->key.size() < commOperandB->key.size();
275 if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex])
276 return true;
277 if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex])
278 return false;
279 keyIndex++;
280 }
281 };
282
283 // If `op` is not commutative, do nothing.
285 return failure();
286
287 // Populate the list of commutative operands.
288 SmallVector<Value, 2> operands = op->getOperands();
290 for (Value operand : operands) {
291 std::unique_ptr<CommutativeOperand> commOperand =
292 std::make_unique<CommutativeOperand>();
293 commOperand->operand = operand;
294 commOperand->pushAncestor(operand.getDefiningOp());
295 commOperand->refreshKey();
296 commOperands.push_back(std::move(commOperand));
297 }
298
299 // Sort the operands.
300 llvm::stable_sort(commOperands, commutativeOperandComparator);
301 SmallVector<Value, 2> sortedOperands;
302 for (const std::unique_ptr<CommutativeOperand> &commOperand : commOperands)
303 sortedOperands.push_back(commOperand->operand);
304 if (sortedOperands == operands)
305 return failure();
306 rewriter.modifyOpInPlace(op, [&] { op->setOperands(sortedOperands); });
307 return success();
308 }
309};
310
return success()
AncestorType
The possible "types" of ancestors.
@ NON_CONSTANT_OP
Pertains to a non-constant-like op.
@ CONSTANT_OP
Pertains to a constant-like op.
@ BLOCK_ARGUMENT
Pertains to a block argument.
Sorts the operands of op in ascending order of the "key" associated with each operand iff op is commu...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
SortCommutativeOperands(MLIRContext *context)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides the API for a sub-set of ops that are known to be constant-like.
This class adds property that the operation is commutative.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
void populateCommutativityUtilsPatterns(RewritePatternSet &patterns)
Populates the commutativity utility patterns.
const FrozenRewritePatternSet & patterns
Stores the "key" associated with an ancestor.
AncestorKey(Operation *op)
Constructor for AncestorKey.
StringRef opName
Holds the op name of the ancestor if its type is NON_CONSTANT_OP or CONSTANT_OP.
bool operator<(const AncestorKey &key) const
Overloaded operator < for AncestorKey.
AncestorType type
Holds BLOCK_ARGUMENT, NON_CONSTANT_OP, or CONSTANT_OP, depending on the ancestor.
Stores a commutative operand along with its BFS traversal information.
void popFrontAndPushAdjacentUnvisitedAncestors()
Pop the front ancestor, if any, from the queue and then push its adjacent unvisited ancestors,...
SmallVector< AncestorKey, 4 > key
Stores the operand's "key".
DenseSet< Operation * > visitedAncestors
Stores the list of ancestors that have been visited by the BFS traversal at a particular point in tim...
void refreshKey()
Refresh the key.
Value operand
Stores the operand.
void pushAncestor(Operation *op)
Push an ancestor into the operand's BFS information structure.
std::queue< Operation * > ancestorQueue
Stores the queue of ancestors of the operand's BFS traversal at a particular point in time.
This class acts as a special tag that makes the desire to match "any" operation type explicit.