26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Allocator.h"
29 #include "llvm/Support/ManagedStatic.h"
32 #define GEN_PASS_DEF_REDUCTIONTREE
33 #include "mlir/Reducer/Passes.h.inc"
44 bool eraseOpNotInRange) {
45 std::vector<Operation *> opsNotInRange;
46 std::vector<Operation *> opsInRange;
49 int index = op.index();
50 if (keepIndex < rangeToKeep.size() &&
51 index == rangeToKeep[keepIndex].second)
53 if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
54 opsNotInRange.push_back(&op.value());
56 opsInRange.push_back(&op.value());
71 if (eraseOpNotInRange)
85 template <
typename IteratorType>
88 const Tester &test,
bool eraseOpNotInRange) {
89 std::pair<Tester::Interestingness, size_t> initStatus =
94 return module.emitWarning() <<
"uninterested module will not be reduced";
96 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
98 std::vector<ReductionNode::Range> ranges{
105 llvm_unreachable(
"unexpected initialization failure");
109 IteratorType iter(root);
111 while (iter != IteratorType::end()) {
121 smallestNode = ¤tNode;
130 trace.push_back(curNode);
131 while (curNode != root) {
133 trace.push_back(curNode);
137 while (!trace.empty()) {
143 llvm::report_fatal_error(
"Reduced module is not interesting");
145 llvm::report_fatal_error(
146 "Reduced module doesn't have consistent size with smallestNode");
150 template <
typename IteratorType>
159 if (failed(findOptimal<IteratorType>(module, region, {}, test,
164 return findOptimal<IteratorType>(module, region, patterns, test,
174 class ReductionPatternInterfaceCollection
193 class ReductionTreePass :
public impl::ReductionTreeBase<ReductionTreePass> {
195 ReductionTreePass() =
default;
196 ReductionTreePass(
const ReductionTreePass &pass) =
default;
198 LogicalResult initialize(
MLIRContext *context)
override;
201 void runOnOperation()
override;
204 LogicalResult reduceOp(ModuleOp module,
Region ®ion);
211 LogicalResult ReductionTreePass::initialize(
MLIRContext *context) {
213 ReductionPatternInterfaceCollection reducePatternCollection(context);
214 reducePatternCollection.populateReductionPatterns(patterns);
215 reducerPatterns = std::move(patterns);
219 void ReductionTreePass::runOnOperation() {
220 Operation *topOperation = getOperation();
223 ModuleOp module = dyn_cast<ModuleOp>(topOperation);
226 <<
"top-level op must be 'builtin.module'";
227 return signalPassFailure();
231 workList.push_back(getOperation());
238 if (failed(reduceOp(module, region)))
239 return signalPassFailure();
244 workList.push_back(&op);
245 }
while (!workList.empty());
248 LogicalResult ReductionTreePass::reduceOp(ModuleOp module,
Region ®ion) {
249 Tester test(testerName, testerArgs);
250 switch (traversalModeId) {
252 return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
253 module, region, reducerPatterns, test);
255 return module.emitError() <<
"unsupported traversal mode detected";
260 return std::make_unique<ReductionTreePass>();
static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static LogicalResult findOptimal(ModuleOp module, Region ®ion, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
A collection of dialect interfaces within a context, for a given concrete interface type.
This is used to report the reduction patterns for a Dialect.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const =0
Patterns provided here are intended to transform operations from a complex form to a simpler form,...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteStrictness strictMode
Strict mode can restrict the ops that are added to the worklist during the rewrite.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
size_t getSize() const
Return the size of the module.
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Region & getRegion() const
Return the region we're reducing.
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
ReductionNode * getParent() const
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
This class is used to keep track of the testing environment of the tool.
std::pair< Interestingness, size_t > isInteresting(ModuleOp module) const
Runs the interestingness testing script on a MLIR test case file.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::unique_ptr< Pass > createReductionTreePass()
@ ExistingOps
Only pre-existing ops are processed.