MLIR  19.0.0git
ReductionTreePass.cpp
Go to the documentation of this file.
1 //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Pass class. It provides a framework for
10 // the implementation of different reduction passes in the MLIR Reduce tool. It
11 // allows for custom specification of the variant generation behavior. It
12 // implements methods that define the different possible traversals of the
13 // reduction tree.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/Reducer/Passes.h"
22 #include "mlir/Reducer/Tester.h"
25 
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Allocator.h"
29 #include "llvm/Support/ManagedStatic.h"
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_REDUCTIONTREE
33 #include "mlir/Reducer/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 /// We implicitly number each operation in the region and if an operation's
39 /// number falls into rangeToKeep, we need to keep it and apply the given
40 /// rewrite patterns on it.
41 static void applyPatterns(Region &region,
42  const FrozenRewritePatternSet &patterns,
44  bool eraseOpNotInRange) {
45  std::vector<Operation *> opsNotInRange;
46  std::vector<Operation *> opsInRange;
47  size_t keepIndex = 0;
48  for (const auto &op : enumerate(region.getOps())) {
49  int index = op.index();
50  if (keepIndex < rangeToKeep.size() &&
51  index == rangeToKeep[keepIndex].second)
52  ++keepIndex;
53  if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
54  opsNotInRange.push_back(&op.value());
55  else
56  opsInRange.push_back(&op.value());
57  }
58 
59  // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
60  // matching in above iteration. Besides, erase op not-in-range may end up in
61  // invalid module, so `applyOpPatternsAndFold` should come before that
62  // transform.
63  for (Operation *op : opsInRange) {
64  // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
65  // because we don't have expectation this reduction will be success or not.
66  GreedyRewriteConfig config;
68  (void)applyOpPatternsAndFold(op, patterns, config);
69  }
70 
71  if (eraseOpNotInRange)
72  for (Operation *op : opsNotInRange) {
73  op->dropAllUses();
74  op->erase();
75  }
76 }
77 
78 /// We will apply the reducer patterns to the operations in the ranges specified
79 /// by ReductionNode. Note that we are not able to remove an operation without
80 /// replacing it with another valid operation. However, The validity of module
81 /// reduction is based on the Tester provided by the user and that means certain
82 /// invalid module is still interested by the use. Thus we provide an
83 /// alternative way to remove operations, which is using `eraseOpNotInRange` to
84 /// erase the operations not in the range specified by ReductionNode.
85 template <typename IteratorType>
86 static LogicalResult findOptimal(ModuleOp module, Region &region,
87  const FrozenRewritePatternSet &patterns,
88  const Tester &test, bool eraseOpNotInRange) {
89  std::pair<Tester::Interestingness, size_t> initStatus =
90  test.isInteresting(module);
91  // While exploring the reduction tree, we always branch from an interesting
92  // node. Thus the root node must be interesting.
93  if (initStatus.first != Tester::Interestingness::True)
94  return module.emitWarning() << "uninterested module will not be reduced";
95 
96  llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
97 
98  std::vector<ReductionNode::Range> ranges{
99  {0, std::distance(region.op_begin(), region.op_end())}};
100 
101  ReductionNode *root = allocator.Allocate();
102  new (root) ReductionNode(nullptr, ranges, allocator);
103  // Duplicate the module for root node and locate the region in the copy.
104  if (failed(root->initialize(module, region)))
105  llvm_unreachable("unexpected initialization failure");
106  root->update(initStatus);
107 
108  ReductionNode *smallestNode = root;
109  IteratorType iter(root);
110 
111  while (iter != IteratorType::end()) {
112  ReductionNode &currentNode = *iter;
113  Region &curRegion = currentNode.getRegion();
114 
115  applyPatterns(curRegion, patterns, currentNode.getRanges(),
116  eraseOpNotInRange);
117  currentNode.update(test.isInteresting(currentNode.getModule()));
118 
119  if (currentNode.isInteresting() == Tester::Interestingness::True &&
120  currentNode.getSize() < smallestNode->getSize())
121  smallestNode = &currentNode;
122 
123  ++iter;
124  }
125 
126  // At here, we have found an optimal path to reduce the given region. Retrieve
127  // the path and apply the reducer to it.
129  ReductionNode *curNode = smallestNode;
130  trace.push_back(curNode);
131  while (curNode != root) {
132  curNode = curNode->getParent();
133  trace.push_back(curNode);
134  }
135 
136  // Reduce the region through the optimal path.
137  while (!trace.empty()) {
138  ReductionNode *top = trace.pop_back_val();
139  applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
140  }
141 
142  if (test.isInteresting(module).first != Tester::Interestingness::True)
143  llvm::report_fatal_error("Reduced module is not interesting");
144  if (test.isInteresting(module).second != smallestNode->getSize())
145  llvm::report_fatal_error(
146  "Reduced module doesn't have consistent size with smallestNode");
147  return success();
148 }
149 
150 template <typename IteratorType>
151 static LogicalResult findOptimal(ModuleOp module, Region &region,
152  const FrozenRewritePatternSet &patterns,
153  const Tester &test) {
154  // We separate the reduction process into 2 steps, the first one is to erase
155  // redundant operations and the second one is to apply the reducer patterns.
156 
157  // In the first phase, we don't apply any patterns so that we only select the
158  // range of operations to keep to the module stay interesting.
159  if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
160  /*eraseOpNotInRange=*/true)))
161  return failure();
162  // In the second phase, we suppose that no operation is redundant, so we try
163  // to rewrite the operation into simpler form.
164  return findOptimal<IteratorType>(module, region, patterns, test,
165  /*eraseOpNotInRange=*/false);
166 }
167 
168 namespace {
169 
170 //===----------------------------------------------------------------------===//
171 // Reduction Pattern Interface Collection
172 //===----------------------------------------------------------------------===//
173 
174 class ReductionPatternInterfaceCollection
175  : public DialectInterfaceCollection<DialectReductionPatternInterface> {
176 public:
177  using Base::Base;
178 
179  // Collect the reduce patterns defined by each dialect.
180  void populateReductionPatterns(RewritePatternSet &pattern) const {
181  for (const DialectReductionPatternInterface &interface : *this)
182  interface.populateReductionPatterns(pattern);
183  }
184 };
185 
186 //===----------------------------------------------------------------------===//
187 // ReductionTreePass
188 //===----------------------------------------------------------------------===//
189 
190 /// This class defines the Reduction Tree Pass. It provides a framework to
191 /// to implement a reduction pass using a tree structure to keep track of the
192 /// generated reduced variants.
193 class ReductionTreePass : public impl::ReductionTreeBase<ReductionTreePass> {
194 public:
195  ReductionTreePass() = default;
196  ReductionTreePass(const ReductionTreePass &pass) = default;
197 
198  LogicalResult initialize(MLIRContext *context) override;
199 
200  /// Runs the pass instance in the pass pipeline.
201  void runOnOperation() override;
202 
203 private:
204  LogicalResult reduceOp(ModuleOp module, Region &region);
205 
206  FrozenRewritePatternSet reducerPatterns;
207 };
208 
209 } // namespace
210 
211 LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
212  RewritePatternSet patterns(context);
213  ReductionPatternInterfaceCollection reducePatternCollection(context);
214  reducePatternCollection.populateReductionPatterns(patterns);
215  reducerPatterns = std::move(patterns);
216  return success();
217 }
218 
219 void ReductionTreePass::runOnOperation() {
220  Operation *topOperation = getOperation();
221  while (topOperation->getParentOp() != nullptr)
222  topOperation = topOperation->getParentOp();
223  ModuleOp module = dyn_cast<ModuleOp>(topOperation);
224  if (!module) {
225  emitError(getOperation()->getLoc())
226  << "top-level op must be 'builtin.module'";
227  return signalPassFailure();
228  }
229 
231  workList.push_back(getOperation());
232 
233  do {
234  Operation *op = workList.pop_back_val();
235 
236  for (Region &region : op->getRegions())
237  if (!region.empty())
238  if (failed(reduceOp(module, region)))
239  return signalPassFailure();
240 
241  for (Region &region : op->getRegions())
242  for (Operation &op : region.getOps())
243  if (op.getNumRegions() != 0)
244  workList.push_back(&op);
245  } while (!workList.empty());
246 }
247 
248 LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
249  Tester test(testerName, testerArgs);
250  switch (traversalModeId) {
252  return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
253  module, region, reducerPatterns, test);
254  default:
255  return module.emitError() << "unsupported traversal mode detected";
256  }
257 }
258 
259 std::unique_ptr<Pass> mlir::createReductionTreePass() {
260  return std::make_unique<ReductionTreePass>();
261 }
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static LogicalResult findOptimal(ModuleOp module, Region &region, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
A collection of dialect interfaces within a context, for a given concrete interface type.
This is used to report the reduction patterns for a Dialect.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const =0
Patterns provided here are intended to transform operations from a complex form to a simpler form,...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteStrictness strictMode
Strict mode can restrict the ops that are added to the worklist during the rewrite.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:830
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
Definition: ReductionNode.h:44
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
size_t getSize() const
Return the size of the module.
Definition: ReductionNode.h:67
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Definition: ReductionNode.h:61
Region & getRegion() const
Return the region we're reducing.
Definition: ReductionNode.h:64
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
Definition: ReductionNode.h:70
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
Definition: ReductionNode.h:77
ReductionNode * getParent() const
Definition: ReductionNode.h:54
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
Definition: ReductionNode.h:74
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
OpIterator op_end()
Definition: Region.h:171
This class is used to keep track of the testing environment of the tool.
Definition: Tester.h:33
std::pair< Interestingness, size_t > isInteresting(ModuleOp module) const
Runs the interestingness testing script on a MLIR test case file.
Definition: Tester.cpp:27
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::unique_ptr< Pass > createReductionTreePass()
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ SinglePath
Definition: ReductionNode.h:37
@ ExistingOps
Only pre-existing ops are processed.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26