MLIR 23.0.0git
ReductionTreePass.cpp
Go to the documentation of this file.
1//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the Reduction Tree Pass class. It provides a framework for
10// the implementation of different reduction passes in the MLIR Reduce tool. It
11// allows for custom specification of the variant generation behavior. It
12// implements methods that define the different possible traversals of the
13// reduction tree.
14//
15//===----------------------------------------------------------------------===//
16
18#include "mlir/Reducer/Passes.h"
21#include "mlir/Reducer/Tester.h"
24
25#include "llvm/ADT/ArrayRef.h"
26#include "llvm/Support/Allocator.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_REDUCTIONTREEPASS
30#include "mlir/Reducer/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34
35/// We implicitly number each operation in the region and if an operation's
36/// number falls into rangeToKeep, we need to keep it and apply the given
37/// rewrite patterns on it.
38static void applyPatterns(Region &region,
39 const FrozenRewritePatternSet &patterns,
41 bool eraseOpNotInRange) {
42 std::vector<Operation *> opsNotInRange;
43 size_t keepIndex = 0;
44 for (const auto &op : enumerate(region.getOps())) {
45 int index = op.index();
46 if (keepIndex < rangeToKeep.size() &&
47 index == rangeToKeep[keepIndex].second)
48 ++keepIndex;
49 if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
50 opsNotInRange.push_back(&op.value());
51 }
52
53 // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
54 // pattern matching in above iteration. Besides, erase op not-in-range may end
55 // up in invalid module, so `applyOpPatternsGreedily` with folding should come
56 // before that transform.
57 if (!eraseOpNotInRange)
58 for (Operation *op : opsNotInRange) {
59 // `applyOpPatternsGreedily` with folding returns whether the op is
60 // converted. Omit it because we don't have expectation this reduction
61 // will be success or not.
62 (void)applyOpPatternsGreedily(op, patterns,
63 GreedyRewriteConfig().setStrictness(
65 }
66
67 if (eraseOpNotInRange)
68 for (Operation *op : opsNotInRange) {
69 op->dropAllUses();
70 op->erase();
71 }
72}
73
74/// We will apply the reducer patterns to the operations in the ranges specified
75/// by ReductionNode. Note that we are not able to remove an operation without
76/// replacing it with another valid operation. However, The validity of module
77/// reduction is based on the Tester provided by the user and that means certain
78/// invalid module is still interested by the use. Thus we provide an
79/// alternative way to remove operations, which is using `eraseOpNotInRange` to
80/// erase the operations not in the range specified by ReductionNode.
81template <typename IteratorType>
82static LogicalResult findOptimal(ModuleOp module, Region &region,
83 const FrozenRewritePatternSet &patterns,
84 const Tester &test, bool eraseOpNotInRange) {
85 std::pair<Tester::Interestingness, size_t> initStatus =
86 test.isInteresting(module);
87 // While exploring the reduction tree, we always branch from an interesting
88 // node. Thus the root node must be interesting.
89 if (initStatus.first != Tester::Interestingness::True)
90 return module.emitError() << "uninterested module will not be reduced";
91
92 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
93
94 std::vector<ReductionNode::Range> ranges{
95 {0, std::distance(region.op_begin(), region.op_end())}};
96
97 ReductionNode *root = allocator.Allocate();
98 new (root) ReductionNode(nullptr, ranges, allocator);
99 // Duplicate the module for root node and locate the region in the copy.
100 if (failed(root->initialize(module, region)))
101 llvm_unreachable("unexpected initialization failure");
102 root->update(initStatus);
103
104 ReductionNode *smallestNode = root;
105 IteratorType iter(root);
106
107 while (iter != IteratorType::end()) {
108 ReductionNode &currentNode = *iter;
109 Region &curRegion = currentNode.getRegion();
110
111 applyPatterns(curRegion, patterns, currentNode.getRanges(),
112 eraseOpNotInRange);
113 currentNode.update(test.isInteresting(currentNode.getModule()));
114
115 if (currentNode.isInteresting() == Tester::Interestingness::True &&
116 currentNode.getSize() < smallestNode->getSize())
117 smallestNode = &currentNode;
118
119 ++iter;
120 }
121
122 // At here, we have found an optimal path to reduce the given region. Retrieve
123 // the path and apply the reducer to it.
125 ReductionNode *curNode = smallestNode;
126 trace.push_back(curNode);
127 while (curNode != root) {
128 curNode = curNode->getParent();
129 trace.push_back(curNode);
130 }
131
132 // Reduce the region through the optimal path.
133 while (!trace.empty()) {
134 ReductionNode *top = trace.pop_back_val();
135 applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
136 }
137
138 if (test.isInteresting(module).first != Tester::Interestingness::True)
139 llvm::report_fatal_error("Reduced module is not interesting");
140 if (test.isInteresting(module).second != smallestNode->getSize())
141 llvm::report_fatal_error(
142 "Reduced module doesn't have consistent size with smallestNode");
143 return success();
146/// This function attempts to erase all operations within the region currently
147/// being processed.
148static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region &region,
149 const Tester &test) {
150 std::pair<Tester::Interestingness, size_t> initStatus =
151 test.isInteresting(module);
153 // While exploring the reduction tree, we always branch from an interesting
154 // node. Thus the root node must be interesting.
155 if (initStatus.first != Tester::Interestingness::True)
156 return module.emitError() << "uninterested module will not be reduced";
157 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
158
159 // Setting the ranges to {{0, 0}} will result in the deletion of all ops
160 // within the region.
161 std::vector<ReductionNode::Range> ranges{{0, 0}};
162
163 // We allocate memory on the stack, and the 'allocator' is only used to
164 // construct the 'root node'. Since we won't be constructing any child nodes
165 // for emptyRegionNode, it is only used within the current scope.
166 ReductionNode emptyRegionNode(nullptr, ranges, allocator);
167 ReductionNode *root = &emptyRegionNode;
168
169 // Create a copy of the current IR.
170 if (failed(root->initialize(module, region)))
171 llvm_unreachable("unexpected initialization failure");
172
173 // Erase all operations within the corresponding region of the clone.
174 applyPatterns(root->getRegion(), {}, root->getRanges(), true);
175 root->update(test.isInteresting(root->getModule()));
177 // If we can successfully remove all ops in the region, we apply the same
178 // transformation to the original IR and return success.
179 applyPatterns(region, {}, root->getRanges(), true);
180 return success();
181 }
182 return failure();
183}
184
185template <typename IteratorType>
186static LogicalResult findOptimal(ModuleOp module, Region &region,
187 const FrozenRewritePatternSet &patterns,
188 const Tester &test) {
189 // We separate the reduction process into 3 steps, the first one is to erase
190 // redundant operations and the second one is to apply the reducer patterns.
192 // In the first phase, we attempt to erase all operations within the entire
193 // region.
194 if (succeeded(eraseAllOpsInRegion(module, region, test)))
195 return success();
197 // In the second phase, we don't apply any patterns so that we only select the
198 // range of operations to keep to the module stay interesting.
199 if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
200 /*eraseOpNotInRange=*/true)))
201 return failure();
202 // In the third phase, we suppose that no operation is redundant, so we try
203 // to rewrite the operation into simpler form.
204 return findOptimal<IteratorType>(module, region, patterns, test,
205 /*eraseOpNotInRange=*/false);
207
208namespace {
209
210//===----------------------------------------------------------------------===//
211// Reduction Pattern Interface Collection
212//===----------------------------------------------------------------------===//
213
214class ReductionPatternInterfaceCollection
215 : public DialectInterfaceCollection<DialectReductionPatternInterface> {
216public:
217 using Base::Base;
218
219 // Collect the reduce patterns defined by each dialect.
220 void populateReductionPatterns(RewritePatternSet &pattern,
221 Tester &tester) const {
222 for (const DialectReductionPatternInterface &interface : *this) {
223 interface.populateReductionPatterns(pattern);
224 interface.populateReductionPatternsWithTester(pattern, tester);
225 }
226 }
227};
228
229//===----------------------------------------------------------------------===//
230// ReductionTreePass
231//===----------------------------------------------------------------------===//
232
233/// This class defines the Reduction Tree Pass. It provides a framework to
234/// to implement a reduction pass using a tree structure to keep track of the
235/// generated reduced variants.
236class ReductionTreePass
237 : public impl::ReductionTreePassBase<ReductionTreePass> {
238public:
239 using Base::Base;
240
241 LogicalResult initialize(MLIRContext *context) override;
242
243 /// Runs the pass instance in the pass pipeline.
244 void runOnOperation() override;
245
246private:
247 LogicalResult reduceOp(ModuleOp module, Region &region);
248
249 Tester tester;
250 FrozenRewritePatternSet reducerPatterns;
251};
252
253} // namespace
254
255LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
256 tester.setTestScript(testerName);
257 tester.setTestScriptArgs(testerArgs);
258
259 RewritePatternSet patterns(context);
260
261 ReductionPatternInterfaceCollection reducePatternCollection(context);
262 reducePatternCollection.populateReductionPatterns(patterns, tester);
263
264 reducerPatterns = std::move(patterns);
265 return success();
266}
267
268void ReductionTreePass::runOnOperation() {
269 Operation *topOperation = getOperation();
270 while (topOperation->getParentOp() != nullptr)
271 topOperation = topOperation->getParentOp();
272 ModuleOp module = dyn_cast<ModuleOp>(topOperation);
273 if (!module) {
274 emitError(getOperation()->getLoc())
275 << "top-level op must be 'builtin.module'";
276 return signalPassFailure();
277 }
278
279 SmallVector<Operation *, 8> workList;
280 workList.push_back(getOperation());
281
282 do {
283 Operation *op = workList.pop_back_val();
284
285 for (Region &region : op->getRegions())
286 if (!region.empty())
287 if (failed(reduceOp(module, region)))
288 return signalPassFailure();
289
290 for (Region &region : op->getRegions())
291 for (Operation &op : region.getOps())
292 if (op.getNumRegions() != 0)
293 workList.push_back(&op);
294 } while (!workList.empty());
295}
296
297LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
298 switch (traversalModeId) {
299 case TraversalMode::SinglePath:
301 module, region, reducerPatterns, tester);
302 default:
303 return module.emitError() << "unsupported traversal mode detected";
304 }
305}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region &region, const Tester &test)
This function attempts to erase all operations within the region currently being processed.
static LogicalResult findOptimal(ModuleOp module, Region &region, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
A collection of dialect interfaces within a context, for a given concrete interface type.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
size_t getSize() const
Return the size of the module.
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Region & getRegion() const
Return the region we're reducing.
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
ReductionNode * getParent() const
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition Region.h:170
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
OpIterator op_end()
Definition Region.h:171
This class is used to keep track of the testing environment of the tool.
Definition Tester.h:31
void setTestScriptArgs(ArrayRef< std::string > args)
Definition Tester.h:53
std::pair< Interestingness, size_t > isInteresting(Operation *topOp) const
Runs the interestingness testing script on a MLIR test case file.
Definition Tester.cpp:27
void setTestScript(StringRef script)
Definition Tester.h:52
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
@ ExistingOps
Only pre-existing ops are processed.