25#include "llvm/ADT/ArrayRef.h"
26#include "llvm/Support/Allocator.h"
29#define GEN_PASS_DEF_REDUCTIONTREEPASS
30#include "mlir/Reducer/Passes.h.inc"
41 bool eraseOpNotInRange) {
42 std::vector<Operation *> opsNotInRange;
43 std::vector<Operation *> opsInRange;
45 for (
const auto &op : enumerate(region.
getOps())) {
46 int index = op.index();
47 if (keepIndex < rangeToKeep.size() &&
48 index == rangeToKeep[keepIndex].second)
50 if (keepIndex == rangeToKeep.size() ||
index < rangeToKeep[keepIndex].first)
51 opsNotInRange.push_back(&op.value());
53 opsInRange.push_back(&op.value());
69 if (eraseOpNotInRange)
83template <
typename IteratorType>
86 const Tester &test,
bool eraseOpNotInRange) {
87 std::pair<Tester::Interestingness, size_t> initStatus =
92 return module.emitWarning() << "uninterested module will not be reduced";
94 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
96 std::vector<ReductionNode::Range> ranges{
103 llvm_unreachable(
"unexpected initialization failure");
107 IteratorType iter(root);
109 while (iter != IteratorType::end()) {
119 smallestNode = ¤tNode;
128 trace.push_back(curNode);
129 while (curNode != root) {
131 trace.push_back(curNode);
135 while (!trace.empty()) {
141 llvm::report_fatal_error(
"Reduced module is not interesting");
143 llvm::report_fatal_error(
144 "Reduced module doesn't have consistent size with smallestNode");
148template <
typename IteratorType>
172class ReductionPatternInterfaceCollection
191class ReductionTreePass
202 LogicalResult reduceOp(ModuleOp module,
Region ®ion);
209LogicalResult ReductionTreePass::initialize(
MLIRContext *context) {
211 ReductionPatternInterfaceCollection reducePatternCollection(context);
212 reducePatternCollection.populateReductionPatterns(
patterns);
213 reducerPatterns = std::move(
patterns);
217void ReductionTreePass::runOnOperation() {
218 Operation *topOperation = getOperation();
221 ModuleOp module = dyn_cast<ModuleOp>(topOperation);
224 <<
"top-level op must be 'builtin.module'";
225 return signalPassFailure();
228 SmallVector<Operation *, 8> workList;
229 workList.push_back(getOperation());
232 Operation *op = workList.pop_back_val();
236 if (
failed(reduceOp(module, region)))
237 return signalPassFailure();
240 for (Operation &op : region.
getOps())
242 workList.push_back(&op);
243 }
while (!workList.empty());
246LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
247 Tester test(testerName, testerArgs);
248 switch (traversalModeId) {
249 case TraversalMode::SinglePath:
251 module, region, reducerPatterns, test);
253 return module.emitError() << "unsupported traversal mode detected";
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static LogicalResult findOptimal(ModuleOp module, Region ®ion, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
A collection of dialect interfaces within a context, for a given concrete interface type.
This is used to report the reduction patterns for a Dialect.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const =0
Patterns provided here are intended to transform operations from a complex form to a simpler form,...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
size_t getSize() const
Return the size of the module.
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Region & getRegion() const
Return the region we're reducing.
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
ReductionNode * getParent() const
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
iterator_range< OpIterator > getOps()
This class is used to keep track of the testing environment of the tool.
std::pair< Interestingness, size_t > isInteresting(ModuleOp module) const
Runs the interestingness testing script on a MLIR test case file.
ReductionTreePassBase Base
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ ExistingOps
Only pre-existing ops are processed.