69 return std::tie(type, opName) < std::tie(key.
type, key.
opName);
119 ancestorQueue.push(op);
121 visitedAncestors.insert(op);
132 if (ancestorQueue.empty())
135 Operation *frontAncestor = ancestorQueue.front();
137 key.push_back(frontAncestorKey);
144 if (ancestorQueue.empty())
146 Operation *frontAncestor = ancestorQueue.front();
152 if (!operandDefOp || !visitedAncestors.contains(operandDefOp))
153 pushAncestor(operandDefOp);
243 auto commutativeOperandComparator =
244 [](
const std::unique_ptr<CommutativeOperand> &constCommOperandA,
245 const std::unique_ptr<CommutativeOperand> &constCommOperandB) {
246 if (constCommOperandA->operand == constCommOperandB->operand)
250 const_cast<std::unique_ptr<CommutativeOperand> &
>(
253 const_cast<std::unique_ptr<CommutativeOperand> &
>(
258 unsigned keyIndex = 0;
260 if (commOperandA->key.size() <= keyIndex) {
261 if (commOperandA->ancestorQueue.empty())
263 commOperandA->popFrontAndPushAdjacentUnvisitedAncestors();
264 commOperandA->refreshKey();
266 if (commOperandB->key.size() <= keyIndex) {
267 if (commOperandB->ancestorQueue.empty())
269 commOperandB->popFrontAndPushAdjacentUnvisitedAncestors();
270 commOperandB->refreshKey();
272 if (commOperandA->ancestorQueue.empty() ||
273 commOperandB->ancestorQueue.empty())
274 return commOperandA->key.size() < commOperandB->key.size();
275 if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex])
277 if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex])
290 for (
Value operand : operands) {
291 std::unique_ptr<CommutativeOperand> commOperand =
292 std::make_unique<CommutativeOperand>();
293 commOperand->operand = operand;
294 commOperand->pushAncestor(operand.getDefiningOp());
295 commOperand->refreshKey();
296 commOperands.push_back(std::move(commOperand));
300 std::stable_sort(commOperands.begin(), commOperands.end(),
301 commutativeOperandComparator);
303 for (
const std::unique_ptr<CommutativeOperand> &commOperand : commOperands)
304 sortedOperands.push_back(commOperand->operand);
305 if (sortedOperands == operands)
AncestorType
The possible "types" of ancestors.
@ NON_CONSTANT_OP
Pertains to a non-constant-like op.
@ CONSTANT_OP
Pertains to a constant-like op.
@ BLOCK_ARGUMENT
Pertains to a block argument.
Sorts the operands of op in ascending order of the "key" associated with each operand iff op is commu...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
SortCommutativeOperands(MLIRContext *context)
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class adds property that the operation is commutative.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
void populateCommutativityUtilsPatterns(RewritePatternSet &patterns)
Populates the commutativity utility patterns.
Stores the "key" associated with an ancestor.
AncestorKey(Operation *op)
Constructor for AncestorKey.
StringRef opName
Holds the op name of the ancestor if its type is NON_CONSTANT_OP or CONSTANT_OP.
bool operator<(const AncestorKey &key) const
Overloaded operator < for AncestorKey.
AncestorType type
Holds BLOCK_ARGUMENT, NON_CONSTANT_OP, or CONSTANT_OP, depending on the ancestor.
Stores a commutative operand along with its BFS traversal information.
void popFrontAndPushAdjacentUnvisitedAncestors()
Pop the front ancestor, if any, from the queue and then push its adjacent unvisited ancestors,...
SmallVector< AncestorKey, 4 > key
Stores the operand's "key".
DenseSet< Operation * > visitedAncestors
Stores the list of ancestors that have been visited by the BFS traversal at a particular point in tim...
void refreshKey()
Refresh the key.
Value operand
Stores the operand.
void pushAncestor(Operation *op)
Push an ancestor into the operand's BFS information structure.
std::queue< Operation * > ancestorQueue
Stores the queue of ancestors of the operand's BFS traversal at a particular point in time.
This class acts as a special tag that makes the desire to match "any" operation type explicit.