MLIR 22.0.0git
ReductionNode.h
Go to the documentation of this file.
1//===- ReductionNode.h - Reduction Node Implementation ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the reduction nodes which are used to track of the metadata
10// for a specific generated variant within a reduction pass and are the building
11// blocks of the reduction tree structure. A reduction tree is used to keep
12// track of the different generated variants throughout a reduction pass in the
13// MLIR Reduce tool.
14//
15//===----------------------------------------------------------------------===//
16
17#ifndef MLIR_REDUCER_REDUCTIONNODE_H
18#define MLIR_REDUCER_REDUCTIONNODE_H
19
20#include <queue>
21#include <vector>
22
23#include "mlir/IR/OwningOpRef.h"
24#include "mlir/Reducer/Tester.h"
25#include "llvm/ADT/ArrayRef.h"
26#include "llvm/Support/Allocator.h"
27#include "llvm/Support/ToolOutputFile.h"
28
29namespace mlir {
30
31class ModuleOp;
32class Region;
33
34/// Defines the traversal method options to be used in the reduction tree
35/// traversal.
37
38/// ReductionTreePass will build a reduction tree during module reduction and
39/// the ReductionNode represents the vertex of the tree. A ReductionNode records
40/// the information such as the reduced module, how this node is reduced from
41/// the parent node, etc. This information will be used to construct a reduction
42/// path to reduce the certain module.
44public:
45 template <TraversalMode mode>
46 class iterator;
47
48 using Range = std::pair<int, int>;
49
50 ReductionNode(ReductionNode *parent, const std::vector<Range> &range,
51 llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator);
52
53 ReductionNode *getParent() const { return parent; }
54
55 /// If the ReductionNode hasn't been tested the interestingness, it'll be the
56 /// same module as the one in the parent node. Otherwise, the returned module
57 /// will have been applied certain reduction strategies. Note that it's not
58 /// necessary to be an interesting case or a reduced module (has smaller size
59 /// than parent's).
60 ModuleOp getModule() const { return module.get(); }
61
62 /// Return the region we're reducing.
63 Region &getRegion() const { return *region; }
64
65 /// Return the size of the module.
66 size_t getSize() const { return size; }
67
68 /// Returns true if the module exhibits the interesting behavior.
69 Tester::Interestingness isInteresting() const { return interesting; }
70
71 /// Return the range information that how this node is reduced from the parent
72 /// node.
73 ArrayRef<Range> getStartRanges() const { return startRanges; }
74
75 /// Return the range set we are using to generate variants.
76 ArrayRef<Range> getRanges() const { return ranges; }
77
78 /// Return the generated variants(the child nodes).
79 ArrayRef<ReductionNode *> getVariants() const { return variants; }
80
81 /// Split the ranges and generate new variants.
83
84 /// Update the interestingness result from tester.
85 void update(std::pair<Tester::Interestingness, size_t> result);
86
87 /// Each Reduction Node contains a copy of module for applying rewrite
88 /// patterns. In addition, we only apply rewrite patterns in a certain region.
89 /// In init(), we will duplicate the module from parent node and locate the
90 /// corresponding region.
91 LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
92
93private:
94 /// A custom BFS iterator. The difference between
95 /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.
96 /// We may explore more neighbors at certain node if we didn't find interested
97 /// event. As a result, we defer pushing adjacent nodes until poping the last
98 /// visited node. The graph exploration strategy will be put in
99 /// getNeighbors().
100 ///
101 /// Subclass BaseIterator and implement traversal strategy in getNeighbors().
102 template <typename T>
103 class BaseIterator {
104 public:
105 BaseIterator(ReductionNode *node) { visitQueue.push(node); }
106 BaseIterator(const BaseIterator &) = default;
107 BaseIterator() = default;
108
109 static BaseIterator end() { return BaseIterator(); }
110
111 bool operator==(const BaseIterator &i) {
112 return visitQueue == i.visitQueue;
113 }
114 bool operator!=(const BaseIterator &i) { return !(*this == i); }
115
116 BaseIterator &operator++() {
117 ReductionNode *top = visitQueue.front();
118 visitQueue.pop();
119 for (ReductionNode *node : getNeighbors(top))
120 visitQueue.push(node);
121 return *this;
122 }
123
124 BaseIterator operator++(int) {
125 BaseIterator tmp = *this;
126 ++*this;
127 return tmp;
128 }
129
130 ReductionNode &operator*() const { return *(visitQueue.front()); }
131 ReductionNode *operator->() const { return visitQueue.front(); }
132
133 protected:
134 ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node) {
135 return static_cast<T *>(this)->getNeighbors(node);
136 }
137
138 private:
139 std::queue<ReductionNode *> visitQueue;
140 };
141
142 /// This is a copy of module from parent node. All the reducer patterns will
143 /// be applied to this instance.
144 OwningOpRef<ModuleOp> module;
145
146 /// The region of certain operation we're reducing in the module
147 Region *region = nullptr;
148
149 /// The node we are reduced from. It means we will be in variants of parent
150 /// node.
151 ReductionNode *parent = nullptr;
152
153 /// The size of module after applying the reducer patterns with range
154 /// constraints. This is only valid while the interestingness has been tested.
155 size_t size = 0;
156
157 /// This is true if the module has been evaluated and it exhibits the
158 /// interesting behavior.
160
161 /// `ranges` represents the selected subset of operations in the region. We
162 /// implicitly number each operation in the region and ReductionTreePass will
163 /// apply reducer patterns on the operation falls into the `ranges`. We will
164 /// generate new ReductionNode with subset of `ranges` to see if we can do
165 /// further reduction. we may split the element in the `ranges` so that we can
166 /// have more subset variants from `ranges`.
167 /// Note that after applying the reducer patterns the number of operation in
168 /// the region may have changed, we need to update the `ranges` after that.
169 std::vector<Range> ranges;
170
171 /// `startRanges` records the ranges of operations selected from the parent
172 /// node to produce this ReductionNode. It can be used to construct the
173 /// reduction path from the root. I.e., if we apply the same reducer patterns
174 /// and `startRanges` selection on the parent region, we will get the same
175 /// module as this node.
176 const std::vector<Range> startRanges;
177
178 /// This points to the child variants that were created using this node as a
179 /// starting point.
180 std::vector<ReductionNode *> variants;
181
182 llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator;
183};
184
185// Specialized iterator for SinglePath traversal
186template <>
188 : public BaseIterator<iterator<SinglePath>> {
189 friend BaseIterator<iterator<SinglePath>>;
190 using BaseIterator::BaseIterator;
191 ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node);
192};
193
194} // namespace mlir
195
196#endif // MLIR_REDUCER_REDUCTIONNODE_H
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
std::pair< int, int > Range
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
ReductionNode(ReductionNode *parent, const std::vector< Range > &range, llvm::SpecificBumpPtrAllocator< ReductionNode > &allocator)
Root node will have the parent pointer point to themselves.
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
size_t getSize() const
Return the size of the module.
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Region & getRegion() const
Return the region we're reducing.
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
ArrayRef< ReductionNode * > generateNewVariants()
Split the ranges and generate new variants.
ArrayRef< ReductionNode * > getVariants() const
Return the generated variants(the child nodes).
ReductionNode * getParent() const
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Include the generated interface declarations.
TraversalMode
Defines the traversal method options to be used in the reduction tree traversal.
@ MultiPath
@ SinglePath
@ Backtrack