MLIR 22.0.0git
ReductionNode.cpp
Go to the documentation of this file.
1//===- ReductionNode.cpp - Reduction Node Implementation -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the reduction nodes which are used to track of the
10// metadata for a specific generated variant within a reduction pass and are the
11// building blocks of the reduction tree structure. A reduction tree is used to
12// keep track of the different generated variants throughout a reduction pass in
13// the MLIR Reduce tool.
14//
15//===----------------------------------------------------------------------===//
16
18#include "mlir/IR/IRMapping.h"
19#include "llvm/ADT/STLExtras.h"
20
21#include <limits>
22
23using namespace mlir;
24
26 ReductionNode *parentNode, const std::vector<Range> &ranges,
27 llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator)
28 /// Root node will have the parent pointer point to themselves.
29 : parent(parentNode == nullptr ? this : parentNode),
30 size(std::numeric_limits<size_t>::max()), ranges(ranges),
31 startRanges(ranges), allocator(allocator) {
32 if (parent != this)
33 if (failed(initialize(parent->getModule(), parent->getRegion())))
34 llvm_unreachable("unexpected initialization failure");
35}
36
37LogicalResult ReductionNode::initialize(ModuleOp parentModule,
38 Region &targetRegion) {
39 // Use the mapper help us find the corresponding region after module clone.
40 IRMapping mapper;
41 module = cast<ModuleOp>(parentModule->clone(mapper));
42 // Use the first block of targetRegion to locate the cloned region.
43 Block *block = mapper.lookup(&*targetRegion.begin());
44 region = block->getParent();
45 return success();
46}
47
48/// If we haven't explored any variants from this node, we will create N
49/// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
50/// max element in `ranges` and create 2 new variants for each call.
52 int oldNumVariant = getVariants().size();
53
54 auto createNewNode = [this](const std::vector<Range> &ranges) {
55 return new (allocator.Allocate()) ReductionNode(this, ranges, allocator);
56 };
57
58 // If we haven't created new variant, then we can create varients by removing
59 // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can
60 // produce variants with range {{1, 3}} and {{4, 9}}.
61 if (variants.empty() && getRanges().size() > 1) {
62 for (const Range &range : getRanges()) {
63 std::vector<Range> subRanges = getRanges();
64 llvm::erase(subRanges, range);
65 variants.push_back(createNewNode(subRanges));
66 }
67
68 return getVariants().drop_front(oldNumVariant);
69 }
70
71 // At here, we have created the type of variants mentioned above. We would
72 // like to split the max range into 2 to create 2 new variants. Continue on
73 // the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and
74 // create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The
75 // final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
76 auto maxElement =
77 llvm::max_element(ranges, [](const Range &lhs, const Range &rhs) {
78 return (lhs.second - lhs.first) > (rhs.second - rhs.first);
79 });
80
81 // The length of range is less than 1, we can't split it to create new
82 // variant.
83 if (maxElement->second - maxElement->first <= 1)
84 return {};
85
86 Range maxRange = *maxElement;
87 std::vector<Range> subRanges = getRanges();
88 auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin());
89 int half = (maxRange.first + maxRange.second) / 2;
90 *subRangesIter = std::make_pair(maxRange.first, half);
91 variants.push_back(createNewNode(subRanges));
92 *subRangesIter = std::make_pair(half, maxRange.second);
93 variants.push_back(createNewNode(subRanges));
94
95 auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second));
96 it = ranges.insert(it, std::make_pair(maxRange.first, half));
97 // Remove the range that has been split.
98 ranges.erase(it + 2);
99
100 return getVariants().drop_front(oldNumVariant);
101}
102
103void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) {
104 std::tie(interesting, size) = result;
105 // After applying reduction, the number of operation in the region may have
106 // changed. Non-interesting case won't be explored thus it's safe to keep it
107 // in a stale status.
108 if (interesting == Tester::Interestingness::True) {
109 // This module may has been updated. Reset the range.
110 ranges.clear();
111 ranges.emplace_back(0, std::distance(region->op_begin(), region->op_end()));
112 } else {
113 // Release the uninteresting module to save some memory.
114 module.release()->erase();
115 }
116}
117
120 // Single Path: Traverses the smallest successful variant at each level until
121 // no new successful variants can be created at that level.
122 ArrayRef<ReductionNode *> variantsFromParent =
123 node->getParent()->getVariants();
124
125 // The parent node created several variants and they may be waiting for
126 // examing interestingness. In Single Path approach, we will select the
127 // smallest variant to continue our exploration. Thus we should wait until the
128 // last variant to be examed then do the following traversal decision.
129 if (!llvm::all_of(variantsFromParent, [](ReductionNode *node) {
131 })) {
132 return {};
133 }
134
135 ReductionNode *smallest = nullptr;
136 for (ReductionNode *node : variantsFromParent) {
138 continue;
139 if (smallest == nullptr || node->getSize() < smallest->getSize())
140 smallest = node;
141 }
142
143 if (smallest != nullptr &&
144 smallest->getSize() < node->getParent()->getSize()) {
145 // We got a smallest one, keep traversing from this node.
146 node = smallest;
147 } else {
148 // None of these variants is interesting, let the parent node to generate
149 // more variants.
150 node = node->getParent();
151 }
152
153 return node->generateNewVariants();
154}
return success()
lhs
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Block represents an ordered list of Operations.
Definition Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
std::pair< int, int > Range
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
ReductionNode(ReductionNode *parent, const std::vector< Range > &range, llvm::SpecificBumpPtrAllocator< ReductionNode > &allocator)
Root node will have the parent pointer point to themselves.
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
size_t getSize() const
Return the size of the module.
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
ArrayRef< ReductionNode * > generateNewVariants()
Split the ranges and generate new variants.
ArrayRef< ReductionNode * > getVariants() const
Return the generated variants(the child nodes).
ReductionNode * getParent() const
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
iterator begin()
Definition Region.h:55
Include the generated interface declarations.