MLIR  16.0.0git
ReductionNode.cpp
Go to the documentation of this file.
1 //===- ReductionNode.cpp - Reduction Node Implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the reduction nodes which are used to track of the
10 // metadata for a specific generated variant within a reduction pass and are the
11 // building blocks of the reduction tree structure. A reduction tree is used to
12 // keep track of the different generated variants throughout a reduction pass in
13 // the MLIR Reduce tool.
14 //
15 //===----------------------------------------------------------------------===//
16 
19 #include "llvm/ADT/STLExtras.h"
20 
21 #include <algorithm>
22 #include <limits>
23 
24 using namespace mlir;
25 
27  ReductionNode *parentNode, const std::vector<Range> &ranges,
28  llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator)
29  /// Root node will have the parent pointer point to themselves.
30  : parent(parentNode == nullptr ? this : parentNode),
31  size(std::numeric_limits<size_t>::max()), ranges(ranges),
32  startRanges(ranges), allocator(allocator) {
33  if (parent != this)
34  if (failed(initialize(parent->getModule(), parent->getRegion())))
35  llvm_unreachable("unexpected initialization failure");
36 }
37 
39  Region &targetRegion) {
40  // Use the mapper help us find the corresponding region after module clone.
41  BlockAndValueMapping mapper;
42  module = cast<ModuleOp>(parentModule->clone(mapper));
43  // Use the first block of targetRegion to locate the cloned region.
44  Block *block = mapper.lookup(&*targetRegion.begin());
45  region = block->getParent();
46  return success();
47 }
48 
49 /// If we haven't explored any variants from this node, we will create N
50 /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
51 /// max element in `ranges` and create 2 new variants for each call.
53  int oldNumVariant = getVariants().size();
54 
55  auto createNewNode = [this](const std::vector<Range> &ranges) {
56  return new (allocator.Allocate()) ReductionNode(this, ranges, allocator);
57  };
58 
59  // If we haven't created new variant, then we can create varients by removing
60  // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can
61  // produce variants with range {{1, 3}} and {{4, 9}}.
62  if (variants.empty() && getRanges().size() > 1) {
63  for (const Range &range : getRanges()) {
64  std::vector<Range> subRanges = getRanges();
65  llvm::erase_value(subRanges, range);
66  variants.push_back(createNewNode(subRanges));
67  }
68 
69  return getVariants().drop_front(oldNumVariant);
70  }
71 
72  // At here, we have created the type of variants mentioned above. We would
73  // like to split the max range into 2 to create 2 new variants. Continue on
74  // the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and
75  // create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The
76  // final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
77  auto maxElement = std::max_element(
78  ranges.begin(), ranges.end(), [](const Range &lhs, const Range &rhs) {
79  return (lhs.second - lhs.first) > (rhs.second - rhs.first);
80  });
81 
82  // The length of range is less than 1, we can't split it to create new
83  // variant.
84  if (maxElement->second - maxElement->first <= 1)
85  return {};
86 
87  Range maxRange = *maxElement;
88  std::vector<Range> subRanges = getRanges();
89  auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin());
90  int half = (maxRange.first + maxRange.second) / 2;
91  *subRangesIter = std::make_pair(maxRange.first, half);
92  variants.push_back(createNewNode(subRanges));
93  *subRangesIter = std::make_pair(half, maxRange.second);
94  variants.push_back(createNewNode(subRanges));
95 
96  auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second));
97  it = ranges.insert(it, std::make_pair(maxRange.first, half));
98  // Remove the range that has been split.
99  ranges.erase(it + 2);
100 
101  return getVariants().drop_front(oldNumVariant);
102 }
103 
104 void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) {
105  std::tie(interesting, size) = result;
106  // After applying reduction, the number of operation in the region may have
107  // changed. Non-interesting case won't be explored thus it's safe to keep it
108  // in a stale status.
109  if (interesting == Tester::Interestingness::True) {
110  // This module may has been updated. Reset the range.
111  ranges.clear();
112  ranges.emplace_back(0, std::distance(region->op_begin(), region->op_end()));
113  } else {
114  // Release the uninteresting module to save some memory.
115  module.release()->erase();
116  }
117 }
118 
121  // Single Path: Traverses the smallest successful variant at each level until
122  // no new successful variants can be created at that level.
123  ArrayRef<ReductionNode *> variantsFromParent =
124  node->getParent()->getVariants();
125 
126  // The parent node created several variants and they may be waiting for
127  // examing interestingness. In Single Path approach, we will select the
128  // smallest variant to continue our exploration. Thus we should wait until the
129  // last variant to be examed then do the following traversal decision.
130  if (!llvm::all_of(variantsFromParent, [](ReductionNode *node) {
132  })) {
133  return {};
134  }
135 
136  ReductionNode *smallest = nullptr;
137  for (ReductionNode *node : variantsFromParent) {
139  continue;
140  if (smallest == nullptr || node->getSize() < smallest->getSize())
141  smallest = node;
142  }
143 
144  if (smallest != nullptr &&
145  smallest->getSize() < node->getParent()->getSize()) {
146  // We got a smallest one, keep traversing from this node.
147  node = smallest;
148  } else {
149  // None of these variants is interesting, let the parent node to generate
150  // more variants.
151  node = node->getParent();
152  }
153 
154  return node->generateNewVariants();
155 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Block * lookup(Block *from) const
Lookup a mapped value within the map.
Block represents an ordered list of Operations.
Definition: Block.h:30
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
OpTy release()
Release the referenced op.
Definition: OwningOpRef.h:59
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
Definition: ReductionNode.h:44
std::pair< int, int > Range
Definition: ReductionNode.h:49
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
ReductionNode(ReductionNode *parent, const std::vector< Range > &range, llvm::SpecificBumpPtrAllocator< ReductionNode > &allocator)
Root node will have the parent pointer point to themselves.
size_t getSize() const
Return the size of the module.
Definition: ReductionNode.h:67
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Definition: ReductionNode.h:61
Region & getRegion() const
Return the region we're reducing.
Definition: ReductionNode.h:64
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
Definition: ReductionNode.h:70
ArrayRef< ReductionNode * > generateNewVariants()
Split the ranges and generate new variants.
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
Definition: ReductionNode.h:77
ArrayRef< ReductionNode * > getVariants() const
Return the generated variants(the child nodes).
Definition: ReductionNode.h:80
ReductionNode * getParent() const
Definition: ReductionNode.h:54
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
iterator begin()
Definition: Region.h:55
OpIterator op_end()
Definition: Region.h:171
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26