MLIR 22.0.0git
ReductionTreePass.cpp
Go to the documentation of this file.
1//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the Reduction Tree Pass class. It provides a framework for
10// the implementation of different reduction passes in the MLIR Reduce tool. It
11// allows for custom specification of the variant generation behavior. It
12// implements methods that define the different possible traversals of the
13// reduction tree.
14//
15//===----------------------------------------------------------------------===//
16
18#include "mlir/Reducer/Passes.h"
21#include "mlir/Reducer/Tester.h"
24
25#include "llvm/ADT/ArrayRef.h"
26#include "llvm/Support/Allocator.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_REDUCTIONTREEPASS
30#include "mlir/Reducer/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34
35/// We implicitly number each operation in the region and if an operation's
36/// number falls into rangeToKeep, we need to keep it and apply the given
37/// rewrite patterns on it.
38static void applyPatterns(Region &region,
41 bool eraseOpNotInRange) {
42 std::vector<Operation *> opsNotInRange;
43 std::vector<Operation *> opsInRange;
44 size_t keepIndex = 0;
45 for (const auto &op : enumerate(region.getOps())) {
46 int index = op.index();
47 if (keepIndex < rangeToKeep.size() &&
48 index == rangeToKeep[keepIndex].second)
49 ++keepIndex;
50 if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
51 opsNotInRange.push_back(&op.value());
52 else
53 opsInRange.push_back(&op.value());
54 }
55
56 // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
57 // pattern matching in above iteration. Besides, erase op not-in-range may end
58 // up in invalid module, so `applyOpPatternsGreedily` with folding should come
59 // before that transform.
60 for (Operation *op : opsInRange) {
61 // `applyOpPatternsGreedily` with folding returns whether the op is
62 // converted. Omit it because we don't have expectation this reduction will
63 // be success or not.
65 GreedyRewriteConfig().setStrictness(
67 }
68
69 if (eraseOpNotInRange)
70 for (Operation *op : opsNotInRange) {
71 op->dropAllUses();
72 op->erase();
73 }
74}
75
76/// We will apply the reducer patterns to the operations in the ranges specified
77/// by ReductionNode. Note that we are not able to remove an operation without
78/// replacing it with another valid operation. However, The validity of module
79/// reduction is based on the Tester provided by the user and that means certain
80/// invalid module is still interested by the use. Thus we provide an
81/// alternative way to remove operations, which is using `eraseOpNotInRange` to
82/// erase the operations not in the range specified by ReductionNode.
83template <typename IteratorType>
84static LogicalResult findOptimal(ModuleOp module, Region &region,
86 const Tester &test, bool eraseOpNotInRange) {
87 std::pair<Tester::Interestingness, size_t> initStatus =
88 test.isInteresting(module);
89 // While exploring the reduction tree, we always branch from an interesting
90 // node. Thus the root node must be interesting.
91 if (initStatus.first != Tester::Interestingness::True)
92 return module.emitWarning() << "uninterested module will not be reduced";
93
94 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
95
96 std::vector<ReductionNode::Range> ranges{
97 {0, std::distance(region.op_begin(), region.op_end())}};
98
99 ReductionNode *root = allocator.Allocate();
100 new (root) ReductionNode(nullptr, ranges, allocator);
101 // Duplicate the module for root node and locate the region in the copy.
102 if (failed(root->initialize(module, region)))
103 llvm_unreachable("unexpected initialization failure");
104 root->update(initStatus);
105
106 ReductionNode *smallestNode = root;
107 IteratorType iter(root);
108
109 while (iter != IteratorType::end()) {
110 ReductionNode &currentNode = *iter;
111 Region &curRegion = currentNode.getRegion();
112
113 applyPatterns(curRegion, patterns, currentNode.getRanges(),
114 eraseOpNotInRange);
115 currentNode.update(test.isInteresting(currentNode.getModule()));
116
117 if (currentNode.isInteresting() == Tester::Interestingness::True &&
118 currentNode.getSize() < smallestNode->getSize())
119 smallestNode = &currentNode;
120
121 ++iter;
122 }
123
124 // At here, we have found an optimal path to reduce the given region. Retrieve
125 // the path and apply the reducer to it.
127 ReductionNode *curNode = smallestNode;
128 trace.push_back(curNode);
129 while (curNode != root) {
130 curNode = curNode->getParent();
131 trace.push_back(curNode);
132 }
133
134 // Reduce the region through the optimal path.
135 while (!trace.empty()) {
136 ReductionNode *top = trace.pop_back_val();
137 applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
138 }
140 if (test.isInteresting(module).first != Tester::Interestingness::True)
141 llvm::report_fatal_error("Reduced module is not interesting");
142 if (test.isInteresting(module).second != smallestNode->getSize())
143 llvm::report_fatal_error(
144 "Reduced module doesn't have consistent size with smallestNode");
145 return success();
147
148template <typename IteratorType>
149static LogicalResult findOptimal(ModuleOp module, Region &region,
151 const Tester &test) {
152 // We separate the reduction process into 2 steps, the first one is to erase
153 // redundant operations and the second one is to apply the reducer patterns.
155 // In the first phase, we don't apply any patterns so that we only select the
156 // range of operations to keep to the module stay interesting.
157 if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
158 /*eraseOpNotInRange=*/true)))
159 return failure();
160 // In the second phase, we suppose that no operation is redundant, so we try
161 // to rewrite the operation into simpler form.
162 return findOptimal<IteratorType>(module, region, patterns, test,
163 /*eraseOpNotInRange=*/false);
164}
165
166namespace {
167
168//===----------------------------------------------------------------------===//
169// Reduction Pattern Interface Collection
170//===----------------------------------------------------------------------===//
171
172class ReductionPatternInterfaceCollection
173 : public DialectInterfaceCollection<DialectReductionPatternInterface> {
174public:
175 using Base::Base;
176
177 // Collect the reduce patterns defined by each dialect.
178 void populateReductionPatterns(RewritePatternSet &pattern) const {
179 for (const DialectReductionPatternInterface &interface : *this)
180 interface.populateReductionPatterns(pattern);
181 }
183
184//===----------------------------------------------------------------------===//
185// ReductionTreePass
186//===----------------------------------------------------------------------===//
187
188/// This class defines the Reduction Tree Pass. It provides a framework to
189/// to implement a reduction pass using a tree structure to keep track of the
190/// generated reduced variants.
191class ReductionTreePass
192 : public impl::ReductionTreePassBase<ReductionTreePass> {
193public:
194 using Base::Base;
195
196 LogicalResult initialize(MLIRContext *context) override;
198 /// Runs the pass instance in the pass pipeline.
199 void runOnOperation() override;
200
201private:
202 LogicalResult reduceOp(ModuleOp module, Region &region);
204 FrozenRewritePatternSet reducerPatterns;
205};
206
207} // namespace
208
209LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
211 ReductionPatternInterfaceCollection reducePatternCollection(context);
212 reducePatternCollection.populateReductionPatterns(patterns);
213 reducerPatterns = std::move(patterns);
214 return success();
215}
216
217void ReductionTreePass::runOnOperation() {
218 Operation *topOperation = getOperation();
219 while (topOperation->getParentOp() != nullptr)
220 topOperation = topOperation->getParentOp();
221 ModuleOp module = dyn_cast<ModuleOp>(topOperation);
222 if (!module) {
223 emitError(getOperation()->getLoc())
224 << "top-level op must be 'builtin.module'";
225 return signalPassFailure();
226 }
227
228 SmallVector<Operation *, 8> workList;
229 workList.push_back(getOperation());
230
231 do {
232 Operation *op = workList.pop_back_val();
233
234 for (Region &region : op->getRegions())
235 if (!region.empty())
236 if (failed(reduceOp(module, region)))
237 return signalPassFailure();
238
239 for (Region &region : op->getRegions())
240 for (Operation &op : region.getOps())
241 if (op.getNumRegions() != 0)
242 workList.push_back(&op);
243 } while (!workList.empty());
244}
245
246LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
247 Tester test(testerName, testerArgs);
248 switch (traversalModeId) {
249 case TraversalMode::SinglePath:
251 module, region, reducerPatterns, test);
252 default:
253 return module.emitError() << "unsupported traversal mode detected";
254 }
255}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static LogicalResult findOptimal(ModuleOp module, Region &region, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
A collection of dialect interfaces within a context, for a given concrete interface type.
This is used to report the reduction patterns for a Dialect.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const =0
Patterns provided here are intended to transform operations from a complex form to a simpler form,...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
size_t getSize() const
Return the size of the module.
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Region & getRegion() const
Return the region we're reducing.
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
ReductionNode * getParent() const
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition Region.h:170
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
OpIterator op_end()
Definition Region.h:171
This class is used to keep track of the testing environment of the tool.
Definition Tester.h:31
std::pair< Interestingness, size_t > isInteresting(ModuleOp module) const
Runs the interestingness testing script on a MLIR test case file.
Definition Tester.cpp:27
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ ExistingOps
Only pre-existing ops are processed.