MLIR  22.0.0git
ReductionNode.cpp
Go to the documentation of this file.
1 //===- ReductionNode.cpp - Reduction Node Implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the reduction nodes which are used to track of the
10 // metadata for a specific generated variant within a reduction pass and are the
11 // building blocks of the reduction tree structure. A reduction tree is used to
12 // keep track of the different generated variants throughout a reduction pass in
13 // the MLIR Reduce tool.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 #include "mlir/IR/IRMapping.h"
19 #include "llvm/ADT/STLExtras.h"
20 
21 #include <limits>
22 
23 using namespace mlir;
24 
26  ReductionNode *parentNode, const std::vector<Range> &ranges,
27  llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator)
28  /// Root node will have the parent pointer point to themselves.
29  : parent(parentNode == nullptr ? this : parentNode),
30  size(std::numeric_limits<size_t>::max()), ranges(ranges),
31  startRanges(ranges), allocator(allocator) {
32  if (parent != this)
33  if (failed(initialize(parent->getModule(), parent->getRegion())))
34  llvm_unreachable("unexpected initialization failure");
35 }
36 
37 LogicalResult ReductionNode::initialize(ModuleOp parentModule,
38  Region &targetRegion) {
39  // Use the mapper help us find the corresponding region after module clone.
40  IRMapping mapper;
41  module = cast<ModuleOp>(parentModule->clone(mapper));
42  // Use the first block of targetRegion to locate the cloned region.
43  Block *block = mapper.lookup(&*targetRegion.begin());
44  region = block->getParent();
45  return success();
46 }
47 
48 /// If we haven't explored any variants from this node, we will create N
49 /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
50 /// max element in `ranges` and create 2 new variants for each call.
52  int oldNumVariant = getVariants().size();
53 
54  auto createNewNode = [this](const std::vector<Range> &ranges) {
55  return new (allocator.Allocate()) ReductionNode(this, ranges, allocator);
56  };
57 
58  // If we haven't created new variant, then we can create varients by removing
59  // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can
60  // produce variants with range {{1, 3}} and {{4, 9}}.
61  if (variants.empty() && getRanges().size() > 1) {
62  for (const Range &range : getRanges()) {
63  std::vector<Range> subRanges = getRanges();
64  llvm::erase(subRanges, range);
65  variants.push_back(createNewNode(subRanges));
66  }
67 
68  return getVariants().drop_front(oldNumVariant);
69  }
70 
71  // At here, we have created the type of variants mentioned above. We would
72  // like to split the max range into 2 to create 2 new variants. Continue on
73  // the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and
74  // create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The
75  // final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
76  auto maxElement =
77  llvm::max_element(ranges, [](const Range &lhs, const Range &rhs) {
78  return (lhs.second - lhs.first) > (rhs.second - rhs.first);
79  });
80 
81  // The length of range is less than 1, we can't split it to create new
82  // variant.
83  if (maxElement->second - maxElement->first <= 1)
84  return {};
85 
86  Range maxRange = *maxElement;
87  std::vector<Range> subRanges = getRanges();
88  auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin());
89  int half = (maxRange.first + maxRange.second) / 2;
90  *subRangesIter = std::make_pair(maxRange.first, half);
91  variants.push_back(createNewNode(subRanges));
92  *subRangesIter = std::make_pair(half, maxRange.second);
93  variants.push_back(createNewNode(subRanges));
94 
95  auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second));
96  it = ranges.insert(it, std::make_pair(maxRange.first, half));
97  // Remove the range that has been split.
98  ranges.erase(it + 2);
99 
100  return getVariants().drop_front(oldNumVariant);
101 }
102 
103 void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) {
104  std::tie(interesting, size) = result;
105  // After applying reduction, the number of operation in the region may have
106  // changed. Non-interesting case won't be explored thus it's safe to keep it
107  // in a stale status.
108  if (interesting == Tester::Interestingness::True) {
109  // This module may has been updated. Reset the range.
110  ranges.clear();
111  ranges.emplace_back(0, std::distance(region->op_begin(), region->op_end()));
112  } else {
113  // Release the uninteresting module to save some memory.
114  module.release()->erase();
115  }
116 }
117 
120  // Single Path: Traverses the smallest successful variant at each level until
121  // no new successful variants can be created at that level.
122  ArrayRef<ReductionNode *> variantsFromParent =
123  node->getParent()->getVariants();
124 
125  // The parent node created several variants and they may be waiting for
126  // examing interestingness. In Single Path approach, we will select the
127  // smallest variant to continue our exploration. Thus we should wait until the
128  // last variant to be examed then do the following traversal decision.
129  if (!llvm::all_of(variantsFromParent, [](ReductionNode *node) {
131  })) {
132  return {};
133  }
134 
135  ReductionNode *smallest = nullptr;
136  for (ReductionNode *node : variantsFromParent) {
138  continue;
139  if (smallest == nullptr || node->getSize() < smallest->getSize())
140  smallest = node;
141  }
142 
143  if (smallest != nullptr &&
144  smallest->getSize() < node->getParent()->getSize()) {
145  // We got a smallest one, keep traversing from this node.
146  node = smallest;
147  } else {
148  // None of these variants is interesting, let the parent node to generate
149  // more variants.
150  node = node->getParent();
151  }
152 
153  return node->generateNewVariants();
154 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Block represents an ordered list of Operations.
Definition: Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
OpTy release()
Release the referenced op.
Definition: OwningOpRef.h:67
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
Definition: ReductionNode.h:43
std::pair< int, int > Range
Definition: ReductionNode.h:48
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
ReductionNode(ReductionNode *parent, const std::vector< Range > &range, llvm::SpecificBumpPtrAllocator< ReductionNode > &allocator)
Root node will have the parent pointer point to themselves.
size_t getSize() const
Return the size of the module.
Definition: ReductionNode.h:66
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Definition: ReductionNode.h:60
Region & getRegion() const
Return the region we're reducing.
Definition: ReductionNode.h:63
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
Definition: ReductionNode.h:69
ArrayRef< ReductionNode * > generateNewVariants()
Split the ranges and generate new variants.
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
Definition: ReductionNode.h:76
ArrayRef< ReductionNode * > getVariants() const
Return the generated variants(the child nodes).
Definition: ReductionNode.h:79
ReductionNode * getParent() const
Definition: ReductionNode.h:53
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
iterator begin()
Definition: Region.h:55
OpIterator op_end()
Definition: Region.h:171
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.