MLIR  16.0.0git
ReductionTreePass.cpp
Go to the documentation of this file.
1 //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Pass class. It provides a framework for
10 // the implementation of different reduction passes in the MLIR Reduce tool. It
11 // allows for custom specification of the variant generation behavior. It
12 // implements methods that define the different possible traversals of the
13 // reduction tree.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/Reducer/Passes.h"
22 #include "mlir/Reducer/Tester.h"
25 
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Allocator.h"
29 #include "llvm/Support/ManagedStatic.h"
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_REDUCTIONTREE
33 #include "mlir/Reducer/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 /// We implicitly number each operation in the region and if an operation's
39 /// number falls into rangeToKeep, we need to keep it and apply the given
40 /// rewrite patterns on it.
41 static void applyPatterns(Region &region,
42  const FrozenRewritePatternSet &patterns,
44  bool eraseOpNotInRange) {
45  std::vector<Operation *> opsNotInRange;
46  std::vector<Operation *> opsInRange;
47  size_t keepIndex = 0;
48  for (const auto &op : enumerate(region.getOps())) {
49  int index = op.index();
50  if (keepIndex < rangeToKeep.size() &&
51  index == rangeToKeep[keepIndex].second)
52  ++keepIndex;
53  if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
54  opsNotInRange.push_back(&op.value());
55  else
56  opsInRange.push_back(&op.value());
57  }
58 
59  // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
60  // matching in above iteration. Besides, erase op not-in-range may end up in
61  // invalid module, so `applyOpPatternsAndFold` should come before that
62  // transform.
63  for (Operation *op : opsInRange)
64  // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
65  // because we don't have expectation this reduction will be success or not.
66  (void)applyOpPatternsAndFold(op, patterns);
67 
68  if (eraseOpNotInRange)
69  for (Operation *op : opsNotInRange) {
70  op->dropAllUses();
71  op->erase();
72  }
73 }
74 
75 /// We will apply the reducer patterns to the operations in the ranges specified
76 /// by ReductionNode. Note that we are not able to remove an operation without
77 /// replacing it with another valid operation. However, The validity of module
78 /// reduction is based on the Tester provided by the user and that means certain
79 /// invalid module is still interested by the use. Thus we provide an
80 /// alternative way to remove operations, which is using `eraseOpNotInRange` to
81 /// erase the operations not in the range specified by ReductionNode.
82 template <typename IteratorType>
83 static LogicalResult findOptimal(ModuleOp module, Region &region,
84  const FrozenRewritePatternSet &patterns,
85  const Tester &test, bool eraseOpNotInRange) {
86  std::pair<Tester::Interestingness, size_t> initStatus =
87  test.isInteresting(module);
88  // While exploring the reduction tree, we always branch from an interesting
89  // node. Thus the root node must be interesting.
90  if (initStatus.first != Tester::Interestingness::True)
91  return module.emitWarning() << "uninterested module will not be reduced";
92 
93  llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
94 
95  std::vector<ReductionNode::Range> ranges{
96  {0, std::distance(region.op_begin(), region.op_end())}};
97 
98  ReductionNode *root = allocator.Allocate();
99  new (root) ReductionNode(nullptr, ranges, allocator);
100  // Duplicate the module for root node and locate the region in the copy.
101  if (failed(root->initialize(module, region)))
102  llvm_unreachable("unexpected initialization failure");
103  root->update(initStatus);
104 
105  ReductionNode *smallestNode = root;
106  IteratorType iter(root);
107 
108  while (iter != IteratorType::end()) {
109  ReductionNode &currentNode = *iter;
110  Region &curRegion = currentNode.getRegion();
111 
112  applyPatterns(curRegion, patterns, currentNode.getRanges(),
113  eraseOpNotInRange);
114  currentNode.update(test.isInteresting(currentNode.getModule()));
115 
116  if (currentNode.isInteresting() == Tester::Interestingness::True &&
117  currentNode.getSize() < smallestNode->getSize())
118  smallestNode = &currentNode;
119 
120  ++iter;
121  }
122 
123  // At here, we have found an optimal path to reduce the given region. Retrieve
124  // the path and apply the reducer to it.
126  ReductionNode *curNode = smallestNode;
127  trace.push_back(curNode);
128  while (curNode != root) {
129  curNode = curNode->getParent();
130  trace.push_back(curNode);
131  }
132 
133  // Reduce the region through the optimal path.
134  while (!trace.empty()) {
135  ReductionNode *top = trace.pop_back_val();
136  applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
137  }
138 
139  if (test.isInteresting(module).first != Tester::Interestingness::True)
140  llvm::report_fatal_error("Reduced module is not interesting");
141  if (test.isInteresting(module).second != smallestNode->getSize())
142  llvm::report_fatal_error(
143  "Reduced module doesn't have consistent size with smallestNode");
144  return success();
145 }
146 
147 template <typename IteratorType>
148 static LogicalResult findOptimal(ModuleOp module, Region &region,
149  const FrozenRewritePatternSet &patterns,
150  const Tester &test) {
151  // We separate the reduction process into 2 steps, the first one is to erase
152  // redundant operations and the second one is to apply the reducer patterns.
153 
154  // In the first phase, we don't apply any patterns so that we only select the
155  // range of operations to keep to the module stay interesting.
156  if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
157  /*eraseOpNotInRange=*/true)))
158  return failure();
159  // In the second phase, we suppose that no operation is redundant, so we try
160  // to rewrite the operation into simpler form.
161  return findOptimal<IteratorType>(module, region, patterns, test,
162  /*eraseOpNotInRange=*/false);
163 }
164 
165 namespace {
166 
167 //===----------------------------------------------------------------------===//
168 // Reduction Pattern Interface Collection
169 //===----------------------------------------------------------------------===//
170 
171 class ReductionPatternInterfaceCollection
172  : public DialectInterfaceCollection<DialectReductionPatternInterface> {
173 public:
174  using Base::Base;
175 
176  // Collect the reduce patterns defined by each dialect.
177  void populateReductionPatterns(RewritePatternSet &pattern) const {
178  for (const DialectReductionPatternInterface &interface : *this)
179  interface.populateReductionPatterns(pattern);
180  }
181 };
182 
183 //===----------------------------------------------------------------------===//
184 // ReductionTreePass
185 //===----------------------------------------------------------------------===//
186 
187 /// This class defines the Reduction Tree Pass. It provides a framework to
188 /// to implement a reduction pass using a tree structure to keep track of the
189 /// generated reduced variants.
190 class ReductionTreePass : public impl::ReductionTreeBase<ReductionTreePass> {
191 public:
192  ReductionTreePass() = default;
193  ReductionTreePass(const ReductionTreePass &pass) = default;
194 
195  LogicalResult initialize(MLIRContext *context) override;
196 
197  /// Runs the pass instance in the pass pipeline.
198  void runOnOperation() override;
199 
200 private:
201  LogicalResult reduceOp(ModuleOp module, Region &region);
202 
203  FrozenRewritePatternSet reducerPatterns;
204 };
205 
206 } // namespace
207 
208 LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
209  RewritePatternSet patterns(context);
210  ReductionPatternInterfaceCollection reducePatternCollection(context);
211  reducePatternCollection.populateReductionPatterns(patterns);
212  reducerPatterns = std::move(patterns);
213  return success();
214 }
215 
216 void ReductionTreePass::runOnOperation() {
217  Operation *topOperation = getOperation();
218  while (topOperation->getParentOp() != nullptr)
219  topOperation = topOperation->getParentOp();
220  ModuleOp module = dyn_cast<ModuleOp>(topOperation);
221  if (!module) {
222  emitError(getOperation()->getLoc())
223  << "top-level op must be 'builtin.module'";
224  return signalPassFailure();
225  }
226 
228  workList.push_back(getOperation());
229 
230  do {
231  Operation *op = workList.pop_back_val();
232 
233  for (Region &region : op->getRegions())
234  if (!region.empty())
235  if (failed(reduceOp(module, region)))
236  return signalPassFailure();
237 
238  for (Region &region : op->getRegions())
239  for (Operation &op : region.getOps())
240  if (op.getNumRegions() != 0)
241  workList.push_back(&op);
242  } while (!workList.empty());
243 }
244 
245 LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
246  Tester test(testerName, testerArgs);
247  switch (traversalModeId) {
249  return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
250  module, region, reducerPatterns, test);
251  default:
252  return module.emitError() << "unsupported traversal mode detected";
253  }
254 }
255 
256 std::unique_ptr<Pass> mlir::createReductionTreePass() {
257  return std::make_unique<ReductionTreePass>();
258 }
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static LogicalResult findOptimal(ModuleOp module, Region &region, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
A collection of dialect interfaces within a context, for a given concrete interface type.
This is used to report the reduction patterns for a Dialect.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const =0
Patterns provided here are intended to transform operations from a complex form to a simpler form,...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:477
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
Definition: ReductionNode.h:44
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
size_t getSize() const
Return the size of the module.
Definition: ReductionNode.h:67
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Definition: ReductionNode.h:61
Region & getRegion() const
Return the region we're reducing.
Definition: ReductionNode.h:64
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
Definition: ReductionNode.h:70
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
Definition: ReductionNode.h:77
ReductionNode * getParent() const
Definition: ReductionNode.h:54
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
Definition: ReductionNode.h:74
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
OpIterator op_end()
Definition: Region.h:171
This class is used to keep track of the testing environment of the tool.
Definition: Tester.h:33
std::pair< Interestingness, size_t > isInteresting(ModuleOp module) const
Runs the interestingness testing script on a MLIR test case file.
Definition: Tester.cpp:27
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::unique_ptr< Pass > createReductionTreePass()
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ SinglePath
Definition: ReductionNode.h:37
LogicalResult applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns, bool *erased=nullptr)
Applies the specified patterns on op alone while also trying to fold it, by selecting the highest ben...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26