MLIR  22.0.0git
ReductionTreePass.cpp
Go to the documentation of this file.
1 //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Pass class. It provides a framework for
10 // the implementation of different reduction passes in the MLIR Reduce tool. It
11 // allows for custom specification of the variant generation behavior. It
12 // implements methods that define the different possible traversals of the
13 // reduction tree.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 #include "mlir/Reducer/Passes.h"
21 #include "mlir/Reducer/Tester.h"
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/Support/Allocator.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_REDUCTIONTREEPASS
30 #include "mlir/Reducer/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 /// We implicitly number each operation in the region and if an operation's
36 /// number falls into rangeToKeep, we need to keep it and apply the given
37 /// rewrite patterns on it.
38 static void applyPatterns(Region &region,
41  bool eraseOpNotInRange) {
42  std::vector<Operation *> opsNotInRange;
43  std::vector<Operation *> opsInRange;
44  size_t keepIndex = 0;
45  for (const auto &op : enumerate(region.getOps())) {
46  int index = op.index();
47  if (keepIndex < rangeToKeep.size() &&
48  index == rangeToKeep[keepIndex].second)
49  ++keepIndex;
50  if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
51  opsNotInRange.push_back(&op.value());
52  else
53  opsInRange.push_back(&op.value());
54  }
55 
56  // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
57  // pattern matching in above iteration. Besides, erase op not-in-range may end
58  // up in invalid module, so `applyOpPatternsGreedily` with folding should come
59  // before that transform.
60  for (Operation *op : opsInRange) {
61  // `applyOpPatternsGreedily` with folding returns whether the op is
62  // converted. Omit it because we don't have expectation this reduction will
63  // be success or not.
67  }
68 
69  if (eraseOpNotInRange)
70  for (Operation *op : opsNotInRange) {
71  op->dropAllUses();
72  op->erase();
73  }
74 }
75 
76 /// We will apply the reducer patterns to the operations in the ranges specified
77 /// by ReductionNode. Note that we are not able to remove an operation without
78 /// replacing it with another valid operation. However, The validity of module
79 /// reduction is based on the Tester provided by the user and that means certain
80 /// invalid module is still interested by the use. Thus we provide an
81 /// alternative way to remove operations, which is using `eraseOpNotInRange` to
82 /// erase the operations not in the range specified by ReductionNode.
83 template <typename IteratorType>
84 static LogicalResult findOptimal(ModuleOp module, Region &region,
86  const Tester &test, bool eraseOpNotInRange) {
87  std::pair<Tester::Interestingness, size_t> initStatus =
88  test.isInteresting(module);
89  // While exploring the reduction tree, we always branch from an interesting
90  // node. Thus the root node must be interesting.
91  if (initStatus.first != Tester::Interestingness::True)
92  return module.emitWarning() << "uninterested module will not be reduced";
93 
94  llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
95 
96  std::vector<ReductionNode::Range> ranges{
97  {0, std::distance(region.op_begin(), region.op_end())}};
98 
99  ReductionNode *root = allocator.Allocate();
100  new (root) ReductionNode(nullptr, ranges, allocator);
101  // Duplicate the module for root node and locate the region in the copy.
102  if (failed(root->initialize(module, region)))
103  llvm_unreachable("unexpected initialization failure");
104  root->update(initStatus);
105 
106  ReductionNode *smallestNode = root;
107  IteratorType iter(root);
108 
109  while (iter != IteratorType::end()) {
110  ReductionNode &currentNode = *iter;
111  Region &curRegion = currentNode.getRegion();
112 
113  applyPatterns(curRegion, patterns, currentNode.getRanges(),
114  eraseOpNotInRange);
115  currentNode.update(test.isInteresting(currentNode.getModule()));
116 
117  if (currentNode.isInteresting() == Tester::Interestingness::True &&
118  currentNode.getSize() < smallestNode->getSize())
119  smallestNode = &currentNode;
120 
121  ++iter;
122  }
123 
124  // At here, we have found an optimal path to reduce the given region. Retrieve
125  // the path and apply the reducer to it.
127  ReductionNode *curNode = smallestNode;
128  trace.push_back(curNode);
129  while (curNode != root) {
130  curNode = curNode->getParent();
131  trace.push_back(curNode);
132  }
133 
134  // Reduce the region through the optimal path.
135  while (!trace.empty()) {
136  ReductionNode *top = trace.pop_back_val();
137  applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
138  }
139 
140  if (test.isInteresting(module).first != Tester::Interestingness::True)
141  llvm::report_fatal_error("Reduced module is not interesting");
142  if (test.isInteresting(module).second != smallestNode->getSize())
143  llvm::report_fatal_error(
144  "Reduced module doesn't have consistent size with smallestNode");
145  return success();
146 }
147 
148 template <typename IteratorType>
149 static LogicalResult findOptimal(ModuleOp module, Region &region,
151  const Tester &test) {
152  // We separate the reduction process into 2 steps, the first one is to erase
153  // redundant operations and the second one is to apply the reducer patterns.
154 
155  // In the first phase, we don't apply any patterns so that we only select the
156  // range of operations to keep to the module stay interesting.
157  if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
158  /*eraseOpNotInRange=*/true)))
159  return failure();
160  // In the second phase, we suppose that no operation is redundant, so we try
161  // to rewrite the operation into simpler form.
162  return findOptimal<IteratorType>(module, region, patterns, test,
163  /*eraseOpNotInRange=*/false);
164 }
165 
166 namespace {
167 
168 //===----------------------------------------------------------------------===//
169 // Reduction Pattern Interface Collection
170 //===----------------------------------------------------------------------===//
171 
172 class ReductionPatternInterfaceCollection
173  : public DialectInterfaceCollection<DialectReductionPatternInterface> {
174 public:
175  using Base::Base;
176 
177  // Collect the reduce patterns defined by each dialect.
178  void populateReductionPatterns(RewritePatternSet &pattern) const {
179  for (const DialectReductionPatternInterface &interface : *this)
180  interface.populateReductionPatterns(pattern);
181  }
182 };
183 
184 //===----------------------------------------------------------------------===//
185 // ReductionTreePass
186 //===----------------------------------------------------------------------===//
187 
188 /// This class defines the Reduction Tree Pass. It provides a framework to
189 /// to implement a reduction pass using a tree structure to keep track of the
190 /// generated reduced variants.
191 class ReductionTreePass
192  : public impl::ReductionTreePassBase<ReductionTreePass> {
193 public:
194  using Base::Base;
195 
196  LogicalResult initialize(MLIRContext *context) override;
197 
198  /// Runs the pass instance in the pass pipeline.
199  void runOnOperation() override;
200 
201 private:
202  LogicalResult reduceOp(ModuleOp module, Region &region);
203 
204  FrozenRewritePatternSet reducerPatterns;
205 };
206 
207 } // namespace
208 
209 LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
210  RewritePatternSet patterns(context);
211  ReductionPatternInterfaceCollection reducePatternCollection(context);
212  reducePatternCollection.populateReductionPatterns(patterns);
213  reducerPatterns = std::move(patterns);
214  return success();
215 }
216 
217 void ReductionTreePass::runOnOperation() {
218  Operation *topOperation = getOperation();
219  while (topOperation->getParentOp() != nullptr)
220  topOperation = topOperation->getParentOp();
221  ModuleOp module = dyn_cast<ModuleOp>(topOperation);
222  if (!module) {
223  emitError(getOperation()->getLoc())
224  << "top-level op must be 'builtin.module'";
225  return signalPassFailure();
226  }
227 
229  workList.push_back(getOperation());
230 
231  do {
232  Operation *op = workList.pop_back_val();
233 
234  for (Region &region : op->getRegions())
235  if (!region.empty())
236  if (failed(reduceOp(module, region)))
237  return signalPassFailure();
238 
239  for (Region &region : op->getRegions())
240  for (Operation &op : region.getOps())
241  if (op.getNumRegions() != 0)
242  workList.push_back(&op);
243  } while (!workList.empty());
244 }
245 
246 LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
247  Tester test(testerName, testerArgs);
248  switch (traversalModeId) {
250  return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
251  module, region, reducerPatterns, test);
252  default:
253  return module.emitError() << "unsupported traversal mode detected";
254  }
255 }
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static LogicalResult findOptimal(ModuleOp module, Region &region, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
A collection of dialect interfaces within a context, for a given concrete interface type.
This is used to report the reduction patterns for a Dialect.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const =0
Patterns provided here are intended to transform operations from a complex form to a simpler form,...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setStrictness(GreedyRewriteStrictness mode)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
Definition: ReductionNode.h:43
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
size_t getSize() const
Return the size of the module.
Definition: ReductionNode.h:66
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Definition: ReductionNode.h:60
Region & getRegion() const
Return the region we're reducing.
Definition: ReductionNode.h:63
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
Definition: ReductionNode.h:69
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
Definition: ReductionNode.h:76
ReductionNode * getParent() const
Definition: ReductionNode.h:53
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
Definition: ReductionNode.h:73
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
OpIterator op_end()
Definition: Region.h:171
This class is used to keep track of the testing environment of the tool.
Definition: Tester.h:31
std::pair< Interestingness, size_t > isInteresting(ModuleOp module) const
Runs the interestingness testing script on a MLIR test case file.
Definition: Tester.cpp:27
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ SinglePath
Definition: ReductionNode.h:36
@ ExistingOps
Only pre-existing ops are processed.