MLIR  21.0.0git
ReductionTreePass.cpp
Go to the documentation of this file.
1 //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Pass class. It provides a framework for
10 // the implementation of different reduction passes in the MLIR Reduce tool. It
11 // allows for custom specification of the variant generation behavior. It
12 // implements methods that define the different possible traversals of the
13 // reduction tree.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/Reducer/Passes.h"
22 #include "mlir/Reducer/Tester.h"
25 
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Allocator.h"
29 #include "llvm/Support/ManagedStatic.h"
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_REDUCTIONTREE
33 #include "mlir/Reducer/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 /// We implicitly number each operation in the region and if an operation's
39 /// number falls into rangeToKeep, we need to keep it and apply the given
40 /// rewrite patterns on it.
41 static void applyPatterns(Region &region,
44  bool eraseOpNotInRange) {
45  std::vector<Operation *> opsNotInRange;
46  std::vector<Operation *> opsInRange;
47  size_t keepIndex = 0;
48  for (const auto &op : enumerate(region.getOps())) {
49  int index = op.index();
50  if (keepIndex < rangeToKeep.size() &&
51  index == rangeToKeep[keepIndex].second)
52  ++keepIndex;
53  if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
54  opsNotInRange.push_back(&op.value());
55  else
56  opsInRange.push_back(&op.value());
57  }
58 
59  // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
60  // pattern matching in above iteration. Besides, erase op not-in-range may end
61  // up in invalid module, so `applyOpPatternsGreedily` with folding should come
62  // before that transform.
63  for (Operation *op : opsInRange) {
64  // `applyOpPatternsGreedily` with folding returns whether the op is
65  // convered. Omit it because we don't have expectation this reduction will
66  // be success or not.
70  }
71 
72  if (eraseOpNotInRange)
73  for (Operation *op : opsNotInRange) {
74  op->dropAllUses();
75  op->erase();
76  }
77 }
78 
79 /// We will apply the reducer patterns to the operations in the ranges specified
80 /// by ReductionNode. Note that we are not able to remove an operation without
81 /// replacing it with another valid operation. However, The validity of module
82 /// reduction is based on the Tester provided by the user and that means certain
83 /// invalid module is still interested by the use. Thus we provide an
84 /// alternative way to remove operations, which is using `eraseOpNotInRange` to
85 /// erase the operations not in the range specified by ReductionNode.
86 template <typename IteratorType>
87 static LogicalResult findOptimal(ModuleOp module, Region &region,
89  const Tester &test, bool eraseOpNotInRange) {
90  std::pair<Tester::Interestingness, size_t> initStatus =
91  test.isInteresting(module);
92  // While exploring the reduction tree, we always branch from an interesting
93  // node. Thus the root node must be interesting.
94  if (initStatus.first != Tester::Interestingness::True)
95  return module.emitWarning() << "uninterested module will not be reduced";
96 
97  llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
98 
99  std::vector<ReductionNode::Range> ranges{
100  {0, std::distance(region.op_begin(), region.op_end())}};
101 
102  ReductionNode *root = allocator.Allocate();
103  new (root) ReductionNode(nullptr, ranges, allocator);
104  // Duplicate the module for root node and locate the region in the copy.
105  if (failed(root->initialize(module, region)))
106  llvm_unreachable("unexpected initialization failure");
107  root->update(initStatus);
108 
109  ReductionNode *smallestNode = root;
110  IteratorType iter(root);
111 
112  while (iter != IteratorType::end()) {
113  ReductionNode &currentNode = *iter;
114  Region &curRegion = currentNode.getRegion();
115 
116  applyPatterns(curRegion, patterns, currentNode.getRanges(),
117  eraseOpNotInRange);
118  currentNode.update(test.isInteresting(currentNode.getModule()));
119 
120  if (currentNode.isInteresting() == Tester::Interestingness::True &&
121  currentNode.getSize() < smallestNode->getSize())
122  smallestNode = &currentNode;
123 
124  ++iter;
125  }
126 
127  // At here, we have found an optimal path to reduce the given region. Retrieve
128  // the path and apply the reducer to it.
130  ReductionNode *curNode = smallestNode;
131  trace.push_back(curNode);
132  while (curNode != root) {
133  curNode = curNode->getParent();
134  trace.push_back(curNode);
135  }
136 
137  // Reduce the region through the optimal path.
138  while (!trace.empty()) {
139  ReductionNode *top = trace.pop_back_val();
140  applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
141  }
142 
143  if (test.isInteresting(module).first != Tester::Interestingness::True)
144  llvm::report_fatal_error("Reduced module is not interesting");
145  if (test.isInteresting(module).second != smallestNode->getSize())
146  llvm::report_fatal_error(
147  "Reduced module doesn't have consistent size with smallestNode");
148  return success();
149 }
150 
151 template <typename IteratorType>
152 static LogicalResult findOptimal(ModuleOp module, Region &region,
154  const Tester &test) {
155  // We separate the reduction process into 2 steps, the first one is to erase
156  // redundant operations and the second one is to apply the reducer patterns.
157 
158  // In the first phase, we don't apply any patterns so that we only select the
159  // range of operations to keep to the module stay interesting.
160  if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
161  /*eraseOpNotInRange=*/true)))
162  return failure();
163  // In the second phase, we suppose that no operation is redundant, so we try
164  // to rewrite the operation into simpler form.
165  return findOptimal<IteratorType>(module, region, patterns, test,
166  /*eraseOpNotInRange=*/false);
167 }
168 
169 namespace {
170 
171 //===----------------------------------------------------------------------===//
172 // Reduction Pattern Interface Collection
173 //===----------------------------------------------------------------------===//
174 
175 class ReductionPatternInterfaceCollection
176  : public DialectInterfaceCollection<DialectReductionPatternInterface> {
177 public:
178  using Base::Base;
179 
180  // Collect the reduce patterns defined by each dialect.
181  void populateReductionPatterns(RewritePatternSet &pattern) const {
182  for (const DialectReductionPatternInterface &interface : *this)
183  interface.populateReductionPatterns(pattern);
184  }
185 };
186 
187 //===----------------------------------------------------------------------===//
188 // ReductionTreePass
189 //===----------------------------------------------------------------------===//
190 
191 /// This class defines the Reduction Tree Pass. It provides a framework to
192 /// to implement a reduction pass using a tree structure to keep track of the
193 /// generated reduced variants.
194 class ReductionTreePass : public impl::ReductionTreeBase<ReductionTreePass> {
195 public:
196  ReductionTreePass() = default;
197  ReductionTreePass(const ReductionTreePass &pass) = default;
198 
199  LogicalResult initialize(MLIRContext *context) override;
200 
201  /// Runs the pass instance in the pass pipeline.
202  void runOnOperation() override;
203 
204 private:
205  LogicalResult reduceOp(ModuleOp module, Region &region);
206 
207  FrozenRewritePatternSet reducerPatterns;
208 };
209 
210 } // namespace
211 
212 LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
213  RewritePatternSet patterns(context);
214  ReductionPatternInterfaceCollection reducePatternCollection(context);
215  reducePatternCollection.populateReductionPatterns(patterns);
216  reducerPatterns = std::move(patterns);
217  return success();
218 }
219 
220 void ReductionTreePass::runOnOperation() {
221  Operation *topOperation = getOperation();
222  while (topOperation->getParentOp() != nullptr)
223  topOperation = topOperation->getParentOp();
224  ModuleOp module = dyn_cast<ModuleOp>(topOperation);
225  if (!module) {
226  emitError(getOperation()->getLoc())
227  << "top-level op must be 'builtin.module'";
228  return signalPassFailure();
229  }
230 
232  workList.push_back(getOperation());
233 
234  do {
235  Operation *op = workList.pop_back_val();
236 
237  for (Region &region : op->getRegions())
238  if (!region.empty())
239  if (failed(reduceOp(module, region)))
240  return signalPassFailure();
241 
242  for (Region &region : op->getRegions())
243  for (Operation &op : region.getOps())
244  if (op.getNumRegions() != 0)
245  workList.push_back(&op);
246  } while (!workList.empty());
247 }
248 
249 LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
250  Tester test(testerName, testerArgs);
251  switch (traversalModeId) {
253  return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
254  module, region, reducerPatterns, test);
255  default:
256  return module.emitError() << "unsupported traversal mode detected";
257  }
258 }
259 
260 std::unique_ptr<Pass> mlir::createReductionTreePass() {
261  return std::make_unique<ReductionTreePass>();
262 }
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static LogicalResult findOptimal(ModuleOp module, Region &region, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
A collection of dialect interfaces within a context, for a given concrete interface type.
This is used to report the reduction patterns for a Dialect.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const =0
Patterns provided here are intended to transform operations from a complex form to a simpler form,...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
Definition: ReductionNode.h:43
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
size_t getSize() const
Return the size of the module.
Definition: ReductionNode.h:66
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Definition: ReductionNode.h:60
Region & getRegion() const
Return the region we're reducing.
Definition: ReductionNode.h:63
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
Definition: ReductionNode.h:69
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
Definition: ReductionNode.h:76
ReductionNode * getParent() const
Definition: ReductionNode.h:53
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
Definition: ReductionNode.h:73
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
OpIterator op_end()
Definition: Region.h:171
This class is used to keep track of the testing environment of the tool.
Definition: Tester.h:33
std::pair< Interestingness, size_t > isInteresting(ModuleOp module) const
Runs the interestingness testing script on a MLIR test case file.
Definition: Tester.cpp:27
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::unique_ptr< Pass > createReductionTreePass()
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ SinglePath
Definition: ReductionNode.h:36
@ ExistingOps
Only pre-existing ops are processed.