MLIR  14.0.0git
ReductionTreePass.cpp
Go to the documentation of this file.
1 //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Pass class. It provides a framework for
10 // the implementation of different reduction passes in the MLIR Reduce tool. It
11 // allows for custom specification of the variant generation behavior. It
12 // implements methods that define the different possible traversals of the
13 // reduction tree.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 #include "mlir/IR/OpDefinition.h"
20 #include "mlir/Reducer/Passes.h"
23 #include "mlir/Reducer/Tester.h"
26 
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/Support/Allocator.h"
30 #include "llvm/Support/ManagedStatic.h"
31 
32 using namespace mlir;
33 
34 /// We implicitly number each operation in the region and if an operation's
35 /// number falls into rangeToKeep, we need to keep it and apply the given
36 /// rewrite patterns on it.
37 static void applyPatterns(Region &region,
38  const FrozenRewritePatternSet &patterns,
40  bool eraseOpNotInRange) {
41  std::vector<Operation *> opsNotInRange;
42  std::vector<Operation *> opsInRange;
43  size_t keepIndex = 0;
44  for (const auto &op : enumerate(region.getOps())) {
45  int index = op.index();
46  if (keepIndex < rangeToKeep.size() &&
47  index == rangeToKeep[keepIndex].second)
48  ++keepIndex;
49  if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
50  opsNotInRange.push_back(&op.value());
51  else
52  opsInRange.push_back(&op.value());
53  }
54 
55  // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
56  // matching in above iteration. Besides, erase op not-in-range may end up in
57  // invalid module, so `applyOpPatternsAndFold` should come before that
58  // transform.
59  for (Operation *op : opsInRange)
60  // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
61  // because we don't have expectation this reduction will be success or not.
62  (void)applyOpPatternsAndFold(op, patterns);
63 
64  if (eraseOpNotInRange)
65  for (Operation *op : opsNotInRange) {
66  op->dropAllUses();
67  op->erase();
68  }
69 }
70 
71 /// We will apply the reducer patterns to the operations in the ranges specified
72 /// by ReductionNode. Note that we are not able to remove an operation without
73 /// replacing it with another valid operation. However, The validity of module
74 /// reduction is based on the Tester provided by the user and that means certain
75 /// invalid module is still interested by the use. Thus we provide an
76 /// alternative way to remove operations, which is using `eraseOpNotInRange` to
77 /// erase the operations not in the range specified by ReductionNode.
78 template <typename IteratorType>
79 static LogicalResult findOptimal(ModuleOp module, Region &region,
80  const FrozenRewritePatternSet &patterns,
81  const Tester &test, bool eraseOpNotInRange) {
82  std::pair<Tester::Interestingness, size_t> initStatus =
83  test.isInteresting(module);
84  // While exploring the reduction tree, we always branch from an interesting
85  // node. Thus the root node must be interesting.
86  if (initStatus.first != Tester::Interestingness::True)
87  return module.emitWarning() << "uninterested module will not be reduced";
88 
89  llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
90 
91  std::vector<ReductionNode::Range> ranges{
92  {0, std::distance(region.op_begin(), region.op_end())}};
93 
94  ReductionNode *root = allocator.Allocate();
95  new (root) ReductionNode(nullptr, ranges, allocator);
96  // Duplicate the module for root node and locate the region in the copy.
97  if (failed(root->initialize(module, region)))
98  llvm_unreachable("unexpected initialization failure");
99  root->update(initStatus);
100 
101  ReductionNode *smallestNode = root;
102  IteratorType iter(root);
103 
104  while (iter != IteratorType::end()) {
105  ReductionNode &currentNode = *iter;
106  Region &curRegion = currentNode.getRegion();
107 
108  applyPatterns(curRegion, patterns, currentNode.getRanges(),
109  eraseOpNotInRange);
110  currentNode.update(test.isInteresting(currentNode.getModule()));
111 
112  if (currentNode.isInteresting() == Tester::Interestingness::True &&
113  currentNode.getSize() < smallestNode->getSize())
114  smallestNode = &currentNode;
115 
116  ++iter;
117  }
118 
119  // At here, we have found an optimal path to reduce the given region. Retrieve
120  // the path and apply the reducer to it.
122  ReductionNode *curNode = smallestNode;
123  trace.push_back(curNode);
124  while (curNode != root) {
125  curNode = curNode->getParent();
126  trace.push_back(curNode);
127  }
128 
129  // Reduce the region through the optimal path.
130  while (!trace.empty()) {
131  ReductionNode *top = trace.pop_back_val();
132  applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
133  }
134 
135  if (test.isInteresting(module).first != Tester::Interestingness::True)
136  llvm::report_fatal_error("Reduced module is not interesting");
137  if (test.isInteresting(module).second != smallestNode->getSize())
138  llvm::report_fatal_error(
139  "Reduced module doesn't have consistent size with smallestNode");
140  return success();
141 }
142 
143 template <typename IteratorType>
144 static LogicalResult findOptimal(ModuleOp module, Region &region,
145  const FrozenRewritePatternSet &patterns,
146  const Tester &test) {
147  // We separate the reduction process into 2 steps, the first one is to erase
148  // redundant operations and the second one is to apply the reducer patterns.
149 
150  // In the first phase, we don't apply any patterns so that we only select the
151  // range of operations to keep to the module stay interesting.
152  if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
153  /*eraseOpNotInRange=*/true)))
154  return failure();
155  // In the second phase, we suppose that no operation is redundant, so we try
156  // to rewrite the operation into simpler form.
157  return findOptimal<IteratorType>(module, region, patterns, test,
158  /*eraseOpNotInRange=*/false);
159 }
160 
161 namespace {
162 
163 //===----------------------------------------------------------------------===//
164 // Reduction Pattern Interface Collection
165 //===----------------------------------------------------------------------===//
166 
167 class ReductionPatternInterfaceCollection
168  : public DialectInterfaceCollection<DialectReductionPatternInterface> {
169 public:
170  using Base::Base;
171 
172  // Collect the reduce patterns defined by each dialect.
173  void populateReductionPatterns(RewritePatternSet &pattern) const {
174  for (const DialectReductionPatternInterface &interface : *this)
175  interface.populateReductionPatterns(pattern);
176  }
177 };
178 
179 //===----------------------------------------------------------------------===//
180 // ReductionTreePass
181 //===----------------------------------------------------------------------===//
182 
183 /// This class defines the Reduction Tree Pass. It provides a framework to
184 /// to implement a reduction pass using a tree structure to keep track of the
185 /// generated reduced variants.
186 class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
187 public:
188  ReductionTreePass() = default;
189  ReductionTreePass(const ReductionTreePass &pass) = default;
190 
191  LogicalResult initialize(MLIRContext *context) override;
192 
193  /// Runs the pass instance in the pass pipeline.
194  void runOnOperation() override;
195 
196 private:
197  LogicalResult reduceOp(ModuleOp module, Region &region);
198 
199  FrozenRewritePatternSet reducerPatterns;
200 };
201 
202 } // namespace
203 
204 LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
205  RewritePatternSet patterns(context);
206  ReductionPatternInterfaceCollection reducePatternCollection(context);
207  reducePatternCollection.populateReductionPatterns(patterns);
208  reducerPatterns = std::move(patterns);
209  return success();
210 }
211 
212 void ReductionTreePass::runOnOperation() {
213  Operation *topOperation = getOperation();
214  while (topOperation->getParentOp() != nullptr)
215  topOperation = topOperation->getParentOp();
216  ModuleOp module = cast<ModuleOp>(topOperation);
217 
219  workList.push_back(getOperation());
220 
221  do {
222  Operation *op = workList.pop_back_val();
223 
224  for (Region &region : op->getRegions())
225  if (!region.empty())
226  if (failed(reduceOp(module, region)))
227  return signalPassFailure();
228 
229  for (Region &region : op->getRegions())
230  for (Operation &op : region.getOps())
231  if (op.getNumRegions() != 0)
232  workList.push_back(&op);
233  } while (!workList.empty());
234 }
235 
236 LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
237  Tester test(testerName, testerArgs);
238  switch (traversalModeId) {
240  return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
241  module, region, reducerPatterns, test);
242  default:
243  return module.emitError() << "unsupported traversal mode detected";
244  }
245 }
246 
247 std::unique_ptr<Pass> mlir::createReductionTreePass() {
248  return std::make_unique<ReductionTreePass>();
249 }
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
size_t getSize() const
Return the size of the module.
Definition: ReductionNode.h:67
OpIterator op_end()
Definition: Region.h:171
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:423
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
Definition: ReductionNode.h:44
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:420
This class represents a frozen set of patterns that can be processed by a pattern applicator...
LogicalResult applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns, bool *erased=nullptr)
Applies the specified patterns on op alone while also trying to fold it, by selecting the highest ben...
ModuleOp getModule() const
If the ReductionNode hasn&#39;t been tested the interestingness, it&#39;ll be the same module as the one in t...
Definition: ReductionNode.h:61
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
Definition: ReductionNode.h:74
A collection of dialect interfaces within a context, for a given concrete interface type...
IteratorType
Typed representation for loop type strings.
This is used to report the reduction patterns for a Dialect.
std::unique_ptr< Pass > createReductionTreePass()
virtual void populateReductionPatterns(RewritePatternSet &patterns) const =0
Patterns provided here are intended to transform operations from a complex form to a simpler form...
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
Definition: ReductionNode.h:77
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
iterator_range< OpIterator > getOps()
Definition: Region.h:172
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:117
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
This class is used to keep track of the testing environment of the tool.
Definition: Tester.h:33
std::pair< Interestingness, size_t > isInteresting(ModuleOp module) const
Runs the interestingness testing script on a MLIR test case file.
Definition: Tester.cpp:27
Region & getRegion() const
Return the region we&#39;re reducing.
Definition: ReductionNode.h:64
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
Definition: ReductionNode.h:70
ReductionNode * getParent() const
Definition: ReductionNode.h:54
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation&#39;s number falls into rangeToKeep...
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
static LogicalResult findOptimal(ModuleOp module, Region &region, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.