MLIR 23.0.0git
ReductionTreePass.cpp
Go to the documentation of this file.
1//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the Reduction Tree Pass class. It provides a framework for
10// the implementation of different reduction passes in the MLIR Reduce tool. It
11// allows for custom specification of the variant generation behavior. It
12// implements methods that define the different possible traversals of the
13// reduction tree.
14//
15//===----------------------------------------------------------------------===//
16
18#include "mlir/Reducer/Passes.h"
21#include "mlir/Reducer/Tester.h"
24
25#include "llvm/ADT/ArrayRef.h"
26#include "llvm/Support/Allocator.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_REDUCTIONTREEPASS
30#include "mlir/Reducer/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34
35/// We implicitly number each operation in the region and if an operation's
36/// number falls into rangeToKeep, we need to keep it and apply the given
37/// rewrite patterns on it.
38static void applyPatterns(Region &region,
39 const FrozenRewritePatternSet &patterns,
41 bool eraseOpNotInRange) {
42 std::vector<Operation *> opsNotInRange;
43 std::vector<Operation *> opsInRange;
44 size_t keepIndex = 0;
45 for (const auto &op : enumerate(region.getOps())) {
46 int index = op.index();
47 if (keepIndex < rangeToKeep.size() &&
48 index == rangeToKeep[keepIndex].second)
49 ++keepIndex;
50 if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
51 opsNotInRange.push_back(&op.value());
52 else
53 opsInRange.push_back(&op.value());
54 }
55
56 // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
57 // pattern matching in above iteration. Besides, erase op not-in-range may end
58 // up in invalid module, so `applyOpPatternsGreedily` with folding should come
59 // before that transform.
60 for (Operation *op : opsInRange) {
61 // `applyOpPatternsGreedily` with folding returns whether the op is
62 // converted. Omit it because we don't have expectation this reduction will
63 // be success or not.
64 (void)applyOpPatternsGreedily(op, patterns,
65 GreedyRewriteConfig().setStrictness(
67 }
68
69 if (eraseOpNotInRange)
70 for (Operation *op : opsNotInRange) {
71 op->dropAllUses();
72 op->erase();
73 }
74}
75
76/// We will apply the reducer patterns to the operations in the ranges specified
77/// by ReductionNode. Note that we are not able to remove an operation without
78/// replacing it with another valid operation. However, The validity of module
79/// reduction is based on the Tester provided by the user and that means certain
80/// invalid module is still interested by the use. Thus we provide an
81/// alternative way to remove operations, which is using `eraseOpNotInRange` to
82/// erase the operations not in the range specified by ReductionNode.
83template <typename IteratorType>
84static LogicalResult findOptimal(ModuleOp module, Region &region,
85 const FrozenRewritePatternSet &patterns,
86 const Tester &test, bool eraseOpNotInRange) {
87 std::pair<Tester::Interestingness, size_t> initStatus =
88 test.isInteresting(module);
89 // While exploring the reduction tree, we always branch from an interesting
90 // node. Thus the root node must be interesting.
91 if (initStatus.first != Tester::Interestingness::True)
92 return module.emitWarning() << "uninterested module will not be reduced";
93
94 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
95
96 std::vector<ReductionNode::Range> ranges{
97 {0, std::distance(region.op_begin(), region.op_end())}};
98
99 ReductionNode *root = allocator.Allocate();
100 new (root) ReductionNode(nullptr, ranges, allocator);
101 // Duplicate the module for root node and locate the region in the copy.
102 if (failed(root->initialize(module, region)))
103 llvm_unreachable("unexpected initialization failure");
104 root->update(initStatus);
105
106 ReductionNode *smallestNode = root;
107 IteratorType iter(root);
108
109 while (iter != IteratorType::end()) {
110 ReductionNode &currentNode = *iter;
111 Region &curRegion = currentNode.getRegion();
112
113 applyPatterns(curRegion, patterns, currentNode.getRanges(),
114 eraseOpNotInRange);
115 currentNode.update(test.isInteresting(currentNode.getModule()));
116
117 if (currentNode.isInteresting() == Tester::Interestingness::True &&
118 currentNode.getSize() < smallestNode->getSize())
119 smallestNode = &currentNode;
120
121 ++iter;
122 }
123
124 // At here, we have found an optimal path to reduce the given region. Retrieve
125 // the path and apply the reducer to it.
127 ReductionNode *curNode = smallestNode;
128 trace.push_back(curNode);
129 while (curNode != root) {
130 curNode = curNode->getParent();
131 trace.push_back(curNode);
132 }
133
134 // Reduce the region through the optimal path.
135 while (!trace.empty()) {
136 ReductionNode *top = trace.pop_back_val();
137 applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
138 }
139
140 if (test.isInteresting(module).first != Tester::Interestingness::True)
141 llvm::report_fatal_error("Reduced module is not interesting");
142 if (test.isInteresting(module).second != smallestNode->getSize())
143 llvm::report_fatal_error(
144 "Reduced module doesn't have consistent size with smallestNode");
145 return success();
146}
147
148/// This function attempts to erase all operations within the region currently
149/// being processed.
150static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region &region,
151 const Tester &test) {
152 std::pair<Tester::Interestingness, size_t> initStatus =
153 test.isInteresting(module);
154
155 // While exploring the reduction tree, we always branch from an interesting
156 // node. Thus the root node must be interesting.
157 if (initStatus.first != Tester::Interestingness::True)
158 return module.emitWarning() << "uninterested module will not be reduced";
159 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
160
161 // Setting the ranges to {{0, 0}} will result in the deletion of all ops
162 // within the region.
163 std::vector<ReductionNode::Range> ranges{{0, 0}};
164
165 // We allocate memory on the stack, and the 'allocator' is only used to
166 // construct the 'root node'. Since we won't be constructing any child nodes
167 // for emptyRegionNode, it is only used within the current scope.
168 ReductionNode emptyRegionNode(nullptr, ranges, allocator);
169 ReductionNode *root = &emptyRegionNode;
170
171 // Create a copy of the current IR.
172 if (failed(root->initialize(module, region)))
173 llvm_unreachable("unexpected initialization failure");
174
175 // Erase all operations within the corresponding region of the clone.
176 applyPatterns(root->getRegion(), {}, root->getRanges(), true);
177 root->update(test.isInteresting(root->getModule()));
179 // If we can successfully remove all ops in the region, we apply the same
180 // transformation to the original IR and return success.
181 applyPatterns(region, {}, root->getRanges(), true);
182 return success();
183 }
184 return failure();
185}
186
187template <typename IteratorType>
188static LogicalResult findOptimal(ModuleOp module, Region &region,
189 const FrozenRewritePatternSet &patterns,
190 const Tester &test) {
191 // We separate the reduction process into 3 steps, the first one is to erase
192 // redundant operations and the second one is to apply the reducer patterns.
193
194 // In the first phase, we attempt to erase all operations within the entire
195 // region.
196 if (succeeded(eraseAllOpsInRegion(module, region, test)))
197 return success();
198
199 // In the second phase, we don't apply any patterns so that we only select the
200 // range of operations to keep to the module stay interesting.
201 if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
202 /*eraseOpNotInRange=*/true)))
203 return failure();
204 // In the third phase, we suppose that no operation is redundant, so we try
205 // to rewrite the operation into simpler form.
206 return findOptimal<IteratorType>(module, region, patterns, test,
207 /*eraseOpNotInRange=*/false);
208}
209
210namespace {
211
212//===----------------------------------------------------------------------===//
213// Reduction Pattern Interface Collection
214//===----------------------------------------------------------------------===//
215
216class ReductionPatternInterfaceCollection
217 : public DialectInterfaceCollection<DialectReductionPatternInterface> {
218public:
219 using Base::Base;
220
221 // Collect the reduce patterns defined by each dialect.
222 void populateReductionPatterns(RewritePatternSet &pattern,
223 Tester &tester) const {
224 for (const DialectReductionPatternInterface &interface : *this) {
225 interface.populateReductionPatterns(pattern);
226 interface.populateReductionPatternsWithTester(pattern, tester);
227 }
228 }
229};
230
231//===----------------------------------------------------------------------===//
232// ReductionTreePass
233//===----------------------------------------------------------------------===//
234
235/// This class defines the Reduction Tree Pass. It provides a framework to
236/// to implement a reduction pass using a tree structure to keep track of the
237/// generated reduced variants.
238class ReductionTreePass
239 : public impl::ReductionTreePassBase<ReductionTreePass> {
240public:
241 using Base::Base;
242
243 LogicalResult initialize(MLIRContext *context) override;
244
245 /// Runs the pass instance in the pass pipeline.
246 void runOnOperation() override;
247
248private:
249 LogicalResult reduceOp(ModuleOp module, Region &region);
250
251 Tester tester;
252 FrozenRewritePatternSet reducerPatterns;
253};
254
255} // namespace
256
257LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
258 tester.setTestScript(testerName);
259 tester.setTestScriptArgs(testerArgs);
260
261 RewritePatternSet patterns(context);
262
263 ReductionPatternInterfaceCollection reducePatternCollection(context);
264 reducePatternCollection.populateReductionPatterns(patterns, tester);
265
266 reducerPatterns = std::move(patterns);
267 return success();
268}
269
270void ReductionTreePass::runOnOperation() {
271 Operation *topOperation = getOperation();
272 while (topOperation->getParentOp() != nullptr)
273 topOperation = topOperation->getParentOp();
274 ModuleOp module = dyn_cast<ModuleOp>(topOperation);
275 if (!module) {
276 emitError(getOperation()->getLoc())
277 << "top-level op must be 'builtin.module'";
278 return signalPassFailure();
279 }
280
281 SmallVector<Operation *, 8> workList;
282 workList.push_back(getOperation());
283
284 do {
285 Operation *op = workList.pop_back_val();
286
287 for (Region &region : op->getRegions())
288 if (!region.empty())
289 if (failed(reduceOp(module, region)))
290 return signalPassFailure();
291
292 for (Region &region : op->getRegions())
293 for (Operation &op : region.getOps())
294 if (op.getNumRegions() != 0)
295 workList.push_back(&op);
296 } while (!workList.empty());
297}
298
299LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
300 switch (traversalModeId) {
301 case TraversalMode::SinglePath:
303 module, region, reducerPatterns, tester);
304 default:
305 return module.emitError() << "unsupported traversal mode detected";
306 }
307}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region &region, const Tester &test)
This function attempts to erase all operations within the region currently being processed.
static LogicalResult findOptimal(ModuleOp module, Region &region, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
A collection of dialect interfaces within a context, for a given concrete interface type.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
size_t getSize() const
Return the size of the module.
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Region & getRegion() const
Return the region we're reducing.
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
ReductionNode * getParent() const
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition Region.h:170
iterator_range< OpIterator > getOps()
Definition Region.h:172
OpIterator op_end()
Definition Region.h:171
This class is used to keep track of the testing environment of the tool.
Definition Tester.h:31
void setTestScriptArgs(ArrayRef< std::string > args)
Definition Tester.h:53
std::pair< Interestingness, size_t > isInteresting(ModuleOp module) const
Runs the interestingness testing script on a MLIR test case file.
Definition Tester.cpp:27
void setTestScript(StringRef script)
Definition Tester.h:52
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
@ ExistingOps
Only pre-existing ops are processed.