MLIR 22.0.0git
SinkVectorProducerOps.cpp
Go to the documentation of this file.
1//===- SinkVectorProducerOps.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13
15#include "mlir/IR/Dominance.h"
17
18#include "mlir/Pass/Pass.h"
20
21using namespace mlir;
22using namespace mlir::vector;
23using namespace mlir::x86vector;
24
25static FailureOr<llvm::SmallVector<Operation *>>
28 for (OpResult result : op->getResults()) {
29 for (Operation *user : result.getUsers()) {
30 // Check prod and users belongs to same block.
31 if (op->getBlock() != user->getBlock())
32 return failure();
33 opUsers.push_back(user);
34 }
35 }
36
37 return opUsers;
38}
39
40// Prevent pathological looping:
41// If two/three producers are used by same consumer, will end in looping of
42// moving the producers.
43// For example:
44// %1 = prod1
45// %2 = prod2
46// %3 = prod3
47// %4 = op %1, %2, %3
48static bool checkLooping(Operation *op) {
50 operations.push_back(op);
51
52 // Retrive the next immediate operation until it is a vector.load or
53 // a vector.transfer_read
54 Operation *nextOp = op->getNextNode();
55 while (nextOp) {
56 if (isa<vector::LoadOp>(nextOp) || isa<vector::TransferReadOp>(nextOp)) {
57 operations.push_back(op);
58 } else {
59 break;
60 }
61 nextOp = nextOp->getNextNode();
62 }
63
64 // If all the loads or transfer_reads have same immediate nextOp as its
65 // user, then it loops.
66 for (Operation *op : operations) {
67 FailureOr<llvm::SmallVector<Operation *>> users = getSameBlockUsers(op);
68 if (failed(users))
69 return false;
70
71 if (!llvm::is_contained(*users, nextOp))
72 return false;
73 }
74
75 return true;
76}
77
78/// Sink vector producers forward to reduce live ranges.
79/// This pattern applies to ops such as vector.load and vector.transfer_read.
80template <typename producerOp>
81struct SinkVectorProducerOps final : public OpRewritePattern<producerOp> {
82 using OpRewritePattern<producerOp>::OpRewritePattern;
83
84 LogicalResult matchAndRewrite(producerOp op,
85 PatternRewriter &rewriter) const override {
86
87 auto users = getSameBlockUsers(op);
88 if (failed(users))
89 return failure();
90
91 if (checkLooping(op))
92 return failure();
93
96
97 llvm::SmallVector<Operation *> opUsers = *users;
98 prodsAllUsers.try_emplace(op, opUsers);
99
100 // Iterate until the last instruction to find the first users of all
101 // producers within the block.
102 Operation *nextOp = op;
103
104 while ((nextOp = nextOp->getNextNode())) {
105
106 if (isa<vector::LoadOp>(nextOp) || isa<vector::TransferReadOp>(nextOp)) {
107 auto nextUsers = getSameBlockUsers(nextOp);
108
109 if (failed(nextUsers))
110 continue;
111 llvm::SmallVector<Operation *> nextOpUsers = *nextUsers;
112 prodsAllUsers.try_emplace(nextOp, nextOpUsers);
113 } else {
115
116 for (auto &entry : prodsAllUsers) {
117 llvm::SmallVector<Operation *> &users = entry.second;
118
119 if (llvm::is_contained(users, nextOp)) {
120 Operation *operation = entry.first;
121 operations.push_back(operation);
122 prodsFirstUser.try_emplace(operation, nextOp);
123 }
124 }
125
126 for (Operation *op : operations) {
127 prodsAllUsers.erase(op);
128 }
129 }
130 }
131
132 // Move all the loads or transfer_reads before its first use.
133 for (auto &entry : prodsFirstUser) {
134 Operation *prod = entry.first;
135 Operation *consumer = entry.second;
136
137 prod->moveBefore(consumer);
138 }
139
140 return success();
141 }
142};
143
return success()
static bool checkLooping(Operation *op)
static FailureOr< llvm::SmallVector< Operation * > > getSameBlockUsers(Operation *op)
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Sink vector producers forward to reduce live ranges.
LogicalResult matchAndRewrite(producerOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})