MLIR  20.0.0git
ReductionTreePass.cpp
Go to the documentation of this file.
1 //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Pass class. It provides a framework for
10 // the implementation of different reduction passes in the MLIR Reduce tool. It
11 // allows for custom specification of the variant generation behavior. It
12 // implements methods that define the different possible traversals of the
13 // reduction tree.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/Reducer/Passes.h"
22 #include "mlir/Reducer/Tester.h"
25 
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Allocator.h"
29 #include "llvm/Support/ManagedStatic.h"
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_REDUCTIONTREE
33 #include "mlir/Reducer/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 /// We implicitly number each operation in the region and if an operation's
39 /// number falls into rangeToKeep, we need to keep it and apply the given
40 /// rewrite patterns on it.
41 static void applyPatterns(Region &region,
44  bool eraseOpNotInRange) {
45  std::vector<Operation *> opsNotInRange;
46  std::vector<Operation *> opsInRange;
47  size_t keepIndex = 0;
48  for (const auto &op : enumerate(region.getOps())) {
49  int index = op.index();
50  if (keepIndex < rangeToKeep.size() &&
51  index == rangeToKeep[keepIndex].second)
52  ++keepIndex;
53  if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
54  opsNotInRange.push_back(&op.value());
55  else
56  opsInRange.push_back(&op.value());
57  }
58 
59  // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
60  // matching in above iteration. Besides, erase op not-in-range may end up in
61  // invalid module, so `applyOpPatternsAndFold` should come before that
62  // transform.
63  for (Operation *op : opsInRange) {
64  // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
65  // because we don't have expectation this reduction will be success or not.
69  }
70 
71  if (eraseOpNotInRange)
72  for (Operation *op : opsNotInRange) {
73  op->dropAllUses();
74  op->erase();
75  }
76 }
77 
78 /// We will apply the reducer patterns to the operations in the ranges specified
79 /// by ReductionNode. Note that we are not able to remove an operation without
80 /// replacing it with another valid operation. However, The validity of module
81 /// reduction is based on the Tester provided by the user and that means certain
82 /// invalid module is still interested by the use. Thus we provide an
83 /// alternative way to remove operations, which is using `eraseOpNotInRange` to
84 /// erase the operations not in the range specified by ReductionNode.
85 template <typename IteratorType>
86 static LogicalResult findOptimal(ModuleOp module, Region &region,
88  const Tester &test, bool eraseOpNotInRange) {
89  std::pair<Tester::Interestingness, size_t> initStatus =
90  test.isInteresting(module);
91  // While exploring the reduction tree, we always branch from an interesting
92  // node. Thus the root node must be interesting.
93  if (initStatus.first != Tester::Interestingness::True)
94  return module.emitWarning() << "uninterested module will not be reduced";
95 
96  llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
97 
98  std::vector<ReductionNode::Range> ranges{
99  {0, std::distance(region.op_begin(), region.op_end())}};
100 
101  ReductionNode *root = allocator.Allocate();
102  new (root) ReductionNode(nullptr, ranges, allocator);
103  // Duplicate the module for root node and locate the region in the copy.
104  if (failed(root->initialize(module, region)))
105  llvm_unreachable("unexpected initialization failure");
106  root->update(initStatus);
107 
108  ReductionNode *smallestNode = root;
109  IteratorType iter(root);
110 
111  while (iter != IteratorType::end()) {
112  ReductionNode &currentNode = *iter;
113  Region &curRegion = currentNode.getRegion();
114 
115  applyPatterns(curRegion, patterns, currentNode.getRanges(),
116  eraseOpNotInRange);
117  currentNode.update(test.isInteresting(currentNode.getModule()));
118 
119  if (currentNode.isInteresting() == Tester::Interestingness::True &&
120  currentNode.getSize() < smallestNode->getSize())
121  smallestNode = &currentNode;
122 
123  ++iter;
124  }
125 
126  // At here, we have found an optimal path to reduce the given region. Retrieve
127  // the path and apply the reducer to it.
129  ReductionNode *curNode = smallestNode;
130  trace.push_back(curNode);
131  while (curNode != root) {
132  curNode = curNode->getParent();
133  trace.push_back(curNode);
134  }
135 
136  // Reduce the region through the optimal path.
137  while (!trace.empty()) {
138  ReductionNode *top = trace.pop_back_val();
139  applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
140  }
141 
142  if (test.isInteresting(module).first != Tester::Interestingness::True)
143  llvm::report_fatal_error("Reduced module is not interesting");
144  if (test.isInteresting(module).second != smallestNode->getSize())
145  llvm::report_fatal_error(
146  "Reduced module doesn't have consistent size with smallestNode");
147  return success();
148 }
149 
150 template <typename IteratorType>
151 static LogicalResult findOptimal(ModuleOp module, Region &region,
153  const Tester &test) {
154  // We separate the reduction process into 2 steps, the first one is to erase
155  // redundant operations and the second one is to apply the reducer patterns.
156 
157  // In the first phase, we don't apply any patterns so that we only select the
158  // range of operations to keep to the module stay interesting.
159  if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
160  /*eraseOpNotInRange=*/true)))
161  return failure();
162  // In the second phase, we suppose that no operation is redundant, so we try
163  // to rewrite the operation into simpler form.
164  return findOptimal<IteratorType>(module, region, patterns, test,
165  /*eraseOpNotInRange=*/false);
166 }
167 
168 namespace {
169 
170 //===----------------------------------------------------------------------===//
171 // Reduction Pattern Interface Collection
172 //===----------------------------------------------------------------------===//
173 
174 class ReductionPatternInterfaceCollection
175  : public DialectInterfaceCollection<DialectReductionPatternInterface> {
176 public:
177  using Base::Base;
178 
179  // Collect the reduce patterns defined by each dialect.
180  void populateReductionPatterns(RewritePatternSet &pattern) const {
181  for (const DialectReductionPatternInterface &interface : *this)
182  interface.populateReductionPatterns(pattern);
183  }
184 };
185 
186 //===----------------------------------------------------------------------===//
187 // ReductionTreePass
188 //===----------------------------------------------------------------------===//
189 
190 /// This class defines the Reduction Tree Pass. It provides a framework to
191 /// to implement a reduction pass using a tree structure to keep track of the
192 /// generated reduced variants.
193 class ReductionTreePass : public impl::ReductionTreeBase<ReductionTreePass> {
194 public:
195  ReductionTreePass() = default;
196  ReductionTreePass(const ReductionTreePass &pass) = default;
197 
198  LogicalResult initialize(MLIRContext *context) override;
199 
200  /// Runs the pass instance in the pass pipeline.
201  void runOnOperation() override;
202 
203 private:
204  LogicalResult reduceOp(ModuleOp module, Region &region);
205 
206  FrozenRewritePatternSet reducerPatterns;
207 };
208 
209 } // namespace
210 
211 LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
212  RewritePatternSet patterns(context);
213  ReductionPatternInterfaceCollection reducePatternCollection(context);
214  reducePatternCollection.populateReductionPatterns(patterns);
215  reducerPatterns = std::move(patterns);
216  return success();
217 }
218 
219 void ReductionTreePass::runOnOperation() {
220  Operation *topOperation = getOperation();
221  while (topOperation->getParentOp() != nullptr)
222  topOperation = topOperation->getParentOp();
223  ModuleOp module = dyn_cast<ModuleOp>(topOperation);
224  if (!module) {
225  emitError(getOperation()->getLoc())
226  << "top-level op must be 'builtin.module'";
227  return signalPassFailure();
228  }
229 
231  workList.push_back(getOperation());
232 
233  do {
234  Operation *op = workList.pop_back_val();
235 
236  for (Region &region : op->getRegions())
237  if (!region.empty())
238  if (failed(reduceOp(module, region)))
239  return signalPassFailure();
240 
241  for (Region &region : op->getRegions())
242  for (Operation &op : region.getOps())
243  if (op.getNumRegions() != 0)
244  workList.push_back(&op);
245  } while (!workList.empty());
246 }
247 
248 LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
249  Tester test(testerName, testerArgs);
250  switch (traversalModeId) {
252  return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
253  module, region, reducerPatterns, test);
254  default:
255  return module.emitError() << "unsupported traversal mode detected";
256  }
257 }
258 
259 std::unique_ptr<Pass> mlir::createReductionTreePass() {
260  return std::make_unique<ReductionTreePass>();
261 }
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static LogicalResult findOptimal(ModuleOp module, Region &region, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
A collection of dialect interfaces within a context, for a given concrete interface type.
This is used to report the reduction patterns for a Dialect.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const =0
Patterns provided here are intended to transform operations from a complex form to a simpler form,...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
Definition: ReductionNode.h:43
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
size_t getSize() const
Return the size of the module.
Definition: ReductionNode.h:66
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Definition: ReductionNode.h:60
Region & getRegion() const
Return the region we're reducing.
Definition: ReductionNode.h:63
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
Definition: ReductionNode.h:69
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
Definition: ReductionNode.h:76
ReductionNode * getParent() const
Definition: ReductionNode.h:53
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
Definition: ReductionNode.h:73
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
OpIterator op_end()
Definition: Region.h:171
This class is used to keep track of the testing environment of the tool.
Definition: Tester.h:33
std::pair< Interestingness, size_t > isInteresting(ModuleOp module) const
Runs the interestingness testing script on a MLIR test case file.
Definition: Tester.cpp:27
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::unique_ptr< Pass > createReductionTreePass()
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ SinglePath
Definition: ReductionNode.h:36
@ ExistingOps
Only pre-existing ops are processed.