MLIR 22.0.0git
ReductionTreePass.cpp
Go to the documentation of this file.
1//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the Reduction Tree Pass class. It provides a framework for
10// the implementation of different reduction passes in the MLIR Reduce tool. It
11// allows for custom specification of the variant generation behavior. It
12// implements methods that define the different possible traversals of the
13// reduction tree.
14//
15//===----------------------------------------------------------------------===//
16
18#include "mlir/Reducer/Passes.h"
21#include "mlir/Reducer/Tester.h"
24
25#include "llvm/ADT/ArrayRef.h"
26#include "llvm/Support/Allocator.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_REDUCTIONTREEPASS
30#include "mlir/Reducer/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34
35/// We implicitly number each operation in the region and if an operation's
36/// number falls into rangeToKeep, we need to keep it and apply the given
37/// rewrite patterns on it.
38static void applyPatterns(Region &region,
41 bool eraseOpNotInRange) {
42 std::vector<Operation *> opsNotInRange;
43 std::vector<Operation *> opsInRange;
44 size_t keepIndex = 0;
45 for (const auto &op : enumerate(region.getOps())) {
46 int index = op.index();
47 if (keepIndex < rangeToKeep.size() &&
48 index == rangeToKeep[keepIndex].second)
49 ++keepIndex;
50 if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
51 opsNotInRange.push_back(&op.value());
52 else
53 opsInRange.push_back(&op.value());
54 }
55
56 // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
57 // pattern matching in above iteration. Besides, erase op not-in-range may end
58 // up in invalid module, so `applyOpPatternsGreedily` with folding should come
59 // before that transform.
60 for (Operation *op : opsInRange) {
61 // `applyOpPatternsGreedily` with folding returns whether the op is
62 // converted. Omit it because we don't have expectation this reduction will
63 // be success or not.
65 GreedyRewriteConfig().setStrictness(
67 }
68
69 if (eraseOpNotInRange)
70 for (Operation *op : opsNotInRange) {
71 op->dropAllUses();
72 op->erase();
73 }
74}
75
76/// We will apply the reducer patterns to the operations in the ranges specified
77/// by ReductionNode. Note that we are not able to remove an operation without
78/// replacing it with another valid operation. However, The validity of module
79/// reduction is based on the Tester provided by the user and that means certain
80/// invalid module is still interested by the use. Thus we provide an
81/// alternative way to remove operations, which is using `eraseOpNotInRange` to
82/// erase the operations not in the range specified by ReductionNode.
83template <typename IteratorType>
84static LogicalResult findOptimal(ModuleOp module, Region &region,
86 const Tester &test, bool eraseOpNotInRange) {
87 std::pair<Tester::Interestingness, size_t> initStatus =
88 test.isInteresting(module);
89 // While exploring the reduction tree, we always branch from an interesting
90 // node. Thus the root node must be interesting.
91 if (initStatus.first != Tester::Interestingness::True)
92 return module.emitWarning() << "uninterested module will not be reduced";
93
94 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
95
96 std::vector<ReductionNode::Range> ranges{
97 {0, std::distance(region.op_begin(), region.op_end())}};
98
99 ReductionNode *root = allocator.Allocate();
100 new (root) ReductionNode(nullptr, ranges, allocator);
101 // Duplicate the module for root node and locate the region in the copy.
102 if (failed(root->initialize(module, region)))
103 llvm_unreachable("unexpected initialization failure");
104 root->update(initStatus);
105
106 ReductionNode *smallestNode = root;
107 IteratorType iter(root);
108
109 while (iter != IteratorType::end()) {
110 ReductionNode &currentNode = *iter;
111 Region &curRegion = currentNode.getRegion();
112
113 applyPatterns(curRegion, patterns, currentNode.getRanges(),
114 eraseOpNotInRange);
115 currentNode.update(test.isInteresting(currentNode.getModule()));
116
117 if (currentNode.isInteresting() == Tester::Interestingness::True &&
118 currentNode.getSize() < smallestNode->getSize())
119 smallestNode = &currentNode;
120
121 ++iter;
122 }
123
124 // At here, we have found an optimal path to reduce the given region. Retrieve
125 // the path and apply the reducer to it.
127 ReductionNode *curNode = smallestNode;
128 trace.push_back(curNode);
129 while (curNode != root) {
130 curNode = curNode->getParent();
131 trace.push_back(curNode);
132 }
133
134 // Reduce the region through the optimal path.
135 while (!trace.empty()) {
136 ReductionNode *top = trace.pop_back_val();
137 applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
138 }
140 if (test.isInteresting(module).first != Tester::Interestingness::True)
141 llvm::report_fatal_error("Reduced module is not interesting");
142 if (test.isInteresting(module).second != smallestNode->getSize())
143 llvm::report_fatal_error(
144 "Reduced module doesn't have consistent size with smallestNode");
145 return success();
147
148template <typename IteratorType>
149static LogicalResult findOptimal(ModuleOp module, Region &region,
151 const Tester &test) {
152 // We separate the reduction process into 2 steps, the first one is to erase
153 // redundant operations and the second one is to apply the reducer patterns.
155 // In the first phase, we don't apply any patterns so that we only select the
156 // range of operations to keep to the module stay interesting.
157 if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
158 /*eraseOpNotInRange=*/true)))
159 return failure();
160 // In the second phase, we suppose that no operation is redundant, so we try
161 // to rewrite the operation into simpler form.
162 return findOptimal<IteratorType>(module, region, patterns, test,
163 /*eraseOpNotInRange=*/false);
164}
165
166namespace {
167
168//===----------------------------------------------------------------------===//
169// Reduction Pattern Interface Collection
170//===----------------------------------------------------------------------===//
171
172class ReductionPatternInterfaceCollection
173 : public DialectInterfaceCollection<DialectReductionPatternInterface> {
174public:
175 using Base::Base;
176
177 // Collect the reduce patterns defined by each dialect.
178 void populateReductionPatterns(RewritePatternSet &pattern,
179 Tester &tester) const {
180 for (const DialectReductionPatternInterface &interface : *this) {
181 interface.populateReductionPatterns(pattern);
182 interface.populateReductionPatternsWithTester(pattern, tester);
183 }
184 }
185};
186
187//===----------------------------------------------------------------------===//
188// ReductionTreePass
189//===----------------------------------------------------------------------===//
191/// This class defines the Reduction Tree Pass. It provides a framework to
192/// to implement a reduction pass using a tree structure to keep track of the
193/// generated reduced variants.
194class ReductionTreePass
195 : public impl::ReductionTreePassBase<ReductionTreePass> {
196public:
198
199 LogicalResult initialize(MLIRContext *context) override;
200
201 /// Runs the pass instance in the pass pipeline.
202 void runOnOperation() override;
204private:
205 LogicalResult reduceOp(ModuleOp module, Region &region);
206
207 Tester tester;
208 FrozenRewritePatternSet reducerPatterns;
209};
210
211} // namespace
212
213LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
214 tester.setTestScript(testerName);
215 tester.setTestScriptArgs(testerArgs);
216
217 RewritePatternSet patterns(context);
218
219 ReductionPatternInterfaceCollection reducePatternCollection(context);
220 reducePatternCollection.populateReductionPatterns(patterns, tester);
221
222 reducerPatterns = std::move(patterns);
223 return success();
224}
225
226void ReductionTreePass::runOnOperation() {
227 Operation *topOperation = getOperation();
228 while (topOperation->getParentOp() != nullptr)
229 topOperation = topOperation->getParentOp();
230 ModuleOp module = dyn_cast<ModuleOp>(topOperation);
231 if (!module) {
232 emitError(getOperation()->getLoc())
233 << "top-level op must be 'builtin.module'";
234 return signalPassFailure();
235 }
236
237 SmallVector<Operation *, 8> workList;
238 workList.push_back(getOperation());
239
240 do {
241 Operation *op = workList.pop_back_val();
242
243 for (Region &region : op->getRegions())
244 if (!region.empty())
245 if (failed(reduceOp(module, region)))
246 return signalPassFailure();
247
248 for (Region &region : op->getRegions())
249 for (Operation &op : region.getOps())
250 if (op.getNumRegions() != 0)
251 workList.push_back(&op);
252 } while (!workList.empty());
253}
254
255LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
256 switch (traversalModeId) {
257 case TraversalMode::SinglePath:
259 module, region, reducerPatterns, tester);
260 default:
261 return module.emitError() << "unsupported traversal mode detected";
262 }
263}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static LogicalResult findOptimal(ModuleOp module, Region &region, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
A collection of dialect interfaces within a context, for a given concrete interface type.
This is used to report the reduction patterns for a Dialect.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const =0
Patterns provided here are intended to transform operations from a complex form to a simpler form,...
virtual void populateReductionPatternsWithTester(RewritePatternSet &patterns, Tester &tester) const
This method extends populateReductionPatterns by allowing reduction patterns to use a Tester instance...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
size_t getSize() const
Return the size of the module.
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Region & getRegion() const
Return the region we're reducing.
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
ReductionNode * getParent() const
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition Region.h:170
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
OpIterator op_end()
Definition Region.h:171
This class is used to keep track of the testing environment of the tool.
Definition Tester.h:31
void setTestScriptArgs(ArrayRef< std::string > args)
Definition Tester.h:53
std::pair< Interestingness, size_t > isInteresting(ModuleOp module) const
Runs the interestingness testing script on a MLIR test case file.
Definition Tester.cpp:27
void setTestScript(StringRef script)
Definition Tester.h:52
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ ExistingOps
Only pre-existing ops are processed.