MLIR  20.0.0git
CommutativityUtils.cpp
Go to the documentation of this file.
1 //===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a commutativity utility pattern and a function to
10 // populate this pattern. The function is intended to be used inside passes to
11 // simplify the matching of commutative operations by fixing the order of their
12 // operands.
13 //
14 //===----------------------------------------------------------------------===//
15 
17 
18 #include <queue>
19 
20 using namespace mlir;
21 
22 /// The possible "types" of ancestors. Here, an ancestor is an op or a block
23 /// argument present in the backward slice of a value.
25  /// Pertains to a block argument.
27 
28  /// Pertains to a non-constant-like op.
30 
31  /// Pertains to a constant-like op.
33 };
34 
35 /// Stores the "key" associated with an ancestor.
36 struct AncestorKey {
37  /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on
38  /// the ancestor.
40 
41  /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or
42  /// `CONSTANT_OP`. Else, holds "".
43  StringRef opName;
44 
45  /// Constructor for `AncestorKey`.
47  if (!op) {
48  type = BLOCK_ARGUMENT;
49  } else {
50  type =
52  opName = op->getName().getStringRef();
53  }
54  }
55 
56  /// Overloaded operator `<` for `AncestorKey`.
57  ///
58  /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those
59  /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in
60  /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller
61  /// ones are the ones with smaller op names (lexicographically).
62  ///
63  /// TODO: Include other information like attributes, value type, etc., to
64  /// enhance this comparison. For example, currently this comparison doesn't
65  /// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and
66  /// `addi (in i64)`. Such an enhancement should only be done if the need
67  /// arises.
68  bool operator<(const AncestorKey &key) const {
69  return std::tie(type, opName) < std::tie(key.type, key.opName);
70  }
71 };
72 
73 /// Stores a commutative operand along with its BFS traversal information.
75  /// Stores the operand.
77 
78  /// Stores the queue of ancestors of the operand's BFS traversal at a
79  /// particular point in time.
80  std::queue<Operation *> ancestorQueue;
81 
82  /// Stores the list of ancestors that have been visited by the BFS traversal
83  /// at a particular point in time.
85 
86  /// Stores the operand's "key". This "key" is defined as a list of the
87  /// "AncestorKeys" associated with the ancestors of this operand, in a
88  /// breadth-first order.
89  ///
90  /// So, if an operand, say `A`, was produced as follows:
91  ///
92  /// `<block argument>` `<block argument>`
93  /// \ /
94  /// \ /
95  /// `arith.subi` `arith.constant`
96  /// \ /
97  /// `arith.addi`
98  /// |
99  /// returns `A`
100  ///
101  /// Then, the ancestors of `A`, in the breadth-first order are:
102  /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and
103  /// `<block argument>`.
104  ///
105  /// Thus, the "key" associated with operand `A` is:
106  /// {
107  /// {type: `NON_CONSTANT_OP`, opName: "arith.addi"},
108  /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"},
109  /// {type: `CONSTANT_OP`, opName: "arith.constant"},
110  /// {type: `BLOCK_ARGUMENT`, opName: ""},
111  /// {type: `BLOCK_ARGUMENT`, opName: ""}
112  /// }
114 
115  /// Push an ancestor into the operand's BFS information structure. This
116  /// entails it being pushed into the queue (always) and inserted into the
117  /// "visited ancestors" list (iff it is an op rather than a block argument).
119  ancestorQueue.push(op);
120  if (op)
121  visitedAncestors.insert(op);
122  }
123 
124  /// Refresh the key.
125  ///
126  /// Refreshing a key entails making it up-to-date with the operand's BFS
127  /// traversal that has happened till that point in time, i.e, appending the
128  /// existing key with the front ancestor's "AncestorKey". Note that a key
129  /// directly reflects the BFS and thus needs to be refreshed during the
130  /// progression of the traversal.
131  void refreshKey() {
132  if (ancestorQueue.empty())
133  return;
134 
135  Operation *frontAncestor = ancestorQueue.front();
136  AncestorKey frontAncestorKey(frontAncestor);
137  key.push_back(frontAncestorKey);
138  }
139 
140  /// Pop the front ancestor, if any, from the queue and then push its adjacent
141  /// unvisited ancestors, if any, to the queue (this is the main body of the
142  /// BFS algorithm).
144  if (ancestorQueue.empty())
145  return;
146  Operation *frontAncestor = ancestorQueue.front();
147  ancestorQueue.pop();
148  if (!frontAncestor)
149  return;
150  for (Value operand : frontAncestor->getOperands()) {
151  Operation *operandDefOp = operand.getDefiningOp();
152  if (!operandDefOp || !visitedAncestors.contains(operandDefOp))
153  pushAncestor(operandDefOp);
154  }
155  }
156 };
157 
158 /// Sorts the operands of `op` in ascending order of the "key" associated with
159 /// each operand iff `op` is commutative. This is a stable sort.
160 ///
161 /// After the application of this pattern, since the commutative operands now
162 /// have a deterministic order in which they occur in an op, the matching of
163 /// large DAGs becomes much simpler, i.e., requires much less number of checks
164 /// to be written by a user in her/his pattern matching function.
165 ///
166 /// Some examples of such a sorting:
167 ///
168 /// Assume that the sorting is being applied to `foo.commutative`, which is a
169 /// commutative op.
170 ///
171 /// Example 1:
172 ///
173 /// %1 = foo.const 0
174 /// %2 = foo.mul <block argument>, <block argument>
175 /// %3 = foo.commutative %1, %2
176 ///
177 /// Here,
178 /// 1. The key associated with %1 is:
179 /// `{
180 /// {CONSTANT_OP, "foo.const"}
181 /// }`
182 /// 2. The key associated with %2 is:
183 /// `{
184 /// {NON_CONSTANT_OP, "foo.mul"},
185 /// {BLOCK_ARGUMENT, ""},
186 /// {BLOCK_ARGUMENT, ""}
187 /// }`
188 ///
189 /// The key of %2 < the key of %1
190 /// Thus, the sorted `foo.commutative` is:
191 /// %3 = foo.commutative %2, %1
192 ///
193 /// Example 2:
194 ///
195 /// %1 = foo.const 0
196 /// %2 = foo.mul <block argument>, <block argument>
197 /// %3 = foo.mul %2, %1
198 /// %4 = foo.add %2, %1
199 /// %5 = foo.commutative %1, %2, %3, %4
200 ///
201 /// Here,
202 /// 1. The key associated with %1 is:
203 /// `{
204 /// {CONSTANT_OP, "foo.const"}
205 /// }`
206 /// 2. The key associated with %2 is:
207 /// `{
208 /// {NON_CONSTANT_OP, "foo.mul"},
209 /// {BLOCK_ARGUMENT, ""}
210 /// }`
211 /// 3. The key associated with %3 is:
212 /// `{
213 /// {NON_CONSTANT_OP, "foo.mul"},
214 /// {NON_CONSTANT_OP, "foo.mul"},
215 /// {CONSTANT_OP, "foo.const"},
216 /// {BLOCK_ARGUMENT, ""},
217 /// {BLOCK_ARGUMENT, ""}
218 /// }`
219 /// 4. The key associated with %4 is:
220 /// `{
221 /// {NON_CONSTANT_OP, "foo.add"},
222 /// {NON_CONSTANT_OP, "foo.mul"},
223 /// {CONSTANT_OP, "foo.const"},
224 /// {BLOCK_ARGUMENT, ""},
225 /// {BLOCK_ARGUMENT, ""}
226 /// }`
227 ///
228 /// Thus, the sorted `foo.commutative` is:
229 /// %5 = foo.commutative %4, %3, %2, %1
231 public:
233  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {}
234  LogicalResult matchAndRewrite(Operation *op,
235  PatternRewriter &rewriter) const override {
236  // Custom comparator for two commutative operands, which returns true iff
237  // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`,
238  // i.e.,
239  // 1. In the first unequal pair of corresponding AncestorKeys, the
240  // AncestorKey in `constCommOperandA` is smaller, or,
241  // 2. Both the AncestorKeys in every pair are the same and the size of
242  // `constCommOperandA`'s "key" is smaller.
243  auto commutativeOperandComparator =
244  [](const std::unique_ptr<CommutativeOperand> &constCommOperandA,
245  const std::unique_ptr<CommutativeOperand> &constCommOperandB) {
246  if (constCommOperandA->operand == constCommOperandB->operand)
247  return false;
248 
249  auto &commOperandA =
250  const_cast<std::unique_ptr<CommutativeOperand> &>(
251  constCommOperandA);
252  auto &commOperandB =
253  const_cast<std::unique_ptr<CommutativeOperand> &>(
254  constCommOperandB);
255 
256  // Iteratively perform the BFS's of both operands until an order among
257  // them can be determined.
258  unsigned keyIndex = 0;
259  while (true) {
260  if (commOperandA->key.size() <= keyIndex) {
261  if (commOperandA->ancestorQueue.empty())
262  return true;
263  commOperandA->popFrontAndPushAdjacentUnvisitedAncestors();
264  commOperandA->refreshKey();
265  }
266  if (commOperandB->key.size() <= keyIndex) {
267  if (commOperandB->ancestorQueue.empty())
268  return false;
269  commOperandB->popFrontAndPushAdjacentUnvisitedAncestors();
270  commOperandB->refreshKey();
271  }
272  if (commOperandA->ancestorQueue.empty() ||
273  commOperandB->ancestorQueue.empty())
274  return commOperandA->key.size() < commOperandB->key.size();
275  if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex])
276  return true;
277  if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex])
278  return false;
279  keyIndex++;
280  }
281  };
282 
283  // If `op` is not commutative, do nothing.
284  if (!op->hasTrait<OpTrait::IsCommutative>())
285  return failure();
286 
287  // Populate the list of commutative operands.
288  SmallVector<Value, 2> operands = op->getOperands();
290  for (Value operand : operands) {
291  std::unique_ptr<CommutativeOperand> commOperand =
292  std::make_unique<CommutativeOperand>();
293  commOperand->operand = operand;
294  commOperand->pushAncestor(operand.getDefiningOp());
295  commOperand->refreshKey();
296  commOperands.push_back(std::move(commOperand));
297  }
298 
299  // Sort the operands.
300  std::stable_sort(commOperands.begin(), commOperands.end(),
301  commutativeOperandComparator);
302  SmallVector<Value, 2> sortedOperands;
303  for (const std::unique_ptr<CommutativeOperand> &commOperand : commOperands)
304  sortedOperands.push_back(commOperand->operand);
305  if (sortedOperands == operands)
306  return failure();
307  rewriter.modifyOpInPlace(op, [&] { op->setOperands(sortedOperands); });
308  return success();
309  }
310 };
311 
313  patterns.add<SortCommutativeOperands>(patterns.getContext());
314 }
AncestorType
The possible "types" of ancestors.
@ NON_CONSTANT_OP
Pertains to a non-constant-like op.
@ CONSTANT_OP
Pertains to a constant-like op.
@ BLOCK_ARGUMENT
Pertains to a block argument.
Sorts the operands of op in ascending order of the "key" associated with each operand iff op is commu...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
SortCommutativeOperands(MLIRContext *context)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides the API for a sub-set of ops that are known to be constant-like.
This class adds property that the operation is commutative.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Include the generated interface declarations.
void populateCommutativityUtilsPatterns(RewritePatternSet &patterns)
Populates the commutativity utility patterns.
const FrozenRewritePatternSet & patterns
Stores the "key" associated with an ancestor.
AncestorKey(Operation *op)
Constructor for AncestorKey.
StringRef opName
Holds the op name of the ancestor if its type is NON_CONSTANT_OP or CONSTANT_OP.
bool operator<(const AncestorKey &key) const
Overloaded operator < for AncestorKey.
AncestorType type
Holds BLOCK_ARGUMENT, NON_CONSTANT_OP, or CONSTANT_OP, depending on the ancestor.
Stores a commutative operand along with its BFS traversal information.
void popFrontAndPushAdjacentUnvisitedAncestors()
Pop the front ancestor, if any, from the queue and then push its adjacent unvisited ancestors,...
SmallVector< AncestorKey, 4 > key
Stores the operand's "key".
DenseSet< Operation * > visitedAncestors
Stores the list of ancestors that have been visited by the BFS traversal at a particular point in tim...
void refreshKey()
Refresh the key.
Value operand
Stores the operand.
void pushAncestor(Operation *op)
Push an ancestor into the operand's BFS information structure.
std::queue< Operation * > ancestorQueue
Stores the queue of ancestors of the operand's BFS traversal at a particular point in time.
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:159